diff --git a/.github/workflows/gradle-all.yml b/.github/workflows/gradle-all.yml index 8540539bb..abbd14106 100644 --- a/.github/workflows/gradle-all.yml +++ b/.github/workflows/gradle-all.yml @@ -5,18 +5,71 @@ on: # but only for the non master/1.0.x branches push: branches-ignore: - - 1.0.x + - 1.1.x - master jobs: build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + othertest: + needs: [build] runs-on: ${{ matrix.os }} strategy: matrix: os: [ ubuntu-latest ] - jdk: [ 1.8, 11, 14 ] + jdk: [ 1.8, 11, 17 ] fail-fast: false steps: @@ -34,12 +87,66 @@ jobs: - name: Grant execute permission for gradlew run: chmod +x gradlew - name: Build with Gradle - run: ./gradlew clean build + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon + + publish: + needs: [ build, coretest, othertest, jcstress ] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew - name: Publish Packages to Artifactory if: ${{ matrix.jdk == '1.8' }} - run: ./gradlew -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" -PversionSuffix="-${githubRef#refs/heads/}-SNAPSHOT" -PbuildNumber="${buildNumber}" artifactoryPublish --stacktrace + run: | + githubRef="${githubRef#refs/heads/}" + githubRef="${githubRef////-}" + ./gradlew -PversionSuffix="-${githubRef}-SNAPSHOT" -PbuildNumber="${buildNumber}" publishMavenPublicationToGitHubPackagesRepository --no-daemon --stacktrace env: - bintrayUser: ${{ secrets.bintrayUser }} - bintrayKey: ${{ secrets.bintrayKey }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} githubRef: ${{ github.ref }} buildNumber: ${{ github.run_number }} \ No newline at end of file diff --git a/.github/workflows/gradle-main.yml b/.github/workflows/gradle-main.yml index d8ba3c3d5..33bca8e72 100644 --- a/.github/workflows/gradle-main.yml +++ b/.github/workflows/gradle-main.yml @@ -2,21 +2,74 @@ name: Main Branches Java CI on: # Trigger the workflow on push - # but only for the master/1.0.x branch + # but only for the master/1.1.x branch push: branches: - master - - 1.0.x + - 1.1.x jobs: build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + othertest: + needs: [build] runs-on: ${{ matrix.os }} strategy: matrix: os: [ ubuntu-latest ] - jdk: [ 1.8, 11, 14 ] + jdk: [ 1.8, 11, 17 ] fail-fast: false steps: @@ -34,14 +87,69 @@ jobs: - name: Grant execute permission for gradlew run: chmod +x gradlew - name: Build with Gradle - run: ./gradlew clean build + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon + + publish: + needs: [ build, coretest, othertest, jcstress ] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew - name: Publish Packages to Artifactory if: ${{ matrix.jdk == '1.8' }} - run: ./gradlew -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" -PversionSuffix="-SNAPSHOT" -PbuildNumber="${buildNumber}" artifactoryPublish --stacktrace + run: ./gradlew -PversionSuffix="-SNAPSHOT" -PbuildNumber="${buildNumber}" publishMavenPublicationToSonatypeRepository --no-daemon --stacktrace env: - bintrayUser: ${{ secrets.bintrayUser }} - bintrayKey: ${{ secrets.bintrayKey }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} buildNumber: ${{ github.run_number }} + ORG_GRADLE_PROJECT_signingKey: ${{secrets.signingKey}} + ORG_GRADLE_PROJECT_signingPassword: ${{secrets.signingPassword}} + ORG_GRADLE_PROJECT_sonatypeUsername: ${{secrets.sonatypeUsername}} + ORG_GRADLE_PROJECT_sonatypePassword: ${{secrets.sonatypePassword}} - name: Aggregate test reports with ciMate if: always() continue-on-error: true diff --git a/.github/workflows/gradle-pr.yml b/.github/workflows/gradle-pr.yml index 994450faf..cecca085f 100644 --- a/.github/workflows/gradle-pr.yml +++ b/.github/workflows/gradle-pr.yml @@ -4,13 +4,93 @@ on: [pull_request] jobs: build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + jcstress: + needs: [build] runs-on: ${{ matrix.os }} strategy: matrix: os: [ ubuntu-latest ] - jdk: [ 1.8, 11, 14 ] + jdk: [ 1.8, 11, 17 ] fail-fast: false steps: @@ -28,4 +108,4 @@ jobs: - name: Grant execute permission for gradlew run: chmod +x gradlew - name: Build with Gradle - run: ./gradlew clean build \ No newline at end of file + run: ./gradlew jcstress --no-daemon \ No newline at end of file diff --git a/.github/workflows/gradle-release.yml b/.github/workflows/gradle-release.yml index 08f2698dc..922eb0e3e 100644 --- a/.github/workflows/gradle-release.yml +++ b/.github/workflows/gradle-release.yml @@ -32,13 +32,13 @@ jobs: - name: Grant execute permission for gradlew run: chmod +x gradlew - name: Build with Gradle - run: ./gradlew clean build - - name: Publish Packages to Bintray - run: ./gradlew -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" -Pversion="${githubRef#refs/tags/}" -PbuildNumber="${buildNumber}" bintrayUpload + run: ./gradlew clean build -x test + - name: Publish Packages to Sonotype + run: ./gradlew -Pversion="${githubRef#refs/tags/}" -PbuildNumber="${buildNumber}" sign publishMavenPublicationToSonatypeRepository env: - bintrayUser: ${{ secrets.bintrayUser }} - bintrayKey: ${{ secrets.bintrayKey }} - sonatypeUsername: ${{ secrets.sonatypeUsername }} - sonatypePassword: ${{ secrets.sonatypePassword }} githubRef: ${{ github.ref }} - buildNumber: ${{ github.run_number }} \ No newline at end of file + buildNumber: ${{ github.run_number }} + ORG_GRADLE_PROJECT_signingKey: ${{secrets.signingKey}} + ORG_GRADLE_PROJECT_signingPassword: ${{secrets.signingPassword}} + ORG_GRADLE_PROJECT_sonatypeUsername: ${{secrets.sonatypeUsername}} + ORG_GRADLE_PROJECT_sonatypePassword: ${{secrets.sonatypePassword}} \ No newline at end of file diff --git a/README.md b/README.md index a721313c7..7ed3244b8 100644 --- a/README.md +++ b/README.md @@ -15,21 +15,22 @@ Learn more at http://rsocket.io ## Build and Binaries -[![Build Status](https://travis-ci.org/rsocket/rsocket-java.svg?branch=develop)](https://travis-ci.org/rsocket/rsocket-java) +[![Build Status](https://github.com/rsocket/rsocket-java/actions/workflows/gradle-main.yml/badge.svg?branch=master)](https://github.com/rsocket/rsocket-java/actions/workflows/gradle-main.yml) -⚠️ The `master` branch is now dedicated to development of the `1.1.x` line. +⚠️ The `master` branch is now dedicated to development of the `1.2.x` line. -Releases are available via Maven Central. +Releases and milestones are available via Maven Central. Example: ```groovy repositories { - mavenCentral() + mavenCentral() + maven { url 'https://repo.spring.io/milestone' } // Reactor milestones (if needed) } dependencies { - implementation 'io.rsocket:rsocket-core:1.1.0-M1' - implementation 'io.rsocket:rsocket-transport-netty:1.1.0-M1' + implementation 'io.rsocket:rsocket-core:1.2.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.2.0-SNAPSHOT' } ``` @@ -39,11 +40,12 @@ Example: ```groovy repositories { - maven { url 'https://oss.jfrog.org/oss-snapshot-local' } + maven { url 'https://maven.pkg.github.com/rsocket/rsocket-java' } + maven { url 'https://repo.spring.io/snapshot' } // Reactor snapshots (if needed) } dependencies { - implementation 'io.rsocket:rsocket-core:1.1.0-SNAPSHOT' - implementation 'io.rsocket:rsocket-transport-netty:1.1.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-core:1.2.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.2.0-SNAPSHOT' } ``` @@ -131,7 +133,7 @@ For bugs, questions and discussions please use the [Github Issues](https://githu ## LICENSE -Copyright 2015-2018 the original author or authors. +Copyright 2015-2020 the original author or authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/benchmarks/README.md b/benchmarks/README.md index 6ba6755a6..656e2de4b 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -17,7 +17,7 @@ Specify extra profilers: Prominent profilers (for full list call `jmhProfilers` task): - comp - JitCompilations, tune your iterations - stack - which methods used most time -- gc - print garbage collection stats +- gc - print garbage collection defaultWeightedStats - hs_thr - thread usage Change report format from JSON to one of [CSV, JSON, NONE, SCSV, TEXT]: diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index 0b8bc601b..74e571d1f 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -14,8 +14,8 @@ dependencies { compileOnly "io.rsocket:rsocket-transport-local:${perfBaselineVersion}" compileOnly "io.rsocket:rsocket-transport-netty:${perfBaselineVersion}" - implementation "org.openjdk.jmh:jmh-core:1.21" - annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.21" + implementation "org.openjdk.jmh:jmh-core:1.35" + annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.35" current project(':rsocket-core') current project(':rsocket-transport-local') diff --git a/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java b/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java deleted file mode 100644 index 6b4f3f624..000000000 --- a/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java +++ /dev/null @@ -1,43 +0,0 @@ -package io.rsocket.core; - -import io.netty.util.collection.IntObjectMap; -import io.rsocket.internal.SynchronizedIntObjectHashMap; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Measurement; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.Warmup; -import org.openjdk.jmh.infra.Blackhole; - -@BenchmarkMode(Mode.Throughput) -@Fork( - value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} - ) -@Warmup(iterations = 10) -@Measurement(iterations = 10) -@State(Scope.Thread) -public class StreamIdSupplierPerf { - @Benchmark - public void benchmarkStreamId(Input input) { - int i = input.supplier.nextStreamId(input.map); - input.bh.consume(i); - } - - @State(Scope.Benchmark) - public static class Input { - Blackhole bh; - IntObjectMap map; - StreamIdSupplier supplier; - - @Setup - public void setup(Blackhole bh) { - this.supplier = StreamIdSupplier.clientSupplier(); - this.bh = bh; - this.map = new SynchronizedIntObjectHashMap(); - } - } -} diff --git a/build.gradle b/build.gradle index 3c63cf58b..2971a7767 100644 --- a/build.gradle +++ b/build.gradle @@ -15,12 +15,12 @@ */ plugins { - id 'com.github.sherter.google-java-format' version '0.8' apply false - id 'com.jfrog.artifactory' version '4.15.2' apply false - id 'com.jfrog.bintray' version '1.8.5' apply false - id 'me.champeau.gradle.jmh' version '0.5.0' apply false - id 'io.spring.dependency-management' version '1.0.9.RELEASE' apply false + id 'com.github.sherter.google-java-format' version '0.9' apply false + id 'me.champeau.jmh' version '0.7.1' apply false + id 'io.spring.dependency-management' version '1.1.0' apply false id 'io.morethan.jmhreport' version '0.9.0' apply false + id 'io.github.reyerizo.gradle.jcstress' version '0.8.15' apply false + id 'com.github.vlsi.gradle-extensions' version '1.89' apply false } boolean isCiServer = ["CI", "CONTINUOUS_INTEGRATION", "TRAVIS", "CIRCLECI", "bamboo_planKey", "GITHUB_ACTION"].with { @@ -31,20 +31,23 @@ boolean isCiServer = ["CI", "CONTINUOUS_INTEGRATION", "TRAVIS", "CIRCLECI", "bam subprojects { apply plugin: 'io.spring.dependency-management' apply plugin: 'com.github.sherter.google-java-format' - - ext['reactor-bom.version'] = '2020.0.0-M2' - ext['logback.version'] = '1.2.3' - ext['findbugs.version'] = '3.0.2' - ext['netty-bom.version'] = '4.1.50.Final' - ext['netty-boringssl.version'] = '2.0.30.Final' - ext['hdrhistogram.version'] = '2.1.10' - ext['mockito.version'] = '3.2.0' - ext['slf4j.version'] = '1.7.25' - ext['jmh.version'] = '1.21' - ext['junit.version'] = '5.5.2' - ext['hamcrest.version'] = '1.3' - ext['micrometer.version'] = '1.0.6' - ext['assertj.version'] = '3.11.1' + apply plugin: 'com.github.vlsi.gradle-extensions' + + ext['reactor-bom.version'] = '2022.0.7-SNAPSHOT' + ext['logback.version'] = '1.2.13' + ext['netty-bom.version'] = '4.1.117.Final' + ext['netty-boringssl.version'] = '2.0.69.Final' + ext['hdrhistogram.version'] = '2.1.12' + ext['mockito.version'] = '4.11.0' + ext['slf4j.version'] = '1.7.36' + ext['jmh.version'] = '1.36' + ext['junit.version'] = '5.9.3' + ext['micrometer.version'] = '1.11.12' + ext['micrometer-tracing.version'] = '1.1.13' + ext['assertj.version'] = '3.24.2' + ext['netflix.limits.version'] = '0.3.6' + ext['bouncycastle-bcpkix.version'] = '1.70' + ext['awaitility.version'] = '4.2.0' group = "io.rsocket" @@ -67,22 +70,23 @@ subprojects { mavenBom "io.projectreactor:reactor-bom:${ext['reactor-bom.version']}" mavenBom "io.netty:netty-bom:${ext['netty-bom.version']}" mavenBom "org.junit:junit-bom:${ext['junit.version']}" + mavenBom "io.micrometer:micrometer-bom:${ext['micrometer.version']}" + mavenBom "io.micrometer:micrometer-tracing-bom:${ext['micrometer-tracing.version']}" } dependencies { + dependency "com.netflix.concurrency-limits:concurrency-limits-core:${ext['netflix.limits.version']}" dependency "ch.qos.logback:logback-classic:${ext['logback.version']}" dependency "io.netty:netty-tcnative-boringssl-static:${ext['netty-boringssl.version']}" - dependency "io.micrometer:micrometer-core:${ext['micrometer.version']}" + dependency "org.bouncycastle:bcpkix-jdk15on:${ext['bouncycastle-bcpkix.version']}" dependency "org.assertj:assertj-core:${ext['assertj.version']}" dependency "org.hdrhistogram:HdrHistogram:${ext['hdrhistogram.version']}" dependency "org.slf4j:slf4j-api:${ext['slf4j.version']}" + dependency "org.awaitility:awaitility:${ext['awaitility.version']}" dependencySet(group: 'org.mockito', version: ext['mockito.version']) { entry 'mockito-junit-jupiter' entry 'mockito-core' } - // TODO: Remove after JUnit5 migration - dependency 'junit:junit:4.12' - dependency "org.hamcrest:hamcrest-library:${ext['hamcrest.version']}" dependencySet(group: 'org.openjdk.jmh', version: ext['jmh.version']) { entry 'jmh-core' entry 'jmh-generator-annprocess' @@ -97,8 +101,19 @@ subprojects { mavenCentral() maven { - url 'https://repo.spring.io/libs-snapshot' + url 'https://repo.spring.io/milestone' + content { + includeGroup "io.micrometer" + includeGroup "io.projectreactor" + includeGroup "io.projectreactor.netty" + includeGroup "io.micrometer" + } + } + + maven { + url 'https://repo.spring.io/snapshot' content { + includeGroup "io.micrometer" includeGroup "io.projectreactor" includeGroup "io.projectreactor.netty" } @@ -107,6 +122,7 @@ subprojects { if (version.endsWith('SNAPSHOT') || project.hasProperty('versionSuffix')) { maven { url 'https://repo.spring.io/libs-snapshot' } maven { url 'https://oss.jfrog.org/artifactory/oss-snapshot-local' } + mavenLocal() } } @@ -115,6 +131,7 @@ subprojects { } plugins.withType(JavaPlugin) { + compileJava { sourceCompatibility = 1.8 @@ -143,8 +160,9 @@ subprojects { test { useJUnitPlatform() testLogging { - events "FAILED" + events "PASSED", "FAILED" showExceptions true + showCauses true exceptionFormat "FULL" stackTraceFilters "ENTRY_POINT" maxGranularity 3 @@ -187,7 +205,7 @@ subprojects { if (JavaVersion.current().isJava9Compatible()) { println "Java 9+: lowering MaxGCPauseMillis to 20ms in ${project.name} ${name}" println "Java 9+: enabling leak detection [ADVANCED]" - jvmArgs = ["-XX:MaxGCPauseMillis=20", "-Dio.netty.leakDetection.level=ADVANCED"] + jvmArgs = ["-XX:MaxGCPauseMillis=20", "-Dio.netty.leakDetection.level=ADVANCED", "-Dio.netty.leakDetection.samplingInterval=32"] } systemProperty("java.awt.headless", "true") @@ -242,4 +260,31 @@ description = 'RSocket: Stream Oriented Messaging Passing with Reactive Stream S repositories { mavenCentral() + + maven { url 'https://repo.spring.io/snapshot' } + mavenLocal() +} + +configurations { + adoc +} + +dependencies { + adoc "io.micrometer:micrometer-docs-generator-spans:1.0.0-SNAPSHOT" + adoc "io.micrometer:micrometer-docs-generator-metrics:1.0.0-SNAPSHOT" +} + +task generateObservabilityDocs(dependsOn: ["generateObservabilityMetricsDocs", "generateObservabilitySpansDocs"]) { +} + +task generateObservabilityMetricsDocs(type: JavaExec) { + mainClass = "io.micrometer.docs.metrics.DocsFromSources" + classpath configurations.adoc + args project.rootDir.getAbsolutePath(), ".*", project.rootProject.buildDir.getAbsolutePath() +} + +task generateObservabilitySpansDocs(type: JavaExec) { + mainClass = "io.micrometer.docs.spans.DocsFromSources" + classpath configurations.adoc + args project.rootDir.getAbsolutePath(), ".*", project.rootProject.buildDir.getAbsolutePath() } diff --git a/gradle.properties b/gradle.properties index 18e7c5584..d138852c5 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,5 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -version=1.1.0 -perfBaselineVersion=1.0.1 +version=1.2.0 +perfBaselineVersion=1.1.4 diff --git a/gradle/artifactory.gradle b/gradle/artifactory.gradle deleted file mode 100644 index cdffb2741..000000000 --- a/gradle/artifactory.gradle +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -if (project.hasProperty('bintrayUser') && project.hasProperty('bintrayKey')) { - - subprojects { - plugins.withId('com.jfrog.artifactory') { - artifactory { - publish { - contextUrl = 'https://oss.jfrog.org' - - repository { - repoKey = 'oss-snapshot-local' - - // Credentials for oss.jfrog.org are a user's Bintray credentials - username = project.property('bintrayUser') - password = project.property('bintrayKey') - } - - defaults { - publications(publishing.publications.maven) - } - - if (project.hasProperty('buildNumber')) { - clientConfig.info.setBuildNumber(project.property('buildNumber').toString()) - } - } - } - tasks.named("artifactoryPublish").configure { - onlyIf { System.getenv('SKIP_RELEASE') != "true" } - } - } - } -} diff --git a/gradle/bintray.gradle b/gradle/bintray.gradle deleted file mode 100644 index 5015f94e4..000000000 --- a/gradle/bintray.gradle +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -if (project.hasProperty('bintrayUser') && project.hasProperty('bintrayKey') && - project.hasProperty('sonatypeUsername') && project.hasProperty('sonatypePassword')) { - - subprojects { - plugins.withId('com.jfrog.bintray') { - bintray { - user = project.property('bintrayUser') - key = project.property('bintrayKey') - - publications = ['maven'] - publish = true - override = true - - pkg { - repo = 'RSocket' - name = 'rsocket-java' - licenses = ['Apache-2.0'] - - issueTrackerUrl = 'https://github.com/rsocket/rsocket-java/issues' - websiteUrl = 'https://github.com/rsocket/rsocket-java' - vcsUrl = 'https://github.com/rsocket/rsocket-java.git' - - githubRepo = 'rsocket/rsocket-java' //Optional Github repository - githubReleaseNotesFile = 'README.md' //Optional Github readme file - - version { - name = project.version - released = new Date() - vcsTag = project.version - - gpg { - sign = true - } - - mavenCentralSync { - user = project.property('sonatypeUsername') - password = project.property('sonatypePassword') - } - } - } - } - tasks.named("bintrayUpload").configure { - onlyIf { System.getenv('SKIP_RELEASE') != "true" } - } - } - } -} diff --git a/gradle/github-pkg.gradle b/gradle/github-pkg.gradle new file mode 100644 index 000000000..f53413766 --- /dev/null +++ b/gradle/github-pkg.gradle @@ -0,0 +1,21 @@ +subprojects { + + plugins.withType(MavenPublishPlugin) { + publishing { + repositories { + maven { + name = "GitHubPackages" + url = uri("https://maven.pkg.github.com/rsocket/rsocket-java") + credentials { + username = project.findProperty("gpr.user") ?: System.getenv("GITHUB_ACTOR") + password = project.findProperty("gpr.key") ?: System.getenv("GITHUB_TOKEN") + } + } + } + } + + tasks.named("publish").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + } +} \ No newline at end of file diff --git a/gradle/publications.gradle b/gradle/publications.gradle index b12d9e9c2..9e8dd6d88 100644 --- a/gradle/publications.gradle +++ b/gradle/publications.gradle @@ -1,5 +1,5 @@ -apply from: "${rootDir}/gradle/artifactory.gradle" -apply from: "${rootDir}/gradle/bintray.gradle" +apply from: "${rootDir}/gradle/github-pkg.gradle" +apply from: "${rootDir}/gradle/sonotype.gradle" subprojects { plugins.withType(MavenPublishPlugin) { @@ -21,30 +21,15 @@ subprojects { } } developers { - developer { - id = 'robertroeser' - name = 'Robert Roeser' - email = 'robert@netifi.com' - } - developer { - id = 'rdegnan' - name = 'Ryland Degnan' - email = 'ryland@netifi.com' - } - developer { - id = 'yschimke' - name = 'Yuri Schimke' - email = 'yuri@schimke.ee' - } developer { id = 'OlegDokuka' name = 'Oleh Dokuka' - email = 'oleh@netifi.com' + email = 'oleh.dokuka@icloud.com' } developer { - id = 'mostroverkhov' - name = 'Maksym Ostroverkhov' - email = 'm.ostroverkhov@gmail.com' + id = 'rstoyanchev' + name = 'Rossen Stoyanchev' + email = 'rstoyanchev@vmware.com' } } scm { diff --git a/gradle/sonotype.gradle b/gradle/sonotype.gradle new file mode 100644 index 000000000..f339079b0 --- /dev/null +++ b/gradle/sonotype.gradle @@ -0,0 +1,36 @@ +subprojects { + if (project.hasProperty('sonatypeUsername') && project.hasProperty('sonatypePassword')) { + plugins.withType(MavenPublishPlugin) { + plugins.withType(SigningPlugin) { + + signing { + //requiring signature if there is a publish task that is not to MavenLocal + required { gradle.taskGraph.allTasks.any { it.name.toLowerCase().contains("publish") && !it.name.contains("MavenLocal") } } + def signingKey = project.findProperty("signingKey") + def signingPassword = project.findProperty("signingPassword") + + useInMemoryPgpKeys(signingKey, signingPassword) + + afterEvaluate { + sign publishing.publications.maven + } + } + + publishing { + repositories { + maven { + name = "sonatype" + url = project.version.contains("-SNAPSHOT") + ? "https://oss.sonatype.org/content/repositories/snapshots/" + : "https://oss.sonatype.org/service/local/staging/deploy/maven2" + credentials { + username project.findProperty("sonatypeUsername") + password project.findProperty("sonatypePassword") + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 5c2d1cf01..249e5832f 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 8f6e03af5..774fae876 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,5 @@ -#Mon Jun 08 20:22:21 EEST 2020 -distributionUrl=https\://services.gradle.org/distributions/gradle-6.5-all.zip distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 83f2acfdc..a69d9cb6c 100755 --- a/gradlew +++ b/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,78 +17,113 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null + +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` +APP_BASE_NAME=${0##*/} # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + # Determine the Java command to use to start the JVM. if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -97,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -105,84 +140,101 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. # For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=$((i+1)) + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - (0) set -- ;; - (1) set -- "$args0" ;; - (2) set -- "$args0" "$args1" ;; - (3) set -- "$args0" "$args1" "$args2" ;; - (4) set -- "$args0" "$args1" "$args2" "$args3" ;; - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=$(save "$@") +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# -# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong -if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then - cd "$(dirname "$0")" -fi +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index 24467a141..53a6b238d 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -14,7 +14,7 @@ @rem limitations under the License. @rem -@if "%DEBUG%" == "" @echo off +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -25,10 +25,13 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" @@ -37,7 +40,7 @@ if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init +if %ERRORLEVEL% equ 0 goto execute echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. @@ -51,7 +54,7 @@ goto fail set JAVA_HOME=%JAVA_HOME:"=% set JAVA_EXE=%JAVA_HOME%/bin/java.exe -if exist "%JAVA_EXE%" goto init +if exist "%JAVA_EXE%" goto execute echo. echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% @@ -61,38 +64,26 @@ echo location of your Java installation. goto fail -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - :execute @rem Setup the command line set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + @rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/rsocket-bom/build.gradle b/rsocket-bom/build.gradle index 2efc20a91..a75ab3bc8 100755 --- a/rsocket-bom/build.gradle +++ b/rsocket-bom/build.gradle @@ -16,8 +16,7 @@ plugins { id 'java-platform' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } description = 'RSocket Java Bill of materials.' diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index 41adbd7a8..da5b69b14 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,10 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' id 'io.morethan.jmhreport' - id 'me.champeau.gradle.jmh' + id 'me.champeau.jmh' + id 'io.github.reyerizo.gradle.jcstress' } dependencies { @@ -29,19 +29,32 @@ dependencies { implementation 'org.slf4j:slf4j-api' + testImplementation (project(":rsocket-transport-local")) testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.junit.jupiter:junit-jupiter-params' - testImplementation 'org.mockito:mockito-core' + testImplementation 'org.mockito:mockito-junit-jupiter' + testImplementation 'org.awaitility:awaitility' testRuntimeOnly 'ch.qos.logback:logback-classic' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' - // TODO: Remove after JUnit5 migration - testCompileOnly 'junit:junit' - testImplementation 'org.hamcrest:hamcrest-library' - testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + jcstressImplementation(project(":rsocket-test")) + jcstressImplementation 'org.slf4j:slf4j-api' + jcstressImplementation "ch.qos.logback:logback-classic" + jcstressImplementation 'io.projectreactor:reactor-test' } -description = "Core functionality for the RSocket library" \ No newline at end of file +jcstress { + mode = 'sanity' //sanity, quick, default, tough + jcstressDependency = "org.openjdk.jcstress:jcstress-core:0.16" +} + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.core") + } +} + +description = "Core functionality for the RSocket library" diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java new file mode 100644 index 000000000..e91be2451 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java @@ -0,0 +1,115 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.test.TestDuplexConnection; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLL_Result; + +public abstract class FireAndForgetRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final TestRequesterResponderSupport requesterResponderSupport = + new TestRequesterResponderSupport(testDuplexConnection, StreamIdSupplier.clientSupplier()); + + final FireAndForgetRequesterMono source = source(); + + abstract FireAndForgetRequesterMono source(); + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + @Override + FireAndForgetRequesterMono source() { + return new FireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndCancelRaceStressTest extends BaseStressTest { + + @Override + FireAndForgetRequesterMono source() { + return new FireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java new file mode 100644 index 000000000..ef79d344d --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java @@ -0,0 +1,604 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.ResolvingOperator.EMPTY_SUBSCRIBED; +import static io.rsocket.core.ResolvingOperator.EMPTY_UNSUBSCRIBED; +import static io.rsocket.core.ResolvingOperator.READY; +import static io.rsocket.core.ResolvingOperator.TERMINATED; +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.function.BiConsumer; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.IIIIIII_Result; +import org.openjdk.jcstress.infra.results.IIIIII_Result; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; + +public abstract class ReconnectMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscription stressSubscription = new StressSubscription<>(); + + final Mono source = source(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + volatile int onValueExpire; + + static final AtomicIntegerFieldUpdater ON_VALUE_EXPIRE = + AtomicIntegerFieldUpdater.newUpdater(BaseStressTest.class, "onValueExpire"); + + volatile int onValueReceived; + + static final AtomicIntegerFieldUpdater ON_VALUE_RECEIVED = + AtomicIntegerFieldUpdater.newUpdater(BaseStressTest.class, "onValueReceived"); + final ReconnectMono reconnectMono = + new ReconnectMono<>( + source, + (__) -> ON_VALUE_EXPIRE.incrementAndGet(BaseStressTest.this), + (__, ___) -> ON_VALUE_RECEIVED.incrementAndGet(BaseStressTest.this)); + + abstract Mono source(); + + int state() { + final BiConsumer[] subscribers = reconnectMono.resolvingInner.subscribers; + if (subscribers == EMPTY_UNSUBSCRIBED) { + return 0; + } else if (subscribers == EMPTY_SUBSCRIBED) { + return 1; + } else if (subscribers == READY) { + return 2; + } else if (subscribers == TERMINATED) { + return 3; + } else { + return 4; + } + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed before value is delivered") + @Outcome( + id = {"0, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after onComplete but before value is delivered") + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after value is delivered") + @State + public static class ExpireValueOnRacingDisposeAndNext extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed before error is delivered") + @Outcome( + id = {"0, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after onError") + @State + public static class ExpireValueOnRacingDisposeAndError extends BaseStressTest { + + { + Hooks.onErrorDropped(t -> {}); + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onError(new RuntimeException("boom")); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + Hooks.resetOnErrorDropped(); + + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 0, 1, 2"}, + expect = ACCEPTABLE, + desc = "Invalidate happens before value is delivered") + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Invalidate happens after value is delivered") + @State + public static class ExpireValueOnRacingInvalidateAndNextComplete extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 0"}, + expect = ACCEPTABLE) + @State + public static class ExpireValueOnceOnRacingInvalidateAndInvalidate extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + }; + } + + @Actor + void invalidate1() { + reconnectMono.invalidate(); + } + + @Actor + void invalidate2() { + reconnectMono.invalidate(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 3"}, + expect = ACCEPTABLE) + @State + public static class ExpireValueOnceOnRacingInvalidateAndDispose extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 2, 2, 0, 1"}, + expect = ACCEPTABLE) + @State + public static class DeliversValueToAllSubscribersUnderRace extends BaseStressTest { + + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNextAndComplete() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void secondSubscribe() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.requestsCount; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.onNextCalls + stressSubscriber2.onNextCalls; + r.r4 = stressSubscriber.onCompleteCalls + stressSubscriber2.onCompleteCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + } + } + + @JCStressTest + @Outcome( + id = {"2, 0, 1, 1, 1, 1, 4"}, + expect = ACCEPTABLE, + desc = "Second Subscriber subscribed after invalidate") + @Outcome( + id = {"1, 0, 2, 2, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Second Subscriber subscribed before invalidate and received value") + @State + public static class InvalidateAndSubscribeUnderRace extends BaseStressTest { + + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + { + reconnectMono.subscribe(stressSubscriber); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void secondSubscribe() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.onNextCalls + stressSubscriber2.onNextCalls; + r.r4 = stressSubscriber.onCompleteCalls + stressSubscriber2.onCompleteCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"2, 0, 2, 1, 2, 2"}, + expect = ACCEPTABLE, + desc = "Subscribed again after invalidate") + @Outcome( + id = {"1, 0, 1, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Subscribed before invalidate") + @State + public static class InvalidateAndBlockUnderRace extends BaseStressTest { + + String receivedValue; + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void secondSubscribe() { + receivedValue = reconnectMono.block(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue.equals("value1") ? 1 : receivedValue.equals("value2") ? 2 : -1; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRace extends BaseStressTest { + + StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void subscribe1() { + reconnectMono.subscribe(stressSubscriber); + } + + @Actor + void subscribe2() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.values.get(0).equals(stressSubscriber2.values.get(0)) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class SubscribeBlockConnectRace extends BaseStressTest { + + String receivedValue; + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void block() { + receivedValue = reconnectMono.block(); + } + + @Actor + void subscribe() { + reconnectMono.subscribe(stressSubscriber); + } + + @Actor + void connect() { + reconnectMono.resolvingInner.connect(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue.equals(stressSubscriber.values.get(0)) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class TwoBlocksRace extends BaseStressTest { + + String receivedValue1; + String receivedValue2; + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void block1() { + receivedValue1 = reconnectMono.block(); + } + + @Actor + void block2() { + receivedValue2 = reconnectMono.block(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue1.equals(receivedValue2) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java new file mode 100644 index 000000000..1dde77b34 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java @@ -0,0 +1,650 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.test.TestDuplexConnection; +import java.util.stream.IntStream; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLLLL_Result; +import org.openjdk.jcstress.infra.results.LLLLL_Result; +import org.openjdk.jcstress.infra.results.LLLL_Result; + +public abstract class RequestResponseRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(initialRequest()); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final RequesterLeaseTracker requesterLeaseTracker; + + final TestRequesterResponderSupport requesterResponderSupport; + + final RequestResponseRequesterMono source; + + BaseStressTest(RequesterLeaseTracker requesterLeaseTracker) { + this.requesterLeaseTracker = requesterLeaseTracker; + this.requesterResponderSupport = + new TestRequesterResponderSupport( + testDuplexConnection, StreamIdSupplier.clientSupplier(), requesterLeaseTracker); + this.source = source(); + } + + abstract RequestResponseRequesterMono source(); + + abstract long initialRequest(); + } + + abstract static class BaseStressTestWithLease extends BaseStressTest { + + BaseStressTestWithLease(int maximumAllowedAwaitingPermitHandlersNumber) { + super(new RequesterLeaseTracker("test", maximumAllowedAwaitingPermitHandlersNumber)); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTestWithLease { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + public TwoSubscribesRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return Long.MAX_VALUE; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + final ByteBuf nextFrame = + PayloadFrameCodec.encode( + this.testDuplexConnection.alloc(), + 1, + false, + true, + true, + null, + ByteBufUtil.writeUtf8(this.testDuplexConnection.alloc(), "response-data")); + this.source.handleNext(nextFrame, false, true); + nextFrame.release(); + + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + this.stressSubscriber1.values.forEach(Payload::release); + + r.r5 = this.source.payload.refCnt() + nextFrame.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancelRaceStressTest extends BaseStressTestWithLease { + + public SubscribeAndRequestAndCancelRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancelWithDeferredLeaseRaceStressTest + extends BaseStressTestWithLease { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + public SubscribeAndRequestAndCancelWithDeferredLeaseRaceStressTest() { + super(1); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 2, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "NoLeaseError delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @Outcome( + id = {"-9223372036854775808, 3, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = + "cancellation happened after lease permit requested but before it was actually decided and in the case when no lease are available. Error is dropped") + @State + public static class SubscribeAndRequestAndCancelWithDeferredLease2RaceStressTest + extends BaseStressTestWithLease { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + SubscribeAndRequestAndCancelWithDeferredLease2RaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancel extends BaseStressTest { + + SubscribeAndRequestAndCancel() { + super(null); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + r.r5 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @State + public static class CancelWithInboundNextRaceStressTest extends BaseStressTestWithLease { + + final ByteBuf nextFrame = + PayloadFrameCodec.encode( + this.testDuplexConnection.alloc(), + 1, + false, + true, + true, + null, + ByteBufUtil.writeUtf8(this.testDuplexConnection.alloc(), "response-data")); + + CancelWithInboundNextRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundNext() { + this.source.handleNext(this.nextFrame, false, true); + this.nextFrame.release(); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt() + this.nextFrame.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @State + public static class CancelWithInboundCompleteRaceStressTest extends BaseStressTestWithLease { + + CancelWithInboundCompleteRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundComplete() { + this.source.handleComplete(); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 2, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 3, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first. inbound error dropped") + @State + public static class CancelWithInboundErrorRaceStressTest extends BaseStressTestWithLease { + + static final RuntimeException ERROR = new RuntimeException("Test"); + + CancelWithInboundErrorRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundError() { + this.source.handleError(ERROR); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt(); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java new file mode 100644 index 000000000..5de7eb4b9 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java @@ -0,0 +1,288 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.test.TestDuplexConnection; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLLL_Result; + +public abstract class SlowFireAndForgetRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final RequesterLeaseTracker requesterLeaseTracker = + new RequesterLeaseTracker("test", maximumAllowedAwaitingPermitHandlersNumber()); + + final TestRequesterResponderSupport requesterResponderSupport = + new TestRequesterResponderSupport( + testDuplexConnection, StreamIdSupplier.clientSupplier(), requesterLeaseTracker); + + final SlowFireAndForgetRequesterMono source = source(); + + abstract SlowFireAndForgetRequesterMono source(); + + abstract int maximumAllowedAwaitingPermitHandlersNumber(); + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @State + public static class SubscribeAndCancelRaceStressTest extends BaseStressTest { + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndCancelWithDeferredLeaseRaceStressTest extends BaseStressTest { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 1; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 2, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "no lease error delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @Outcome( + id = {"-9223372036854775808, 3, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = + "cancellation happened after lease permit requested but before it was actually decided and in the case when no lease are available. Error is dropped") + @State + public static class SubscribeAndCancelWithDeferredLease2RaceStressTest extends BaseStressTest { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java new file mode 100644 index 000000000..883077f77 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java @@ -0,0 +1,472 @@ +/* + * Copyright (c) 2020-Present Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static reactor.core.publisher.Operators.addCap; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +public class StressSubscriber implements CoreSubscriber { + + enum Operation { + ON_NEXT, + ON_ERROR, + ON_COMPLETE, + ON_SUBSCRIBE + } + + final Context context; + final int requestedFusionMode; + + int fusionMode; + Subscription subscription; + + public Throwable error; + public boolean done; + + public List droppedErrors = new CopyOnWriteArrayList<>(); + + public List values = new ArrayList<>(); + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(StressSubscriber.class, "requested"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "wip"); + + public volatile Operation guard; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater GUARD = + AtomicReferenceFieldUpdater.newUpdater(StressSubscriber.class, Operation.class, "guard"); + + public volatile boolean concurrentOnNext; + + public volatile boolean concurrentOnError; + + public volatile boolean concurrentOnComplete; + + public volatile boolean concurrentOnSubscribe; + + public volatile int onNextCalls; + + public BlockingQueue q = new LinkedBlockingDeque<>(); + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_NEXT_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onNextCalls"); + + public volatile int onNextDiscarded; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_NEXT_DISCARDED = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onNextDiscarded"); + + public volatile int onErrorCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_ERROR_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onErrorCalls"); + + public volatile int onCompleteCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_COMPLETE_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onCompleteCalls"); + + public volatile int onSubscribeCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_SUBSCRIBE_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onSubscribeCalls"); + + /** Build a {@link StressSubscriber} that makes an unbounded request upon subscription. */ + public StressSubscriber() { + this(Long.MAX_VALUE, Fuseable.NONE); + } + + /** + * Build a {@link StressSubscriber} that requests the provided amount in {@link + * #onSubscribe(Subscription)}. Use {@code 0} to avoid any initial request upon subscription. + * + * @param initRequest the requested amount upon subscription, or zero to disable initial request + */ + public StressSubscriber(long initRequest) { + this(initRequest, Fuseable.NONE); + } + + /** + * Build a {@link StressSubscriber} that requests the provided amount in {@link + * #onSubscribe(Subscription)}. Use {@code 0} to avoid any initial request upon subscription. + * + * @param initRequest the requested amount upon subscription, or zero to disable initial request + */ + public StressSubscriber(long initRequest, int requestedFusionMode) { + this.requestedFusionMode = requestedFusionMode; + this.context = + Operators.enableOnDiscard( + Context.of( + "reactor.onErrorDropped.local", + (Consumer) throwable -> droppedErrors.add(throwable)), + (__) -> ON_NEXT_DISCARDED.incrementAndGet(this)); + REQUESTED.lazySet(this, initRequest | Long.MIN_VALUE); + } + + @Override + public Context currentContext() { + return this.context; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (!GUARD.compareAndSet(this, null, Operation.ON_SUBSCRIBE)) { + concurrentOnSubscribe = true; + subscription.cancel(); + } else { + final boolean isValid = Operators.validate(this.subscription, subscription); + if (isValid) { + this.subscription = subscription; + } + GUARD.compareAndSet(this, Operation.ON_SUBSCRIBE, null); + + if (this.requestedFusionMode > 0 && subscription instanceof Fuseable.QueueSubscription) { + final int m = + ((Fuseable.QueueSubscription) subscription).requestFusion(this.requestedFusionMode); + final long requested = this.requested; + this.fusionMode = m; + if (m != Fuseable.NONE) { + if (requested == Long.MAX_VALUE) { + subscription.cancel(); + } + drain(); + return; + } + } + + if (isValid) { + long delivered = 0; + for (; ; ) { + long s = requested; + if (s == Long.MAX_VALUE) { + subscription.cancel(); + break; + } + + long r = s & Long.MAX_VALUE; + long toRequest = r - delivered; + if (toRequest > 0) { + subscription.request(toRequest); + delivered = r; + } + + if (REQUESTED.compareAndSet(this, s, 0)) { + break; + } + } + } + } + ON_SUBSCRIBE_CALLS.incrementAndGet(this); + } + + @Override + public void onNext(T value) { + if (fusionMode == Fuseable.ASYNC) { + drain(); + return; + } + + if (!GUARD.compareAndSet(this, null, Operation.ON_NEXT)) { + concurrentOnNext = true; + } else { + values.add(value); + GUARD.compareAndSet(this, Operation.ON_NEXT, null); + } + ON_NEXT_CALLS.incrementAndGet(this); + } + + @Override + public void onError(Throwable throwable) { + if (!GUARD.compareAndSet(this, null, Operation.ON_ERROR)) { + concurrentOnError = true; + } else { + GUARD.compareAndSet(this, Operation.ON_ERROR, null); + } + + if (done) { + throw new IllegalStateException("Already done"); + } + + error = throwable; + done = true; + q.offer(throwable); + ON_ERROR_CALLS.incrementAndGet(this); + + if (fusionMode == Fuseable.ASYNC) { + drain(); + } + } + + @Override + public void onComplete() { + if (!GUARD.compareAndSet(this, null, Operation.ON_COMPLETE)) { + concurrentOnComplete = true; + } else { + GUARD.compareAndSet(this, Operation.ON_COMPLETE, null); + } + if (done) { + throw new IllegalStateException("Already done"); + } + + done = true; + ON_COMPLETE_CALLS.incrementAndGet(this); + + if (fusionMode == Fuseable.ASYNC) { + drain(); + } + } + + public void request(long n) { + if (Operators.validate(n)) { + for (; ; ) { + final long s = this.requested; + if (s == 0) { + this.subscription.request(n); + return; + } + + if ((s & Long.MIN_VALUE) != Long.MIN_VALUE) { + return; + } + + final long r = s & Long.MAX_VALUE; + if (r == Long.MAX_VALUE) { + return; + } + + final long u = addCap(r, n); + if (REQUESTED.compareAndSet(this, s, u | Long.MIN_VALUE)) { + if (this.fusionMode != Fuseable.NONE) { + drain(); + } + return; + } + } + } + } + + public void cancel() { + for (; ; ) { + long s = this.requested; + if (s == 0) { + this.subscription.cancel(); + return; + } + + if (REQUESTED.compareAndSet(this, s, Long.MAX_VALUE)) { + if (this.fusionMode != Fuseable.NONE) { + drain(); + } + return; + } + } + } + + @SuppressWarnings("unchecked") + private void drain() { + final int previousState = markWorkAdded(); + if (isFinalized(previousState)) { + ((Queue) this.subscription).clear(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + final Subscription s = this.subscription; + final Queue q = (Queue) s; + + int expectedState = previousState + 1; + for (; ; ) { + long r = this.requested & Long.MAX_VALUE; + long e = 0L; + + while (r != e) { + // done has to be read before queue.poll to ensure there was no racing: + // Thread1: <#drain>: queue.poll(null) --------------------> this.done(true) + // Thread2: ------------------> <#onNext(V)> --> <#onComplete()> + boolean done = this.done; + + final T t = q.poll(); + final boolean empty = t == null; + + if (checkTerminated(done, empty)) { + if (!empty) { + values.add(t); + } + return; + } + + if (empty) { + break; + } + + values.add(t); + + e++; + } + + if (r == e) { + // done has to be read before queue.isEmpty to ensure there was no racing: + // Thread1: <#drain>: queue.isEmpty(true) --------------------> this.done(true) + // Thread2: --------------------> <#onNext(V)> ---> <#onComplete()> + boolean done = this.done; + boolean empty = q.isEmpty(); + + if (checkTerminated(done, empty)) { + return; + } + } + + if (e != 0) { + ON_NEXT_CALLS.addAndGet(this, (int) e); + if (r != Long.MAX_VALUE) { + produce(e); + } + } + + expectedState = markWorkDone(expectedState); + if (!isWorkInProgress(expectedState)) { + return; + } + } + } + + boolean checkTerminated(boolean done, boolean empty) { + final long state = this.requested; + if (state == Long.MAX_VALUE) { + this.subscription.cancel(); + clearAndFinalize(); + return true; + } + + if (done && empty) { + clearAndFinalize(); + return true; + } + + return false; + } + + final void produce(long produced) { + for (; ; ) { + final long s = this.requested; + + if ((s & Long.MIN_VALUE) != Long.MIN_VALUE) { + return; + } + + final long r = s & Long.MAX_VALUE; + if (r == Long.MAX_VALUE) { + return; + } + + final long u = r - produced; + if (REQUESTED.compareAndSet(this, s, u | Long.MIN_VALUE)) { + return; + } + } + } + + @SuppressWarnings("unchecked") + final void clearAndFinalize() { + final Queue q = (Queue) this.subscription; + for (; ; ) { + final int state = this.wip; + + q.clear(); + + if (WIP.compareAndSet(this, state, Integer.MIN_VALUE)) { + return; + } + } + } + + final int markWorkAdded() { + for (; ; ) { + final int state = this.wip; + + if (isFinalized(state)) { + return state; + } + + int nextState = state + 1; + if ((nextState & Integer.MAX_VALUE) == 0) { + return state; + } + + if (WIP.compareAndSet(this, state, nextState)) { + return state; + } + } + } + + final int markWorkDone(int expectedState) { + for (; ; ) { + final int state = this.wip; + + if (expectedState != state) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + if (WIP.compareAndSet(this, state, 0)) { + return 0; + } + } + } + + static boolean isFinalized(int state) { + return state == Integer.MIN_VALUE; + } + + static boolean isWorkInProgress(int state) { + return (state & Integer.MAX_VALUE) > 0; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java new file mode 100644 index 000000000..3b51b8ef6 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2020-Present Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Operators; + +public class StressSubscription implements Subscription { + + CoreSubscriber actual; + + public volatile int subscribes; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater SUBSCRIBES = + AtomicIntegerFieldUpdater.newUpdater(StressSubscription.class, "subscribes"); + + public volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(StressSubscription.class, "requested"); + + public volatile int requestsCount; + + @SuppressWarnings("rawtype s") + static final AtomicIntegerFieldUpdater REQUESTS_COUNT = + AtomicIntegerFieldUpdater.newUpdater(StressSubscription.class, "requestsCount"); + + public volatile boolean cancelled; + + void subscribe(CoreSubscriber actual) { + this.actual = actual; + actual.onSubscribe(this); + SUBSCRIBES.getAndIncrement(this); + } + + @Override + public void request(long n) { + REQUESTS_COUNT.incrementAndGet(this); + Operators.addCap(REQUESTED, this, n); + } + + @Override + public void cancel() { + cancelled = true; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java new file mode 100644 index 000000000..420da66ba --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -0,0 +1,39 @@ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import reactor.util.annotation.Nullable; + +public class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { + + @Nullable private final RequesterLeaseTracker requesterLeaseTracker; + + public TestRequesterResponderSupport( + DuplexConnection connection, StreamIdSupplier streamIdSupplier) { + this(connection, streamIdSupplier, null); + } + + public TestRequesterResponderSupport( + DuplexConnection connection, + StreamIdSupplier streamIdSupplier, + @Nullable RequesterLeaseTracker requesterLeaseTracker) { + super( + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + PayloadDecoder.ZERO_COPY, + connection, + streamIdSupplier, + __ -> null); + this.requesterLeaseTracker = requesterLeaseTracker; + } + + @Override + @Nullable + public RequesterLeaseTracker getRequesterLeaseTracker() { + return this.requesterLeaseTracker; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java b/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java new file mode 100644 index 000000000..22c478979 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java @@ -0,0 +1,155 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class UnpooledByteBufPayload extends AbstractReferenceCounted implements Payload { + + private final ByteBuf data; + private final ByteBuf metadata; + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data) { + return create(data, ByteBufAllocator.DEFAULT); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data, ByteBufAllocator allocator) { + return new UnpooledByteBufPayload(ByteBufUtil.writeUtf8(allocator, data), null); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata) { + return create(data, metadata, ByteBufAllocator.DEFAULT); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata, ByteBufAllocator allocator) { + return new UnpooledByteBufPayload( + ByteBufUtil.writeUtf8(allocator, data), + metadata == null ? null : ByteBufUtil.writeUtf8(allocator, metadata)); + } + + public UnpooledByteBufPayload(ByteBuf data, @Nullable ByteBuf metadata) { + this.data = data; + this.metadata = metadata; + } + + @Override + public boolean hasMetadata() { + ensureAccessible(); + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); + } + + @Override + public ByteBuf data() { + ensureAccessible(); + return data; + } + + @Override + public ByteBuf metadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + ensureAccessible(); + return data.slice(); + } + + @Override + public UnpooledByteBufPayload retain() { + super.retain(); + return this; + } + + @Override + public UnpooledByteBufPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public UnpooledByteBufPayload touch() { + ensureAccessible(); + data.touch(); + if (metadata != null) { + metadata.touch(); + } + return this; + } + + @Override + public UnpooledByteBufPayload touch(Object hint) { + ensureAccessible(); + data.touch(hint); + if (metadata != null) { + metadata.touch(hint); + } + return this; + } + + @Override + protected void deallocate() { + data.release(); + if (metadata != null) { + metadata.release(); + } + } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link UnpooledByteBufPayload#ensureAccessible()} to try to guard against + * using the buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java new file mode 100644 index 000000000..a2d9fcf4d --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java @@ -0,0 +1,1733 @@ +package io.rsocket.internal; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.core.StressSubscriber; +import io.rsocket.utils.FastLogger; +import java.util.Arrays; +import java.util.ConcurrentModificationException; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.Expect; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLL_Result; +import org.openjdk.jcstress.infra.results.LLL_Result; +import org.openjdk.jcstress.infra.results.L_Result; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.util.Logger; + +public abstract class UnboundedProcessorStressTest { + + static { + Hooks.onErrorDropped(t -> {}); + } + + final Logger logger = new FastLogger(getClass().getName()); + + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(logger); + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class SmokeStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class SmokeFusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke2StressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + if (stressSubscriber.onCompleteCalls > 0 && stressSubscriber.onErrorCalls > 0) { + throw new RuntimeException("boom"); + } + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke24StressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke2FusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class Smoke21FusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with onComplete()") + @State + public static class Smoke30StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void subscribeAndRequest() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke31StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void subscribeAndRequest() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + if (stressSubscriber.concurrentOnNext || stressSubscriber.concurrentOnComplete) { + throw new ConcurrentModificationException("boo"); + } + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke32StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = + new StressSubscriber<>(Long.MAX_VALUE, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0, 5", + "1, 1, 0, 5", + "2, 1, 0, 5", + "3, 1, 0, 5", + "4, 1, 0, 5", + "5, 1, 0, 5", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke33StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = + new StressSubscriber<>(Long.MAX_VALUE, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + final ByteBuf byteBuf5 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(5); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void next1() { + unboundedProcessor.tryEmitNormal(byteBuf1); + unboundedProcessor.tryEmitPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.tryEmitPrioritized(byteBuf3); + unboundedProcessor.tryEmitNormal(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.tryEmitFinal(byteBuf5); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + r.r4 = stressSubscriber.values.get(stressSubscriber.values.size() - 1).readByte(); + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = + byteBuf1.refCnt() + + byteBuf2.refCnt() + + byteBuf3.refCnt() + + byteBuf4.refCnt() + + byteBuf5.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = { + "-2954361355555045376, 4, 2, 0", + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 4, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 4, 0, 0", + "-7854277750134145024, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 2, 0", + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 3, 0, 0", + "-7854277750134145024, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 2, 0", + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 2, 0, 0", + "-7854277750134145024, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 2, 0", + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 1, 0, 0", + "-7854277750134145024, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 2, 0", + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 0, 0, 0", + "-7854277750134145024, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class RequestVsCancelVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class RequestVsCancelVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-2954361355555045376, 4, 2, 0", + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + "-4539628424389459968, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 4, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 4, 0, 0", + "-7854277750134145024, 4, 0, 0", + "-4539628424389459968, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 2, 0", + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + "-4539628424389459968, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 3, 0, 0", + "-7854277750134145024, 3, 0, 0", + "-4539628424389459968, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 2, 0", + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + "-4539628424389459968, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 2, 0, 0", + "-7854277750134145024, 2, 0, 0", + "-4539628424389459968, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 2, 0", + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + "-4539628424389459968, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 1, 0, 0", + "-7854277750134145024, 1, 0, 0", + "-4539628424389459968, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 2, 0", + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + "-4539628424389459968, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 0, 0, 0", + "-7854277750134145024, 0, 0, 0", + "-4539628424389459968, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class SubscribeWithFollowingRequestsVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + "-4539628424389459968, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 4, 0, 0", + "-4539628424389459968, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + "-4539628424389459968, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 3, 0, 0", + "-4539628424389459968, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + "-4539628424389459968, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 2, 0, 0", + "-4539628424389459968, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + "-4539628424389459968, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 1, 0, 0", + "-4539628424389459968, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + "-4539628424389459968, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 0, 0, 0", + "-4539628424389459968, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class SubscribeWithFollowingRequestsVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"-4539628424389459968, 0, 2, 0", "-3386706919782612992, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = {"-4395513236313604096, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> dispose() before anything") + @Outcome( + id = {"-3242591731706757120, 0, 2, 0", "-3242591731706757120, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> (dispose() || cancel())") + @Outcome( + id = {"-7854277750134145024, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> cancel() before anything") + @State + public static class SubscribeWithFollowingCancelVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndCancel() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"-4539628424389459968, 0, 2, 0", "-3386706919782612992, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = {"-4395513236313604096, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> dispose() before anything") + @Outcome( + id = {"-3242591731706757120, 0, 2, 0", "-3242591731706757120, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> (dispose() || cancel())") + @Outcome( + id = {"-7854277750134145024, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> cancel() before anything") + @State + public static class SubscribeWithFollowingCancelVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndCancel() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"1"}, + expect = Expect.ACCEPTABLE) + @State + public static class SubscribeVsSubscribeStressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(0, Fuseable.NONE); + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(0, Fuseable.NONE); + + @Actor + public void subscribe1() { + unboundedProcessor.subscribe(stressSubscriber1); + } + + @Actor + public void subscribe2() { + unboundedProcessor.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(L_Result r) { + r.r1 = stressSubscriber1.onErrorCalls + stressSubscriber2.onErrorCalls; + + checkOutcomes(this, r.toString(), logger); + } + } + + static void checkOutcomes(Object instance, String result, Logger logger) { + if (Arrays.stream(instance.getClass().getDeclaredAnnotationsByType(Outcome.class)) + .flatMap(o -> Arrays.stream(o.id())) + .noneMatch(s -> s.equalsIgnoreCase(result))) { + throw new RuntimeException(result + " " + logger); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java new file mode 100644 index 000000000..f0b209552 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java @@ -0,0 +1,118 @@ +package io.rsocket.resume; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LL_Result; +import reactor.core.Disposable; + +public class InMemoryResumableFramesStoreStressTest { + boolean storeClosed; + + InMemoryResumableFramesStore store = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 128); + boolean processorClosed; + UnboundedProcessor processor = new UnboundedProcessor(() -> processorClosed = true); + + void subscribe() { + store.saveFrames(processor).subscribe(); + store.onClose().subscribe(null, t -> storeClosed = true, () -> storeClosed = true); + } + + @JCStressTest + @Outcome( + id = {"true, true"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends InMemoryResumableFramesStoreStressTest { + + Disposable d1; + + final ByteBuf b1 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello1"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello2")); + final ByteBuf b2 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 3, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello3"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello4")); + final ByteBuf b3 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 5, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello5"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello6")); + + final ByteBuf c1 = + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new ConnectionErrorException("closed")); + + { + subscribe(); + d1 = store.doOnDiscard(ByteBuf.class, ByteBuf::release).subscribe(ByteBuf::release, t -> {}); + } + + @Actor + public void producer1() { + processor.tryEmitNormal(b1); + processor.tryEmitNormal(b2); + processor.tryEmitNormal(b3); + } + + @Actor + public void producer2() { + processor.tryEmitFinal(c1); + } + + @Actor + public void producer3() { + d1.dispose(); + store + .doOnDiscard(ByteBuf.class, ByteBuf::release) + .subscribe(ByteBuf::release, t -> {}) + .dispose(); + store + .doOnDiscard(ByteBuf.class, ByteBuf::release) + .subscribe(ByteBuf::release, t -> {}) + .dispose(); + store.doOnDiscard(ByteBuf.class, ByteBuf::release).subscribe(ByteBuf::release, t -> {}); + } + + @Actor + public void producer4() { + store.releaseFrames(0); + store.releaseFrames(0); + store.releaseFrames(0); + } + + @Arbiter + public void arbiter(LL_Result r) { + r.r1 = storeClosed; + r.r2 = processorClosed; + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java b/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java new file mode 100644 index 000000000..c301d87cf --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java @@ -0,0 +1,137 @@ +package io.rsocket.utils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import reactor.util.Logger; + +/** + * Implementation of {@link Logger} which is based on the {@link ThreadLocal} based queue which + * collects all the events on the per-thread basis.
Such logger is designed to have all events + * stored during the stress-test run and then sorted and printed out once all the Threads completed + * execution (inside the {@link org.openjdk.jcstress.annotations.Arbiter} annotated method.
+ * Note, this implementation only supports trace-level logs and ignores all others, it is intended + * to be used by {@link reactor.core.publisher.StateLogger}. + */ +public class FastLogger implements Logger { + + final Map> queues = new ConcurrentHashMap<>(); + + final ThreadLocal> logsQueueLocal = + ThreadLocal.withInitial( + () -> { + final ArrayList logs = new ArrayList<>(100); + queues.put(Thread.currentThread(), logs); + return logs; + }); + + private final String name; + + public FastLogger(String name) { + this.name = name; + } + + @Override + public String toString() { + return queues + .values() + .stream() + .flatMap(List::stream) + .sorted( + Comparator.comparingLong( + s -> { + Pattern pattern = Pattern.compile("\\[(.*?)]"); + Matcher matcher = pattern.matcher(s); + matcher.find(); + return Long.parseLong(matcher.group(1)); + })) + .collect(Collectors.joining("\n")); + } + + @Override + public String getName() { + return this.name; + } + + @Override + public boolean isTraceEnabled() { + return true; + } + + @Override + public void trace(String msg) { + logsQueueLocal.get().add(String.format("[%s] %s", System.nanoTime(), msg)); + } + + @Override + public void trace(String format, Object... arguments) { + trace(String.format(format, arguments)); + } + + @Override + public void trace(String msg, Throwable t) { + trace(String.format("%s, %s", msg, Arrays.toString(t.getStackTrace()))); + } + + @Override + public boolean isDebugEnabled() { + return false; + } + + @Override + public void debug(String msg) {} + + @Override + public void debug(String format, Object... arguments) {} + + @Override + public void debug(String msg, Throwable t) {} + + @Override + public boolean isInfoEnabled() { + return false; + } + + @Override + public void info(String msg) {} + + @Override + public void info(String format, Object... arguments) {} + + @Override + public void info(String msg, Throwable t) {} + + @Override + public boolean isWarnEnabled() { + return false; + } + + @Override + public void warn(String msg) {} + + @Override + public void warn(String format, Object... arguments) {} + + @Override + public void warn(String msg, Throwable t) {} + + @Override + public boolean isErrorEnabled() { + return false; + } + + @Override + public void error(String msg) {} + + @Override + public void error(String format, Object... arguments) {} + + @Override + public void error(String msg, Throwable t) {} +} diff --git a/rsocket-core/src/jcstress/resources/logback.xml b/rsocket-core/src/jcstress/resources/logback.xml new file mode 100644 index 000000000..e5877552c --- /dev/null +++ b/rsocket-core/src/jcstress/resources/logback.xml @@ -0,0 +1,39 @@ + + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java index 6190d24e3..fe91f4bf0 100644 --- a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,41 +18,30 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import java.net.SocketAddress; import java.nio.channels.ClosedChannelException; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; /** Represents a connection with input/output that the protocol uses. */ public interface DuplexConnection extends Availability, Closeable { /** - * Sends the source of Frames on this connection and returns the {@code Publisher} representing - * the result of this send. + * Delivers the given frame to the underlying transport connection. This method is non-blocking + * and can be safely executed from multiple threads. This method does not provide any flow-control + * mechanism. * - *

Flow control - * - *

The passed {@code Publisher} must - * - * @param frames Stream of {@code Frame}s to send on the connection. - * @return {@code Publisher} that completes when all the frames are written on the connection - * successfully and errors when it fails. - * @throws NullPointerException if {@code frames} is {@code null} + * @param streamId to which the given frame relates + * @param frame with the encoded content */ - Mono send(Publisher frames); + void sendFrame(int streamId, ByteBuf frame); /** - * Sends a single {@code Frame} on this connection and returns the {@code Publisher} representing - * the result of this send. + * Send an error frame and after it is successfully sent, close the connection. * - * @param frame {@code Frame} to send. - * @return {@code Publisher} that completes when the frame is written on the connection - * successfully and errors when it fails. + * @param errorException to encode in the error frame */ - default Mono sendOne(ByteBuf frame) { - return send(Mono.just(frame)); - } + void sendErrorAndClose(RSocketErrorException errorException); /** * Returns a stream of all {@code Frame}s received on this connection. @@ -60,7 +49,7 @@ default Mono sendOne(ByteBuf frame) { *

Completion * *

Returned {@code Publisher} MUST never emit a completion event ({@link - * Subscriber#onComplete()}. + * Subscriber#onComplete()}). * *

Error * @@ -86,6 +75,17 @@ default Mono sendOne(ByteBuf frame) { */ ByteBufAllocator alloc(); + /** + * Return the remote address that this connection is connected to. The returned {@link + * SocketAddress} varies by transport type and should be downcast to obtain more detailed + * information. For TCP and WebSocket, the address type is {@link java.net.InetSocketAddress}. For + * local transport, it is {@link io.rsocket.transport.local.LocalSocketAddress}. + * + * @return the address + * @since 1.1 + */ + SocketAddress remoteAddress(); + @Override default double availability() { return isDisposed() ? 0.0 : 1.0; diff --git a/rsocket-core/src/main/java/io/rsocket/RSocket.java b/rsocket-core/src/main/java/io/rsocket/RSocket.java index 773c93dc2..b05241365 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocket.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,8 +34,7 @@ public interface RSocket extends Availability, Closeable { * handled, otherwise errors. */ default Mono fireAndForget(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Fire-and-Forget not implemented.")); + return RSocketAdapter.fireAndForget(payload); } /** @@ -46,8 +45,7 @@ default Mono fireAndForget(Payload payload) { * response. */ default Mono requestResponse(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Request-Response not implemented.")); + return RSocketAdapter.requestResponse(payload); } /** @@ -57,8 +55,7 @@ default Mono requestResponse(Payload payload) { * @return {@code Publisher} containing the stream of {@code Payload}s representing the response. */ default Flux requestStream(Payload payload) { - payload.release(); - return Flux.error(new UnsupportedOperationException("Request-Stream not implemented.")); + return RSocketAdapter.requestStream(payload); } /** @@ -68,7 +65,7 @@ default Flux requestStream(Payload payload) { * @return Stream of response payloads. */ default Flux requestChannel(Publisher payloads) { - return Flux.error(new UnsupportedOperationException("Request-Channel not implemented.")); + return RSocketAdapter.requestChannel(payloads); } /** @@ -79,8 +76,7 @@ default Flux requestChannel(Publisher payloads) { * handled, otherwise errors. */ default Mono metadataPush(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Metadata-Push not implemented.")); + return RSocketAdapter.metadataPush(payload); } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java b/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java new file mode 100644 index 000000000..b5a64b8dd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Package private class with default implementations for use in {@link RSocket}. The main purpose + * is to hide static {@link UnsupportedOperationException} declarations. + * + * @since 1.0.3 + */ +class RSocketAdapter { + + private static final Mono UNSUPPORTED_FIRE_AND_FORGET = + Mono.error(new UnsupportedInteractionException("Fire-and-Forget")); + + private static final Mono UNSUPPORTED_REQUEST_RESPONSE = + Mono.error(new UnsupportedInteractionException("Request-Response")); + + private static final Flux UNSUPPORTED_REQUEST_STREAM = + Flux.error(new UnsupportedInteractionException("Request-Stream")); + + private static final Flux UNSUPPORTED_REQUEST_CHANNEL = + Flux.error(new UnsupportedInteractionException("Request-Channel")); + + private static final Mono UNSUPPORTED_METADATA_PUSH = + Mono.error(new UnsupportedInteractionException("Metadata-Push")); + + static Mono fireAndForget(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_FIRE_AND_FORGET; + } + + static Mono requestResponse(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_REQUEST_RESPONSE; + } + + static Flux requestStream(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_REQUEST_STREAM; + } + + static Flux requestChannel(Publisher payloads) { + return RSocketAdapter.UNSUPPORTED_REQUEST_CHANNEL; + } + + static Mono metadataPush(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_METADATA_PUSH; + } + + private static class UnsupportedInteractionException extends RuntimeException { + + private static final long serialVersionUID = 5084623297446471999L; + + UnsupportedInteractionException(String interactionName) { + super(interactionName + " not implemented.", null, false, false); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java index 4b07f04c7..e19d31924 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java @@ -20,15 +20,13 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameUtil; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; import io.rsocket.plugins.InitializingInterceptorRegistry; +import java.net.SocketAddress; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -49,21 +47,14 @@ */ class ClientServerInputMultiplexer implements CoreSubscriber, Closeable { - private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); - private static final InitializingInterceptorRegistry emptyInterceptorRegistry = - new InitializingInterceptorRegistry(); - - private final InternalDuplexConnection setupReceiver; private final InternalDuplexConnection serverReceiver; private final InternalDuplexConnection clientReceiver; - private final DuplexConnection setupConnection; private final DuplexConnection serverConnection; private final DuplexConnection clientConnection; private final DuplexConnection source; private final boolean isClient; private Subscription s; - private boolean setupReceived; private Throwable t; @@ -71,45 +62,25 @@ class ClientServerInputMultiplexer implements CoreSubscriber, Closeable private static final AtomicIntegerFieldUpdater STATE = AtomicIntegerFieldUpdater.newUpdater(ClientServerInputMultiplexer.class, "state"); - public ClientServerInputMultiplexer(DuplexConnection source) { - this(source, emptyInterceptorRegistry, false); - } - public ClientServerInputMultiplexer( DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { this.source = source; this.isClient = isClient; - source = registry.initConnection(Type.SOURCE, source); - if (!isClient) { - setupReceiver = new InternalDuplexConnection(this, source); - setupConnection = registry.initConnection(Type.SETUP, setupReceiver); - } else { - setupReceiver = null; - setupConnection = null; - } - serverReceiver = new InternalDuplexConnection(this, source); - clientReceiver = new InternalDuplexConnection(this, source); - serverConnection = registry.initConnection(Type.SERVER, serverReceiver); - clientConnection = registry.initConnection(Type.CLIENT, clientReceiver); + this.serverReceiver = new InternalDuplexConnection(Type.SERVER, this, source); + this.clientReceiver = new InternalDuplexConnection(Type.CLIENT, this, source); + this.serverConnection = registry.initConnection(Type.SERVER, serverReceiver); + this.clientConnection = registry.initConnection(Type.CLIENT, clientReceiver); } - public DuplexConnection asClientServerConnection() { - return source; - } - - public DuplexConnection asServerConnection() { + DuplexConnection asServerConnection() { return serverConnection; } - public DuplexConnection asClientConnection() { + DuplexConnection asClientConnection() { return clientConnection; } - public DuplexConnection asSetupConnection() { - return setupConnection; - } - @Override public void dispose() { source.dispose(); @@ -129,12 +100,7 @@ public Mono onClose() { public void onSubscribe(Subscription s) { if (Operators.validate(this.s, s)) { this.s = s; - if (isClient) { - s.request(Long.MAX_VALUE); - } else { - // request first SetupFrame - s.request(1); - } + s.request(Long.MAX_VALUE); } } @@ -144,12 +110,6 @@ public void onNext(ByteBuf frame) { final Type type; if (streamId == 0) { switch (FrameHeaderCodec.frameType(frame)) { - case SETUP: - case RESUME: - case RESUME_OK: - type = Type.SETUP; - setupReceived = true; - break; case LEASE: case KEEPALIVE: case ERROR: @@ -163,19 +123,8 @@ public void onNext(ByteBuf frame) { } else { type = Type.CLIENT; } - if (!isClient && type != Type.SETUP && !setupReceived) { - final IllegalStateException error = - new IllegalStateException("SETUP or LEASE frame must be received before any others."); - this.s.cancel(); - onError(error); - } switch (type) { - case SETUP: - final InternalDuplexConnection setupReceiver = this.setupReceiver; - setupReceiver.onNext(frame); - setupReceiver.onComplete(); - break; case CLIENT: clientReceiver.onNext(frame); break; @@ -192,16 +141,6 @@ public void onComplete() { return; } - if (!isClient) { - if (!setupReceived) { - setupReceiver.onComplete(); - } - - if (previousState == 1) { - return; - } - } - if (clientReceiver.isSubscribed()) { clientReceiver.onComplete(); } @@ -219,16 +158,6 @@ public void onError(Throwable t) { return; } - if (!isClient) { - if (!setupReceived) { - setupReceiver.onError(t); - } - - if (previousState == 1) { - return; - } - } - if (clientReceiver.isSubscribed()) { clientReceiver.onError(t); } @@ -243,17 +172,8 @@ boolean notifyRequested() { return false; } - if (isClient) { - if (currentState == 2) { - source.receive().subscribe(this); - } - } else { - if (currentState == 1) { - source.receive().subscribe(this); - } else if (currentState == 3) { - // means setup was consumed and we got request from client and server multiplexers - s.request(Long.MAX_VALUE); - } + if (currentState == 2) { + source.receive().subscribe(this); } return true; @@ -275,11 +195,35 @@ int incrementAndGetCheckingState() { } } + @Override + public String toString() { + return "ClientServerInputMultiplexer{" + + "serverReceiver=" + + serverReceiver + + ", clientReceiver=" + + clientReceiver + + ", serverConnection=" + + serverConnection + + ", clientConnection=" + + clientConnection + + ", source=" + + source + + ", isClient=" + + isClient + + ", s=" + + s + + ", t=" + + t + + ", state=" + + state + + '}'; + } + private static class InternalDuplexConnection extends Flux implements Subscription, DuplexConnection { + private final Type type; private final ClientServerInputMultiplexer clientServerInputMultiplexer; private final DuplexConnection source; - private final boolean debugEnabled; private volatile int state; static final AtomicIntegerFieldUpdater STATE = @@ -288,10 +232,12 @@ private static class InternalDuplexConnection extends Flux CoreSubscriber actual; public InternalDuplexConnection( - ClientServerInputMultiplexer clientServerInputMultiplexer, DuplexConnection source) { + Type type, + ClientServerInputMultiplexer clientServerInputMultiplexer, + DuplexConnection source) { + this.type = type; this.clientServerInputMultiplexer = clientServerInputMultiplexer; this.source = source; - this.debugEnabled = LOGGER.isDebugEnabled(); } @Override @@ -339,32 +285,18 @@ void onError(Throwable t) { } @Override - public Mono send(Publisher frame) { - if (debugEnabled) { - return Flux.from(frame) - .doOnNext(f -> LOGGER.debug("sending -> " + FrameUtil.toString(f))) - .as(source::send); - } - - return source.send(frame); + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); } @Override - public Mono sendOne(ByteBuf frame) { - if (debugEnabled) { - LOGGER.debug("sending -> " + FrameUtil.toString(frame)); - } - - return source.sendOne(frame); + public void sendErrorAndClose(RSocketErrorException e) { + source.sendErrorAndClose(e); } @Override public Flux receive() { - if (debugEnabled) { - return this.doOnNext(frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); - } else { - return this; - } + return this; } @Override @@ -372,6 +304,11 @@ public ByteBufAllocator alloc() { return source.alloc(); } + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + @Override public void dispose() { source.dispose(); @@ -395,5 +332,17 @@ public Mono onClose() { public double availability() { return source.availability(); } + + @Override + public String toString() { + return "InternalDuplexConnection{" + + "type=" + + type + + ", source=" + + source + + ", state=" + + state + + '}'; + } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java new file mode 100644 index 000000000..3477b8d6d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java @@ -0,0 +1,49 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.DuplexConnection; +import java.nio.channels.ClosedChannelException; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +abstract class ClientSetup { + abstract Mono> init(DuplexConnection connection); +} + +class DefaultClientSetup extends ClientSetup { + + @Override + Mono> init(DuplexConnection connection) { + return Mono.create( + sink -> sink.onRequest(__ -> sink.success(Tuples.of(Unpooled.EMPTY_BUFFER, connection)))); + } +} + +class ResumableClientSetup extends ClientSetup { + + @Override + Mono> init(DuplexConnection connection) { + return Mono.create( + sink -> { + sink.onRequest( + __ -> { + new SetupHandlingDuplexConnection(connection, sink); + }); + + Disposable subscribe = + connection + .onClose() + .doFinally(__ -> sink.error(new ClosedChannelException())) + .subscribe(); + sink.onCancel( + () -> { + subscribe.dispose(); + connection.dispose(); + connection.receive().subscribe(); + }); + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java index 4dc250158..82a02268d 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java @@ -35,6 +35,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoOperator; import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; import reactor.util.context.Context; @@ -45,13 +46,16 @@ */ class DefaultRSocketClient extends ResolvingOperator implements CoreSubscriber, CorePublisher, RSocketClient { - static final Consumer DISCARD_ELEMENTS_CONSUMER = - referenceCounted -> { - if (referenceCounted.refCnt() > 0) { - try { - referenceCounted.release(); - } catch (IllegalReferenceCountException e) { - // ignored + static final Consumer DISCARD_ELEMENTS_CONSUMER = + data -> { + if (data instanceof ReferenceCounted) { + ReferenceCounted referenceCounted = ((ReferenceCounted) data); + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } } } }; @@ -65,6 +69,8 @@ class DefaultRSocketClient extends ResolvingOperator final Mono source; + final Sinks.Empty onDisposeSink; + volatile Subscription s; static final AtomicReferenceFieldUpdater S = @@ -72,12 +78,18 @@ class DefaultRSocketClient extends ResolvingOperator DefaultRSocketClient(Mono source) { this.source = unwrapReconnectMono(source); + this.onDisposeSink = Sinks.empty(); } private Mono unwrapReconnectMono(Mono source) { return source instanceof ReconnectMono ? ((ReconnectMono) source).getSource() : source; } + @Override + public Mono onClose() { + return this.onDisposeSink.asMono(); + } + @Override public Mono source() { return Mono.fromDirect(this); @@ -194,6 +206,12 @@ protected void doOnValueExpired(RSocket value) { @Override protected void doOnDispose() { Operators.terminate(S, this); + final RSocket value = this.value; + if (value != null) { + value.onClose().subscribe(null, onDisposeSink::tryEmitError, onDisposeSink::tryEmitEmpty); + } else { + onDisposeSink.tryEmitEmpty(); + } } static final class FlatMapMain implements CoreSubscriber, Context, Scannable { @@ -435,8 +453,8 @@ public void accept(RSocket rSocket, Throwable t) { @Override public void request(long n) { - this.main.request(n); super.request(n); + this.main.request(n); } public void cancel() { diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java index 3d7a3dfa7..a5d527f5c 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -20,12 +20,12 @@ import static io.rsocket.core.SendUtils.sendReleasingPayload; import static io.rsocket.core.StateUtils.*; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.frame.FrameType; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.RequestInterceptor; import java.time.Duration; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; @@ -50,7 +50,9 @@ final class FireAndForgetRequesterMono extends Mono implements Subscriptio final int mtu; final int maxFrameLength; final RequesterResponderSupport requesterResponderSupport; - final UnboundedProcessor sendProcessor; + final DuplexConnection connection; + + @Nullable final RequestInterceptor requestInterceptor; FireAndForgetRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { this.allocator = requesterResponderSupport.getAllocator(); @@ -58,15 +60,23 @@ final class FireAndForgetRequesterMono extends Mono implements Subscriptio this.mtu = requesterResponderSupport.getMtu(); this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("FireAndForgetMono allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); return; } @@ -77,14 +87,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - p.release(); - actual.onError( + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + actual.onError(e); return; } @@ -94,26 +118,54 @@ public void subscribe(CoreSubscriber actual) { streamId = this.requesterResponderSupport.getNextStreamId(); } catch (Throwable t) { lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); - actual.onError(Exceptions.unwrap(t)); + + actual.onError(ut); return; } + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + try { if (isTerminated(this.state)) { p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + return; } sendReleasingPayload( - streamId, FrameType.REQUEST_FNF, mtu, p, this.sendProcessor, this.allocator, true); + streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); } catch (Throwable e) { lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + actual.onError(e); return; } lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + actual.onComplete(); } @@ -133,24 +185,51 @@ public Void block(Duration m) { return block(); } + /** + * This method is deliberately non-blocking regardless it is named as `.block`. The main intent to + * keep this method along with the {@link #subscribe()} is to eliminate redundancy which comes + * with a default block method implementation. + */ @Override @Nullable public Void block() { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - throw new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw e; } final Payload p = this.payload; try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); - throw new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + throw e; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw Exceptions.propagate(e); } @@ -159,25 +238,47 @@ public Void block() { streamId = this.requesterResponderSupport.getNextStreamId(); } catch (Throwable t) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(Exceptions.unwrap(t), FrameType.REQUEST_FNF, p.metadata()); + } + p.release(); + throw Exceptions.propagate(t); } + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + try { sendReleasingPayload( streamId, FrameType.REQUEST_FNF, this.mtu, this.payload, - this.sendProcessor, + this.connection, this.allocator, true); } catch (Throwable e) { lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + throw Exceptions.propagate(e); } lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + return null; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java index 3a2363d47..e76fdf9ed 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -21,12 +21,15 @@ import io.netty.util.ReferenceCountUtil; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.CoreSubscriber; import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; final class FireAndForgetResponderSubscriber implements CoreSubscriber, ResponderFrameHandler { @@ -42,6 +45,8 @@ final class FireAndForgetResponderSubscriber final RSocket handler; final int maxInboundPayloadSize; + @Nullable final RequestInterceptor requestInterceptor; + CompositeByteBuf frames; private FireAndForgetResponderSubscriber() { @@ -51,6 +56,19 @@ private FireAndForgetResponderSubscriber() { this.maxInboundPayloadSize = 0; this.requesterResponderSupport = null; this.handler = null; + this.requestInterceptor = null; + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.frames = null; } @@ -65,6 +83,7 @@ private FireAndForgetResponderSubscriber() { this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; this.handler = handler; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.frames = ReassemblyUtils.addFollowingFrame( @@ -81,11 +100,21 @@ public void onNext(Void voidVal) {} @Override public void onError(Throwable t) { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); + } + logger.debug("Dropped Outbound error", t); } @Override - public void onComplete() {} + public void onComplete() { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, null); + } + } @Override public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { @@ -95,11 +124,17 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas ReassemblyUtils.addFollowingFrame( frames, followingFrame, hasFollows, this.maxInboundPayloadSize); } catch (IllegalStateException t) { - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); this.frames = null; frames.release(); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_FNF, t); + } + logger.debug("Reassembly has failed", t); return; } @@ -114,6 +149,12 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas frames.release(); } catch (Throwable t) { ReferenceCountUtil.safeRelease(frames); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); + } + logger.debug("Reassembly has failed", t); return; } @@ -127,9 +168,16 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas public final void handleCancel() { final CompositeByteBuf frames = this.frames; if (frames != null) { - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + this.frames = null; frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java b/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java new file mode 100644 index 000000000..03ab7c257 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java @@ -0,0 +1,20 @@ +package io.rsocket.core; + +/** Handler which enables async lease permits issuing */ +interface LeasePermitHandler { + + /** + * Called by {@link RequesterLeaseTracker} when there is an available lease + * + * @return {@code true} to indicate that lease permit was consumed successfully + */ + boolean handlePermit(); + + /** + * Called by {@link RequesterLeaseTracker} when there are no lease permit available at the moment + * and the list of awaiting {@link LeasePermitHandler} reached the configured limit + * + * @param t associated lease permit rejection exception + */ + void handlePermitError(Throwable t); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java b/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java new file mode 100644 index 000000000..ad4b36e3a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.rsocket.lease.LeaseSender; +import reactor.core.publisher.Flux; + +public final class LeaseSpec { + + LeaseSender sender = Flux::never; + int maxPendingRequests = 256; + + LeaseSpec() {} + + public LeaseSpec sender(LeaseSender sender) { + this.sender = sender; + return this; + } + + /** + * Setup the maximum queued requests waiting for lease to be available. The default value is 256 + * + * @param maxPendingRequests if set to 0 the requester will terminate the request immediately if + * no leases is available + */ + public LeaseSpec maxPendingRequests(int maxPendingRequests) { + this.maxPendingRequests = maxPendingRequests; + return this; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java new file mode 100644 index 000000000..7b5d8f6c2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java @@ -0,0 +1,72 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameUtil; +import java.net.SocketAddress; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +class LoggingDuplexConnection implements DuplexConnection { + + private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); + + final DuplexConnection source; + + LoggingDuplexConnection(DuplexConnection source) { + this.source = source; + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + LOGGER.debug("sending -> " + FrameUtil.toString(frame)); + + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + LOGGER.debug("sending -> " + e.getClass().getSimpleName() + ": " + e.getMessage()); + + source.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return source + .receive() + .doOnNext(frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + static DuplexConnection wrapIfEnabled(DuplexConnection source) { + if (LOGGER.isDebugEnabled()) { + return new LoggingDuplexConnection(source); + } + + return source; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java index 3a53b0ad8..e2512e995 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java @@ -22,9 +22,9 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.frame.MetadataPushFrameCodec; -import io.rsocket.internal.UnboundedProcessor; import java.time.Duration; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import reactor.core.CoreSubscriber; @@ -43,13 +43,13 @@ final class MetadataPushRequesterMono extends Mono implements Scannable { final ByteBufAllocator allocator; final Payload payload; final int maxFrameLength; - final UnboundedProcessor sendProcessor; + final DuplexConnection connection; MetadataPushRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { this.allocator = requesterResponderSupport.getAllocator(); this.payload = payload; this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); } @Override @@ -109,7 +109,7 @@ public void subscribe(CoreSubscriber actual) { final ByteBuf requestFrame = MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); - this.sendProcessor.onNext(requestFrame); + this.connection.sendFrame(0, requestFrame); Operators.complete(actual); } @@ -120,6 +120,11 @@ public Void block(Duration m) { return block(); } + /** + * This method is deliberately non-blocking regardless it is named as `.block`. The main intent to + * keep this method along with the {@link #subscribe()} is to eliminate redundancy which comes + * with a default block method implementation. + */ @Override @Nullable public Void block() { @@ -133,15 +138,16 @@ public Void block() { try { final boolean hasMetadata = p.hasMetadata(); metadata = p.metadata(); - if (hasMetadata) { + if (!hasMetadata) { lazyTerminate(STATE, this); p.release(); - throw new IllegalArgumentException("Metadata push does not support metadata field"); + throw new IllegalArgumentException("Metadata push should have metadata field present"); } if (!isValidMetadata(this.maxFrameLength, metadata)) { lazyTerminate(STATE, this); p.release(); - throw new IllegalArgumentException("Too Big Payload size"); + throw new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); @@ -166,7 +172,7 @@ public Void block() { final ByteBuf requestFrame = MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); - this.sendProcessor.onNext(requestFrame); + this.connection.sendFrame(0, requestFrame); return null; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java index 81392e661..32e3c229d 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java @@ -15,12 +15,13 @@ */ package io.rsocket.core; +import io.rsocket.Closeable; import io.rsocket.Payload; import io.rsocket.RSocket; import org.reactivestreams.Publisher; -import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import sun.reflect.generics.reflectiveObjects.NotImplementedException; /** * Contract for performing RSocket requests. @@ -74,7 +75,22 @@ * @since 1.1 * @see io.rsocket.loadbalance.LoadbalanceRSocketClient */ -public interface RSocketClient extends Disposable { +public interface RSocketClient extends Closeable { + + /** + * Connect to the remote rsocket endpoint, if not yet connected. This method is a shortcut for + * {@code RSocketClient#source().subscribe()}. + * + * @return {@code true} if an attempt to connect was triggered or if already connected, or {@code + * false} if the client is terminated. + */ + default boolean connect() { + throw new NotImplementedException(); + } + + default Mono onClose() { + return Mono.error(new NotImplementedException()); + } /** Return the underlying source used to obtain a shared {@link RSocket} connection. */ Mono source(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java index cc94f4102..ae8b7da97 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java @@ -41,11 +41,21 @@ public RSocket rsocket() { return rsocket; } + @Override + public boolean connect() { + throw new UnsupportedOperationException("Connect does not apply to a server side RSocket"); + } + @Override public Mono source() { return Mono.just(rsocket); } + @Override + public Mono onClose() { + return rsocket.onClose(); + } + @Override public Mono fireAndForget(Mono payloadMono) { return payloadMono.flatMap(rsocket::fireAndForget); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java index 0058106bc..de494c4e3 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -29,13 +29,14 @@ import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.keepalive.KeepAliveHandler; -import io.rsocket.lease.LeaseStats; -import io.rsocket.lease.Leases; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.lease.TrackingLeaseSender; +import io.rsocket.plugins.DuplexConnectionInterceptor; import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.plugins.RequestInterceptor; import io.rsocket.resume.ClientRSocketSession; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumableFramesStore; import io.rsocket.transport.ClientTransport; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; @@ -46,6 +47,7 @@ import java.util.function.Supplier; import reactor.core.Disposable; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; import reactor.util.function.Tuples; import reactor.util.retry.Retry; @@ -58,18 +60,20 @@ *

{@code
  * import io.rsocket.transport.netty.client.TcpClientTransport;
  *
- * RSocketClient client =
- *         RSocketConnector.createRSocketClient(TcpClientTransport.create("localhost", 7000));
+ * Mono source =
+ *         RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000));
+ * RSocketClient client = RSocketClient.from(source);
  * }
* *

To customize connection settings before connecting: * *

{@code
- * RSocketClient client =
+ * Mono source =
  *         RSocketConnector.create()
  *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
  *                 .dataMimeType("application/cbor")
- *                 .toRSocketClient(TcpClientTransport.create("localhost", 7000));
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ * RSocketClient client = RSocketClient.from(source);
  * }
*/ public class RSocketConnector { @@ -89,7 +93,8 @@ public class RSocketConnector { private Retry retrySpec; private Resume resume; - private Supplier> leasesSupplier; + + @Nullable private Consumer leaseConfigurer; private int mtu = 0; private int maxInboundPayloadSize = Integer.MAX_VALUE; @@ -109,7 +114,7 @@ public static RSocketConnector create() { * Static factory method to connect with default settings, effectively a shortcut for: * *
-   * RSocketConnector.create().connectWith(transport);
+   * RSocketConnector.create().connect(transport);
    * 
* * @param transport the transport of choice to connect with @@ -398,18 +403,43 @@ public RSocketConnector resume(Resume resume) { * *
{@code
    * Mono rocketMono =
-   *         RSocketConnector.create().lease(Leases::new).connect(transport);
+   *         RSocketConnector.create()
+   *                         .lease()
+   *                         .connect(transport);
    * }
* *

By default this is not enabled. * - * @param supplier supplier for a {@link Leases} * @return the same instance for method chaining * @see Lease * Semantics */ - public RSocketConnector lease(Supplier> supplier) { - this.leasesSupplier = supplier; + public RSocketConnector lease() { + return lease((config -> {})); + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. + * + *

Example usage: + * + *

{@code
+   * Mono rocketMono =
+   *         RSocketConnector.create()
+   *                         .lease(spec -> spec.maxPendingRequests(128))
+   *                         .connect(transport);
+   * }
+ * + *

By default this is not enabled. + * + * @param leaseConfigurer consumer which accepts {@link LeaseSpec} and use it for configuring + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketConnector lease(Consumer leaseConfigurer) { + this.leaseConfigurer = leaseConfigurer; return this; } @@ -519,7 +549,12 @@ public Mono connect(Supplier transportSupplier) { assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); return ct; }) - .flatMap(transport -> transport.connect()); + .flatMap(transport -> transport.connect()) + .map( + sourceConnection -> + interceptors.initConnection( + DuplexConnectionInterceptor.Type.SOURCE, sourceConnection)) + .map(source -> LoggingDuplexConnection.wrapIfEnabled(source)); return connectionMono .flatMap( @@ -530,65 +565,24 @@ public Mono connect(Supplier transportSupplier) { .doOnError(ex -> connection.dispose()) .doOnCancel(connection::dispose)) .flatMap( - tuple -> { - DuplexConnection connection = tuple.getT1(); - Payload setupPayload = tuple.getT2(); + tuple2 -> { + DuplexConnection sourceConnection = tuple2.getT1(); + Payload setupPayload = tuple2.getT2(); + boolean leaseEnabled = leaseConfigurer != null; + boolean resumeEnabled = resume != null; + // TODO: add LeaseClientSetup + ClientSetup clientSetup = new DefaultClientSetup(); ByteBuf resumeToken; - KeepAliveHandler keepAliveHandler; - DuplexConnection wrappedConnection; - if (resume != null) { + if (resumeEnabled) { resumeToken = resume.getTokenSupplier().get(); - ClientRSocketSession session = - new ClientRSocketSession( - connection, - resume.getSessionDuration(), - resume.getRetry(), - resume.getStoreFactory(CLIENT_TAG).apply(resumeToken), - resume.getStreamTimeout(), - resume.isCleanupStoreOnKeepAlive()) - .continueWith(connectionMono) - .resumeToken(resumeToken); - keepAliveHandler = - new KeepAliveHandler.ResumableKeepAliveHandler( - session.resumableConnection()); - wrappedConnection = session.resumableConnection(); } else { resumeToken = Unpooled.EMPTY_BUFFER; - keepAliveHandler = - new KeepAliveHandler.DefaultKeepAliveHandler(connection); - wrappedConnection = connection; } - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(wrappedConnection, interceptors, true); - - boolean leaseEnabled = leasesSupplier != null; - Leases leases = leaseEnabled ? leasesSupplier.get() : null; - RequesterLeaseHandler requesterLeaseHandler = - leaseEnabled - ? new RequesterLeaseHandler.Impl(CLIENT_TAG, leases.receiver()) - : RequesterLeaseHandler.None; - - RSocket rSocketRequester = - new RSocketRequester( - multiplexer.asClientConnection(), - payloadDecoder, - StreamIdSupplier.clientSupplier(), - mtu, - maxFrameLength, - maxInboundPayloadSize, - (int) keepAliveInterval.toMillis(), - (int) keepAliveMaxLifeTime.toMillis(), - keepAliveHandler, - requesterLeaseHandler); - - RSocket wrappedRSocketRequester = - interceptors.initRequester(rSocketRequester); - ByteBuf setupFrame = SetupFrameCodec.encode( - wrappedConnection.alloc(), + sourceConnection.alloc(), leaseEnabled, (int) keepAliveInterval.toMillis(), (int) keepAliveMaxLifeTime.toMillis(), @@ -597,46 +591,146 @@ public Mono connect(Supplier transportSupplier) { dataMimeType, setupPayload); - SocketAcceptor acceptor = - this.acceptor != null - ? this.acceptor - : SocketAcceptor.with(new RSocket() {}); - - ConnectionSetupPayload setup = - new DefaultConnectionSetupPayload(setupFrame); + sourceConnection.sendFrame(0, setupFrame.retainedSlice()); - return interceptors - .initSocketAcceptor(acceptor) - .accept(setup, wrappedRSocketRequester) + return clientSetup + .init(sourceConnection) .flatMap( - rSocketHandler -> { - RSocket wrappedRSocketHandler = - interceptors.initResponder(rSocketHandler); - - ResponderLeaseHandler responderLeaseHandler = - leaseEnabled - ? new ResponderLeaseHandler.Impl<>( - CLIENT_TAG, - wrappedConnection.alloc(), - leases.sender(), - leases.stats()) - : ResponderLeaseHandler.None; - - RSocket rSocketResponder = - new RSocketResponder( - multiplexer.asServerConnection(), - wrappedRSocketHandler, + tuple -> { + // should be used if lease setup sequence; + // See: + // https://github.com/rsocket/rsocket/blob/master/Protocol.md#sequences-with-lease + final ByteBuf serverResponse = tuple.getT1(); + final DuplexConnection clientServerConnection = tuple.getT2(); + final KeepAliveHandler keepAliveHandler; + final DuplexConnection wrappedConnection; + final InitializingInterceptorRegistry interceptors = + this.interceptors; + + if (resumeEnabled) { + final ResumableFramesStore resumableFramesStore = + resume.getStoreFactory(CLIENT_TAG).apply(resumeToken); + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + CLIENT_TAG, + resumeToken, + clientServerConnection, + resumableFramesStore); + final ResumableClientSetup resumableClientSetup = + new ResumableClientSetup(); + final ClientRSocketSession session = + new ClientRSocketSession( + resumeToken, + resumableDuplexConnection, + connectionMono, + resumableClientSetup::init, + resumableFramesStore, + resume.getSessionDuration(), + resume.getRetry(), + resume.isCleanupStoreOnKeepAlive()); + keepAliveHandler = + new KeepAliveHandler.ResumableKeepAliveHandler( + resumableDuplexConnection, session, session); + wrappedConnection = resumableDuplexConnection; + } else { + keepAliveHandler = + new KeepAliveHandler.DefaultKeepAliveHandler(); + wrappedConnection = clientServerConnection; + } + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer( + wrappedConnection, interceptors, true); + + final LeaseSpec leases; + final RequesterLeaseTracker requesterLeaseTracker; + if (leaseEnabled) { + leases = new LeaseSpec(); + leaseConfigurer.accept(leases); + requesterLeaseTracker = + new RequesterLeaseTracker( + CLIENT_TAG, leases.maxPendingRequests); + } else { + leases = null; + requesterLeaseTracker = null; + } + + final Sinks.Empty requesterOnAllClosedSink = + Sinks.unsafe().empty(); + final Sinks.Empty responderOnAllClosedSink = + Sinks.unsafe().empty(); + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), payloadDecoder, - responderLeaseHandler, + StreamIdSupplier.clientSupplier(), mtu, maxFrameLength, - maxInboundPayloadSize); - - return wrappedConnection - .sendOne(setupFrame.retain()) - .thenReturn(wrappedRSocketRequester); - }) - .doFinally(signalType -> setup.release()); + maxInboundPayloadSize, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + keepAliveHandler, + interceptors::initRequesterRequestInterceptor, + requesterLeaseTracker, + requesterOnAllClosedSink, + Mono.whenDelayError( + responderOnAllClosedSink.asMono(), + requesterOnAllClosedSink.asMono())); + + RSocket wrappedRSocketRequester = + interceptors.initRequester(rSocketRequester); + + SocketAcceptor acceptor = + this.acceptor != null + ? this.acceptor + : SocketAcceptor.with(new RSocket() {}); + + ConnectionSetupPayload setup = + new DefaultConnectionSetupPayload(setupFrame); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setup, wrappedRSocketRequester) + .map( + rSocketHandler -> { + RSocket wrappedRSocketHandler = + interceptors.initResponder(rSocketHandler); + + ResponderLeaseTracker responderLeaseTracker = + leaseEnabled + ? new ResponderLeaseTracker( + CLIENT_TAG, + wrappedConnection, + leases.sender) + : null; + + RSocket rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + wrappedRSocketHandler, + payloadDecoder, + responderLeaseTracker, + mtu, + maxFrameLength, + maxInboundPayloadSize, + leaseEnabled + && leases.sender + instanceof TrackingLeaseSender + ? rSocket -> + interceptors + .initResponderRequestInterceptor( + rSocket, + (RequestInterceptor) + leases.sender) + : interceptors + ::initResponderRequestInterceptor, + responderOnAllClosedSink); + + return wrappedRSocketRequester; + }) + .doFinally(signalType -> setup.release()); + }); }); }) .as( diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java index 66e2c60ec..b8a9c00ff 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; import io.netty.buffer.ByteBuf; +import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -29,20 +30,22 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; import io.rsocket.keepalive.KeepAliveFramesAcceptor; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.keepalive.KeepAliveSupport; -import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collection; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; /** @@ -62,10 +65,11 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { AtomicReferenceFieldUpdater.newUpdater( RSocketRequester.class, Throwable.class, "terminationError"); - private final DuplexConnection connection; - private final RequesterLeaseHandler leaseHandler; + @Nullable private final RequesterLeaseTracker requesterLeaseTracker; + + private final Sinks.Empty onThisSideClosedSink; + private final Mono onAllClosed; private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; - private final MonoProcessor onClose; RSocketRequester( DuplexConnection connection, @@ -77,24 +81,25 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { int keepAliveTickPeriod, int keepAliveAckTimeout, @Nullable KeepAliveHandler keepAliveHandler, - RequesterLeaseHandler leaseHandler) { + Function requestInterceptorFunction, + @Nullable RequesterLeaseTracker requesterLeaseTracker, + Sinks.Empty onThisSideClosedSink, + Mono onAllClosed) { super( mtu, maxFrameLength, maxInboundPayloadSize, payloadDecoder, - connection.alloc(), - streamIdSupplier); - - this.connection = connection; - this.leaseHandler = leaseHandler; - this.onClose = MonoProcessor.create(); + connection, + streamIdSupplier, + requestInterceptorFunction); - UnboundedProcessor sendProcessor = super.getSendProcessor(); + this.requesterLeaseTracker = requesterLeaseTracker; + this.onThisSideClosedSink = onThisSideClosedSink; + this.onAllClosed = onAllClosed; // DO NOT Change the order here. The Send processor must be subscribed to before receiving - connection.onClose().subscribe(null, this::tryTerminateOnConnectionError, this::tryShutdown); - connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); + connection.onClose().subscribe(null, this::tryShutdown, this::tryShutdown); connection.receive().subscribe(this::handleIncomingFrames, e -> {}); @@ -103,7 +108,9 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { new ClientKeepAliveSupport(this.getAllocator(), keepAliveTickPeriod, keepAliveAckTimeout); this.keepAliveFramesAcceptor = keepAliveHandler.start( - keepAliveSupport, sendProcessor::onNextPrioritized, this::tryTerminateOnKeepAlive); + keepAliveSupport, + (keepAliveFrame) -> connection.sendFrame(0, keepAliveFrame), + this::tryTerminateOnKeepAlive); } else { keepAliveFramesAcceptor = null; } @@ -111,7 +118,11 @@ class RSocketRequester extends RequesterResponderSupport implements RSocket { @Override public Mono fireAndForget(Payload payload) { - return new FireAndForgetRequesterMono(payload, this); + if (this.requesterLeaseTracker == null) { + return new FireAndForgetRequesterMono(payload, this); + } else { + return new SlowFireAndForgetRequesterMono(payload, this); + } } @Override @@ -141,12 +152,12 @@ public Mono metadataPush(Payload payload) { } @Override - public int getNextStreamId() { - RequesterLeaseHandler leaseHandler = this.leaseHandler; - if (!leaseHandler.useLease()) { - throw reactor.core.Exceptions.propagate(leaseHandler.leaseError()); - } + public RequesterLeaseTracker getRequesterLeaseTracker() { + return this.requesterLeaseTracker; + } + @Override + public int getNextStreamId() { int nextStreamId = super.getNextStreamId(); Throwable terminationError = this.terminationError; @@ -159,11 +170,6 @@ public int getNextStreamId() { @Override public int addAndGetNextStreamId(FrameHandler frameHandler) { - RequesterLeaseHandler leaseHandler = this.leaseHandler; - if (!leaseHandler.useLease()) { - throw reactor.core.Exceptions.propagate(leaseHandler.leaseError()); - } - int nextStreamId = super.addAndGetNextStreamId(frameHandler); Throwable terminationError = this.terminationError; @@ -177,12 +183,21 @@ public int addAndGetNextStreamId(FrameHandler frameHandler) { @Override public double availability() { - return Math.min(connection.availability(), leaseHandler.availability()); + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + if (requesterLeaseTracker != null) { + return Math.min(getDuplexConnection().availability(), requesterLeaseTracker.availability()); + } else { + return getDuplexConnection().availability(); + } } @Override public void dispose() { - tryShutdown(); + if (terminationError != null) { + return; + } + + getDuplexConnection().sendErrorAndClose(new ConnectionErrorException("Disposed")); } @Override @@ -192,7 +207,7 @@ public boolean isDisposed() { @Override public Mono onClose() { - return onClose; + return onAllClosed; } private void handleIncomingFrames(ByteBuf frame) { @@ -206,13 +221,9 @@ private void handleIncomingFrames(ByteBuf frame) { } } catch (Throwable t) { LOGGER.error("Unexpected error during frame handling", t); - super.getSendProcessor() - .onNext( - ErrorFrameCodec.encode( - super.getAllocator(), - 0, - new ConnectionErrorException("Unexpected error during frame handling", t))); - this.tryTerminateOnConnectionError(t); + final ConnectionErrorException error = + new ConnectionErrorException("Unexpected error during frame handling", t); + getDuplexConnection().sendErrorAndClose(error); } } @@ -222,7 +233,7 @@ private void handleStreamZero(FrameType type, ByteBuf frame) { tryTerminateOnZeroError(frame); break; case LEASE: - leaseHandler.receive(frame); + requesterLeaseTracker.handleLeaseFrame(frame); break; case KEEPALIVE: if (keepAliveFramesAcceptor != null) { @@ -301,10 +312,34 @@ private void tryTerminateOnKeepAlive(KeepAliveSupport.KeepAlive keepAlive) { () -> new ConnectionErrorException( String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); + getDuplexConnection().dispose(); } - private void tryTerminateOnConnectionError(Throwable e) { - tryTerminate(() -> e); + private void tryShutdown(Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } + if (terminationError == null) { + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(CLOSED_CHANNEL_EXCEPTION); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.info( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } } private void tryTerminateOnZeroError(ByteBuf errorFrame) { @@ -312,50 +347,99 @@ private void tryTerminateOnZeroError(ByteBuf errorFrame) { } private void tryTerminate(Supplier errorSupplier) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } if (terminationError == null) { Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { terminate(e); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); } } } private void tryShutdown() { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } if (terminationError == null) { if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) { terminate(CLOSED_CHANNEL_EXCEPTION); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); } } } private void terminate(Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("closing requester " + getDuplexConnection() + " due to " + e); + } if (keepAliveFramesAcceptor != null) { keepAliveFramesAcceptor.dispose(); } - connection.dispose(); - leaseHandler.dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + if (requesterLeaseTracker != null) { + requesterLeaseTracker.dispose(e); + } + + final Collection activeStreamsCopy; synchronized (this) { - activeStreams - .values() - .forEach( - receiver -> { - try { - receiver.handleError(e); - } catch (Throwable ignored) { - } - }); + final IntObjectMap activeStreams = this.activeStreams; + activeStreamsCopy = new ArrayList<>(activeStreams.values()); + } + + for (FrameHandler handler : activeStreamsCopy) { + if (handler != null) { + try { + handler.handleError(e); + } catch (Throwable ignored) { + } + } } - this.getSendProcessor().dispose(); if (e == CLOSED_CHANNEL_EXCEPTION) { - onClose.onComplete(); + onThisSideClosedSink.tryEmitEmpty(); } else { - onClose.onError(e); + onThisSideClosedSink.tryEmitError(e); + } + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("requester closed " + getDuplexConnection()); } - } - - private void handleSendProcessorError(Throwable t) { - connection.dispose(); } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 2368445c9..50c5ba54c 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package io.rsocket.core; import io.netty.buffer.ByteBuf; +import io.netty.util.collection.IntObjectMap; import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -25,21 +26,26 @@ import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; -import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collection; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; import java.util.function.Supplier; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ class RSocketResponder extends RequesterResponderSupport implements RSocket { @@ -48,11 +54,10 @@ class RSocketResponder extends RequesterResponderSupport implements RSocket { private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); - private final DuplexConnection connection; private final RSocket requestHandler; + private final Sinks.Empty onThisSideClosedSink; - private final ResponderLeaseHandler leaseHandler; - private final Disposable leaseHandlerDisposable; + @Nullable private final ResponderLeaseTracker leaseHandler; private volatile Throwable terminationError; private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = @@ -63,42 +68,45 @@ class RSocketResponder extends RequesterResponderSupport implements RSocket { DuplexConnection connection, RSocket requestHandler, PayloadDecoder payloadDecoder, - ResponderLeaseHandler leaseHandler, + @Nullable ResponderLeaseTracker leaseHandler, int mtu, int maxFrameLength, - int maxInboundPayloadSize) { - super(mtu, maxFrameLength, maxInboundPayloadSize, payloadDecoder, connection.alloc(), null); - this.connection = connection; + int maxInboundPayloadSize, + Function requestInterceptorFunction, + Sinks.Empty onThisSideClosedSink) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + null, + requestInterceptorFunction); this.requestHandler = requestHandler; this.leaseHandler = leaseHandler; + this.onThisSideClosedSink = onThisSideClosedSink; - // DO NOT Change the order here. The Send processor must be subscribed to before receiving - // connections - UnboundedProcessor sendProcessor = super.getSendProcessor(); - - connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); - - connection.receive().subscribe(this::handleFrame, e -> {}); - leaseHandlerDisposable = leaseHandler.send(sendProcessor::onNextPrioritized); - - this.connection + connection .onClose() .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose); - } - private void handleSendProcessorError(Throwable t) { - for (FrameHandler frameHandler : activeStreams.values()) { - frameHandler.handleError(t); - } + connection.receive().subscribe(this::handleFrame, e -> {}); } private void tryTerminateOnConnectionError(Throwable e) { + if (LOGGER.isDebugEnabled()) { + + LOGGER.debug("Try terminate connection on responder side"); + } tryTerminate(() -> e); } private void tryTerminateOnConnectionClose() { + if (LOGGER.isDebugEnabled()) { + LOGGER.info("Try terminate connection on responder side"); + } tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); } @@ -106,7 +114,7 @@ private void tryTerminate(Supplier errorSupplier) { if (terminationError == null) { Throwable e = errorSupplier.get(); if (TERMINATION_ERROR.compareAndSet(this, null, e)) { - cleanup(e); + doOnDispose(); } } } @@ -114,12 +122,7 @@ private void tryTerminate(Supplier errorSupplier) { @Override public Mono fireAndForget(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.fireAndForget(payload); - } else { - payload.release(); - return Mono.error(leaseHandler.leaseError()); - } + return requestHandler.fireAndForget(payload); } catch (Throwable t) { return Mono.error(t); } @@ -128,12 +131,7 @@ public Mono fireAndForget(Payload payload) { @Override public Mono requestResponse(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestResponse(payload); - } else { - payload.release(); - return Mono.error(leaseHandler.leaseError()); - } + return requestHandler.requestResponse(payload); } catch (Throwable t) { return Mono.error(t); } @@ -142,12 +140,7 @@ public Mono requestResponse(Payload payload) { @Override public Flux requestStream(Payload payload) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestStream(payload); - } else { - payload.release(); - return Flux.error(leaseHandler.leaseError()); - } + return requestHandler.requestStream(payload); } catch (Throwable t) { return Flux.error(t); } @@ -156,24 +149,7 @@ public Flux requestStream(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { try { - if (leaseHandler.useLease()) { - return requestHandler.requestChannel(payloads); - } else { - return Flux.error(leaseHandler.leaseError()); - } - } catch (Throwable t) { - return Flux.error(t); - } - } - - private Flux requestChannel(Payload payload, Publisher payloads) { - try { - if (leaseHandler.useLease()) { - return requestHandler.requestChannel(payloads); - } else { - payload.release(); - return Flux.error(leaseHandler.leaseError()); - } + return requestHandler.requestChannel(payloads); } catch (Throwable t) { return Flux.error(t); } @@ -195,29 +171,53 @@ public void dispose() { @Override public boolean isDisposed() { - return connection.isDisposed(); + return getDuplexConnection().isDisposed(); } @Override public Mono onClose() { - return connection.onClose(); + return getDuplexConnection().onClose(); } - private void cleanup(Throwable e) { + final void doOnDispose() { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("closing responder " + getDuplexConnection()); + } cleanUpSendingSubscriptions(); - connection.dispose(); - leaseHandlerDisposable.dispose(); + getDuplexConnection().dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } + + final ResponderLeaseTracker handler = leaseHandler; + if (handler != null) { + handler.dispose(); + } + requestHandler.dispose(); - super.getSendProcessor().dispose(); + onThisSideClosedSink.tryEmitEmpty(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("responder closed " + getDuplexConnection()); + } } - private synchronized void cleanUpSendingSubscriptions() { - activeStreams.values().forEach(FrameHandler::handleCancel); - activeStreams.clear(); + private void cleanUpSendingSubscriptions() { + final Collection activeStreamsCopy; + synchronized (this) { + final IntObjectMap activeStreams = this.activeStreams; + activeStreamsCopy = new ArrayList<>(activeStreams.values()); + } + + for (FrameHandler handler : activeStreamsCopy) { + if (handler != null) { + handler.handleCancel(); + } + } } - private void handleFrame(ByteBuf frame) { + final void handleFrame(ByteBuf frame) { try { int streamId = FrameHeaderCodec.streamId(frame); FrameHandler receiver; @@ -235,7 +235,8 @@ private void handleFrame(ByteBuf frame) { break; case REQUEST_CHANNEL: long channelInitialRequestN = RequestChannelFrameCodec.initialRequestN(frame); - handleChannel(streamId, frame, channelInitialRequestN); + handleChannel( + streamId, frame, channelInitialRequestN, FrameHeaderCodec.hasComplete(frame)); break; case METADATA_PUSH: handleMetadataPush(metadataPush(super.getPayloadDecoder().apply(frame))); @@ -282,8 +283,9 @@ private void handleFrame(ByteBuf frame) { } break; case SETUP: - super.getSendProcessor() - .onNext( + getDuplexConnection() + .sendFrame( + streamId, ErrorFrameCodec.encode( super.getAllocator(), streamId, @@ -291,8 +293,9 @@ private void handleFrame(ByteBuf frame) { break; case LEASE: default: - super.getSendProcessor() - .onNext( + getDuplexConnection() + .sendFrame( + streamId, ErrorFrameCodec.encode( super.getAllocator(), streamId, @@ -302,8 +305,9 @@ private void handleFrame(ByteBuf frame) { } } catch (Throwable t) { LOGGER.error("Unexpected error during frame handling", t); - super.getSendProcessor() - .onNext( + getDuplexConnection() + .sendFrame( + 0, ErrorFrameCodec.encode( super.getAllocator(), 0, @@ -312,78 +316,158 @@ private void handleFrame(ByteBuf frame) { } } - private void handleFireAndForget(int streamId, ByteBuf frame) { - if (FrameHeaderCodec.hasFollows(frame)) { - FireAndForgetResponderSubscriber subscriber = - new FireAndForgetResponderSubscriber(streamId, frame, this, this); - - this.add(streamId, subscriber); + final void handleFireAndForget(int streamId, ByteBuf frame) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + if (FrameHeaderCodec.hasFollows(frame)) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } + + FireAndForgetResponderSubscriber subscriber = + new FireAndForgetResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(new FireAndForgetResponderSubscriber(streamId, this)); + } else { + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + } + } } else { - fireAndForget(super.getPayloadDecoder().apply(frame)) - .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } } } - private void handleRequestResponse(int streamId, ByteBuf frame) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestResponseResponderSubscriber subscriber = - new RequestResponseResponderSubscriber(streamId, frame, this, this); + final void handleRequestResponse(int streamId, ByteBuf frame) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } - this.add(streamId, subscriber); - } else { - RequestResponseResponderSubscriber subscriber = - new RequestResponseResponderSubscriber(streamId, this); + if (FrameHeaderCodec.hasFollows(frame)) { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, this); - if (this.add(streamId, subscriber)) { - this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + if (this.add(streamId, subscriber)) { + this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } } + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); } } - private void handleStream(int streamId, ByteBuf frame, long initialRequestN) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestStreamResponderSubscriber subscriber = - new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); + final void handleStream(int streamId, ByteBuf frame, long initialRequestN) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } - this.add(streamId, subscriber); - } else { - RequestStreamResponderSubscriber subscriber = - new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + if (FrameHeaderCodec.hasFollows(frame)) { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); - if (this.add(streamId, subscriber)) { - this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + this.add(streamId, subscriber); + } else { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + + if (this.add(streamId, subscriber)) { + this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); } + sendLeaseRejection(streamId, leaseError); } } - private void handleChannel(int streamId, ByteBuf frame, long initialRequestN) { - if (FrameHeaderCodec.hasFollows(frame)) { - RequestChannelResponderSubscriber subscriber = - new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); + final void handleChannel(int streamId, ByteBuf frame, long initialRequestN, boolean complete) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } - this.add(streamId, subscriber); - } else { - final Payload firstPayload = super.getPayloadDecoder().apply(frame); - RequestChannelResponderSubscriber subscriber = - new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); + if (FrameHeaderCodec.hasFollows(frame)) { + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); - if (this.add(streamId, subscriber)) { - this.requestChannel(firstPayload, subscriber).subscribe(subscriber); + this.add(streamId, subscriber); + } else { + final Payload firstPayload = super.getPayloadDecoder().apply(frame); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); + + if (this.add(streamId, subscriber)) { + this.requestChannel(subscriber).subscribe(subscriber); + if (complete) { + subscriber.handleComplete(); + } + } } + } else { + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); } } + private void sendLeaseRejection(int streamId, Throwable leaseError) { + getDuplexConnection() + .sendFrame(streamId, ErrorFrameCodec.encode(getAllocator(), streamId, leaseError)); + } + private void handleMetadataPush(Mono result) { result.subscribe(MetadataPushResponderSubscriber.INSTANCE); } - private boolean add(int streamId, FrameHandler frameHandler) { - FrameHandler existingHandler; - synchronized (this) { - existingHandler = super.activeStreams.putIfAbsent(streamId, frameHandler); - } - - if (existingHandler != null) { + @Override + public boolean add(int streamId, FrameHandler frameHandler) { + if (!super.add(streamId, frameHandler)) { frameHandler.handleCancel(); return false; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java index 5a411e464..e969c39d2 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,23 +27,26 @@ import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.RSocketErrorException; import io.rsocket.SocketAcceptor; import io.rsocket.exceptions.InvalidSetupException; import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.lease.Leases; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.lease.TrackingLeaseSender; +import io.rsocket.plugins.DuplexConnectionInterceptor; import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.plugins.RequestInterceptor; import io.rsocket.resume.SessionManager; import io.rsocket.transport.ServerTransport; +import java.time.Duration; import java.util.Objects; import java.util.function.Consumer; import java.util.function.Supplier; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; /** * The main class for starting an RSocket server. @@ -64,11 +67,12 @@ public final class RSocketServer { private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); private Resume resume; - private Supplier> leasesSupplier = null; + private Consumer leaseConfigurer = null; private int mtu = 0; private int maxInboundPayloadSize = Integer.MAX_VALUE; private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + private Duration timeout = Duration.ofMinutes(1); private RSocketServer() {} @@ -182,21 +186,23 @@ public RSocketServer resume(Resume resume) { * *

{@code
    * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
-   *         .lease(Leases::new)
+   *         .lease(spec ->
+   *            spec.sender(() -> Flux.interval(ofSeconds(1))
+   *                                  .map(__ -> Lease.create(ofSeconds(1), 1)))
+   *         )
    *         .bind(TcpServerTransport.create("localhost", 7000))
    *         .subscribe();
    * }
* *

By default this is not enabled. * - * @param supplier supplier for a {@link Leases} - * @return the same instance for method chaining + * @param leaseConfigurer consumer which accepts {@link LeaseSpec} and use it for configuring * @return the same instance for method chaining * @see Lease * Semantics */ - public RSocketServer lease(Supplier> supplier) { - this.leasesSupplier = supplier; + public RSocketServer lease(Consumer leaseConfigurer) { + this.leaseConfigurer = leaseConfigurer; return this; } @@ -220,6 +226,23 @@ public RSocketServer maxInboundPayloadSize(int maxInboundPayloadSize) { return this; } + /** + * Specify the max time to wait for the first frame (e.g. {@code SETUP}) on an accepted + * connection. + * + *

By default this is set to 1 minute. + * + * @param timeout duration + * @return the same instance for method chaining + */ + public RSocketServer maxTimeToFirstFrame(Duration timeout) { + if (timeout.isNegative() || timeout.isZero()) { + throw new IllegalArgumentException("Setup Handling Timeout should be greater than zero"); + } + this.timeout = timeout; + return this; + } + /** * When this is set, frames larger than the given maximum transmission unit (mtu) size value are * fragmented. @@ -284,7 +307,7 @@ public RSocketServer payloadDecoder(PayloadDecoder decoder) { public Mono bind(ServerTransport transport) { return Mono.defer( new Supplier>() { - final ServerSetup serverSetup = serverSetup(); + final ServerSetup serverSetup = serverSetup(timeout); @Override public Mono get() { @@ -323,7 +346,7 @@ public ServerTransport.ConnectionAcceptor asConnectionAcceptor() { public ServerTransport.ConnectionAcceptor asConnectionAcceptor(int maxFrameLength) { assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); return new ServerTransport.ConnectionAcceptor() { - private final ServerSetup serverSetup = serverSetup(); + private final ServerSetup serverSetup = serverSetup(timeout); @Override public Mono apply(DuplexConnection connection) { @@ -333,85 +356,94 @@ public Mono apply(DuplexConnection connection) { } private Mono acceptor( - ServerSetup serverSetup, DuplexConnection connection, int maxFrameLength) { + ServerSetup serverSetup, DuplexConnection sourceConnection, int maxFrameLength) { + + final DuplexConnection interceptedConnection = + interceptors.initConnection(DuplexConnectionInterceptor.Type.SOURCE, sourceConnection); - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, interceptors, false); + return serverSetup + .init(LoggingDuplexConnection.wrapIfEnabled(interceptedConnection)) + .flatMap( + tuple2 -> { + final ByteBuf startFrame = tuple2.getT1(); + final DuplexConnection clientServerConnection = tuple2.getT2(); - return multiplexer - .asSetupConnection() - .receive() - .next() - .flatMap(startFrame -> accept(serverSetup, startFrame, multiplexer, maxFrameLength)); + return accept(serverSetup, startFrame, clientServerConnection, maxFrameLength); + }); } private Mono acceptResume( - ServerSetup serverSetup, ByteBuf resumeFrame, ClientServerInputMultiplexer multiplexer) { - return serverSetup.acceptRSocketResume(resumeFrame, multiplexer); + ServerSetup serverSetup, ByteBuf resumeFrame, DuplexConnection clientServerConnection) { + return serverSetup.acceptRSocketResume(resumeFrame, clientServerConnection); } private Mono accept( ServerSetup serverSetup, ByteBuf startFrame, - ClientServerInputMultiplexer multiplexer, + DuplexConnection clientServerConnection, int maxFrameLength) { switch (FrameHeaderCodec.frameType(startFrame)) { case SETUP: - return acceptSetup(serverSetup, startFrame, multiplexer, maxFrameLength); + return acceptSetup(serverSetup, startFrame, clientServerConnection, maxFrameLength); case RESUME: - return acceptResume(serverSetup, startFrame, multiplexer); + return acceptResume(serverSetup, startFrame, clientServerConnection); default: - return serverSetup - .sendError( - multiplexer, - new InvalidSetupException( - "invalid setup frame: " + FrameHeaderCodec.frameType(startFrame))) - .doFinally( - signalType -> { - startFrame.release(); - multiplexer.dispose(); - }); + serverSetup.sendError( + clientServerConnection, + new InvalidSetupException("SETUP or RESUME frame must be received before any others")); + return clientServerConnection.onClose(); } } private Mono acceptSetup( ServerSetup serverSetup, ByteBuf setupFrame, - ClientServerInputMultiplexer multiplexer, + DuplexConnection clientServerConnection, int maxFrameLength) { if (!SetupFrameCodec.isSupportedVersion(setupFrame)) { - return serverSetup - .sendError( - multiplexer, - new InvalidSetupException( - "Unsupported version: " + SetupFrameCodec.humanReadableVersion(setupFrame))) - .doFinally(signalType -> multiplexer.dispose()); + serverSetup.sendError( + clientServerConnection, + new InvalidSetupException( + "Unsupported version: " + SetupFrameCodec.humanReadableVersion(setupFrame))); + return clientServerConnection.onClose(); } - boolean leaseEnabled = leasesSupplier != null; + boolean leaseEnabled = leaseConfigurer != null; if (SetupFrameCodec.honorLease(setupFrame) && !leaseEnabled) { - return serverSetup - .sendError(multiplexer, new InvalidSetupException("lease is not supported")) - .doFinally(signalType -> multiplexer.dispose()); + serverSetup.sendError( + clientServerConnection, new InvalidSetupException("lease is not supported")); + return clientServerConnection.onClose(); } return serverSetup.acceptRSocketSetup( setupFrame, - multiplexer, - (keepAliveHandler, wrappedMultiplexer) -> { + clientServerConnection, + (keepAliveHandler, wrappedDuplexConnection) -> { ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(setupFrame.retain()); + final InitializingInterceptorRegistry interceptors = this.interceptors; + final ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(wrappedDuplexConnection, interceptors, false); + + final LeaseSpec leases; + final RequesterLeaseTracker requesterLeaseTracker; + if (leaseEnabled) { + leases = new LeaseSpec(); + leaseConfigurer.accept(leases); + requesterLeaseTracker = + new RequesterLeaseTracker(SERVER_TAG, leases.maxPendingRequests); + } else { + leases = null; + requesterLeaseTracker = null; + } - Leases leases = leaseEnabled ? leasesSupplier.get() : null; - RequesterLeaseHandler requesterLeaseHandler = - leaseEnabled - ? new RequesterLeaseHandler.Impl(SERVER_TAG, leases.receiver()) - : RequesterLeaseHandler.None; + final Sinks.Empty requesterOnAllClosedSink = Sinks.unsafe().empty(); + final Sinks.Empty responderOnAllClosedSink = Sinks.unsafe().empty(); RSocket rSocketRequester = new RSocketRequester( - wrappedMultiplexer.asServerConnection(), + multiplexer.asServerConnection(), payloadDecoder, StreamIdSupplier.serverSupplier(), mtu, @@ -420,7 +452,11 @@ private Mono acceptSetup( setupPayload.keepAliveInterval(), setupPayload.keepAliveMaxLifetime(), keepAliveHandler, - requesterLeaseHandler); + interceptors::initRequesterRequestInterceptor, + requesterLeaseTracker, + requesterOnAllClosedSink, + Mono.whenDelayError( + responderOnAllClosedSink.asMono(), requesterOnAllClosedSink.asMono())); RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); @@ -429,41 +465,50 @@ private Mono acceptSetup( .accept(setupPayload, wrappedRSocketRequester) .onErrorResume( err -> - serverSetup - .sendError(multiplexer, rejectedSetupError(err)) + Mono.fromRunnable( + () -> + serverSetup.sendError( + wrappedDuplexConnection, rejectedSetupError(err))) + .then(wrappedDuplexConnection.onClose()) .then(Mono.error(err))) .doOnNext( rSocketHandler -> { RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); - DuplexConnection connection = wrappedMultiplexer.asClientConnection(); + DuplexConnection clientConnection = multiplexer.asClientConnection(); - ResponderLeaseHandler responderLeaseHandler = + ResponderLeaseTracker responderLeaseTracker = leaseEnabled - ? new ResponderLeaseHandler.Impl<>( - SERVER_TAG, connection.alloc(), leases.sender(), leases.stats()) - : ResponderLeaseHandler.None; + ? new ResponderLeaseTracker(SERVER_TAG, clientConnection, leases.sender) + : null; RSocket rSocketResponder = new RSocketResponder( - connection, + clientConnection, wrappedRSocketHandler, payloadDecoder, - responderLeaseHandler, + responderLeaseTracker, mtu, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + leaseEnabled && leases.sender instanceof TrackingLeaseSender + ? rSocket -> + interceptors.initResponderRequestInterceptor( + rSocket, (RequestInterceptor) leases.sender) + : interceptors::initResponderRequestInterceptor, + responderOnAllClosedSink); }) .doFinally(signalType -> setupPayload.release()) .then(); }); } - private ServerSetup serverSetup() { - return resume != null ? createSetup() : new ServerSetup.DefaultServerSetup(); + private ServerSetup serverSetup(Duration timeout) { + return resume != null ? createSetup(timeout) : new ServerSetup.DefaultServerSetup(timeout); } - ServerSetup createSetup() { + ServerSetup createSetup(Duration timeout) { return new ServerSetup.ResumableServerSetup( + timeout, new SessionManager(), resume.getSessionDuration(), resume.getStreamTimeout(), @@ -471,7 +516,7 @@ ServerSetup createSetup() { resume.isCleanupStoreOnKeepAlive()); } - private Exception rejectedSetupError(Throwable err) { + private RSocketErrorException rejectedSetupError(Throwable err) { String msg = err.getMessage(); return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java index 8355022c9..aab491793 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -26,6 +26,7 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.ErrorFrameCodec; @@ -33,46 +34,60 @@ import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.RequestInterceptor; import java.util.Objects; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; -import reactor.core.*; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; import reactor.util.annotation.NonNull; import reactor.util.annotation.Nullable; import reactor.util.context.Context; +import reactor.util.context.ContextView; final class RequestChannelRequesterFlux extends Flux - implements RequesterFrameHandler, CoreSubscriber, Subscription, Scannable { + implements RequesterFrameHandler, + LeasePermitHandler, + CoreSubscriber, + Subscription, + Scannable { final ByteBufAllocator allocator; final int mtu; final int maxFrameLength; final int maxInboundPayloadSize; final RequesterResponderSupport requesterResponderSupport; - final UnboundedProcessor sendProcessor; + final DuplexConnection connection; final PayloadDecoder payloadDecoder; final Publisher payloadsPublisher; + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestChannelRequesterFlux.class, "state"); int streamId; - Context cachedContext; + boolean isFirstSignal = true; + Payload firstPayload; - boolean isFirstPayload = true; + Subscription outboundSubscription; + boolean outboundDone; + Throwable outboundError; + Context cachedContext; CoreSubscriber inboundSubscriber; - Subscription outboundSubscription; boolean inboundDone; - boolean outboundDone; + long requested; + long produced; CompositeByteBuf frames; @@ -84,8 +99,10 @@ final class RequestChannelRequesterFlux extends Flux this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override @@ -94,8 +111,14 @@ public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestChannelFlux allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestChannelFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + Operators.error(actual, e); return; } @@ -117,7 +140,9 @@ public final void request(long n) { return; } - long previousState = addRequestN(STATE, this, n); + this.requested = Operators.addCap(this.requested, n); + + long previousState = addRequestN(STATE, this, n, this.requesterLeaseTracker == null); if (isTerminated(previousState)) { return; } @@ -125,8 +150,9 @@ public final void request(long n) { if (hasRequested(previousState)) { if (isFirstFrameSent(previousState) && !isMaxAllowedRequestN(extractRequestN(previousState))) { - final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, this.streamId, n); - this.sendProcessor.onNext(requestNFrame); + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); } return; } @@ -142,39 +168,97 @@ public void onNext(Payload p) { return; } - if (this.isFirstPayload) { - this.isFirstPayload = false; + if (this.isFirstSignal) { + this.isFirstSignal = false; - long state = this.state; - if (isTerminated(state)) { - p.release(); - return; + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + + if (leaseEnabled) { + this.firstPayload = p; + + final long previousState = markFirstPayloadReceived(STATE, this); + if (isTerminated(previousState)) { + this.firstPayload = null; + p.release(); + return; + } + + requesterLeaseTracker.issue(this); + } else { + final long state = this.state; + if (isTerminated(state)) { + p.release(); + return; + } + // TODO: check if source is Scalar | Callable | Mono + sendFirstPayload(p, extractRequestN(state), false); } - sendFirstPayload(p, extractRequestN(state)); } else { sendFollowingPayload(p); } } - void sendFirstPayload(Payload firstPayload, long initialRequestN) { + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + sendFirstPayload( + firstPayload, extractRequestN(previousState), isOutboundTerminated(previousState)); + return true; + } + + void sendFirstPayload(Payload firstPayload, long initialRequestN, boolean completed) { int mtu = this.mtu; try { if (!isValid(mtu, this.maxFrameLength, firstPayload, true)) { - lazyTerminate(STATE, this); + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } firstPayload.release(); - this.outboundSubscription.cancel(); this.inboundDone = true; - this.inboundSubscriber.onError( - new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + this.inboundSubscriber.onError(e); return; } } catch (IllegalReferenceCountException e) { - lazyTerminate(STATE, this); + final long previousState = markTerminated(STATE, this); - this.outboundSubscription.cancel(); + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } this.inboundDone = true; this.inboundSubscriber.onError(e); @@ -182,7 +266,7 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { } final RequesterResponderSupport sm = this.requesterResponderSupport; - final UnboundedProcessor sender = this.sendProcessor; + final DuplexConnection connection = this.connection; final ByteBufAllocator allocator = this.allocator; final int streamId; @@ -190,18 +274,36 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { streamId = sm.addAndGetNextStreamId(this); this.streamId = streamId; } catch (Throwable t) { - this.inboundDone = true; final long previousState = markTerminated(STATE, this); firstPayload.release(); - this.outboundSubscription.cancel(); - if (!isTerminated(previousState)) { - this.inboundSubscriber.onError(Exceptions.unwrap(t)); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); } + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(ut); + return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + try { sendReleasingPayload( streamId, @@ -209,38 +311,82 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { initialRequestN, mtu, firstPayload, - sender, + connection, allocator, - // TODO: Should be a different flag in case of the scalar - // source or if we know in advance upstream is mono - false); - } catch (Throwable e) { - lazyTerminate(STATE, this); + completed); + } catch (Throwable t) { + final long previousState = markTerminated(STATE, this); + + firstPayload.release(); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } sm.remove(streamId, this); - this.outboundSubscription.cancel(); + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } this.inboundDone = true; - this.inboundSubscriber.onError(e); + this.inboundSubscriber.onError(t); return; } long previousState = markFirstFrameSent(STATE, this); if (isTerminated(previousState)) { + // now, this can be terminated in case of the following scenarios: + // + // 1) SendFirst is called synchronously from onNext, thus we can have + // handleError called before we marked first frame sent, thus we may check if + // inboundDone flag is true and exit execution without any further actions: if (this.inboundDone) { return; } sm.remove(streamId, this); - ReassemblyUtils.synchronizedRelease(this, previousState); + // 2) SendFirst is called asynchronously on the connection event-loop. Thus, we + // need to check if outbound error is present. Note, we check outboundError since + // in the last scenario, cancellation may terminate the state and async + // onComplete may set outboundDone to true. Thus, we explicitly check for + // outboundError + final Throwable outboundError = this.outboundError; + if (outboundError != null) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, outboundError); + connection.sendFrame(streamId, errorFrame); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, outboundError); + } - final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); - sender.onNext(cancelFrame); + this.inboundDone = true; + this.inboundSubscriber.onError(outboundError); + } else { + // 3) SendFirst is interleaving with cancel. Thus, we need to generate cancel + // frame + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_CHANNEL); + } + } return; } + if (!completed && isOutboundTerminated(previousState)) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + connection.sendFrame(streamId, completeFrame); + } + if (isMaxAllowedRequestN(initialRequestN)) { return; } @@ -248,14 +394,14 @@ void sendFirstPayload(Payload firstPayload, long initialRequestN) { long requestN = extractRequestN(previousState); if (isMaxAllowedRequestN(requestN)) { final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); - sender.onNext(requestNFrame); + connection.sendFrame(streamId, requestNFrame); return; } if (requestN > initialRequestN) { final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); - sender.onNext(requestNFrame); + connection.sendFrame(streamId, requestNFrame); } } @@ -267,16 +413,22 @@ final void sendFollowingPayload(Payload followingPayload) { if (!isValid(mtu, this.maxFrameLength, followingPayload, true)) { followingPayload.release(); - this.cancel(); - final IllegalArgumentException e = new IllegalArgumentException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + this.propagateErrorSafely(e); return; } } catch (IllegalReferenceCountException e) { - this.cancel(); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } this.propagateErrorSafely(e); @@ -292,54 +444,83 @@ final void sendFollowingPayload(Payload followingPayload) { FrameType.NEXT, mtu, followingPayload, - this.sendProcessor, + this.connection, allocator, true); } catch (Throwable e) { - this.cancel(); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } this.propagateErrorSafely(e); } } - void propagateErrorSafely(Throwable e) { + void propagateErrorSafely(Throwable t) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber if (!this.inboundDone) { synchronized (this) { if (!this.inboundDone) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + this.inboundDone = true; - this.inboundSubscriber.onError(e); + this.inboundSubscriber.onError(t); } else { - Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); } } } else { - Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); } } @Override public final void cancel() { + if (!tryCancel()) { + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + } + + boolean tryCancel() { long previousState = markTerminated(STATE, this); if (isTerminated(previousState)) { - return; + return false; } - this.outboundSubscription.cancel(); + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } - if (!isFirstFrameSent(previousState)) { + if (!isReadyToSendFirstFrame(previousState) && isFirstPayloadReceived(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); // no need to send anything, since we have not started a stream yet (no logical wire) - return; + return false; } - final int streamId = this.streamId; - this.requesterResponderSupport.remove(streamId, this); - ReassemblyUtils.synchronizedRelease(this, previousState); - final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); - this.sendProcessor.onNext(cancelFrame); + final boolean firstFrameSent = isFirstFrameSent(previousState); + if (firstFrameSent) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.connection.sendFrame(streamId, cancelFrame); + } + + return firstFrameSent; } @Override @@ -349,6 +530,7 @@ public void onError(Throwable t) { return; } + this.outboundError = t; this.outboundDone = true; long previousState = markTerminated(STATE, this); @@ -357,29 +539,49 @@ public void onError(Throwable t) { return; } - if (!isFirstFrameSent(previousState)) { - // first signal, thus, just propagates error to actual subscriber + if (this.isFirstSignal) { + this.inboundDone = true; this.inboundSubscriber.onError(t); + return; + } else if (!isReadyToSendFirstFrame(previousState)) { + // first signal is received but we are still waiting for lease permit to be issued, + // thus, just propagates error to actual subscriber + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + return; } ReassemblyUtils.synchronizedRelease(this, previousState); - final int streamId = this.streamId; - this.requesterResponderSupport.remove(streamId, this); - // propagates error to remote responder - final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); - this.sendProcessor.onNext(errorFrame); + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + // propagates error to remote responder + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + if (!isInboundTerminated(previousState)) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } - if (!isInboundTerminated(previousState)) { - // FIXME: must be scheduled on the connection event-loop to achieve serial - // behaviour on the inbound subscriber - synchronized (this) { - this.inboundDone = true; - this.inboundSubscriber.onError(t); + this.inboundDone = true; + this.inboundSubscriber.onError(t); + } + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); } - } else { - Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); } } @@ -397,19 +599,26 @@ public void onComplete() { } if (!isFirstFrameSent(previousState)) { - // first signal, thus, just propagates error to actual subscriber - this.inboundSubscriber.onError(new CancellationException("Empty Source")); + if (!isFirstPayloadReceived(previousState)) { + // first signal, thus, just propagates error to actual subscriber + this.inboundSubscriber.onError(new CancellationException("Empty Source")); + } return; } final int streamId = this.streamId; + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + + this.connection.sendFrame(streamId, completeFrame); if (isInboundTerminated(previousState)) { this.requesterResponderSupport.remove(streamId, this); - } - final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); - this.sendProcessor.onNext(completeFrame); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } } @Override @@ -427,11 +636,40 @@ public final void handleComplete() { if (isOutboundTerminated(previousState)) { this.requesterResponderSupport.remove(this.streamId, this); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } } this.inboundSubscriber.onComplete(); } + @Override + public final void handlePermitError(Throwable cause) { + this.inboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final Payload p = this.firstPayload; + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onReject(cause, FrameType.REQUEST_CHANNEL, p.metadata()); + } + p.release(); + + this.inboundSubscriber.onError(cause); + } + @Override public final void handleError(Throwable cause) { if (this.inboundDone) { @@ -447,13 +685,20 @@ public final void handleError(Throwable cause) { return; } + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + ReassemblyUtils.release(this, previousState); final int streamId = this.streamId; - this.requesterResponderSupport.remove(streamId, this); - this.outboundSubscription.cancel(); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); + } + this.inboundSubscriber.onError(cause); } @@ -465,6 +710,27 @@ public final void handlePayload(Payload value) { return; } + final long produced = this.produced; + if (this.requested == produced) { + value.release(); + if (!tryCancel()) { + return; + } + + final Throwable cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); + } + + this.inboundSubscriber.onError(cause); + return; + } + + this.produced = produced + 1; + this.inboundSubscriber.onNext(value); } } @@ -485,11 +751,19 @@ public void handleCancel() { return; } - if (isInboundTerminated(previousState)) { + final boolean inboundTerminated = isInboundTerminated(previousState); + if (inboundTerminated) { this.requesterResponderSupport.remove(this.streamId, this); } this.outboundSubscription.cancel(); + + if (inboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } } @Override @@ -513,9 +787,13 @@ public Context currentContext() { long state = this.state; if (isSubscribedOrTerminated(state)) { - Context contextWithDiscard = this.inboundSubscriber.currentContext().putAll(DISCARD_CONTEXT); - cachedContext = contextWithDiscard; - return contextWithDiscard; + Context cachedContext = this.cachedContext; + if (cachedContext == null) { + cachedContext = + this.inboundSubscriber.currentContext().putAll((ContextView) DISCARD_CONTEXT); + this.cachedContext = cachedContext; + } + return cachedContext; } return Context.empty(); diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java index 0c2258950..32128fee4 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -26,6 +26,7 @@ import io.netty.buffer.CompositeByteBuf; import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.exceptions.CanceledException; @@ -35,7 +36,7 @@ import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; @@ -46,6 +47,7 @@ import reactor.core.Exceptions; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; import reactor.util.context.Context; final class RequestChannelResponderSubscriber extends Flux @@ -60,9 +62,11 @@ final class RequestChannelResponderSubscriber extends Flux final int maxFrameLength; final int maxInboundPayloadSize; final RequesterResponderSupport requesterResponderSupport; - final UnboundedProcessor sendProcessor; + final DuplexConnection connection; final long firstRequest; + @Nullable final RequestInterceptor requestInterceptor; + final RSocket handler; volatile long state; @@ -84,6 +88,8 @@ final class RequestChannelResponderSubscriber extends Flux boolean inboundDone; boolean outboundDone; + long requested; + long produced; public RequestChannelResponderSubscriber( int streamId, @@ -97,8 +103,9 @@ public RequestChannelResponderSubscriber( this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; this.firstRequest = firstRequestN; @@ -119,8 +126,9 @@ public RequestChannelResponderSubscriber( this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.firstRequest = firstRequestN; this.firstPayload = firstPayload; @@ -173,6 +181,8 @@ public void request(long n) { return; } + this.requested = Operators.addCap(this.requested, n); + long previousState = StateUtils.addRequestN(STATE, this, n); if (isTerminated(previousState)) { // full termination can be the result of both sides completion / cancelFrame / remote or local @@ -190,6 +200,9 @@ public void request(long n) { Payload firstPayload = this.firstPayload; if (firstPayload != null) { this.firstPayload = null; + + this.produced++; + inboundSubscriber.onNext(firstPayload); } @@ -210,6 +223,8 @@ public void request(long n) { final Payload firstPayload = this.firstPayload; this.firstPayload = null; + this.produced++; + inboundSubscriber.onNext(firstPayload); inboundSubscriber.onComplete(); @@ -221,8 +236,9 @@ public void request(long n) { if (hasRequested(previousState)) { if (isFirstFrameSent(previousState) && !isMaxAllowedRequestN(StateUtils.extractRequestN(previousState))) { - final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, this.streamId, n); - this.sendProcessor.onNext(requestNFrame); + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); } return; } @@ -231,6 +247,9 @@ public void request(long n) { final Payload firstPayload = this.firstPayload; this.firstPayload = null; + + this.produced++; + inboundSubscriber.onNext(firstPayload); previousState = markFirstFrameSent(STATE, this); @@ -262,14 +281,16 @@ public void request(long n) { long requestN = StateUtils.extractRequestN(previousState); if (isMaxAllowedRequestN(requestN)) { + final int streamId = this.streamId; final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); - this.sendProcessor.onNext(requestNFrame); + this.connection.sendFrame(streamId, requestNFrame); } else { long firstRequestN = requestN - 1; if (firstRequestN > 0) { + final int streamId = this.streamId; final ByteBuf requestNFrame = - RequestNFrameCodec.encode(this.allocator, this.streamId, firstRequestN); - this.sendProcessor.onNext(requestNFrame); + RequestNFrameCodec.encode(this.allocator, streamId, firstRequestN); + this.connection.sendFrame(streamId, requestNFrame); } } } @@ -279,6 +300,7 @@ public void request(long n) { public void cancel() { long previousState = markInboundTerminated(STATE, this); if (isTerminated(previousState) || isInboundTerminated(previousState)) { + INBOUND_ERROR.lazySet(this, TERMINATED); return; } @@ -290,12 +312,20 @@ public void cancel() { final int streamId = this.streamId; - if (isOutboundTerminated(previousState)) { + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { this.requesterResponderSupport.remove(streamId, this); } final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); - this.sendProcessor.onNext(cancelFrame); + this.connection.sendFrame(streamId, cancelFrame); + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } } @Override @@ -317,10 +347,23 @@ public final void handleCancel() { this.firstPayload = null; firstPayload.release(); } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { return; } - this.tryTerminate(true); + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } } final long tryTerminate(boolean isFromInbound) { @@ -385,6 +428,58 @@ final void handlePayload(Payload p) { return; } + final long produced = this.produced; + if (this.requested == produced) { + p.release(); + + this.inboundDone = true; + + final Throwable cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, cause); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + if (!wasThrowableAdded) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + } + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + this.connection.sendFrame( + streamId, + ErrorFrameCodec.encode( + this.allocator, streamId, new CanceledException(cause.getMessage()))); + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + + // this is downstream subscription so need to cancel it just in case error signal has not + // reached it + // needs for disconnected upstream and downstream case + this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause); + } + return; + } + + this.produced = produced + 1; + this.inboundSubscriber.onNext(p); } } @@ -431,6 +526,11 @@ public final void handleError(Throwable t) { // reached it // needs for disconnected upstream and downstream case this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } } @Override @@ -443,13 +543,21 @@ public void handleComplete() { long previousState = markInboundTerminated(STATE, this); - if (isOutboundTerminated(previousState)) { + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { this.requesterResponderSupport.remove(this.streamId, this); } if (isFirstFrameSent(previousState)) { this.inboundSubscriber.onComplete(); } + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } } @Override @@ -465,18 +573,30 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) payload = this.payloadDecoder.apply(frame); } catch (Throwable t) { long previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); return; } this.outboundDone = true; // send error to terminate interaction + final int streamId = this.streamId; final ByteBuf errorFrame = - ErrorFrameCodec.encode( - this.allocator, this.streamId, new CanceledException(t.getMessage())); - this.sendProcessor.onNext(errorFrame); + ErrorFrameCodec.encode(this.allocator, streamId, new CanceledException(t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } return; } @@ -511,19 +631,33 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } long previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, e); + } + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); return; } this.outboundDone = true; // send error to terminate interaction + final int streamId = this.streamId; final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, + streamId, new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); - this.sendProcessor.onNext(errorFrame); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } return; } @@ -545,18 +679,32 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) ReferenceCountUtil.safeRelease(frames); previousState = this.tryTerminate(true); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); return; } // send error to terminate interaction + final int streamId = this.streamId; final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, + streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); - this.sendProcessor.onNext(errorFrame); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } return; } @@ -583,15 +731,9 @@ public void onNext(Payload p) { } final int streamId = this.streamId; - final UnboundedProcessor sender = this.sendProcessor; + final DuplexConnection connection = this.connection; final ByteBufAllocator allocator = this.allocator; - if (p == null) { - final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); - sender.onNext(completeFrame); - return; - } - final int mtu = this.mtu; try { if (!isValid(mtu, this.maxFrameLength, p, false)) { @@ -600,21 +742,36 @@ public void onNext(Payload p) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber long previousState = this.tryTerminate(false); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { Operators.onErrorDropped( new IllegalArgumentException( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), this.inboundSubscriber.currentContext()); return; + } else if (isOutboundTerminated(previousState)) { + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; } - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); - sender.onNext(errorFrame); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } return; } } catch (IllegalReferenceCountException e) { @@ -622,7 +779,15 @@ public void onNext(Payload p) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber long previousState = this.tryTerminate(false); - if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); return; } @@ -632,16 +797,25 @@ public void onNext(Payload p) { allocator, streamId, new CanceledException("Failed to validate payload. Cause:" + e.getMessage())); - sender.onNext(errorFrame); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } return; } try { - sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, connection, allocator, false); } catch (Throwable t) { // FIXME: must be scheduled on the connection event-loop to achieve serial // behaviour on the inbound subscriber - this.tryTerminate(false); + long previousState = this.tryTerminate(false); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null && !isTerminated(previousState)) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } } } @@ -677,15 +851,13 @@ public void onError(Throwable t) { } } - if (!isFirstFrameSent(previousState)) { - if (!hasRequested(previousState)) { - final Payload firstPayload = this.firstPayload; - this.firstPayload = null; - firstPayload.release(); - } - } - - if (wasThrowableAdded && !isInboundTerminated(previousState)) { + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (wasThrowableAdded + && isFirstFrameSent(previousState) + && !isInboundTerminated(previousState)) { Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); if (inboundError != TERMINATED) { // FIXME: must be scheduled on the connection event-loop to achieve serial @@ -699,7 +871,12 @@ public void onError(Throwable t) { } final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); - this.sendProcessor.onNext(errorFrame); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } } @Override @@ -717,12 +894,20 @@ public void onComplete() { final int streamId = this.streamId; - if (isInboundTerminated(previousState)) { + final boolean isInboundTerminated = isInboundTerminated(previousState); + if (isInboundTerminated) { this.requesterResponderSupport.remove(streamId, this); } final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); - this.sendProcessor.onNext(completeFrame); + this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java index 0ce91725b..a13b105b5 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -25,11 +25,12 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -41,7 +42,7 @@ import reactor.util.annotation.Nullable; final class RequestResponseRequesterMono extends Mono - implements RequesterFrameHandler, Subscription, Scannable { + implements RequesterFrameHandler, LeasePermitHandler, Subscription, Scannable { final ByteBufAllocator allocator; final Payload payload; @@ -49,9 +50,12 @@ final class RequestResponseRequesterMono extends Mono final int maxFrameLength; final int maxInboundPayloadSize; final RequesterResponderSupport requesterResponderSupport; - final UnboundedProcessor sendProcessor; + final DuplexConnection connection; final PayloadDecoder payloadDecoder; + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestResponseRequesterMono.class, "state"); @@ -70,8 +74,10 @@ final class RequestResponseRequesterMono extends Mono this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override @@ -79,8 +85,14 @@ public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestResponseMono allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestResponseMono allows only a single " + "Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); return; } @@ -88,15 +100,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - Operators.error( - actual, + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, p.metadata()); + } + p.release(); + + Operators.error(actual, e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + Operators.error(actual, e); return; } @@ -111,18 +136,38 @@ public final void request(long n) { return; } - long previousState = addRequestN(STATE, this, n); + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + final long previousState = addRequestN(STATE, this, n, !leaseEnabled); + if (isTerminated(previousState) || hasRequested(previousState)) { return; } - sendFirstPayload(this.payload, n); + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstPayload(this.payload); } - void sendFirstPayload(Payload payload, long initialRequestN) { + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstPayload(this.payload); + return true; + } + + void sendFirstPayload(Payload payload) { final RequesterResponderSupport sm = this.requesterResponderSupport; - final UnboundedProcessor sender = this.sendProcessor; + final DuplexConnection connection = this.connection; final ByteBufAllocator allocator = this.allocator; final int streamId; @@ -133,23 +178,38 @@ void sendFirstPayload(Payload payload, long initialRequestN) { this.done = true; final long previousState = markTerminated(STATE, this); + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + payload.release(); if (!isTerminated(previousState)) { - this.actual.onError(Exceptions.unwrap(t)); + this.actual.onError(ut); } return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + try { sendReleasingPayload( - streamId, FrameType.REQUEST_RESPONSE, this.mtu, payload, sender, allocator, true); + streamId, FrameType.REQUEST_RESPONSE, this.mtu, payload, connection, allocator, true); } catch (Throwable e) { this.done = true; lazyTerminate(STATE, this); sm.remove(streamId, this); + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + this.actual.onError(e); return; } @@ -163,7 +223,11 @@ void sendFirstPayload(Payload payload, long initialRequestN) { sm.remove(streamId, this); final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); - sender.onNext(cancelFrame); + connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } } } @@ -180,8 +244,13 @@ public final void cancel() { ReassemblyUtils.synchronizedRelease(this, previousState); - this.sendProcessor.onNext(CancelFrameCodec.encode(this.allocator, streamId)); - } else if (!hasRequested(previousState)) { + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } else if (!isReadyToSendFirstFrame(previousState)) { this.payload.release(); } } @@ -201,10 +270,15 @@ public final void handlePayload(Payload value) { return; } - final CoreSubscriber a = this.actual; + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); - this.requesterResponderSupport.remove(this.streamId, this); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + final CoreSubscriber a = this.actual; a.onNext(value); a.onComplete(); } @@ -222,11 +296,37 @@ public final void handleComplete() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } this.actual.onComplete(); } + @Override + public final void handlePermitError(Throwable cause) { + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_RESPONSE, p.metadata()); + } + p.release(); + + this.actual.onError(cause); + } + @Override public final void handleError(Throwable cause) { if (this.done) { @@ -244,7 +344,13 @@ public final void handleError(Throwable cause) { ReassemblyUtils.synchronizedRelease(this, previousState); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, cause); + } this.actual.onError(cause); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java index 36177e217..3d9d020ff 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -24,6 +24,7 @@ import io.netty.buffer.CompositeByteBuf; import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.exceptions.CanceledException; @@ -31,7 +32,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import org.reactivestreams.Subscription; import org.slf4j.Logger; @@ -54,10 +55,11 @@ final class RequestResponseResponderSubscriber final int maxFrameLength; final int maxInboundPayloadSize; final RequesterResponderSupport requesterResponderSupport; - final UnboundedProcessor sendProcessor; - + final DuplexConnection connection; final RSocket handler; + @Nullable final RequestInterceptor requestInterceptor; + boolean done; CompositeByteBuf frames; @@ -77,9 +79,11 @@ public RequestResponseResponderSubscriber( this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; + this.frames = ReassemblyUtils.addFollowingFrame( allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); @@ -93,7 +97,8 @@ public RequestResponseResponderSubscriber( this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.payloadDecoder = null; this.handler = null; @@ -129,14 +134,19 @@ public void onNext(@Nullable Payload p) { this.done = true; final int streamId = this.streamId; - final UnboundedProcessor sender = this.sendProcessor; + final DuplexConnection connection = this.connection; final ByteBufAllocator allocator = this.allocator; this.requesterResponderSupport.remove(streamId, this); if (p == null) { final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); - sender.onNext(completeFrame); + connection.sendFrame(streamId, completeFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } return; } @@ -147,13 +157,16 @@ public void onNext(@Nullable Payload p) { p.release(); - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); - sender.onNext(errorFrame); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } return; } } catch (IllegalReferenceCountException e) { @@ -164,14 +177,29 @@ public void onNext(@Nullable Payload p) { allocator, streamId, new CanceledException("Failed to validate payload. Cause" + e.getMessage())); - sender.onNext(errorFrame); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } return; } try { - sendReleasingPayload(streamId, FrameType.NEXT_COMPLETE, mtu, p, sender, allocator, false); - } catch (Throwable ignored) { + sendReleasingPayload(streamId, FrameType.NEXT_COMPLETE, mtu, p, connection, allocator, false); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + } catch (Throwable t) { currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } } } @@ -196,7 +224,12 @@ public void onError(Throwable t) { this.requesterResponderSupport.remove(streamId, this); final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); - this.sendProcessor.onNext(errorFrame); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } } @Override @@ -216,7 +249,8 @@ public void handleCancel() { // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); final CompositeByteBuf frames = this.frames; if (frames != null) { @@ -224,6 +258,10 @@ public void handleCancel() { frames.release(); } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } return; } @@ -231,9 +269,15 @@ public void handleCancel() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } } @Override @@ -256,12 +300,18 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) logger.debug("Reassembly has failed", t); // sends error frame from the responder side to tell that something went wrong + final int streamId = this.streamId; final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, + streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); - this.sendProcessor.onNext(errorFrame); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } return; } @@ -274,7 +324,8 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) } catch (Throwable t) { S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); ReferenceCountUtil.safeRelease(frames); @@ -284,9 +335,14 @@ public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, + streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); - this.sendProcessor.onNext(errorFrame); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } return; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java index cf70109ea..6182ca506 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -25,12 +25,13 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestNFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -42,7 +43,7 @@ import reactor.util.annotation.Nullable; final class RequestStreamRequesterFlux extends Flux - implements RequesterFrameHandler, Subscription, Scannable { + implements RequesterFrameHandler, LeasePermitHandler, Subscription, Scannable { final ByteBufAllocator allocator; final Payload payload; @@ -50,9 +51,12 @@ final class RequestStreamRequesterFlux extends Flux final int maxFrameLength; final int maxInboundPayloadSize; final RequesterResponderSupport requesterResponderSupport; - final UnboundedProcessor sendProcessor; + final DuplexConnection connection; final PayloadDecoder payloadDecoder; + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + volatile long state; static final AtomicLongFieldUpdater STATE = AtomicLongFieldUpdater.newUpdater(RequestStreamRequesterFlux.class, "state"); @@ -61,6 +65,8 @@ final class RequestStreamRequesterFlux extends Flux CoreSubscriber inboundSubscriber; CompositeByteBuf frames; boolean done; + long requested; + long produced; RequestStreamRequesterFlux(Payload payload, RequesterResponderSupport requesterResponderSupport) { this.allocator = requesterResponderSupport.getAllocator(); @@ -69,16 +75,24 @@ final class RequestStreamRequesterFlux extends Flux this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); } @Override public void subscribe(CoreSubscriber actual) { long previousState = markSubscribed(STATE, this); if (isSubscribedOrTerminated(previousState)) { - Operators.error( - actual, new IllegalStateException("RequestStreamFlux allows only a single Subscriber")); + final IllegalStateException e = + new IllegalStateException("RequestStreamFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); return; } @@ -86,15 +100,28 @@ public void subscribe(CoreSubscriber actual) { try { if (!isValid(this.mtu, this.maxFrameLength, p, false)) { lazyTerminate(STATE, this); - Operators.error( - actual, + + final IllegalArgumentException e = new IllegalArgumentException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, p.metadata()); + } + p.release(); + + Operators.error(actual, e); return; } } catch (IllegalReferenceCountException e) { lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + Operators.error(actual, e); return; } @@ -109,7 +136,11 @@ public final void request(long n) { return; } - long previousState = addRequestN(STATE, this, n); + this.requested = Operators.addCap(this.requested, n); + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + final long previousState = addRequestN(STATE, this, n, !leaseEnabled); if (isTerminated(previousState)) { return; } @@ -117,19 +148,37 @@ public final void request(long n) { if (hasRequested(previousState)) { if (isFirstFrameSent(previousState) && !isMaxAllowedRequestN(extractRequestN(previousState))) { - final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, this.streamId, n); - this.sendProcessor.onNext(requestNFrame); + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); } return; } + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + sendFirstPayload(this.payload, n); } + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstPayload(this.payload, extractRequestN(previousState)); + return true; + } + void sendFirstPayload(Payload payload, long initialRequestN) { final RequesterResponderSupport sm = this.requesterResponderSupport; - final UnboundedProcessor sender = this.sendProcessor; + final DuplexConnection connection = this.connection; final ByteBufAllocator allocator = this.allocator; final int streamId; @@ -140,14 +189,25 @@ void sendFirstPayload(Payload payload, long initialRequestN) { this.done = true; final long previousState = markTerminated(STATE, this); + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_STREAM, payload.metadata()); + } + payload.release(); if (!isTerminated(previousState)) { - this.inboundSubscriber.onError(Exceptions.unwrap(t)); + this.inboundSubscriber.onError(ut); } return; } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_STREAM, payload.metadata()); + } + try { sendReleasingPayload( streamId, @@ -155,16 +215,20 @@ void sendFirstPayload(Payload payload, long initialRequestN) { initialRequestN, this.mtu, payload, - sender, + connection, allocator, false); - } catch (Throwable e) { + } catch (Throwable t) { this.done = true; lazyTerminate(STATE, this); sm.remove(streamId, this); - this.inboundSubscriber.onError(e); + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + + this.inboundSubscriber.onError(t); return; } @@ -174,11 +238,14 @@ void sendFirstPayload(Payload payload, long initialRequestN) { return; } - sm.remove(streamId, this); - final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); - sender.onNext(cancelFrame); + connection.sendFrame(streamId, cancelFrame); + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } return; } @@ -189,14 +256,14 @@ void sendFirstPayload(Payload payload, long initialRequestN) { long requestN = extractRequestN(previousState); if (isMaxAllowedRequestN(requestN)) { final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); - sender.onNext(requestNFrame); + connection.sendFrame(streamId, requestNFrame); return; } if (requestN > initialRequestN) { final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); - sender.onNext(requestNFrame); + connection.sendFrame(streamId, requestNFrame); } } @@ -209,12 +276,18 @@ public final void cancel() { if (isFirstFrameSent(previousState)) { final int streamId = this.streamId; - this.requesterResponderSupport.remove(streamId, this); ReassemblyUtils.synchronizedRelease(this, previousState); - this.sendProcessor.onNext(CancelFrameCodec.encode(this.allocator, streamId)); - } else if (!hasRequested(previousState)) { + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + } else if (!isReadyToSendFirstFrame(previousState)) { // no need to send anything, since the first request has not happened this.payload.release(); } @@ -227,6 +300,35 @@ public final void handlePayload(Payload p) { return; } + final long produced = this.produced; + if (this.requested == produced) { + p.release(); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + + final IllegalStateException cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); + } + + this.inboundSubscriber.onError(cause); + return; + } + + this.produced = produced + 1; + this.inboundSubscriber.onNext(p); } @@ -243,11 +345,37 @@ public final void handleComplete() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); + } this.inboundSubscriber.onComplete(); } + @Override + public final void handlePermitError(Throwable cause) { + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_STREAM, p.metadata()); + } + p.release(); + + this.inboundSubscriber.onError(cause); + } + @Override public final void handleError(Throwable cause) { if (this.done) { @@ -263,10 +391,16 @@ public final void handleError(Throwable cause) { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); ReassemblyUtils.synchronizedRelease(this, previousState); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); + } + this.inboundSubscriber.onError(cause); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java index 8486d9b24..48903ae38 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -24,6 +24,7 @@ import io.netty.buffer.CompositeByteBuf; import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.exceptions.CanceledException; @@ -31,7 +32,7 @@ import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.RequestInterceptor; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import org.reactivestreams.Subscription; import org.slf4j.Logger; @@ -39,6 +40,7 @@ import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; import reactor.util.context.Context; final class RequestStreamResponderSubscriber @@ -54,7 +56,9 @@ final class RequestStreamResponderSubscriber final int maxFrameLength; final int maxInboundPayloadSize; final RequesterResponderSupport requesterResponderSupport; - final UnboundedProcessor sendProcessor; + final DuplexConnection connection; + + @Nullable final RequestInterceptor requestInterceptor; final RSocket handler; @@ -79,8 +83,9 @@ public RequestStreamResponderSubscriber( this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.handler = handler; this.frames = ReassemblyUtils.addFollowingFrame( @@ -96,7 +101,8 @@ public RequestStreamResponderSubscriber( this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); this.requesterResponderSupport = requesterResponderSupport; - this.sendProcessor = requesterResponderSupport.getSendProcessor(); + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); this.payloadDecoder = null; this.handler = null; @@ -120,50 +126,84 @@ public void onNext(Payload p) { } final int streamId = this.streamId; - final UnboundedProcessor sender = this.sendProcessor; + final DuplexConnection sender = this.connection; final ByteBufAllocator allocator = this.allocator; - if (p == null) { - final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); - sender.onNext(completeFrame); - return; - } - final int mtu = this.mtu; try { if (!isValid(mtu, this.maxFrameLength, p, false)) { p.release(); - this.handleCancel(); + if (!this.tryTerminateOnError()) { + return; + } - this.done = true; - final ByteBuf errorFrame = - ErrorFrameCodec.encode( - allocator, - streamId, - new CanceledException( - String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); - sender.onNext(errorFrame); + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + sender.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } return; } } catch (IllegalReferenceCountException e) { - this.handleCancel(); - this.done = true; + if (!this.tryTerminateOnError()) { + return; + } + final ByteBuf errorFrame = ErrorFrameCodec.encode( allocator, streamId, new CanceledException("Failed to validate payload. Cause" + e.getMessage())); - sender.onNext(errorFrame); + sender.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } return; } try { sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); } catch (Throwable t) { - this.handleCancel(); - this.done = true; + if (!this.tryTerminateOnError()) { + return; + } + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + } + } + + boolean tryTerminateOnError() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return false; + } + + this.done = true; + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return false; } + + currentSubscription.cancel(); + + return true; } @Override @@ -187,10 +227,15 @@ public void onError(Throwable t) { final int streamId = this.streamId; + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + this.requesterResponderSupport.remove(streamId, this); - final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); - this.sendProcessor.onNext(errorFrame); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } } @Override @@ -207,10 +252,15 @@ public void onComplete() { final int streamId = this.streamId; + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.connection.sendFrame(streamId, completeFrame); + this.requesterResponderSupport.remove(streamId, this); - final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); - this.sendProcessor.onNext(completeFrame); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); + } } @Override @@ -230,7 +280,8 @@ public final void handleCancel() { // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); final CompositeByteBuf frames = this.frames; if (frames != null) { @@ -238,6 +289,10 @@ public final void handleCancel() { frames.release(); } + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } return; } @@ -245,9 +300,15 @@ public final void handleCancel() { return; } - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } } @Override @@ -260,25 +321,32 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas try { ReassemblyUtils.addFollowingFrame( frames, followingFrame, hasFollows, this.maxInboundPayloadSize); - } catch (IllegalStateException t) { + } catch (IllegalStateException e) { // if subscription is null, it means that streams has not yet reassembled all the fragments // and fragmentation of the first frame was cancelled before S.lazySet(this, Operators.cancelledSubscription()); - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; this.frames = null; frames.release(); - logger.debug("Reassembly has failed", t); - // sends error frame from the responder side to tell that something went wrong final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, - new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); - this.sendProcessor.onNext(errorFrame); + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + + logger.debug("Reassembly has failed", e); return; } @@ -292,19 +360,26 @@ public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLas S.lazySet(this, Operators.cancelledSubscription()); this.done = true; - this.requesterResponderSupport.remove(this.streamId, this); + final int streamId = this.streamId; ReferenceCountUtil.safeRelease(frames); - logger.debug("Reassembly has failed", t); - // sends error frame from the responder side to tell that something went wrong final ByteBuf errorFrame = ErrorFrameCodec.encode( this.allocator, - this.streamId, + streamId, new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); - this.sendProcessor.onNext(errorFrame); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + + logger.debug("Reassembly has failed", t); return; } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java new file mode 100644 index 000000000..50da83b8f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java @@ -0,0 +1,135 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Availability; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.lease.Lease; +import io.rsocket.lease.MissingLeaseException; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Queue; + +final class RequesterLeaseTracker implements Availability { + + final String tag; + final int maximumAllowedAwaitingPermitHandlersNumber; + final Queue awaitingPermitHandlersQueue; + + Lease currentLease = null; + int availableRequests; + + boolean isDisposed; + Throwable t; + + RequesterLeaseTracker(String tag, int maximumAllowedAwaitingPermitHandlersNumber) { + this.tag = tag; + this.maximumAllowedAwaitingPermitHandlersNumber = maximumAllowedAwaitingPermitHandlersNumber; + this.awaitingPermitHandlersQueue = new ArrayDeque<>(); + } + + synchronized void issue(LeasePermitHandler leasePermitHandler) { + if (this.isDisposed) { + leasePermitHandler.handlePermitError(this.t); + return; + } + + final int availableRequests = this.availableRequests; + final Lease l = this.currentLease; + final boolean leaseReceived = l != null; + final boolean isExpired = leaseReceived && isExpired(l); + + if (leaseReceived && availableRequests > 0 && !isExpired) { + if (leasePermitHandler.handlePermit()) { + this.availableRequests = availableRequests - 1; + } + } else { + final Queue queue = this.awaitingPermitHandlersQueue; + if (this.maximumAllowedAwaitingPermitHandlersNumber > queue.size()) { + queue.offer(leasePermitHandler); + } else { + final String tag = this.tag; + final String message; + if (!leaseReceived) { + message = String.format("[%s] Lease was not received yet", tag); + } else if (isExpired) { + message = String.format("[%s] Missing leases. Lease is expired", tag); + } else { + message = + String.format( + "[%s] Missing leases. Issued [%s] request allowance is used", + tag, availableRequests); + } + + final Throwable t = new MissingLeaseException(message); + leasePermitHandler.handlePermitError(t); + } + } + } + + void handleLeaseFrame(ByteBuf leaseFrame) { + final int numberOfRequests = LeaseFrameCodec.numRequests(leaseFrame); + final int timeToLiveMillis = LeaseFrameCodec.ttl(leaseFrame); + final ByteBuf metadata = LeaseFrameCodec.metadata(leaseFrame); + + synchronized (this) { + final Lease lease = + Lease.create(Duration.ofMillis(timeToLiveMillis), numberOfRequests, metadata); + final Queue queue = this.awaitingPermitHandlersQueue; + + int availableRequests = lease.numberOfRequests(); + + this.currentLease = lease; + if (queue.size() > 0) { + do { + final LeasePermitHandler handler = queue.poll(); + if (handler.handlePermit()) { + availableRequests--; + } + } while (availableRequests > 0 && queue.size() > 0); + } + + this.availableRequests = availableRequests; + } + } + + public synchronized void dispose(Throwable t) { + this.isDisposed = true; + this.t = t; + + final Queue queue = this.awaitingPermitHandlersQueue; + final int size = queue.size(); + + for (int i = 0; i < size; i++) { + final LeasePermitHandler leasePermitHandler = queue.poll(); + + //noinspection ConstantConditions + leasePermitHandler.handlePermitError(t); + } + } + + @Override + public synchronized double availability() { + final Lease lease = this.currentLease; + return lease != null ? this.availableRequests / (double) lease.numberOfRequests() : 0.0d; + } + + static boolean isExpired(Lease currentLease) { + return System.currentTimeMillis() >= currentLease.expirationTime(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java index f5ddb199c..bea7dc1aa 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java @@ -1,11 +1,14 @@ package io.rsocket.core; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.RequestInterceptor; +import java.util.Objects; +import java.util.function.Function; import reactor.util.annotation.Nullable; class RequesterResponderSupport { @@ -15,28 +18,30 @@ class RequesterResponderSupport { private final int maxInboundPayloadSize; private final PayloadDecoder payloadDecoder; private final ByteBufAllocator allocator; + private final DuplexConnection connection; + @Nullable private final RequestInterceptor requestInterceptor; @Nullable final StreamIdSupplier streamIdSupplier; final IntObjectMap activeStreams; - private final UnboundedProcessor sendProcessor; - public RequesterResponderSupport( int mtu, int maxFrameLength, int maxInboundPayloadSize, PayloadDecoder payloadDecoder, - ByteBufAllocator allocator, - @Nullable StreamIdSupplier streamIdSupplier) { + DuplexConnection connection, + @Nullable StreamIdSupplier streamIdSupplier, + Function requestInterceptorFunction) { this.activeStreams = new IntObjectHashMap<>(); this.mtu = mtu; this.maxFrameLength = maxFrameLength; this.maxInboundPayloadSize = maxInboundPayloadSize; this.payloadDecoder = payloadDecoder; - this.allocator = allocator; + this.allocator = connection.alloc(); this.streamIdSupplier = streamIdSupplier; - this.sendProcessor = new UnboundedProcessor<>(); + this.connection = connection; + this.requestInterceptor = requestInterceptorFunction.apply((RSocket) this); } public int getMtu() { @@ -59,8 +64,18 @@ public ByteBufAllocator getAllocator() { return allocator; } - public UnboundedProcessor getSendProcessor() { - return sendProcessor; + public DuplexConnection getDuplexConnection() { + return connection; + } + + @Nullable + public RequesterLeaseTracker getRequesterLeaseTracker() { + return null; + } + + @Nullable + public RequestInterceptor getRequestInterceptor() { + return requestInterceptor; } /** @@ -103,6 +118,17 @@ public int addAndGetNextStreamId(FrameHandler frameHandler) { } } + public synchronized boolean add(int streamId, FrameHandler frameHandler) { + final IntObjectMap activeStreams = this.activeStreams; + // copy of Map.putIfAbsent(key, value) without `streamId` boxing + final FrameHandler previousHandler = activeStreams.get(streamId); + if (previousHandler == null) { + activeStreams.put(streamId, frameHandler); + return true; + } + return false; + } + /** * Resolves {@link FrameHandler} by {@code streamId} * @@ -123,6 +149,13 @@ public synchronized FrameHandler get(int streamId) { * instance equals to the passed one */ public synchronized boolean remove(int streamId, FrameHandler frameHandler) { - return this.activeStreams.remove(streamId, frameHandler); + final IntObjectMap activeStreams = this.activeStreams; + // copy of Map.remove(key, value) without `streamId` boxing + final FrameHandler curValue = activeStreams.get(streamId); + if (!Objects.equals(curValue, frameHandler)) { + return false; + } + activeStreams.remove(streamId); + return true; } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java b/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java index c431b3f3f..50bef5b70 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import java.time.Duration; @@ -15,6 +30,8 @@ import reactor.util.annotation.Nullable; import reactor.util.context.Context; +// A copy of this class exists in io.rsocket.loadbalance + class ResolvingOperator implements Disposable { static final CancellationException ON_DISPOSE = new CancellationException("Disposed"); @@ -153,19 +170,19 @@ public T block(@Nullable Duration timeout) { delay = System.nanoTime() + timeout.toNanos(); } for (; ; ) { - BiConsumer[] inners = this.subscribers; + subscribers = this.subscribers; - if (inners == READY) { + if (subscribers == READY) { final T value = this.value; if (value != null) { return value; } else { // value == null means racing between invalidate and this block // thus, we have to update the state again and see what happened - inners = this.subscribers; + subscribers = this.subscribers; } } - if (inners == TERMINATED) { + if (subscribers == TERMINATED) { RuntimeException re = Exceptions.propagate(this.t); re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); throw re; @@ -174,6 +191,12 @@ public T block(@Nullable Duration timeout) { throw new IllegalStateException("Timeout on Mono blocking read"); } + // connect again since invalidate() has happened in between + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + Thread.sleep(1); } } catch (InterruptedException ie) { @@ -186,6 +209,7 @@ public T block(@Nullable Duration timeout) { @SuppressWarnings("unchecked") final void terminate(Throwable t) { if (isDisposed()) { + Operators.onErrorDropped(t, Context.empty()); return; } @@ -307,6 +331,30 @@ protected void doOnDispose() { // no ops } + public final boolean connect() { + for (; ; ) { + final BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return false; + } + + if (a == READY) { + return true; + } + + if (a != EMPTY_UNSUBSCRIBED) { + // do nothing if already started + return true; + } + + if (SUBSCRIBERS.compareAndSet(this, a, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + return true; + } + } + } + final int add(BiConsumer ps) { for (; ; ) { BiConsumer[] a = this.subscribers; diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java b/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java new file mode 100644 index 000000000..fc7442f4a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Availability; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.lease.Lease; +import io.rsocket.lease.LeaseSender; +import io.rsocket.lease.MissingLeaseException; +import reactor.core.Disposable; +import reactor.core.publisher.BaseSubscriber; +import reactor.util.annotation.Nullable; + +final class ResponderLeaseTracker extends BaseSubscriber + implements Disposable, Availability { + + final String tag; + final ByteBufAllocator allocator; + final DuplexConnection connection; + + @Nullable volatile MutableLease currentLease; + + ResponderLeaseTracker(String tag, DuplexConnection connection, LeaseSender leaseSender) { + this.tag = tag; + this.connection = connection; + this.allocator = connection.alloc(); + + leaseSender.send().subscribe(this); + } + + @Nullable + Throwable use() { + final MutableLease lease = this.currentLease; + final String tag = this.tag; + + if (lease == null) { + return new MissingLeaseException(String.format("[%s] Lease was not issued yet", tag)); + } + + if (isExpired(lease)) { + return new MissingLeaseException(String.format("[%s] Missing leases. Lease is expired", tag)); + } + + final int allowedRequests = lease.allowedRequests; + final int remainingRequests = lease.remainingRequests; + if (remainingRequests <= 0) { + return new MissingLeaseException( + String.format( + "[%s] Missing leases. Issued [%s] request allowance is used", tag, allowedRequests)); + } + + lease.remainingRequests = remainingRequests - 1; + + return null; + } + + @Override + protected void hookOnNext(Lease lease) { + final int allowedRequests = lease.numberOfRequests(); + final int ttl = lease.timeToLiveInMillis(); + final long expireAt = lease.expirationTime(); + + this.currentLease = new MutableLease(allowedRequests, expireAt); + this.connection.sendFrame( + 0, LeaseFrameCodec.encode(this.allocator, ttl, allowedRequests, lease.metadata())); + } + + @Override + public double availability() { + final MutableLease lease = this.currentLease; + + if (lease == null || isExpired(lease)) { + return 0; + } + + return lease.remainingRequests / (double) lease.allowedRequests; + } + + static boolean isExpired(MutableLease currentLease) { + return System.currentTimeMillis() >= currentLease.expireAt; + } + + static final class MutableLease { + final int allowedRequests; + final long expireAt; + + int remainingRequests; + + MutableLease(int allowedRequests, long expireAt) { + this.allowedRequests = allowedRequests; + this.expireAt = expireAt; + + this.remainingRequests = allowedRequests; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/Resume.java b/rsocket-core/src/main/java/io/rsocket/core/Resume.java index 48133af98..fa0eedbfa 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/Resume.java +++ b/rsocket-core/src/main/java/io/rsocket/core/Resume.java @@ -160,7 +160,7 @@ boolean isCleanupStoreOnKeepAlive() { Function getStoreFactory(String tag) { return storeFactory != null ? storeFactory - : token -> new InMemoryResumableFramesStore(tag, 100_000); + : token -> new InMemoryResumableFramesStore(tag, token, 100_000); } Duration getStreamTimeout() { diff --git a/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java index 3e8cf34f5..568dada2e 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java @@ -22,6 +22,7 @@ import io.netty.buffer.Unpooled; import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCounted; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.exceptions.CanceledException; import io.rsocket.frame.CancelFrameCodec; @@ -32,7 +33,6 @@ import io.rsocket.frame.RequestFireAndForgetFrameCodec; import io.rsocket.frame.RequestResponseFrameCodec; import io.rsocket.frame.RequestStreamFrameCodec; -import io.rsocket.internal.UnboundedProcessor; import java.util.function.Consumer; import reactor.core.publisher.Operators; import reactor.util.context.Context; @@ -40,11 +40,13 @@ final class SendUtils { private static final Consumer DROPPED_ELEMENTS_CONSUMER = data -> { - try { - ReferenceCounted referenceCounted = (ReferenceCounted) data; - referenceCounted.release(); - } catch (Throwable e) { - // ignored + if (data instanceof ReferenceCounted) { + try { + ReferenceCounted referenceCounted = (ReferenceCounted) data; + referenceCounted.release(); + } catch (Throwable e) { + // ignored + } } }; @@ -55,7 +57,7 @@ static void sendReleasingPayload( FrameType frameType, int mtu, Payload payload, - UnboundedProcessor sendProcessor, + DuplexConnection connection, ByteBufAllocator allocator, boolean requester) { @@ -67,7 +69,7 @@ static void sendReleasingPayload( try { fragmentable = isFragmentable(mtu, data, metadata, false); } catch (IllegalReferenceCountException | NullPointerException e) { - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, false, e); + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); throw e; } @@ -81,11 +83,11 @@ static void sendReleasingPayload( FragmentationUtils.encodeFirstFragment( allocator, mtu, frameType, streamId, hasMetadata, slicedMetadata, slicedData); } catch (IllegalReferenceCountException e) { - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, false, e); + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); throw e; } - sendProcessor.onNext(first); + connection.sendFrame(streamId, first); boolean complete = frameType == FrameType.NEXT_COMPLETE; while (slicedData.isReadable() || slicedMetadata.isReadable()) { @@ -95,16 +97,16 @@ static void sendReleasingPayload( FragmentationUtils.encodeFollowsFragment( allocator, mtu, streamId, complete, slicedMetadata, slicedData); } catch (IllegalReferenceCountException e) { - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, true, e); + sendTerminalFrame(streamId, frameType, connection, allocator, requester, true, e); throw e; } - sendProcessor.onNext(following); + connection.sendFrame(streamId, following); } try { payload.release(); } catch (IllegalReferenceCountException e) { - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, true, e); + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); throw e; } } else { @@ -116,7 +118,7 @@ static void sendReleasingPayload( } catch (IllegalReferenceCountException e) { dataRetainedSlice.release(); - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, false, e); + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); throw e; } @@ -128,7 +130,7 @@ static void sendReleasingPayload( metadataRetainedSlice.release(); } - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, requester, false, e); + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); throw e; } @@ -161,7 +163,7 @@ static void sendReleasingPayload( throw new IllegalArgumentException("Unsupported frame type " + frameType); } - sendProcessor.onNext(requestFrame); + connection.sendFrame(streamId, requestFrame); } } @@ -171,7 +173,7 @@ static void sendReleasingPayload( long initialRequestN, int mtu, Payload payload, - UnboundedProcessor sendProcessor, + DuplexConnection connection, ByteBufAllocator allocator, boolean complete) { @@ -183,7 +185,7 @@ static void sendReleasingPayload( try { fragmentable = isFragmentable(mtu, data, metadata, true); } catch (IllegalReferenceCountException | NullPointerException e) { - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, false, e); + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); throw e; } @@ -204,11 +206,11 @@ static void sendReleasingPayload( slicedMetadata, slicedData); } catch (IllegalReferenceCountException e) { - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, false, e); + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); throw e; } - sendProcessor.onNext(first); + connection.sendFrame(streamId, first); while (slicedData.isReadable() || slicedMetadata.isReadable()) { final ByteBuf following; @@ -217,16 +219,16 @@ static void sendReleasingPayload( FragmentationUtils.encodeFollowsFragment( allocator, mtu, streamId, complete, slicedMetadata, slicedData); } catch (IllegalReferenceCountException e) { - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, true, e); + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); throw e; } - sendProcessor.onNext(following); + connection.sendFrame(streamId, following); } try { payload.release(); } catch (IllegalReferenceCountException e) { - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, true, e); + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); throw e; } } else { @@ -238,7 +240,7 @@ static void sendReleasingPayload( } catch (IllegalReferenceCountException e) { dataRetainedSlice.release(); - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, false, e); + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); throw e; } @@ -250,7 +252,7 @@ static void sendReleasingPayload( metadataRetainedSlice.release(); } - sendTerminalFrame(streamId, frameType, sendProcessor, allocator, true, false, e); + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); throw e; } @@ -281,14 +283,14 @@ static void sendReleasingPayload( throw new IllegalArgumentException("Unsupported frame type " + frameType); } - sendProcessor.onNext(requestFrame); + connection.sendFrame(streamId, requestFrame); } } static void sendTerminalFrame( int streamId, FrameType frameType, - UnboundedProcessor sendProcessor, + DuplexConnection connection, ByteBufAllocator allocator, boolean requester, boolean onFollowingFrame, @@ -297,7 +299,7 @@ static void sendTerminalFrame( if (onFollowingFrame) { if (requester) { final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); - sendProcessor.onNext(cancelFrame); + connection.sendFrame(streamId, cancelFrame); } else { final ByteBuf errorFrame = ErrorFrameCodec.encode( @@ -308,7 +310,7 @@ static void sendTerminalFrame( + frameType + " frame. Cause: " + t.getMessage())); - sendProcessor.onNext(errorFrame); + connection.sendFrame(streamId, errorFrame); } } else { switch (frameType) { @@ -317,7 +319,7 @@ static void sendTerminalFrame( case PAYLOAD: if (requester) { final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); - sendProcessor.onNext(cancelFrame); + connection.sendFrame(streamId, cancelFrame); } else { final ByteBuf errorFrame = ErrorFrameCodec.encode( @@ -325,7 +327,7 @@ static void sendTerminalFrame( streamId, new CanceledException( "Failed to encode " + frameType + " frame. Cause: " + t.getMessage())); - sendProcessor.onNext(errorFrame); + connection.sendFrame(streamId, errorFrame); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java index eb86c6734..5aae22e89 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java @@ -20,65 +20,73 @@ import io.netty.buffer.ByteBuf; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; import io.rsocket.exceptions.RejectedResumeException; import io.rsocket.exceptions.UnsupportedSetupException; -import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.ResumeFrameCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.resume.*; +import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.function.BiFunction; import java.util.function.Function; import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; abstract class ServerSetup { + final Duration timeout; + + protected ServerSetup(Duration timeout) { + this.timeout = timeout; + } + + Mono> init(DuplexConnection connection) { + return Mono.>create( + sink -> sink.onRequest(__ -> new SetupHandlingDuplexConnection(connection, sink))) + .timeout(this.timeout) + .or(connection.onClose().then(Mono.error(ClosedChannelException::new))); + } + abstract Mono acceptRSocketSetup( ByteBuf frame, - ClientServerInputMultiplexer multiplexer, - BiFunction> then); + DuplexConnection clientServerConnection, + BiFunction> then); - abstract Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer); + abstract Mono acceptRSocketResume(ByteBuf frame, DuplexConnection connection); void dispose() {} - Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { - DuplexConnection duplexConnection = multiplexer.asSetupConnection(); - return duplexConnection - .sendOne(ErrorFrameCodec.encode(duplexConnection.alloc(), 0, exception)) - .onErrorResume(err -> Mono.empty()); + void sendError(DuplexConnection duplexConnection, RSocketErrorException exception) { + duplexConnection.sendErrorAndClose(exception); + duplexConnection.receive().subscribe(); } static class DefaultServerSetup extends ServerSetup { + DefaultServerSetup(Duration timeout) { + super(timeout); + } + @Override public Mono acceptRSocketSetup( ByteBuf frame, - ClientServerInputMultiplexer multiplexer, - BiFunction> then) { + DuplexConnection duplexConnection, + BiFunction> then) { if (SetupFrameCodec.resumeEnabled(frame)) { - return sendError(multiplexer, new UnsupportedSetupException("resume not supported")) - .doFinally( - signalType -> { - frame.release(); - multiplexer.dispose(); - }); + sendError(duplexConnection, new UnsupportedSetupException("resume not supported")); + return duplexConnection.onClose(); } else { - return then.apply(new DefaultKeepAliveHandler(multiplexer), multiplexer); + return then.apply(new DefaultKeepAliveHandler(), duplexConnection); } } @Override - public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { - - return sendError(multiplexer, new RejectedResumeException("resume not supported")) - .doFinally( - signalType -> { - frame.release(); - multiplexer.dispose(); - }); + public Mono acceptRSocketResume(ByteBuf frame, DuplexConnection duplexConnection) { + sendError(duplexConnection, new RejectedResumeException("resume not supported")); + return duplexConnection.onClose(); } } @@ -90,11 +98,13 @@ static class ResumableServerSetup extends ServerSetup { private final boolean cleanupStoreOnKeepAlive; ResumableServerSetup( + Duration timeout, SessionManager sessionManager, Duration resumeSessionDuration, Duration resumeStreamTimeout, Function resumeStoreFactory, boolean cleanupStoreOnKeepAlive) { + super(timeout); this.sessionManager = sessionManager; this.resumeSessionDuration = resumeSessionDuration; this.resumeStreamTimeout = resumeStreamTimeout; @@ -105,47 +115,45 @@ static class ResumableServerSetup extends ServerSetup { @Override public Mono acceptRSocketSetup( ByteBuf frame, - ClientServerInputMultiplexer multiplexer, - BiFunction> then) { + DuplexConnection duplexConnection, + BiFunction> then) { if (SetupFrameCodec.resumeEnabled(frame)) { ByteBuf resumeToken = SetupFrameCodec.resumeToken(frame); - ResumableDuplexConnection connection = - sessionManager - .save( - new ServerRSocketSession( - multiplexer.asClientServerConnection(), - resumeSessionDuration, - resumeStreamTimeout, - resumeStoreFactory, - resumeToken, - cleanupStoreOnKeepAlive)) - .resumableConnection(); + final ResumableFramesStore resumableFramesStore = resumeStoreFactory.apply(resumeToken); + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "server", resumeToken, duplexConnection, resumableFramesStore); + final ServerRSocketSession serverRSocketSession = + new ServerRSocketSession( + resumeToken, + resumableDuplexConnection, + duplexConnection, + resumableFramesStore, + resumeSessionDuration, + cleanupStoreOnKeepAlive); + + sessionManager.save(serverRSocketSession, resumeToken); + return then.apply( - new ResumableKeepAliveHandler(connection), - new ClientServerInputMultiplexer(connection)); + new ResumableKeepAliveHandler( + resumableDuplexConnection, serverRSocketSession, serverRSocketSession), + resumableDuplexConnection); } else { - return then.apply(new DefaultKeepAliveHandler(multiplexer), multiplexer); + return then.apply(new DefaultKeepAliveHandler(), duplexConnection); } } @Override - public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { + public Mono acceptRSocketResume(ByteBuf frame, DuplexConnection duplexConnection) { ServerRSocketSession session = sessionManager.get(ResumeFrameCodec.token(frame)); if (session != null) { - return session - .continueWith(multiplexer.asClientServerConnection()) - .resumeWith(frame) - .onClose() - .then(); + session.resumeWith(frame, duplexConnection); + return duplexConnection.onClose(); } else { - return sendError(multiplexer, new RejectedResumeException("unknown resume token")) - .doFinally( - s -> { - frame.release(); - multiplexer.dispose(); - }); + sendError(duplexConnection, new RejectedResumeException("unknown resume token")); + return duplexConnection.onClose(); } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java new file mode 100644 index 000000000..3beedf97f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java @@ -0,0 +1,176 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +class SetupHandlingDuplexConnection extends Flux + implements DuplexConnection, CoreSubscriber, Subscription { + + final DuplexConnection source; + final MonoSink> sink; + + Subscription s; + boolean firstFrameReceived = false; + + CoreSubscriber actual; + + boolean done; + Throwable t; + + SetupHandlingDuplexConnection( + DuplexConnection source, MonoSink> sink) { + this.source = source; + this.sink = sink; + + source.receive().subscribe(this); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (done) { + final Throwable t = this.t; + if (t == null) { + Operators.complete(actual); + } else { + Operators.error(actual, t); + } + return; + } + + this.actual = actual; + actual.onSubscribe(this); + } + + @Override + public void request(long n) { + if (n != Long.MAX_VALUE) { + actual.onError(new IllegalArgumentException("Only unbounded request is allowed")); + return; + } + + s.request(Long.MAX_VALUE); + } + + @Override + public void cancel() { + source.dispose(); + s.cancel(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + s.request(1); + } + } + + @Override + public void onNext(ByteBuf frame) { + if (!firstFrameReceived) { + firstFrameReceived = true; + sink.success(Tuples.of(frame, this)); + return; + } + + actual.onNext(frame); + } + + @Override + public void onError(Throwable t) { + if (done) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.done = true; + this.t = t; + + if (!firstFrameReceived) { + sink.error(t); + return; + } + + final CoreSubscriber actual = this.actual; + if (actual != null) { + actual.onError(t); + } + } + + @Override + public void onComplete() { + if (done) { + return; + } + + this.done = true; + + if (!firstFrameReceived) { + sink.error(new ClosedChannelException()); + return; + } + + final CoreSubscriber actual = this.actual; + if (actual != null) { + actual.onComplete(); + } + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + source.sendErrorAndClose(e); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public String toString() { + return "SetupHandlingDuplexConnection{" + "source=" + source + ", done=" + done + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java new file mode 100644 index 000000000..3035696b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java @@ -0,0 +1,255 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.isReadyToSendFirstFrame; +import static io.rsocket.core.StateUtils.isSubscribedOrTerminated; +import static io.rsocket.core.StateUtils.isTerminated; +import static io.rsocket.core.StateUtils.lazyTerminate; +import static io.rsocket.core.StateUtils.markReadyToSendFirstFrame; +import static io.rsocket.core.StateUtils.markSubscribed; +import static io.rsocket.core.StateUtils.markTerminated; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class SlowFireAndForgetRequesterMono extends Mono + implements LeasePermitHandler, Subscription, Scannable { + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(SlowFireAndForgetRequesterMono.class, "state"); + + final Payload payload; + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + CoreSubscriber actual; + + SlowFireAndForgetRequesterMono( + Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + long previousState = markSubscribed(STATE, this, !leaseEnabled); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + this.actual = actual; + actual.onSubscribe(this); + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstFrame(p); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstFrame(this.payload); + return true; + } + + void sendFirstFrame(Payload p) { + final CoreSubscriber actual = this.actual; + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(ut); + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + if (isTerminated(this.state)) { + p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + + return; + } + + sendReleasingPayload( + streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + actual.onError(e); + return; + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + // no ops + } + + @Override + public void cancel() { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + return; + } + + if (!isReadyToSendFirstFrame(previousState)) { + this.payload.release(); + } + } + + @Override + public final void handlePermitError(Throwable cause) { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_RESPONSE, p.metadata()); + } + + p.release(); + + this.actual.onError(cause); + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(FireAndForgetMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java index b3857bc12..2b6a0e09a 100644 --- a/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java @@ -13,27 +13,36 @@ final class StateUtils { /** Bit Flag that indicates Requester Producer has been subscribed once */ static final long SUBSCRIBED_FLAG = 0b000000000000000000000000000000001_0000000000000000000000000000000L; + /** Bit Flag that indicates that the first payload in RequestChannel scenario is received */ + static final long FIRST_PAYLOAD_RECEIVED_FLAG = + 0b000000000000000000000000000000010_0000000000000000000000000000000L; + /** + * Bit Flag that indicates that the logical stream is ready to send the first initial frame + * (applicable for requester only) + */ + static final long READY_TO_SEND_FIRST_FRAME_FLAG = + 0b000000000000000000000000000000100_0000000000000000000000000000000L; /** * Bit Flag that indicates that sent first initial frame was sent (in case of requester) or * consumed (if responder) */ static final long FIRST_FRAME_SENT_FLAG = - 0b000000000000000000000000000000010_0000000000000000000000000000000L; + 0b000000000000000000000000000001000_0000000000000000000000000000000L; /** Bit Flag that indicates that there is a frame being reassembled */ static final long REASSEMBLING_FLAG = - 0b000000000000000000000000000000100_0000000000000000000000000000000L; + 0b000000000000000000000000000010000_0000000000000000000000000000000L; /** * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates * that the inbound is terminated */ static final long INBOUND_TERMINATED_FLAG = - 0b000000000000000000000000000001000_0000000000000000000000000000000L; + 0b000000000000000000000000000100000_0000000000000000000000000000000L; /** * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates * that the outbound is terminated */ static final long OUTBOUND_TERMINATED_FLAG = - 0b000000000000000000000000000010000_0000000000000000000000000000000L; + 0b000000000000000000000000001000000_0000000000000000000000000000000L; /** Initial state for any request operator */ static final long UNSUBSCRIBED_STATE = 0b000000000000000000000000000000000_0000000000000000000000000000000L; @@ -54,6 +63,24 @@ final class StateUtils { * @return return previous state before setting the new one */ static long markSubscribed(AtomicLongFieldUpdater updater, T instance) { + return markSubscribed(updater, instance, false); + } + + /** + * Adds (if possible) to the given state the {@link #SUBSCRIBED_FLAG} flag which indicates that + * the given stream has already been subscribed once + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been subscribed once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param markPrepared indicates whether the given instance should be marked as prepared + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markSubscribed( + AtomicLongFieldUpdater updater, T instance, boolean markPrepared) { for (; ; ) { long state = updater.get(instance); @@ -65,7 +92,10 @@ static long markSubscribed(AtomicLongFieldUpdater updater, T instance) { return state; } - if (updater.compareAndSet(instance, state, state | SUBSCRIBED_FLAG)) { + if (updater.compareAndSet( + instance, + state, + state | SUBSCRIBED_FLAG | (markPrepared ? READY_TO_SEND_FIRST_FRAME_FLAG : 0))) { return state; } } @@ -121,6 +151,86 @@ static boolean isFirstFrameSent(long state) { return (state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG; } + /** + * Adds (if possible) to the given state the {@link #READY_TO_SEND_FIRST_FRAME_FLAG} flag which + * indicates that the logical stream is ready for initial frame sending. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been marked as prepared + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReadyToSendFirstFrame(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & READY_TO_SEND_FIRST_FRAME_FLAG) == READY_TO_SEND_FIRST_FRAME_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | READY_TO_SEND_FIRST_FRAME_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the logical stream is ready for initial frame sending + * + * @param state to check whether stream is prepared for initial frame sending + * @return true if the {@link #READY_TO_SEND_FIRST_FRAME_FLAG} flag is set + */ + static boolean isReadyToSendFirstFrame(long state) { + return (state & READY_TO_SEND_FIRST_FRAME_FLAG) == READY_TO_SEND_FIRST_FRAME_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #FIRST_PAYLOAD_RECEIVED_FLAG} flag which + * indicates that the logical stream is ready for initial frame sending. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been marked as prepared + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markFirstPayloadReceived(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & FIRST_PAYLOAD_RECEIVED_FLAG) == FIRST_PAYLOAD_RECEIVED_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | FIRST_PAYLOAD_RECEIVED_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the logical stream is ready for initial frame sending + * + * @param state to check whether stream is established + * @return true if the {@link #FIRST_PAYLOAD_RECEIVED_FLAG} flag is set + */ + static boolean isFirstPayloadReceived(long state) { + return (state & FIRST_PAYLOAD_RECEIVED_FLAG) == FIRST_PAYLOAD_RECEIVED_FLAG; + } + /** * Adds (if possible) to the given state the {@link #REASSEMBLING_FLAG} flag which indicates that * there is a payload reassembling in progress. @@ -327,14 +437,12 @@ static boolean isSubscribedOrTerminated(long state) { return state == TERMINATED_STATE || (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; } - /** - * @param updater - * @param instance - * @param toAdd - * @param - * @return - */ static long addRequestN(AtomicLongFieldUpdater updater, T instance, long toAdd) { + return addRequestN(updater, instance, toAdd, false); + } + + static long addRequestN( + AtomicLongFieldUpdater updater, T instance, long toAdd, boolean markPrepared) { long currentState, flags, requestN, nextRequestN; for (; ; ) { currentState = updater.get(instance); @@ -348,7 +456,7 @@ static long addRequestN(AtomicLongFieldUpdater updater, T instance, long return currentState; } - flags = currentState & FLAGS_MASK; + flags = (currentState & FLAGS_MASK) | (markPrepared ? READY_TO_SEND_FIRST_FRAME_FLAG : 0); nextRequestN = addRequestN(requestN, toAdd); if (updater.compareAndSet(instance, currentState, nextRequestN | flags)) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java index 28f39459d..fc146c935 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java @@ -60,6 +60,10 @@ public static boolean hasFollows(ByteBuf byteBuf) { return (flags(byteBuf) & FLAGS_F) == FLAGS_F; } + public static boolean hasComplete(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_C) == FLAGS_C; + } + public static int streamId(ByteBuf byteBuf) { byteBuf.markReaderIndex(); int streamId = byteBuf.readInt(); diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java index 66d18c8a7..d581731a3 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.frame; import io.netty.buffer.ByteBuf; @@ -99,8 +114,9 @@ private static ByteBuf getData(ByteBuf frame, FrameType frameType) { case REQUEST_CHANNEL: data = RequestChannelFrameCodec.data(frame); break; - // Payload and synthetic types + // Payload, KeepAlive and synthetic types case PAYLOAD: + case KEEPALIVE: case NEXT: case NEXT_COMPLETE: case COMPLETE: diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java index e6874c097..0d8063e0b 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java @@ -52,12 +52,12 @@ public Payload apply(ByteBuf byteBuf) { throw new IllegalArgumentException("unsupported frame type: " + type); } - ByteBuffer data = ByteBuffer.allocateDirect(d.readableBytes()); + ByteBuffer data = ByteBuffer.allocate(d.readableBytes()); data.put(d.nioBuffer()); data.flip(); if (m != null) { - ByteBuffer metadata = ByteBuffer.allocateDirect(m.readableBytes()); + ByteBuffer metadata = ByteBuffer.allocate(m.readableBytes()); metadata.put(m.nioBuffer()); metadata.flip(); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java index 9668e5e18..0296b0a07 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java @@ -1,30 +1,56 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.internal; +import io.netty.buffer.ByteBuf; import io.rsocket.DuplexConnection; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; public abstract class BaseDuplexConnection implements DuplexConnection { - private MonoProcessor onClose = MonoProcessor.create(); + protected final Sinks.Empty onClose = Sinks.empty(); + protected final UnboundedProcessor sender = new UnboundedProcessor(onClose::tryEmitEmpty); - public BaseDuplexConnection() { - onClose.doFinally(s -> doOnClose()).subscribe(); + public BaseDuplexConnection() {} + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + sender.tryEmitPrioritized(frame); + } else { + sender.tryEmitNormal(frame); + } } protected abstract void doOnClose(); @Override - public final Mono onClose() { - return onClose; + public Mono onClose() { + return onClose.asMono(); } @Override public final void dispose() { - onClose.onComplete(); + doOnClose(); } @Override + @SuppressWarnings("ConstantConditions") public final boolean isDisposed() { - return onClose.isDisposed(); + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java index 179a7a757..8b1378917 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -1,244 +1 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Closeable; -import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderCodec; -import io.rsocket.frame.FrameUtil; -import io.rsocket.plugins.DuplexConnectionInterceptor.Type; -import io.rsocket.plugins.InitializingInterceptorRegistry; -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -/** - * {@link DuplexConnection#receive()} is a single stream on which the following type of frames - * arrive: - * - *

    - *
  • Frames for streams initiated by the initiator of the connection (client). - *
  • Frames for streams initiated by the acceptor of the connection (server). - *
- * - *

The only way to differentiate these two frames is determining whether the stream Id is odd or - * even. Even IDs are for the streams initiated by server and odds are for streams initiated by the - * client. - * - * @deprecated since 1.1.0-M1 in favor of package-private {@link - * io.rsocket.core.ClientServerInputMultiplexer} - */ -@Deprecated -public class ClientServerInputMultiplexer implements Closeable { - private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); - private static final InitializingInterceptorRegistry emptyInterceptorRegistry = - new InitializingInterceptorRegistry(); - - private final DuplexConnection setupConnection; - private final DuplexConnection serverConnection; - private final DuplexConnection clientConnection; - private final DuplexConnection source; - private final DuplexConnection clientServerConnection; - - private boolean setupReceived; - - public ClientServerInputMultiplexer(DuplexConnection source) { - this(source, emptyInterceptorRegistry, false); - } - - public ClientServerInputMultiplexer( - DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { - this.source = source; - final MonoProcessor> setup = MonoProcessor.create(); - final MonoProcessor> server = MonoProcessor.create(); - final MonoProcessor> client = MonoProcessor.create(); - - source = registry.initConnection(Type.SOURCE, source); - setupConnection = - registry.initConnection(Type.SETUP, new InternalDuplexConnection(source, setup)); - serverConnection = - registry.initConnection(Type.SERVER, new InternalDuplexConnection(source, server)); - clientConnection = - registry.initConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); - clientServerConnection = new InternalDuplexConnection(source, client, server); - - source - .receive() - .groupBy( - frame -> { - int streamId = FrameHeaderCodec.streamId(frame); - final Type type; - if (streamId == 0) { - switch (FrameHeaderCodec.frameType(frame)) { - case SETUP: - case RESUME: - case RESUME_OK: - type = Type.SETUP; - setupReceived = true; - break; - case LEASE: - case KEEPALIVE: - case ERROR: - type = isClient ? Type.CLIENT : Type.SERVER; - break; - default: - type = isClient ? Type.SERVER : Type.CLIENT; - } - } else if ((streamId & 0b1) == 0) { - type = Type.SERVER; - } else { - type = Type.CLIENT; - } - if (!isClient && type != Type.SETUP && !setupReceived) { - frame.release(); - throw new IllegalStateException( - "SETUP or LEASE frame must be received before any others."); - } - return type; - }) - .subscribe( - group -> { - switch (group.key()) { - case SETUP: - setup.onNext(group); - break; - - case SERVER: - server.onNext(group); - break; - - case CLIENT: - client.onNext(group); - break; - } - }, - ex -> { - setup.onError(ex); - server.onError(ex); - client.onError(ex); - }); - } - - public DuplexConnection asClientServerConnection() { - return clientServerConnection; - } - - public DuplexConnection asServerConnection() { - return serverConnection; - } - - public DuplexConnection asClientConnection() { - return clientConnection; - } - - public DuplexConnection asSetupConnection() { - return setupConnection; - } - - @Override - public void dispose() { - source.dispose(); - } - - @Override - public boolean isDisposed() { - return source.isDisposed(); - } - - @Override - public Mono onClose() { - return source.onClose(); - } - - private static class InternalDuplexConnection implements DuplexConnection { - private final DuplexConnection source; - private final MonoProcessor>[] processors; - private final boolean debugEnabled; - - @SafeVarargs - public InternalDuplexConnection( - DuplexConnection source, MonoProcessor>... processors) { - this.source = source; - this.processors = processors; - this.debugEnabled = LOGGER.isDebugEnabled(); - } - - @Override - public Mono send(Publisher frame) { - if (debugEnabled) { - frame = Flux.from(frame).doOnNext(f -> LOGGER.debug("sending -> " + FrameUtil.toString(f))); - } - - return source.send(frame); - } - - @Override - public Mono sendOne(ByteBuf frame) { - if (debugEnabled) { - LOGGER.debug("sending -> " + FrameUtil.toString(frame)); - } - - return source.sendOne(frame); - } - - @Override - public Flux receive() { - return Flux.fromArray(processors) - .flatMap( - p -> - p.flatMapMany( - f -> { - if (debugEnabled) { - return f.doOnNext( - frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); - } else { - return f; - } - })); - } - - @Override - public ByteBufAllocator alloc() { - return source.alloc(); - } - - @Override - public void dispose() { - source.dispose(); - } - - @Override - public boolean isDisposed() { - return source.isDisposed(); - } - - @Override - public Mono onClose() { - return source.onClose(); - } - - @Override - public double availability() { - return source.availability(); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java index cb8b5d63d..c96a7aed2 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -16,19 +16,23 @@ package io.rsocket.internal; -import io.netty.util.ReferenceCounted; +import io.netty.buffer.ByteBuf; import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; import java.util.Objects; import java.util.Queue; +import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import org.reactivestreams.Subscriber; +import java.util.stream.Stream; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; +import reactor.core.Disposable; import reactor.core.Exceptions; import reactor.core.Fuseable; -import reactor.core.publisher.FluxProcessor; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; import reactor.core.publisher.Operators; +import reactor.util.Logger; import reactor.util.annotation.Nullable; import reactor.util.concurrent.Queues; import reactor.util.context.Context; @@ -37,194 +41,399 @@ * A Processor implementation that takes a custom queue and allows only a single subscriber. * *

The implementation keeps the order of signals. - * - * @param the input and output type */ -public final class UnboundedProcessor extends FluxProcessor - implements Fuseable.QueueSubscription, Fuseable { - - final Queue queue; - final Queue priorityQueue; - - volatile boolean done; +public final class UnboundedProcessor extends Flux + implements Scannable, + Disposable, + CoreSubscriber, + Fuseable.QueueSubscription, + Fuseable { + + final Queue queue; + final Queue priorityQueue; + final Runnable onFinalizedHook; + @Nullable final Logger logger; + + boolean cancelled; + boolean done; Throwable error; - // important to not loose the downstream too early and miss discard hook, while - // having relevant hasDownstreams() - boolean hasDownstream; - volatile CoreSubscriber actual; - - volatile boolean cancelled; - - volatile int once; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "once"); - - volatile int wip; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "wip"); + CoreSubscriber actual; + + static final long FLAG_FINALIZED = + 0b1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_DISPOSED = + 0b0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_TERMINATED = + 0b0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_CANCELLED = + 0b0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_HAS_VALUE = + 0b0000_1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_HAS_REQUEST = + 0b0000_0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_SUBSCRIBER_READY = + 0b0000_0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_SUBSCRIBED_ONCE = + 0b0000_0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long MAX_WIP_VALUE = + 0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111L; + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "state"); volatile int discardGuard; - @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater DISCARD_GUARD = AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "discardGuard"); volatile long requested; - @SuppressWarnings("rawtypes") static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "requested"); + ByteBuf last; + boolean outputFused; public UnboundedProcessor() { + this(() -> {}); + } + + UnboundedProcessor(Logger logger) { + this(() -> {}, logger); + } + + public UnboundedProcessor(Runnable onFinalizedHook) { + this(onFinalizedHook, null); + } + + UnboundedProcessor(Runnable onFinalizedHook, @Nullable Logger logger) { + this.onFinalizedHook = onFinalizedHook; this.queue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); this.priorityQueue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.logger = logger; } @Override - public int getBufferSize() { - return Integer.MAX_VALUE; + public Stream inners() { + return hasDownstreams() ? Stream.of(Scannable.from(this.actual)) : Stream.empty(); } @Override public Object scanUnsafe(Attr key) { - if (Attr.BUFFERED == key) return queue.size(); + if (Attr.ACTUAL == key) return isSubscriberReady(this.state) ? this.actual : null; + if (Attr.BUFFERED == key) return this.queue.size() + this.priorityQueue.size(); if (Attr.PREFETCH == key) return Integer.MAX_VALUE; - return super.scanUnsafe(key); + if (Attr.CANCELLED == key) { + final long state = this.state; + return isCancelled(state) || isDisposed(state); + } + + return null; } - void drainRegular(Subscriber a) { - int missed = 1; + public boolean tryEmitPrioritized(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } - final Queue q = queue; - final Queue pq = priorityQueue; + if (!this.priorityQueue.offer(t)) { + onError(Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext())); + release(t); + return false; + } - for (; ; ) { + final long previousState = markValueAdded(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } - long r = requested; - long e = 0L; + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + return true; + } - while (r != e) { - boolean d = done; + if (isWorkInProgress(previousState)) { + return true; + } - T t; - boolean empty; + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_HAS_VALUE) + 1); + } + } + return true; + } - if (!pq.isEmpty()) { - t = pq.poll(); - empty = false; - } else { - t = q.poll(); - empty = t == null; - } + public boolean tryEmitNormal(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } - if (checkTerminated(d, empty, a)) { - return; - } + if (!this.queue.offer(t)) { + onError(Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext())); + release(t); + return false; + } - if (empty) { - break; - } + final long previousState = markValueAdded(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } - a.onNext(t); + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + return true; + } - e++; + if (isWorkInProgress(previousState)) { + return true; } - if (r == e) { - if (checkTerminated(done, q.isEmpty() && pq.isEmpty(), a)) { - return; - } + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_HAS_VALUE) + 1); } + } - if (e != 0 && r != Long.MAX_VALUE) { - REQUESTED.addAndGet(this, -e); + return true; + } + + public boolean tryEmitFinal(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } + + this.last = t; + this.done = true; + + final long previousState = markValueAddedAndTerminated(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + this.actual.onComplete(); + return true; } - missed = WIP.addAndGet(this, -missed); - if (missed == 0) { - break; + if (isWorkInProgress(previousState)) { + return true; } + + drainRegular((previousState | FLAG_TERMINATED | FLAG_HAS_VALUE) + 1); } + + return true; } - void drainFused(Subscriber a) { - int missed = 1; + @Deprecated + public void onNextPrioritized(ByteBuf t) { + tryEmitPrioritized(t); + } - for (; ; ) { + @Override + @Deprecated + public void onNext(ByteBuf t) { + tryEmitNormal(t); + } - if (cancelled) { - this.clear(); - hasDownstream = false; - return; - } + @Override + @Deprecated + public void onError(Throwable t) { + if (this.done || this.cancelled) { + Operators.onErrorDropped(t, currentContext()); + return; + } - boolean d = done; + this.error = t; + this.done = true; - a.onNext(null); + final long previousState = markTerminatedOrFinalized(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isCancelled(previousState) + || isTerminated(previousState)) { + Operators.onErrorDropped(t, currentContext()); + return; + } - if (d) { - hasDownstream = false; + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion scenario + this.actual.onError(t); + return; + } - Throwable ex = error; - if (ex != null) { - a.onError(ex); - } else { - a.onComplete(); - } + if (isWorkInProgress(previousState)) { return; } - missed = WIP.addAndGet(this, -missed); - if (missed == 0) { - break; + if (!hasValue(previousState)) { + // fast path no-values scenario + this.actual.onError(t); + return; } + + drainRegular((previousState | FLAG_TERMINATED) + 1); } } - public void drain() { - if (WIP.getAndIncrement(this) != 0) { - if (cancelled) { - this.clear(); - } + @Override + @Deprecated + public void onComplete() { + if (this.done || this.cancelled) { return; } - int missed = 1; + this.done = true; + + final long previousState = markTerminatedOrFinalized(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isCancelled(previousState) + || isTerminated(previousState)) { + return; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion scenario + this.actual.onComplete(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (!hasValue(previousState)) { + this.actual.onComplete(); + return; + } + + drainRegular((previousState | FLAG_TERMINATED) + 1); + } + } + + void drainRegular(long expectedState) { + final CoreSubscriber a = this.actual; + final Queue q = this.queue; + final Queue pq = this.priorityQueue; for (; ; ) { - Subscriber a = actual; - if (a != null) { - if (outputFused) { - drainFused(a); - } else { - drainRegular(a); + long r = this.requested; + long e = 0L; + + boolean empty = false; + boolean done; + while (r != e) { + // done has to be read before queue.poll to ensure there was no racing: + // Thread1: <#drain>: queue.poll(null) --------------------> this.done(true) + // Thread2: ------------------> <#onNext(V)> --> <#onComplete()> + done = this.done; + + ByteBuf t = pq.poll(); + empty = t == null; + + if (empty) { + t = q.poll(); + empty = t == null; } + + if (checkTerminated(done, empty, true, a)) { + if (!empty) { + release(t); + } + return; + } + + if (empty) { + break; + } + + a.onNext(t); + + e++; + } + + if (r == e) { + // done has to be read before queue.isEmpty to ensure there was no racing: + // Thread1: <#drain>: queue.isEmpty(true) --------------------> this.done(true) + // Thread2: --------------------> <#onNext(V)> ---> <#onComplete()> + done = this.done; + empty = q.isEmpty() && pq.isEmpty(); + + if (checkTerminated(done, empty, false, a)) { + return; + } + } + + if (e != 0 && r != Long.MAX_VALUE) { + r = REQUESTED.addAndGet(this, -e); + } + + expectedState = markWorkDone(this, expectedState, r > 0, !empty); + if (isCancelled(expectedState)) { + clearAndFinalize(this); return; } - missed = WIP.addAndGet(this, -missed); - if (missed == 0) { + if (isDisposed(expectedState)) { + clearAndFinalize(this); + a.onError(new CancellationException("Disposed")); + return; + } + + if (!isWorkInProgress(expectedState)) { break; } } } - boolean checkTerminated(boolean d, boolean empty, Subscriber a) { - if (cancelled) { - this.clear(); - hasDownstream = false; + boolean checkTerminated( + boolean done, boolean empty, boolean hasDemand, CoreSubscriber a) { + final long state = this.state; + if (isCancelled(state)) { + clearAndFinalize(this); + return true; + } + + if (isDisposed(state)) { + clearAndFinalize(this); + a.onError(new CancellationException("Disposed")); return true; } - if (d && empty) { - Throwable e = error; - hasDownstream = false; + + if (done && empty) { + if (!isTerminated(state)) { + // proactively return if volatile field is not yet set to needed state + return false; + } + final ByteBuf last = this.last; + if (last != null) { + if (!hasDemand) { + return false; + } + this.last = null; + a.onNext(last); + } + clearAndFinalize(this); + Throwable e = this.error; if (e != null) { a.onError(e); } else { @@ -238,7 +447,8 @@ boolean checkTerminated(boolean d, boolean empty, Subscriber a) { @Override public void onSubscribe(Subscription s) { - if (done || cancelled) { + final long state = this.state; + if (isFinalized(state) || isTerminated(state) || isCancelled(state) || isDisposed(state)) { s.cancel(); } else { s.request(Long.MAX_VALUE); @@ -252,150 +462,216 @@ public int getPrefetch() { @Override public Context currentContext() { - CoreSubscriber actual = this.actual; - return actual != null ? actual.currentContext() : Context.empty(); + return isSubscriberReady(this.state) ? this.actual.currentContext() : Context.empty(); } - public void onNextPrioritized(T t) { - if (done || cancelled) { - Operators.onNextDropped(t, currentContext()); - release(t); + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + long previousState = markSubscribedOnce(this); + if (isSubscribedOnce(previousState)) { + Operators.error( + actual, new IllegalStateException("UnboundedProcessor allows only a single Subscriber")); return; } - if (!priorityQueue.offer(t)) { - Throwable ex = - Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); - onError(Operators.onOperatorError(null, ex, t, currentContext())); - release(t); + if (isDisposed(previousState)) { + Operators.error(actual, new CancellationException("Disposed")); return; } - drain(); - } - @Override - public void onNext(T t) { - if (done || cancelled) { - Operators.onNextDropped(t, currentContext()); - release(t); + actual.onSubscribe(this); + this.actual = actual; + + previousState = markSubscriberReady(this); + + if (isSubscriberReady(previousState)) { return; } - if (!queue.offer(t)) { - Throwable ex = - Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); - onError(Operators.onOperatorError(null, ex, t, currentContext())); - release(t); + if (this.outputFused) { + if (isCancelled(previousState)) { + return; + } + + if (isDisposed(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (hasValue(previousState)) { + actual.onNext(null); + } + + if (isTerminated(previousState)) { + final Throwable e = this.error; + if (e != null) { + actual.onError(e); + } else { + actual.onComplete(); + } + } return; } - drain(); - } - @Override - public void onError(Throwable t) { - if (done || cancelled) { - Operators.onErrorDropped(t, currentContext()); + if (isCancelled(previousState)) { + clearAndFinalize(this); return; } - error = t; - done = true; - - drain(); - } - - @Override - public void onComplete() { - if (done || cancelled) { + if (isDisposed(previousState)) { + clearAndFinalize(this); + actual.onError(new CancellationException("Disposed")); return; } - done = true; + if (!hasValue(previousState)) { + if (isTerminated(previousState)) { + clearAndFinalize(this); + final Throwable e = this.error; + if (e != null) { + actual.onError(e); + } else { + actual.onComplete(); + } + } + return; + } - drain(); + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_SUBSCRIBER_READY) + 1); + } } @Override - public void subscribe(CoreSubscriber actual) { - Objects.requireNonNull(actual, "subscribe"); - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { + public void request(long n) { + if (Operators.validate(n)) { + if (this.outputFused) { + final long state = this.state; + if (isSubscriberReady(state)) { + this.actual.onNext(null); + } + return; + } - actual.onSubscribe(this); - this.actual = actual; - if (cancelled) { - this.hasDownstream = false; - } else { - drain(); + Operators.addCap(REQUESTED, this, n); + + final long previousState = markRequestAdded(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (isSubscriberReady(previousState) && hasValue(previousState)) { + drainRegular((previousState | FLAG_HAS_REQUEST) + 1); } - } else { - Operators.error( - actual, - new IllegalStateException("UnboundedProcessor " + "allows only a single Subscriber")); } } @Override - public void request(long n) { - if (Operators.validate(n)) { - Operators.addCap(REQUESTED, this, n); - drain(); + public void cancel() { + this.cancelled = true; + + final long previousState = markCancelled(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (!isSubscribedOnce(previousState) || !this.outputFused) { + clearAndFinalize(this); } } @Override - public void cancel() { - if (cancelled) { + @Deprecated + public void dispose() { + this.cancelled = true; + + final long previousState = markDisposed(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (!isSubscribedOnce(previousState)) { + clearAndFinalize(this); + return; + } + + if (!isSubscriberReady(previousState)) { + return; + } + + if (!this.outputFused) { + clearAndFinalize(this); + this.actual.onError(new CancellationException("Disposed")); return; } - cancelled = true; - if (WIP.getAndIncrement(this) == 0) { - this.clear(); - hasDownstream = false; + if (!isTerminated(previousState)) { + this.actual.onError(new CancellationException("Disposed")); } } @Override @Nullable - public T poll() { - Queue pq = this.priorityQueue; - if (!pq.isEmpty()) { - return pq.poll(); + public ByteBuf poll() { + ByteBuf t = this.priorityQueue.poll(); + if (t != null) { + return t; } - return queue.poll(); + + t = this.queue.poll(); + if (t != null) { + return t; + } + + t = this.last; + if (t != null) { + this.last = null; + return t; + } + + return null; } @Override public int size() { - return priorityQueue.size() + queue.size(); + return this.priorityQueue.size() + this.queue.size(); } @Override public boolean isEmpty() { - return priorityQueue.isEmpty() && queue.isEmpty(); + return this.priorityQueue.isEmpty() && this.queue.isEmpty(); } + /** + * Clears all elements from queues and set state to terminate. This method MUST be called only by + * the downstream subscriber which has enabled {@link Fuseable#ASYNC} fusion with the given {@link + * UnboundedProcessor} and is and indicator that the downstream is done with draining, it has + * observed any terminal signal (ON_COMPLETE or ON_ERROR or CANCEL) and will never be interacting + * with SingleConsumer queue anymore. + */ @Override public void clear() { + clearAndFinalize(this); + } + + void clearSafely() { if (DISCARD_GUARD.getAndIncrement(this) != 0) { return; } int missed = 1; - for (; ; ) { - while (!queue.isEmpty()) { - T t = queue.poll(); - if (t != null) { - release(t); - } - } - while (!priorityQueue.isEmpty()) { - T t = priorityQueue.poll(); - if (t != null) { - release(t); - } - } + clearUnsafely(); missed = DISCARD_GUARD.addAndGet(this, -missed); if (missed == 0) { @@ -404,56 +680,488 @@ public void clear() { } } + void clearUnsafely() { + final Queue queue = this.queue; + final Queue priorityQueue = this.priorityQueue; + + final ByteBuf last = this.last; + + if (last != null) { + release(last); + } + + ByteBuf byteBuf; + while ((byteBuf = queue.poll()) != null) { + release(byteBuf); + } + + while ((byteBuf = priorityQueue.poll()) != null) { + release(byteBuf); + } + } + @Override public int requestFusion(int requestedMode) { if ((requestedMode & Fuseable.ASYNC) != 0) { - outputFused = true; + this.outputFused = true; return Fuseable.ASYNC; } return Fuseable.NONE; } @Override - public void dispose() { - cancel(); + public boolean isDisposed() { + return isFinalized(this.state); } - @Override - public boolean isDisposed() { - return cancelled || done; + boolean hasDownstreams() { + final long state = this.state; + return !isTerminated(state) && isSubscriberReady(state); } - @Override - public boolean isTerminated() { - return done; + static void release(ByteBuf byteBuf) { + if (byteBuf.refCnt() > 0) { + try { + byteBuf.release(); + } catch (Throwable ex) { + // no ops + } + } } - @Override - @Nullable - public Throwable getError() { - return error; + /** + * Sets {@link #FLAG_SUBSCRIBED_ONCE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED} or {@link #FLAG_DISPOSED} are unset + * + * @return {@code true} if {@link #FLAG_SUBSCRIBED_ONCE} was successfully set + */ + static long markSubscribedOnce(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isSubscribedOnce(state)) { + return state; + } + + final long nextState = state | FLAG_SUBSCRIBED_ONCE; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mso", state, nextState); + return state; + } + } } - @Override - public long downstreamCount() { - return hasDownstreams() ? 1L : 0L; + /** + * Sets {@link #FLAG_SUBSCRIBER_READY} flag if flags {@link #FLAG_FINALIZED}, {@link + * #FLAG_CANCELLED} or {@link #FLAG_DISPOSED} are unset + * + * @return previous state + */ + static long markSubscriberReady(UnboundedProcessor instance) { + for (; ; ) { + long state = instance.state; + + if (isFinalized(state) + || isCancelled(state) + || isDisposed(state) + || isSubscriberReady(state)) { + return state; + } + + long nextState = state; + if (!instance.outputFused) { + if ((!hasValue(state) && isTerminated(state)) || (hasRequest(state) && hasValue(state))) { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_SUBSCRIBER_READY; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " msr", state, nextState); + return state; + } + } } - @Override - public boolean hasDownstreams() { - return hasDownstream; - } - - void release(T t) { - if (t instanceof ReferenceCounted) { - ReferenceCounted refCounted = (ReferenceCounted) t; - if (refCounted.refCnt() > 0) { - try { - refCounted.release(); - } catch (Throwable ex) { - // no ops + /** + * Sets {@link #FLAG_HAS_REQUEST} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) + * + * @return previous state + */ + static long markRequestAdded(UnboundedProcessor instance) { + for (; ; ) { + long state = instance.state; + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state) || (isSubscriberReady(state) && hasValue(state))) { + nextState = addWork(state); + } + + nextState = nextState | FLAG_HAS_REQUEST; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mra", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_HAS_VALUE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) if {@link #FLAG_HAS_REQUEST} is set + * + * @return previous state + */ + static long markValueAdded(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state)) { + if (instance.outputFused) { + // fast path for fusion scenario + return state; + } + + if (hasRequest(state)) { + nextState = addWork(state); } } + + nextState = nextState | FLAG_HAS_VALUE; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mva", state, nextState); + return state; + } } } + + /** + * Sets {@link #FLAG_HAS_VALUE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) if {@link #FLAG_HAS_REQUEST} is set + * + * @return previous state + */ + static long markValueAddedAndTerminated(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state) && !instance.outputFused) { + nextState = addWork(state); + } + + nextState = nextState | FLAG_HAS_VALUE | FLAG_TERMINATED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, "mva&t", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_TERMINATED} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) + * + * @return previous state + */ + static long markTerminatedOrFinalized(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isTerminated(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state) && !instance.outputFused) { + if (!hasValue(state)) { + // fast path for no values and no work in progress + nextState = FLAG_FINALIZED; + } else { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_TERMINATED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mt|f", state, nextState); + if (isFinalized(nextState)) { + instance.onFinalizedHook.run(); + } + return state; + } + } + } + + /** + * Sets {@link #FLAG_CANCELLED} flag if it was not set before and if flag {@link #FLAG_FINALIZED} + * is unset. Also, this method increments number of work in progress (WIP) + * + * @return previous state + */ + static long markCancelled(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isCancelled(state)) { + return state; + } + + final long nextState = addWork(state) | FLAG_CANCELLED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mc", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_DISPOSED} flag if it was not set before and if flags {@link #FLAG_FINALIZED}, + * {@link #FLAG_CANCELLED} are unset. Also, this method increments number of work in progress + * (WIP) + * + * @return previous state + */ + static long markDisposed(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + final long nextState = addWork(state) | FLAG_DISPOSED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " md", state, nextState); + return state; + } + } + } + + static long addWork(long state) { + return (state & MAX_WIP_VALUE) == MAX_WIP_VALUE ? state : state + 1; + } + + /** + * Decrements the amount of work in progress by the given amount on the given state. Fails if flag + * is {@link #FLAG_FINALIZED} is set or if fusion disabled and flags {@link #FLAG_CANCELLED} or + * {@link #FLAG_DISPOSED} are set. + * + *

Note, if fusion is enabled, the decrement should work if flags {@link #FLAG_CANCELLED} or + * {@link #FLAG_DISPOSED} are set, since, while the operator was not terminate by the downstream, + * we still have to propagate notifications that new elements are enqueued + * + * @return state after changing WIP or current state if update failed + */ + static long markWorkDone( + UnboundedProcessor instance, long expectedState, boolean hasRequest, boolean hasValue) { + for (; ; ) { + final long state = instance.state; + + if (state != expectedState) { + return state; + } + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + final long nextState = + (state - (expectedState & MAX_WIP_VALUE)) + ^ (hasRequest ? 0 : FLAG_HAS_REQUEST) + ^ (hasValue ? 0 : FLAG_HAS_VALUE); + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mwd", state, nextState); + return nextState; + } + } + } + + /** + * Set flag {@link #FLAG_FINALIZED} and {@link #release(ByteBuf)} all the elements from {@link + * #queue} and {@link #priorityQueue}. + * + *

This method may be called concurrently only if the given {@link UnboundedProcessor} has no + * output fusion ({@link #outputFused} {@code == true}). Otherwise this method MUST only by the + * downstream calling method {@link #clear()} + */ + static void clearAndFinalize(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + instance.clearSafely(); + return; + } + + if (!isSubscriberReady(state) || !instance.outputFused) { + instance.clearSafely(); + } else { + instance.clearUnsafely(); + } + + long nextState = (state & ~MAX_WIP_VALUE & ~FLAG_HAS_VALUE) | FLAG_FINALIZED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " c&f", state, nextState); + instance.onFinalizedHook.run(); + break; + } + } + } + + static boolean hasValue(long state) { + return (state & FLAG_HAS_VALUE) == FLAG_HAS_VALUE; + } + + static boolean hasRequest(long state) { + return (state & FLAG_HAS_REQUEST) == FLAG_HAS_REQUEST; + } + + static boolean isCancelled(long state) { + return (state & FLAG_CANCELLED) == FLAG_CANCELLED; + } + + static boolean isDisposed(long state) { + return (state & FLAG_DISPOSED) == FLAG_DISPOSED; + } + + static boolean isWorkInProgress(long state) { + return (state & MAX_WIP_VALUE) != 0; + } + + static boolean isTerminated(long state) { + return (state & FLAG_TERMINATED) == FLAG_TERMINATED; + } + + static boolean isFinalized(long state) { + return (state & FLAG_FINALIZED) == FLAG_FINALIZED; + } + + static boolean isSubscriberReady(long state) { + return (state & FLAG_SUBSCRIBER_READY) == FLAG_SUBSCRIBER_READY; + } + + static boolean isSubscribedOnce(long state) { + return (state & FLAG_SUBSCRIBED_ONCE) == FLAG_SUBSCRIBED_ONCE; + } + + static void log( + UnboundedProcessor instance, String action, long initialState, long committedState) { + log(instance, action, initialState, committedState, false); + } + + static void log( + UnboundedProcessor instance, + String action, + long initialState, + long committedState, + boolean logStackTrace) { + Logger logger = instance.logger; + if (logger == null || !logger.isTraceEnabled()) { + return; + } + + if (logStackTrace) { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + action, + Thread.currentThread().getId(), + formatState(initialState, 64), + formatState(committedState, 64)), + new RuntimeException()); + } else { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + Thread.currentThread().getId(), + formatState(initialState, 64), + formatState(committedState, 64))); + } + } + + static void log( + UnboundedProcessor instance, String action, int initialState, int committedState) { + log(instance, action, initialState, committedState, false); + } + + static void log( + UnboundedProcessor instance, + String action, + int initialState, + int committedState, + boolean logStackTrace) { + Logger logger = instance.logger; + if (logger == null || !logger.isTraceEnabled()) { + return; + } + + if (logStackTrace) { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + action, + Thread.currentThread().getId(), + formatState(initialState, 32), + formatState(committedState, 32)), + new RuntimeException()); + } else { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + Thread.currentThread().getId(), + formatState(initialState, 32), + formatState(committedState, 32))); + } + } + + static String formatState(long state, int size) { + final String defaultFormat = Long.toBinaryString(state); + final StringBuilder formatted = new StringBuilder(); + final int toPrepend = size - defaultFormat.length(); + for (int i = 0; i < size; i++) { + if (i != 0 && i % 4 == 0) { + formatted.append("_"); + } + if (i < toPrepend) { + formatted.append("0"); + } else { + formatted.append(defaultFormat.charAt(i - toPrepend)); + } + } + + formatted.insert(0, "0b"); + return formatted.toString(); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java index 6939b0f7a..a99ef8a49 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java @@ -13,15 +13,30 @@ */ package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; -import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; import java.util.AbstractQueue; import java.util.Iterator; abstract class BaseLinkedQueuePad0 extends AbstractQueue implements MessagePassingQueue { - long p00, p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + // * drop 8b as object header acts as padding and is >= 8b * } // $gen:ordered-fields @@ -29,18 +44,20 @@ abstract class BaseLinkedQueueProducerNodeRef extends BaseLinkedQueuePad0 static final long P_NODE_OFFSET = fieldOffset(BaseLinkedQueueProducerNodeRef.class, "producerNode"); - private LinkedQueueNode producerNode; + private volatile LinkedQueueNode producerNode; final void spProducerNode(LinkedQueueNode newValue) { - producerNode = newValue; + UNSAFE.putObject(this, P_NODE_OFFSET, newValue); + } + + final void soProducerNode(LinkedQueueNode newValue) { + UNSAFE.putOrderedObject(this, P_NODE_OFFSET, newValue); } - @SuppressWarnings("unchecked") final LinkedQueueNode lvProducerNode() { - return (LinkedQueueNode) UNSAFE.getObjectVolatile(this, P_NODE_OFFSET); + return producerNode; } - @SuppressWarnings("unchecked") final boolean casProducerNode(LinkedQueueNode expect, LinkedQueueNode newValue) { return UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, expect, newValue); } @@ -51,8 +68,22 @@ final LinkedQueueNode lpProducerNode() { } abstract class BaseLinkedQueuePad1 extends BaseLinkedQueueProducerNodeRef { - long p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } // $gen:ordered-fields @@ -77,16 +108,27 @@ final LinkedQueueNode lpConsumerNode() { } abstract class BaseLinkedQueuePad2 extends BaseLinkedQueueConsumerNodeRef { - long p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } /** * A base data structure for concurrent linked queues. For convenience also pulled in common single * consumer methods since at this time there's no plan to implement MC. - * - * @param - * @author nitsanw */ abstract class BaseLinkedQueue extends BaseLinkedQueuePad2 { @@ -158,8 +200,10 @@ public final int size() { * @see MessagePassingQueue#isEmpty() */ @Override - public final boolean isEmpty() { - return lvConsumerNode() == lvProducerNode(); + public boolean isEmpty() { + LinkedQueueNode consumerNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + return consumerNode == producerNode; } protected E getSingleConsumerNodeValue( diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java index 635779df3..cfad5ef71 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java @@ -13,25 +13,38 @@ */ package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.queues.CircularArrayOffsetCalculator.allocate; import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; -import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.modifiedCalcElementOffset; -import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; -import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.calcElementOffset; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.lvElement; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.soElement; +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.modifiedCalcCircularRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.allocateRefArray; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.calcCircularRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.calcRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.lvRefElement; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.soRefElement; import io.rsocket.internal.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; -import io.rsocket.internal.jctools.util.PortableJvmInfo; -import io.rsocket.internal.jctools.util.Pow2; -import io.rsocket.internal.jctools.util.RangeUtil; import java.util.AbstractQueue; import java.util.Iterator; +import java.util.NoSuchElementException; abstract class BaseMpscLinkedArrayQueuePad1 extends AbstractQueue implements IndexedQueue { - long p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } // $gen:ordered-fields @@ -56,8 +69,22 @@ final boolean casProducerIndex(long expect, long newValue) { } abstract class BaseMpscLinkedArrayQueuePad2 extends BaseMpscLinkedArrayQueueProducerFields { - long p01, p02, p03, p04, p05, p06, p07; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } // $gen:ordered-fields @@ -84,8 +111,22 @@ final void soConsumerIndex(long newValue) { } abstract class BaseMpscLinkedArrayQueuePad3 extends BaseMpscLinkedArrayQueueConsumerFields { - long p0, p1, p2, p3, p4, p5, p6, p7; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b } // $gen:ordered-fields @@ -115,12 +156,9 @@ final void soProducerLimit(long newValue) { * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in * linked chunks of the initial size. The queue grows only when the current buffer is full and * elements are not copied on resize, instead a link to the new buffer is stored in the old buffer - * for the consumer to follow.
- * - * @param + * for the consumer to follow. */ -public abstract class BaseMpscLinkedArrayQueue - extends BaseMpscLinkedArrayQueueColdProducerFields +abstract class BaseMpscLinkedArrayQueue extends BaseMpscLinkedArrayQueueColdProducerFields implements MessagePassingQueue, QueueProgressIndicators { // No post padding here, subclasses must add private static final Object JUMP = new Object(); @@ -141,7 +179,7 @@ public BaseMpscLinkedArrayQueue(final int initialCapacity) { // leave lower bit of mask clear long mask = (p2capacity - 1) << 1; // need extra element to point at next array - E[] buffer = allocate(p2capacity + 1); + E[] buffer = allocateRefArray(p2capacity + 1); producerBuffer = buffer; producerMask = mask; consumerBuffer = buffer; @@ -150,7 +188,7 @@ public BaseMpscLinkedArrayQueue(final int initialCapacity) { } @Override - public final int size() { + public int size() { // NOTE: because indices are on even numbers we cannot use the size util. /* @@ -181,7 +219,7 @@ public final int size() { } @Override - public final boolean isEmpty() { + public boolean isEmpty() { // Order matters! // Loading consumer before producer allows for producer increments after consumer index is read. // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there @@ -240,8 +278,8 @@ public boolean offer(final E e) { } } // INDEX visible before ELEMENT - final long offset = modifiedCalcElementOffset(pIndex, mask); - soElement(buffer, offset, e); // release element e + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); // release element e return true; } @@ -257,8 +295,8 @@ public E poll() { final long index = lpConsumerIndex(); final long mask = consumerMask; - final long offset = modifiedCalcElementOffset(index, mask); - Object e = lvElement(buffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); if (e == null) { if (index != lvProducerIndex()) { // poll() == null iff queue is empty, null element is not strong enough indicator, so we @@ -266,7 +304,7 @@ public E poll() { // check the producer index. If the queue is indeed not empty we spin until element is // visible. do { - e = lvElement(buffer, offset); + e = lvRefElement(buffer, offset); } while (e == null); } else { return null; @@ -278,7 +316,7 @@ public E poll() { return newBufferPoll(nextBuffer, index); } - soElement(buffer, offset, null); // release element null + soRefElement(buffer, offset, null); // release element null soConsumerIndex(index + 2); // release cIndex return (E) e; } @@ -295,14 +333,14 @@ public E peek() { final long index = lpConsumerIndex(); final long mask = consumerMask; - final long offset = modifiedCalcElementOffset(index, mask); - Object e = lvElement(buffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); if (e == null && index != lvProducerIndex()) { // peek() == null iff queue is empty, null element is not strong enough indicator, so we must // check the producer index. If the queue is indeed not empty we spin until element is // visible. do { - e = lvElement(buffer, offset); + e = lvRefElement(buffer, offset); } while (e == null); } if (e == JUMP) { @@ -346,31 +384,31 @@ else if (casProducerIndex(pIndex, pIndex + 1)) { @SuppressWarnings("unchecked") private E[] nextBuffer(final E[] buffer, final long mask) { final long offset = nextArrayOffset(mask); - final E[] nextBuffer = (E[]) lvElement(buffer, offset); + final E[] nextBuffer = (E[]) lvRefElement(buffer, offset); consumerBuffer = nextBuffer; consumerMask = (length(nextBuffer) - 2) << 1; - soElement(buffer, offset, BUFFER_CONSUMED); + soRefElement(buffer, offset, BUFFER_CONSUMED); return nextBuffer; } - private long nextArrayOffset(long mask) { - return modifiedCalcElementOffset(mask + 2, Long.MAX_VALUE); + private static long nextArrayOffset(long mask) { + return modifiedCalcCircularRefElementOffset(mask + 2, Long.MAX_VALUE); } private E newBufferPoll(E[] nextBuffer, long index) { - final long offset = modifiedCalcElementOffset(index, consumerMask); - final E n = lvElement(nextBuffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, consumerMask); + final E n = lvRefElement(nextBuffer, offset); if (n == null) { throw new IllegalStateException("new buffer must have at least one element"); } - soElement(nextBuffer, offset, null); // StoreStore + soRefElement(nextBuffer, offset, null); soConsumerIndex(index + 2); return n; } private E newBufferPeek(E[] nextBuffer, long index) { - final long offset = modifiedCalcElementOffset(index, consumerMask); - final E n = lvElement(nextBuffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, consumerMask); + final E n = lvRefElement(nextBuffer, offset); if (null == n) { throw new IllegalStateException("new buffer must have at least one element"); } @@ -402,8 +440,8 @@ public E relaxedPoll() { final long index = lpConsumerIndex(); final long mask = consumerMask; - final long offset = modifiedCalcElementOffset(index, mask); - Object e = lvElement(buffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); if (e == null) { return null; } @@ -411,7 +449,7 @@ public E relaxedPoll() { final E[] nextBuffer = nextBuffer(buffer, mask); return newBufferPoll(nextBuffer, index); } - soElement(buffer, offset, null); + soRefElement(buffer, offset, null); soConsumerIndex(index + 2); return (E) e; } @@ -423,8 +461,8 @@ public E relaxedPeek() { final long index = lpConsumerIndex(); final long mask = consumerMask; - final long offset = modifiedCalcElementOffset(index, mask); - Object e = lvElement(buffer, offset); // LoadLoad + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); if (e == JUMP) { return newBufferPeek(nextBuffer(buffer, mask), index); } @@ -447,7 +485,11 @@ public int fill(Supplier s) { } @Override - public int fill(Supplier s, int batchSize) { + public int fill(Supplier s, int limit) { + if (null == s) throw new IllegalArgumentException("supplier is null"); + if (limit < 0) throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) return 0; + long mask; E[] buffer; long pIndex; @@ -471,9 +513,10 @@ public int fill(Supplier s, int batchSize) { // a successful CAS ties the ordering, lv(pIndex) -> [mask/buffer] -> cas(pIndex) // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' - long batchIndex = Math.min(producerLimit, pIndex + 2 * batchSize); + long batchIndex = + Math.min(producerLimit, pIndex + 2l * limit); // -> producerLimit >= batchIndex - if (pIndex >= producerLimit || producerLimit < batchIndex) { + if (pIndex >= producerLimit) { int result = offerSlowPath(mask, pIndex, producerLimit); switch (result) { case CONTINUE_TO_P_INDEX_CAS: @@ -496,23 +539,15 @@ public int fill(Supplier s, int batchSize) { } for (int i = 0; i < claimedSlots; i++) { - final long offset = modifiedCalcElementOffset(pIndex + 2 * i, mask); - soElement(buffer, offset, s.get()); + final long offset = modifiedCalcCircularRefElementOffset(pIndex + 2l * i, mask); + soRefElement(buffer, offset, s.get()); } return claimedSlots; } @Override - public void fill(Supplier s, WaitStrategy w, ExitCondition exit) { - - while (exit.keepRunning()) { - if (fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { - int idleCounter = 0; - while (exit.keepRunning() && fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { - idleCounter = w.idle(idleCounter); - } - } - } + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); } @Override @@ -521,30 +556,13 @@ public int drain(Consumer c) { } @Override - public int drain(final Consumer c, final int limit) { - // Impl note: there are potentially some small gains to be had by manually inlining - // relaxedPoll() and hoisting - // reused fields out to reduce redundant reads. - int i = 0; - E m; - for (; i < limit && (m = relaxedPoll()) != null; i++) { - c.accept(m); - } - return i; + public int drain(Consumer c, int limit) { + return MessagePassingQueueUtil.drain(this, c, limit); } @Override - public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { - int idleCounter = 0; - while (exit.keepRunning()) { - E e = relaxedPoll(); - if (e == null) { - idleCounter = w.idle(idleCounter); - continue; - } - idleCounter = 0; - c.accept(e); - } + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); } /** @@ -559,21 +577,28 @@ public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { */ @Override public Iterator iterator() { - return new WeakIterator(); + return new WeakIterator(consumerBuffer, lvConsumerIndex(), lvProducerIndex()); } - private final class WeakIterator implements Iterator { - + private static class WeakIterator implements Iterator { + private final long pIndex; private long nextIndex; private E nextElement; private E[] currentBuffer; - private int currentBufferLength; + private int mask; - WeakIterator() { - setBuffer(consumerBuffer); + WeakIterator(E[] currentBuffer, long cIndex, long pIndex) { + this.pIndex = pIndex >> 1; + this.nextIndex = cIndex >> 1; + setBuffer(currentBuffer); nextElement = getNext(); } + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + @Override public boolean hasNext() { return nextElement != null; @@ -581,37 +606,54 @@ public boolean hasNext() { @Override public E next() { - E e = nextElement; + final E e = nextElement; + if (e == null) { + throw new NoSuchElementException(); + } nextElement = getNext(); return e; } private void setBuffer(E[] buffer) { this.currentBuffer = buffer; - this.currentBufferLength = length(buffer); - this.nextIndex = 0; + this.mask = length(buffer) - 2; } private E getNext() { - while (true) { - while (nextIndex < currentBufferLength - 1) { - long offset = calcElementOffset(nextIndex++); - E e = lvElement(currentBuffer, offset); - if (e != null && e != JUMP) { - return e; - } + while (nextIndex < pIndex) { + long index = nextIndex++; + E e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; } - long offset = calcElementOffset(currentBufferLength - 1); - Object nextArray = lvElement(currentBuffer, offset); - if (nextArray == BUFFER_CONSUMED) { - // Consumer may have passed us, just jump to the current consumer buffer - setBuffer(consumerBuffer); - } else if (nextArray != null) { - setBuffer((E[]) nextArray); - } else { + + // not null && not JUMP -> found next element + if (e != JUMP) { + return e; + } + + // need to jump to the next buffer + int nextBufferIndex = mask + 1; + Object nextBuffer = lvRefElement(currentBuffer, calcRefElementOffset(nextBufferIndex)); + + if (nextBuffer == BUFFER_CONSUMED || nextBuffer == null) { + // Consumer may have passed us, or the next buffer is not visible yet: drop out early return null; } + + setBuffer((E[]) nextBuffer); + // now with the new array retry the load, it can't be a JUMP, but we need to repeat same + // index + e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } else { + return e; + } } + return null; } } @@ -620,7 +662,7 @@ private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s int newBufferLength = getNextBufferSize(oldBuffer); final E[] newBuffer; try { - newBuffer = allocate(newBufferLength); + newBuffer = allocateRefArray(newBufferLength); } catch (OutOfMemoryError oom) { assert lvProducerIndex() == pIndex + 1; soProducerIndex(pIndex); @@ -631,11 +673,11 @@ private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s final int newMask = (newBufferLength - 2) << 1; producerMask = newMask; - final long offsetInOld = modifiedCalcElementOffset(pIndex, oldMask); - final long offsetInNew = modifiedCalcElementOffset(pIndex, newMask); + final long offsetInOld = modifiedCalcCircularRefElementOffset(pIndex, oldMask); + final long offsetInNew = modifiedCalcCircularRefElementOffset(pIndex, newMask); - soElement(newBuffer, offsetInNew, e == null ? s.get() : e); // element in new array - soElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); // buffer linked + soRefElement(newBuffer, offsetInNew, e == null ? s.get() : e); // element in new array + soRefElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); // buffer linked // ASSERT code final long cIndex = lvConsumerIndex(); @@ -652,7 +694,7 @@ private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s // INDEX visible before ELEMENT, consistent with consumer expectation // make resize visible to consumer - soElement(oldBuffer, offsetInOld, JUMP); + soRefElement(oldBuffer, offsetInOld, JUMP); } /** @return next buffer size(inclusive of next array pointer) */ diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java deleted file mode 100644 index d746fccbb..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal.jctools.queues; - -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ARRAY_BASE; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; - -import io.rsocket.internal.jctools.util.InternalAPI; - -@InternalAPI -public final class CircularArrayOffsetCalculator { - @SuppressWarnings("unchecked") - public static E[] allocate(int capacity) { - return (E[]) new Object[capacity]; - } - - /** - * @param index desirable element index - * @param mask (length - 1) - * @return the offset in bytes within the array for a given index. - */ - public static long calcElementOffset(long index, long mask) { - return REF_ARRAY_BASE + ((index & mask) << REF_ELEMENT_SHIFT); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java index 1b7d43166..40116bbe1 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java @@ -13,10 +13,7 @@ */ package io.rsocket.internal.jctools.queues; -import io.rsocket.internal.jctools.util.InternalAPI; - -@InternalAPI -public final class IndexedQueueSizeUtil { +final class IndexedQueueSizeUtil { public static int size(IndexedQueue iq) { /* * It is possible for a thread to be interrupted or reschedule between the read of the producer and @@ -54,7 +51,6 @@ public static boolean isEmpty(IndexedQueue iq) { return (iq.lvConsumerIndex() == iq.lvProducerIndex()); } - @InternalAPI public interface IndexedQueue { long lvConsumerIndex(); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java index 5e7831128..37651f351 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java @@ -13,13 +13,11 @@ */ package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ARRAY_BASE; -import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.REF_ARRAY_BASE; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; /** This is used for method substitution in the LinkedArray classes code generation. */ final class LinkedArrayQueueUtil { - private LinkedArrayQueueUtil() {} - static int length(Object[] buf) { return buf.length; } @@ -29,7 +27,7 @@ static int length(Object[] buf) { * is compensated for by reducing the element shift. The computation is constant folded, so * there's no cost. */ - static long modifiedCalcElementOffset(long index, long mask) { + static long modifiedCalcCircularRefElementOffset(long index, long mask) { return REF_ARRAY_BASE + ((index & mask) << (REF_ELEMENT_SHIFT - 1)); } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java index 6ea69e330..72e78bb92 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java @@ -13,8 +13,8 @@ */ package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; -import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; final class LinkedQueueNode { private static final long NEXT_OFFSET = fieldOffset(LinkedQueueNode.class, "next"); @@ -53,6 +53,10 @@ public void soNext(LinkedQueueNode n) { UNSAFE.putOrderedObject(this, NEXT_OFFSET, n); } + public void spNext(LinkedQueueNode n) { + UNSAFE.putObject(this, NEXT_OFFSET, n); + } + public LinkedQueueNode lvNext() { return next; } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java index e0c3d0ee1..7a0fa901f 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java @@ -13,55 +13,327 @@ */ package io.rsocket.internal.jctools.queues; +import java.util.Queue; + +/** + * Message passing queues are intended for concurrent method passing. A subset of {@link Queue} + * methods are provided with the same semantics, while further functionality which accomodates the + * concurrent usecase is also on offer. + * + *

Message passing queues provide happens before semantics to messages passed through, namely + * that writes made by the producer before offering the message are visible to the consuming thread + * after the message has been polled out of the queue. + * + * @param the event/message type + */ public interface MessagePassingQueue { int UNBOUNDED_CAPACITY = -1; interface Supplier { + /** + * This method will return the next value to be written to the queue. As such the queue + * implementations are commited to insert the value once the call is made. + * + *

Users should be aware that underlying queue implementations may upfront claim parts of the + * queue for batch operations and this will effect the view on the queue from the supplier + * method. In particular size and any offer methods may take the view that the full batch has + * already happened. + * + *

WARNING: this method is assumed to never throw. Breaking this assumption can lead + * to a broken queue. + * + *

WARNING: this method is assumed to never return {@code null}. Breaking this + * assumption can lead to a broken queue. + * + * @return new element, NEVER {@code null} + */ T get(); } interface Consumer { + /** + * This method will process an element already removed from the queue. This method is expected + * to never throw an exception. + * + *

Users should be aware that underlying queue implementations may upfront claim parts of the + * queue for batch operations and this will effect the view on the queue from the accept method. + * In particular size and any poll/peek methods may take the view that the full batch has + * already happened. + * + *

WARNING: this method is assumed to never throw. Breaking this assumption can lead + * to a broken queue. + * + * @param e not {@code null} + */ void accept(T e); } interface WaitStrategy { + /** + * This method can implement static or dynamic backoff. Dynamic backoff will rely on the counter + * for estimating how long the caller has been idling. The expected usage is: + * + *

+ * + *

+     * 
+     * int ic = 0;
+     * while(true) {
+     *   if(!isGodotArrived()) {
+     *     ic = w.idle(ic);
+     *     continue;
+     *   }
+     *   ic = 0;
+     *   // party with Godot until he goes again
+     * }
+     * 
+     * 
+ * + * @param idleCounter idle calls counter, managed by the idle method until reset + * @return new counter value to be used on subsequent idle cycle + */ int idle(int idleCounter); } interface ExitCondition { + /** + * This method should be implemented such that the flag read or determination cannot be hoisted + * out of a loop which notmally means a volatile load, but with JDK9 VarHandles may mean + * getOpaque. + * + * @return true as long as we should keep running + */ boolean keepRunning(); } + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation and + * according to the {@link Queue#offer(Object)} interface. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false iff full + */ boolean offer(T e); + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation + * and according to the {@link Queue#poll()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ T poll(); + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation + * and according to the {@link Queue#peek()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ T peek(); + /** + * This method's accuracy is subject to concurrent modifications happening as the size is + * estimated and as such is a best effort rather than absolute value. For some implementations + * this method may be O(n) rather than O(1). + * + * @return number of messages in the queue, between 0 and {@link Integer#MAX_VALUE} but less or + * equals to capacity (if bounded). + */ int size(); + /** + * Removes all items from the queue. Called from the consumer thread subject to the restrictions + * appropriate to the implementation and according to the {@link Queue#clear()} interface. + */ void clear(); + /** + * This method's accuracy is subject to concurrent modifications happening as the observation is + * carried out. + * + * @return true if empty, false otherwise + */ boolean isEmpty(); + /** + * @return the capacity of this queue or {@link MessagePassingQueue#UNBOUNDED_CAPACITY} if not + * bounded + */ int capacity(); + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation. As + * opposed to {@link Queue#offer(Object)} this method may return false without the queue being + * full. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false if unable to offer + */ boolean relaxedOffer(T e); + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. + * As opposed to {@link Queue#poll()} this method may return {@code null} without the queue being + * empty. + * + * @return a message from the queue if one is available, {@code null} if unable to poll + */ T relaxedPoll(); + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. + * As opposed to {@link Queue#peek()} this method may return {@code null} without the queue being + * empty. + * + * @return a message from the queue if one is available, {@code null} if unable to peek + */ T relaxedPeek(); - int drain(Consumer c); - - int fill(Supplier s); - + /** + * Remove up to limit elements from the queue and hand to consume. This should be + * semantically similar to: + * + *

+ * + *

{@code
+   * M m;
+   * int i = 0;
+   * for(;i < limit && (m = relaxedPoll()) != null; i++){
+   *   c.accept(m);
+   * }
+   * return i;
+   * }
+ * + *

There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + * @throws IllegalArgumentException if limit is negative + */ int drain(Consumer c, int limit); + /** + * Stuff the queue with up to limit elements from the supplier. Semantically similar to: + * + *

+ * + *

{@code
+   * for(int i=0; i < limit && relaxedOffer(s.get()); i++);
+   * }
+ * + *

There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + * @throws IllegalArgumentException if limit is negative + */ int fill(Supplier s, int limit); + /** + * Remove all available item from the queue and hand to consume. This should be semantically + * similar to: + * + *

+   * M m;
+   * while((m = relaxedPoll()) != null){
+   * c.accept(m);
+   * }
+   * 
+ * + * There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + */ + int drain(Consumer c); + + /** + * Stuff the queue with elements from the supplier. Semantically similar to: + * + *

+   * while(relaxedOffer(s.get());
+   * 
+ * + * There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + * + *

Unbounded queues will fill up the queue with a fixed amount rather than fill up to oblivion. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + */ + int fill(Supplier s); + + /** + * Remove elements from the queue and hand to consume forever. Semantically similar to: + * + *

+ * + *

+   *  int idleCounter = 0;
+   *  while (exit.keepRunning()) {
+   *      E e = relaxedPoll();
+   *      if(e==null){
+   *          idleCounter = wait.idle(idleCounter);
+   *          continue;
+   *      }
+   *      idleCounter = 0;
+   *      c.accept(e);
+   *  }
+   * 
+ * + *

Called from a consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @throws IllegalArgumentException c OR wait OR exit are {@code null} + */ void drain(Consumer c, WaitStrategy wait, ExitCondition exit); + /** + * Stuff the queue with elements from the supplier forever. Semantically similar to: + * + *

+ * + *

+   * 
+   *  int idleCounter = 0;
+   *  while (exit.keepRunning()) {
+   *      E e = s.get();
+   *      while (!relaxedOffer(e)) {
+   *          idleCounter = wait.idle(idleCounter);
+   *          continue;
+   *      }
+   *      idleCounter = 0;
+   *  }
+   * 
+   * 
+ * + *

Called from a producer thread subject to the restrictions appropriate to the implementation. + * The main difference being that implementors MUST assure room in the queue is available BEFORE + * calling {@link Supplier#get}. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @throws IllegalArgumentException s OR wait OR exit are {@code null} + */ void fill(Supplier s, WaitStrategy wait, ExitCondition exit); } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java new file mode 100644 index 000000000..cb03364d8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.jctools.queues; + +import io.rsocket.internal.jctools.queues.MessagePassingQueue.Consumer; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.ExitCondition; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.Supplier; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.WaitStrategy; + +final class MessagePassingQueueUtil { + public static int drain(MessagePassingQueue queue, Consumer c, int limit) { + if (null == c) throw new IllegalArgumentException("c is null"); + if (limit < 0) throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) return 0; + E e; + int i = 0; + for (; i < limit && (e = queue.relaxedPoll()) != null; i++) { + c.accept(e); + } + return i; + } + + public static int drain(MessagePassingQueue queue, Consumer c) { + if (null == c) throw new IllegalArgumentException("c is null"); + E e; + int i = 0; + while ((e = queue.relaxedPoll()) != null) { + i++; + c.accept(e); + } + return i; + } + + public static void drain( + MessagePassingQueue queue, Consumer c, WaitStrategy wait, ExitCondition exit) { + if (null == c) throw new IllegalArgumentException("c is null"); + if (null == wait) throw new IllegalArgumentException("wait is null"); + if (null == exit) throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) { + final E e = queue.relaxedPoll(); + if (e == null) { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + c.accept(e); + } + } + + public static void fill( + MessagePassingQueue q, Supplier s, WaitStrategy wait, ExitCondition exit) { + if (null == wait) throw new IllegalArgumentException("waiter is null"); + if (null == exit) throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) { + if (q.fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + } + } + + public static int fillBounded(MessagePassingQueue q, Supplier s) { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, q.capacity()); + } + + public static int fillInBatchesToLimit( + MessagePassingQueue q, Supplier s, int batch, int limit) { + long result = + 0; // result is a long because we want to have a safepoint check at regular intervals + do { + final int filled = q.fill(s, batch); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= limit); + return (int) result; + } + + public static int fillUnbounded(MessagePassingQueue q, Supplier s) { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, 4096); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java index 59eab33a1..179070be4 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java @@ -14,20 +14,31 @@ package io.rsocket.internal.jctools.queues; import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; - -import io.rsocket.internal.jctools.util.PortableJvmInfo; +import static io.rsocket.internal.jctools.queues.MessagePassingQueueUtil.fillUnbounded; /** * An MPSC array queue which starts at initialCapacity and grows indefinitely in linked * chunks of the initial size. The queue grows only when the current chunk is full and elements are * not copied on resize, instead a link to the new chunk is stored in the old chunk for the consumer - * to follow.
- * - * @param + * to follow. */ public class MpscUnboundedArrayQueue extends BaseMpscLinkedArrayQueue { - long p0, p1, p2, p3, p4, p5, p6, p7; - long p10, p11, p12, p13, p14, p15, p16, p17; + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b public MpscUnboundedArrayQueue(int chunkSize) { super(chunkSize); @@ -50,17 +61,7 @@ public int drain(Consumer c) { @Override public int fill(Supplier s) { - long result = - 0; // result is a long because we want to have a safepoint check at regular intervals - final int capacity = 4096; - do { - final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); - if (filled == 0) { - return (int) result; - } - result += filled; - } while (result <= capacity); - return (int) result; + return fillUnbounded(this, s); } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java similarity index 90% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java index 2d567d60d..f037857e8 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java @@ -11,11 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; /** JVM Information that is standard and available on all JVMs (i.e. does not use unsafe) */ -@InternalAPI -public interface PortableJvmInfo { +interface PortableJvmInfo { int CACHE_LINE_SIZE = Integer.getInteger("jctools.cacheLineSize", 64); int CPUs = Runtime.getRuntime().availableProcessors(); int RECOMENDED_OFFER_BATCH = CPUs * 4; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java similarity index 96% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java index d8c66d89e..282a22f02 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java @@ -11,11 +11,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; /** Power of 2 utility functions. */ -@InternalAPI -public final class Pow2 { +final class Pow2 { public static final int MAX_POW2 = 1 << 30; /** diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java similarity index 94% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java index 77a0582ca..3adcb2f3c 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java @@ -11,10 +11,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; -@InternalAPI -public final class RangeUtil { +final class RangeUtil { public static long checkPositive(long n, String name) { if (n <= 0) { throw new IllegalArgumentException(name + ": " + n + " (expected: > 0)"); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java old mode 100755 new mode 100644 similarity index 71% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java index 793e64505..c99aeb689 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java @@ -11,7 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; import java.lang.reflect.Constructor; import java.lang.reflect.Field; @@ -33,41 +33,56 @@ * * @author nitsanw */ -@InternalAPI -public class UnsafeAccess { - public static final boolean SUPPORTS_GET_AND_SET; +class UnsafeAccess { + public static final boolean SUPPORTS_GET_AND_SET_REF; + public static final boolean SUPPORTS_GET_AND_ADD_LONG; public static final Unsafe UNSAFE; static { + UNSAFE = getUnsafe(); + SUPPORTS_GET_AND_SET_REF = hasGetAndSetSupport(); + SUPPORTS_GET_AND_ADD_LONG = hasGetAndAddLongSupport(); + } + + private static Unsafe getUnsafe() { Unsafe instance; try { final Field field = Unsafe.class.getDeclaredField("theUnsafe"); field.setAccessible(true); instance = (Unsafe) field.get(null); } catch (Exception ignored) { - // Some platforms, notably Android, might not have a sun.misc.Unsafe - // implementation with a private `theUnsafe` static instance. In this - // case we can try and call the default constructor, which proves - // sufficient for Android usage. + // Some platforms, notably Android, might not have a sun.misc.Unsafe implementation with a + // private + // `theUnsafe` static instance. In this case we can try to call the default constructor, which + // is sufficient + // for Android usage. try { Constructor c = Unsafe.class.getDeclaredConstructor(); c.setAccessible(true); instance = c.newInstance(); } catch (Exception e) { - SUPPORTS_GET_AND_SET = false; throw new RuntimeException(e); } } + return instance; + } - boolean getAndSetSupport = false; + private static boolean hasGetAndSetSupport() { try { Unsafe.class.getMethod("getAndSetObject", Object.class, Long.TYPE, Object.class); - getAndSetSupport = true; + return true; } catch (Exception ignored) { } + return false; + } - UNSAFE = instance; - SUPPORTS_GET_AND_SET = getAndSetSupport; + private static boolean hasGetAndAddLongSupport() { + try { + Unsafe.class.getMethod("getAndAddLong", Object.class, Long.TYPE, Long.TYPE); + return true; + } catch (Exception ignored) { + } + return false; } public static long fieldOffset(Class clz, String fieldName) throws RuntimeException { diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java similarity index 57% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java index d8309c5c5..c734a9914 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java @@ -11,32 +11,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.util; +package io.rsocket.internal.jctools.queues; -import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; -/** - * A concurrent access enabling class used by circular array based queues this class exposes an - * offset computation method along with differently memory fenced load/store methods into the - * underlying array. The class is pre-padded and the array is padded on either side to help with - * False sharing prvention. It is expected theat subclasses handle post padding. - * - *

Offset calculation is separate from access to enable the reuse of a give compute offset. - * - *

Load/Store methods using a buffer parameter are provided to allow the prevention of - * final field reload after a LoadLoad barrier. - * - *

- * - * @author nitsanw - */ -@InternalAPI -public final class UnsafeRefArrayAccess { +final class UnsafeRefArrayAccess { public static final long REF_ARRAY_BASE; public static final int REF_ELEMENT_SHIFT; static { - final int scale = UnsafeAccess.UNSAFE.arrayIndexScale(Object[].class); + final int scale = UNSAFE.arrayIndexScale(Object[].class); if (4 == scale) { REF_ELEMENT_SHIFT = 2; } else if (8 == scale) { @@ -44,28 +28,28 @@ public final class UnsafeRefArrayAccess { } else { throw new IllegalStateException("Unknown pointer size: " + scale); } - REF_ARRAY_BASE = UnsafeAccess.UNSAFE.arrayBaseOffset(Object[].class); + REF_ARRAY_BASE = UNSAFE.arrayBaseOffset(Object[].class); } /** * A plain store (no ordering/fences) of an element to a given offset * * @param buffer this.buffer - * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} * @param e an orderly kitty */ - public static void spElement(E[] buffer, long offset, E e) { + public static void spRefElement(E[] buffer, long offset, E e) { UNSAFE.putObject(buffer, offset, e); } /** - * An ordered store(store + StoreStore barrier) of an element to a given offset + * An ordered store of an element to a given offset * * @param buffer this.buffer - * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset} + * @param offset computed via {@link UnsafeRefArrayAccess#calcCircularRefElementOffset} * @param e an orderly kitty */ - public static void soElement(E[] buffer, long offset, E e) { + public static void soRefElement(E[] buffer, long offset, E e) { UNSAFE.putOrderedObject(buffer, offset, e); } @@ -73,31 +57,48 @@ public static void soElement(E[] buffer, long offset, E e) { * A plain load (no ordering/fences) of an element from a given offset. * * @param buffer this.buffer - * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} * @return the element at the offset */ @SuppressWarnings("unchecked") - public static E lpElement(E[] buffer, long offset) { + public static E lpRefElement(E[] buffer, long offset) { return (E) UNSAFE.getObject(buffer, offset); } /** - * A volatile load (load + LoadLoad barrier) of an element from a given offset. + * A volatile load of an element from a given offset. * * @param buffer this.buffer - * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} * @return the element at the offset */ @SuppressWarnings("unchecked") - public static E lvElement(E[] buffer, long offset) { + public static E lvRefElement(E[] buffer, long offset) { return (E) UNSAFE.getObjectVolatile(buffer, offset); } /** * @param index desirable element index - * @return the offset in bytes within the array for a given index. + * @return the offset in bytes within the array for a given index */ - public static long calcElementOffset(long index) { + public static long calcRefElementOffset(long index) { return REF_ARRAY_BASE + (index << REF_ELEMENT_SHIFT); } + + /** + * Note: circular arrays are assumed a power of 2 in length and the `mask` is (length - 1). + * + * @param index desirable element index + * @param mask (length - 1) + * @return the offset in bytes within the circular array for a given index + */ + public static long calcCircularRefElementOffset(long index, long mask) { + return REF_ARRAY_BASE + ((index & mask) << REF_ELEMENT_SHIFT); + } + + /** This makes for an easier time generating the atomic queues, and removes some warnings. */ + @SuppressWarnings("unchecked") + public static E[] allocateRefArray(int capacity) { + return (E[]) new Object[capacity]; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java deleted file mode 100644 index f233e9597..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal.jctools.util; - -import java.lang.annotation.ElementType; -import java.lang.annotation.Retention; -import java.lang.annotation.RetentionPolicy; -import java.lang.annotation.Target; - -/** - * This annotation marks classes and methods which may be public for any reason (to support better - * testing or reduce code duplication) but are not intended as public API and may change between - * releases without the change being considered a breaking API change (a major release). - */ -@Target({ElementType.TYPE, ElementType.METHOD, ElementType.FIELD, ElementType.CONSTRUCTOR}) -@Retention(RetentionPolicy.SOURCE) -public @interface InternalAPI {} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java index 2535c342b..4fd7a772d 100644 --- a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java @@ -1,9 +1,10 @@ package io.rsocket.keepalive; import io.netty.buffer.ByteBuf; -import io.rsocket.Closeable; import io.rsocket.keepalive.KeepAliveSupport.KeepAlive; +import io.rsocket.resume.RSocketSession; import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumeStateHolder; import java.util.function.Consumer; public interface KeepAliveHandler { @@ -14,18 +15,11 @@ KeepAliveFramesAcceptor start( Consumer onTimeout); class DefaultKeepAliveHandler implements KeepAliveHandler { - private final Closeable duplexConnection; - - public DefaultKeepAliveHandler(Closeable duplexConnection) { - this.duplexConnection = duplexConnection; - } - @Override public KeepAliveFramesAcceptor start( KeepAliveSupport keepAliveSupport, Consumer onSendKeepAliveFrame, Consumer onTimeout) { - duplexConnection.onClose().doFinally(s -> keepAliveSupport.stop()).subscribe(); return keepAliveSupport .onSendKeepAliveFrame(onSendKeepAliveFrame) .onTimeout(onTimeout) @@ -34,10 +28,18 @@ public KeepAliveFramesAcceptor start( } class ResumableKeepAliveHandler implements KeepAliveHandler { + private final ResumableDuplexConnection resumableDuplexConnection; + private final RSocketSession rSocketSession; + private final ResumeStateHolder resumeStateHolder; - public ResumableKeepAliveHandler(ResumableDuplexConnection resumableDuplexConnection) { + public ResumableKeepAliveHandler( + ResumableDuplexConnection resumableDuplexConnection, + RSocketSession rSocketSession, + ResumeStateHolder resumeStateHolder) { this.resumableDuplexConnection = resumableDuplexConnection; + this.rSocketSession = rSocketSession; + this.resumeStateHolder = resumeStateHolder; } @Override @@ -45,10 +47,11 @@ public KeepAliveFramesAcceptor start( KeepAliveSupport keepAliveSupport, Consumer onSendKeepAliveFrame, Consumer onTimeout) { - resumableDuplexConnection.onResume(keepAliveSupport::start); - resumableDuplexConnection.onDisconnect(keepAliveSupport::stop); + + rSocketSession.setKeepAliveSupport(keepAliveSupport); + return keepAliveSupport - .resumeState(resumableDuplexConnection) + .resumeState(resumeStateHolder) .onSendKeepAliveFrame(onSendKeepAliveFrame) .onTimeout(keepAlive -> resumableDuplexConnection.disconnect()) .start(); diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java index 6a3ab40d3..4fd18d041 100644 --- a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java @@ -23,7 +23,7 @@ import io.rsocket.resume.ResumeStateHolder; import java.time.Duration; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.function.Consumer; import reactor.core.Disposable; import reactor.core.publisher.Flux; @@ -38,11 +38,19 @@ public abstract class KeepAliveSupport implements KeepAliveFramesAcceptor { final Duration keepAliveTimeout; final long keepAliveTimeoutMillis; - final AtomicBoolean started = new AtomicBoolean(); + volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(KeepAliveSupport.class, "state"); + + static final int STOPPED_STATE = 0; + static final int STARTING_STATE = 1; + static final int STARTED_STATE = 2; + static final int DISPOSED_STATE = -1; volatile Consumer onTimeout; volatile Consumer onFrameSent; - volatile Disposable ticksDisposable; + + Disposable ticksDisposable; volatile ResumeStateHolder resumeStateHolder; volatile long lastReceivedMillis; @@ -57,25 +65,30 @@ private KeepAliveSupport( } public KeepAliveSupport start() { - this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); - if (started.compareAndSet(false, true)) { - ticksDisposable = + if (this.state == STOPPED_STATE && STATE.compareAndSet(this, STOPPED_STATE, STARTING_STATE)) { + this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); + + final Disposable disposable = Flux.interval(keepAliveInterval, scheduler).subscribe(v -> onIntervalTick()); + this.ticksDisposable = disposable; + + if (this.state != STARTING_STATE + || !STATE.compareAndSet(this, STARTING_STATE, STARTED_STATE)) { + disposable.dispose(); + } } return this; } public void stop() { - if (started.compareAndSet(true, false)) { - ticksDisposable.dispose(); - } + terminate(STOPPED_STATE); } @Override public void receive(ByteBuf keepAliveFrame) { this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); if (resumeStateHolder != null) { - long remoteLastReceivedPos = remoteLastReceivedPosition(keepAliveFrame); + final long remoteLastReceivedPos = KeepAliveFrameCodec.lastPosition(keepAliveFrame); resumeStateHolder.onImpliedPosition(remoteLastReceivedPos); } if (KeepAliveFrameCodec.respondFlag(keepAliveFrame)) { @@ -104,6 +117,16 @@ public KeepAliveSupport onTimeout(Consumer onTimeout) { return this; } + @Override + public void dispose() { + terminate(DISPOSED_STATE); + } + + @Override + public boolean isDisposed() { + return ticksDisposable.isDisposed(); + } + abstract void onIntervalTick(); void send(ByteBuf frame) { @@ -122,22 +145,24 @@ void tryTimeout() { } } - long localLastReceivedPosition() { - return resumeStateHolder != null ? resumeStateHolder.impliedPosition() : 0; - } + void terminate(int terminationState) { + for (; ; ) { + final int state = this.state; - long remoteLastReceivedPosition(ByteBuf keepAliveFrame) { - return KeepAliveFrameCodec.lastPosition(keepAliveFrame); - } + if (state == STOPPED_STATE || state == DISPOSED_STATE) { + return; + } - @Override - public void dispose() { - stop(); + final Disposable disposable = this.ticksDisposable; + if (STATE.compareAndSet(this, state, terminationState)) { + disposable.dispose(); + return; + } + } } - @Override - public boolean isDisposed() { - return ticksDisposable.isDisposed(); + long localLastReceivedPosition() { + return resumeStateHolder != null ? resumeStateHolder.impliedPosition() : 0; } public static final class ClientKeepAliveSupport extends KeepAliveSupport { diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java index 673b4a480..9e76d176d 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java @@ -18,51 +18,62 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; -import io.rsocket.Availability; +import java.time.Duration; import reactor.util.annotation.Nullable; /** A contract for RSocket lease, which is sent by a request acceptor and is time bound. */ -public interface Lease extends Availability { +public final class Lease { - static Lease create(int timeToLiveMillis, int numberOfRequests, @Nullable ByteBuf metadata) { - return LeaseImpl.create(timeToLiveMillis, numberOfRequests, metadata); + public static Lease create( + Duration timeToLive, int numberOfRequests, @Nullable ByteBuf metadata) { + return new Lease(timeToLive, numberOfRequests, metadata); } - static Lease create(int timeToLiveMillis, int numberOfRequests) { - return create(timeToLiveMillis, numberOfRequests, Unpooled.EMPTY_BUFFER); + public static Lease create(Duration timeToLive, int numberOfRequests) { + return create(timeToLive, numberOfRequests, Unpooled.EMPTY_BUFFER); } - /** - * Number of requests allowed by this lease. - * - * @return The number of requests allowed by this lease. - */ - int getAllowedRequests(); + public static Lease unbounded() { + return unbounded(null); + } - /** - * Initial number of requests allowed by this lease. - * - * @return initial number of requests allowed by this lease. - */ - default int getStartingAllowedRequests() { - throw new UnsupportedOperationException("Not implemented"); + public static Lease unbounded(@Nullable ByteBuf metadata) { + return create(Duration.ofMillis(Integer.MAX_VALUE), Integer.MAX_VALUE, metadata); + } + + public static Lease empty() { + return create(Duration.ZERO, 0); + } + + final int timeToLiveMillis; + final int numberOfRequests; + final ByteBuf metadata; + final long expirationTime; + + Lease(Duration timeToLive, int numberOfRequests, @Nullable ByteBuf metadata) { + this.numberOfRequests = numberOfRequests; + this.timeToLiveMillis = (int) Math.min(timeToLive.toMillis(), Integer.MAX_VALUE); + this.metadata = metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + this.expirationTime = + timeToLive.isZero() ? 0 : System.currentTimeMillis() + timeToLive.toMillis(); } /** - * Number of milliseconds that this lease is valid from the time it is received. + * Number of requests allowed by this lease. * - * @return Number of milliseconds that this lease is valid from the time it is received. + * @return The number of requests allowed by this lease. */ - int getTimeToLiveMillis(); + public int numberOfRequests() { + return numberOfRequests; + } /** - * Number of milliseconds that this lease is still valid from now. + * Time to live for the given lease * - * @param now millis since epoch - * @return Number of milliseconds that this lease is still valid from now, or 0 if expired. + * @return relative duration in milliseconds */ - default int getRemainingTimeToLiveMillis(long now) { - return isEmpty() ? 0 : (int) Math.max(0, expiry() - now); + public int timeToLiveInMillis() { + return this.timeToLiveMillis; } /** @@ -70,41 +81,29 @@ default int getRemainingTimeToLiveMillis(long now) { * * @return Absolute time since epoch at which this lease will expire. */ - long expiry(); + public long expirationTime() { + return expirationTime; + } /** * Metadata for the lease. * * @return Metadata for the lease. */ - ByteBuf getMetadata(); - - /** - * Checks if the lease is expired now. - * - * @return {@code true} if the lease has expired. - */ - default boolean isExpired() { - return isExpired(System.currentTimeMillis()); - } - - /** - * Checks if the lease is expired for the passed {@code now}. - * - * @param now current time in millis. - * @return {@code true} if the lease has expired. - */ - default boolean isExpired(long now) { - return now > expiry(); - } - - /** Checks if the lease has not expired and there are allowed requests available */ - default boolean isValid() { - return !isExpired() && getAllowedRequests() > 0; + @Nullable + public ByteBuf metadata() { + return metadata; } - /** Checks if the lease is empty(default value if no lease was received yet) */ - default boolean isEmpty() { - return getAllowedRequests() == 0 && getTimeToLiveMillis() == 0; + @Override + public String toString() { + return "Lease{" + + "timeToLiveMillis=" + + timeToLiveMillis + + ", numberOfRequests=" + + numberOfRequests + + ", expirationTime=" + + expirationTime + + '}'; } } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java deleted file mode 100644 index 7abb8aab9..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import java.util.concurrent.atomic.AtomicInteger; -import reactor.util.annotation.Nullable; - -public class LeaseImpl implements Lease { - private final int timeToLiveMillis; - private final AtomicInteger allowedRequests; - private final int startingAllowedRequests; - private final ByteBuf metadata; - private final long expiry; - - static LeaseImpl create(int timeToLiveMillis, int numberOfRequests, @Nullable ByteBuf metadata) { - assertLease(timeToLiveMillis, numberOfRequests); - return new LeaseImpl(timeToLiveMillis, numberOfRequests, metadata); - } - - static LeaseImpl empty() { - return new LeaseImpl(0, 0, null); - } - - private LeaseImpl(int timeToLiveMillis, int allowedRequests, @Nullable ByteBuf metadata) { - this.allowedRequests = new AtomicInteger(allowedRequests); - this.startingAllowedRequests = allowedRequests; - this.timeToLiveMillis = timeToLiveMillis; - this.metadata = metadata == null ? Unpooled.EMPTY_BUFFER : metadata; - this.expiry = timeToLiveMillis == 0 ? 0 : now() + timeToLiveMillis; - } - - public int getTimeToLiveMillis() { - return timeToLiveMillis; - } - - @Override - public int getAllowedRequests() { - return Math.max(0, allowedRequests.get()); - } - - @Override - public int getStartingAllowedRequests() { - return startingAllowedRequests; - } - - @Override - public ByteBuf getMetadata() { - return metadata; - } - - @Override - public long expiry() { - return expiry; - } - - @Override - public boolean isValid() { - return !isEmpty() && getAllowedRequests() > 0 && !isExpired(); - } - - /** - * try use 1 allowed request of Lease - * - * @return true if used successfully, false if Lease is expired or no allowed requests available - */ - public boolean use() { - if (isExpired()) { - return false; - } - int remaining = - allowedRequests.accumulateAndGet(1, (cur, update) -> Math.max(-1, cur - update)); - return remaining >= 0; - } - - @Override - public double availability() { - return isValid() ? getAllowedRequests() / (double) getStartingAllowedRequests() : 0.0; - } - - @Override - public String toString() { - long now = now(); - return "LeaseImpl{" - + "timeToLiveMillis=" - + timeToLiveMillis - + ", allowedRequests=" - + getAllowedRequests() - + ", startingAllowedRequests=" - + startingAllowedRequests - + ", expired=" - + isExpired(now) - + ", remainingTimeToLiveMillis=" - + getRemainingTimeToLiveMillis(now) - + '}'; - } - - private static long now() { - return System.currentTimeMillis(); - } - - private static void assertLease(int timeToLiveMillis, int numberOfRequests) { - if (numberOfRequests <= 0) { - throw new IllegalArgumentException("Number of requests must be positive"); - } - if (timeToLiveMillis <= 0) { - throw new IllegalArgumentException("Time-to-live must be positive"); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java new file mode 100644 index 000000000..48bd38494 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java @@ -0,0 +1,8 @@ +package io.rsocket.lease; + +import reactor.core.publisher.Flux; + +public interface LeaseSender { + + Flux send(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java deleted file mode 100644 index 791f5a023..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -public interface LeaseStats { - - void onEvent(EventType eventType); - - enum EventType { - ACCEPT, - REJECT, - TERMINATE - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Leases.java b/rsocket-core/src/main/java/io/rsocket/lease/Leases.java deleted file mode 100644 index 4c90e38ce..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/Leases.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import java.util.Objects; -import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Function; -import reactor.core.publisher.Flux; - -public class Leases { - private static final Function> noopLeaseSender = leaseStats -> Flux.never(); - private static final Consumer> noopLeaseReceiver = leases -> {}; - - private Function> leaseSender = noopLeaseSender; - private Consumer> leaseReceiver = noopLeaseReceiver; - private Optional stats = Optional.empty(); - - public static Leases create() { - return new Leases<>(); - } - - public Leases sender(Function, Flux> leaseSender) { - this.leaseSender = leaseSender; - return this; - } - - public Leases receiver(Consumer> leaseReceiver) { - this.leaseReceiver = leaseReceiver; - return this; - } - - public Leases stats(T stats) { - this.stats = Optional.of(Objects.requireNonNull(stats)); - return this; - } - - @SuppressWarnings("unchecked") - public Function, Flux> sender() { - return (Function, Flux>) leaseSender; - } - - public Consumer> receiver() { - return leaseReceiver; - } - - @SuppressWarnings("unchecked") - public Optional stats() { - return (Optional) stats; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java index 3b6cec62c..84af91b1b 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java @@ -16,35 +16,16 @@ package io.rsocket.lease; import io.rsocket.exceptions.RejectedException; -import java.util.Objects; -import reactor.util.annotation.Nullable; public class MissingLeaseException extends RejectedException { private static final long serialVersionUID = -6169748673403858959L; - public MissingLeaseException(Lease lease, String tag) { - super(leaseMessage(Objects.requireNonNull(lease), Objects.requireNonNull(tag))); - } - - public MissingLeaseException(String tag) { - super(leaseMessage(null, Objects.requireNonNull(tag))); + public MissingLeaseException(String message) { + super(message); } @Override public synchronized Throwable fillInStackTrace() { return this; } - - static String leaseMessage(@Nullable Lease lease, String tag) { - if (lease == null) { - return String.format("[%s] Missing leases", tag); - } - if (lease.isEmpty()) { - return String.format("[%s] Lease was not received yet", tag); - } - boolean expired = lease.isExpired(); - int allowedRequests = lease.getAllowedRequests(); - return String.format( - "[%s] Missing leases. Expired: %b, allowedRequests: %d", tag, expired, allowedRequests); - } } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java deleted file mode 100644 index fd569a2c8..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import io.netty.buffer.ByteBuf; -import io.rsocket.Availability; -import io.rsocket.frame.LeaseFrameCodec; -import java.util.function.Consumer; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.ReplayProcessor; - -public interface RequesterLeaseHandler extends Availability, Disposable { - - boolean useLease(); - - Exception leaseError(); - - void receive(ByteBuf leaseFrame); - - void dispose(); - - final class Impl implements RequesterLeaseHandler { - private final String tag; - private final ReplayProcessor receivedLease; - private volatile LeaseImpl currentLease = LeaseImpl.empty(); - - public Impl(String tag, Consumer> leaseReceiver) { - this.tag = tag; - receivedLease = ReplayProcessor.create(1); - leaseReceiver.accept(receivedLease); - } - - @Override - public boolean useLease() { - return currentLease.use(); - } - - @Override - public Exception leaseError() { - LeaseImpl l = this.currentLease; - String t = this.tag; - if (!l.isValid()) { - return new MissingLeaseException(l, t); - } else { - return new MissingLeaseException(t); - } - } - - @Override - public void receive(ByteBuf leaseFrame) { - int numberOfRequests = LeaseFrameCodec.numRequests(leaseFrame); - int timeToLiveMillis = LeaseFrameCodec.ttl(leaseFrame); - ByteBuf metadata = LeaseFrameCodec.metadata(leaseFrame); - LeaseImpl lease = LeaseImpl.create(timeToLiveMillis, numberOfRequests, metadata); - currentLease = lease; - receivedLease.onNext(lease); - } - - @Override - public void dispose() { - receivedLease.onComplete(); - } - - @Override - public boolean isDisposed() { - return receivedLease.isTerminated(); - } - - @Override - public double availability() { - return currentLease.availability(); - } - } - - RequesterLeaseHandler None = - new RequesterLeaseHandler() { - @Override - public boolean useLease() { - return true; - } - - @Override - public Exception leaseError() { - throw new AssertionError("Error not possible with NOOP leases handler"); - } - - @Override - public void receive(ByteBuf leaseFrame) {} - - @Override - public void dispose() {} - - @Override - public double availability() { - return 1.0; - } - }; -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java deleted file mode 100644 index df8787cb7..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Availability; -import io.rsocket.frame.LeaseFrameCodec; -import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Function; -import reactor.core.Disposable; -import reactor.core.Disposables; -import reactor.core.publisher.Flux; -import reactor.util.annotation.Nullable; - -public interface ResponderLeaseHandler extends Availability { - - boolean useLease(); - - Exception leaseError(); - - Disposable send(Consumer leaseFrameSender); - - final class Impl implements ResponderLeaseHandler { - private volatile LeaseImpl currentLease = LeaseImpl.empty(); - private final String tag; - private final ByteBufAllocator allocator; - private final Function, Flux> leaseSender; - private final Optional leaseStatsOption; - private final T leaseStats; - - public Impl( - String tag, - ByteBufAllocator allocator, - Function, Flux> leaseSender, - Optional leaseStatsOption) { - this.tag = tag; - this.allocator = allocator; - this.leaseSender = leaseSender; - this.leaseStatsOption = leaseStatsOption; - this.leaseStats = leaseStatsOption.orElse(null); - } - - @Override - public boolean useLease() { - boolean success = currentLease.use(); - onUseEvent(success, leaseStats); - return success; - } - - @Override - public Exception leaseError() { - LeaseImpl l = currentLease; - String t = tag; - if (!l.isValid()) { - return new MissingLeaseException(l, t); - } else { - return new MissingLeaseException(t); - } - } - - @Override - public Disposable send(Consumer leaseFrameSender) { - return leaseSender - .apply(leaseStatsOption) - .doOnTerminate(this::onTerminateEvent) - .subscribe( - lease -> { - currentLease = create(lease); - leaseFrameSender.accept(createLeaseFrame(lease)); - }); - } - - @Override - public double availability() { - return currentLease.availability(); - } - - private ByteBuf createLeaseFrame(Lease lease) { - return LeaseFrameCodec.encode( - allocator, lease.getTimeToLiveMillis(), lease.getAllowedRequests(), lease.getMetadata()); - } - - private void onTerminateEvent() { - T ls = leaseStats; - if (ls != null) { - ls.onEvent(LeaseStats.EventType.TERMINATE); - } - } - - private void onUseEvent(boolean success, @Nullable T ls) { - if (ls != null) { - LeaseStats.EventType eventType = - success ? LeaseStats.EventType.ACCEPT : LeaseStats.EventType.REJECT; - ls.onEvent(eventType); - } - } - - private static LeaseImpl create(Lease lease) { - if (lease instanceof LeaseImpl) { - return (LeaseImpl) lease; - } else { - return LeaseImpl.create( - lease.getTimeToLiveMillis(), lease.getAllowedRequests(), lease.getMetadata()); - } - } - } - - ResponderLeaseHandler None = - new ResponderLeaseHandler() { - @Override - public boolean useLease() { - return true; - } - - @Override - public Exception leaseError() { - throw new AssertionError("Error not possible with NOOP leases handler"); - } - - @Override - public Disposable send(Consumer leaseFrameSender) { - return Disposables.disposed(); - } - - @Override - public double availability() { - return 1.0; - } - }; -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java b/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java new file mode 100644 index 000000000..3e6f68321 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java @@ -0,0 +1,5 @@ +package io.rsocket.lease; + +import io.rsocket.plugins.RequestInterceptor; + +public interface TrackingLeaseSender extends LeaseSender, RequestInterceptor {} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java new file mode 100644 index 000000000..fdbbeb25d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java @@ -0,0 +1,235 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Implementation of {@link WeightedStats} that manages tracking state and exposes the required + * stats. + * + *

A sub-class or a different class (delegation) needs to call {@link #startStream()}, {@link + * #stopStream()}, {@link #startRequest()}, and {@link #stopRequest(long)} to drive state tracking. + * + * @since 1.1 + * @see WeightedStatsRequestInterceptor + */ +public class BaseWeightedStats implements WeightedStats { + + private static final double DEFAULT_LOWER_QUANTILE = 0.5; + private static final double DEFAULT_HIGHER_QUANTILE = 0.8; + private static final int INACTIVITY_FACTOR = 500; + private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = + Clock.unit().convert(1L, TimeUnit.SECONDS); + + private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; + + private final Quantile lowerQuantile; + private final Quantile higherQuantile; + private final Ewma availabilityPercentage; + private final Median median; + private final Ewma interArrivalTime; + + private final long tau; + private final long inactivityFactor; + + private long errorStamp; // last we got an error + private long stamp; // last timestamp we sent a request + private long stamp0; // last timestamp we sent a request or receive a response + private long duration; // instantaneous cumulative duration + + private volatile int pendingRequests; // instantaneous rate + private static final AtomicIntegerFieldUpdater PENDING_REQUESTS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingRequests"); + private volatile int pendingStreams; // number of active streams + private static final AtomicIntegerFieldUpdater PENDING_STREAMS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingStreams"); + + protected BaseWeightedStats() { + this( + new FrugalQuantile(DEFAULT_LOWER_QUANTILE), + new FrugalQuantile(DEFAULT_HIGHER_QUANTILE), + INACTIVITY_FACTOR); + } + + private BaseWeightedStats( + Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { + this.lowerQuantile = lowerQuantile; + this.higherQuantile = higherQuantile; + this.inactivityFactor = inactivityFactor; + + long now = Clock.now(); + this.stamp = now; + this.errorStamp = now; + this.stamp0 = now; + this.duration = 0L; + this.pendingRequests = 0; + this.median = new Median(); + this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); + this.availabilityPercentage = new Ewma(5, TimeUnit.SECONDS, 1.0); + this.tau = Clock.unit().convert((long) (5 / Math.log(2)), TimeUnit.SECONDS); + } + + @Override + public double lowerQuantileLatency() { + return lowerQuantile.estimation(); + } + + @Override + public double higherQuantileLatency() { + return higherQuantile.estimation(); + } + + @Override + public int pending() { + return pendingRequests + pendingStreams; + } + + @Override + public double weightedAvailability() { + if (Clock.now() - stamp > tau) { + updateAvailability(1.0); + } + return availabilityPercentage.value(); + } + + @Override + public double predictedLatency() { + final long now = Clock.now(); + final long elapsed; + + synchronized (this) { + elapsed = Math.max(now - stamp, 1L); + } + + final double latency; + final double prediction = median.estimation(); + + final int pending = this.pending(); + if (prediction == 0.0) { + if (pending == 0) { + latency = 0.0; // first request + } else { + // subsequent requests while we don't have any history + latency = STARTUP_PENALTY + pending; + } + } else if (pending == 0 && elapsed > inactivityFactor * interArrivalTime.value()) { + // if we did't see any data for a while, we decay the prediction by inserting + // artificial 0.0 into the median + median.insert(0.0); + latency = median.estimation(); + } else { + final double predicted = prediction * pending; + final double instant = instantaneous(now, pending); + + if (predicted < instant) { // NB: (0.0 < 0.0) == false + latency = instant / pending; // NB: pending never equal 0 here + } else { + // we are under the predictions + latency = prediction; + } + } + + return latency; + } + + long instantaneous(long now, int pending) { + return duration + (now - stamp0) * pending; + } + + void startStream() { + PENDING_STREAMS.incrementAndGet(this); + } + + void stopStream() { + PENDING_STREAMS.decrementAndGet(this); + } + + synchronized long startRequest() { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + interArrivalTime.insert(now - stamp); + duration += Math.max(0, now - stamp0) * pendingRequests; + PENDING_REQUESTS.lazySet(this, pendingRequests + 1); + stamp = now; + stamp0 = now; + + return now; + } + + synchronized long stopRequest(long timestamp) { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + duration += Math.max(0, now - stamp0) * pendingRequests - (now - timestamp); + PENDING_REQUESTS.lazySet(this, pendingRequests - 1); + stamp0 = now; + + return now; + } + + synchronized void record(double roundTripTime) { + median.insert(roundTripTime); + lowerQuantile.insert(roundTripTime); + higherQuantile.insert(roundTripTime); + } + + void updateAvailability(double value) { + availabilityPercentage.insert(value); + if (value == 0.0d) { + synchronized (this) { + errorStamp = Clock.now(); + } + } + } + + @Override + public String toString() { + return "Stats{" + + "lowerQuantile=" + + lowerQuantile.estimation() + + ", higherQuantile=" + + higherQuantile.estimation() + + ", inactivityFactor=" + + inactivityFactor + + ", tau=" + + tau + + ", errorPercentage=" + + availabilityPercentage.value() + + ", pending=" + + pendingRequests + + ", errorStamp=" + + errorStamp + + ", stamp=" + + stamp + + ", stamp0=" + + stamp0 + + ", duration=" + + duration + + ", median=" + + median.estimation() + + ", interArrivalTime=" + + interArrivalTime.value() + + ", pendingStreams=" + + pendingStreams + + ", availability=" + + availabilityPercentage.value() + + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java new file mode 100644 index 000000000..528f4f896 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.InterceptorRegistry; + +/** + * A {@link LoadbalanceStrategy} with an interest in configuring the {@link RSocketConnector} for + * connecting to load-balance targets in order to hook into request lifecycle and track usage + * statistics. + * + *

Currently this callback interface is supported for strategies configured in {@link + * LoadbalanceRSocketClient}. + * + * @since 1.1 + */ +public interface ClientLoadbalanceStrategy extends LoadbalanceStrategy { + + /** + * Initialize the connector, for example using the {@link InterceptorRegistry}, to intercept + * requests. + * + * @param connector the connector to configure + */ + void initialize(RSocketConnector connector); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java index 4812114dd..0f87f6510 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java @@ -18,6 +18,7 @@ import io.rsocket.util.Clock; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; /** * Compute the exponential weighted moving average of a series of values. The time at which you @@ -28,20 +29,27 @@ * equal to (200 - 100)/2 = 150 (half of the distance between the new and the old value) */ class Ewma { - private final long tau; - private volatile long stamp; - private volatile double ewma; + + final long tau; + + volatile long stamp; + static final AtomicLongFieldUpdater STAMP = + AtomicLongFieldUpdater.newUpdater(Ewma.class, "stamp"); + volatile double ewma; public Ewma(long halfLife, TimeUnit unit, double initialValue) { this.tau = Clock.unit().convert((long) (halfLife / Math.log(2)), unit); - stamp = 0L; - ewma = initialValue; + + this.ewma = initialValue; + + STAMP.lazySet(this, 0L); } public synchronized void insert(double x) { - long now = Clock.now(); - double elapsed = Math.max(0, now - stamp); - stamp = now; + final long now = Clock.now(); + final double elapsed = Math.max(0, now - stamp); + + STAMP.lazySet(this, now); double w = Math.exp(-elapsed / tau); ewma = w * ewma + (1.0 - w) * x; diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java index 337edc530..6c2b9c3ea 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java @@ -83,7 +83,7 @@ public final Context currentContext() { @Nullable @Override - public Object scanUnsafe(Attr key) { + public final Object scanUnsafe(Attr key) { long state = this.requested; if (key == Attr.PARENT) { @@ -145,7 +145,7 @@ public final void onNext(Payload payload) { } @Override - public void onError(Throwable t) { + public final void onError(Throwable t) { if (this.done) { Operators.onErrorDropped(t, this.actual.currentContext()); return; @@ -156,7 +156,7 @@ public void onError(Throwable t) { } @Override - public void onComplete() { + public final void onComplete() { if (this.done) { return; } @@ -206,7 +206,7 @@ public final void request(long n) { } } - public void cancel() { + public final void cancel() { long state = REQUESTED.getAndSet(this, STATE_TERMINATED); if (state == STATE_TERMINATED) { return; diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java index efa32ff83..cdbdc19b3 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java @@ -26,12 +26,14 @@ *

More info: http://blog.aggregateknowledge.com/2013/09/16/sketch-of-the-day-frugal-streaming/ */ class FrugalQuantile implements Quantile { - private final double increment; - volatile double estimate; + final double increment; + final SplittableRandom rnd; + int step; int sign; - private double quantile; - private SplittableRandom rnd; + double quantile; + + volatile double estimate; public FrugalQuantile(double quantile, double increment) { this.increment = increment; @@ -63,7 +65,8 @@ public synchronized void insert(double x) { estimate = x; sign = 1; } else { - double v = rnd.nextDouble(); + final double v = rnd.nextDouble(); + final double estimate = this.estimate; if (x > estimate && v > (1 - quantile)) { higher(x); @@ -74,6 +77,8 @@ public synchronized void insert(double x) { } private void higher(double x) { + double estimate = this.estimate; + step += sign * increment; if (step > 0) { @@ -92,9 +97,13 @@ private void higher(double x) { } sign = 1; + + this.estimate = estimate; } private void lower(double x) { + double estimate = this.estimate; + step -= sign * increment; if (step > 0) { @@ -113,6 +122,8 @@ private void lower(double x) { } sign = -1; + + this.estimate = estimate; } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java new file mode 100644 index 000000000..eebf82fe9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java @@ -0,0 +1,1005 @@ +/* + * Copyright 2014-2020 Real Logic Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import java.io.Serializable; +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.IntToLongFunction; +import reactor.util.annotation.Nullable; + +/** A open addressing with linear probing hash map specialised for primitive key and value pairs. */ +class Int2LongHashMap implements Map, Serializable { + static final float DEFAULT_LOAD_FACTOR = 0.55f; + static final int MIN_CAPACITY = 8; + private static final long serialVersionUID = -690554872053575793L; + + private final float loadFactor; + private final long missingValue; + private int resizeThreshold; + private int size = 0; + private final boolean shouldAvoidAllocation; + + private long[] entries; + private KeySet keySet; + private ValueCollection values; + private EntrySet entrySet; + + /** @param missingValue for the map that represents null. */ + public Int2LongHashMap(final long missingValue) { + this(MIN_CAPACITY, DEFAULT_LOAD_FACTOR, missingValue); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + */ + public Int2LongHashMap( + final int initialCapacity, final float loadFactor, final long missingValue) { + this(initialCapacity, loadFactor, missingValue, true); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + * @param shouldAvoidAllocation should allocation be avoided by caching iterators and map entries. + */ + public Int2LongHashMap( + final int initialCapacity, + final float loadFactor, + final long missingValue, + final boolean shouldAvoidAllocation) { + validateLoadFactor(loadFactor); + + this.loadFactor = loadFactor; + this.missingValue = missingValue; + this.shouldAvoidAllocation = shouldAvoidAllocation; + + capacity(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, initialCapacity))); + } + + /** + * The value to be used as a null marker in the map. + * + * @return value to be used as a null marker in the map. + */ + public long missingValue() { + return missingValue; + } + + /** + * Get the load factor applied for resize operations. + * + * @return the load factor applied for resize operations. + */ + public float loadFactor() { + return loadFactor; + } + + /** + * Get the total capacity for the map to which the load factor will be a fraction of. + * + * @return the total capacity for the map. + */ + public int capacity() { + return entries.length >> 1; + } + + /** + * Get the actual threshold which when reached the map will resize. This is a function of the + * current capacity and load factor. + * + * @return the threshold when the map will resize. + */ + public int resizeThreshold() { + return resizeThreshold; + } + + /** {@inheritDoc} */ + public int size() { + return size; + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return size == 0; + } + + /** + * Get a value using provided key avoiding boxing. + * + * @param key lookup key. + * @return value associated with the key or {@link #missingValue()} if key is not found in the + * map. + */ + public long get(final int key) { + final int mask = entries.length - 1; + int index = evenHash(key, mask); + + long value = missingValue; + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + value = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + return value; + } + + /** + * Put a key value pair in the map. + * + * @param key lookup key + * @param value new value, must not be {@link #missingValue()} + * @return previous value associated with the key, or {@link #missingValue()} if none found + * @throws IllegalArgumentException if value is {@link #missingValue()} + */ + public long put(final int key, final long value) { + if (value == missingValue) { + throw new IllegalArgumentException("cannot accept missingValue"); + } + + final int mask = entries.length - 1; + int index = evenHash(key, mask); + long oldValue = missingValue; + + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + oldValue = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + if (oldValue == missingValue) { + ++size; + entries[index] = key; + } + + entries[index + 1] = value; + + increaseCapacity(); + + return oldValue; + } + + private void increaseCapacity() { + if (size > resizeThreshold) { + // entries.length = 2 * capacity + final int newCapacity = entries.length; + rehash(newCapacity); + } + } + + private void rehash(final int newCapacity) { + final long[] oldEntries = entries; + final int length = entries.length; + + capacity(newCapacity); + + final long[] newEntries = entries; + final int mask = entries.length - 1; + + for (int keyIndex = 0; keyIndex < length; keyIndex += 2) { + final long value = oldEntries[keyIndex + 1]; + if (value != missingValue) { + final int key = (int) oldEntries[keyIndex]; + int index = evenHash(key, mask); + + while (newEntries[index + 1] != missingValue) { + index = next(index, mask); + } + + newEntries[index] = key; + newEntries[index + 1] = value; + } + } + } + + /** + * Int primitive specialised containsKey. + * + * @param key the key to check. + * @return true if the map contains key as a key, false otherwise. + */ + public boolean containsKey(final int key) { + return get(key) != missingValue; + } + + /** + * Does the map contain the value. + * + * @param value to be tested against contained values. + * @return true if contained otherwise value. + */ + public boolean containsValue(final long value) { + boolean found = false; + if (value != missingValue) { + final int length = entries.length; + int remaining = size; + + for (int valueIndex = 1; remaining > 0 && valueIndex < length; valueIndex += 2) { + if (missingValue != entries[valueIndex]) { + if (value == entries[valueIndex]) { + found = true; + break; + } + --remaining; + } + } + } + + return found; + } + + /** {@inheritDoc} */ + public void clear() { + if (size > 0) { + Arrays.fill(entries, missingValue); + size = 0; + } + } + + /** + * Compact the backing arrays by rehashing with a capacity just larger than current size and + * giving consideration to the load factor. + */ + public void compact() { + final int idealCapacity = (int) Math.round(size() * (1.0d / loadFactor)); + rehash(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, idealCapacity))); + } + + /** + * Primitive specialised version of {@link #computeIfAbsent(Object, Function)} + * + * @param key to search on. + * @param mappingFunction to provide a value if the get returns null. + * @return the value if found otherwise the missing value. + */ + public long computeIfAbsent(final int key, final IntToLongFunction mappingFunction) { + long value = get(key); + if (value == missingValue) { + value = mappingFunction.applyAsLong(key); + if (value != missingValue) { + put(key, value); + } + } + + return value; + } + + // ---------------- Boxed Versions Below ---------------- + + /** {@inheritDoc} */ + @Nullable + public Long get(final Object key) { + return valOrNull(get((int) key)); + } + + /** {@inheritDoc} */ + public Long put(final Integer key, final Long value) { + return valOrNull(put((int) key, (long) value)); + } + + /** {@inheritDoc} */ + public boolean containsKey(final Object key) { + return containsKey((int) key); + } + + /** {@inheritDoc} */ + public boolean containsValue(final Object value) { + return containsValue((long) value); + } + + /** {@inheritDoc} */ + public void putAll(final Map map) { + for (final Map.Entry entry : map.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + /** {@inheritDoc} */ + public KeySet keySet() { + if (null == keySet) { + keySet = new KeySet(); + } + + return keySet; + } + + /** {@inheritDoc} */ + public ValueCollection values() { + if (null == values) { + values = new ValueCollection(); + } + + return values; + } + + /** {@inheritDoc} */ + public EntrySet entrySet() { + if (null == entrySet) { + entrySet = new EntrySet(); + } + + return entrySet; + } + + /** {@inheritDoc} */ + @Nullable + public Long remove(final Object key) { + return valOrNull(remove((int) key)); + } + + /** + * Remove value from the map using given key avoiding boxing. + * + * @param key whose mapping is to be removed from the map. + * @return removed value or {@link #missingValue()} if key was not found in the map. + */ + public long remove(final int key) { + final int mask = entries.length - 1; + int keyIndex = evenHash(key, mask); + + long oldValue = missingValue; + while (entries[keyIndex + 1] != missingValue) { + if (entries[keyIndex] == key) { + oldValue = entries[keyIndex + 1]; + entries[keyIndex + 1] = missingValue; + size--; + + compactChain(keyIndex); + + break; + } + + keyIndex = next(keyIndex, mask); + } + + return oldValue; + } + + @SuppressWarnings("FinalParameters") + private void compactChain(int deleteKeyIndex) { + final int mask = entries.length - 1; + int keyIndex = deleteKeyIndex; + + while (true) { + keyIndex = next(keyIndex, mask); + if (entries[keyIndex + 1] == missingValue) { + break; + } + + final int hash = evenHash((int) entries[keyIndex], mask); + + if ((keyIndex < hash && (hash <= deleteKeyIndex || deleteKeyIndex <= keyIndex)) + || (hash <= deleteKeyIndex && deleteKeyIndex <= keyIndex)) { + entries[deleteKeyIndex] = entries[keyIndex]; + entries[deleteKeyIndex + 1] = entries[keyIndex + 1]; + + entries[keyIndex + 1] = missingValue; + deleteKeyIndex = keyIndex; + } + } + } + + /** + * Get the minimum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the minimum value stored in the map. + */ + public long minValue() { + final long missingValue = this.missingValue; + long min = size == 0 ? missingValue : Long.MAX_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + min = Math.min(min, value); + } + } + + return min; + } + + /** + * Get the maximum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the maximum value stored in the map. + */ + public long maxValue() { + final long missingValue = this.missingValue; + long max = size == 0 ? missingValue : Long.MIN_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + max = Math.max(max, value); + } + } + + return max; + } + + /** {@inheritDoc} */ + public String toString() { + if (isEmpty()) { + return "{}"; + } + + final EntryIterator entryIterator = new EntryIterator(); + entryIterator.reset(); + + final StringBuilder sb = new StringBuilder().append('{'); + while (true) { + entryIterator.next(); + sb.append(entryIterator.getIntKey()).append('=').append(entryIterator.getLongValue()); + if (!entryIterator.hasNext()) { + return sb.append('}').toString(); + } + sb.append(',').append(' '); + } + } + + /** + * Primitive specialised version of {@link #replace(Object, Object)} + * + * @param key key with which the specified value is associated + * @param value value to be associated with the specified key + * @return the previous value associated with the specified key, or {@link #missingValue()} if + * there was no mapping for the key. + */ + public long replace(final int key, final long value) { + long currentValue = get(key); + if (currentValue != missingValue) { + currentValue = put(key, value); + } + + return currentValue; + } + + /** + * Primitive specialised version of {@link #replace(Object, Object, Object)} + * + * @param key key with which the specified value is associated + * @param oldValue value expected to be associated with the specified key + * @param newValue value to be associated with the specified key + * @return {@code true} if the value was replaced + */ + public boolean replace(final int key, final long oldValue, final long newValue) { + final long curValue = get(key); + if (curValue != oldValue || curValue == missingValue) { + return false; + } + + put(key, newValue); + + return true; + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Map)) { + return false; + } + + final Map that = (Map) o; + + return size == that.size() && entrySet().equals(that.entrySet()); + } + + public int hashCode() { + return entrySet().hashCode(); + } + + private static int next(final int index, final int mask) { + return (index + 2) & mask; + } + + private void capacity(final int newCapacity) { + final int entriesLength = newCapacity * 2; + if (entriesLength < 0) { + throw new IllegalStateException("max capacity reached at size=" + size); + } + + /*@DoNotSub*/ resizeThreshold = (int) (newCapacity * loadFactor); + entries = new long[entriesLength]; + Arrays.fill(entries, missingValue); + } + + @Nullable + private Long valOrNull(final long value) { + return value == missingValue ? null : value; + } + + // ---------------- Utility Classes ---------------- + + /** Base iterator implementation. */ + abstract class AbstractIterator implements Serializable { + private static final long serialVersionUID = 5262459454112462433L; + /** Is current position valid. */ + protected boolean isPositionValid = false; + + private int remaining; + private int positionCounter; + private int stopCounter; + + final void reset() { + isPositionValid = false; + remaining = Int2LongHashMap.this.size; + final long missingValue = Int2LongHashMap.this.missingValue; + final long[] entries = Int2LongHashMap.this.entries; + final int capacity = entries.length; + + int keyIndex = capacity; + if (entries[capacity - 1] != missingValue) { + for (int i = 1; i < capacity; i += 2) { + if (entries[i] == missingValue) { + keyIndex = i - 1; + break; + } + } + } + + stopCounter = keyIndex; + positionCounter = keyIndex + capacity; + } + + /** + * Returns position of the key of the current entry. + * + * @return key position. + */ + protected final int keyPosition() { + return positionCounter & entries.length - 1; + } + + /** + * Number of remaining elements. + * + * @return number of remaining elements. + */ + public int remaining() { + return remaining; + } + + /** + * Check if there are more elements remaining. + * + * @return {@code true} if {@code remaining > 0}. + */ + public boolean hasNext() { + return remaining > 0; + } + + /** + * Advance to the next entry. + * + * @throws NoSuchElementException if no more entries available. + */ + protected final void findNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final long[] entries = Int2LongHashMap.this.entries; + final long missingValue = Int2LongHashMap.this.missingValue; + final int mask = entries.length - 1; + + for (int keyIndex = positionCounter - 2; keyIndex >= stopCounter; keyIndex -= 2) { + final int index = keyIndex & mask; + if (entries[index + 1] != missingValue) { + isPositionValid = true; + positionCounter = keyIndex; + --remaining; + return; + } + } + + isPositionValid = false; + throw new IllegalStateException(); + } + + /** {@inheritDoc} */ + public void remove() { + if (isPositionValid) { + final int position = keyPosition(); + entries[position + 1] = missingValue; + --size; + + compactChain(position); + + isPositionValid = false; + } else { + throw new IllegalStateException(); + } + } + } + + /** Iterator over keys which supports access to unboxed keys via {@link #nextValue()}. */ + public final class KeyIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = 9151493609653852972L; + + public Integer next() { + return nextValue(); + } + + /** + * Return next key. + * + * @return next key. + */ + public int nextValue() { + findNext(); + return (int) entries[keyPosition()]; + } + } + + /** Iterator over values which supports access to unboxed values. */ + public final class ValueIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = -5670291734793552927L; + + public Long next() { + return nextValue(); + } + + /** + * Return next value. + * + * @return next value. + */ + public long nextValue() { + findNext(); + return entries[keyPosition() + 1]; + } + } + + /** Iterator over entries which supports access to unboxed keys and values. */ + public final class EntryIterator extends AbstractIterator + implements Iterator>, Entry, Serializable { + private static final long serialVersionUID = 1744408438593481051L; + + public Integer getKey() { + return getIntKey(); + } + + /** + * Returns the key of the current entry. + * + * @return the key. + */ + public int getIntKey() { + return (int) entries[keyPosition()]; + } + + public Long getValue() { + return getLongValue(); + } + + /** + * Returns the value of the current entry. + * + * @return the value. + */ + public long getLongValue() { + return entries[keyPosition() + 1]; + } + + public Long setValue(final Long value) { + return setValue(value.longValue()); + } + + /** + * Sets the value of the current entry. + * + * @param value to be set. + * @return previous value of the entry. + */ + public long setValue(final long value) { + if (!isPositionValid) { + throw new IllegalStateException(); + } + + if (missingValue == value) { + throw new IllegalArgumentException(); + } + + final int keyPosition = keyPosition(); + final long prevValue = entries[keyPosition + 1]; + entries[keyPosition + 1] = value; + return prevValue; + } + + public Entry next() { + findNext(); + + if (shouldAvoidAllocation) { + return this; + } + + return allocateDuplicateEntry(); + } + + private Entry allocateDuplicateEntry() { + return new MapEntry(getIntKey(), getLongValue()); + } + + /** {@inheritDoc} */ + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Entry)) { + return false; + } + + final Entry that = (Entry) o; + + return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue()); + } + + /** An {@link java.util.Map.Entry} implementation. */ + public final class MapEntry implements Entry { + private final int k; + private final long v; + + /** + * Constructs entry with given key and value. + * + * @param k key. + * @param v value. + */ + public MapEntry(final int k, final long v) { + this.k = k; + this.v = v; + } + + public Integer getKey() { + return k; + } + + public Long getValue() { + return v; + } + + public Long setValue(final Long value) { + return Int2LongHashMap.this.put(k, value.longValue()); + } + + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + public boolean equals(final Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + + final Entry e = (Entry) o; + + return (e.getKey() != null && e.getValue() != null) + && (e.getKey().equals(k) && e.getValue().equals(v)); + } + + public String toString() { + return k + "=" + v; + } + } + } + + /** Set of keys which supports optional cached iterators to avoid allocation. */ + public final class KeySet extends AbstractSet implements Serializable { + private static final long serialVersionUID = -7645453993079742625L; + private final KeyIterator keyIterator = shouldAvoidAllocation ? new KeyIterator() : null; + + /** {@inheritDoc} */ + public KeyIterator iterator() { + KeyIterator keyIterator = this.keyIterator; + if (null == keyIterator) { + keyIterator = new KeyIterator(); + } + + keyIterator.reset(); + + return keyIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((int) o); + } + + /** + * Checks if key is contained in the map without boxing. + * + * @param key to check. + * @return {@code true} if key is contained in this map. + */ + public boolean contains(final int key) { + return containsKey(key); + } + } + + /** Collection of values which supports optionally cached iterators to avoid allocation. */ + public final class ValueCollection extends AbstractCollection implements Serializable { + private static final long serialVersionUID = -8925598924781601919L; + private final ValueIterator valueIterator = shouldAvoidAllocation ? new ValueIterator() : null; + + /** {@inheritDoc} */ + public ValueIterator iterator() { + ValueIterator valueIterator = this.valueIterator; + if (null == valueIterator) { + valueIterator = new ValueIterator(); + } + + valueIterator.reset(); + + return valueIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((long) o); + } + + /** + * Checks if the value is contained in the map. + * + * @param value to be checked. + * @return {@code true} if value is contained in this map. + */ + public boolean contains(final long value) { + return containsValue(value); + } + } + + /** Set of entries which supports optionally cached iterators to avoid allocation. */ + public final class EntrySet extends AbstractSet> + implements Serializable { + private static final long serialVersionUID = 63641283589916174L; + private final EntryIterator entryIterator = shouldAvoidAllocation ? new EntryIterator() : null; + + /** {@inheritDoc} */ + public EntryIterator iterator() { + EntryIterator entryIterator = this.entryIterator; + if (null == entryIterator) { + entryIterator = new EntryIterator(); + } + + entryIterator.reset(); + + return entryIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + if (!(o instanceof Entry)) { + return false; + } + final Entry entry = (Entry) o; + final Long value = get(entry.getKey()); + + return value != null && value.equals(entry.getValue()); + } + + /** {@inheritDoc} */ + public Object[] toArray() { + return toArray(new Object[size()]); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + public T[] toArray(final T[] a) { + final T[] array = + a.length >= size + ? a + : (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size); + final EntryIterator it = iterator(); + + for (int i = 0; i < array.length; i++) { + if (it.hasNext()) { + it.next(); + array[i] = (T) it.allocateDuplicateEntry(); + } else { + array[i] = null; + break; + } + } + + return array; + } + } + + private static int evenHash(final int value, final int mask) { + final int hash = (value << 1) - (value << 8); + + return hash & mask; + } + + private static void validateLoadFactor(final float loadFactor) { + if (loadFactor < 0.1f || loadFactor > 0.9f) { + throw new IllegalArgumentException( + "load factor must be in the range of 0.1 to 0.9: " + loadFactor); + } + } + + private static int findNextPositivePowerOfTwo(final int value) { + return 1 << (Integer.SIZE - Integer.numberOfLeadingZeros(value - 1)); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java index 89ae01f18..d59cbb86e 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import io.rsocket.RSocket; import io.rsocket.core.RSocketClient; import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; import java.util.List; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -26,8 +27,8 @@ import reactor.util.annotation.Nullable; /** - * {@link RSocketClient} implementation that uses a {@link LoadbalanceStrategy} to select the {@code - * RSocket} to use for a given request from a pool of possible targets. + * An implementation of {@link RSocketClient} backed by a pool of {@code RSocket} instances and + * using a {@link LoadbalanceStrategy} to select the {@code RSocket} to use for a given request. * * @since 1.1 */ @@ -39,6 +40,17 @@ private LoadbalanceRSocketClient(RSocketPool rSocketPool) { this.rSocketPool = rSocketPool; } + @Override + public Mono onClose() { + return rSocketPool.onClose(); + } + + @Override + public boolean connect() { + return rSocketPool.connect(); + } + + /** Return {@code Mono} that selects an RSocket from the underlying pool. */ @Override public Mono source() { return Mono.fromSupplier(rSocketPool::select); @@ -61,7 +73,7 @@ public Flux requestStream(Mono payloadMono) { @Override public Flux requestChannel(Publisher payloads) { - return rSocketPool.select().requestChannel(payloads); + return source().flatMapMany(rSocket -> rSocket.requestChannel(payloads)); } @Override @@ -75,7 +87,7 @@ public void dispose() { } /** - * Shortcut to create an {@link LoadbalanceRSocketClient} with round robin loadalancing. + * Shortcut to create an {@link LoadbalanceRSocketClient} with round-robin load balancing. * Effectively a shortcut for: * *

@@ -84,8 +96,8 @@ public void dispose() {
    *    .build();
    * 
* - * @param connector the {@link Builder#connector(RSocketConnector) to use - * @param targetPublisher publisher that periodically refreshes the list of targets to loadbalance across. + * @param connector a "template" for connecting to load balance targets + * @param targetPublisher refreshes the list of load balance targets periodically * @return the created client instance */ public static LoadbalanceRSocketClient create( @@ -94,11 +106,10 @@ public static LoadbalanceRSocketClient create( } /** - * Return a builder to create an {@link LoadbalanceRSocketClient} with. + * Return a builder for a {@link LoadbalanceRSocketClient}. * - * @param targetPublisher publisher that periodically refreshes the list of targets to loadbalance - * across. - * @return the builder instance + * @param targetPublisher refreshes the list of load balance targets periodically + * @return the created builder */ public static Builder builder(Publisher> targetPublisher) { return new Builder(targetPublisher); @@ -118,10 +129,11 @@ public static class Builder { } /** - * The given {@link RSocketConnector} is used as a template to produce the {@code Mono} - * source for each {@link LoadbalanceTarget}. This is done by passing the {@code - * ClientTransport} contained in every target to the {@code connect} method of the given - * connector instance. + * Configure the "template" connector to use for connecting to load balance targets. To + * establish a connection, the {@link LoadbalanceTarget#getTransport() ClientTransport} + * contained in each target is passed to the connector's {@link + * RSocketConnector#connect(ClientTransport) connect} method and thus the same connector with + * the same settings applies to all targets. * *

By default this is initialized with {@link RSocketConnector#create()}. * @@ -133,7 +145,7 @@ public Builder connector(RSocketConnector connector) { } /** - * Switch to using a round-robin strategy for selecting a target. + * Configure {@link RoundRobinLoadbalanceStrategy} as the strategy to use to select targets. * *

This is the strategy used by default. */ @@ -143,18 +155,17 @@ public Builder roundRobinLoadbalanceStrategy() { } /** - * Switch to using a strategy that assigns a weight to each pooled {@code RSocket} based on - * actual usage stats, and uses that to make a choice. + * Configure {@link WeightedLoadbalanceStrategy} as the strategy to use to select targets. * *

By default, {@link RoundRobinLoadbalanceStrategy} is used. */ public Builder weightedLoadbalanceStrategy() { - this.loadbalanceStrategy = new WeightedLoadbalanceStrategy(); + this.loadbalanceStrategy = WeightedLoadbalanceStrategy.create(); return this; } /** - * Provide the {@link LoadbalanceStrategy} to use. + * Configure the {@link LoadbalanceStrategy} to use. * *

By default, {@link RoundRobinLoadbalanceStrategy} is used. */ @@ -165,18 +176,20 @@ public Builder loadbalanceStrategy(LoadbalanceStrategy strategy) { /** Build the {@link LoadbalanceRSocketClient} instance. */ public LoadbalanceRSocketClient build() { - return new LoadbalanceRSocketClient( - new RSocketPool(initConnector(), this.targetPublisher, initLoadbalanceStrategy())); - } + final RSocketConnector connector = + (this.connector != null ? this.connector : RSocketConnector.create()); - private RSocketConnector initConnector() { - return (this.connector != null ? this.connector : RSocketConnector.create()); - } + final LoadbalanceStrategy strategy = + (this.loadbalanceStrategy != null + ? this.loadbalanceStrategy + : new RoundRobinLoadbalanceStrategy()); - private LoadbalanceStrategy initLoadbalanceStrategy() { - return (this.loadbalanceStrategy != null - ? this.loadbalanceStrategy - : new RoundRobinLoadbalanceStrategy()); + if (strategy instanceof ClientLoadbalanceStrategy) { + ((ClientLoadbalanceStrategy) strategy).initialize(connector); + } + + return new LoadbalanceRSocketClient( + new RSocketPool(connector, this.targetPublisher, strategy)); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java index 2a333959b..5662448e7 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,8 +18,21 @@ import io.rsocket.RSocket; import java.util.List; +/** + * Strategy to select an {@link RSocket} given a list of instances for load-balancing purposes. A + * simple implementation might go in round-robin fashion while a more sophisticated strategy might + * check availability, track usage stats, and so on. + * + * @since 1.1 + */ @FunctionalInterface public interface LoadbalanceStrategy { - RSocket select(List availableRSockets); + /** + * Select an {@link RSocket} from the given non-empty list. + * + * @param sockets the list to choose from + * @return the selected instance + */ + RSocket select(List sockets); } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java index e99914caa..3b5d71e4e 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2020 the original author or authors. + * Copyright 2002-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,14 +15,18 @@ */ package io.rsocket.loadbalance; +import io.rsocket.core.RSocketConnector; import io.rsocket.transport.ClientTransport; +import org.reactivestreams.Publisher; /** - * Simple container for a key and a {@link ClientTransport}, representing a specific target for - * loadbalancing purposes. The key is used to compare previous and new targets when refreshing the - * list of target to use. The transport is used to connect to the target. + * Representation for a load-balance target used as input to {@link LoadbalanceRSocketClient} that + * in turn maintains and peridodically updates a list of current load-balance targets. The {@link + * #getKey()} is used to identify a target uniquely while the {@link #getTransport() transport} is + * used to connect to the target server. * * @since 1.1 + * @see LoadbalanceRSocketClient#create(RSocketConnector, Publisher) */ public class LoadbalanceTarget { @@ -34,23 +38,22 @@ private LoadbalanceTarget(String key, ClientTransport transport) { this.transport = transport; } - /** Return the key for this target. */ + /** Return the key that identifies this target uniquely. */ public String getKey() { return key; } - /** Return the transport to use to connect to the target. */ + /** Return the transport to use to connect to the target server. */ public ClientTransport getTransport() { return transport; } /** - * Create a an instance of {@link LoadbalanceTarget} with the given key and {@link - * ClientTransport}. The key can be anything that can be used to identify identical targets, e.g. - * a SocketAddress, URL, etc. + * Create a new {@link LoadbalanceTarget} with the given key and {@link ClientTransport}. The key + * can be anything that identifies the target uniquely, e.g. SocketAddress, URL, and so on. * - * @param key the key to use to identify identical targets - * @param transport the transport to use for connecting to the target + * @param key identifies the load-balance target uniquely + * @param transport for connecting to the target * @return the created instance */ public static LoadbalanceTarget from(String key, ClientTransport transport) { diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java index 833bd5380..5319706f9 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java @@ -18,6 +18,7 @@ /** This implementation gives better results because it considers more data-point. */ class Median extends FrugalQuantile { + public Median() { super(0.5, 1.0); } @@ -32,6 +33,7 @@ public synchronized void insert(double x) { estimate = x; sign = 1; } else { + final double estimate = this.estimate; if (x > estimate) { greaterThanZero(x); } else if (x < estimate) { @@ -41,6 +43,8 @@ public synchronized void insert(double x) { } private void greaterThanZero(double x) { + double estimate = this.estimate; + step += sign; if (step > 0) { @@ -59,9 +63,13 @@ private void greaterThanZero(double x) { } sign = 1; + + this.estimate = estimate; } private void lessThanZero(double x) { + double estimate = this.estimate; + step -= sign; if (step > 0) { @@ -80,6 +88,8 @@ private void lessThanZero(double x) { } sign = -1; + + this.estimate = estimate; } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java index b37ec4b47..69838f1b6 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java @@ -59,7 +59,7 @@ abstract class MonoDeferredResolution extends Mono } @Override - public void subscribe(CoreSubscriber actual) { + public final void subscribe(CoreSubscriber actual) { if (this.requested == STATE_UNSUBSCRIBED && REQUESTED.compareAndSet(this, STATE_UNSUBSCRIBED, STATE_SUBSCRIBER_SET)) { @@ -145,7 +145,7 @@ public final void onNext(RESULT payload) { } @Override - public void onError(Throwable t) { + public final void onError(Throwable t) { if (this.done) { Operators.onErrorDropped(t, this.actual.currentContext()); return; @@ -156,7 +156,7 @@ public void onError(Throwable t) { } @Override - public void onComplete() { + public final void onComplete() { if (this.done) { return; } @@ -206,7 +206,7 @@ public final void request(long n) { } } - public void cancel() { + public final void cancel() { long state = REQUESTED.getAndSet(this, STATE_TERMINATED); if (state == STATE_TERMINATED) { return; diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledWeightedRSocket.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java similarity index 54% rename from rsocket-core/src/main/java/io/rsocket/loadbalance/PooledWeightedRSocket.java rename to rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java index ad681087e..a77329d31 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledWeightedRSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,31 +26,29 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.util.context.Context; -/** Default implementation of {@link WeightedRSocket} stored in {@link RSocketPool} */ -final class PooledWeightedRSocket extends ResolvingOperator - implements CoreSubscriber, WeightedRSocket { +/** Default implementation of {@link RSocket} stored in {@link RSocketPool} */ +final class PooledRSocket extends ResolvingOperator + implements CoreSubscriber, RSocket { final RSocketPool parent; final Mono rSocketSource; final LoadbalanceTarget loadbalanceTarget; - final Stats stats; + final Sinks.Empty onCloseSink; volatile Subscription s; - static final AtomicReferenceFieldUpdater S = - AtomicReferenceFieldUpdater.newUpdater(PooledWeightedRSocket.class, Subscription.class, "s"); + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(PooledRSocket.class, Subscription.class, "s"); - PooledWeightedRSocket( - RSocketPool parent, - Mono rSocketSource, - LoadbalanceTarget loadbalanceTarget, - Stats stats) { + PooledRSocket( + RSocketPool parent, Mono rSocketSource, LoadbalanceTarget loadbalanceTarget) { this.parent = parent; this.rSocketSource = rSocketSource; this.loadbalanceTarget = loadbalanceTarget; - this.stats = stats; + this.onCloseSink = Sinks.unsafe().empty(); } @Override @@ -90,8 +88,8 @@ public void onError(Throwable t) { } this.doFinally(); - // terminate upstream which means retryBackoff has exhausted - this.terminate(t); + // terminate upstream (retryBackoff has exhausted) and remove from the parent target list + this.doCleanup(t); } @Override @@ -113,27 +111,19 @@ protected void doSubscribe() { @Override protected void doOnValueResolved(RSocket value) { - stats.setAvailability(1.0); - value.onClose().subscribe(null, t -> this.invalidate(), this::invalidate); + value.onClose().subscribe(null, this::doCleanup, () -> doCleanup(ON_DISPOSE)); } - @Override - protected void doOnValueExpired(RSocket value) { - stats.setAvailability(0.0); - value.dispose(); - this.dispose(); - } + void doCleanup(Throwable t) { + if (isDisposed()) { + return; + } - @Override - public void dispose() { - super.dispose(); - } + this.terminate(t); - @Override - protected void doOnDispose() { final RSocketPool parent = this.parent; for (; ; ) { - final PooledWeightedRSocket[] sockets = parent.activeSockets; + final PooledRSocket[] sockets = parent.activeSockets; final int activeSocketsCount = sockets.length; int index = -1; @@ -148,74 +138,94 @@ protected void doOnDispose() { break; } - final int lastIndex = activeSocketsCount - 1; - final PooledWeightedRSocket[] newSockets = new PooledWeightedRSocket[lastIndex]; - if (index != 0) { - System.arraycopy(sockets, 0, newSockets, 0, index); - } + final PooledRSocket[] newSockets; + if (activeSocketsCount == 1) { + newSockets = RSocketPool.EMPTY; + } else { + final int lastIndex = activeSocketsCount - 1; - if (index != lastIndex) { - System.arraycopy(sockets, index + 1, newSockets, index, lastIndex - index); + newSockets = new PooledRSocket[lastIndex]; + if (index != 0) { + System.arraycopy(sockets, 0, newSockets, 0, index); + } + + if (index != lastIndex) { + System.arraycopy(sockets, index + 1, newSockets, index, lastIndex - index); + } } if (RSocketPool.ACTIVE_SOCKETS.compareAndSet(parent, sockets, newSockets)) { break; } } - stats.setAvailability(0.0); + + if (t == ON_DISPOSE) { + this.onCloseSink.tryEmitEmpty(); + } else { + this.onCloseSink.tryEmitError(t); + } + } + + @Override + protected void doOnValueExpired(RSocket value) { + value.dispose(); + } + + @Override + protected void doOnDispose() { Operators.terminate(S, this); + + final RSocket value = this.value; + if (value != null) { + value.onClose().subscribe(null, onCloseSink::tryEmitError, onCloseSink::tryEmitEmpty); + } else { + onCloseSink.tryEmitEmpty(); + } } @Override public Mono fireAndForget(Payload payload) { - return new RequestTrackingMonoInner<>(this, payload, FrameType.REQUEST_FNF); + return new MonoInner<>(this, payload, FrameType.REQUEST_FNF); } @Override public Mono requestResponse(Payload payload) { - return new RequestTrackingMonoInner<>(this, payload, FrameType.REQUEST_RESPONSE); + return new MonoInner<>(this, payload, FrameType.REQUEST_RESPONSE); } @Override public Flux requestStream(Payload payload) { - return new RequestTrackingFluxInner<>(this, payload, FrameType.REQUEST_STREAM); + return new FluxInner<>(this, payload, FrameType.REQUEST_STREAM); } @Override public Flux requestChannel(Publisher payloads) { - return new RequestTrackingFluxInner<>(this, payloads, FrameType.REQUEST_CHANNEL); + return new FluxInner<>(this, payloads, FrameType.REQUEST_CHANNEL); } @Override public Mono metadataPush(Payload payload) { - return new RequestTrackingMonoInner<>(this, payload, FrameType.METADATA_PUSH); + return new MonoInner<>(this, payload, FrameType.METADATA_PUSH); } - /** - * Indicates number of active requests - * - * @return number of requests in progress - */ - @Override - public Stats stats() { - return stats; + LoadbalanceTarget target() { + return this.loadbalanceTarget; } - LoadbalanceTarget target() { - return loadbalanceTarget; + @Override + public Mono onClose() { + return this.onCloseSink.asMono(); } @Override public double availability() { - return stats.availability(); + final RSocket socket = valueIfResolved(); + return socket != null ? socket.availability() : 0.0d; } - static final class RequestTrackingMonoInner - extends MonoDeferredResolution { - - long startTime; + static final class MonoInner extends MonoDeferredResolution { - RequestTrackingMonoInner(PooledWeightedRSocket parent, Payload payload, FrameType requestType) { + MonoInner(PooledRSocket parent, Payload payload, FrameType requestType) { super(parent, payload, requestType); } @@ -249,58 +259,16 @@ public void accept(RSocket rSocket, Throwable t) { return; } - startTime = ((PooledWeightedRSocket) parent).stats.startRequest(); - source.subscribe((CoreSubscriber) this); } else { - parent.add(this); - } - } - - @Override - public void onComplete() { - final long state = this.requested; - if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) { - final Stats stats = ((PooledWeightedRSocket) parent).stats; - final long now = stats.stopRequest(startTime); - stats.record(now - startTime); - super.onComplete(); - } - } - - @Override - public void onError(Throwable t) { - final long state = this.requested; - if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) { - Stats stats = ((PooledWeightedRSocket) parent).stats; - stats.stopRequest(startTime); - stats.recordError(0.0); - super.onError(t); - } - } - - @Override - public void cancel() { - long state = REQUESTED.getAndSet(this, STATE_TERMINATED); - if (state == STATE_TERMINATED) { - return; - } - - if (state == STATE_SUBSCRIBED) { - this.s.cancel(); - ((PooledWeightedRSocket) parent).stats.stopRequest(startTime); - } else { - this.parent.remove(this); - ReferenceCountUtil.safeRelease(this.payload); + parent.observe(this); } } } - static final class RequestTrackingFluxInner - extends FluxDeferredResolution { + static final class FluxInner extends FluxDeferredResolution { - RequestTrackingFluxInner( - PooledWeightedRSocket parent, INPUT fluxOrPayload, FrameType requestType) { + FluxInner(PooledRSocket parent, INPUT fluxOrPayload, FrameType requestType) { super(parent, fluxOrPayload, requestType); } @@ -333,47 +301,9 @@ public void accept(RSocket rSocket, Throwable t) { return; } - ((PooledWeightedRSocket) parent).stats.startStream(); - source.subscribe(this); } else { - parent.add(this); - } - } - - @Override - public void onComplete() { - final long state = this.requested; - if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) { - ((PooledWeightedRSocket) parent).stats.stopStream(); - super.onComplete(); - } - } - - @Override - public void onError(Throwable t) { - final long state = this.requested; - if (state != TERMINATED_STATE && REQUESTED.compareAndSet(this, state, TERMINATED_STATE)) { - ((PooledWeightedRSocket) parent).stats.stopStream(); - super.onError(t); - } - } - - @Override - public void cancel() { - long state = REQUESTED.getAndSet(this, STATE_TERMINATED); - if (state == STATE_TERMINATED) { - return; - } - - if (state == STATE_SUBSCRIBED) { - this.s.cancel(); - ((PooledWeightedRSocket) parent).stats.stopStream(); - } else { - this.parent.remove(this); - if (requestType == FrameType.REQUEST_STREAM) { - ReferenceCountUtil.safeRelease(this.fluxOrPayload); - } + parent.observe(this); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java index dbd05abcb..59d9678d0 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package io.rsocket.loadbalance; import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.core.RSocketConnector; @@ -28,35 +29,32 @@ import java.util.ListIterator; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Supplier; +import java.util.stream.Collectors; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; -class RSocketPool extends ResolvingOperator - implements CoreSubscriber>, List { +class RSocketPool extends ResolvingOperator + implements CoreSubscriber>, Closeable { + static final AtomicReferenceFieldUpdater ACTIVE_SOCKETS = + AtomicReferenceFieldUpdater.newUpdater( + RSocketPool.class, PooledRSocket[].class, "activeSockets"); + static final PooledRSocket[] EMPTY = new PooledRSocket[0]; + static final PooledRSocket[] TERMINATED = new PooledRSocket[0]; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(RSocketPool.class, Subscription.class, "s"); final DeferredResolutionRSocket deferredResolutionRSocket = new DeferredResolutionRSocket(this); final RSocketConnector connector; final LoadbalanceStrategy loadbalanceStrategy; - final Supplier statsSupplier; - - volatile PooledWeightedRSocket[] activeSockets; - - static final AtomicReferenceFieldUpdater ACTIVE_SOCKETS = - AtomicReferenceFieldUpdater.newUpdater( - RSocketPool.class, PooledWeightedRSocket[].class, "activeSockets"); - - static final PooledWeightedRSocket[] EMPTY = new PooledWeightedRSocket[0]; - static final PooledWeightedRSocket[] TERMINATED = new PooledWeightedRSocket[0]; - + final Sinks.Empty onAllClosedSink = Sinks.unsafe().empty(); + volatile PooledRSocket[] activeSockets; volatile Subscription s; - static final AtomicReferenceFieldUpdater S = - AtomicReferenceFieldUpdater.newUpdater(RSocketPool.class, Subscription.class, "s"); public RSocketPool( RSocketConnector connector, @@ -64,17 +62,17 @@ public RSocketPool( LoadbalanceStrategy loadbalanceStrategy) { this.connector = connector; this.loadbalanceStrategy = loadbalanceStrategy; - if (loadbalanceStrategy instanceof WeightedLoadbalanceStrategy) { - this.statsSupplier = Stats::create; - } else { - this.statsSupplier = Stats::noOps; - } ACTIVE_SOCKETS.lazySet(this, EMPTY); targetPublisher.subscribe(this); } + @Override + public Mono onClose() { + return onAllClosedSink.asMono(); + } + @Override protected void doOnDispose() { Operators.terminate(S, this); @@ -83,6 +81,14 @@ protected void doOnDispose() { for (RSocket rSocket : activeSockets) { rSocket.dispose(); } + + if (activeSockets.length > 0) { + Mono.whenDelayError( + Arrays.stream(activeSockets).map(RSocket::onClose).collect(Collectors.toList())) + .subscribe(null, onAllClosedSink::tryEmitError, onAllClosedSink::tryEmitEmpty); + } else { + onAllClosedSink.tryEmitEmpty(); + } } @Override @@ -92,90 +98,87 @@ public void onSubscribe(Subscription s) { } } - /** - * This operation should happen rarely relatively compares the number of the {@link #select()} - * method invocations, therefore it is acceptable to have it algorithmically inefficient. The - * algorithmic complexity of this method is - * - * @param targets set which represents RSocket targets to balance on - */ @Override public void onNext(List targets) { if (isDisposed()) { return; } - PooledWeightedRSocket[] previouslyActiveSockets; - PooledWeightedRSocket[] activeSockets; + // This operation should happen less frequently than calls to select() (which are per request) + // and therefore it is acceptable somewhat less efficient. + + PooledRSocket[] previouslyActiveSockets; + PooledRSocket[] inactiveSockets; + PooledRSocket[] socketsToUse; for (; ; ) { - HashMap rSocketSuppliersCopy = new HashMap<>(); + HashMap rSocketSuppliersCopy = new HashMap<>(targets.size()); int j = 0; for (LoadbalanceTarget target : targets) { rSocketSuppliersCopy.put(target, j++); } - // checking intersection of active RSocket with the newly received set + // Intersect current and new list of targets and find the ones to keep vs dispose previouslyActiveSockets = this.activeSockets; - PooledWeightedRSocket[] nextActiveSockets = - new PooledWeightedRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()]; - int position = 0; + inactiveSockets = new PooledRSocket[previouslyActiveSockets.length]; + PooledRSocket[] nextActiveSockets = + new PooledRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()]; + int activeSocketsPosition = 0; + int inactiveSocketsPosition = 0; for (int i = 0; i < previouslyActiveSockets.length; i++) { - PooledWeightedRSocket rSocket = previouslyActiveSockets[i]; + PooledRSocket rSocket = previouslyActiveSockets[i]; Integer index = rSocketSuppliersCopy.remove(rSocket.target()); if (index == null) { // if one of the active rSockets is not included, we remove it and put in the // pending removal if (!rSocket.isDisposed()) { - rSocket.dispose(); + inactiveSockets[inactiveSocketsPosition++] = rSocket; // TODO: provide a meaningful algo for keeping removed rsocket in the list // nextActiveSockets[position++] = rSocket; } } else { if (!rSocket.isDisposed()) { // keep old RSocket instance - nextActiveSockets[position++] = rSocket; + nextActiveSockets[activeSocketsPosition++] = rSocket; } else { // put newly create RSocket instance LoadbalanceTarget target = targets.get(index); - nextActiveSockets[position++] = - new PooledWeightedRSocket( - this, - this.connector.connect(target.getTransport()), - target, - this.statsSupplier.get()); + nextActiveSockets[activeSocketsPosition++] = + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); } } } - // going though brightly new rsocket + // The remainder are the brand new targets for (LoadbalanceTarget target : rSocketSuppliersCopy.keySet()) { - nextActiveSockets[position++] = - new PooledWeightedRSocket( - this, - this.connector.connect(target.getTransport()), - target, - this.statsSupplier.get()); + nextActiveSockets[activeSocketsPosition++] = + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); } - // shrank to actual length - if (position == 0) { - activeSockets = EMPTY; + if (activeSocketsPosition == 0) { + socketsToUse = EMPTY; } else { - activeSockets = Arrays.copyOf(nextActiveSockets, position); + socketsToUse = Arrays.copyOf(nextActiveSockets, activeSocketsPosition); } + if (ACTIVE_SOCKETS.compareAndSet(this, previouslyActiveSockets, socketsToUse)) { + break; + } + } - if (ACTIVE_SOCKETS.compareAndSet(this, previouslyActiveSockets, activeSockets)) { + for (PooledRSocket inactiveSocket : inactiveSockets) { + if (inactiveSocket == null) { break; } + + inactiveSocket.dispose(); } if (isPending()) { // notifies that upstream is resolved - if (activeSockets != EMPTY) { + if (socketsToUse != EMPTY) { //noinspection ConstantConditions - complete(null); + complete(this); } } } @@ -206,6 +209,13 @@ RSocket select() { terminate(new CancellationException("Pool is exhausted")); } else { invalidate(); + + // check since it is possible that between doSelect() and invalidate() we might + // have received new sockets + selected = doSelect(); + if (selected != null) { + return selected; + } } return this.deferredResolutionRSocket; } @@ -215,38 +225,13 @@ RSocket select() { @Nullable RSocket doSelect() { - WeightedRSocket[] sockets = this.activeSockets; - if (sockets == EMPTY) { + PooledRSocket[] sockets = this.activeSockets; + + if (sockets == EMPTY || sockets == TERMINATED) { return null; } - return this.loadbalanceStrategy.select(this); - } - - @Override - public WeightedRSocket get(int index) { - return activeSockets[index]; - } - - @Override - public int size() { - return activeSockets.length; - } - - @Override - public boolean isEmpty() { - return activeSockets.length == 0; - } - - @Override - public Object[] toArray() { - return activeSockets; - } - - @Override - @SuppressWarnings("unchecked") - public T[] toArray(T[] a) { - return (T[]) activeSockets; + return this.loadbalanceStrategy.select(WrappingList.wrap(sockets)); } static class DeferredResolutionRSocket implements RSocket { @@ -274,7 +259,7 @@ public Flux requestStream(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { - return new FluxInner<>(this.parent, payloads, FrameType.REQUEST_STREAM); + return new FluxInner<>(this.parent, payloads, FrameType.REQUEST_CHANNEL); } @Override @@ -283,7 +268,7 @@ public Mono metadataPush(Payload payload) { } } - static final class MonoInner extends MonoDeferredResolution { + static final class MonoInner extends MonoDeferredResolution { MonoInner(RSocketPool parent, Payload payload, FrameType requestType) { super(parent, payload, requestType); @@ -291,7 +276,7 @@ static final class MonoInner extends MonoDeferredResolution { @Override @SuppressWarnings({"unchecked", "rawtypes"}) - public void accept(Void aVoid, Throwable t) { + public void accept(Object aVoid, Throwable t) { if (isTerminated()) { return; } @@ -303,32 +288,47 @@ public void accept(Void aVoid, Throwable t) { } RSocketPool parent = (RSocketPool) this.parent; - RSocket rSocket = parent.doSelect(); - if (rSocket != null) { - Mono source; - switch (this.requestType) { - case REQUEST_FNF: - source = rSocket.fireAndForget(this.payload); - break; - case REQUEST_RESPONSE: - source = rSocket.requestResponse(this.payload); - break; - case METADATA_PUSH: - source = rSocket.metadataPush(this.payload); - break; - default: - Operators.error(this.actual, new IllegalStateException("Should never happen")); - return; + for (; ; ) { + RSocket rSocket = parent.doSelect(); + if (rSocket != null) { + Mono source; + switch (this.requestType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(this.payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(this.payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(this.payload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + + return; } - source.subscribe((CoreSubscriber) this); - } else { - parent.add(this); + final int state = parent.add(this); + + if (state == ADDED_STATE) { + return; + } + + if (state == TERMINATED_STATE) { + final Throwable error = parent.t; + ReferenceCountUtil.safeRelease(this.payload); + onError(error); + return; + } } } } - static final class FluxInner extends FluxDeferredResolution { + static final class FluxInner extends FluxDeferredResolution { FluxInner(RSocketPool parent, INPUT fluxOrPayload, FrameType requestType) { super(parent, fluxOrPayload, requestType); @@ -336,7 +336,7 @@ static final class FluxInner extends FluxDeferredResolution @Override @SuppressWarnings("unchecked") - public void accept(Void aVoid, Throwable t) { + public void accept(Object aVoid, Throwable t) { if (isTerminated()) { return; } @@ -350,115 +350,183 @@ public void accept(Void aVoid, Throwable t) { } RSocketPool parent = (RSocketPool) this.parent; - RSocket rSocket = parent.doSelect(); - if (rSocket != null) { - Flux source; - switch (this.requestType) { - case REQUEST_STREAM: - source = rSocket.requestStream((Payload) this.fluxOrPayload); - break; - case REQUEST_CHANNEL: - source = rSocket.requestChannel((Flux) this.fluxOrPayload); - break; - default: - Operators.error(this.actual, new IllegalStateException("Should never happen")); - return; + for (; ; ) { + RSocket rSocket = parent.doSelect(); + if (rSocket != null) { + Flux source; + switch (this.requestType) { + case REQUEST_STREAM: + source = rSocket.requestStream((Payload) this.fluxOrPayload); + break; + case REQUEST_CHANNEL: + source = rSocket.requestChannel((Flux) this.fluxOrPayload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + + return; } - source.subscribe(this); - } else { - parent.add(this); + final int state = parent.add(this); + + if (state == ADDED_STATE) { + return; + } + + if (state == TERMINATED_STATE) { + final Throwable error = parent.t; + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(error); + return; + } } } } - @Override - public boolean contains(Object o) { - throw new UnsupportedOperationException(); - } + static final class WrappingList implements List { - @Override - public Iterator iterator() { - throw new UnsupportedOperationException(); - } + static final ThreadLocal INSTANCE = ThreadLocal.withInitial(WrappingList::new); - @Override - public boolean add(RSocket weightedRSocket) { - throw new UnsupportedOperationException(); - } + private PooledRSocket[] activeSockets; - @Override - public boolean remove(Object o) { - throw new UnsupportedOperationException(); - } + static List wrap(PooledRSocket[] activeSockets) { + final WrappingList sockets = INSTANCE.get(); + sockets.activeSockets = activeSockets; + return sockets; + } - @Override - public boolean containsAll(Collection c) { - throw new UnsupportedOperationException(); - } + @Override + public RSocket get(int index) { + final PooledRSocket socket = activeSockets[index]; - @Override - public boolean addAll(Collection c) { - throw new UnsupportedOperationException(); - } + RSocket realValue = socket.value; + if (realValue != null) { + return realValue; + } - @Override - public boolean addAll(int index, Collection c) { - throw new UnsupportedOperationException(); - } + realValue = socket.valueIfResolved(); + if (realValue != null) { + return realValue; + } - @Override - public boolean removeAll(Collection c) { - throw new UnsupportedOperationException(); - } + return socket; + } - @Override - public boolean retainAll(Collection c) { - throw new UnsupportedOperationException(); - } + @Override + public int size() { + return activeSockets.length; + } - @Override - public void clear() { - throw new UnsupportedOperationException(); - } + @Override + public boolean isEmpty() { + return activeSockets.length == 0; + } - @Override - public WeightedRSocket set(int index, RSocket element) { - throw new UnsupportedOperationException(); - } + @Override + public Object[] toArray() { + return activeSockets; + } - @Override - public void add(int index, RSocket element) { - throw new UnsupportedOperationException(); - } + @Override + @SuppressWarnings("unchecked") + public T[] toArray(T[] a) { + return (T[]) activeSockets; + } - @Override - public WeightedRSocket remove(int index) { - throw new UnsupportedOperationException(); - } + @Override + public boolean contains(Object o) { + throw new UnsupportedOperationException(); + } - @Override - public int indexOf(Object o) { - throw new UnsupportedOperationException(); - } + @Override + public Iterator iterator() { + throw new UnsupportedOperationException(); + } - @Override - public int lastIndexOf(Object o) { - throw new UnsupportedOperationException(); - } + @Override + public boolean add(RSocket weightedRSocket) { + throw new UnsupportedOperationException(); + } - @Override - public ListIterator listIterator() { - throw new UnsupportedOperationException(); - } + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } - @Override - public ListIterator listIterator(int index) { - throw new UnsupportedOperationException(); - } + @Override + public boolean containsAll(Collection c) { + throw new UnsupportedOperationException(); + } - @Override - public List subList(int fromIndex, int toIndex) { - throw new UnsupportedOperationException(); + @Override + public boolean addAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(int index, Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public RSocket set(int index, RSocket element) { + throw new UnsupportedOperationException(); + } + + @Override + public void add(int index, RSocket element) { + throw new UnsupportedOperationException(); + } + + @Override + public RSocket remove(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public int indexOf(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public int lastIndexOf(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public ListIterator listIterator() { + throw new UnsupportedOperationException(); + } + + @Override + public ListIterator listIterator(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public List subList(int fromIndex, int toIndex) { + throw new UnsupportedOperationException(); + } } } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java index e03088b7f..52f16e166 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,18 +18,16 @@ import java.time.Duration; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.BiConsumer; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; import reactor.core.Disposable; import reactor.core.Exceptions; -import reactor.core.Scannable; import reactor.core.publisher.Operators; import reactor.util.annotation.Nullable; import reactor.util.context.Context; +// This class is a copy of the same class in io.rsocket.core + class ResolvingOperator implements Disposable { static final CancellationException ON_DISPOSE = new CancellationException("Disposed"); @@ -72,7 +70,7 @@ public ResolvingOperator() { } @Override - public void dispose() { + public final void dispose() { this.terminate(ON_DISPOSE); } @@ -168,19 +166,19 @@ public T block(@Nullable Duration timeout) { delay = System.nanoTime() + timeout.toNanos(); } for (; ; ) { - BiConsumer[] inners = this.subscribers; + subscribers = this.subscribers; - if (inners == READY) { + if (subscribers == READY) { final T value = this.value; if (value != null) { return value; } else { // value == null means racing between invalidate and this block // thus, we have to update the state again and see what happened - inners = this.subscribers; + subscribers = this.subscribers; } } - if (inners == TERMINATED) { + if (subscribers == TERMINATED) { RuntimeException re = Exceptions.propagate(this.t); re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); throw re; @@ -189,6 +187,12 @@ public T block(@Nullable Duration timeout) { throw new IllegalStateException("Timeout on Mono blocking read"); } + // connect again since invalidate() has happened in between + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + Thread.sleep(1); } } catch (InterruptedException ie) { @@ -201,6 +205,7 @@ public T block(@Nullable Duration timeout) { @SuppressWarnings("unchecked") final void terminate(Throwable t) { if (isDisposed()) { + Operators.onErrorDropped(t, Context.empty()); return; } @@ -322,6 +327,30 @@ protected void doOnDispose() { // no ops } + public final boolean connect() { + for (; ; ) { + final BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return false; + } + + if (a == READY) { + return true; + } + + if (a != EMPTY_UNSUBSCRIBED) { + // do nothing if already started + return true; + } + + if (SUBSCRIBERS.compareAndSet(this, a, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + return true; + } + } + } + final int add(BiConsumer ps) { for (; ; ) { BiConsumer[] a = this.subscribers; @@ -388,226 +417,4 @@ final void remove(BiConsumer ps) { } } } - - abstract static class DeferredResolution - implements CoreSubscriber, Subscription, Scannable, BiConsumer { - - final ResolvingOperator parent; - final CoreSubscriber actual; - - volatile long requested; - - @SuppressWarnings("rawtypes") - static final AtomicLongFieldUpdater REQUESTED = - AtomicLongFieldUpdater.newUpdater(DeferredResolution.class, "requested"); - - static final long STATE_SUBSCRIBED = -1; - static final long STATE_CANCELLED = Long.MIN_VALUE; - - Subscription s; - boolean done; - - DeferredResolution(ResolvingOperator parent, CoreSubscriber actual) { - this.parent = parent; - this.actual = actual; - } - - @Override - public final Context currentContext() { - return this.actual.currentContext(); - } - - @Nullable - @Override - public Object scanUnsafe(Attr key) { - long state = this.requested; - - if (key == Attr.PARENT) { - return this.s; - } - if (key == Attr.ACTUAL) { - return this.parent; - } - if (key == Attr.TERMINATED) { - return this.done; - } - if (key == Attr.CANCELLED) { - return state == STATE_CANCELLED; - } - - return null; - } - - @Override - public final void onSubscribe(Subscription s) { - final long state = this.requested; - Subscription a = this.s; - if (state == STATE_CANCELLED) { - s.cancel(); - return; - } - if (a != null) { - s.cancel(); - return; - } - - long r; - long accumulated = 0; - for (; ; ) { - r = this.requested; - - if (r == STATE_CANCELLED || r == STATE_SUBSCRIBED) { - s.cancel(); - return; - } - - this.s = s; - - long toRequest = r - accumulated; - if (toRequest > 0) { // if there is something, - s.request(toRequest); // then we do a request on the given subscription - } - accumulated = r; - - if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { - return; - } - } - } - - @Override - public final void onNext(T payload) { - this.actual.onNext(payload); - } - - @Override - public final void onError(Throwable t) { - if (this.done) { - Operators.onErrorDropped(t, this.actual.currentContext()); - return; - } - - this.done = true; - this.actual.onError(t); - } - - @Override - public final void onComplete() { - if (this.done) { - return; - } - - this.done = true; - this.actual.onComplete(); - } - - @Override - public void request(long n) { - if (Operators.validate(n)) { - long r = this.requested; // volatile read beforehand - - if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened - long u; - for (; ; ) { // normal CAS loop with overflow protection - if (r == Long.MAX_VALUE) { - // if r == Long.MAX_VALUE then we dont care and we can loose this - // request just in case of racing - return; - } - u = Operators.addCap(r, n); - if (REQUESTED.compareAndSet(this, r, u)) { - // Means increment happened before onSubscribe - return; - } else { - // Means increment happened after onSubscribe - - // update new state to see what exactly happened (onSubscribe |cancel | requestN) - r = this.requested; - - // check state (expect -1 | -2 to exit, otherwise repeat) - if (r < 0) { - break; - } - } - } - } - - if (r == STATE_CANCELLED) { // if canceled, just exit - return; - } - - // if onSubscribe -> subscription exists (and we sure of that because volatile read - // after volatile write) so we can execute requestN on the subscription - this.s.request(n); - } - } - - public boolean isCancelled() { - return this.requested == STATE_CANCELLED; - } - - public void cancel() { - long state = REQUESTED.getAndSet(this, STATE_CANCELLED); - if (state == STATE_CANCELLED) { - return; - } - - if (state == STATE_SUBSCRIBED) { - this.s.cancel(); - } else { - this.parent.remove(this); - } - } - } - - static class MonoDeferredResolutionOperator extends Operators.MonoSubscriber - implements BiConsumer { - - final ResolvingOperator parent; - - MonoDeferredResolutionOperator(ResolvingOperator parent, CoreSubscriber actual) { - super(actual); - this.parent = parent; - } - - @Override - public void accept(T t, Throwable throwable) { - if (throwable != null) { - onError(throwable); - return; - } - - complete(t); - } - - @Override - public void cancel() { - if (!isCancelled()) { - super.cancel(); - this.parent.remove(this); - } - } - - @Override - public void onComplete() { - if (!isCancelled()) { - this.actual.onComplete(); - } - } - - @Override - public void onError(Throwable t) { - if (isCancelled()) { - Operators.onErrorDropped(t, currentContext()); - } else { - this.actual.onError(t); - } - } - - @Override - public Object scanUnsafe(Attr key) { - if (key == Attr.PARENT) return this.parent; - return super.scanUnsafe(key); - } - } } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java index 98c86d565..f1a9f8c55 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ public class RoundRobinLoadbalanceStrategy implements LoadbalanceStrategy { volatile int nextIndex; - static final AtomicIntegerFieldUpdater NEXT_INDEX = + private static final AtomicIntegerFieldUpdater NEXT_INDEX = AtomicIntegerFieldUpdater.newUpdater(RoundRobinLoadbalanceStrategy.class, "nextIndex"); @Override diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Stats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Stats.java deleted file mode 100644 index 2e9828938..000000000 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/Stats.java +++ /dev/null @@ -1,308 +0,0 @@ -package io.rsocket.loadbalance; - -import io.rsocket.Availability; -import io.rsocket.util.Clock; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; - -class Stats implements Availability { - - private static final double DEFAULT_LOWER_QUANTILE = 0.5; - private static final double DEFAULT_HIGHER_QUANTILE = 0.8; - private static final int INACTIVITY_FACTOR = 500; - private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = - Clock.unit().convert(1L, TimeUnit.SECONDS); - - private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; - - private final Quantile lowerQuantile; - private final Quantile higherQuantile; - private final Ewma errorPercentage; - private final Median median; - private final Ewma interArrivalTime; - - private final long tau; - private final long inactivityFactor; - - private long errorStamp; // last we got an error - private long stamp; // last timestamp we sent a request - private long stamp0; // last timestamp we sent a request or receive a response - private long duration; // instantaneous cumulative duration - - private double availability = 1.0; - - private volatile int pending; // instantaneous rate - private volatile long pendingStreams; // number of active streams - private static final AtomicLongFieldUpdater PENDING_STREAMS = - AtomicLongFieldUpdater.newUpdater(Stats.class, "pendingStreams"); - - private Stats() { - this( - new FrugalQuantile(DEFAULT_LOWER_QUANTILE), - new FrugalQuantile(DEFAULT_HIGHER_QUANTILE), - INACTIVITY_FACTOR); - } - - private Stats(Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { - this.lowerQuantile = lowerQuantile; - this.higherQuantile = higherQuantile; - this.inactivityFactor = inactivityFactor; - - long now = Clock.now(); - this.stamp = now; - this.errorStamp = now; - this.stamp0 = now; - this.duration = 0L; - this.pending = 0; - this.median = new Median(); - this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); - this.errorPercentage = new Ewma(5, TimeUnit.SECONDS, 1.0); - this.tau = Clock.unit().convert((long) (5 / Math.log(2)), TimeUnit.SECONDS); - } - - public double errorPercentage() { - return errorPercentage.value(); - } - - public double medianLatency() { - return median.estimation(); - } - - public double lowerQuantileLatency() { - return lowerQuantile.estimation(); - } - - public double higherQuantileLatency() { - return higherQuantile.estimation(); - } - - public double interArrivalTime() { - return interArrivalTime.value(); - } - - public int pending() { - return pending; - } - - public long lastTimeUsedMillis() { - return stamp0; - } - - @Override - public double availability() { - if (Clock.now() - stamp > tau) { - recordError(1.0); - } - return availability * errorPercentage.value(); - } - - public synchronized double predictedLatency() { - long now = Clock.now(); - long elapsed = Math.max(now - stamp, 1L); - - double weight; - double prediction = median.estimation(); - - if (prediction == 0.0) { - if (pending == 0) { - weight = 0.0; // first request - } else { - // subsequent requests while we don't have any history - weight = STARTUP_PENALTY + pending; - } - } else if (pending == 0 && elapsed > inactivityFactor * interArrivalTime.value()) { - // if we did't see any data for a while, we decay the prediction by inserting - // artificial 0.0 into the median - median.insert(0.0); - weight = median.estimation(); - } else { - double predicted = prediction * pending; - double instant = instantaneous(now); - - if (predicted < instant) { // NB: (0.0 < 0.0) == false - weight = instant / pending; // NB: pending never equal 0 here - } else { - // we are under the predictions - weight = prediction; - } - } - - return weight; - } - - synchronized long instantaneous(long now) { - return duration + (now - stamp0) * pending; - } - - public void startStream() { - PENDING_STREAMS.incrementAndGet(this); - } - - public void stopStream() { - PENDING_STREAMS.decrementAndGet(this); - } - - public synchronized long startRequest() { - long now = Clock.now(); - interArrivalTime.insert(now - stamp); - duration += Math.max(0, now - stamp0) * pending; - pending += 1; - stamp = now; - stamp0 = now; - return now; - } - - public synchronized long stopRequest(long timestamp) { - long now = Clock.now(); - duration += Math.max(0, now - stamp0) * pending - (now - timestamp); - pending -= 1; - stamp0 = now; - return now; - } - - public synchronized void record(double roundTripTime) { - median.insert(roundTripTime); - lowerQuantile.insert(roundTripTime); - higherQuantile.insert(roundTripTime); - } - - public synchronized void recordError(double value) { - errorPercentage.insert(value); - errorStamp = Clock.now(); - } - - public void setAvailability(double availability) { - this.availability = availability; - } - - @Override - public String toString() { - return "Stats{" - + "lowerQuantile=" - + lowerQuantile.estimation() - + ", higherQuantile=" - + higherQuantile.estimation() - + ", inactivityFactor=" - + inactivityFactor - + ", tau=" - + tau - + ", errorPercentage=" - + errorPercentage.value() - + ", pending=" - + pending - + ", errorStamp=" - + errorStamp - + ", stamp=" - + stamp - + ", stamp0=" - + stamp0 - + ", duration=" - + duration - + ", median=" - + median.estimation() - + ", interArrivalTime=" - + interArrivalTime.value() - + ", pendingStreams=" - + pendingStreams - + ", availability=" - + availability - + '}'; - } - - private static final class NoOpsStats extends Stats { - - static final Stats INSTANCE = new NoOpsStats(); - - private NoOpsStats() {} - - @Override - public double errorPercentage() { - return 0.0d; - } - - @Override - public double medianLatency() { - return 0.0d; - } - - @Override - public double lowerQuantileLatency() { - return 0.0d; - } - - @Override - public double higherQuantileLatency() { - return 0.0d; - } - - @Override - public double interArrivalTime() { - return 0; - } - - @Override - public int pending() { - return 0; - } - - @Override - public long lastTimeUsedMillis() { - return 0; - } - - @Override - public double availability() { - return 1.0d; - } - - @Override - public double predictedLatency() { - return 0.0d; - } - - @Override - long instantaneous(long now) { - return 0; - } - - @Override - public void startStream() {} - - @Override - public void stopStream() {} - - @Override - public long startRequest() { - return 0; - } - - @Override - public long stopRequest(long timestamp) { - return 0; - } - - @Override - public void record(double roundTripTime) {} - - @Override - public void recordError(double value) {} - - @Override - public String toString() { - return "NoOpsStats{}"; - } - } - - public static Stats noOps() { - return NoOpsStats.INSTANCE; - } - - public static Stats create() { - return new Stats(); - } - - public static Stats create( - Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { - return new Stats(lowerQuantile, higherQuantile, inactivityFactor); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java index 03bc0530d..c30c8ad6b 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,56 +17,65 @@ package io.rsocket.loadbalance; import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.RequestInterceptor; import java.util.List; -import java.util.SplittableRandom; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Function; import reactor.util.annotation.Nullable; /** - * {@link LoadbalanceStrategy} that assigns a weight to each {@code RSocket} based on usage - * statistics, and uses this weight to select the {@code RSocket} to use. + * {@link LoadbalanceStrategy} that assigns a weight to each {@code RSocket} based on {@link + * RSocket#availability() availability} and usage statistics. The weight is used to decide which + * {@code RSocket} to select. + * + *

Use {@link #create()} or a {@link #builder() Builder} to create an instance. * * @since 1.1 + * @see Predictive Load-Balancing: Unfair but + * Faster & more Robust + * @see WeightedStatsRequestInterceptor */ -public class WeightedLoadbalanceStrategy implements LoadbalanceStrategy { +public class WeightedLoadbalanceStrategy implements ClientLoadbalanceStrategy { private static final double EXP_FACTOR = 4.0; - private static final int EFFORT = 5; - - final int effort; - final SplittableRandom splittableRandom; + final int maxPairSelectionAttempts; + final Function weightedStatsResolver; - public WeightedLoadbalanceStrategy() { - this(EFFORT); + private WeightedLoadbalanceStrategy( + int numberOfAttempts, @Nullable Function resolver) { + this.maxPairSelectionAttempts = numberOfAttempts; + this.weightedStatsResolver = (resolver != null ? resolver : new DefaultWeightedStatsResolver()); } - public WeightedLoadbalanceStrategy(int effort) { - this(effort, new SplittableRandom(System.nanoTime())); - } - - public WeightedLoadbalanceStrategy(int effort, SplittableRandom splittableRandom) { - this.effort = effort; - this.splittableRandom = splittableRandom; + @Override + public void initialize(RSocketConnector connector) { + final Function resolver = weightedStatsResolver; + if (resolver instanceof DefaultWeightedStatsResolver) { + ((DefaultWeightedStatsResolver) resolver).init(connector); + } } @Override public RSocket select(List sockets) { - final int effort = this.effort; final int size = sockets.size(); - WeightedRSocket weightedRSocket; + RSocket weightedRSocket; + final Function weightedStatsResolver = this.weightedStatsResolver; switch (size) { case 1: - weightedRSocket = (WeightedRSocket) sockets.get(0); + weightedRSocket = sockets.get(0); break; case 2: { - WeightedRSocket rsc1 = (WeightedRSocket) sockets.get(0); - WeightedRSocket rsc2 = (WeightedRSocket) sockets.get(1); + RSocket rsc1 = sockets.get(0); + RSocket rsc2 = sockets.get(1); - double w1 = algorithmicWeight(rsc1); - double w2 = algorithmicWeight(rsc2); + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); if (w1 < w2) { weightedRSocket = rsc2; } else { @@ -76,29 +85,36 @@ public RSocket select(List sockets) { break; default: { - WeightedRSocket rsc1 = null; - WeightedRSocket rsc2 = null; + RSocket rsc1 = null; + RSocket rsc2 = null; - for (int i = 0; i < effort; i++) { + for (int i = 0; i < this.maxPairSelectionAttempts; i++) { int i1 = ThreadLocalRandom.current().nextInt(size); int i2 = ThreadLocalRandom.current().nextInt(size - 1); if (i2 >= i1) { i2++; } - rsc1 = (WeightedRSocket) sockets.get(i1); - rsc2 = (WeightedRSocket) sockets.get(i2); + rsc1 = sockets.get(i1); + rsc2 = sockets.get(i2); if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) { break; } } - double w1 = algorithmicWeight(rsc1); - double w2 = algorithmicWeight(rsc2); - if (w1 < w2) { - weightedRSocket = rsc2; - } else { + if (rsc1 != null & rsc2 != null) { + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); + + if (w1 < w2) { + weightedRSocket = rsc2; + } else { + weightedRSocket = rsc1; + } + } else if (rsc1 != null) { weightedRSocket = rsc1; + } else { + weightedRSocket = rsc2; } } } @@ -106,20 +122,22 @@ public RSocket select(List sockets) { return weightedRSocket; } - private static double algorithmicWeight(@Nullable final WeightedRSocket weightedRSocket) { - if (weightedRSocket == null - || weightedRSocket.isDisposed() - || weightedRSocket.availability() == 0.0) { + private static double algorithmicWeight( + RSocket rSocket, @Nullable final WeightedStats weightedStats) { + if (weightedStats == null) { + return 1.0; + } + if (rSocket.isDisposed() || rSocket.availability() == 0.0) { return 0.0; } - final Stats stats = weightedRSocket.stats(); - final int pending = stats.pending(); - double latency = stats.predictedLatency(); + final int pending = weightedStats.pending(); + + double latency = weightedStats.predictedLatency(); - final double low = stats.lowerQuantileLatency(); + final double low = weightedStats.lowerQuantileLatency(); final double high = Math.max( - stats.higherQuantileLatency(), + weightedStats.higherQuantileLatency(), low * 1.001); // ensure higherQuantile > lowerQuantile + .1% final double bandWidth = Math.max(high - low, 1); @@ -129,11 +147,103 @@ private static double algorithmicWeight(@Nullable final WeightedRSocket weighted latency *= calculateFactor(latency, high, bandWidth); } - return weightedRSocket.availability() * 1.0 / (1.0 + latency * (pending + 1)); + return (rSocket.availability() * weightedStats.weightedAvailability()) + / (1.0d + latency * (pending + 1)); } private static double calculateFactor(final double u, final double l, final double bandWidth) { final double alpha = (u - l) / bandWidth; return Math.pow(1 + alpha, EXP_FACTOR); } + + /** + * Create an instance of {@link WeightedLoadbalanceStrategy} with default settings, which include + * round-robin load-balancing and 5 {@link #maxPairSelectionAttempts}. + */ + public static WeightedLoadbalanceStrategy create() { + return new Builder().build(); + } + + /** Return a builder to create a {@link WeightedLoadbalanceStrategy} with. */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link WeightedLoadbalanceStrategy}. */ + public static class Builder { + + private int maxPairSelectionAttempts = 5; + + @Nullable private Function weightedStatsResolver; + + private Builder() {} + + /** + * How many times to try to randomly select a pair of RSocket connections with non-zero + * availability. This is applicable when there are more than two connections in the pool. If the + * number of attempts is exceeded, the last selected pair is used. + * + *

By default this is set to 5. + * + * @param numberOfAttempts the iteration count + */ + public Builder maxPairSelectionAttempts(int numberOfAttempts) { + this.maxPairSelectionAttempts = numberOfAttempts; + return this; + } + + /** + * Configure how the created {@link WeightedLoadbalanceStrategy} should find the stats for a + * given RSocket. + * + *

By default this resolver is not set. + * + *

When {@code WeightedLoadbalanceStrategy} is used through the {@link + * LoadbalanceRSocketClient}, the resolver does not need to be set because a {@link + * WeightedStatsRequestInterceptor} is automatically installed through the {@link + * ClientLoadbalanceStrategy} callback. If this strategy is used in any other context however, a + * resolver here must be provided. + * + * @param resolver to find the stats for an RSocket with + */ + public Builder weightedStatsResolver(Function resolver) { + this.weightedStatsResolver = resolver; + return this; + } + + /** Build the {@code WeightedLoadbalanceStrategy} instance. */ + public WeightedLoadbalanceStrategy build() { + return new WeightedLoadbalanceStrategy( + this.maxPairSelectionAttempts, this.weightedStatsResolver); + } + } + + private static class DefaultWeightedStatsResolver implements Function { + + final Map statsMap = new ConcurrentHashMap<>(); + + @Override + public WeightedStats apply(RSocket rSocket) { + return statsMap.get(rSocket); + } + + void init(RSocketConnector connector) { + connector.interceptors( + registry -> + registry.forRequestsInRequester( + (Function) + rSocket -> { + final WeightedStatsRequestInterceptor interceptor = + new WeightedStatsRequestInterceptor() { + @Override + public void dispose() { + statsMap.remove(rSocket); + } + }; + statsMap.put(rSocket, interceptor); + + return interceptor; + })); + } + } } diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedRSocket.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedRSocket.java deleted file mode 100644 index 488a7134d..000000000 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedRSocket.java +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.loadbalance; - -import io.rsocket.RSocket; - -interface WeightedRSocket extends RSocket { - - Stats stats(); -} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java new file mode 100644 index 000000000..5ebe668ce --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; + +/** + * Contract to expose the stats required in {@link WeightedLoadbalanceStrategy} to calculate an + * algorithmic weight for an {@code RSocket}. The weight helps to select an {@code RSocket} for + * load-balancing. + * + * @since 1.1 + */ +public interface WeightedStats { + + double higherQuantileLatency(); + + double lowerQuantileLatency(); + + int pending(); + + double predictedLatency(); + + double weightedAvailability(); + + /** + * Create a proxy for the given {@code RSocket} that attaches the stats contained in this instance + * and exposes them as {@link WeightedStats}. + * + * @param rsocket the RSocket to wrap + * @return the wrapped RSocket + * @since 1.1.1 + */ + default RSocket wrap(RSocket rsocket) { + return new WeightedStatsRSocketProxy(rsocket, this); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java new file mode 100644 index 000000000..f2cf3fbd0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java @@ -0,0 +1,62 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; + +/** + * Package private {@code RSocketProxy} used from {@link WeightedStats#wrap(RSocket)} to attach a + * {@link WeightedStats} instance to an {@code RSocket}. + */ +class WeightedStatsRSocketProxy extends RSocketProxy implements WeightedStats { + + private final WeightedStats weightedStats; + + public WeightedStatsRSocketProxy(RSocket source, WeightedStats weightedStats) { + super(source); + this.weightedStats = weightedStats; + } + + @Override + public double higherQuantileLatency() { + return this.weightedStats.higherQuantileLatency(); + } + + @Override + public double lowerQuantileLatency() { + return this.weightedStats.lowerQuantileLatency(); + } + + @Override + public int pending() { + return this.weightedStats.pending(); + } + + @Override + public double predictedLatency() { + return this.weightedStats.predictedLatency(); + } + + @Override + public double weightedAvailability() { + return this.weightedStats.weightedAvailability(); + } + + public WeightedStats getDelegate() { + return this.weightedStats; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java new file mode 100644 index 000000000..ec2c88b19 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import reactor.util.annotation.Nullable; + +/** + * {@link RequestInterceptor} that hooks into request lifecycle and calls methods of the parent + * class to manage tracking state and expose {@link WeightedStats}. + * + *

This interceptor the default mechanism for gathering stats when {@link + * WeightedLoadbalanceStrategy} is used with {@link LoadbalanceRSocketClient}. + * + * @since 1.1 + * @see LoadbalanceRSocketClient + * @see WeightedLoadbalanceStrategy + */ +public class WeightedStatsRequestInterceptor extends BaseWeightedStats + implements RequestInterceptor { + + final Int2LongHashMap requestsStartTime = new Int2LongHashMap(-1); + + public WeightedStatsRequestInterceptor() { + super(); + } + + @Override + public final void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + final long startTime = startRequest(); + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + requestsStartTime.put(streamId, startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + this.startStream(); + } + } + + @Override + public final void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + long endTime = stopRequest(startTime); + if (t == null) { + record(endTime - startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + + if (t != null) { + updateAvailability(0.0d); + } else { + updateAvailability(1.0d); + } + } + + @Override + public final void onCancel(int streamId, FrameType requestType) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + stopRequest(startTime); + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + } + + @Override + public final void onReject(Throwable rejectionReason, FrameType requestType, ByteBuf metadata) {} + + @Override + public void dispose() {} +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java index f5fd00a52..19668e99c 100644 --- a/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java @@ -14,6 +14,7 @@ * limitations under the License. */ +/** Support client load-balancing in RSocket Java. */ @NonNullApi package io.rsocket.loadbalance; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java index d908abb3c..c16c4dc52 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java @@ -12,7 +12,7 @@ public class AuthMetadataCodec { static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 - static final int USERNAME_BYTES_LENGTH = 1; + static final int USERNAME_BYTES_LENGTH = 2; static final int AUTH_TYPE_ID_LENGTH = 1; static final char[] EMPTY_CHARS_ARRAY = new char[0]; @@ -81,7 +81,7 @@ public static ByteBuf encodeMetadata( /** * Encode a Authentication CompositeMetadata payload using Simple Authentication format * - * @throws IllegalArgumentException if the username length is greater than 255 + * @throws IllegalArgumentException if the username length is greater than 65535 * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. * @param username the char sequence which represents user name. * @param password the char sequence which represents user password. @@ -90,9 +90,9 @@ public static ByteBuf encodeSimpleMetadata( ByteBufAllocator allocator, char[] username, char[] password) { int usernameLength = CharByteBufUtil.utf8Bytes(username); - if (usernameLength > 255) { + if (usernameLength > 65535) { throw new IllegalArgumentException( - "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + "Username should be shorter than or equal to 65535 bytes length in UTF-8 encoding"); } int passwordLength = CharByteBufUtil.utf8Bytes(password); @@ -101,7 +101,7 @@ public static ByteBuf encodeSimpleMetadata( allocator .buffer(capacity, capacity) .writeByte(WellKnownAuthType.SIMPLE.getIdentifier() | STREAM_METADATA_KNOWN_MASK) - .writeByte(usernameLength); + .writeShort(usernameLength); CharByteBufUtil.writeUtf8(buffer, username); CharByteBufUtil.writeUtf8(buffer, password); @@ -235,15 +235,15 @@ public static ByteBuf readPayload(ByteBuf metadata) { } /** - * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username - * length and the subsequent number of bytes equal to decoded length + * Read up to 65537 {@code bytes} from the given {@link ByteBuf} where the first two bytes + * represent username length and the subsequent number of bytes equal to read length * * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the username length byte + * simpleAuthMetadata#readIndex} should be set to the username length position * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero */ public static ByteBuf readUsername(ByteBuf simpleAuthMetadata) { - short usernameLength = readUsernameLength(simpleAuthMetadata); + int usernameLength = readUsernameLength(simpleAuthMetadata); if (usernameLength == 0) { return Unpooled.EMPTY_BUFFER; @@ -268,15 +268,15 @@ public static ByteBuf readPassword(ByteBuf simpleAuthMetadata) { return simpleAuthMetadata.readSlice(simpleAuthMetadata.readableBytes()); } /** - * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username - * length and the subsequent number of bytes equal to decoded length + * Read up to 65537 {@code bytes} from the given {@link ByteBuf} where the first two bytes + * represent username length and the subsequent number of bytes equal to read length * * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code * simpleAuthMetadata#readIndex} should be set to the username length byte * @return {@code char[]} which represents UTF-8 username */ public static char[] readUsernameAsCharArray(ByteBuf simpleAuthMetadata) { - short usernameLength = readUsernameLength(simpleAuthMetadata); + int usernameLength = readUsernameLength(simpleAuthMetadata); if (usernameLength == 0) { return EMPTY_CHARS_ARRAY; @@ -302,11 +302,10 @@ public static char[] readPasswordAsCharArray(ByteBuf simpleAuthMetadata) { } /** - * Read all the remaining {@code bytes} from the given {@link ByteBuf} where the first byte is - * username length and the subsequent number of bytes equal to decoded length + * Read all the remaining {@code bytes} from the given {@link ByteBuf} * * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code - * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * bearerAuthMetadata#readIndex} should be set to the beginning of the password bytes * @return {@code char[]} which represents UTF-8 password */ public static char[] readBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { @@ -317,13 +316,13 @@ public static char[] readBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { return CharByteBufUtil.readUtf8(bearerAuthMetadata, bearerAuthMetadata.readableBytes()); } - private static short readUsernameLength(ByteBuf simpleAuthMetadata) { - if (simpleAuthMetadata.readableBytes() < 1) { + private static int readUsernameLength(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() < 2) { throw new IllegalStateException( "Unable to decode custom username. Not enough readable bytes"); } - short usernameLength = simpleAuthMetadata.readUnsignedByte(); + int usernameLength = simpleAuthMetadata.readUnsignedShort(); if (simpleAuthMetadata.readableBytes() < usernameLength) { throw new IllegalArgumentException( diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java new file mode 100644 index 000000000..2e03bd754 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java @@ -0,0 +1,137 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import java.util.ArrayList; +import java.util.List; + +/** + * Provides support for encoding and decoding the per-stream MIME type to use for payload data. + * + *

For more on the format of the metadata, see the + * Stream Data MIME Types extension specification. + * + * @since 1.1.1 + */ +public class MimeTypeMetadataCodec { + + private static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + private static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private MimeTypeMetadataCodec() {} + + /** + * Encode a {@link WellKnownMimeType} into a newly allocated single byte {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeType well-known MIME type to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, WellKnownMimeType mimeType) { + return allocator.buffer(1, 1).writeByte(mimeType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + } + + /** + * Encode the given MIME type into a newly allocated {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeType MIME type to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, String mimeType) { + if (mimeType == null || mimeType.length() == 0) { + throw new IllegalArgumentException("MIME type is required"); + } + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + return encodeCustomMimeType(allocator, mimeType); + } else { + return encode(allocator, wkn); + } + } + + /** + * Encode multiple MIME types into a newly allocated {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeTypes MIME types to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, List mimeTypes) { + if (mimeTypes == null || mimeTypes.size() == 0) { + throw new IllegalArgumentException("No MIME types provided"); + } + CompositeByteBuf compositeByteBuf = allocator.compositeBuffer(); + for (String mimeType : mimeTypes) { + ByteBuf byteBuf = encode(allocator, mimeType); + compositeByteBuf.addComponents(true, byteBuf); + } + return compositeByteBuf; + } + + private static ByteBuf encodeCustomMimeType(ByteBufAllocator allocator, String customMimeType) { + ByteBuf byteBuf = allocator.buffer(1 + customMimeType.length()); + + byteBuf.writerIndex(1); + int length = ByteBufUtil.writeUtf8(byteBuf, customMimeType); + + if (!ByteBufUtil.isText(byteBuf, 1, length, CharsetUtil.US_ASCII)) { + byteBuf.release(); + throw new IllegalArgumentException("MIME type must be ASCII characters only"); + } + + if (length < 1 || length > 128) { + byteBuf.release(); + throw new IllegalArgumentException( + "MIME type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + byteBuf.markWriterIndex(); + byteBuf.writerIndex(0); + byteBuf.writeByte(length - 1); + byteBuf.resetWriterIndex(); + + return byteBuf; + } + + /** + * Decode the per-stream MIME type metadata encoded in the given {@link ByteBuf}. + * + * @return the decoded MIME types + */ + public static List decode(ByteBuf byteBuf) { + List mimeTypes = new ArrayList<>(); + while (byteBuf.isReadable()) { + byte idOrLength = byteBuf.readByte(); + if ((idOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + byte id = (byte) (idOrLength & STREAM_METADATA_LENGTH_MASK); + WellKnownMimeType wellKnownMimeType = WellKnownMimeType.fromIdentifier(id); + mimeTypes.add(wellKnownMimeType.toString()); + } else { + int length = Byte.toUnsignedInt(idOrLength) + 1; + mimeTypes.add(byteBuf.readCharSequence(length, CharsetUtil.US_ASCII).toString()); + } + } + return mimeTypes; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java new file mode 100644 index 000000000..9a134153d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java @@ -0,0 +1,147 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import java.util.List; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +class CompositeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor[] requestInterceptors; + + CompositeRequestInterceptor(RequestInterceptor[] requestInterceptors) { + this.requestInterceptors = requestInterceptors; + } + + @Override + public void dispose() { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + requestInterceptor.dispose(); + } + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onTerminate(streamId, requestType, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onCancel(streamId, requestType); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Nullable + static RequestInterceptor create(List interceptors) { + switch (interceptors.size()) { + case 0: + return null; + case 1: + return new SafeRequestInterceptor(interceptors.get(0)); + default: + return new CompositeRequestInterceptor(interceptors.toArray(new RequestInterceptor[0])); + } + } + + static class SafeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor requestInterceptor; + + public SafeRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + } + + @Override + public void dispose() { + requestInterceptor.dispose(); + } + + @Override + public boolean isDisposed() { + return requestInterceptor.isDisposed(); + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { + try { + requestInterceptor.onTerminate(streamId, requestType, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + try { + requestInterceptor.onCancel(streamId, requestType); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java index 6b2a7a71b..5d3a43b03 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java @@ -27,6 +27,8 @@ extends BiFunction { enum Type { + /** @deprecated since 1.1.0-M2. Will be removed in 1.2 */ + @Deprecated SETUP, CLIENT, SERVER, diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java index fc032847c..7c9a90f54 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -18,6 +18,9 @@ import io.rsocket.DuplexConnection; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import reactor.util.annotation.Nullable; /** * Extends {@link InterceptorRegistry} with methods for building a chain of registered interceptors. @@ -25,6 +28,27 @@ */ public class InitializingInterceptorRegistry extends InterceptorRegistry { + @Nullable + public RequestInterceptor initRequesterRequestInterceptor(RSocket rSocketRequester) { + return CompositeRequestInterceptor.create( + getRequestInterceptorsForRequester() + .stream() + .map(factory -> factory.apply(rSocketRequester)) + .collect(Collectors.toList())); + } + + @Nullable + public RequestInterceptor initResponderRequestInterceptor( + RSocket rSocketResponder, RequestInterceptor... perConnectionInterceptors) { + return CompositeRequestInterceptor.create( + Stream.concat( + Stream.of(perConnectionInterceptors), + getRequestInterceptorsForResponder() + .stream() + .map(inteptorFactory -> inteptorFactory.apply(rSocketResponder))) + .collect(Collectors.toList())); + } + public DuplexConnection initConnection( DuplexConnectionInterceptor.Type type, DuplexConnection connection) { for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { @@ -34,7 +58,7 @@ public DuplexConnection initConnection( } public RSocket initRequester(RSocket rsocket) { - for (RSocketInterceptor interceptor : getRequesterInteceptors()) { + for (RSocketInterceptor interceptor : getRequesterInterceptors()) { rsocket = interceptor.apply(rsocket); } return rsocket; diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java index 427fa15ae..680fb514f 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -15,9 +15,11 @@ */ package io.rsocket.plugins; +import io.rsocket.RSocket; import java.util.ArrayList; import java.util.List; import java.util.function.Consumer; +import java.util.function.Function; /** * Provides support for registering interceptors at the following levels: @@ -30,16 +32,46 @@ * */ public class InterceptorRegistry { - private List requesterInteceptors = new ArrayList<>(); - private List responderInterceptors = new ArrayList<>(); + private List> requesterRequestInterceptors = + new ArrayList<>(); + private List> responderRequestInterceptors = + new ArrayList<>(); + private List requesterRSocketInterceptors = new ArrayList<>(); + private List responderRSocketInterceptors = new ArrayList<>(); private List socketAcceptorInterceptors = new ArrayList<>(); private List connectionInterceptors = new ArrayList<>(); + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequestsInRequester( + Function interceptor) { + requesterRequestInterceptors.add(interceptor); + return this; + } + + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequestsInResponder( + Function interceptor) { + responderRequestInterceptors.add(interceptor); + return this; + } + /** * Add an {@link RSocketInterceptor} that will decorate the RSocket used for performing requests. */ public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { - requesterInteceptors.add(interceptor); + requesterRSocketInterceptors.add(interceptor); return this; } @@ -48,7 +80,7 @@ public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { * registrations. */ public InterceptorRegistry forRequester(Consumer> consumer) { - consumer.accept(requesterInteceptors); + consumer.accept(requesterRSocketInterceptors); return this; } @@ -57,7 +89,7 @@ public InterceptorRegistry forRequester(Consumer> consu * requests. */ public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { - responderInterceptors.add(interceptor); + responderRSocketInterceptors.add(interceptor); return this; } @@ -66,7 +98,7 @@ public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { * registrations. */ public InterceptorRegistry forResponder(Consumer> consumer) { - consumer.accept(responderInterceptors); + consumer.accept(responderRSocketInterceptors); return this; } @@ -102,12 +134,20 @@ public InterceptorRegistry forConnection(Consumer getRequesterInteceptors() { - return requesterInteceptors; + List> getRequestInterceptorsForRequester() { + return requesterRequestInterceptors; + } + + List> getRequestInterceptorsForResponder() { + return responderRequestInterceptors; + } + + List getRequesterInterceptors() { + return requesterRSocketInterceptors; } List getResponderInterceptors() { - return responderInterceptors; + return responderRSocketInterceptors; } List getConnectionInterceptors() { diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java new file mode 100644 index 000000000..08131b39d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java @@ -0,0 +1,79 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import reactor.core.Disposable; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +/** + * Class used to track the RSocket requests lifecycles. The main difference and advantage of this + * interceptor compares to {@link RSocketInterceptor} is that it allows intercepting the initial and + * terminal phases on every individual request. + * + *

Note, if any of the invocations will rise a runtime exception, this exception will be + * caught and be propagated to {@link reactor.core.publisher.Operators#onErrorDropped(Throwable, + * Context)} + * + * @since 1.1 + */ +public interface RequestInterceptor extends Disposable { + + /** + * Method which is being invoked on successful acceptance and start of a request. + * + * @param streamId used for the request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata); + + /** + * Method which is being invoked once a successfully accepted request is terminated. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onCancel(int, FrameType)}. + * + * @param streamId used by this request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param t with which this finished has terminated. Must be one of the following signals + */ + void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t); + + /** + * Method which is being invoked once a successfully accepted request is cancelled. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onTerminate(int, FrameType, Throwable)}. + * + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param streamId used by this request + */ + void onCancel(int streamId, FrameType requestType); + + /** + * Method which is being invoked on the request rejection. This method is being called only if the + * actual request can not be started and is called instead of the {@link #onStart(int, FrameType, + * ByteBuf)} method. The reason for rejection can be one of the following: + * + *

+ * + *

    + *
  • No available {@link io.rsocket.lease.Lease} on the requester or the responder sides + *
  • Invalid {@link io.rsocket.Payload} size or format on the Requester side, so the request + * is being rejected before the actual streamId is generated + *
  • A second subscription on the ongoing Request + *
+ * + * @param rejectionReason exception which causes rejection of a particular request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onReject(Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java index ed9450357..ca4f5dcb4 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java @@ -18,177 +18,366 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; import io.rsocket.DuplexConnection; import io.rsocket.exceptions.ConnectionErrorException; -import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; import io.rsocket.frame.ResumeFrameCodec; import io.rsocket.frame.ResumeOkFrameCodec; -import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.keepalive.KeepAliveSupport; import java.time.Duration; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.function.Tuple2; import reactor.util.retry.Retry; -public class ClientRSocketSession implements RSocketSession> { +public class ClientRSocketSession + implements RSocketSession, + ResumeStateHolder, + CoreSubscriber> { + private static final Logger logger = LoggerFactory.getLogger(ClientRSocketSession.class); - private final ResumableDuplexConnection resumableConnection; - private volatile Mono newConnection; - private volatile ByteBuf resumeToken; - private final ByteBufAllocator allocator; + final ResumableDuplexConnection resumableConnection; + final Mono> connectionFactory; + final ResumableFramesStore resumableFramesStore; + + final ByteBufAllocator allocator; + final Duration resumeSessionDuration; + final Retry retry; + final boolean cleanupStoreOnKeepAlive; + final ByteBuf resumeToken; + final String session; + final Disposable reconnectDisposable; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(ClientRSocketSession.class, Subscription.class, "s"); + + KeepAliveSupport keepAliveSupport; public ClientRSocketSession( - DuplexConnection duplexConnection, + ByteBuf resumeToken, + ResumableDuplexConnection resumableDuplexConnection, + Mono connectionFactory, + Function>> connectionTransformer, + ResumableFramesStore resumableFramesStore, Duration resumeSessionDuration, Retry retry, - ResumableFramesStore resumableFramesStore, - Duration resumeStreamTimeout, boolean cleanupStoreOnKeepAlive) { - this.allocator = duplexConnection.alloc(); - this.resumableConnection = - new ResumableDuplexConnection( - "client", - duplexConnection, - resumableFramesStore, - resumeStreamTimeout, - cleanupStoreOnKeepAlive); - - /*session completed: release token initially retained in resumeToken(ByteBuf)*/ - onClose().doFinally(s -> resumeToken.release()).subscribe(); - - resumableConnection - .connectionErrors() - .flatMap( - err -> { - logger.debug("Client session connection error. Starting new connection"); - AtomicBoolean once = new AtomicBoolean(); - return newConnection - .delaySubscription( - once.compareAndSet(false, true) - ? retry.generateCompanion(Flux.just(new RetrySignal(err))) - : Mono.empty()) - .retryWhen(retry) - .timeout(resumeSessionDuration); - }) - .map(ClientServerInputMultiplexer::new) - .subscribe( - multiplexer -> { - /*reconnect resumable connection*/ - reconnect(multiplexer.asClientServerConnection()); - long impliedPosition = resumableConnection.impliedPosition(); - long position = resumableConnection.position(); - logger.debug( - "Client ResumableConnection reconnected. Sending RESUME frame with state: [impliedPos: {}, pos: {}]", - impliedPosition, - position); - /*Connection is established again: send RESUME frame to server, listen for RESUME_OK*/ - sendFrame( + this.resumeToken = resumeToken; + this.session = resumeToken.toString(CharsetUtil.UTF_8); + this.connectionFactory = + connectionFactory + .doOnDiscard( + DuplexConnection.class, + c -> { + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server=[Session Expired]"); + c.sendErrorAndClose(connectionErrorException); + c.receive().subscribe(); + }) + .flatMap( + dc -> { + final long impliedPosition = resumableFramesStore.frameImpliedPosition(); + final long position = resumableFramesStore.framePosition(); + dc.sendFrame( + 0, ResumeFrameCodec.encode( - allocator, - /*retain so token is not released once sent as part of resume frame*/ + dc.alloc(), resumeToken.retain(), - impliedPosition, - position)) - .then(multiplexer.asSetupConnection().receive().next()) - .subscribe(this::resumeWith); - }, - err -> { - logger.debug("Client ResumableConnection reconnect timeout"); - resumableConnection.dispose(); - }); - } + // server uses this to release its cache + impliedPosition, // observed on the client side + // server uses this to check whether there is no mismatch + position // sent from the client sent + )); - @Override - public ClientRSocketSession continueWith(Mono connectionFactory) { - this.newConnection = connectionFactory; - return this; - } + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. ResumeFrame[impliedPosition[{}], position[{}]] has been sent.", + session, + impliedPosition, + position); + } - @Override - public ClientRSocketSession resumeWith(ByteBuf resumeOkFrame) { - logger.debug("ResumeOK FRAME received"); - long remotePos = remotePos(resumeOkFrame); - long remoteImpliedPos = remoteImpliedPos(resumeOkFrame); - resumeOkFrame.release(); - - resumableConnection.resume( - remotePos, - remoteImpliedPos, - pos -> - pos.then() - /*Resumption is impossible: send CONNECTION_ERROR*/ - .onErrorResume( - err -> - sendFrame( - ErrorFrameCodec.encode( - allocator, 0, errorFrameThrowable(remoteImpliedPos))) - .then(Mono.fromRunnable(resumableConnection::dispose)) - /*Resumption is impossible: no need to return control to ResumableConnection*/ - .then(Mono.never()))); - return this; + return connectionTransformer.apply(dc); + }) + .doOnDiscard(Tuple2.class, this::tryReestablishSession); + this.resumableFramesStore = resumableFramesStore; + this.allocator = resumableDuplexConnection.alloc(); + this.resumeSessionDuration = resumeSessionDuration; + this.retry = retry; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + this.resumableConnection = resumableDuplexConnection; + + resumableDuplexConnection.onClose().doFinally(__ -> dispose()).subscribe(); + + this.reconnectDisposable = + resumableDuplexConnection.onActiveConnectionClosed().subscribe(this::reconnect); } - public ClientRSocketSession resumeToken(ByteBuf resumeToken) { - /*retain so token is not released once sent as part of setup frame*/ - this.resumeToken = resumeToken.retain(); - return this; + void reconnect(int index) { + if (this.s == Operators.cancelledSubscription()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Connection[{}] is lost. Reconnecting rejected since session is closed", + session, + index); + } + return; + } + + keepAliveSupport.stop(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Connection[{}] is lost. Reconnecting to resume...", + session, + index); + } + connectionFactory + .doOnNext(this::tryReestablishSession) + .retryWhen(retry) + .timeout(resumeSessionDuration) + .subscribe(this); } @Override - public void reconnect(DuplexConnection connection) { - resumableConnection.reconnect(connection); + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); } @Override - public ResumableDuplexConnection resumableConnection() { - return resumableConnection; + public void onImpliedPosition(long remoteImpliedPos) { + if (cleanupStoreOnKeepAlive) { + try { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } catch (Throwable e) { + resumableConnection.sendErrorAndClose(new ConnectionErrorException(e.getMessage(), e)); + } + } } @Override - public ByteBuf token() { - return resumeToken; - } + public void dispose() { + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Disposing", session); + } - private Mono sendFrame(ByteBuf frame) { - return resumableConnection.sendOne(frame).onErrorResume(err -> Mono.empty()); - } + boolean result = Operators.terminate(S, this); - private static long remoteImpliedPos(ByteBuf resumeOkFrame) { - return ResumeOkFrameCodec.lastReceivedClientPos(resumeOkFrame); - } + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Sessions[isDisposed={}]", session, result); + } - private static long remotePos(ByteBuf resumeOkFrame) { - return -1; + reconnectDisposable.dispose(); + resumableConnection.dispose(); + // frame store is disposed by resumable connection + // resumableFramesStore.dispose(); + + if (resumeToken.refCnt() > 0) { + resumeToken.release(); + } } - private static ConnectionErrorException errorFrameThrowable(long impliedPos) { - return new ConnectionErrorException("resumption_server_pos=[" + impliedPos + "]"); + @Override + public boolean isDisposed() { + return resumableConnection.isDisposed(); } - private static class RetrySignal implements Retry.RetrySignal { + void tryReestablishSession(Tuple2 tuple2) { + if (logger.isDebugEnabled()) { + logger.debug("Active subscription is canceled {}", s == Operators.cancelledSubscription()); + } + ByteBuf shouldBeResumeOKFrame = tuple2.getT1(); + DuplexConnection nextDuplexConnection = tuple2.getT2(); + + final int streamId = FrameHeaderCodec.streamId(shouldBeResumeOKFrame); + if (streamId != 0) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Illegal first frame received. RESUME_OK frame must be received before any others. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("RESUME_OK frame must be received before any others"); + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + + throw connectionErrorException; // throw to retry connection again + } + + final FrameType frameType = FrameHeaderCodec.nativeFrameType(shouldBeResumeOKFrame); + if (frameType == FrameType.RESUME_OK) { + // how many frames the server has received from the client + // so the client can release cached frames by this point + long remoteImpliedPos = ResumeOkFrameCodec.lastReceivedClientPos(shouldBeResumeOKFrame); + // what was the last notification from the server about number of frames being + // observed + final long position = resumableFramesStore.framePosition(); + final long impliedPosition = resumableFramesStore.frameImpliedPosition(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. ResumeOK FRAME received. ServerResumeState[remoteImpliedPosition[{}]]. ClientResumeState[impliedPosition[{}], position[{}]]", + session, + remoteImpliedPos, + impliedPosition, + position); + } + if (position <= remoteImpliedPos) { + try { + if (position != remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + } catch (IllegalStateException e) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Exception occurred while releasing frames in the frameStore", + session, + e); + } + final ConnectionErrorException t = new ConnectionErrorException(e.getMessage(), e); + + resumableConnection.dispose(nextDuplexConnection, t); + + nextDuplexConnection.sendErrorAndClose(t); + nextDuplexConnection.receive().subscribe(); + + return; + } + + if (!tryCancelSessionTimeout()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server=[Session Expired]"); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + return; + } + + keepAliveSupport.start(); - private final Throwable ex; + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Session has been resumed successfully", session); + } - RetrySignal(Throwable ex) { - this.ex = ex; + if (!resumableConnection.connect(nextDuplexConnection)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server_pos=[Session Expired]"); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + // no need to do anything since connection resumable connection is liklly to + // be disposed + } + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Mismatching remote and local state. Expected RemoteImpliedPosition[{}] to be greater or equal to the LocalPosition[{}]. Terminating received connection", + session, + remoteImpliedPos, + position); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server_pos=[" + remoteImpliedPos + "]"); + + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + } + } else if (frameType == FrameType.ERROR) { + final RuntimeException exception = Exceptions.from(0, shouldBeResumeOKFrame); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Received error frame. Terminating received connection", + session, + exception); + } + if (exception instanceof RejectedResumeException) { + resumableConnection.dispose(nextDuplexConnection, exception); + nextDuplexConnection.dispose(); + nextDuplexConnection.receive().subscribe(); + return; + } + + nextDuplexConnection.dispose(); + nextDuplexConnection.receive().subscribe(); + throw exception; // assume retryable exception + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Illegal first frame received. RESUME_OK frame must be received before any others. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("RESUME_OK frame must be received before any others"); + + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + + // no need to do anything since remote server rejected our connection completely } + } - @Override - public long totalRetries() { - return 0; + boolean tryCancelSessionTimeout() { + for (; ; ) { + final Subscription subscription = this.s; + + if (subscription == Operators.cancelledSubscription()) { + return false; + } + + if (S.compareAndSet(this, subscription, null)) { + subscription.cancel(); + return true; + } } + } - @Override - public long totalRetriesInARow() { - return 0; + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); } + } + + @Override + public void onNext(Tuple2 objects) {} - @Override - public Throwable failure() { - return ex; + @Override + public void onError(Throwable t) { + if (!Operators.terminate(S, this)) { + Operators.onErrorDropped(t, currentContext()); } + + resumableConnection.dispose(); + } + + @Override + public void onComplete() {} + + public void setKeepAliveSupport(KeepAliveSupport keepAliveSupport) { + this.keepAliveSupport = keepAliveSupport; } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java index 1875b7eac..e23bc154b 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,225 +16,839 @@ package io.rsocket.resume; +import static io.rsocket.resume.ResumableDuplexConnection.isResumableFrame; + import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import java.util.ArrayDeque; import java.util.Queue; -import org.reactivestreams.Subscriber; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.util.concurrent.Queues; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** + * writes - n (where n is frequent, primary operation) reads - m (where m == KeepAliveFrequency) + * skip - k -> 0 (where k is the rare operation which happens after disconnection + */ +public class InMemoryResumableFramesStore extends Flux + implements ResumableFramesStore, Subscription { -public class InMemoryResumableFramesStore implements ResumableFramesStore { + private FramesSubscriber framesSubscriber; private static final Logger logger = LoggerFactory.getLogger(InMemoryResumableFramesStore.class); - private static final long SAVE_REQUEST_SIZE = Long.MAX_VALUE; - private final MonoProcessor disposed = MonoProcessor.create(); - volatile long position; - volatile long impliedPosition; - volatile int cacheSize; + final Sinks.Empty disposed = Sinks.empty(); final Queue cachedFrames; - private final String tag; - private final int cacheLimit; - private volatile int upstreamFrameRefCnt; + final String side; + final String session; + final int cacheLimit; + + volatile long impliedPosition; + static final AtomicLongFieldUpdater IMPLIED_POSITION = + AtomicLongFieldUpdater.newUpdater(InMemoryResumableFramesStore.class, "impliedPosition"); + + volatile long firstAvailableFramePosition; + static final AtomicLongFieldUpdater FIRST_AVAILABLE_FRAME_POSITION = + AtomicLongFieldUpdater.newUpdater( + InMemoryResumableFramesStore.class, "firstAvailableFramePosition"); + + long remoteImpliedPosition; + + int cacheSize; - public InMemoryResumableFramesStore(String tag, int cacheSizeBytes) { - this.tag = tag; + Throwable terminal; + + CoreSubscriber actual; + CoreSubscriber pendingActual; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(InMemoryResumableFramesStore.class, "state"); + + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is finalized and all related + * stores are cleaned + */ + static final long FINALIZED_FLAG = + 0b1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is terminated via the {@link + * InMemoryResumableFramesStore#dispose()} method + */ + static final long DISPOSED_FLAG = + 0b0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is terminated via the {@link + * FramesSubscriber#onComplete()} or {@link FramesSubscriber#onError(Throwable)} ()} methods + */ + static final long TERMINATED_FLAG = + 0b0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** Flag which indicates that {@link InMemoryResumableFramesStore} has active frames consumer */ + static final long CONNECTED_FLAG = + 0b0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} has no active frames consumer + * but there is a one pending + */ + static final long PENDING_CONNECTION_FLAG = + 0b0000_1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that there are some received implied position changes from the remote + * party + */ + static final long REMOTE_IMPLIED_POSITION_CHANGED_FLAG = + 0b0000_0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that there are some frames stored in the {@link + * io.rsocket.internal.UnboundedProcessor} which has to be cached and sent to the remote party + */ + static final long HAS_FRAME_FLAG = + 0b0000_0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore#drain(long)} has an actor which + * is currently progressing on the work. This flag should work as a guard to enter|exist into|from + * the {@link InMemoryResumableFramesStore#drain(long)} method. + */ + static final long MAX_WORK_IN_PROGRESS = + 0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111L; + + public InMemoryResumableFramesStore(String side, ByteBuf session, int cacheSizeBytes) { + this.side = side; + this.session = session.toString(CharsetUtil.UTF_8); this.cacheLimit = cacheSizeBytes; - this.cachedFrames = cachedFramesQueue(cacheSizeBytes); + this.cachedFrames = new ArrayDeque<>(); } public Mono saveFrames(Flux frames) { - MonoProcessor completed = MonoProcessor.create(); - frames - .doFinally(s -> completed.onComplete()) - .subscribe(new FramesSubscriber(SAVE_REQUEST_SIZE)); - return completed; + return frames + .transform( + Operators.lift( + (__, actual) -> this.framesSubscriber = new FramesSubscriber(actual, this))) + .then(); } @Override public void releaseFrames(long remoteImpliedPos) { - long pos = position; - logger.debug( - "{} Removing frames for local: {}, remote implied: {}", tag, pos, remoteImpliedPos); - long removeSize = Math.max(0, remoteImpliedPos - pos); - while (removeSize > 0) { - ByteBuf cachedFrame = cachedFrames.poll(); - if (cachedFrame != null) { - removeSize -= releaseTailFrame(cachedFrame); - } else { + long lastReceivedRemoteImpliedPosition = this.remoteImpliedPosition; + if (lastReceivedRemoteImpliedPosition > remoteImpliedPos) { + throw new IllegalStateException( + "Given Remote Implied Position is behind the last received Remote Implied Position"); + } + + this.remoteImpliedPosition = remoteImpliedPos; + + final long previousState = markRemoteImpliedPositionChanged(this); + if (isFinalized(previousState) || isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | REMOTE_IMPLIED_POSITION_CHANGED_FLAG); + } + + void drain(long expectedState) { + final Fuseable.QueueSubscription qs = this.framesSubscriber.qs; + final Queue cachedFrames = this.cachedFrames; + + for (; ; ) { + if (hasRemoteImpliedPositionChanged(expectedState)) { + expectedState = handlePendingRemoteImpliedPositionChanges(expectedState, cachedFrames); + } + + if (hasPendingConnection(expectedState)) { + expectedState = handlePendingConnection(expectedState, cachedFrames); + } + + if (isConnected(expectedState)) { + if (isTerminated(expectedState)) { + handleTerminated(qs, this.terminal); + } else if (isDisposed()) { + handleDisposed(); + } else if (hasFrames(expectedState)) { + handlePendingFrames(qs); + } + } + + if (isDisposed(expectedState) || isTerminated(expectedState)) { + clearAndFinalize(this); + return; + } + + expectedState = markWorkDone(this, expectedState); + if (isFinalized(expectedState)) { + return; + } + + if (!isWorkInProgress(expectedState)) { + return; + } + } + } + + long handlePendingRemoteImpliedPositionChanges(long expectedState, Queue cachedFrames) { + final long remoteImpliedPosition = this.remoteImpliedPosition; + final long firstAvailableFramePosition = this.firstAvailableFramePosition; + final long toDropFromCache = Math.max(0, remoteImpliedPosition - firstAvailableFramePosition); + + if (toDropFromCache > 0) { + final int droppedFromCache = dropFramesFromCache(toDropFromCache, cachedFrames); + + if (toDropFromCache > droppedFromCache) { + this.terminal = + new IllegalStateException( + String.format( + "Local and remote state disagreement: " + + "need to remove additional %d bytes, but cache is empty", + toDropFromCache)); + expectedState = markTerminated(this) | TERMINATED_FLAG; + } + + if (toDropFromCache < droppedFromCache) { + this.terminal = + new IllegalStateException( + "Local and remote state disagreement: local and remote frame sizes are not equal"); + expectedState = markTerminated(this) | TERMINATED_FLAG; + } + + FIRST_AVAILABLE_FRAME_POSITION.lazySet(this, firstAvailableFramePosition + droppedFromCache); + if (this.cacheLimit != Integer.MAX_VALUE) { + this.cacheSize -= droppedFromCache; + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Removed frames from cache to position[{}]. CacheSize[{}]", + this.side, + this.session, + this.remoteImpliedPosition, + this.cacheSize); + } + } + } + + return expectedState; + } + + void handlePendingFrames(Fuseable.QueueSubscription qs) { + for (; ; ) { + final ByteBuf frame = qs.poll(); + final boolean empty = frame == null; + + if (empty) { + break; + } + + handleFrame(frame); + + if (!isConnected(this.state)) { break; } } - if (removeSize > 0) { - throw new IllegalStateException( - String.format( - "Local and remote state disagreement: " - + "need to remove additional %d bytes, but cache is empty", - removeSize)); - } else if (removeSize < 0) { - throw new IllegalStateException( - "Local and remote state disagreement: " + "local and remote frame sizes are not equal"); - } else { - logger.debug("{} Removed frames. Current cache size: {}", tag, cacheSize); + } + + long handlePendingConnection(long expectedState, Queue cachedFrames) { + CoreSubscriber lastActual = null; + for (; ; ) { + final CoreSubscriber nextActual = this.pendingActual; + + if (nextActual != lastActual) { + for (final ByteBuf frame : cachedFrames) { + nextActual.onNext(frame.retainedSlice()); + } + } + + expectedState = markConnected(this, expectedState); + if (isConnected(expectedState)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Connected at Position[{}] and ImpliedPosition[{}]", + side, + session, + firstAvailableFramePosition, + impliedPosition); + } + + this.actual = nextActual; + break; + } + + if (!hasPendingConnection(expectedState)) { + break; + } + + lastActual = nextActual; } + return expectedState; + } + + static int dropFramesFromCache(long toRemoveBytes, Queue cache) { + int removedBytes = 0; + while (toRemoveBytes > removedBytes && cache.size() > 0) { + final ByteBuf cachedFrame = cache.poll(); + final int frameSize = cachedFrame.readableBytes(); + + cachedFrame.release(); + + removedBytes += frameSize; + } + + return removedBytes; } @Override public Flux resumeStream() { - return Flux.generate( - () -> new ResumeStreamState(cachedFrames.size(), upstreamFrameRefCnt), - (state, sink) -> { - if (state.next()) { - /*spsc queue has no iterator - iterating by consuming*/ - ByteBuf frame = cachedFrames.poll(); - if (state.shouldRetain(frame)) { - frame.retain(); - } - cachedFrames.offer(frame); - sink.next(frame); - } else { - sink.complete(); - logger.debug("{} Resuming stream completed", tag); - } - return state; - }); + return this; } @Override public long framePosition() { - return position; + return this.firstAvailableFramePosition; } @Override public long frameImpliedPosition() { - return impliedPosition; + return this.impliedPosition & Long.MAX_VALUE; } @Override - public void resumableFrameReceived(ByteBuf frame) { - /*called on transport thread so non-atomic on volatile is safe*/ - impliedPosition += frame.readableBytes(); + public boolean resumableFrameReceived(ByteBuf frame) { + final int frameSize = frame.readableBytes(); + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + if (impliedPosition < 0) { + return false; + } + + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, impliedPosition + frameSize)) { + return true; + } + } + } + + void pauseImplied() { + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, impliedPosition | Long.MIN_VALUE)) { + logger.debug( + "Side[{}]|Session[{}]. Paused at position[{}]", side, session, impliedPosition); + return; + } + } + } + + void resumeImplied() { + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + final long restoredImpliedPosition = impliedPosition & Long.MAX_VALUE; + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, restoredImpliedPosition)) { + logger.debug( + "Side[{}]|Session[{}]. Resumed at position[{}]", + side, + session, + restoredImpliedPosition); + return; + } + } } @Override public Mono onClose() { - return disposed; + return disposed.asMono(); } @Override public void dispose() { - cacheSize = 0; - ByteBuf frame = cachedFrames.poll(); - while (frame != null) { + final long previousState = markDisposed(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | DISPOSED_FLAG); + } + + void clearCache() { + final Queue frames = this.cachedFrames; + this.cacheSize = 0; + + ByteBuf frame; + while ((frame = frames.poll()) != null) { frame.release(); - frame = cachedFrames.poll(); } - disposed.onComplete(); } @Override public boolean isDisposed() { - return disposed.isTerminated(); + return isDisposed(this.state); } - /* this method and saveFrame() won't be called concurrently, - * so non-atomic on volatile is safe*/ - private int releaseTailFrame(ByteBuf content) { - int frameSize = content.readableBytes(); - cacheSize -= frameSize; - position += frameSize; - content.release(); - return frameSize; + void handleFrame(ByteBuf frame) { + final boolean isResumable = isResumableFrame(frame); + if (isResumable) { + handleResumableFrame(frame); + return; + } + + handleConnectionFrame(frame); } - /*this method and releaseTailFrame() won't be called concurrently, - * so non-atomic on volatile is safe*/ - void saveFrame(ByteBuf frame) { - if (upstreamFrameRefCnt == 0) { - upstreamFrameRefCnt = frame.refCnt(); - } + void handleTerminated(Fuseable.QueueSubscription qs, @Nullable Throwable t) { + for (; ; ) { + final ByteBuf frame = qs.poll(); + final boolean empty = frame == null; - int frameSize = frame.readableBytes(); - long availableSize = cacheLimit - cacheSize; - while (availableSize < frameSize) { - ByteBuf cachedFrame = cachedFrames.poll(); - if (cachedFrame != null) { - availableSize += releaseTailFrame(cachedFrame); - } else { + if (empty) { break; } + + handleFrame(frame); } - if (availableSize >= frameSize) { - cachedFrames.offer(frame.retain()); - cacheSize += frameSize; + if (t != null) { + this.actual.onError(t); } else { - position += frameSize; + this.actual.onComplete(); } } - static class ResumeStreamState { - private final int cacheSize; - private final int expectedRefCnt; - private int cacheCounter; + void handleDisposed() { + this.actual.onError(new CancellationException("Disposed")); + } - public ResumeStreamState(int cacheSize, int expectedRefCnt) { - this.cacheSize = cacheSize; - this.expectedRefCnt = expectedRefCnt; - } + void handleConnectionFrame(ByteBuf frame) { + this.actual.onNext(frame); + } - public boolean next() { - if (cacheCounter < cacheSize) { - cacheCounter++; - return true; + void handleResumableFrame(ByteBuf frame) { + final Queue frames = this.cachedFrames; + final int incomingFrameSize = frame.readableBytes(); + final int cacheLimit = this.cacheLimit; + + final boolean canBeStore; + int cacheSize = this.cacheSize; + if (cacheLimit != Integer.MAX_VALUE) { + final long availableSize = cacheLimit - cacheSize; + + if (availableSize < incomingFrameSize) { + final long firstAvailableFramePosition = this.firstAvailableFramePosition; + final long toRemoveBytes = incomingFrameSize - availableSize; + final int removedBytes = dropFramesFromCache(toRemoveBytes, frames); + + cacheSize = cacheSize - removedBytes; + canBeStore = removedBytes >= toRemoveBytes; + + if (canBeStore) { + FIRST_AVAILABLE_FRAME_POSITION.lazySet(this, firstAvailableFramePosition + removedBytes); + } else { + this.cacheSize = cacheSize; + FIRST_AVAILABLE_FRAME_POSITION.lazySet( + this, firstAvailableFramePosition + removedBytes + incomingFrameSize); + } } else { - return false; + canBeStore = true; + } + } else { + canBeStore = true; + } + + if (canBeStore) { + frames.offer(frame); + + if (cacheLimit != Integer.MAX_VALUE) { + this.cacheSize = cacheSize + incomingFrameSize; } } - public boolean shouldRetain(ByteBuf frame) { - return frame.refCnt() == expectedRefCnt; + this.actual.onNext(canBeStore ? frame.retainedSlice() : frame); + } + + @Override + public void request(long n) {} + + @Override + public void cancel() { + pauseImplied(); + markDisconnected(this); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Disconnected at Position[{}] and ImpliedPosition[{}]", + side, + session, + firstAvailableFramePosition, + frameImpliedPosition()); } } - static Queue cachedFramesQueue(int size) { - return Queues.get(size).get(); + @Override + public void subscribe(CoreSubscriber actual) { + resumeImplied(); + actual.onSubscribe(this); + this.pendingActual = actual; + + final long previousState = markPendingConnection(this); + if (isDisposed(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (isTerminated(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | PENDING_CONNECTION_FLAG); } - class FramesSubscriber implements Subscriber { - private final long firstRequestSize; - private final long refillSize; - private int received; - private Subscription s; + static class FramesSubscriber + implements CoreSubscriber, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + final InMemoryResumableFramesStore parent; + + Fuseable.QueueSubscription qs; + + boolean done; - public FramesSubscriber(long requestSize) { - this.firstRequestSize = requestSize; - this.refillSize = firstRequestSize / 2; + FramesSubscriber(CoreSubscriber actual, InMemoryResumableFramesStore parent) { + this.actual = actual; + this.parent = parent; } @Override + @SuppressWarnings("unchecked") public void onSubscribe(Subscription s) { - this.s = s; - s.request(firstRequestSize); + if (Operators.validate(this.qs, s)) { + final Fuseable.QueueSubscription qs = (Fuseable.QueueSubscription) s; + this.qs = qs; + + final int m = qs.requestFusion(Fuseable.ANY); + + if (m != Fuseable.ASYNC) { + s.cancel(); + this.actual.onSubscribe(this); + this.actual.onError(new IllegalStateException("Source has to be ASYNC fuseable")); + return; + } + + this.actual.onSubscribe(this); + } } @Override public void onNext(ByteBuf byteBuf) { - saveFrame(byteBuf); - if (firstRequestSize != Long.MAX_VALUE && ++received == refillSize) { - received = 0; - s.request(refillSize); + final InMemoryResumableFramesStore parent = this.parent; + long previousState = InMemoryResumableFramesStore.markFrameAdded(parent); + + if (isFinalized(previousState)) { + this.qs.clear(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (isConnected(previousState) || hasPendingConnection(previousState)) { + parent.drain((previousState + 1) | HAS_FRAME_FLAG); } } @Override public void onError(Throwable t) { - logger.info("unexpected onError signal: {}, {}", t.getClass(), t.getMessage()); + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + final InMemoryResumableFramesStore parent = this.parent; + + parent.terminal = t; + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain((previousState + 1) | TERMINATED_FLAG); + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + final InMemoryResumableFramesStore parent = this.parent; + + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain((previousState + 1) | TERMINATED_FLAG); + } + + @Override + public void cancel() { + if (this.done) { + return; + } + + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain(previousState | TERMINATED_FLAG); } @Override - public void onComplete() {} + public void request(long n) {} + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public Void poll() { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public void clear() {} + } + + static long markFrameAdded(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isConnected(state) || hasPendingConnection(state) || isWorkInProgress(state)) { + nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? nextState : nextState + 1; + } + + if (STATE.compareAndSet(store, state, nextState | HAS_FRAME_FLAG)) { + return state; + } + } + } + + static long markPendingConnection(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state) || isDisposed(state) || isTerminated(state)) { + return state; + } + + if (isConnected(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : state + 1; + if (STATE.compareAndSet(store, state, nextState | PENDING_CONNECTION_FLAG)) { + return state; + } + } + } + + static long markRemoteImpliedPositionChanged(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | REMOTE_IMPLIED_POSITION_CHANGED_FLAG)) { + return state; + } + } + } + + static long markDisconnected(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + if (STATE.compareAndSet(store, state, state & ~CONNECTED_FLAG & ~PENDING_CONNECTION_FLAG)) { + return state; + } + } + } + + static long markWorkDone(InMemoryResumableFramesStore store, long expectedState) { + for (; ; ) { + final long state = store.state; + + if (expectedState != state) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + final long nextState = state & ~MAX_WORK_IN_PROGRESS & ~REMOTE_IMPLIED_POSITION_CHANGED_FLAG; + if (STATE.compareAndSet(store, state, nextState)) { + return nextState; + } + } + } + + static long markConnected(InMemoryResumableFramesStore store, long expectedState) { + for (; ; ) { + final long state = store.state; + + if (state != expectedState) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + final long nextState = state ^ PENDING_CONNECTION_FLAG | CONNECTED_FLAG; + if (STATE.compareAndSet(store, state, nextState)) { + return nextState; + } + } + } + + static long markTerminated(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | TERMINATED_FLAG)) { + return state; + } + } + } + + static long markDisposed(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | DISPOSED_FLAG)) { + return state; + } + } + } + + static void clearAndFinalize(InMemoryResumableFramesStore store) { + final Fuseable.QueueSubscription qs = store.framesSubscriber.qs; + for (; ; ) { + final long state = store.state; + + qs.clear(); + store.clearCache(); + + if (isFinalized(state)) { + return; + } + + if (STATE.compareAndSet(store, state, state | FINALIZED_FLAG & ~MAX_WORK_IN_PROGRESS)) { + store.disposed.tryEmitEmpty(); + store.framesSubscriber.onComplete(); + return; + } + } + } + + static boolean isConnected(long state) { + return (state & CONNECTED_FLAG) == CONNECTED_FLAG; + } + + static boolean hasRemoteImpliedPositionChanged(long state) { + return (state & REMOTE_IMPLIED_POSITION_CHANGED_FLAG) == REMOTE_IMPLIED_POSITION_CHANGED_FLAG; + } + + static boolean hasPendingConnection(long state) { + return (state & PENDING_CONNECTION_FLAG) == PENDING_CONNECTION_FLAG; + } + + static boolean hasFrames(long state) { + return (state & HAS_FRAME_FLAG) == HAS_FRAME_FLAG; + } + + static boolean isTerminated(long state) { + return (state & TERMINATED_FLAG) == TERMINATED_FLAG; + } + + static boolean isDisposed(long state) { + return (state & DISPOSED_FLAG) == DISPOSED_FLAG; + } + + static boolean isFinalized(long state) { + return (state & FINALIZED_FLAG) == FINALIZED_FLAG; + } + + static boolean isWorkInProgress(long state) { + return (state & MAX_WORK_IN_PROGRESS) > 0; } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java index 7ec0abaee..6dd3d5f4d 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java @@ -16,35 +16,10 @@ package io.rsocket.resume; -import io.netty.buffer.ByteBuf; -import io.rsocket.Closeable; -import io.rsocket.DuplexConnection; -import reactor.core.publisher.Mono; +import io.rsocket.keepalive.KeepAliveSupport; +import reactor.core.Disposable; -public interface RSocketSession extends Closeable { +public interface RSocketSession extends Disposable { - ByteBuf token(); - - ResumableDuplexConnection resumableConnection(); - - RSocketSession continueWith(T ConnectionFactory); - - RSocketSession resumeWith(ByteBuf resumeFrame); - - void reconnect(DuplexConnection connection); - - @Override - default Mono onClose() { - return resumableConnection().onClose(); - } - - @Override - default void dispose() { - resumableConnection().dispose(); - } - - @Override - default boolean isDisposed() { - return resumableConnection().isDisposed(); - } + void setKeepAliveSupport(KeepAliveSupport keepAliveSupport); } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java b/rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java deleted file mode 100644 index 6553e5ec5..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import reactor.core.publisher.Flux; -import reactor.core.publisher.ReplayProcessor; - -class RequestListener { - private final ReplayProcessor requests = ReplayProcessor.create(1); - - public Flux apply(Flux flux) { - return flux.doOnRequest(requests::onNext); - } - - public Flux requests() { - return requests; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java index 461d71228..c8811b9b3 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,433 +18,430 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Closeable; +import io.netty.util.CharsetUtil; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; -import java.nio.channels.ClosedChannelException; -import java.time.Duration; -import java.util.Queue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; -import org.reactivestreams.Publisher; +import io.rsocket.internal.UnboundedProcessor; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; import reactor.core.Disposable; -import reactor.core.Disposables; -import reactor.core.publisher.*; -import reactor.util.concurrent.Queues; - -public class ResumableDuplexConnection implements DuplexConnection, ResumeStateHolder { - private static final Logger logger = LoggerFactory.getLogger(ResumableDuplexConnection.class); - private static final Throwable closedChannelException = new ClosedChannelException(); - - private final String tag; - private final ResumableFramesStore resumableFramesStore; - private final Duration resumeStreamTimeout; - private final boolean cleanupOnKeepAlive; - - private final ReplayProcessor connections = ReplayProcessor.create(1); - private final EmitterProcessor connectionErrors = EmitterProcessor.create(); - private volatile DuplexConnection curConnection; - /*used instead of EmitterProcessor because its autocancel=false capability had no expected effect*/ - private final FluxProcessor downStreamFrames = ReplayProcessor.create(0); - private final FluxProcessor resumeSaveFrames = EmitterProcessor.create(); - private final MonoProcessor resumeSaveCompleted = MonoProcessor.create(); - private final Queue actions = Queues.unboundedMultiproducer().get(); - private final AtomicInteger actionsWip = new AtomicInteger(); - private final AtomicBoolean disposed = new AtomicBoolean(); - - private final Mono framesSent; - private final RequestListener downStreamRequestListener = new RequestListener(); - private final RequestListener resumeSaveStreamRequestListener = new RequestListener(); - private final UnicastProcessor> upstreams = UnicastProcessor.create(); - private final UpstreamFramesSubscriber upstreamSubscriber = - new UpstreamFramesSubscriber( - Queues.SMALL_BUFFER_SIZE, - downStreamRequestListener.requests(), - resumeSaveStreamRequestListener.requests(), - this::dispatch); - - private volatile Runnable onResume; - private volatile Runnable onDisconnect; - private volatile int state; - private volatile Disposable resumedStreamDisposable = Disposables.disposed(); +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +public class ResumableDuplexConnection extends Flux + implements DuplexConnection, Subscription { + + static final Logger logger = LoggerFactory.getLogger(ResumableDuplexConnection.class); + + final String side; + final String session; + final ResumableFramesStore resumableFramesStore; + + final UnboundedProcessor savableFramesSender; + final Sinks.Empty onQueueClose; + final Sinks.Empty onLastConnectionClose; + final SocketAddress remoteAddress; + final Sinks.Many onConnectionClosedSink; + + CoreSubscriber receiveSubscriber; + FrameReceivingSubscriber activeReceivingSubscriber; + + volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(ResumableDuplexConnection.class, "state"); + + volatile DuplexConnection activeConnection; + static final AtomicReferenceFieldUpdater + ACTIVE_CONNECTION = + AtomicReferenceFieldUpdater.newUpdater( + ResumableDuplexConnection.class, DuplexConnection.class, "activeConnection"); + + int connectionIndex = 0; public ResumableDuplexConnection( - String tag, - DuplexConnection duplexConnection, - ResumableFramesStore resumableFramesStore, - Duration resumeStreamTimeout, - boolean cleanupOnKeepAlive) { - this.tag = tag; + String side, + ByteBuf session, + DuplexConnection initialConnection, + ResumableFramesStore resumableFramesStore) { + this.side = side; + this.session = session.toString(CharsetUtil.UTF_8); + this.onConnectionClosedSink = Sinks.unsafe().many().unicast().onBackpressureBuffer(); this.resumableFramesStore = resumableFramesStore; - this.resumeStreamTimeout = resumeStreamTimeout; - this.cleanupOnKeepAlive = cleanupOnKeepAlive; - - resumableFramesStore - .saveFrames(resumeSaveStreamRequestListener.apply(resumeSaveFrames)) - .subscribe(resumeSaveCompleted); - - upstreams.flatMap(Function.identity()).subscribe(upstreamSubscriber); - - framesSent = - connections - .switchMap( - c -> { - logger.debug("Switching transport: {}", tag); - return c.send(downStreamRequestListener.apply(downStreamFrames)) - .doFinally( - s -> - logger.debug( - "{} Transport send completed: {}, {}", tag, s, c.toString())) - .onErrorResume(err -> Mono.never()); - }) - .then() - .cache(); - - reconnect(duplexConnection); - } + this.onQueueClose = Sinks.unsafe().empty(); + this.onLastConnectionClose = Sinks.unsafe().empty(); + this.savableFramesSender = new UnboundedProcessor(onQueueClose::tryEmitEmpty); + this.remoteAddress = initialConnection.remoteAddress(); - @Override - public ByteBufAllocator alloc() { - return curConnection.alloc(); + resumableFramesStore.saveFrames(savableFramesSender).subscribe(); + + ACTIVE_CONNECTION.lazySet(this, initialConnection); } - public void disconnect() { - DuplexConnection c = this.curConnection; - if (c != null) { - disconnect(c); + public boolean connect(DuplexConnection nextConnection) { + final DuplexConnection activeConnection = this.activeConnection; + if (activeConnection != DisposedConnection.INSTANCE + && ACTIVE_CONNECTION.compareAndSet(this, activeConnection, nextConnection)) { + + if (!activeConnection.isDisposed()) { + activeConnection.sendErrorAndClose( + new ConnectionErrorException("Connection unexpectedly replaced")); + } + + initConnection(nextConnection); + + return true; + } else { + return false; } } - public void onDisconnect(Runnable onDisconnectAction) { - this.onDisconnect = onDisconnectAction; - } + void initConnection(DuplexConnection nextConnection) { + final int nextConnectionIndex = this.connectionIndex + 1; + final FrameReceivingSubscriber frameReceivingSubscriber = + new FrameReceivingSubscriber(side, resumableFramesStore, receiveSubscriber); - public void onResume(Runnable onResumeAction) { - this.onResume = onResumeAction; - } + this.connectionIndex = nextConnectionIndex; + this.activeReceivingSubscriber = frameReceivingSubscriber; - /*reconnected by session after error. After this downstream can receive frames, - * but sending in suppressed until resume() is called*/ - public void reconnect(DuplexConnection connection) { - if (curConnection == null) { - logger.debug("{} Resumable duplex connection started with connection: {}", tag, connection); - state = State.CONNECTED; - onNewConnection(connection); - } else { + if (logger.isDebugEnabled()) { logger.debug( - "{} Resumable duplex connection reconnected with connection: {}", tag, connection); - /*race between sendFrame and doResumeStart may lead to ongoing upstream frames - written before resume complete*/ - dispatch(new ResumeStart(connection)); + "Side[{}]|Session[{}]|DuplexConnection[{}]. Connecting", side, session, connectionIndex); } + + final Disposable resumeStreamSubscription = + resumableFramesStore + .resumeStream() + .subscribe( + f -> nextConnection.sendFrame(FrameHeaderCodec.streamId(f), f), + t -> { + dispose(nextConnection, t); + nextConnection.sendErrorAndClose(new ConnectionErrorException(t.getMessage(), t)); + }, + () -> { + final ConnectionErrorException e = + new ConnectionErrorException("Connection Closed Unexpectedly"); + dispose(nextConnection, e); + nextConnection.sendErrorAndClose(e); + }); + nextConnection.receive().subscribe(frameReceivingSubscriber); + nextConnection + .onClose() + .doFinally( + __ -> { + frameReceivingSubscriber.dispose(); + resumeStreamSubscription.dispose(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Disconnected", + side, + session, + connectionIndex); + } + Sinks.EmitResult result = onConnectionClosedSink.tryEmitNext(nextConnectionIndex); + if (!result.equals(Sinks.EmitResult.OK)) { + logger.error( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Failed to notify session of closed connection: {}", + side, + session, + connectionIndex, + result); + } + }) + .subscribe(); } - /*after receiving RESUME (Server) or RESUME_OK (Client) - calculate and send resume frames */ - public void resume( - long remotePos, long remoteImpliedPos, Function, Mono> resumeFrameSent) { - /*race between sendFrame and doResume may lead to duplicate frames on resume store*/ - dispatch(new Resume(remotePos, remoteImpliedPos, resumeFrameSent)); + public void disconnect() { + final DuplexConnection activeConnection = this.activeConnection; + if (activeConnection != DisposedConnection.INSTANCE && !activeConnection.isDisposed()) { + activeConnection.dispose(); + } } @Override - public Mono sendOne(ByteBuf frame) { - return curConnection.sendOne(frame); + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + savableFramesSender.tryEmitPrioritized(frame); + } else { + savableFramesSender.tryEmitNormal(frame); + } } - @Override - public Mono send(Publisher frames) { - upstreams.onNext(Flux.from(frames)); - return framesSent; + /** + * Publisher for a sequence of integers starting at 1, with each next number emitted when the + * currently active connection is closed and should be resumed. The Publisher never emits an error + * and completes when the connection is disposed and not resumed. + */ + Flux onActiveConnectionClosed() { + return onConnectionClosedSink.asFlux(); } @Override - public Flux receive() { - return connections.switchMap( - c -> - c.receive() - .doOnNext( - f -> { - if (isResumableFrame(f)) { - resumableFramesStore.resumableFrameReceived(f); - } - }) - .onErrorResume(err -> Mono.never())); - } + public void sendErrorAndClose(RSocketErrorException rSocketErrorException) { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } - public long position() { - return resumableFramesStore.framePosition(); + savableFramesSender.tryEmitFinal( + ErrorFrameCodec.encode(activeConnection.alloc(), 0, rSocketErrorException)); + + activeConnection + .onClose() + .subscribe( + null, + t -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }, + () -> { + onConnectionClosedSink.tryEmitComplete(); + + final Throwable cause = rSocketErrorException.getCause(); + if (cause == null) { + onLastConnectionClose.tryEmitEmpty(); + } else { + onLastConnectionClose.tryEmitError(cause); + } + }); } @Override - public long impliedPosition() { - return resumableFramesStore.frameImpliedPosition(); + public Flux receive() { + return this; } @Override - public void onImpliedPosition(long remoteImpliedPos) { - logger.debug("Got remote position from keep-alive: {}", remoteImpliedPos); - if (cleanupOnKeepAlive) { - dispatch(new ReleaseFrames(remoteImpliedPos)); - } + public ByteBufAllocator alloc() { + return activeConnection.alloc(); } @Override public Mono onClose() { - return Flux.merge(connections.last().flatMap(Closeable::onClose), resumeSaveCompleted).then(); + return Mono.whenDelayError( + onQueueClose.asMono(), resumableFramesStore.onClose(), onLastConnectionClose.asMono()); } @Override public void dispose() { - if (disposed.compareAndSet(false, true)) { - logger.debug("Resumable connection disposed: {}, {}", tag, this); - upstreams.onComplete(); - connections.onComplete(); - connectionErrors.onComplete(); - resumeSaveFrames.onComplete(); - curConnection.dispose(); - upstreamSubscriber.dispose(); - resumedStreamDisposable.dispose(); - resumableFramesStore.dispose(); + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } + savableFramesSender.onComplete(); + activeConnection + .onClose() + .subscribe( + null, + t -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }, + () -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }); + } + + void dispose(DuplexConnection nextConnection, @Nullable Throwable e) { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; } + savableFramesSender.onComplete(); + nextConnection + .onClose() + .subscribe( + null, + t -> { + if (e != null) { + onLastConnectionClose.tryEmitError(e); + } else { + onLastConnectionClose.tryEmitEmpty(); + } + onConnectionClosedSink.tryEmitComplete(); + }, + () -> { + if (e != null) { + onLastConnectionClose.tryEmitError(e); + } else { + onLastConnectionClose.tryEmitEmpty(); + } + onConnectionClosedSink.tryEmitComplete(); + }); } @Override - public double availability() { - return curConnection.availability(); + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onQueueClose.scan(Scannable.Attr.TERMINATED) + || onQueueClose.scan(Scannable.Attr.CANCELLED); } @Override - public boolean isDisposed() { - return disposed.get(); + public SocketAddress remoteAddress() { + return remoteAddress; } - private void sendFrame(ByteBuf f) { - if (disposed.get()) { - f.release(); - return; - } - /*resuming from store so no need to save again*/ - if (state != State.RESUME && isResumableFrame(f)) { - resumeSaveFrames.onNext(f); - } - /*filter frames coming from upstream before actual resumption began, - * to preserve frames ordering*/ - if (state != State.RESUME_STARTED) { - downStreamFrames.onNext(f); + @Override + public void request(long n) { + if (state == 1 && STATE.compareAndSet(this, 1, 2)) { + // happens for the very first time with the initial connection + initConnection(this.activeConnection); } } - Flux connectionErrors() { - return connectionErrors; + @Override + public void cancel() { + dispose(); } - private void dispatch(Object action) { - actions.offer(action); - if (actionsWip.getAndIncrement() == 0) { - do { - Object a = actions.poll(); - if (a instanceof ByteBuf) { - sendFrame((ByteBuf) a); - } else { - ((Runnable) a).run(); - } - } while (actionsWip.decrementAndGet() != 0); + @Override + public void subscribe(CoreSubscriber receiverSubscriber) { + if (state == 0 && STATE.compareAndSet(this, 0, 1)) { + receiveSubscriber = receiverSubscriber; + receiverSubscriber.onSubscribe(this); } } - private void doResumeStart(DuplexConnection connection) { - state = State.RESUME_STARTED; - resumedStreamDisposable.dispose(); - upstreamSubscriber.resumeStart(); - onNewConnection(connection); + static boolean isResumableFrame(ByteBuf frame) { + return FrameHeaderCodec.streamId(frame) != 0; } - private void doResume( - long remotePosition, - long remoteImpliedPosition, - Function, Mono> sendResumeFrame) { - long localPosition = position(); - long localImpliedPosition = impliedPosition(); - - logger.debug("Resumption start"); - logger.debug( - "Resumption states. local: [pos: {}, impliedPos: {}], remote: [pos: {}, impliedPos: {}]", - localPosition, - localImpliedPosition, - remotePosition, - remoteImpliedPosition); - - long remoteImpliedPos = - calculateRemoteImpliedPos( - localPosition, localImpliedPosition, - remotePosition, remoteImpliedPosition); - - Mono impliedPositionOrError; - if (remoteImpliedPos >= 0) { - state = State.RESUME; - releaseFramesToPosition(remoteImpliedPos); - impliedPositionOrError = Mono.just(localImpliedPosition); - } else { - impliedPositionOrError = - Mono.error( - new ResumeStateException( - localPosition, localImpliedPosition, - remotePosition, remoteImpliedPosition)); - } + @Override + public String toString() { + return "ResumableDuplexConnection{" + + "side='" + + side + + '\'' + + ", session='" + + session + + '\'' + + ", remoteAddress=" + + remoteAddress + + ", state=" + + state + + ", activeConnection=" + + activeConnection + + ", connectionIndex=" + + connectionIndex + + '}'; + } + + private static final class DisposedConnection implements DuplexConnection { + + static final DisposedConnection INSTANCE = new DisposedConnection(); + + private DisposedConnection() {} - sendResumeFrame - .apply(impliedPositionOrError) - .doOnSuccess( - v -> { - Runnable r = this.onResume; - if (r != null) { - r.run(); - } - }) - .then( - streamResumedFrames( - resumableFramesStore - .resumeStream() - .timeout(resumeStreamTimeout) - .doFinally(s -> dispatch(new ResumeComplete()))) - .doOnError(err -> dispose())) - .onErrorResume(err -> Mono.empty()) - .subscribe(); - } + @Override + public void dispose() {} - static long calculateRemoteImpliedPos( - long pos, long impliedPos, long remotePos, long remoteImpliedPos) { - if (remotePos <= impliedPos && pos <= remoteImpliedPos) { - return remoteImpliedPos; - } else { - return -1L; + @Override + public Mono onClose() { + return Mono.never(); } - } - private void doResumeComplete() { - logger.debug("Completing resumption"); - state = State.RESUME_COMPLETED; - upstreamSubscriber.resumeComplete(); - } + @Override + public void sendFrame(int streamId, ByteBuf frame) {} - private Mono streamResumedFrames(Flux frames) { - return Mono.create( - s -> { - ResumeFramesSubscriber subscriber = - new ResumeFramesSubscriber( - downStreamRequestListener.requests(), this::dispatch, s::error, s::success); - s.onDispose(subscriber); - resumedStreamDisposable = subscriber; - frames.subscribe(subscriber); - }); - } + @Override + public Flux receive() { + return Flux.never(); + } - private void onNewConnection(DuplexConnection connection) { - curConnection = connection; - connection.onClose().doFinally(v -> disconnect(connection)).subscribe(); - connections.onNext(connection); - } + @Override + public void sendErrorAndClose(RSocketErrorException e) {} - private void disconnect(DuplexConnection connection) { - /*do not report late disconnects on old connection if new one is available*/ - if (curConnection == connection && state != State.DISCONNECTED) { - connection.dispose(); - state = State.DISCONNECTED; - logger.debug( - "{} Inner connection disconnected: {}", - tag, - closedChannelException.getClass().getSimpleName()); - connectionErrors.onNext(closedChannelException); - Runnable r = this.onDisconnect; - if (r != null) { - r.run(); - } + @Override + public ByteBufAllocator alloc() { + return ByteBufAllocator.DEFAULT; } - } - - /*remove frames confirmed by implied pos, - set current pos accordingly*/ - private void releaseFramesToPosition(long remoteImpliedPos) { - resumableFramesStore.releaseFrames(remoteImpliedPos); - } - static boolean isResumableFrame(ByteBuf frame) { - switch (FrameHeaderCodec.nativeFrameType(frame)) { - case REQUEST_CHANNEL: - case REQUEST_STREAM: - case REQUEST_RESPONSE: - case REQUEST_FNF: - case REQUEST_N: - case CANCEL: - case ERROR: - case PAYLOAD: - return true; - default: - return false; + @Override + @SuppressWarnings("ConstantConditions") + public SocketAddress remoteAddress() { + return null; } } - static class State { - static int CONNECTED = 0; - static int RESUME_STARTED = 1; - static int RESUME = 2; - static int RESUME_COMPLETED = 3; - static int DISCONNECTED = 4; - } + private static final class FrameReceivingSubscriber + implements CoreSubscriber, Disposable { - class ResumeStart implements Runnable { - private final DuplexConnection connection; + final ResumableFramesStore resumableFramesStore; + final CoreSubscriber actual; + final String tag; - public ResumeStart(DuplexConnection connection) { - this.connection = connection; + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + FrameReceivingSubscriber.class, Subscription.class, "s"); + + boolean cancelled; + + private FrameReceivingSubscriber( + String tag, ResumableFramesStore store, CoreSubscriber actual) { + this.tag = tag; + this.resumableFramesStore = store; + this.actual = actual; } @Override - public void run() { - doResumeStart(connection); + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } } - } - class Resume implements Runnable { - private final long remotePos; - private final long remoteImpliedPos; - private final Function, Mono> resumeFrameSent; + @Override + public void onNext(ByteBuf frame) { + if (cancelled || s == Operators.cancelledSubscription()) { + return; + } + + if (isResumableFrame(frame)) { + if (resumableFramesStore.resumableFrameReceived(frame)) { + actual.onNext(frame); + } + return; + } - public Resume( - long remotePos, long remoteImpliedPos, Function, Mono> resumeFrameSent) { - this.remotePos = remotePos; - this.remoteImpliedPos = remoteImpliedPos; - this.resumeFrameSent = resumeFrameSent; + actual.onNext(frame); } @Override - public void run() { - doResume(remotePos, remoteImpliedPos, resumeFrameSent); + public void onError(Throwable t) { + Operators.set(S, this, Operators.cancelledSubscription()); } - } - - private class ResumeComplete implements Runnable { @Override - public void run() { - doResumeComplete(); + public void onComplete() { + Operators.set(S, this, Operators.cancelledSubscription()); } - } - private class ReleaseFrames implements Runnable { - private final long remoteImpliedPos; - - public ReleaseFrames(long remoteImpliedPos) { - this.remoteImpliedPos = remoteImpliedPos; + @Override + public void dispose() { + cancelled = true; + Operators.terminate(S, this); } @Override - public void run() { - releaseFramesToPosition(remoteImpliedPos); + public boolean isDisposed() { + return cancelled || s == Operators.cancelledSubscription(); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java index 3a30544b6..80d9a36dd 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java @@ -50,6 +50,8 @@ public interface ResumableFramesStore extends Closeable { /** * Received resumable frame as defined by RSocket protocol. Implementation must increment frame * implied position + * + * @return {@code true} if information about the frame has been stored */ - void resumableFrameReceived(ByteBuf frame); + boolean resumableFrameReceived(ByteBuf frame); } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java deleted file mode 100644 index 4facdd3c1..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import io.netty.buffer.ByteBuf; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; - -class ResumeFramesSubscriber implements Subscriber, Disposable { - private final Flux requests; - private final Consumer onNext; - private final Consumer onError; - private final Runnable onComplete; - private final AtomicBoolean disposed = new AtomicBoolean(); - private volatile Disposable requestsDisposable; - private volatile Subscription subscription; - - public ResumeFramesSubscriber( - Flux requests, - Consumer onNext, - Consumer onError, - Runnable onComplete) { - this.requests = requests; - this.onNext = onNext; - this.onError = onError; - this.onComplete = onComplete; - } - - @Override - public void onSubscribe(Subscription s) { - if (isDisposed()) { - s.cancel(); - } else { - this.subscription = s; - this.requestsDisposable = requests.subscribe(s::request); - } - } - - @Override - public void onNext(ByteBuf frame) { - this.onNext.accept(frame); - } - - @Override - public void onError(Throwable t) { - this.onError.accept(t); - requestsDisposable.dispose(); - } - - @Override - public void onComplete() { - this.onComplete.run(); - requestsDisposable.dispose(); - } - - @Override - public void dispose() { - if (disposed.compareAndSet(false, true)) { - if (subscription != null) { - subscription.cancel(); - requestsDisposable.dispose(); - } - } - } - - @Override - public boolean isDisposed() { - return disposed.get(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java index b54ce644f..ad1b38375 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java @@ -18,143 +18,284 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.exceptions.RejectedResumeException; -import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.ResumeFrameCodec; import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; import java.time.Duration; -import java.util.function.Function; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.FluxProcessor; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Mono; -import reactor.core.publisher.ReplayProcessor; +import reactor.core.publisher.Operators; +import reactor.util.concurrent.Queues; -public class ServerRSocketSession implements RSocketSession { +public class ServerRSocketSession + implements RSocketSession, ResumeStateHolder, CoreSubscriber { private static final Logger logger = LoggerFactory.getLogger(ServerRSocketSession.class); - private final ResumableDuplexConnection resumableConnection; - /*used instead of EmitterProcessor because its autocancel=false capability had no expected effect*/ - private final FluxProcessor newConnections = - ReplayProcessor.create(0); - private final ByteBufAllocator allocator; - private final ByteBuf resumeToken; + final ResumableDuplexConnection resumableConnection; + final Duration resumeSessionDuration; + final ResumableFramesStore resumableFramesStore; + final String session; + final ByteBufAllocator allocator; + final boolean cleanupStoreOnKeepAlive; + + /** + * All incoming connections with the Resume intent are enqueued in this queue. Such an approach + * ensure that the new connection will affect the resumption state anyhow until the previous + * (active) connection is finally closed + */ + final Queue connectionsQueue; + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ServerRSocketSession.class, "wip"); + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(ServerRSocketSession.class, Subscription.class, "s"); + + KeepAliveSupport keepAliveSupport; public ServerRSocketSession( - DuplexConnection duplexConnection, + ByteBuf session, + ResumableDuplexConnection resumableDuplexConnection, + DuplexConnection initialDuplexConnection, + ResumableFramesStore resumableFramesStore, Duration resumeSessionDuration, - Duration resumeStreamTimeout, - Function resumeStoreFactory, - ByteBuf resumeToken, boolean cleanupStoreOnKeepAlive) { - this.allocator = duplexConnection.alloc(); - this.resumeToken = resumeToken; - this.resumableConnection = - new ResumableDuplexConnection( - "server", - duplexConnection, - resumeStoreFactory.apply(resumeToken), - resumeStreamTimeout, - cleanupStoreOnKeepAlive); - - Mono timeout = - resumableConnection - .connectionErrors() - .flatMap( - err -> { - logger.debug("Starting session timeout due to error", err); - return newConnections - .next() - .doOnNext(c -> logger.debug("Connection after error: {}", c)) - .timeout(resumeSessionDuration); - }) - .then() - .cast(DuplexConnection.class); - - newConnections - .mergeWith(timeout) - .subscribe( - connection -> { - reconnect(connection); - logger.debug("Server ResumableConnection reconnected: {}", connection); - }, - err -> { - logger.debug("Server ResumableConnection reconnect timeout"); - resumableConnection.dispose(); - }); + this.session = session.toString(CharsetUtil.UTF_8); + this.allocator = initialDuplexConnection.alloc(); + this.resumeSessionDuration = resumeSessionDuration; + this.resumableFramesStore = resumableFramesStore; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + this.resumableConnection = resumableDuplexConnection; + this.connectionsQueue = Queues.unboundedMultiproducer().get(); + + WIP.lazySet(this, 1); + + resumableDuplexConnection.onClose().doFinally(__ -> dispose()).subscribe(); + resumableDuplexConnection.onActiveConnectionClosed().subscribe(__ -> tryTimeoutSession()); } - @Override - public ServerRSocketSession continueWith(DuplexConnection connectionFactory) { - logger.debug("Server continued with connection: {}", connectionFactory); - newConnections.onNext(connectionFactory); - return this; + void tryTimeoutSession() { + keepAliveSupport.stop(); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Connection is lost. Trying to timeout the active session", + session); + } + + Mono.delay(resumeSessionDuration).subscribe(this); + + if (WIP.decrementAndGet(this) == 0) { + return; + } + + final Runnable doResumeRunnable = connectionsQueue.poll(); + if (doResumeRunnable != null) { + doResumeRunnable.run(); + } } - @Override - public ServerRSocketSession resumeWith(ByteBuf resumeFrame) { - logger.debug("Resume FRAME received"); - long remotePos = remotePos(resumeFrame); - long remoteImpliedPos = remoteImpliedPos(resumeFrame); - resumeFrame.release(); - - resumableConnection.resume( - remotePos, - remoteImpliedPos, - pos -> - pos.flatMap(impliedPos -> sendFrame(ResumeOkFrameCodec.encode(allocator, impliedPos))) - .onErrorResume( - err -> - sendFrame(ErrorFrameCodec.encode(allocator, 0, errorFrameThrowable(err))) - .then(Mono.fromRunnable(resumableConnection::dispose)) - /*Resumption is impossible: no need to return control to ResumableConnection*/ - .then(Mono.never()))); - return this; + public void resumeWith(ByteBuf resumeFrame, DuplexConnection nextDuplexConnection) { + + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. New DuplexConnection received.", session); + } + + long remotePos = ResumeFrameCodec.firstAvailableClientPos(resumeFrame); + long remoteImpliedPos = ResumeFrameCodec.lastReceivedServerPos(resumeFrame); + + connectionsQueue.offer(() -> doResume(remotePos, remoteImpliedPos, nextDuplexConnection)); + + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final Runnable doResumeRunnable = connectionsQueue.poll(); + if (doResumeRunnable != null) { + doResumeRunnable.run(); + } } - @Override - public void reconnect(DuplexConnection connection) { - resumableConnection.reconnect(connection); + void doResume(long remotePos, long remoteImpliedPos, DuplexConnection nextDuplexConnection) { + if (!tryCancelSessionTimeout()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final RejectedResumeException rejectedResumeException = + new RejectedResumeException("resume_internal_error: Session Expired"); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + return; + } + + long impliedPosition = resumableFramesStore.frameImpliedPosition(); + long position = resumableFramesStore.framePosition(); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Resume FRAME received. ServerResumeState[impliedPosition[{}], position[{}]]. ClientResumeState[remoteImpliedPosition[{}], remotePosition[{}]]", + session, + impliedPosition, + position, + remoteImpliedPos, + remotePos); + } + + if (remotePos <= impliedPosition && position <= remoteImpliedPos) { + try { + if (position != remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + nextDuplexConnection.sendFrame(0, ResumeOkFrameCodec.encode(allocator, impliedPosition)); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. ResumeOKFrame[impliedPosition[{}]] has been sent", + session, + impliedPosition); + } + } catch (Throwable t) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Exception occurred while releasing frames in the frameStore", + session, + t); + } + + dispose(); + + final RejectedResumeException rejectedResumeException = + new RejectedResumeException(t.getMessage(), t); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + + return; + } + + keepAliveSupport.start(); + + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. Session has been resumed successfully", session); + } + + if (!resumableConnection.connect(nextDuplexConnection)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final RejectedResumeException rejectedResumeException = + new RejectedResumeException("resume_internal_error: Session Expired"); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + + // resumableConnection is likely to be disposed at this stage. Thus we have + // nothing to do + } + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Mismatching remote and local state. Expected RemoteImpliedPosition[{}] to be greater or equal to the LocalPosition[{}] and RemotePosition[{}] to be less or equal to LocalImpliedPosition[{}]. Terminating received connection", + session, + remoteImpliedPos, + position, + remotePos, + impliedPosition); + } + + dispose(); + + final RejectedResumeException rejectedResumeException = + new RejectedResumeException( + String.format( + "resumption_pos=[ remote: { pos: %d, impliedPos: %d }, local: { pos: %d, impliedPos: %d }]", + remotePos, remoteImpliedPos, position, impliedPosition)); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + } + } + + boolean tryCancelSessionTimeout() { + for (; ; ) { + final Subscription subscription = this.s; + + if (subscription == Operators.cancelledSubscription()) { + return false; + } + + if (S.compareAndSet(this, subscription, null)) { + subscription.cancel(); + return true; + } + } } @Override - public ResumableDuplexConnection resumableConnection() { - return resumableConnection; + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); } @Override - public ByteBuf token() { - return resumeToken; + public void onImpliedPosition(long remoteImpliedPos) { + if (cleanupStoreOnKeepAlive) { + try { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } catch (Throwable e) { + resumableConnection.sendErrorAndClose(new ConnectionErrorException(e.getMessage(), e)); + } + } } - private Mono sendFrame(ByteBuf frame) { - logger.debug("Sending Resume frame: {}", frame); - return resumableConnection.sendOne(frame).onErrorResume(e -> Mono.empty()); + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } } - private static long remotePos(ByteBuf resumeFrame) { - return ResumeFrameCodec.firstAvailableClientPos(resumeFrame); + @Override + public void onNext(Long aLong) { + if (!Operators.terminate(S, this)) { + return; + } + + resumableConnection.dispose(); } - private static long remoteImpliedPos(ByteBuf resumeFrame) { - return ResumeFrameCodec.lastReceivedServerPos(resumeFrame); + @Override + public void onComplete() {} + + @Override + public void onError(Throwable t) {} + + public void setKeepAliveSupport(KeepAliveSupport keepAliveSupport) { + this.keepAliveSupport = keepAliveSupport; } - private static RejectedResumeException errorFrameThrowable(Throwable err) { - String msg; - if (err instanceof ResumeStateException) { - ResumeStateException resumeException = ((ResumeStateException) err); - msg = - String.format( - "resumption_pos=[ remote: { pos: %d, impliedPos: %d }, local: { pos: %d, impliedPos: %d }]", - resumeException.getRemotePos(), - resumeException.getRemoteImpliedPos(), - resumeException.getLocalPos(), - resumeException.getLocalImpliedPos()); - } else { - msg = String.format("resume_internal_error: %s", err.getMessage()); + @Override + public void dispose() { + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. Disposing session", session); } - return new RejectedResumeException(msg); + Operators.terminate(S, this); + resumableConnection.dispose(); + } + + @Override + public boolean isDisposed() { + return resumableConnection.isDisposed(); } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java index 1d5c23bd6..736d7c77c 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java @@ -17,27 +17,36 @@ package io.rsocket.resume; import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.util.annotation.Nullable; public class SessionManager { + static final Logger logger = LoggerFactory.getLogger(SessionManager.class); + private volatile boolean isDisposed; - private final Map sessions = new ConcurrentHashMap<>(); + private final Map sessions = new ConcurrentHashMap<>(); - public ServerRSocketSession save(ServerRSocketSession session) { + public ServerRSocketSession save(ServerRSocketSession session, ByteBuf resumeToken) { if (isDisposed) { session.dispose(); } else { - ByteBuf token = session.token().retain(); + final String token = resumeToken.toString(CharsetUtil.UTF_8); session + .resumableConnection .onClose() - .doOnSuccess( - v -> { + .doFinally( + __ -> { + logger.debug( + "ResumableConnection has been closed. Removing associated session {" + + token + + "}"); if (isDisposed || sessions.get(token) == session) { sessions.remove(token); } - token.release(); }) .subscribe(); ServerRSocketSession prevSession = sessions.remove(token); @@ -51,7 +60,7 @@ public ServerRSocketSession save(ServerRSocketSession session) { @Nullable public ServerRSocketSession get(ByteBuf resumeToken) { - return sessions.get(resumeToken); + return sessions.get(resumeToken.toString(CharsetUtil.UTF_8)); } public void dispose() { diff --git a/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java b/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java deleted file mode 100644 index f010a05bd..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import io.netty.buffer.ByteBuf; -import java.util.Queue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; -import reactor.util.concurrent.Queues; - -class UpstreamFramesSubscriber implements Subscriber, Disposable { - private static final Logger logger = LoggerFactory.getLogger(UpstreamFramesSubscriber.class); - - private final AtomicBoolean disposed = new AtomicBoolean(); - private final Consumer itemConsumer; - private final Disposable downstreamRequestDisposable; - private final Disposable resumeSaveStreamDisposable; - - private volatile Subscription subs; - private volatile boolean resumeStarted; - private final Queue framesCache; - private long request; - private long downStreamRequestN; - private long resumeSaveStreamRequestN; - - UpstreamFramesSubscriber( - int estimatedDownstreamRequest, - Flux downstreamRequests, - Flux resumeSaveStreamRequests, - Consumer itemConsumer) { - this.itemConsumer = itemConsumer; - this.framesCache = Queues.unbounded(estimatedDownstreamRequest).get(); - - downstreamRequestDisposable = downstreamRequests.subscribe(requestN -> requestN(0, requestN)); - - resumeSaveStreamDisposable = - resumeSaveStreamRequests.subscribe(requestN -> requestN(requestN, 0)); - } - - @Override - public void onSubscribe(Subscription s) { - this.subs = s; - if (!isDisposed()) { - doRequest(); - } else { - s.cancel(); - } - } - - @Override - public void onNext(ByteBuf item) { - processFrame(item); - } - - @Override - public void onError(Throwable t) { - dispose(); - } - - @Override - public void onComplete() { - dispose(); - } - - public void resumeStart() { - resumeStarted = true; - } - - public void resumeComplete() { - ByteBuf frame = framesCache.poll(); - while (frame != null) { - itemConsumer.accept(frame); - frame = framesCache.poll(); - } - resumeStarted = false; - doRequest(); - } - - @Override - public void dispose() { - if (disposed.compareAndSet(false, true)) { - releaseCache(); - if (subs != null) { - subs.cancel(); - } - resumeSaveStreamDisposable.dispose(); - downstreamRequestDisposable.dispose(); - } - } - - @Override - public boolean isDisposed() { - return disposed.get(); - } - - private void requestN(long resumeStreamRequest, long downStreamRequest) { - synchronized (this) { - downStreamRequestN = Operators.addCap(downStreamRequestN, downStreamRequest); - resumeSaveStreamRequestN = Operators.addCap(resumeSaveStreamRequestN, resumeStreamRequest); - - long requests = Math.min(downStreamRequestN, resumeSaveStreamRequestN); - if (requests > 0) { - downStreamRequestN -= requests; - resumeSaveStreamRequestN -= requests; - logger.debug("Upstream subscriber requestN: {}", requests); - request = Operators.addCap(request, requests); - } - } - doRequest(); - } - - private void doRequest() { - if (subs != null && !resumeStarted) { - synchronized (this) { - long r = request; - if (r > 0) { - subs.request(r); - request = 0; - } - } - } - } - - private void releaseCache() { - ByteBuf frame = framesCache.poll(); - while (frame != null && frame.refCnt() > 0) { - frame.release(); - } - } - - private void processFrame(ByteBuf item) { - if (resumeStarted) { - framesCache.offer(item); - } else { - itemConsumer.accept(item); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java index d59b9fe97..08b8b2fb7 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -100,7 +100,7 @@ public static Payload create(ByteBuf data) { public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { try { - return create(data.nioBuffer(), metadata == null ? null : metadata.nioBuffer()); + return create(toBytes(data), metadata != null ? toBytes(metadata) : null); } finally { data.release(); if (metadata != null) { @@ -110,7 +110,16 @@ public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { } public static Payload create(Payload payload) { - return create(payload.getData(), payload.hasMetadata() ? payload.getMetadata() : null); + return create( + toBytes(payload.data()), payload.hasMetadata() ? toBytes(payload.metadata()) : null); + } + + private static byte[] toBytes(ByteBuf byteBuf) { + byte[] bytes = new byte[byteBuf.readableBytes()]; + byteBuf.markReaderIndex(); + byteBuf.readBytes(bytes); + byteBuf.resetReaderIndex(); + return bytes; } @Override diff --git a/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json b/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json new file mode 100644 index 000000000..0a3844451 --- /dev/null +++ b/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json @@ -0,0 +1,130 @@ +[ + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseLinkedQueueConsumerNodeRef" + }, + "name": "io.rsocket.internal.jctools.queues.BaseLinkedQueueConsumerNodeRef", + "fields": [ + { + "name": "consumerNode" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseLinkedQueueProducerNodeRef" + }, + "name": "io.rsocket.internal.jctools.queues.BaseLinkedQueueProducerNodeRef", + "fields": [ + { + "name": "producerNode" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields", + "fields": [ + { + "name": "producerLimit" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields", + "fields": [ + { + "name": "consumerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueProducerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueProducerFields", + "fields": [ + { + "name": "producerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.LinkedQueueNode" + }, + "name": "io.rsocket.internal.jctools.queues.LinkedQueueNode", + "fields": [ + { + "name": "next" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueConsumerIndexField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueConsumerIndexField", + "fields": [ + { + "name": "consumerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerIndexField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerIndexField", + "fields": [ + { + "name": "producerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerLimitField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerLimitField", + "fields": [ + { + "name": "producerLimit" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.UnsafeAccess" + }, + "name": "sun.misc.Unsafe", + "fields": [ + { + "name": "theUnsafe" + } + ], + "queriedMethods": [ + { + "name": "getAndAddLong", + "parameterTypes": [ + "java.lang.Object", + "long", + "long" + ] + }, + { + "name": "getAndSetObject", + "parameterTypes": [ + "java.lang.Object", + "long", + "java.lang.Object" + ] + } + ] + } +] \ No newline at end of file diff --git a/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java b/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java new file mode 100644 index 000000000..d30f1415e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java @@ -0,0 +1,6 @@ +package io.rsocket; + +public class RaceTestConstants { + public static final int REPEATS = + Integer.parseInt(System.getProperty("rsocket.test.race.repeats", "1000")); +} diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java index 800e5d678..1db708ab5 100644 --- a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -1,16 +1,29 @@ package io.rsocket.buffer; +import static java.util.concurrent.locks.LockSupport.parkNanos; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; import org.assertj.core.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created * ByteBuffs */ public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + static final Logger LOGGER = LoggerFactory.getLogger(LeaksTrackingByteBufAllocator.class); /** * Allows to instrument any given the instance of ByteBufAllocator @@ -19,24 +32,96 @@ public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { * @return */ public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { - return new LeaksTrackingByteBufAllocator(allocator); + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO, ""); + } + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument( + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration, String tag) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration, tag); } final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); final ByteBufAllocator delegate; - private LeaksTrackingByteBufAllocator(ByteBufAllocator delegate) { + final Duration awaitZeroRefCntDuration; + + final String tag; + + private LeaksTrackingByteBufAllocator( + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration, String tag) { this.delegate = delegate; + this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + this.tag = tag; } public LeaksTrackingByteBufAllocator assertHasNoLeaks() { try { - Assertions.assertThat(tracker) - .allSatisfy( - buf -> - Assertions.assertThat(buf) - .matches(bb -> bb.refCnt() == 0, "buffer should be released")); + ArrayList unreleased = new ArrayList<>(); + for (ByteBuf bb : tracker) { + if (bb.refCnt() != 0) { + unreleased.add(bb); + } + } + + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!unreleased.isEmpty() && !awaitZeroRefCntDuration.isZero()) { + final long startTime = System.currentTimeMillis(); + final long endTimeInMillis = startTime + awaitZeroRefCntDuration.toMillis(); + boolean hasUnreleased; + while (System.currentTimeMillis() <= endTimeInMillis) { + hasUnreleased = false; + for (ByteBuf bb : unreleased) { + if (bb.refCnt() != 0) { + hasUnreleased = true; + break; + } + } + + if (!hasUnreleased) { + return this; + } + + LOGGER.debug(tag + " await buffers to be released"); + for (int i = 0; i < 100; i++) { + System.gc(); + parkNanos(1000); + System.gc(); + } + } + } + + Set collected = new HashSet<>(); + for (ByteBuf buf : unreleased) { + if (buf.refCnt() != 0) { + try { + collected.add(buf); + } catch (IllegalReferenceCountException ignored) { + // fine to ignore if throws because of refCnt + } + } + } + + Assertions.assertThat( + collected + .stream() + .filter(bb -> bb.refCnt() != 0) + .peek( + bb -> { + try { + LOGGER.debug(tag + " " + resolveTrackingInfo(bb)); + } catch (Exception e) { + e.printStackTrace(); + } + })) + .describedAs("[" + tag + "] all buffers expected to be released but got ") + .isEmpty(); } finally { tracker.clear(); } @@ -150,4 +235,60 @@ T track(T buffer) { return buffer; } + + static final Class simpleLeakAwareCompositeByteBufClass; + static final Field leakFieldForComposite; + static final Class simpleLeakAwareByteBufClass; + static final Field leakFieldForNormal; + static final Field allLeaksField; + + static { + try { + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareCompositeByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareCompositeByteBufClass = aClass; + leakFieldForComposite = leakField; + } + + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareByteBufClass = aClass; + leakFieldForNormal = leakField; + } + + { + final Class aClass = + Class.forName("io.netty.util.ResourceLeakDetector$DefaultResourceLeak"); + final Field field = aClass.getDeclaredField("allLeaks"); + + field.setAccessible(true); + + allLeaksField = field; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + static Set resolveTrackingInfo(ByteBuf byteBuf) throws Exception { + if (ResourceLeakDetector.getLevel().ordinal() + >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { + if (simpleLeakAwareCompositeByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForComposite.get(byteBuf)); + } else if (simpleLeakAwareByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForNormal.get(byteBuf)); + } + } + + return Collections.emptySet(); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java index 84b46ea69..310e15b3e 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -23,12 +23,10 @@ import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; +import java.time.Duration; import org.reactivestreams.Subscriber; -public abstract class AbstractSocketRule extends ExternalResource { +public abstract class AbstractSocketRule { protected TestDuplexConnection connection; protected Subscriber connectSub; @@ -37,43 +35,38 @@ public abstract class AbstractSocketRule extends ExternalReso protected int maxFrameLength = FRAME_LENGTH_MASK; protected int maxInboundPayloadSize = Integer.MAX_VALUE; - @Override - public Statement apply(final Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - connectSub = TestSubscriber.create(); - init(); - base.evaluate(); - } - }; + public void init() { + allocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(5), ""); + connectSub = TestSubscriber.create(); + doInit(); } - protected void init() { - if (socket != null) { - socket.dispose(); - } + protected void doInit() { if (connection != null) { connection.dispose(); } + if (socket != null) { + socket.dispose(); + } connection = new TestDuplexConnection(allocator); socket = newRSocket(); } public void setMaxInboundPayloadSize(int maxInboundPayloadSize) { this.maxInboundPayloadSize = maxInboundPayloadSize; - init(); + doInit(); } public void setMaxFrameLength(int maxFrameLength) { this.maxFrameLength = maxFrameLength; - init(); + doInit(); } protected abstract T newRSocket(); - public ByteBufAllocator alloc() { + public LeaksTrackingByteBufAllocator alloc() { return allocator; } diff --git a/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java new file mode 100644 index 000000000..195df9434 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.test.util.TestDuplexConnection; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ClientServerInputMultiplexerTest { + private TestDuplexConnection source; + private ClientServerInputMultiplexer clientMultiplexer; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private ClientServerInputMultiplexer serverMultiplexer; + + @BeforeEach + public void setup() { + source = new TestDuplexConnection(allocator); + clientMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); + serverMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); + } + + @Test + public void clientSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + + clientMultiplexer + .asClientConnection() + .receive() + .doOnNext( + f -> { + clientFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + clientMultiplexer + .asServerConnection() + .receive() + .doOnNext( + f -> { + serverFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isOne(); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(leaseFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(3); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(keepAliveFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(4); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(2).retain()); + assertThat(clientFrames.get()).isEqualTo(4); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(errorFrame(0).retain()); + assertThat(clientFrames.get()).isEqualTo(5); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(metadataPushFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(5); + assertThat(serverFrames.get()).isEqualTo(2); + } + + @Test + public void serverSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + + serverMultiplexer + .asClientConnection() + .receive() + .doOnNext( + f -> { + clientFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + serverMultiplexer + .asServerConnection() + .receive() + .doOnNext( + f -> { + serverFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(1); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(leaseFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(keepAliveFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(2); + + source.addToReceivedBuffer(errorFrame(2).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(3); + + source.addToReceivedBuffer(errorFrame(0).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(4); + + source.addToReceivedBuffer(metadataPushFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(3); + assertThat(serverFrames.get()).isEqualTo(4); + } + + private ByteBuf leaseFrame() { + return LeaseFrameCodec.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf errorFrame(int i) { + return ErrorFrameCodec.encode(allocator, i, new Exception()); + } + + private ByteBuf keepAliveFrame() { + return KeepAliveFrameCodec.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf metadataPushFrame() { + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java index d080b166d..84576e6ce 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -1,6 +1,6 @@ package io.rsocket.core; /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,33 +15,31 @@ * limitations under the License. */ -import static io.rsocket.frame.FrameHeaderCodec.frameType; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; - import io.netty.buffer.ByteBuf; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.RSocketProxy; import java.time.Duration; import java.util.ArrayList; import java.util.Collection; -import java.util.List; import java.util.Map; import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.Stream; import org.assertj.core.api.Assertions; @@ -52,18 +50,19 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.junit.runners.model.Statement; +import org.mockito.Mockito; import org.reactivestreams.Publisher; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.SignalType; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; import reactor.test.util.RaceTestUtils; import reactor.util.context.Context; +import reactor.util.context.ContextView; import reactor.util.retry.Retry; public class DefaultRSocketClientTests { @@ -75,19 +74,33 @@ public void setUp() throws Throwable { Hooks.onNextDropped(ReferenceCountUtil::safeRelease); Hooks.onErrorDropped((t) -> {}); rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); } @AfterEach public void tearDown() { Hooks.resetOnErrorDropped(); Hooks.resetOnNextDropped(); + rule.allocator.assertHasNoLeaks(); + } + + @Test + @SuppressWarnings("unchecked") + void discardElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = Mockito.mock(ReferenceCounted.class); + Mockito.when(referenceCounted.release()).thenReturn(true); + Mockito.when(referenceCounted.refCnt()).thenReturn(1); + + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(referenceCounted); + + Mockito.verify(referenceCounted).release(); } static Stream interactions() { @@ -179,19 +192,12 @@ public void shouldSentFrameOnResolution( @MethodSource("interactions") @SuppressWarnings({"unchecked", "rawtypes"}) public void shouldHaveNoLeaksOnPayloadInCaseOfRacingOfOnNextAndCancel( - BiFunction, Publisher> request, FrameType requestType) - throws Throwable { + BiFunction, Publisher> request, FrameType requestType) { Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ClientSocketRule rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); Payload payload = ByteBufPayload.create("test", "testMetadata"); TestPublisher testPublisher = TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); @@ -241,19 +247,12 @@ public void evaluate() {} @MethodSource("interactions") @SuppressWarnings({"unchecked", "rawtypes"}) public void shouldHaveNoLeaksOnPayloadInCaseOfRacingOfRequestAndCancel( - BiFunction, Publisher> request, FrameType requestType) - throws Throwable { + BiFunction, Publisher> request, FrameType requestType) { Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ClientSocketRule rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); ByteBuf dataBuffer = rule.allocator.buffer(); dataBuffer.writeCharSequence("test", CharsetUtil.UTF_8); @@ -311,14 +310,17 @@ public void shouldPropagateDownstreamContext( Payload payload = ByteBufPayload.create(dataBuffer, metadataBuffer); AssertSubscriber assertSubscriber = new AssertSubscriber(Context.of("test", "test")); - Context[] receivedContext = new Context[1]; + ContextView[] receivedContext = new Context[1]; Publisher publisher = request.apply( rule.client, Mono.just(payload) .mergeWith( - Mono.subscriberContext() - .doOnNext(c -> receivedContext[0] = c) + Mono.deferContextual( + c -> { + receivedContext[0] = c; + return Mono.empty(); + }) .then(Mono.empty()))); publisher.subscribe(assertSubscriber); @@ -455,6 +457,8 @@ public void shouldBeAbleToResolveOriginalSource() { assertSubscriber1.assertTerminated().assertValueCount(1); Assertions.assertThat(assertSubscriber1.values()).isEqualTo(assertSubscriber.values()); + + rule.allocator.assertHasNoLeaks(); } @Test @@ -478,19 +482,173 @@ public void shouldDisposeOriginalSource() { .assertErrorMessage("Disposed"); Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldReceiveOnCloseNotificationOnDisposeOriginalSource() { + Sinks.Empty onCloseDelayer = Sinks.empty(); + ClientSocketRule rule = + new ClientSocketRule() { + @Override + protected RSocket newRSocket() { + return new RSocketProxy(super.newRSocket()) { + @Override + public Mono onClose() { + return super.onClose().and(onCloseDelayer.asMono()); + } + }; + } + }; + rule.init(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber onCloseSubscriber = AssertSubscriber.create(); + + rule.client.onClose().subscribe(onCloseSubscriber); + onCloseSubscriber.assertNotTerminated(); + + onCloseDelayer.tryEmitEmpty(); + + onCloseSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); } @Test - public void shouldDisposeOriginalSourceIfRacing() throws Throwable { - for (int i = 0; i < 10000; i++) { + public void shouldResolveOnStartSource() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldNotStartIfAlreadyDisposed() { + Assertions.assertThat(rule.client.connect()).isTrue(); + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.delayer.run(); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.connect()).isFalse(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldBeRestartedIfSourceWasClosed() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + AssertSubscriber terminateSubscriber = AssertSubscriber.create(); + + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber); + rule.client.onClose().subscribe(terminateSubscriber); + + rule.delayer.run(); + + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.socket.dispose(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + terminateSubscriber.assertNotTerminated(); + Assertions.assertThat(rule.client.isDisposed()).isFalse(); + + rule.connection = new TestDuplexConnection(rule.allocator); + rule.socket = rule.newRSocket(); + rule.producer = Sinks.one(); + + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(); + + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber2); + + rule.delayer.run(); + + assertSubscriber2.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + terminateSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.client.connect()).isFalse(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldDisposeOriginalSourceIfRacing() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ClientSocketRule rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + + rule.init(); AssertSubscriber assertSubscriber = AssertSubscriber.create(); rule.client.source().subscribe(assertSubscriber); @@ -510,29 +668,79 @@ public void evaluate() {} .assertTerminated() .assertError(CancellationException.class) .assertErrorMessage("Disposed"); + + ByteBuf buf; + while ((buf = rule.connection.pollFrame()) != null) { + FrameAssert.assertThat(buf).hasStreamIdZero().hasData("Disposed").hasNoLeaks(); + } + + rule.allocator.assertHasNoLeaks(); + } + } + + @Test + public void shouldStartOriginalSourceOnceIfRacing() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + + rule.init(); + + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + RaceTestUtils.race( + () -> rule.client.source().subscribe(assertSubscriber), () -> rule.client.connect()); + + Assertions.assertThat(rule.producer.currentSubscriberCount()).isOne(); + + rule.delayer.run(); + + assertSubscriber.assertTerminated(); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + assertSubscriber1.assertTerminated().assertComplete(); + + rule.allocator.assertHasNoLeaks(); } } - public static class ClientSocketRule extends AbstractSocketRule { + public static class ClientSocketRule extends AbstractSocketRule { protected RSocketClient client; protected Runnable delayer; - protected MonoProcessor producer; + protected Sinks.One producer; + + protected Sinks.Empty thisClosedSink; @Override - protected void init() { - super.init(); - delayer = () -> producer.onNext(socket); - producer = MonoProcessor.create(); + protected void doInit() { + super.doInit(); + delayer = () -> producer.tryEmitValue(socket); + producer = Sinks.one(); client = new DefaultRSocketClient( - producer - .doOnCancel(() -> socket.dispose()) - .doOnDiscard(Disposable.class, Disposable::dispose)); + Mono.defer( + () -> + producer + .asMono() + .doOnCancel(() -> socket.dispose()) + .doOnDiscard(Disposable.class, Disposable::dispose))); } @Override - protected RSocketRequester newRSocket() { + protected RSocket newRSocket() { + this.thisClosedSink = Sinks.empty(); return new RSocketRequester( connection, PayloadDecoder.ZERO_COPY, @@ -543,24 +751,10 @@ protected RSocketRequester newRSocket() { Integer.MAX_VALUE, Integer.MAX_VALUE, null, - RequesterLeaseHandler.None); - } - - public int getStreamIdForRequestType(FrameType expectedFrameType) { - assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); - List framesFound = new ArrayList<>(); - for (ByteBuf frame : connection.getSent()) { - FrameType frameType = frameType(frame); - if (frameType == expectedFrameType) { - return FrameHeaderCodec.streamId(frame); - } - framesFound.add(frameType); - } - throw new AssertionError( - "No frames sent with frame type: " - + expectedFrameType - + ", frames found: " - + framesFound); + __ -> null, + null, + thisClosedSink, + thisClosedSink.asMono()); } } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java index cb5044e17..f5422a4bf 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java @@ -14,7 +14,8 @@ import io.rsocket.Payload; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import java.time.Duration; import java.util.Arrays; @@ -46,7 +47,9 @@ public static void setUp() { @ParameterizedTest @MethodSource("frameSent") public void frameShouldBeSentOnSubscription(Consumer monoConsumer) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Payload payload = genericPayload(activeStreams.getAllocator()); final FireAndForgetRequesterMono fireAndForgetRequesterMono = new FireAndForgetRequesterMono(payload, activeStreams); @@ -62,8 +65,7 @@ public void frameShouldBeSentOnSubscription(Consumer // should not add anything to map stateAssert.isTerminated(); activeStreams.assertNoActiveStreams(); - - final ByteBuf frame = activeStreams.getSendProcessor().poll(); + final ByteBuf frame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() .hasPayloadSize( @@ -77,8 +79,12 @@ public void frameShouldBeSentOnSubscription(Consumer .hasStreamId(1) .hasNoLeaks(); - Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectNothing(); } /** @@ -93,7 +99,7 @@ public void frameFragmentsShouldBeSentOnSubscription( final int mtu = 64; final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(mtu); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); - final UnboundedProcessor sender = streamManager.getSendProcessor(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); final byte[] metadata = new byte[65]; final byte[] data = new byte[129]; @@ -118,7 +124,7 @@ public void frameFragmentsShouldBeSentOnSubscription( Assertions.assertThat(payload.refCnt()).isZero(); - final ByteBuf frameFragment1 = sender.poll(); + final ByteBuf frameFragment1 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment1) .isNotNull() .hasPayloadSize( @@ -132,7 +138,7 @@ public void frameFragmentsShouldBeSentOnSubscription( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment2 = sender.poll(); + final ByteBuf frameFragment2 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment2) .isNotNull() .hasPayloadSize( @@ -146,7 +152,7 @@ public void frameFragmentsShouldBeSentOnSubscription( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment3 = sender.poll(); + final ByteBuf frameFragment3 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment3) .isNotNull() .hasPayloadSize( @@ -159,7 +165,7 @@ public void frameFragmentsShouldBeSentOnSubscription( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment4 = sender.poll(); + final ByteBuf frameFragment4 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment4) .isNotNull() .hasPayloadSize(35) @@ -189,9 +195,11 @@ static Stream> frameSent() { @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") public void shouldErrorOnIncorrectRefCntInGivenPayload( Consumer monoConsumer) { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); - final UnboundedProcessor sender = streamManager.getSendProcessor(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); final Payload payload = ByteBufPayload.create(""); payload.release(); @@ -210,6 +218,9 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject(FrameType.REQUEST_FNF, new IllegalReferenceCountException("refCnt: 0")) + .expectNothing(); } static Stream> @@ -233,9 +244,11 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( Consumer monoConsumer) { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); - final UnboundedProcessor sender = streamManager.getSendProcessor(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); final byte[] metadata = new byte[FRAME_LENGTH_MASK]; final byte[] data = new byte[FRAME_LENGTH_MASK]; @@ -260,6 +273,12 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( streamManager.assertNoActiveStreams(); Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject( + FrameType.REQUEST_FNF, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK))) + .expectNothing(); } static Stream> @@ -289,10 +308,12 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( @ParameterizedTest @MethodSource("shouldErrorIfNoAvailabilitySource") public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RuntimeException exception = new RuntimeException("test"); final TestRequesterResponderSupport streamManager = - TestRequesterResponderSupport.client(new RuntimeException("test")); + TestRequesterResponderSupport.client(exception, testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); - final UnboundedProcessor sender = streamManager.getSendProcessor(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); final Payload payload = genericPayload(allocator); final FireAndForgetRequesterMono fireAndForgetRequesterMono = @@ -311,6 +332,7 @@ public void shouldErrorIfNoAvailability(Consumer mon streamManager.assertNoActiveStreams(); Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); + testRequestInterceptor.expectOnReject(FrameType.REQUEST_FNF, exception).expectNothing(); } static Stream> shouldErrorIfNoAvailabilitySource() { @@ -333,9 +355,11 @@ static Stream> shouldErrorIfNoAvailabilityS /** Ensures single subscription happens in case of racing */ @Test public void shouldSubscribeExactlyOnce1() { - final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); - final UnboundedProcessor sender = streamManager.getSendProcessor(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); for (int i = 1; i < 50000; i += 2) { final Payload payload = ByteBufPayload.create("testData", "testMetadata"); @@ -349,7 +373,7 @@ public void shouldSubscribeExactlyOnce1() { () -> RaceTestUtils.race( () -> { - AtomicReference atomicReference = new AtomicReference(); + AtomicReference atomicReference = new AtomicReference<>(); fireAndForgetRequesterMono.subscribe(null, atomicReference::set); Throwable throwable = atomicReference.get(); if (throwable != null) { @@ -364,7 +388,7 @@ public void shouldSubscribeExactlyOnce1() { return true; }); - final ByteBuf frame = sender.poll(); + final ByteBuf frame = sender.awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() .hasPayloadSize( @@ -380,6 +404,27 @@ public void shouldSubscribeExactlyOnce1() { stateAssert.isTerminated(); streamManager.assertNoActiveStreams(); + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); } Assertions.assertThat(sender.isEmpty()).isTrue(); @@ -391,7 +436,6 @@ public void checkName() { final TestRequesterResponderSupport testRequesterResponderSupport = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); - final UnboundedProcessor sender = testRequesterResponderSupport.getSendProcessor(); final Payload payload = ByteBufPayload.create("testData", "testMetadata"); final FireAndForgetRequesterMono fireAndForgetRequesterMono = diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index 78f7bff66..5be59235c 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -23,25 +23,29 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.rsocket.FrameAssert; import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ConnectionErrorException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.KeepAliveFrameCodec; -import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.RSocketSession; import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumeStateHolder; import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.util.DefaultPayload; import java.time.Duration; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import reactor.core.Disposable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.scheduler.VirtualTimeScheduler; @@ -66,19 +70,23 @@ static RSocketState requester(int tickPeriod, int timeout) { LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); TestDuplexConnection connection = new TestDuplexConnection(allocator); + Sinks.Empty empty = Sinks.empty(); RSocketRequester rSocket = new RSocketRequester( connection, - DefaultPayload::create, + PayloadDecoder.ZERO_COPY, StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, tickPeriod, timeout, - new DefaultKeepAliveHandler(connection), - RequesterLeaseHandler.None); - return new RSocketState(rSocket, allocator, connection); + new DefaultKeepAliveHandler(), + r -> null, + null, + empty, + empty.asMono()); + return new RSocketState(rSocket, allocator, connection, empty); } static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { @@ -88,24 +96,30 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { ResumableDuplexConnection resumableConnection = new ResumableDuplexConnection( "test", + Unpooled.EMPTY_BUFFER, connection, - new InMemoryResumableFramesStore("test", 10_000), - Duration.ofSeconds(10), - false); + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 10_000)); + Sinks.Empty onClose = Sinks.empty(); RSocketRequester rSocket = new RSocketRequester( resumableConnection, - DefaultPayload::create, + PayloadDecoder.ZERO_COPY, StreamIdSupplier.clientSupplier(), 0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, tickPeriod, timeout, - new ResumableKeepAliveHandler(resumableConnection), - RequesterLeaseHandler.None); - return new ResumableRSocketState(rSocket, connection, resumableConnection, allocator); + new ResumableKeepAliveHandler( + resumableConnection, + Mockito.mock(RSocketSession.class), + Mockito.mock(ResumeStateHolder.class)), + __ -> null, + null, + onClose, + onClose.asMono()); + return new ResumableRSocketState(rSocket, connection, resumableConnection, onClose, allocator); } @Test @@ -145,11 +159,14 @@ void noKeepAlivesSentAfterRSocketDispose() { requesterState.rSocket().dispose(); Duration duration = Duration.ofMillis(500); - StepVerifier.create(Flux.from(requesterState.connection().getSentAsPublisher()).take(duration)) - .then(() -> virtualTimeScheduler.advanceTimeBy(duration)) - .expectComplete() - .verify(Duration.ofSeconds(1)); + virtualTimeScheduler.advanceTimeBy(duration); + + FrameAssert.assertThat(requesterState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); + FrameAssert.assertThat(requesterState.connection.pollFrame()).isNull(); requesterState.allocator.assertHasNoLeaks(); } @@ -178,17 +195,18 @@ void clientRequesterSendsKeepAlives() { RSocketState RSocketState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); TestDuplexConnection connection = RSocketState.connection(); - StepVerifier.create(Flux.from(connection.getSentAsPublisher()).take(3)) - .then(() -> virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL))) - .expectNextMatches(this::keepAliveFrameWithRespondFlag) - .then(() -> virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL))) - .expectNextMatches(this::keepAliveFrameWithRespondFlag) - .then(() -> virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL))) - .expectNextMatches(this::keepAliveFrameWithRespondFlag) - .expectComplete() - .verify(Duration.ofSeconds(5)); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); RSocketState.rSocket.dispose(); + FrameAssert.assertThat(connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); RSocketState.connection.dispose(); RSocketState.allocator.assertHasNoLeaks(); @@ -206,13 +224,17 @@ void requesterRespondsToKeepAlives() { KeepAliveFrameCodec.encode( rSocketState.allocator, true, 0, Unpooled.EMPTY_BUFFER))); - StepVerifier.create(Flux.from(connection.getSentAsPublisher()).take(1)) - .then(() -> virtualTimeScheduler.advanceTimeBy(duration)) - .expectNextMatches(this::keepAliveFrameWithoutRespondFlag) - .expectComplete() - .verify(Duration.ofSeconds(5)); + virtualTimeScheduler.advanceTimeBy(duration); + FrameAssert.assertThat(connection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .matches(this::keepAliveFrameWithoutRespondFlag); rSocketState.rSocket.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); rSocketState.connection.dispose(); rSocketState.allocator.assertHasNoLeaks(); @@ -227,11 +249,9 @@ void resumableRequesterNoKeepAlivesAfterDisconnect() { resumableDuplexConnection.disconnect(); - Duration duration = Duration.ofMillis(500); - StepVerifier.create(Flux.from(testConnection.getSentAsPublisher()).take(duration)) - .then(() -> virtualTimeScheduler.advanceTimeBy(duration)) - .expectComplete() - .verify(Duration.ofSeconds(5)); + Duration duration = Duration.ofMillis(KEEP_ALIVE_INTERVAL * 5); + virtualTimeScheduler.advanceTimeBy(duration); + Assertions.assertThat(testConnection.pollFrame()).isNull(); rSocketState.rSocket.dispose(); rSocketState.connection.dispose(); @@ -246,17 +266,28 @@ void resumableRequesterKeepAlivesAfterReconnect() { ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); resumableDuplexConnection.disconnect(); TestDuplexConnection newTestConnection = new TestDuplexConnection(rSocketState.alloc()); - resumableDuplexConnection.reconnect(newTestConnection); - resumableDuplexConnection.resume(0, 0, ignored -> Mono.empty()); + resumableDuplexConnection.connect(newTestConnection); + // resumableDuplexConnection.(0, 0, ignored -> Mono.empty()); - StepVerifier.create(Flux.from(newTestConnection.getSentAsPublisher()).take(1)) - .then(() -> virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL))) - .expectNextMatches(frame -> keepAliveFrame(frame) && frame.release()) - .expectComplete() - .verify(Duration.ofSeconds(5)); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + + FrameAssert.assertThat(newTestConnection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .hasStreamIdZero() + .hasNoLeaks(); rSocketState.rSocket.dispose(); - rSocketState.connection.dispose(); + FrameAssert.assertThat(newTestConnection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + FrameAssert.assertThat(newTestConnection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Connection Closed Unexpectedly") // API limitations + .hasNoLeaks(); + newTestConnection.dispose(); rSocketState.allocator.assertHasNoLeaks(); } @@ -273,7 +304,17 @@ void resumableRequesterNoKeepAlivesAfterDispose() { .verify(Duration.ofSeconds(5)); rSocketState.rSocket.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); rSocketState.connection.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Connection Closed Unexpectedly") + .hasNoLeaks(); rSocketState.allocator.assertHasNoLeaks(); } @@ -314,12 +355,17 @@ static class RSocketState { private final RSocket rSocket; private final TestDuplexConnection connection; private final LeaksTrackingByteBufAllocator allocator; + private final Sinks.Empty onClose; public RSocketState( - RSocket rSocket, LeaksTrackingByteBufAllocator allocator, TestDuplexConnection connection) { + RSocket rSocket, + LeaksTrackingByteBufAllocator allocator, + TestDuplexConnection connection, + Sinks.Empty onClose) { this.rSocket = rSocket; this.connection = connection; this.allocator = allocator; + this.onClose = onClose; } public TestDuplexConnection connection() { @@ -340,15 +386,18 @@ static class ResumableRSocketState { private final TestDuplexConnection connection; private final ResumableDuplexConnection resumableDuplexConnection; private final LeaksTrackingByteBufAllocator allocator; + private final Sinks.Empty onClose; public ResumableRSocketState( RSocket rSocket, TestDuplexConnection connection, ResumableDuplexConnection resumableDuplexConnection, + Sinks.Empty onClose, LeaksTrackingByteBufAllocator allocator) { this.rSocket = rSocket; this.connection = connection; this.resumableDuplexConnection = resumableDuplexConnection; + this.onClose = onClose; this.allocator = allocator; } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java index 468a13505..7cf12a81e 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java @@ -1,29 +1,112 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.util.CharsetUtil; import io.netty.util.ReferenceCounted; import io.rsocket.ConnectionSetupPayload; +import io.rsocket.FrameAssert; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import java.time.Duration; import java.util.ArrayList; import java.util.List; -import org.assertj.core.api.Assertions; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.test.StepVerifier; +import reactor.util.retry.Retry; public class RSocketConnectorTest { + @ParameterizedTest + @ValueSource(strings = {"KEEPALIVE", "REQUEST_RESPONSE"}) + public void unexpectedFramesBeforeResumeOKFrame(String frameType) { + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.create() + .resume(new Resume().retry(Retry.indefinitely())) + .connect(transport) + .block(); + + final TestDuplexConnection duplexConnection = transport.testConnection(); + + duplexConnection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(duplexConnection.alloc(), false, 1, Unpooled.EMPTY_BUFFER)); + FrameAssert.assertThat(duplexConnection.pollFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + FrameAssert.assertThat(duplexConnection.pollFrame()).isNull(); + + duplexConnection.dispose(); + + final TestDuplexConnection duplexConnection2 = transport.testConnection(); + + final ByteBuf frame; + switch (frameType) { + case "KEEPALIVE": + frame = + KeepAliveFrameCodec.encode(duplexConnection2.alloc(), false, 1, Unpooled.EMPTY_BUFFER); + break; + case "REQUEST_RESPONSE": + default: + frame = + RequestResponseFrameCodec.encode( + duplexConnection2.alloc(), 2, false, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + } + duplexConnection2.addToReceivedBuffer(frame); + + StepVerifier.create(duplexConnection2.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection2.pollFrame()) + .typeOf(FrameType.RESUME) + .hasStreamIdZero() + .hasNoLeaks(); + + FrameAssert.assertThat(duplexConnection2.pollFrame()) + .isNotNull() + .typeOf(FrameType.ERROR) + .hasData("RESUME_OK frame must be received before any others") + .hasStreamIdZero() + .hasNoLeaks(); + + transport.alloc().assertHasNoLeaks(); + } + @Test public void ensuresThatSetupPayloadCanBeRetained() { - MonoProcessor retainedSetupPayload = MonoProcessor.create(); + AtomicReference retainedSetupPayload = new AtomicReference<>(); TestClientTransport transport = new TestClientTransport(); ByteBuf data = transport.alloc().buffer(); @@ -34,13 +117,13 @@ public void ensuresThatSetupPayloadCanBeRetained() { .setupPayload(ByteBufPayload.create(data)) .acceptor( (setup, sendingSocket) -> { - retainedSetupPayload.onNext(setup.retain()); + retainedSetupPayload.set(setup.retain()); return Mono.just(new RSocket() {}); }) .connect(transport) .block(); - Assertions.assertThat(transport.testConnection().getSent()) + assertThat(transport.testConnection().getSent()) .hasSize(1) .first() .matches( @@ -55,17 +138,10 @@ public void ensuresThatSetupPayloadCanBeRetained() { return buf.refCnt() == 1; }); - retainedSetupPayload - .as(StepVerifier::create) - .expectNextMatches( - setup -> { - String dataUtf8 = setup.getDataUtf8(); - return "data".equals(dataUtf8) && setup.release(); - }) - .expectComplete() - .verify(Duration.ofSeconds(5)); - - Assertions.assertThat(retainedSetupPayload.peek().refCnt()).isZero(); + ConnectionSetupPayload setup = retainedSetupPayload.get(); + String dataUtf8 = setup.getDataUtf8(); + assertThat("data".equals(dataUtf8) && setup.release()).isTrue(); + assertThat(setup.refCnt()).isZero(); transport.alloc().assertHasNoLeaks(); } @@ -73,8 +149,13 @@ public void ensuresThatSetupPayloadCanBeRetained() { @Test public void ensuresThatMonoFromRSocketConnectorCanBeUsedForMultipleSubscriptions() { Payload setupPayload = ByteBufPayload.create("TestData", "TestMetadata"); + assertThat(setupPayload.refCnt()).isOne(); - Assertions.assertThat(setupPayload.refCnt()).isOne(); + // Keep the data and metadata around so we can try changing them independently + ByteBuf dataBuf = setupPayload.data(); + ByteBuf metadataBuf = setupPayload.metadata(); + dataBuf.retain(); + metadataBuf.retain(); TestClientTransport testClientTransport = new TestClientTransport(); Mono connectionMono = @@ -86,31 +167,59 @@ public void ensuresThatMonoFromRSocketConnectorCanBeUsedForMultipleSubscriptions .expectComplete() .verify(Duration.ofMillis(100)); + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData") + && payload.getMetadataUtf8().equals("TestMetadata"); + }) + .allMatch(ReferenceCounted::release); + connectionMono .as(StepVerifier::create) .expectNextCount(1) .expectComplete() .verify(Duration.ofMillis(100)); - Assertions.assertThat(testClientTransport.testConnection().getSent()) - .hasSize(2) + // Changing the original data and metadata should not impact the SetupPayload + dataBuf.writerIndex(dataBuf.readerIndex()); + dataBuf.writeChar('d'); + dataBuf.release(); + + metadataBuf.writerIndex(metadataBuf.readerIndex()); + metadataBuf.writeChar('m'); + metadataBuf.release(); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) .allMatch( bb -> { DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); return payload.getDataUtf8().equals("TestData") && payload.getMetadataUtf8().equals("TestMetadata"); }) - .allMatch(ReferenceCounted::release); - Assertions.assertThat(setupPayload.refCnt()).isZero(); + .allMatch( + byteBuf -> { + System.out.println("calling release " + byteBuf.refCnt()); + return byteBuf.release(); + }); + assertThat(setupPayload.refCnt()).isZero(); + + testClientTransport.alloc().assertHasNoLeaks(); } @Test public void ensuresThatSetupPayloadProvidedAsMonoIsReleased() { List saved = new ArrayList<>(); + AtomicLong subscriptions = new AtomicLong(); Mono setupPayloadMono = Mono.create( sink -> { - Payload payload = ByteBufPayload.create("TestData", "TestMetadata"); + final long subscriptionN = subscriptions.getAndIncrement(); + Payload payload = + ByteBufPayload.create("TestData" + subscriptionN, "TestMetadata" + subscriptionN); saved.add(payload); sink.success(payload); }); @@ -125,29 +234,41 @@ public void ensuresThatSetupPayloadProvidedAsMonoIsReleased() { .expectComplete() .verify(Duration.ofMillis(100)); + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData0") + && payload.getMetadataUtf8().equals("TestMetadata0"); + }) + .allMatch(ReferenceCounted::release); + connectionMono .as(StepVerifier::create) .expectNextCount(1) .expectComplete() .verify(Duration.ofMillis(100)); - Assertions.assertThat(testClientTransport.testConnection().getSent()) - .hasSize(2) + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) .allMatch( bb -> { DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); - return payload.getDataUtf8().equals("TestData") - && payload.getMetadataUtf8().equals("TestMetadata"); + return payload.getDataUtf8().equals("TestData1") + && payload.getMetadataUtf8().equals("TestMetadata1"); }) .allMatch(ReferenceCounted::release); - Assertions.assertThat(saved) + assertThat(saved) .as("Metadata and data were consumed and released as slices") .allMatch( payload -> payload.refCnt() == 1 && payload.data().refCnt() == 0 && payload.metadata().refCnt() == 0); + + testClientTransport.alloc().assertHasNoLeaks(); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 32bae9270..a461833d3 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,12 @@ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; -import static io.rsocket.frame.FrameType.*; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.LEASE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.SETUP; import static org.assertj.core.data.Offset.offset; import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; @@ -27,21 +32,24 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.CharsetUtil; -import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCounted; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.LeaseFrameCodec; import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.*; +import io.rsocket.lease.Lease; import io.rsocket.lease.MissingLeaseException; import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.test.util.TestClientTransport; @@ -52,12 +60,11 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.util.ArrayList; import java.util.Collection; -import java.util.Optional; import java.util.function.BiFunction; import java.util.stream.Stream; import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -65,24 +72,28 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; import org.reactivestreams.Publisher; -import reactor.core.publisher.EmitterProcessor; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; class RSocketLeaseTest { private static final String TAG = "test"; private RSocket rSocketRequester; - private ResponderLeaseHandler responderLeaseHandler; + private ResponderLeaseTracker responderLeaseTracker; private LeaksTrackingByteBufAllocator byteBufAllocator; private TestDuplexConnection connection; private RSocketResponder rSocketResponder; private RSocket mockRSocketHandler; - private EmitterProcessor leaseSender = EmitterProcessor.create(); - private Flux leaseReceiver; - private RequesterLeaseHandler requesterLeaseHandler; + private Sinks.Many leaseSender = Sinks.many().multicast().onBackpressureBuffer(); + private RequesterLeaseTracker requesterLeaseTracker; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; @BeforeEach void setUp() { @@ -90,10 +101,10 @@ void setUp() { byteBufAllocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); connection = new TestDuplexConnection(byteBufAllocator); - requesterLeaseHandler = new RequesterLeaseHandler.Impl(TAG, leases -> leaseReceiver = leases); - responderLeaseHandler = - new ResponderLeaseHandler.Impl<>( - TAG, byteBufAllocator, stats -> leaseSender, Optional.empty()); + requesterLeaseTracker = new RequesterLeaseTracker(TAG, 0); + responderLeaseTracker = new ResponderLeaseTracker(TAG, connection, () -> leaseSender.asFlux()); + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); ClientServerInputMultiplexer multiplexer = new ClientServerInputMultiplexer(connection, new InitializingInterceptorRegistry(), true); @@ -108,7 +119,10 @@ void setUp() { 0, 0, null, - requesterLeaseHandler); + __ -> null, + requesterLeaseTracker, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); mockRSocketHandler = mock(RSocket.class); when(mockRSocketHandler.metadataPush(any())) @@ -145,7 +159,25 @@ void setUp() { Publisher payloadPublisher = a.getArgument(0); return Flux.from(payloadPublisher) .doOnNext(ReferenceCounted::release) - .thenMany(Flux.empty()); + .transform( + Operators.lift( + (__, actual) -> + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + actual.onSubscribe(this); + } + + @Override + protected void hookOnComplete() { + actual.onComplete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + actual.onError(throwable); + } + })); }); rSocketResponder = @@ -153,10 +185,17 @@ void setUp() { multiplexer.asServerConnection(), mockRSocketHandler, payloadDecoder, - responderLeaseHandler, + responderLeaseTracker, 0, FRAME_LENGTH_MASK, - Integer.MAX_VALUE); + Integer.MAX_VALUE, + __ -> null, + otherClosedSink); + } + + @AfterEach + void tearDownAndCheckForLeaks() { + byteBufAllocator.assertHasNoLeaks(); } @Test @@ -184,18 +223,26 @@ public void serverRSocketFactoryRejectsUnsupportedLease() { Assertions.assertThat(FrameHeaderCodec.frameType(error)).isEqualTo(ERROR); Assertions.assertThat(Exceptions.from(0, error).getMessage()) .isEqualTo("lease is not supported"); + error.release(); + connection.dispose(); + transport.alloc().assertHasNoLeaks(); } @Test public void clientRSocketFactorySetsLeaseFlag() { TestClientTransport clientTransport = new TestClientTransport(); - RSocketConnector.create().lease(Leases::new).connect(clientTransport).block(); - - Collection sent = clientTransport.testConnection().getSent(); - Assertions.assertThat(sent).hasSize(1); - ByteBuf setup = sent.iterator().next(); - Assertions.assertThat(FrameHeaderCodec.frameType(setup)).isEqualTo(SETUP); - Assertions.assertThat(SetupFrameCodec.honorLease(setup)).isTrue(); + try { + RSocketConnector.create().lease().connect(clientTransport).block(); + Collection sent = clientTransport.testConnection().getSent(); + Assertions.assertThat(sent).hasSize(1); + ByteBuf setup = sent.iterator().next(); + Assertions.assertThat(FrameHeaderCodec.frameType(setup)).isEqualTo(SETUP); + Assertions.assertThat(SetupFrameCodec.honorLease(setup)).isTrue(); + setup.release(); + } finally { + clientTransport.testConnection().dispose(); + clientTransport.alloc().assertHasNoLeaks(); + } } @ParameterizedTest @@ -218,7 +265,7 @@ void requesterMissingLeaseRequestsAreRejected( void requesterPresentLeaseRequestsAreAccepted( BiFunction> interaction, FrameType frameType) { ByteBuf frame = leaseFrame(5_000, 2, Unpooled.EMPTY_BUFFER); - requesterLeaseHandler.receive(frame); + requesterLeaseTracker.handleLeaseFrame(frame); Assertions.assertThat(rSocketRequester.availability()).isCloseTo(1.0, offset(1e-2)); ByteBuf buffer = byteBufAllocator.buffer(); @@ -270,13 +317,13 @@ void requesterDepletedAllowedLeaseRequestsAreRejected( buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); ByteBuf leaseFrame = leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER); - requesterLeaseHandler.receive(leaseFrame); + requesterLeaseTracker.handleLeaseFrame(leaseFrame); - double initialAvailability = requesterLeaseHandler.availability(); + double initialAvailability = requesterLeaseTracker.availability(); Publisher request = interaction.apply(rSocketRequester, payload1); // ensures that lease is not used until the frame is sent - Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseHandler.availability()); + Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseTracker.availability()); Assertions.assertThat(connection.getSent()).hasSize(0); AssertSubscriber assertSubscriber = AssertSubscriber.create(0); @@ -285,7 +332,7 @@ void requesterDepletedAllowedLeaseRequestsAreRejected( // if request is FNF, then request frame is sent on subscribe // otherwise we need to make request(1) if (interactionType != REQUEST_FNF) { - Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseHandler.availability()); + Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseTracker.availability()); Assertions.assertThat(connection.getSent()).hasSize(0); assertSubscriber.request(1); @@ -330,7 +377,7 @@ void requesterDepletedAllowedLeaseRequestsAreRejected( void requesterExpiredLeaseRequestsAreRejected( BiFunction> interaction) { ByteBuf frame = leaseFrame(50, 1, Unpooled.EMPTY_BUFFER); - requesterLeaseHandler.receive(frame); + requesterLeaseTracker.handleLeaseFrame(frame); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); @@ -349,39 +396,99 @@ void requesterExpiredLeaseRequestsAreRejected( @Test void requesterAvailabilityRespectsTransport() { - requesterLeaseHandler.receive(leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER)); - double unavailable = 0.0; - connection.setAvailability(unavailable); - Assertions.assertThat(rSocketRequester.availability()).isCloseTo(unavailable, offset(1e-2)); + ByteBuf frame = leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER); + try { + + requesterLeaseTracker.handleLeaseFrame(frame); + double unavailable = 0.0; + connection.setAvailability(unavailable); + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(unavailable, offset(1e-2)); + } finally { + frame.release(); + } } @ParameterizedTest - @MethodSource("interactions") - void responderMissingLeaseRequestsAreRejected( - BiFunction> interaction) { + @MethodSource("responderInteractions") + void responderMissingLeaseRequestsAreRejected(FrameType frameType) { ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - StepVerifier.create(interaction.apply(rSocketResponder, payload1)) - .expectError(MissingLeaseException.class) - .verify(Duration.ofSeconds(5)); + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest - @MethodSource("interactions") - void responderPresentLeaseRequestsAreAccepted( - BiFunction> interaction, FrameType frameType) { - leaseSender.onNext(Lease.create(5_000, 2)); + @MethodSource("responderInteractions") + void responderPresentLeaseRequestsAreAccepted(FrameType frameType) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 2)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - Flux.from(interaction.apply(rSocketResponder, payload1)) - .as(StepVerifier::create) - .expectComplete() - .verify(Duration.ofSeconds(5)); + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFireAndForget(1, fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } switch (frameType) { case REQUEST_FNF: @@ -399,47 +506,119 @@ void responderPresentLeaseRequestsAreAccepted( } Assertions.assertThat(connection.getSent()) - .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) .matches(ReferenceCounted::release); + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest - @MethodSource("interactions") - void responderDepletedAllowedLeaseRequestsAreRejected( - BiFunction> interaction) { - leaseSender.onNext(Lease.create(5_000, 1)); + @MethodSource("responderInteractions") + void responderDepletedAllowedLeaseRequestsAreRejected(FrameType frameType) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 1)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); Payload payload1 = ByteBufPayload.create(buffer); - Flux responder = Flux.from(interaction.apply(rSocketResponder, payload1)); - responder.subscribe(); + ByteBuf buffer2 = byteBufAllocator.buffer(); + buffer2.writeCharSequence("test2", CharsetUtil.UTF_8); + Payload payload2 = ByteBufPayload.create(buffer2); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf fnfFrame2 = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(fnfFrame); + rSocketResponder.handleFrame(fnfFrame2); + fnfFrame.release(); + fnfFrame2.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf requestResponseFrame2 = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(requestResponseFrame); + rSocketResponder.handleFrame(requestResponseFrame2); + requestResponseFrame.release(); + requestResponseFrame2.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + final ByteBuf requestStreamFrame2 = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, 1, payload2); + rSocketResponder.handleFrame(requestStreamFrame); + rSocketResponder.handleFrame(requestStreamFrame2); + requestStreamFrame.release(); + requestStreamFrame2.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + final ByteBuf requestChannelFrame2 = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, true, 1, payload2); + rSocketResponder.handleFrame(requestChannelFrame); + rSocketResponder.handleFrame(requestChannelFrame2); + requestChannelFrame.release(); + requestChannelFrame2.release(); + break; + } + + switch (frameType) { + case REQUEST_FNF: + Mockito.verify(mockRSocketHandler).fireAndForget(any()); + break; + case REQUEST_RESPONSE: + Mockito.verify(mockRSocketHandler).requestResponse(any()); + break; + case REQUEST_STREAM: + Mockito.verify(mockRSocketHandler).requestStream(any()); + break; + case REQUEST_CHANNEL: + Mockito.verify(mockRSocketHandler).requestChannel(any()); + break; + } Assertions.assertThat(connection.getSent()) - .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) .matches(ReferenceCounted::release); - ByteBuf buffer2 = byteBufAllocator.buffer(); - buffer2.writeCharSequence("test", CharsetUtil.UTF_8); - Payload payload2 = ByteBufPayload.create(buffer2); + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); - Flux.from(interaction.apply(rSocketResponder, payload2)) - .as(StepVerifier::create) - .expectError(MissingLeaseException.class) - .verify(Duration.ofSeconds(5)); + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(2) + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); } @ParameterizedTest @MethodSource("interactions") void expiredLeaseRequestsAreRejected(BiFunction> interaction) { - leaseSender.onNext(Lease.create(50, 1)); + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(50), 1)); ByteBuf buffer = byteBufAllocator.buffer(); buffer.writeCharSequence("test", CharsetUtil.UTF_8); @@ -468,7 +647,7 @@ void sendLease() { metadata.writeCharSequence(metadataContent, utf8); int ttl = 5_000; int numberOfRequests = 2; - leaseSender.onNext(Lease.create(5_000, 2, metadata)); + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 2, metadata)); ByteBuf leaseFrame = connection @@ -478,36 +657,41 @@ void sendLease() { .findFirst() .orElseThrow(() -> new IllegalStateException("Lease frame not sent")); - Assertions.assertThat(LeaseFrameCodec.ttl(leaseFrame)).isEqualTo(ttl); - Assertions.assertThat(LeaseFrameCodec.numRequests(leaseFrame)).isEqualTo(numberOfRequests); - Assertions.assertThat(LeaseFrameCodec.metadata(leaseFrame).toString(utf8)) - .isEqualTo(metadataContent); + try { + Assertions.assertThat(LeaseFrameCodec.ttl(leaseFrame)).isEqualTo(ttl); + Assertions.assertThat(LeaseFrameCodec.numRequests(leaseFrame)).isEqualTo(numberOfRequests); + Assertions.assertThat(LeaseFrameCodec.metadata(leaseFrame).toString(utf8)) + .isEqualTo(metadataContent); + } finally { + leaseFrame.release(); + } } - @Test - void receiveLease() { - Collection receivedLeases = new ArrayList<>(); - leaseReceiver.subscribe(lease -> receivedLeases.add(lease)); - - ByteBuf metadata = byteBufAllocator.buffer(); - Charset utf8 = StandardCharsets.UTF_8; - String metadataContent = "test"; - metadata.writeCharSequence(metadataContent, utf8); - int ttl = 5_000; - int numberOfRequests = 2; - - ByteBuf leaseFrame = leaseFrame(ttl, numberOfRequests, metadata).retain(1); - - connection.addToReceivedBuffer(leaseFrame); - - Assertions.assertThat(receivedLeases.isEmpty()).isFalse(); - Lease receivedLease = receivedLeases.iterator().next(); - Assertions.assertThat(receivedLease.getTimeToLiveMillis()).isEqualTo(ttl); - Assertions.assertThat(receivedLease.getStartingAllowedRequests()).isEqualTo(numberOfRequests); - Assertions.assertThat(receivedLease.getMetadata().toString(utf8)).isEqualTo(metadataContent); - - ReferenceCountUtil.safeRelease(leaseFrame); - } + // @Test + // void receiveLease() { + // Collection receivedLeases = new ArrayList<>(); + // leaseReceiver.subscribe(lease -> receivedLeases.add(lease)); + // + // ByteBuf metadata = byteBufAllocator.buffer(); + // Charset utf8 = StandardCharsets.UTF_8; + // String metadataContent = "test"; + // metadata.writeCharSequence(metadataContent, utf8); + // int ttl = 5_000; + // int numberOfRequests = 2; + // + // ByteBuf leaseFrame = leaseFrame(ttl, numberOfRequests, metadata).retain(1); + // + // connection.addToReceivedBuffer(leaseFrame); + // + // Assertions.assertThat(receivedLeases.isEmpty()).isFalse(); + // Lease receivedLease = receivedLeases.iterator().next(); + // Assertions.assertThat(receivedLease.getTimeToLiveMillis()).isEqualTo(ttl); + // + // Assertions.assertThat(receivedLease.getStartingAllowedRequests()).isEqualTo(numberOfRequests); + // Assertions.assertThat(receivedLease.metadata().toString(utf8)).isEqualTo(metadataContent); + // + // ReferenceCountUtil.safeRelease(leaseFrame); + // } ByteBuf leaseFrame(int ttl, int requests, ByteBuf metadata) { return LeaseFrameCodec.encode(byteBufAllocator, ttl, requests, metadata); @@ -529,4 +713,12 @@ static Stream interactions() { (rSocket, payload) -> rSocket.requestChannel(Mono.just(payload)), FrameType.REQUEST_CHANNEL)); } + + static Stream responderInteractions() { + return Stream.of( + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java index 34810b6bd..966fd65f2 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java @@ -15,10 +15,13 @@ */ package io.rsocket.core; -import static org.junit.Assert.assertEquals; +import static org.assertj.core.api.Assertions.assertThat; +import io.rsocket.FrameAssert; import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.transport.ClientTransport; import java.io.UncheckedIOException; import java.time.Duration; @@ -49,27 +52,44 @@ public void shouldBeASharedReconnectableInstanceOfRSocketMono() throws Interrupt RSocket rSocket1 = rSocketMono.block(); RSocket rSocket2 = rSocketMono.block(); - Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + FrameAssert.assertThat(testClientTransport[0].testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + assertThat(rSocket1).isEqualTo(rSocket2); testClientTransport[0].testConnection().dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + testClientTransport[0].alloc().assertHasNoLeaks(); testClientTransport[0] = new TestClientTransport(); RSocket rSocket3 = rSocketMono.block(); RSocket rSocket4 = rSocketMono.block(); - Assertions.assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + FrameAssert.assertThat(testClientTransport[0].testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + rSocket3.onClose().block(Duration.ofSeconds(1)); + testClientTransport[0].alloc().assertHasNoLeaks(); } @Test - @SuppressWarnings({"rawtype", "unchecked"}) + @SuppressWarnings({"rawtype"}) public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { ClientTransport transport = Mockito.mock(ClientTransport.class); + TestClientTransport transport1 = new TestClientTransport(); Mockito.when(transport.connect()) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) - .thenReturn(new TestClientTransport().connect()); + .thenReturn(transport1.connect()); Mono rSocketMono = RSocketConnector.create() .reconnect( @@ -81,25 +101,35 @@ public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono RSocket rSocket1 = rSocketMono.block(); RSocket rSocket2 = rSocketMono.block(); - Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + assertThat(rSocket1).isEqualTo(rSocket2); assertRetries( UncheckedIOException.class, UncheckedIOException.class, UncheckedIOException.class, UncheckedIOException.class); + + FrameAssert.assertThat(transport1.testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + transport1.testConnection().dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + transport1.alloc().assertHasNoLeaks(); } @Test - @SuppressWarnings({"rawtype", "unchecked"}) + @SuppressWarnings({"rawtype"}) public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { ClientTransport transport = Mockito.mock(ClientTransport.class); + TestClientTransport transport1 = new TestClientTransport(); Mockito.when(transport.connect()) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) .thenThrow(UncheckedIOException.class) - .thenReturn(new TestClientTransport().connect()); + .thenReturn(transport1.connect()); Mono rSocketMono = RSocketConnector.create() .reconnect( @@ -121,27 +151,48 @@ public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSo UncheckedIOException.class, UncheckedIOException.class, UncheckedIOException.class); + + transport1.alloc().assertHasNoLeaks(); } @Test public void shouldBeNotBeASharedReconnectableInstanceOfRSocketMono() { - - Mono rSocketMono = RSocketConnector.connectWith(new TestClientTransport()); + TestClientTransport transport = new TestClientTransport(); + Mono rSocketMono = RSocketConnector.connectWith(transport); RSocket rSocket1 = rSocketMono.block(); + TestDuplexConnection connection1 = transport.testConnection(); + + FrameAssert.assertThat(connection1.awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + RSocket rSocket2 = rSocketMono.block(); + TestDuplexConnection connection2 = transport.testConnection(); + + assertThat(rSocket1).isNotEqualTo(rSocket2); + + FrameAssert.assertThat(connection2.awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); - Assertions.assertThat(rSocket1).isNotEqualTo(rSocket2); + connection1.dispose(); + connection2.dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + rSocket2.onClose().block(Duration.ofSeconds(1)); + transport.alloc().assertHasNoLeaks(); } @SafeVarargs private final void assertRetries(Class... exceptions) { - assertEquals(exceptions.length, retries.size()); + assertThat(retries.size()).isEqualTo(exceptions.length); int index = 0; for (Iterator it = retries.iterator(); it.hasNext(); ) { Retry.RetrySignal retryContext = it.next(); - assertEquals(index, retryContext.totalRetries()); - assertEquals(exceptions[index], retryContext.failure().getClass()); + assertThat(retryContext.totalRetries()).isEqualTo(index); + assertThat(retryContext.failure().getClass()).isEqualTo(exceptions[index]); index++; } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index fda6b61ee..01eb998c7 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.util.CharsetUtil; +import io.rsocket.FrameAssert; import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameHeaderCodec; @@ -28,7 +29,6 @@ import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.DefaultPayload; import java.util.Arrays; @@ -38,6 +38,7 @@ import java.util.function.Function; import java.util.stream.Stream; import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -45,6 +46,7 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.util.RaceTestUtils; class RSocketRequesterSubscribersTest { @@ -61,11 +63,20 @@ class RSocketRequesterSubscribersTest { private LeaksTrackingByteBufAllocator allocator; private RSocket rSocketRequester; private TestDuplexConnection connection; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + @AfterEach + void tearDownAndCheckNoLeaks() { + allocator.assertHasNoLeaks(); + } @BeforeEach void setUp() { allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); connection = new TestDuplexConnection(allocator); + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); rSocketRequester = new RSocketRequester( connection, @@ -77,11 +88,15 @@ void setUp() { 0, 0, null, - RequesterLeaseHandler.None); + __ -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); } @ParameterizedTest @MethodSource("allInteractions") + @SuppressWarnings({"rawtypes", "unchecked"}) void singleSubscriber(Function> interaction, FrameType requestType) { Flux response = Flux.from(interaction.apply(rSocketRequester)); @@ -98,7 +113,11 @@ void singleSubscriber(Function> interaction, FrameType req assertSubscriberA.assertTerminated(); assertSubscriberB.assertTerminated(); - Assertions.assertThat(requestFramesCount(connection.getSent())).isEqualTo(1); + FrameAssert.assertThat(connection.pollFrame()).typeOf(requestType).hasNoLeaks(); + + if (requestType == FrameType.REQUEST_CHANNEL) { + FrameAssert.assertThat(connection.pollFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } } @ParameterizedTest diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java index de6f86c57..5cfa76a1c 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java @@ -1,45 +1,59 @@ package io.rsocket.core; +import io.rsocket.FrameAssert; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.core.RSocketRequesterTest.ClientSocketRule; +import io.rsocket.frame.FrameType; import io.rsocket.util.EmptyPayload; import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.Arrays; import java.util.function.Function; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -@RunWith(Parameterized.class) public class RSocketRequesterTerminationTest { - @Rule public final ClientSocketRule rule = new ClientSocketRule(); - private Function> interaction; + public final ClientSocketRule rule = new ClientSocketRule(); - public RSocketRequesterTerminationTest(Function> interaction) { - this.interaction = interaction; + @BeforeEach + public void setup() { + rule.init(); } - @Test - public void testCurrentStreamIsTerminatedOnConnectionClose() { - RSocketRequester rSocket = rule.socket; + @AfterEach + public void tearDownAndCheckNoLeaks() { + rule.assertHasNoLeaks(); + } - Mono.delay(Duration.ofSeconds(1)).doOnNext(v -> rule.connection.dispose()).subscribe(); + @ParameterizedTest + @MethodSource("rsocketInteractions") + public void testCurrentStreamIsTerminatedOnConnectionClose( + FrameType requestType, Function> interaction) { + RSocketRequester rSocket = rule.socket; StepVerifier.create(interaction.apply(rSocket)) + .then( + () -> { + FrameAssert.assertThat(rule.connection.pollFrame()).typeOf(requestType).hasNoLeaks(); + }) + .then(() -> rule.connection.dispose()) .expectError(ClosedChannelException.class) .verify(Duration.ofSeconds(5)); } - @Test - public void testSubsequentStreamIsTerminatedAfterConnectionClose() { + @ParameterizedTest + @MethodSource("rsocketInteractions") + public void testSubsequentStreamIsTerminatedAfterConnectionClose( + FrameType requestType, Function> interaction) { RSocketRequester rSocket = rule.socket; rule.connection.dispose(); @@ -48,14 +62,51 @@ public void testSubsequentStreamIsTerminatedAfterConnectionClose() { .verify(Duration.ofSeconds(5)); } - @Parameterized.Parameters - public static Iterable>> rsocketInteractions() { + public static Iterable rsocketInteractions() { EmptyPayload payload = EmptyPayload.INSTANCE; - Publisher payloadStream = Flux.just(payload); - Function> resp = rSocket -> rSocket.requestResponse(payload); - Function> stream = rSocket -> rSocket.requestStream(payload); - Function> channel = rSocket -> rSocket.requestChannel(payloadStream); + Arguments resp = + Arguments.of( + FrameType.REQUEST_RESPONSE, + new Function>() { + @Override + public Mono apply(RSocket rSocket) { + return rSocket.requestResponse(payload); + } + + @Override + public String toString() { + return "Request Response"; + } + }); + Arguments stream = + Arguments.of( + FrameType.REQUEST_STREAM, + new Function>() { + @Override + public Flux apply(RSocket rSocket) { + return rSocket.requestStream(payload); + } + + @Override + public String toString() { + return "Request Stream"; + } + }); + Arguments channel = + Arguments.of( + FrameType.REQUEST_CHANNEL, + new Function>() { + @Override + public Flux apply(RSocket rSocket) { + return rSocket.requestChannel(Flux.never().startWith(payload)); + } + + @Override + public String toString() { + return "Request Channel"; + } + }); return Arrays.asList(resp, stream, channel); } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java index a0b3ef3f2..a1199f698 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,11 +25,15 @@ import static io.rsocket.core.TestRequesterResponderSupport.randomPayload; import static io.rsocket.frame.FrameHeaderCodec.frameType; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; -import static io.rsocket.frame.FrameType.*; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; @@ -45,6 +49,7 @@ import io.rsocket.Payload; import io.rsocket.PayloadAssert; import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.exceptions.CustomRSocketException; @@ -62,7 +67,6 @@ import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; @@ -78,7 +82,6 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Stream; -import org.assertj.core.api.Assertions; import org.assertj.core.api.Assumptions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -90,17 +93,15 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; -import org.junit.runners.model.Statement; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.Scannable; import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.UnicastProcessor; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; import reactor.test.util.RaceTestUtils; @@ -114,26 +115,21 @@ public void setUp() throws Throwable { Hooks.onNextDropped(ReferenceCountUtil::safeRelease); Hooks.onErrorDropped((t) -> {}); rule = new ClientSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); } @AfterEach public void tearDown() { Hooks.resetOnErrorDropped(); Hooks.resetOnNextDropped(); + rule.assertHasNoLeaks(); } @Test @Timeout(2_000) public void testInvalidFrameOnStream0ShouldNotTerminateRSocket() { rule.connection.addToReceivedBuffer(RequestNFrameCodec.encode(rule.alloc(), 0, 10)); - Assertions.assertThat(rule.socket.isDisposed()).isFalse(); + assertThat(rule.socket.isDisposed()).isFalse(); rule.assertHasNoLeaks(); } @@ -151,19 +147,21 @@ protected void hookOnSubscribe(Subscription subscription) { }; stream.subscribe(subscriber); - Assertions.assertThat(rule.connection.getSent()).isEmpty(); + assertThat(rule.connection.getSent()).isEmpty(); subscriber.request(5); List sent = new ArrayList<>(rule.connection.getSent()); - assertThat("sent frame count", sent.size(), is(1)); + assertThat(sent.size()).describedAs("sent frame count").isEqualTo(1); ByteBuf f = sent.get(0); - assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); - assertThat("initial request n", RequestStreamFrameCodec.initialRequestN(f), is(5L)); - assertThat("should be released", f.release(), is(true)); + assertThat(frameType(f)).describedAs("initial frame").isEqualTo(REQUEST_STREAM); + assertThat(RequestStreamFrameCodec.initialRequestN(f)) + .describedAs("initial request n") + .isEqualTo(5L); + assertThat(f.release()).describedAs("should be released").isEqualTo(true); rule.assertHasNoLeaks(); } @@ -172,7 +170,7 @@ protected void hookOnSubscribe(Subscription subscription) { public void testHandleSetupException() { rule.connection.addToReceivedBuffer( ErrorFrameCodec.encode(rule.alloc(), 0, new RejectedSetupException("boom"))); - Assertions.assertThatThrownBy(() -> rule.socket.onClose().block()) + assertThatThrownBy(() -> rule.socket.onClose().block()) .isInstanceOf(RejectedSetupException.class); rule.assertHasNoLeaks(); } @@ -191,7 +189,7 @@ public void testHandleApplicationException() { verify(responseSub).onError(any(ApplicationErrorException.class)); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) // requestResponseFrame .hasSize(1) .allMatch(ReferenceCounted::release); @@ -212,7 +210,7 @@ public void testHandleValidFrame() { rule.alloc(), streamId, EmptyPayload.INSTANCE)); verify(sub).onComplete(); - Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @@ -228,10 +226,13 @@ public void testRequestReplyWithCancel() { List sent = new ArrayList<>(rule.connection.getSent()); - assertThat( - "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); - assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); - Assertions.assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); + assertThat(frameType(sent.get(0))) + .describedAs("Unexpected frame sent on the connection.") + .isEqualTo(REQUEST_RESPONSE); + assertThat(frameType(sent.get(1))) + .describedAs("Unexpected frame sent on the connection.") + .isEqualTo(CANCEL); + assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @@ -261,11 +262,11 @@ public void testRequestReplyErrorOnSend() { @Test @Timeout(2_000) public void testChannelRequestCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - Flux request = Flux.never().doOnCancel(cancelled::onComplete); + Sinks.Empty cancelled = Sinks.empty(); + Flux request = Flux.never().doOnCancel(cancelled::tryEmitEmpty); rule.socket.requestChannel(request).subscribe().dispose(); - Flux.first( - cancelled, + Flux.firstWithSignal( + cancelled.asMono(), Flux.error(new IllegalStateException("Channel request not cancelled")) .delaySubscription(Duration.ofSeconds(1))) .blockFirst(); @@ -275,36 +276,39 @@ public void testChannelRequestCancellation() { @Test @Timeout(2_000) public void testChannelRequestCancellation2() { - MonoProcessor cancelled = MonoProcessor.create(); + Sinks.Empty cancelled = Sinks.empty(); Flux request = - Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::onComplete); + Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::tryEmitEmpty); rule.socket.requestChannel(request).subscribe().dispose(); - Flux.first( - cancelled, + Flux.firstWithSignal( + cancelled.asMono(), Flux.error(new IllegalStateException("Channel request not cancelled")) .delaySubscription(Duration.ofSeconds(1))) .blockFirst(); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @Test public void testChannelRequestServerSideCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - UnicastProcessor request = UnicastProcessor.create(); - request.onNext(EmptyPayload.INSTANCE); - rule.socket.requestChannel(request).subscribe(cancelled); + Sinks.One cancelled = Sinks.one(); + Sinks.Many request = Sinks.many().unicast().onBackpressureBuffer(); + request.tryEmitNext(EmptyPayload.INSTANCE); + rule.socket + .requestChannel(request.asFlux()) + .subscribe(cancelled::tryEmitValue, cancelled::tryEmitError, cancelled::tryEmitEmpty); int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(rule.alloc(), streamId)); rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.alloc(), streamId)); - Flux.first( - cancelled, + Flux.firstWithSignal( + cancelled.asMono(), Flux.error(new IllegalStateException("Channel request not cancelled")) .delaySubscription(Duration.ofSeconds(1))) .blockFirst(); - Assertions.assertThat(request.isDisposed()).isTrue(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(request.scan(Scannable.Attr.TERMINATED) || request.scan(Scannable.Attr.CANCELLED)) + .isTrue(); + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> frameType(bb) == REQUEST_CHANNEL) @@ -314,7 +318,7 @@ public void testChannelRequestServerSideCancellation() { @Test public void testCorrectFrameOrder() { - MonoProcessor delayer = MonoProcessor.create(); + Sinks.One delayer = Sinks.one(); BaseSubscriber subscriber = new BaseSubscriber() { @Override @@ -322,26 +326,25 @@ protected void hookOnSubscribe(Subscription subscription) {} }; rule.socket .requestChannel( - Flux.concat(Flux.just(0).delayUntil(i -> delayer), Flux.range(1, 999)) + Flux.concat(Flux.just(0).delayUntil(i -> delayer.asMono()), Flux.range(1, 999)) .map(i -> DefaultPayload.create(i + ""))) .subscribe(subscriber); subscriber.request(1); subscriber.request(Long.MAX_VALUE); - delayer.onComplete(); + delayer.tryEmitEmpty(); Iterator iterator = rule.connection.getSent().iterator(); ByteBuf initialFrame = iterator.next(); - Assertions.assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); - Assertions.assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame)) - .isEqualTo(Long.MAX_VALUE); - Assertions.assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8)) + assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); + assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame)).isEqualTo(Long.MAX_VALUE); + assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8)) .isEqualTo("0"); - Assertions.assertThat(initialFrame.release()).isTrue(); + assertThat(initialFrame.release()).isTrue(); - Assertions.assertThat(iterator.hasNext()).isFalse(); + assertThat(iterator.hasNext()).isFalse(); rule.assertHasNoLeaks(); } @@ -362,7 +365,7 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen .expectSubscription() .expectErrorSatisfies( t -> - Assertions.assertThat(t) + assertThat(t) .isInstanceOf(IllegalArgumentException.class) .hasMessage( String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) @@ -371,6 +374,65 @@ public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmen }); } + @ParameterizedTest + @ValueSource(ints = {128, 256, FRAME_LENGTH_MASK}) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation1( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + assertThatThrownBy( + () -> { + final Publisher source = + generator.apply(rule.socket, DefaultPayload.create(data, metadata)); + + if (source instanceof Mono) { + ((Mono) source).block(); + } else { + ((Flux) source).blockLast(); + } + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength)); + + rule.assertHasNoLeaks(); + }); + } + + @Test + public void shouldRejectCallOfNoMetadataPayload() { + final ByteBuf data = rule.allocator.buffer(10); + final Payload payload = ByteBufPayload.create(data); + StepVerifier.create(rule.socket.metadataPush(payload)) + .expectSubscription() + .expectErrorSatisfies( + t -> + assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Metadata push should have metadata field present")) + .verify(); + PayloadAssert.assertThat(payload).isReleased(); + rule.assertHasNoLeaks(); + } + + @Test + public void shouldRejectCallOfNoMetadataPayloadBlocking() { + final ByteBuf data = rule.allocator.buffer(10); + final Payload payload = ByteBufPayload.create(data); + + assertThatThrownBy(() -> rule.socket.metadataPush(payload).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Metadata push should have metadata field present"); + PayloadAssert.assertThat(payload).isReleased(); + rule.assertHasNoLeaks(); + } + static Stream>> prepareCalls() { return Stream.of( RSocket::fireAndForget, @@ -404,11 +466,11 @@ static Stream>> prepareCalls() { }) .expectErrorSatisfies( t -> - Assertions.assertThat(t) + assertThat(t) .isInstanceOf(IllegalArgumentException.class) .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) .verify(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) // expect to be sent RequestChannelFrame // expect to be sent CancelFrame .hasSize(2) @@ -421,20 +483,10 @@ static Stream>> prepareCalls() { public void checkNoLeaksOnRacing( Function> initiator, BiConsumer, ClientSocketRule> runner) { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ClientSocketRule clientSocketRule = new ClientSocketRule(); - try { - clientSocketRule - .apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); - } catch (Throwable throwable) { - throwable.printStackTrace(); - } + + clientSocketRule.init(); Publisher payloadP = initiator.apply(clientSocketRule); AssertSubscriber assertSubscriber = AssertSubscriber.create(0); @@ -447,8 +499,7 @@ public void evaluate() {} runner.accept(assertSubscriber, clientSocketRule); - Assertions.assertThat(clientSocketRule.connection.getSent()) - .allMatch(ReferenceCounted::release); + assertThat(clientSocketRule.connection.getSent()).allMatch(ReferenceCounted::release); clientSocketRule.assertHasNoLeaks(); } @@ -509,8 +560,8 @@ private static Stream racingCases() { RaceTestUtils.race(() -> as.request(1), as::cancel); // ensures proper frames order if (rule.connection.getSent().size() > 0) { - Assertions.assertThat(rule.connection.getSent()).hasSize(2); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()).hasSize(2); + assertThat(rule.connection.getSent()) .element(0) .matches( bb -> frameType(bb) == REQUEST_STREAM, @@ -519,7 +570,7 @@ private static Stream racingCases() { + "} but was {" + frameType(rule.connection.getSent().stream().findFirst().get()) + "}"); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .element(1) .matches( bb -> frameType(bb) == CANCEL, @@ -556,8 +607,8 @@ private static Stream racingCases() { int size = rule.connection.getSent().size(); if (size > 0) { - Assertions.assertThat(size).isLessThanOrEqualTo(3).isGreaterThanOrEqualTo(2); - Assertions.assertThat(rule.connection.getSent()) + assertThat(size).isLessThanOrEqualTo(3).isGreaterThanOrEqualTo(2); + assertThat(rule.connection.getSent()) .element(0) .matches( bb -> frameType(bb) == REQUEST_CHANNEL, @@ -567,7 +618,7 @@ private static Stream racingCases() { + frameType(rule.connection.getSent().stream().findFirst().get()) + "}"); if (size == 2) { - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .element(1) .matches( bb -> frameType(bb) == CANCEL, @@ -578,7 +629,7 @@ private static Stream racingCases() { rule.connection.getSent().stream().skip(1).findFirst().get()) + "}"); } else { - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .element(1) .matches( bb -> frameType(bb) == COMPLETE || frameType(bb) == CANCEL, @@ -590,7 +641,7 @@ private static Stream racingCases() { + frameType( rule.connection.getSent().stream().skip(1).findFirst().get()) + "}"); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .element(2) .matches( bb -> frameType(bb) == CANCEL || frameType(bb) == COMPLETE, @@ -715,20 +766,25 @@ private static Stream racingCases() { @Test public void simpleOnDiscardRequestChannelTest() { AssertSubscriber assertSubscriber = AssertSubscriber.create(1); - TestPublisher testPublisher = TestPublisher.create(); + Sinks.Many testPublisher = Sinks.many().unicast().onBackpressureBuffer(); - Flux payloadFlux = rule.socket.requestChannel(testPublisher); + Flux payloadFlux = rule.socket.requestChannel(testPublisher.asFlux()); payloadFlux.subscribe(assertSubscriber); - testPublisher.next( - ByteBufPayload.create("d", "m"), - ByteBufPayload.create("d1", "m1"), - ByteBufPayload.create("d2", "m2")); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d"), ByteBufUtil.writeUtf8(rule.alloc(), "m"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d1"), ByteBufUtil.writeUtf8(rule.alloc(), "m1"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d2"), ByteBufUtil.writeUtf8(rule.alloc(), "m2"))); assertSubscriber.cancel(); - Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); rule.assertHasNoLeaks(); } @@ -737,22 +793,29 @@ public void simpleOnDiscardRequestChannelTest() { public void simpleOnDiscardRequestChannelTest2() { ByteBufAllocator allocator = rule.alloc(); AssertSubscriber assertSubscriber = AssertSubscriber.create(1); - TestPublisher testPublisher = TestPublisher.create(); + Sinks.Many testPublisher = Sinks.many().unicast().onBackpressureBuffer(); - Flux payloadFlux = rule.socket.requestChannel(testPublisher); + Flux payloadFlux = rule.socket.requestChannel(testPublisher.asFlux()); payloadFlux.subscribe(assertSubscriber); - testPublisher.next(ByteBufPayload.create("d", "m")); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d"), ByteBufUtil.writeUtf8(rule.alloc(), "m"))); int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); - testPublisher.next(ByteBufPayload.create("d1", "m1"), ByteBufPayload.create("d2", "m2")); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d1"), ByteBufUtil.writeUtf8(rule.alloc(), "m1"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d2"), ByteBufUtil.writeUtf8(rule.alloc(), "m2"))); rule.connection.addToReceivedBuffer( ErrorFrameCodec.encode( allocator, streamId, new CustomRSocketException(0x00000404, "test"))); - Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); rule.assertHasNoLeaks(); } @@ -770,7 +833,7 @@ public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( switch (frameType) { case REQUEST_FNF: response = - testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p).then(Mono.empty())); + testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p)).then(Mono.empty()); break; case REQUEST_RESPONSE: response = testPublisher.mono().flatMap(p -> rule.socket.requestResponse(p)); @@ -786,7 +849,7 @@ public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( } response.subscribe(assertSubscriber); - testPublisher.next(ByteBufPayload.create("d")); + testPublisher.next(ByteBufPayload.create(ByteBufUtil.writeUtf8(rule.alloc(), "d"))); int streamId = rule.getStreamIdForRequestType(frameType); @@ -820,21 +883,21 @@ public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( } for (int i = 1; i < framesCnt; i++) { - testPublisher.next(ByteBufPayload.create("d" + i)); + testPublisher.next(ByteBufPayload.create(ByteBufUtil.writeUtf8(rule.alloc(), "d" + i))); } - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .describedAs( "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, framesCnt) .hasSize(framesCnt) .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)) .allMatch(ByteBuf::release); - Assertions.assertThat(assertSubscriber.isTerminated()) + assertThat(assertSubscriber.isTerminated()) .describedAs("Interaction Type :[%s]. Expected to be terminated", frameType) .isTrue(); - Assertions.assertThat(assertSubscriber.values()) + assertThat(assertSubscriber.values()) .describedAs( "Interaction Type :[%s]. Expected to observe %s frames received", frameType, responsesCnt) @@ -858,7 +921,10 @@ static Stream encodeDecodePayloadCases() { @MethodSource("refCntCases") public void ensureSendsErrorOnIllegalRefCntPayload( BiFunction> sourceProducer) { - Payload invalidPayload = ByteBufPayload.create("test", "test"); + Payload invalidPayload = + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "test"), + ByteBufUtil.writeUtf8(rule.alloc(), "test")); invalidPayload.release(); Publisher source = sourceProducer.apply(invalidPayload, rule); @@ -876,7 +942,8 @@ private static Stream>> refCn (p, clientSocketRule) -> clientSocketRule.socket.requestChannel(Mono.just(p)), (p, clientSocketRule) -> { Flux.from(clientSocketRule.connection.getSentAsPublisher()) - .filter(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_CHANNEL) + .filter(bb -> frameType(bb) == REQUEST_CHANNEL) + .doOnDiscard(ByteBuf.class, ReferenceCounted::release) .subscribe( bb -> { clientSocketRule.connection.addToReceivedBuffer( @@ -897,12 +964,12 @@ public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { Payload payload2 = ByteBufPayload.create("abc2"); Mono fnf2 = rule.socket.fireAndForget(payload2); - Assertions.assertThat(rule.connection.getSent()).isEmpty(); + assertThat(rule.connection.getSent()).isEmpty(); // checks that fnf2 should have id 1 even though it was generated later than fnf1 AssertSubscriber voidAssertSubscriber2 = fnf2.subscribeWith(AssertSubscriber.create(0)); voidAssertSubscriber2.assertTerminated().assertNoError(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> frameType(bb) == REQUEST_FNF) @@ -920,7 +987,7 @@ public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { // checks that fnf1 should have id 3 even though it was generated earlier AssertSubscriber voidAssertSubscriber1 = fnf1.subscribeWith(AssertSubscriber.create(0)); voidAssertSubscriber1.assertTerminated().assertNoError(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> frameType(bb) == REQUEST_FNF) @@ -944,7 +1011,7 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( Payload payload2 = ByteBufPayload.create("abc2"); Publisher interaction2 = interaction.apply(rule, payload2); - Assertions.assertThat(rule.connection.getSent()).isEmpty(); + assertThat(rule.connection.getSent()).isEmpty(); AssertSubscriber assertSubscriber1 = AssertSubscriber.create(0); interaction1.subscribe(assertSubscriber1); @@ -953,12 +1020,12 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( assertSubscriber1.assertNotTerminated().assertNoError(); assertSubscriber2.assertNotTerminated().assertNoError(); // even though we subscribed, nothing should happen until the first requestN - Assertions.assertThat(rule.connection.getSent()).isEmpty(); + assertThat(rule.connection.getSent()).isEmpty(); // first request on the second interaction to ensure that stream id issuing on the first request assertSubscriber2.request(1); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) .first() .matches(bb -> frameType(bb) == frameType) @@ -987,7 +1054,7 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( .matches(ReferenceCounted::release); if (frameType == REQUEST_CHANNEL) { - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .element(1) .matches(bb -> frameType(bb) == COMPLETE) .matches( @@ -1001,7 +1068,7 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( rule.connection.clearSendReceiveBuffers(); assertSubscriber1.request(1); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) .first() .matches(bb -> frameType(bb) == frameType) @@ -1030,7 +1097,7 @@ public void ensuresThatNoOpsMustHappenUntilFirstRequestN( .matches(ReferenceCounted::release); if (frameType == REQUEST_CHANNEL) { - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .element(1) .matches(bb -> frameType(bb) == COMPLETE) .matches( @@ -1068,7 +1135,7 @@ public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( FrameType interactionType2) { Assumptions.assumeThat(interactionType1).isNotEqualTo(METADATA_PUSH); Assumptions.assumeThat(interactionType2).isNotEqualTo(METADATA_PUSH); - for (int i = 1; i < 10000; i += 4) { + for (int i = 1; i < RaceTestConstants.REPEATS; i += 4) { Payload payload = DefaultPayload.create("test", "test"); Publisher publisher1 = interaction1.apply(rule, payload); Publisher publisher2 = interaction2.apply(rule, payload); @@ -1076,7 +1143,7 @@ public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( () -> publisher1.subscribe(AssertSubscriber.create()), () -> publisher2.subscribe(AssertSubscriber.create())); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .extracting(FrameHeaderCodec::streamId) .containsExactly(i, i + 2); rule.connection.getSent().forEach(bb -> bb.release()); @@ -1153,7 +1220,7 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( BiFunction> interaction2, FrameType interactionType1, FrameType interactionType2) { - for (int i = 1; i < 10000; i++) { + for (int i = 1; i < RaceTestConstants.REPEATS; i++) { Payload payload1 = ByteBufPayload.create("test", "test"); Payload payload2 = ByteBufPayload.create("test", "test"); AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); @@ -1162,12 +1229,8 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( Publisher publisher2 = interaction2.apply(rule, payload2); RaceTestUtils.race( () -> rule.socket.dispose(), - () -> - RaceTestUtils.race( - () -> publisher1.subscribe(assertSubscriber1), - () -> publisher2.subscribe(assertSubscriber2), - Schedulers.parallel()), - Schedulers.parallel()); + () -> publisher1.subscribe(assertSubscriber1), + () -> publisher2.subscribe(assertSubscriber2)); assertSubscriber1.await().assertTerminated(); if (interactionType1 != REQUEST_FNF && interactionType1 != METADATA_PUSH) { @@ -1192,16 +1255,15 @@ public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( } } - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.connection.getSent().clear(); - Assertions.assertThat(payload1.refCnt()).isZero(); - Assertions.assertThat(payload2.refCnt()).isZero(); + assertThat(payload1.refCnt()).isZero(); + assertThat(payload2.refCnt()).isZero(); } } @Test - @Disabled("Reactor 3.4.0 should fix that. No need to do anything on our side") // see https://github.com/rsocket/rsocket-java/issues/858 public void testWorkaround858() { ByteBuf buffer = rule.alloc().buffer(); @@ -1212,13 +1274,13 @@ public void testWorkaround858() { rule.connection.addToReceivedBuffer( ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test"))); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_RESPONSE) .matches(ByteBuf::release); - Assertions.assertThat(rule.socket.isDisposed()).isFalse(); + assertThat(rule.socket.isDisposed()).isFalse(); rule.assertHasNoLeaks(); } @@ -1279,7 +1341,7 @@ void reassembleMetadata( .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) .assertNext( responsePayload -> { - PayloadAssert.assertThat(requestPayload).isEqualTo(metadataOnlyPayload).hasNoLeaks(); + PayloadAssert.assertThat(responsePayload).isEqualTo(metadataOnlyPayload).hasNoLeaks(); metadataOnlyPayload.release(); }) .thenCancel() @@ -1362,9 +1424,60 @@ public void errorFragmentTooSmall( leaksTrackingByteBufAllocator.assertHasNoLeaks(); } + @ParameterizedTest + @ValueSource(strings = {"stream", "channel"}) + // see https://github.com/rsocket/rsocket-java/issues/959 + public void testWorkaround959(String type) { + for (int i = 1; i < 20000; i += 2) { + ByteBuf buffer = rule.alloc().buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(3); + if (type.equals("stream")) { + rule.socket.requestStream(ByteBufPayload.create(buffer)).subscribe(assertSubscriber); + } else if (type.equals("channel")) { + rule.socket + .requestChannel(Flux.just(ByteBufPayload.create(buffer))) + .subscribe(assertSubscriber); + } + + final ByteBuf payloadFrame = + PayloadFrameCodec.encode( + rule.alloc(), i, false, false, true, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + + RaceTestUtils.race( + () -> { + rule.connection.addToReceivedBuffer(payloadFrame.copy()); + rule.connection.addToReceivedBuffer(payloadFrame.copy()); + rule.connection.addToReceivedBuffer(payloadFrame); + }, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + assertThat(rule.socket.isDisposed()).isFalse(); + + assertSubscriber.values().forEach(ReferenceCountUtil::safeRelease); + assertSubscriber.assertNoError(); + + rule.connection.clearSendReceiveBuffers(); + rule.assertHasNoLeaks(); + } + } + public static class ClientSocketRule extends AbstractSocketRule { + + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + @Override protected RSocketRequester newRSocket() { + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); return new RSocketRequester( connection, PayloadDecoder.ZERO_COPY, @@ -1375,11 +1488,16 @@ protected RSocketRequester newRSocket() { Integer.MAX_VALUE, Integer.MAX_VALUE, null, - RequesterLeaseHandler.None); + (__) -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); } public int getStreamIdForRequestType(FrameType expectedFrameType) { - assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); + assertThat(connection.getSent().size()) + .describedAs("Unexpected frames sent.") + .isGreaterThanOrEqualTo(1); List framesFound = new ArrayList<>(); for (ByteBuf frame : connection.getSent()) { FrameType frameType = frameType(frame); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java index de7e48d64..4f689e396 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,10 +34,7 @@ import static io.rsocket.frame.FrameType.REQUEST_N; import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; import static io.rsocket.frame.FrameType.REQUEST_STREAM; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.anyOf; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.is; +import static org.assertj.core.api.Assertions.assertThat; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -49,6 +46,7 @@ import io.rsocket.Payload; import io.rsocket.PayloadAssert; import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; import io.rsocket.frame.CancelFrameCodec; import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameHeaderCodec; @@ -62,7 +60,8 @@ import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; import io.rsocket.util.ByteBufPayload; @@ -74,7 +73,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Stream; -import org.assertj.core.api.Assertions; import org.assertj.core.api.Assumptions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -86,7 +84,6 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; -import org.junit.runners.model.Statement; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -95,10 +92,8 @@ import reactor.core.publisher.FluxSink; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.Operators; -import reactor.core.scheduler.Scheduler; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Sinks; import reactor.test.publisher.TestPublisher; import reactor.test.util.RaceTestUtils; @@ -107,43 +102,39 @@ public class RSocketResponderTest { ServerSocketRule rule; @BeforeEach - public void setUp() throws Throwable { + public void setUp() { Hooks.onNextDropped(ReferenceCountUtil::safeRelease); Hooks.onErrorDropped(t -> {}); rule = new ServerSocketRule(); - rule.apply( - new Statement() { - @Override - public void evaluate() {} - }, - null) - .evaluate(); + rule.init(); } @AfterEach public void tearDown() { Hooks.resetOnErrorDropped(); Hooks.resetOnNextDropped(); + rule.assertHasNoLeaks(); } @Test @Timeout(2_000) @Disabled - public void testHandleKeepAlive() throws Exception { + public void testHandleKeepAlive() { rule.connection.addToReceivedBuffer( KeepAliveFrameCodec.encode(rule.alloc(), true, 0, Unpooled.EMPTY_BUFFER)); - ByteBuf sent = rule.connection.awaitSend(); - assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); + ByteBuf sent = rule.connection.awaitFrame(); + assertThat(frameType(sent)) + .describedAs("Unexpected frame sent.") + .isEqualTo(FrameType.KEEPALIVE); /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ - assertThat( - "Unexpected keep-alive frame respond flag.", - KeepAliveFrameCodec.respondFlag(sent), - is(false)); + assertThat(KeepAliveFrameCodec.respondFlag(sent)) + .describedAs("Unexpected keep-alive frame respond flag.") + .isEqualTo(false); } @Test @Timeout(2_000) - public void testHandleResponseFrameNoError() throws Exception { + public void testHandleResponseFrameNoError() { final int streamId = 4; rule.connection.clearSendReceiveBuffers(); final TestPublisher testPublisher = TestPublisher.create(); @@ -156,21 +147,20 @@ public Mono requestResponse(Payload payload) { }); rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); testPublisher.complete(); - assertThat( - "Unexpected frame sent.", - frameType(rule.connection.awaitSend()), - anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); + FrameAssert.assertThat(rule.connection.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); testPublisher.assertWasNotCancelled(); } @Test @Timeout(2_000) - @Disabled - public void testHandlerEmitsError() throws Exception { + public void testHandlerEmitsError() { final int streamId = 4; + rule.prefetch = 1; rule.sendRequest(streamId, FrameType.REQUEST_STREAM); - assertThat( - "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); + FrameAssert.assertThat(rule.connection.awaitFrame()) + .typeOf(FrameType.ERROR) + .hasData("Request-Stream not implemented.") + .hasNoLeaks(); } @Test @@ -189,12 +179,12 @@ public Mono requestResponse(Payload payload) { }); rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); + assertThat(rule.connection.getSent()).describedAs("Unexpected frame sent.").isEmpty(); rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(allocator, streamId)); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + assertThat(rule.connection.getSent()).describedAs("Unexpected frame sent.").isEmpty(); + assertThat(cancelled.get()).describedAs("Subscription not cancelled.").isTrue(); rule.assertHasNoLeaks(); } @@ -250,7 +240,7 @@ protected void hookOnSubscribe(Subscription subscription) { for (Runnable runnable : runnables) { rule.connection.clearSendReceiveBuffers(); runnable.run(); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.ERROR) @@ -260,7 +250,7 @@ protected void hookOnSubscribe(Subscription subscription) { .contains(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) .matches(ReferenceCounted::release); - assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + assertThat(cancelled.get()).describedAs("Subscription not cancelled.").isTrue(); } rule.assertHasNoLeaks(); @@ -269,15 +259,18 @@ protected void hookOnSubscribe(Subscription subscription) { @Test public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); + final Sinks.One sink = Sinks.one(); rule.setAcceptingSocket( new RSocket() { @Override public Flux requestChannel(Publisher payloads) { payloads.subscribe(assertSubscriber); - return Flux.never(); + return sink.asMono().flux(); } }, Integer.MAX_VALUE); @@ -303,19 +296,21 @@ public Flux requestChannel(Publisher payloads) { ByteBuf data3 = allocator.buffer(); data3.writeCharSequence("def3", CharsetUtil.UTF_8); ByteBuf nextFrame3 = - PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + PayloadFrameCodec.encode(allocator, 1, false, true, true, metadata3, data3); RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), () -> { - rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); - }, - assertSubscriber::cancel); + assertSubscriber.cancel(); + sink.tryEmitEmpty(); + }); - Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnComplete(1).expectNothing(); } } @@ -323,7 +318,9 @@ public Flux requestChannel(Publisher payloads) { public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); FluxSink[] sinks = new FluxSink[1]; @@ -350,20 +347,23 @@ public Flux requestChannel(Publisher payloads) { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); }); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); } } @Test public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { - Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { AssertSubscriber assertSubscriber = AssertSubscriber.create(); FluxSink[] sinks = new FluxSink[1]; @@ -386,20 +386,17 @@ public Flux requestChannel(Publisher payloads) { ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); FluxSink sink = sinks[0]; RaceTestUtils.race( - () -> - RaceTestUtils.race( - () -> rule.connection.addToReceivedBuffer(requestNFrame), - () -> rule.connection.addToReceivedBuffer(cancelFrame), - parallel), + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(cancelFrame), () -> { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); - }, - parallel); - - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + sink.complete(); + }); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); rule.assertHasNoLeaks(); } } @@ -407,10 +404,11 @@ public Flux requestChannel(Publisher payloads) { @Test public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() { - Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { FluxSink[] sinks = new FluxSink[1]; AssertSubscriber assertSubscriber = AssertSubscriber.create(); rule.setAcceptingSocket( @@ -473,41 +471,39 @@ public Flux requestChannel(Publisher payloads) { FluxSink sink = sinks[0]; RaceTestUtils.race( - () -> - RaceTestUtils.race( - () -> rule.connection.addToReceivedBuffer(requestNFrame), - () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), - parallel), + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), () -> { sink.next(np1); sink.next(np2); sink.next(np3); sink.error(new RuntimeException()); - }, - parallel); + }); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); assertSubscriber .assertTerminated() .assertError(CancellationException.class) .assertErrorMessage("Outbound has terminated with an error"); - Assertions.assertThat(assertSubscriber.values()) + assertThat(assertSubscriber.values()) .allMatch( msg -> { ReferenceCountUtil.safeRelease(msg); return msg.refCnt() == 0; }); rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnError(1).expectNothing(); } } @Test public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { - Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { FluxSink[] sinks = new FluxSink[1]; rule.setAcceptingSocket( @@ -530,21 +526,23 @@ public Flux requestStream(Payload payload) { sink.next(ByteBufPayload.create("d1", "m1")); sink.next(ByteBufPayload.create("d2", "m2")); sink.next(ByteBufPayload.create("d3", "m3")); - }, - parallel); + }); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + + testRequestInterceptor.expectOnStart(1, REQUEST_STREAM).expectOnCancel(1).expectNothing(); } } @Test public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestResponseTest1() { - Scheduler parallel = Schedulers.parallel(); Hooks.onErrorDropped((e) -> {}); ByteBufAllocator allocator = rule.alloc(); - for (int i = 0; i < 10000; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; rule.setAcceptingSocket( @@ -570,12 +568,21 @@ public void subscribe(CoreSubscriber actual) { () -> rule.connection.addToReceivedBuffer(cancelFrame), () -> { sources[0].complete(ByteBufPayload.create("d1", "m1")); - }, - parallel); + }); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, REQUEST_RESPONSE) + .assertNext( + e -> + assertThat(e.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_CANCEL)) + .expectNothing(); } } @@ -604,7 +611,7 @@ public Flux requestStream(Payload payload) { sink.next(ByteBufPayload.create("d3", "m3")); rule.connection.addToReceivedBuffer(cancelFrame); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @@ -650,7 +657,7 @@ public Flux requestChannel(Publisher payloads) { rule.connection.addToReceivedBuffer(cancelFrame); - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); rule.assertHasNoLeaks(); } @@ -720,8 +727,7 @@ public Flux requestChannel(Publisher payloads) { } if (responsesCnt > 0) { - Assertions.assertThat( - rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) + assertThat(rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) .describedAs( "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, responsesCnt) .hasSize(responsesCnt) @@ -729,8 +735,7 @@ public Flux requestChannel(Publisher payloads) { } if (framesCnt > 1) { - Assertions.assertThat( - rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) + assertThat(rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) .describedAs( "Interaction Type :[%s]. Expected to observe single RequestN(%s) frame", frameType, framesCnt - 1) @@ -739,9 +744,9 @@ public Flux requestChannel(Publisher payloads) { .matches(bb -> RequestNFrameCodec.requestN(bb) == (framesCnt - 1)); } - Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); - Assertions.assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) + assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) .hasSize(framesCnt) .allMatch(p -> !p.hasMetadata()) .allMatch(ReferenceCounted::release); @@ -786,7 +791,7 @@ public Flux requestChannel(Publisher payloads) { rule.sendRequest(1, frameType); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches( @@ -795,7 +800,8 @@ public Flux requestChannel(Publisher payloads) { + ERROR + "} but was {" + frameType(rule.connection.getSent().iterator().next()) - + "}"); + + "}") + .matches(ByteBuf::release); } private static Stream refCntCases() { @@ -803,7 +809,6 @@ private static Stream refCntCases() { } @Test - @Disabled("Reactor 3.4.0 should fix that. No need to do anything on our side") // see https://github.com/rsocket/rsocket-java/issues/858 public void testWorkaround858() { ByteBuf buffer = rule.alloc().buffer(); @@ -827,13 +832,13 @@ public Flux requestChannel(Publisher payloads) { rule.connection.addToReceivedBuffer( ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test"))); - Assertions.assertThat(rule.connection.getSent()) + assertThat(rule.connection.getSent()) .hasSize(1) .first() .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_N) .matches(ReferenceCounted::release); - Assertions.assertThat(rule.socket.isDisposed()).isFalse(); + assertThat(rule.socket.isDisposed()).isFalse(); testPublisher.assertWasCancelled(); rule.assertHasNoLeaks(); @@ -1063,32 +1068,32 @@ public Flux requestChannel(Publisher payloads) { void receivingRequestOnStreamIdThaIsAlreadyInUseMUSTBeIgnored_ReassemblyCase( FrameType requestType) { AtomicReference receivedPayload = new AtomicReference<>(); - final MonoProcessor delayer = MonoProcessor.create(); + final Sinks.Empty delayer = Sinks.empty(); rule.setAcceptingSocket( new RSocket() { @Override public Mono fireAndForget(Payload payload) { receivedPayload.set(payload); - return delayer; + return delayer.asMono(); } @Override public Mono requestResponse(Payload payload) { receivedPayload.set(payload); - return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer); + return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); } @Override public Flux requestStream(Payload payload) { receivedPayload.set(payload); - return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); } @Override public Flux requestChannel(Publisher payloads) { Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); - return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); } }); final Payload randomPayload1 = fixedSizePayload(rule.allocator, 128); @@ -1104,9 +1109,9 @@ public Flux requestChannel(Publisher payloads) { rule.connection.addToReceivedBuffer(fragments1.toArray(new ByteBuf[0])); if (requestType != REQUEST_CHANNEL) { rule.connection.addToReceivedBuffer(fragments2.toArray(new ByteBuf[0])); - delayer.onComplete(); + delayer.tryEmitEmpty(); } else { - delayer.onComplete(); + delayer.tryEmitEmpty(); rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.allocator, 1)); rule.connection.addToReceivedBuffer(fragments2.toArray(new ByteBuf[0])); } @@ -1132,25 +1137,25 @@ public Flux requestChannel(Publisher payloads) { void receivingRequestOnStreamIdThaIsAlreadyInUseMUSTBeIgnored(FrameType requestType) { Assumptions.assumeThat(requestType).isNotEqualTo(REQUEST_FNF); AtomicReference receivedPayload = new AtomicReference<>(); - final MonoProcessor delayer = MonoProcessor.create(); + final Sinks.One delayer = Sinks.one(); rule.setAcceptingSocket( new RSocket() { @Override public Mono requestResponse(Payload payload) { receivedPayload.set(payload); - return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer); + return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); } @Override public Flux requestStream(Payload payload) { receivedPayload.set(payload); - return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); } @Override public Flux requestChannel(Publisher payloads) { Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); - return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); } }); final Payload randomPayload1 = fixedSizePayload(rule.allocator, 64); @@ -1158,7 +1163,7 @@ public Flux requestChannel(Publisher payloads) { rule.sendRequest(1, requestType, randomPayload1.retain()); rule.sendRequest(1, requestType, randomPayload2); - delayer.onComplete(); + delayer.tryEmitEmpty(); PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload1).hasNoLeaks(); randomPayload1.release(); @@ -1178,9 +1183,11 @@ public static class ServerSocketRule extends AbstractSocketRule onCloseSink; @Override - protected void init() { + protected void doInit() { acceptingSocket = new RSocket() { @Override @@ -1188,7 +1195,7 @@ public Mono requestResponse(Payload payload) { return Mono.just(payload); } }; - super.init(); + super.doInit(); } public void setAcceptingSocket(RSocket acceptingSocket) { @@ -1196,7 +1203,12 @@ public void setAcceptingSocket(RSocket acceptingSocket) { connection = new TestDuplexConnection(alloc()); connectSub = TestSubscriber.create(); this.prefetch = Integer.MAX_VALUE; - super.init(); + super.doInit(); + } + + public void setRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + super.doInit(); } public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { @@ -1204,19 +1216,22 @@ public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { connection = new TestDuplexConnection(alloc()); connectSub = TestSubscriber.create(); this.prefetch = prefetch; - super.init(); + super.doInit(); } @Override protected RSocketResponder newRSocket() { + onCloseSink = Sinks.empty(); return new RSocketResponder( connection, acceptingSocket, PayloadDecoder.ZERO_COPY, - ResponderLeaseHandler.None, + null, 0, maxFrameLength, - maxInboundPayloadSize); + maxInboundPayloadSize, + __ -> requestInterceptor, + onCloseSink); } private void sendRequest(int streamId, FrameType frameType) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java index 073ebfd06..90e881257 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java @@ -1,9 +1,12 @@ package io.rsocket.core; +import io.rsocket.Closeable; +import io.rsocket.FrameAssert; +import io.rsocket.frame.FrameType; import io.rsocket.test.util.TestClientTransport; import io.rsocket.test.util.TestServerTransport; import org.assertj.core.api.Assertions; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class RSocketServerFragmentationTest { @@ -16,12 +19,18 @@ public void serverErrorsWithEnabledFragmentationOnInsufficientMtu() { @Test public void serverSucceedsWithEnabledFragmentationOnSufficientMtu() { - RSocketServer.create().fragment(100).bind(new TestServerTransport()).block(); + TestServerTransport transport = new TestServerTransport(); + Closeable closeable = RSocketServer.create().fragment(100).bind(transport).block(); + closeable.dispose(); + transport.alloc().assertHasNoLeaks(); } @Test public void serverSucceedsWithDisabledFragmentation() { - RSocketServer.create().bind(new TestServerTransport()).block(); + TestServerTransport transport = new TestServerTransport(); + Closeable closeable = RSocketServer.create().bind(transport).block(); + closeable.dispose(); + transport.alloc().assertHasNoLeaks(); } @Test @@ -33,11 +42,23 @@ public void clientErrorsWithEnabledFragmentationOnInsufficientMtu() { @Test public void clientSucceedsWithEnabledFragmentationOnSufficientMtu() { - RSocketConnector.create().fragment(100).connect(new TestClientTransport()).block(); + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.create().fragment(100).connect(transport).block(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .typeOf(FrameType.SETUP) + .hasNoLeaks(); + transport.testConnection().dispose(); + transport.alloc().assertHasNoLeaks(); } @Test public void clientSucceedsWithDisabledFragmentation() { - RSocketConnector.connectWith(new TestClientTransport()).block(); + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.connectWith(transport).block(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .typeOf(FrameType.SETUP) + .hasNoLeaks(); + transport.testConnection().dispose(); + transport.alloc().assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java index a6103a2ba..a335ac1f3 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; @@ -5,19 +20,80 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.rsocket.Closeable; +import io.rsocket.FrameAssert; import io.rsocket.RSocket; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.SetupFrameCodec; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestServerTransport; +import io.rsocket.util.EmptyPayload; import java.time.Duration; import java.util.Random; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; public class RSocketServerTest { + @Test + public void unexpectedFramesBeforeSetupFrame() { + TestServerTransport transport = new TestServerTransport(); + RSocketServer.create().bind(transport).block(); + + final TestDuplexConnection duplexConnection = transport.connect(); + + duplexConnection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(duplexConnection.alloc(), false, 1, Unpooled.EMPTY_BUFFER)); + + StepVerifier.create(duplexConnection.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection.pollFrame()) + .isNotNull() + .typeOf(FrameType.ERROR) + .hasData("SETUP or RESUME frame must be received before any others") + .hasStreamIdZero() + .hasNoLeaks(); + duplexConnection.alloc().assertHasNoLeaks(); + } + + @Test + public void timeoutOnNoFirstFrame() { + final VirtualTimeScheduler scheduler = VirtualTimeScheduler.getOrSet(); + TestServerTransport transport = new TestServerTransport(); + try { + RSocketServer.create().maxTimeToFirstFrame(Duration.ofMinutes(2)).bind(transport).block(); + + final TestDuplexConnection duplexConnection = transport.connect(); + + scheduler.advanceTimeBy(Duration.ofMinutes(1)); + + Assertions.assertThat(duplexConnection.isDisposed()).isFalse(); + + scheduler.advanceTimeBy(Duration.ofMinutes(1)); + + StepVerifier.create(duplexConnection.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection.pollFrame()).isNull(); + } finally { + transport.alloc().assertHasNoLeaks(); + VirtualTimeScheduler.reset(); + } + } + @Test public void ensuresMaxFrameLengthCanNotBeLessThenMtu() { RSocketServer.create() @@ -55,17 +131,18 @@ public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPossibleFrameLength() { @Test public void unexpectedFramesBeforeSetup() { - MonoProcessor connectedMono = MonoProcessor.create(); + Sinks.Empty connectedSink = Sinks.empty(); TestServerTransport transport = new TestServerTransport(); - RSocketServer.create() - .acceptor( - (setup, sendingSocket) -> { - connectedMono.onComplete(); - return Mono.just(new RSocket() {}); - }) - .bind(transport) - .block(); + Closeable server = + RSocketServer.create() + .acceptor( + (setup, sendingSocket) -> { + connectedSink.tryEmitEmpty(); + return Mono.just(new RSocket() {}); + }) + .bind(transport) + .block(); byte[] bytes = new byte[16_000_000]; new Random().nextBytes(bytes); @@ -80,6 +157,45 @@ public void unexpectedFramesBeforeSetup() { ByteBufAllocator.DEFAULT.buffer(bytes.length).writeBytes(bytes))); StepVerifier.create(connection.onClose()).expectComplete().verify(Duration.ofSeconds(30)); - assertThat(connectedMono.isTerminated()).as("Connection should not succeed").isFalse(); + assertThat(connectedSink.scan(Scannable.Attr.TERMINATED)) + .as("Connection should not succeed") + .isFalse(); + FrameAssert.assertThat(connection.pollFrame()) + .hasStreamIdZero() + .hasData("SETUP or RESUME frame must be received before any others") + .hasNoLeaks(); + server.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresErrorFrameDeliveredPriorConnectionDisposal() { + TestServerTransport transport = new TestServerTransport(); + Closeable server = + RSocketServer.create() + .acceptor( + (setup, sendingSocket) -> Mono.error(new RejectedSetupException("ACCESS_DENIED"))) + .bind(transport) + .block(); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer( + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + false, + 0, + 1, + Unpooled.EMPTY_BUFFER, + "metadata_type", + "data_type", + EmptyPayload.INSTANCE)); + + StepVerifier.create(connection.onClose()).expectComplete().verify(Duration.ofSeconds(30)); + FrameAssert.assertThat(connection.pollFrame()) + .hasStreamIdZero() + .hasData("ACCESS_DENIED") + .hasNoLeaks(); + server.dispose(); + transport.alloc().assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java index 38745327e..e01e6ebdc 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,8 +27,6 @@ import io.rsocket.exceptions.CustomRSocketException; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.subscriber.AssertSubscriber; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; import io.rsocket.test.util.LocalDuplexConnection; import io.rsocket.util.DefaultPayload; import io.rsocket.util.EmptyPayload; @@ -37,23 +35,32 @@ import java.util.concurrent.CancellationException; import java.util.concurrent.atomic.AtomicReference; import org.assertj.core.api.Assertions; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.reactivestreams.Publisher; import reactor.core.Disposable; import reactor.core.Disposables; -import reactor.core.publisher.DirectProcessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; public class RSocketTest { - @Rule public final SocketRule rule = new SocketRule(); + public final SocketRule rule = new SocketRule(); + + @BeforeEach + public void setup() { + rule.init(); + } + + @AfterEach + public void tearDownAndCheckOnLeaks() { + rule.alloc().assertHasNoLeaks(); + } @Test public void rsocketDisposalShouldEndupWithNoErrorsOnClose() { @@ -83,7 +90,8 @@ public boolean isDisposed() { Assertions.assertThat(requestHandlingRSocket.isDisposed()).isTrue(); } - @Test(timeout = 2_000) + @Test + @Timeout(2_000) public void testRequestReplyNoError() { StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) .expectNextCount(1) @@ -91,7 +99,8 @@ public void testRequestReplyNoError() { .verify(); } - @Test(timeout = 2000) + @Test + @Timeout(2000) public void testHandlerEmitsError() { rule.setRequestAcceptor( new RSocket() { @@ -111,7 +120,8 @@ public Mono requestResponse(Payload payload) { .verify(Duration.ofMillis(100)); } - @Test(timeout = 2000) + @Test + @Timeout(2000) public void testHandlerEmitsCustomError() { rule.setRequestAcceptor( new RSocket() { @@ -133,7 +143,8 @@ public Mono requestResponse(Payload payload) { .verify(); } - @Test(timeout = 2000) + @Test + @Timeout(2000) public void testRequestPropagatesCorrectlyForRequestChannel() { rule.setRequestAcceptor( new RSocket() { @@ -142,7 +153,7 @@ public Flux requestChannel(Publisher payloads) { return Flux.from(payloads) // specifically limits request to 3 in order to prevent 256 request from limitRate // hidden on the responder side - .limitRequest(3); + .take(3, true); } }); @@ -156,21 +167,24 @@ public Flux requestChannel(Publisher payloads) { .verify(Duration.ofMillis(5000)); } - @Test(timeout = 2000) - public void testStream() throws Exception { + @Test + @Timeout(2000) + public void testStream() { Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); } - @Test(timeout = 200000) - public void testChannel() throws Exception { + @Test + @Timeout(200000) + public void testChannel() { Flux requests = Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); Flux responses = rule.crs.requestChannel(requests); StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); } - @Test(timeout = 2000) + @Test + @Timeout(2000) public void testErrorPropagatesCorrectly() { AtomicReference error = new AtomicReference<>(); rule.setRequestAcceptor( @@ -489,10 +503,10 @@ void errorFromRequesterPublisher( responderPublisher.assertNoSubscribers(); } - public static class SocketRule extends ExternalResource { + public static class SocketRule { - DirectProcessor serverProcessor; - DirectProcessor clientProcessor; + Sinks.Many serverProcessor; + Sinks.Many clientProcessor; private RSocketRequester crs; @SuppressWarnings("unused") @@ -501,26 +515,20 @@ public static class SocketRule extends ExternalResource { private RSocket requestAcceptor; private LeaksTrackingByteBufAllocator allocator; - - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - init(); - base.evaluate(); - } - }; - } + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; public LeaksTrackingByteBufAllocator alloc() { return allocator; } - protected void init() { + public void init() { allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - serverProcessor = DirectProcessor.create(); - clientProcessor = DirectProcessor.create(); + serverProcessor = Sinks.many().multicast().directBestEffort(); + clientProcessor = Sinks.many().multicast().directBestEffort(); + + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); LocalDuplexConnection serverConnection = new LocalDuplexConnection("server", allocator, clientProcessor, serverProcessor); @@ -542,8 +550,7 @@ public Mono requestResponse(Payload payload) { @Override public Flux requestStream(Payload payload) { return Flux.range(1, 10) - .map( - i -> DefaultPayload.create("server got -> [" + payload.toString() + "]")); + .map(i -> DefaultPayload.create("server got -> [" + payload + "]")); } @Override @@ -566,10 +573,12 @@ public Flux requestChannel(Publisher payloads) { serverConnection, requestAcceptor, PayloadDecoder.DEFAULT, - ResponderLeaseHandler.None, + null, 0, FRAME_LENGTH_MASK, - Integer.MAX_VALUE); + Integer.MAX_VALUE, + __ -> null, + otherClosedSink); crs = new RSocketRequester( @@ -582,7 +591,10 @@ public Flux requestChannel(Publisher payloads) { 0, 0, null, - RequesterLeaseHandler.None); + __ -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); } public void setRequestAcceptor(RSocket requestAcceptor) { diff --git a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java index 8d96222df..3112a0943 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,11 @@ package io.rsocket.core; -import static org.junit.Assert.assertEquals; +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.subscriber.AssertSubscriber; import java.io.IOException; import java.time.Duration; import java.util.ArrayList; @@ -29,9 +32,10 @@ import java.util.concurrent.TimeoutException; import java.util.function.BiConsumer; import java.util.function.Consumer; +import java.util.function.Supplier; import java.util.stream.Collectors; import org.assertj.core.api.Assertions; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.mockito.Mockito; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; @@ -39,7 +43,6 @@ import reactor.core.Scannable; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import reactor.test.publisher.TestPublisher; @@ -58,7 +61,7 @@ public class ReconnectMonoTests { public void shouldExpireValueOnRacingDisposeAndNext() { Hooks.onErrorDropped(t -> {}); Hooks.onNextDropped(System.out::println); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; final CoreSubscriber[] monoSubscribers = new CoreSubscriber[1]; Subscription mockSubscription = Mockito.mock(Subscription.class); @@ -76,26 +79,27 @@ public void subscribe(CoreSubscriber actual) { .doOnDiscard(Object.class, System.out::println) .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); RaceTestUtils.race(() -> monoSubscribers[0].onNext("value" + index), reconnectMono::dispose); monoSubscribers[0].onComplete(); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); Mockito.verify(mockSubscription).cancel(); - if (processor.isError()) { - Assertions.assertThat(processor.getError()) - .isInstanceOf(CancellationException.class) - .hasMessage("ReconnectMono has already been disposed"); + if (!subscriber.errors().isEmpty()) { + subscriber + .assertError(CancellationException.class) + .assertErrorMessage("ReconnectMono has already been disposed"); - Assertions.assertThat(expired).containsOnly("value" + i); + assertThat(expired).containsOnly("value" + i); } else { - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + subscriber.assertValues("value" + i); } expired.clear(); @@ -106,41 +110,38 @@ public void subscribe(CoreSubscriber actual) { @Test public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); - final MonoProcessor racerProcessor = MonoProcessor.create(); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); cold.next("value" + i); - RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(racerProcessor)); - - Assertions.assertThat(processor.isTerminated()).isTrue(); + RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(raceSubscriber)); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); - Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + subscriber.assertTerminated(); + subscriber.assertValues("value" + i); + raceSubscriber.assertValues("value" + i); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) - .isEqualTo(ResolvingOperator.READY); + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); - Assertions.assertThat( + assertThat( reconnectMono.resolvingInner.add( new ResolvingOperator.MonoDeferredResolutionOperator<>( - reconnectMono.resolvingInner, processor))) + reconnectMono.resolvingInner, subscriber))) .isEqualTo(ResolvingOperator.READY_STATE); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); received.clear(); } @@ -149,7 +150,7 @@ public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete( @Test public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); @@ -157,11 +158,12 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); - final MonoProcessor racerProcessor = MonoProcessor.create(); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); reconnectMono.resolvingInner.mainSubscriber.onComplete(); @@ -169,38 +171,33 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() RaceTestUtils.race( reconnectMono::invalidate, () -> { - reconnectMono.subscribe(racerProcessor); - if (!racerProcessor.isTerminated()) { + reconnectMono.subscribe(raceSubscriber); + if (!raceSubscriber.isTerminated()) { reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_not_expire" + index); reconnectMono.resolvingInner.mainSubscriber.onComplete(); } - }, - Schedulers.parallel()); - - Assertions.assertThat(processor.isTerminated()).isTrue(); - - Assertions.assertThat(processor.peek()).isEqualTo("value_to_expire" + i); - StepVerifier.create(racerProcessor) - .expectNextMatches( - (v) -> { - if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { - return v.equals("value_to_not_expire" + index); - } else { - return v.equals("value_to_expire" + index); - } - }) - .expectComplete() - .verify(Duration.ofMillis(100)); + }); + + subscriber.assertTerminated(); + subscriber.assertValues("value_to_expire" + i); + + raceSubscriber.assertComplete(); + String v = raceSubscriber.values().get(0); + if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { + assertThat(v).isEqualTo("value_to_not_expire" + index); + } else { + assertThat(v).isEqualTo("value_to_expire" + index); + } - Assertions.assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { - Assertions.assertThat(received) + assertThat(received) .hasSize(2) .containsExactly( Tuples.of("value_to_expire" + i, reconnectMono), Tuples.of("value_to_not_expire" + i, reconnectMono)); } else { - Assertions.assertThat(received) + assertThat(received) .hasSize(1) .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); } @@ -213,7 +210,7 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() @Test public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); @@ -221,55 +218,50 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates( final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); - final MonoProcessor racerProcessor = MonoProcessor.create(); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); reconnectMono.resolvingInner.mainSubscriber.onComplete(); RaceTestUtils.race( - () -> - RaceTestUtils.race( - reconnectMono::invalidate, reconnectMono::invalidate, Schedulers.parallel()), + reconnectMono::invalidate, + reconnectMono::invalidate, () -> { - reconnectMono.subscribe(racerProcessor); - if (!racerProcessor.isTerminated()) { + reconnectMono.subscribe(raceSubscriber); + if (!raceSubscriber.isTerminated()) { reconnectMono.resolvingInner.mainSubscriber.onNext( "value_to_possibly_expire" + index); reconnectMono.resolvingInner.mainSubscriber.onComplete(); } - }, - Schedulers.parallel()); + }); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); + subscriber.assertValues("value_to_expire" + i); - Assertions.assertThat(processor.peek()).isEqualTo("value_to_expire" + i); - StepVerifier.create(racerProcessor) - .expectNextMatches( - (v) -> - v.equals("value_to_possibly_expire" + index) - || v.equals("value_to_expire" + index)) - .expectComplete() - .verify(Duration.ofMillis(100)); + raceSubscriber.assertComplete(); + assertThat(raceSubscriber.values().get(0)) + .isIn("value_to_possibly_expire" + index, "value_to_expire" + index); if (expired.size() == 2) { - Assertions.assertThat(expired) + assertThat(expired) .hasSize(2) .containsExactly("value_to_expire" + i, "value_to_possibly_expire" + i); } else { - Assertions.assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); } if (received.size() == 2) { - Assertions.assertThat(received) + assertThat(received) .hasSize(2) .containsExactly( Tuples.of("value_to_expire" + i, reconnectMono), Tuples.of("value_to_possibly_expire" + i, reconnectMono)); } else { - Assertions.assertThat(received) + assertThat(received) .hasSize(1) .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); } @@ -282,58 +274,62 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates( @Test public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; - final TestPublisher cold = - TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + final Mono source = + Mono.fromSupplier( + new Supplier() { + boolean once = false; + + @Override + public String get() { + + if (!once) { + once = true; + return "value_to_expire" + index; + } + + return "value_to_not_expire" + index; + } + }); final ReconnectMono reconnectMono = - cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + new ReconnectMono<>( + source.subscribeOn(Schedulers.boundedElastic()), onExpire(), onValue()); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); - reconnectMono.resolvingInner.mainSubscriber.onComplete(); + subscriber.await().assertComplete(); + + assertThat(expired).isEmpty(); RaceTestUtils.race( () -> - Assertions.assertThat(reconnectMono.block()) + assertThat(reconnectMono.block()) .matches( (v) -> v.equals("value_to_not_expire" + index) || v.equals("value_to_expire" + index)), - () -> - RaceTestUtils.race( - reconnectMono::invalidate, - () -> { - for (; ; ) { - if (reconnectMono.resolvingInner.subscribers != ResolvingOperator.READY) { - reconnectMono.resolvingInner.mainSubscriber.onNext( - "value_to_not_expire" + index); - reconnectMono.resolvingInner.mainSubscriber.onComplete(); - break; - } - } - }, - Schedulers.parallel()), - Schedulers.parallel()); - - Assertions.assertThat(processor.isTerminated()).isTrue(); - - Assertions.assertThat(processor.peek()).isEqualTo("value_to_expire" + i); - - Assertions.assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + reconnectMono::invalidate); + + subscriber.assertTerminated(); + + subscriber.assertValues("value_to_expire" + i); + + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { - Assertions.assertThat(received) + await().atMost(Duration.ofSeconds(5)).until(() -> received.size() == 2); + assertThat(received) .hasSize(2) .containsExactly( Tuples.of("value_to_expire" + i, reconnectMono), Tuples.of("value_to_not_expire" + i, reconnectMono)); } else { - Assertions.assertThat(received) + assertThat(received) .hasSize(1) .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); } @@ -345,45 +341,42 @@ public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value" + i); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = MonoProcessor.create(); - final MonoProcessor racerProcessor = MonoProcessor.create(); + final AssertSubscriber subscriber = new AssertSubscriber<>(); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); - Assertions.assertThat(cold.subscribeCount()).isZero(); + assertThat(cold.subscribeCount()).isZero(); RaceTestUtils.race( - () -> reconnectMono.subscribe(processor), () -> reconnectMono.subscribe(racerProcessor)); + () -> reconnectMono.subscribe(subscriber), () -> reconnectMono.subscribe(raceSubscriber)); - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(racerProcessor.isTerminated()).isTrue(); + subscriber.assertTerminated(); + assertThat(raceSubscriber.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); - Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + subscriber.assertValues("value" + i); + raceSubscriber.assertValues("value" + i); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) - .isEqualTo(ResolvingOperator.READY); + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); - Assertions.assertThat(cold.subscribeCount()).isOne(); + assertThat(cold.subscribeCount()).isOne(); - Assertions.assertThat( + assertThat( reconnectMono.resolvingInner.add( new ResolvingOperator.MonoDeferredResolutionOperator<>( - reconnectMono.resolvingInner, processor))) + reconnectMono.resolvingInner, subscriber))) .isEqualTo(ResolvingOperator.READY_STATE); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); received.clear(); } @@ -392,45 +385,43 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { Duration timeout = Duration.ofMillis(100); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value" + i); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = MonoProcessor.create(); + final AssertSubscriber subscriber = new AssertSubscriber<>(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); - Assertions.assertThat(cold.subscribeCount()).isZero(); + assertThat(cold.subscribeCount()).isZero(); String[] values = new String[1]; RaceTestUtils.race( - () -> values[0] = reconnectMono.block(timeout), () -> reconnectMono.subscribe(processor)); + () -> values[0] = reconnectMono.block(timeout), + () -> reconnectMono.subscribe(subscriber)); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); - Assertions.assertThat(values).containsExactly("value" + i); + subscriber.assertValues("value" + i); + assertThat(values).containsExactly("value" + i); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) - .isEqualTo(ResolvingOperator.READY); + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); - Assertions.assertThat(cold.subscribeCount()).isOne(); + assertThat(cold.subscribeCount()).isOne(); - Assertions.assertThat( + assertThat( reconnectMono.resolvingInner.add( new ResolvingOperator.MonoDeferredResolutionOperator<>( - reconnectMono.resolvingInner, processor))) + reconnectMono.resolvingInner, subscriber))) .isEqualTo(ResolvingOperator.READY_STATE); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); received.clear(); } @@ -439,17 +430,17 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { Duration timeout = Duration.ofMillis(100); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value" + i); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); - Assertions.assertThat(cold.subscribeCount()).isZero(); + assertThat(cold.subscribeCount()).isZero(); String[] values1 = new String[1]; String[] values2 = new String[1]; @@ -458,24 +449,21 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { () -> values1[0] = reconnectMono.block(timeout), () -> values2[0] = reconnectMono.block(timeout)); - Assertions.assertThat(values2).containsExactly("value" + i); - Assertions.assertThat(values1).containsExactly("value" + i); + assertThat(values2).containsExactly("value" + i); + assertThat(values1).containsExactly("value" + i); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) - .isEqualTo(ResolvingOperator.READY); + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); - Assertions.assertThat(cold.subscribeCount()).isOne(); + assertThat(cold.subscribeCount()).isOne(); - Assertions.assertThat( + assertThat( reconnectMono.resolvingInner.add( new ResolvingOperator.MonoDeferredResolutionOperator<>( - reconnectMono.resolvingInner, MonoProcessor.create()))) + reconnectMono.resolvingInner, new AssertSubscriber<>()))) .isEqualTo(ResolvingOperator.READY_STATE); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); received.clear(); } @@ -484,35 +472,36 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { @Test public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); RaceTestUtils.race(cold::complete, reconnectMono::dispose); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - Throwable error = processor.getError(); + Throwable error = subscriber.errors().get(0); if (error instanceof CancellationException) { - Assertions.assertThat(error) + assertThat(error) .isInstanceOf(CancellationException.class) .hasMessage("ReconnectMono has already been disposed"); } else { - Assertions.assertThat(error) + assertThat(error) .isInstanceOf(IllegalStateException.class) .hasMessage("Source completed empty"); } - Assertions.assertThat(expired).isEmpty(); + assertThat(expired).isEmpty(); expired.clear(); received.clear(); @@ -522,36 +511,35 @@ public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { @Test public void shouldExpireValueOnRacingDisposeAndComplete() { Hooks.onErrorDropped(t -> {}); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); cold.next("value" + i); RaceTestUtils.race(cold::complete, reconnectMono::dispose); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - if (processor.isError()) { - Assertions.assertThat(processor.getError()) + if (!subscriber.errors().isEmpty()) { + assertThat(subscriber.errors().get(0)) .isInstanceOf(CancellationException.class) .hasMessage("ReconnectMono has already been disposed"); } else { - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); } - Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + assertThat(expired).hasSize(1).containsOnly("value" + i); expired.clear(); received.clear(); @@ -562,42 +550,40 @@ public void shouldExpireValueOnRacingDisposeAndComplete() { public void shouldExpireValueOnRacingDisposeAndError() { Hooks.onErrorDropped(t -> {}); RuntimeException runtimeException = new RuntimeException("test"); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); cold.next("value" + i); RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - if (processor.isError()) { - if (processor.getError() instanceof CancellationException) { - Assertions.assertThat(processor.getError()) + if (!subscriber.errors().isEmpty()) { + Throwable error = subscriber.errors().get(0); + if (error instanceof CancellationException) { + assertThat(error) .isInstanceOf(CancellationException.class) .hasMessage("ReconnectMono has already been disposed"); } else { - Assertions.assertThat(processor.getError()) - .isInstanceOf(RuntimeException.class) - .hasMessage("test"); + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("test"); } } else { - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); } - Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + assertThat(expired).hasSize(1).containsOnly("value" + i); expired.clear(); received.clear(); @@ -608,7 +594,7 @@ public void shouldExpireValueOnRacingDisposeAndError() { public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { Hooks.onErrorDropped(t -> {}); RuntimeException runtimeException = new RuntimeException("test"); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); @@ -617,35 +603,32 @@ public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { .retryWhen(Retry.max(1).filter(t -> t instanceof Exception)) .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); cold.next("value" + i); RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); - Assertions.assertThat(processor.isTerminated()).isTrue(); + subscriber.assertTerminated(); - if (processor.isError()) { - - if (processor.getError() instanceof CancellationException) { - Assertions.assertThat(processor.getError()) + if (!subscriber.errors().isEmpty()) { + Throwable error = subscriber.errors().get(0); + if (error instanceof CancellationException) { + assertThat(error) .isInstanceOf(CancellationException.class) .hasMessage("ReconnectMono has already been disposed"); } else { - Assertions.assertThat(processor.getError()) - .matches(t -> Exceptions.isRetryExhausted(t)) - .hasCause(runtimeException); + assertThat(error).matches(Exceptions::isRetryExhausted).hasCause(runtimeException); } - Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + assertThat(expired).hasSize(1).containsOnly("value" + i); } else { - Assertions.assertThat(received) - .hasSize(1) - .containsOnly(Tuples.of("value" + i, reconnectMono)); - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); } expired.clear(); @@ -693,20 +676,20 @@ public void shouldBeScannable() { final Scannable scannableOfReconnect = Scannable.from(reconnectMono); - Assertions.assertThat( + assertThat( (List) scannableOfReconnect.parents().map(s -> s.getClass()).collect(Collectors.toList())) .hasSize(1) .containsExactly(publisher.mono().getClass()); - Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) - .isEqualTo(false); - Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)).isEqualTo(false); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); - final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); - final Scannable scannableOfMonoProcessor = Scannable.from(processor); + final Scannable scannableOfMonoProcessor = Scannable.from(subscriber); - Assertions.assertThat( + assertThat( (List) scannableOfMonoProcessor .parents() @@ -721,9 +704,8 @@ public void shouldBeScannable() { reconnectMono.dispose(); - Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) - .isEqualTo(true); - Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)).isEqualTo(true); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) .isInstanceOf(CancellationException.class); } @@ -735,36 +717,36 @@ public void shouldNotExpiredIfNotCompleted() { final ReconnectMono reconnectMono = publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = new AssertSubscriber<>(); - reconnectMono.subscribe(processor); + reconnectMono.subscribe(subscriber); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.next("test"); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); reconnectMono.invalidate(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.assertSubscribers(1); - Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + assertThat(publisher.subscribeCount()).isEqualTo(1); publisher.complete(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1); - Assertions.assertThat(processor.isTerminated()).isTrue(); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + subscriber.assertTerminated(); publisher.assertSubscribers(0); - Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + assertThat(publisher.subscribeCount()).isEqualTo(1); } @Test @@ -775,26 +757,26 @@ public void shouldNotEmitUntilCompletion() { final ReconnectMono reconnectMono = publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = new AssertSubscriber<>(); - reconnectMono.subscribe(processor); + reconnectMono.subscribe(subscriber); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.next("test"); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.complete(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1); - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo("test"); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + subscriber.assertTerminated(); + subscriber.assertValues("test"); } @Test @@ -805,31 +787,30 @@ public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest() { final ReconnectMono reconnectMono = publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = new AssertSubscriber<>(); - reconnectMono.subscribe(processor); + reconnectMono.subscribe(subscriber); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); publisher.next("test"); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(processor.isTerminated()).isFalse(); + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); - processor.cancel(); + subscriber.cancel(); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers) + assertThat(reconnectMono.resolvingInner.subscribers) .isEqualTo(ResolvingOperator.EMPTY_SUBSCRIBED); publisher.complete(); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1); - Assertions.assertThat(processor.isTerminated()).isFalse(); - Assertions.assertThat(processor.peek()).isNull(); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + assertThat(subscriber.values()).isEmpty(); } @Test @@ -848,16 +829,16 @@ public void shouldExpireValueOnDispose() { .expectComplete() .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); reconnectMono.dispose(); - Assertions.assertThat(expired).hasSize(1); - Assertions.assertThat(received).hasSize(1); - Assertions.assertThat(reconnectMono.isDisposed()).isTrue(); + assertThat(expired).hasSize(1); + assertThat(received).hasSize(1); + assertThat(reconnectMono.isDisposed()).isTrue(); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectError(CancellationException.class) .verify(Duration.ofSeconds(timeout)); @@ -870,85 +851,91 @@ public void shouldNotifyAllTheSubscribers() { final ReconnectMono reconnectMono = publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - final MonoProcessor sub1 = MonoProcessor.create(); - final MonoProcessor sub2 = MonoProcessor.create(); - final MonoProcessor sub3 = MonoProcessor.create(); - final MonoProcessor sub4 = MonoProcessor.create(); + final AssertSubscriber sub1 = new AssertSubscriber<>(); + final AssertSubscriber sub2 = new AssertSubscriber<>(); + final AssertSubscriber sub3 = new AssertSubscriber<>(); + final AssertSubscriber sub4 = new AssertSubscriber<>(); reconnectMono.subscribe(sub1); reconnectMono.subscribe(sub2); reconnectMono.subscribe(sub3); reconnectMono.subscribe(sub4); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers).hasSize(4); + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(4); - final ArrayList> processors = new ArrayList<>(200); + final ArrayList> subscribers = new ArrayList<>(200); - for (int i = 0; i < 100; i++) { - final MonoProcessor subA = MonoProcessor.create(); - final MonoProcessor subB = MonoProcessor.create(); - processors.add(subA); - processors.add(subB); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final AssertSubscriber subA = new AssertSubscriber<>(); + final AssertSubscriber subB = new AssertSubscriber<>(); + subscribers.add(subA); + subscribers.add(subB); RaceTestUtils.race(() -> reconnectMono.subscribe(subA), () -> reconnectMono.subscribe(subB)); } - Assertions.assertThat(reconnectMono.resolvingInner.subscribers).hasSize(204); + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(RaceTestConstants.REPEATS * 2 + 4); - sub1.dispose(); + sub1.cancel(); - Assertions.assertThat(reconnectMono.resolvingInner.subscribers).hasSize(203); + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(RaceTestConstants.REPEATS * 2 + 3); publisher.next("value"); - Assertions.assertThatThrownBy(sub1::peek).isInstanceOf(CancellationException.class); - Assertions.assertThat(sub2.peek()).isEqualTo("value"); - Assertions.assertThat(sub3.peek()).isEqualTo("value"); - Assertions.assertThat(sub4.peek()).isEqualTo("value"); + assertThat(sub1.scan(Scannable.Attr.CANCELLED)).isTrue(); + assertThat(sub2.values().get(0)).isEqualTo("value"); + assertThat(sub3.values().get(0)).isEqualTo("value"); + assertThat(sub4.values().get(0)).isEqualTo("value"); - for (MonoProcessor sub : processors) { - Assertions.assertThat(sub.peek()).isEqualTo("value"); - Assertions.assertThat(sub.isTerminated()).isTrue(); + for (AssertSubscriber sub : subscribers) { + assertThat(sub.values().get(0)).isEqualTo("value"); + assertThat(sub.isTerminated()).isTrue(); } - Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + assertThat(publisher.subscribeCount()).isEqualTo(1); } @Test public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value"); + cold.complete(); final int timeout = 10; final ReconnectMono reconnectMono = - cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + cold.flux() + .takeLast(1) + .next() + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectNext("value") .expectComplete() .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::invalidate); - Assertions.assertThat(expired).hasSize(1).containsOnly("value"); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + cold.next("value2"); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() - .expectNext("value") + .expectNext("value2") .expectComplete() .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).hasSize(1).containsOnly("value"); - Assertions.assertThat(received) + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received) .hasSize(2) - .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value", reconnectMono)); + .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value2", reconnectMono)); - Assertions.assertThat(cold.subscribeCount()).isEqualTo(2); + assertThat(cold.subscribeCount()).isEqualTo(2); expired.clear(); received.clear(); @@ -957,7 +944,7 @@ public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { @Test public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestPublisher cold = TestPublisher.createCold(); cold.next("value"); final int timeout = 10000; @@ -965,29 +952,29 @@ public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { final ReconnectMono reconnectMono = cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectNext("value") .expectComplete() .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).isEmpty(); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::dispose); - Assertions.assertThat(expired).hasSize(1).containsOnly("value"); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectError(CancellationException.class) .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(expired).hasSize(1).containsOnly("value"); - Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); - Assertions.assertThat(cold.subscribeCount()).isEqualTo(1); + assertThat(cold.subscribeCount()).isEqualTo(1); expired.clear(); received.clear(); @@ -1011,14 +998,14 @@ public void shouldTimeoutRetryWithVirtualTime() { .maxBackoff(Duration.ofSeconds(maxBackoff))) .timeout(Duration.ofSeconds(timeout)) .as(m -> new ReconnectMono<>(m, onExpire(), onValue())) - .subscribeOn(Schedulers.elastic())) + .subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .thenAwait(Duration.ofSeconds(timeout)) .expectError(TimeoutException.class) .verify(Duration.ofSeconds(timeout)); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); } @Test @@ -1027,11 +1014,11 @@ public void ensuresThatMainSubscriberAllowsOnlyTerminationWithValue() { final ReconnectMono reconnectMono = new ReconnectMono<>(Mono.empty(), onExpire(), onValue()); - StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) .expectSubscription() .expectErrorSatisfies( t -> - Assertions.assertThat(t) + assertThat(t) .hasMessage("Source completed empty") .isInstanceOf(IllegalStateException.class)) .verify(Duration.ofSeconds(timeout)); @@ -1047,8 +1034,8 @@ public void monoRetryNoBackoff() { StepVerifier.create(mono).verifyErrorMatches(Exceptions::isRetryExhausted); assertRetries(IOException.class, IOException.class); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); } @Test @@ -1066,8 +1053,8 @@ public void monoRetryFixedBackoff() { assertRetries(IOException.class); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); } @Test @@ -1091,8 +1078,8 @@ public void monoRetryExponentialBackoff() { assertRetries(IOException.class, IOException.class, IOException.class, IOException.class); - Assertions.assertThat(received).isEmpty(); - Assertions.assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); } Consumer onRetry() { @@ -1109,12 +1096,12 @@ Consumer onExpire() { @SafeVarargs private final void assertRetries(Class... exceptions) { - assertEquals(exceptions.length, retries.size()); + assertThat(retries.size()).isEqualTo(exceptions.length); int index = 0; for (Iterator it = retries.iterator(); it.hasNext(); ) { Retry.RetrySignal retryContext = it.next(); - assertEquals(index, retryContext.totalRetries()); - assertEquals(exceptions[index], retryContext.failure().getClass()); + assertThat(retryContext.totalRetries()).isEqualTo(index); + assertThat(retryContext.failure().getClass()).isEqualTo(exceptions[index]); index++; } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java index 4fc06fdc2..c1e0a6876 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java @@ -16,6 +16,7 @@ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.CANCEL; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -24,11 +25,12 @@ import io.rsocket.FrameAssert; import io.rsocket.Payload; import io.rsocket.PayloadAssert; +import io.rsocket.RaceTestConstants; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.frame.FrameType; -import io.rsocket.internal.UnboundedProcessor; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; import java.time.Duration; @@ -39,6 +41,7 @@ import java.util.stream.Stream; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -66,7 +69,7 @@ public static void setUp() { public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); final TestPublisher publisher = TestPublisher.create(); @@ -103,7 +106,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp // state machine check stateAssert.hasSubscribedFlag().hasRequestN(10).hasFirstFrameSentFlag(); - final ByteBuf frame = sender.poll(); + final ByteBuf frame = sender.awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() .hasPayloadSize( @@ -121,7 +124,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp Assertions.assertThat(sender.isEmpty()).isTrue(); assertSubscriber.request(1); - final ByteBuf requestNFrame = sender.poll(); + final ByteBuf requestNFrame = sender.awaitFrame(); FrameAssert.assertThat(requestNFrame) .isNotNull() .hasRequestN(1) @@ -137,7 +140,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp stateAssert.hasSubscribedFlag().hasRequestN(11).hasFirstFrameSentFlag(); assertSubscriber.request(Long.MAX_VALUE); - final ByteBuf requestMaxNFrame = sender.poll(); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); FrameAssert.assertThat(requestMaxNFrame) .isNotNull() .hasRequestN(Integer.MAX_VALUE) @@ -211,10 +214,10 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp .hasInboundTerminated(); publisher.complete(); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); } else if (completionCase.equals("outbound")) { publisher.complete(); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); // state machine check stateAssert @@ -247,7 +250,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp public void streamShouldErrorWithoutInitializingRemoteStreamIfSourceIsEmpty(boolean doRequest) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final TestPublisher publisher = TestPublisher.create(); final RequestChannelRequesterFlux requestChannelRequesterFlux = @@ -292,7 +295,7 @@ public void streamShouldPropagateErrorWithoutInitializingRemoteStreamIfTheFirstS boolean doRequest) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final TestPublisher publisher = TestPublisher.create(); final RequestChannelRequesterFlux requestChannelRequesterFlux = @@ -336,7 +339,7 @@ public void streamShouldPropagateErrorWithoutInitializingRemoteStreamIfTheFirstS public void streamShouldBeInHalfClosedStateOnTheInboundCancellation(String terminationMode) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final TestPublisher publisher = TestPublisher.create(); final RequestChannelRequesterFlux requestChannelRequesterFlux = @@ -366,7 +369,7 @@ public void streamShouldBeInHalfClosedStateOnTheInboundCancellation(String termi publisher.next(payload1.retain()); - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .typeOf(FrameType.REQUEST_CHANNEL) .hasPayload(payload1) .hasRequestN(Integer.MAX_VALUE) @@ -386,10 +389,16 @@ public void streamShouldBeInHalfClosedStateOnTheInboundCancellation(String termi publisher.next(payload2.retain(), payload3.retain()); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.NEXT).hasPayload(payload2).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload2) + .hasNoLeaks(); payload2.release(); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.NEXT).hasPayload(payload3).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload3) + .hasNoLeaks(); payload3.release(); if (terminationMode.equals("outbound")) { @@ -428,7 +437,7 @@ public void streamShouldBeInHalfClosedStateOnTheInboundCancellation(String termi public void errorShouldTerminateExecution(String terminationMode) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final TestPublisher publisher = TestPublisher.create(); final RequestChannelRequesterFlux requestChannelRequesterFlux = @@ -458,7 +467,7 @@ public void errorShouldTerminateExecution(String terminationMode) { publisher.next(payload1.retain()); - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .typeOf(FrameType.REQUEST_CHANNEL) .hasPayload(payload1) .hasRequestN(Integer.MAX_VALUE) @@ -478,15 +487,24 @@ public void errorShouldTerminateExecution(String terminationMode) { publisher.next(payload2.retain(), payload3.retain()); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.NEXT).hasPayload(payload2).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload2) + .hasNoLeaks(); payload2.release(); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.NEXT).hasPayload(payload3).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload3) + .hasNoLeaks(); payload3.release(); if (terminationMode.equals("outbound")) { publisher.error(new ApplicationErrorException("test")); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.ERROR).hasData("test").hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.ERROR) + .hasData("test") + .hasNoLeaks(); } else if (terminationMode.equals("inbound")) { requestChannelRequesterFlux.handleError(new ApplicationErrorException("test")); publisher.assertWasCancelled(); @@ -497,6 +515,77 @@ public void errorShouldTerminateExecution(String terminationMode) { stateAssert.isTerminated(); } + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + stateAssert.hasSubscribedFlag().hasRequestN(1).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(1) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(nextPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(unrequestedPayload); + + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks()) + .assertError() + .assertErrorMessage("The number of messages received exceeds the number requested"); + + publisher.assertWasCancelled(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + } + /* * +--------------------------------+ * | Racing Test Cases | @@ -530,10 +619,10 @@ public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundS Hooks.onErrorDropped(droppedErrors::add); try { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final TestPublisher publisher = TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); @@ -561,7 +650,7 @@ public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundS stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); activeStreams.assertHasStream(1, requestChannelRequesterFlux); - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .typeOf(FrameType.REQUEST_CHANNEL) .hasRequestN(Integer.MAX_VALUE) .hasNoLeaks(); @@ -599,7 +688,7 @@ public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundS } }); - ByteBuf errorFrameOrEmpty = sender.poll(); + ByteBuf errorFrameOrEmpty = sender.pollFrame(); if (errorFrameOrEmpty != null) { if (outboundTerminationMode.equals("onError")) { FrameAssert.assertThat(errorFrameOrEmpty) @@ -691,10 +780,10 @@ public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundS @ValueSource(strings = {"complete", "cancel"}) public void shouldRemoveItselfFromActiveStreamsWhenInboundAndOutboundAreTerminated( String outboundTerminationMode) { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final TestPublisher publisher = TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); @@ -723,7 +812,7 @@ public void shouldRemoveItselfFromActiveStreamsWhenInboundAndOutboundAreTerminat stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); activeStreams.assertHasStream(1, requestChannelRequesterFlux); - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .typeOf(FrameType.REQUEST_CHANNEL) .hasRequestN(Integer.MAX_VALUE) .hasNoLeaks(); @@ -740,7 +829,7 @@ public void shouldRemoveItselfFromActiveStreamsWhenInboundAndOutboundAreTerminat }, requestChannelRequesterFlux::handleComplete); - ByteBuf completeFrameOrNull = sender.poll(); + ByteBuf completeFrameOrNull = sender.pollFrame(); if (completeFrameOrNull != null) { FrameAssert.assertThat(completeFrameOrNull) .hasStreamId(1) diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java index b1c1e8cf9..890458caf 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java @@ -28,11 +28,12 @@ import io.rsocket.FrameAssert; import io.rsocket.Payload; import io.rsocket.PayloadAssert; +import io.rsocket.RaceTestConstants; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.frame.FrameType; -import io.rsocket.internal.UnboundedProcessor; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import io.rsocket.util.DefaultPayload; import java.time.Duration; @@ -72,8 +73,8 @@ public static void setUp() { @ValueSource(strings = {"inbound", "outbound", "inboundCancel"}) public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); final TestPublisher publisher = TestPublisher.create(); @@ -112,7 +113,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); // should not send requestN since 1 is remaining - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .typeOf(REQUEST_N) .hasStreamId(1) .hasRequestN(1) @@ -120,7 +121,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp publisher.next(TestRequesterResponderSupport.genericPayload(allocator)); - final ByteBuf frame = sender.poll(); + final ByteBuf frame = sender.awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() .hasPayloadSize( @@ -135,7 +136,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp .hasNoLeaks(); assertSubscriber.request(Long.MAX_VALUE); - final ByteBuf requestMaxNFrame = sender.poll(); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); FrameAssert.assertThat(requestMaxNFrame) .isNotNull() .hasRequestN(Integer.MAX_VALUE) @@ -204,7 +205,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp .hasInboundTerminated(); publisher.complete(); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); } else if (completionCase.equals("inboundCancel")) { assertSubscriber.cancel(); assertSubscriber.assertValuesWith( @@ -215,7 +216,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp randomPayload.release(); }); - FrameAssert.assertThat(sender.poll()).typeOf(CANCEL).hasStreamId(1).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(CANCEL).hasStreamId(1).hasNoLeaks(); // state machine check stateAssert @@ -226,10 +227,13 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp .hasInboundTerminated(); publisher.complete(); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasStreamId(1).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); } else if (completionCase.equals("outbound")) { publisher.complete(); - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); // state machine check stateAssert @@ -259,6 +263,143 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp allocator.assertHasNoLeaks(); } + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + assertSubscriber.request(1); + + // state machine check + stateAssert.hasSubscribedFlag().hasFirstFrameSentFlag().hasRequestN(1); + + // should not send requestN since 1 is remaining + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + // should not send requestN since 1 is remaining + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(nextPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(unrequestedPayload); + + final ByteBuf cancelErrorFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelErrorFrame) + .isNotNull() + .typeOf(ERROR) + .hasData("The number of messages received exceeds the number requested") + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks()) + .assertErrorMessage("The number of messages received exceeds the number requested"); + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + Assertions.assertThat(nextPayload.refCnt()).isZero(); + Assertions.assertThat(unrequestedPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void failOnOverflowBeforeFirstPayloadIsSent() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(unrequestedPayload); + + final ByteBuf cancelErrorFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelErrorFrame) + .isNotNull() + .typeOf(ERROR) + .hasData("The number of messages received exceeds the number requested") + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber.request(1); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks()) + .assertErrorMessage("The number of messages received exceeds the number requested"); + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + Assertions.assertThat(unrequestedPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + /* * +--------------------------------+ * | Racing Test Cases | @@ -267,10 +408,11 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String comp @Test public void streamShouldWorkCorrectlyWhenRacingHandleCompleteWithSubscription() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); final TestPublisher publisher = TestPublisher.create(); @@ -308,14 +450,14 @@ public void streamShouldWorkCorrectlyWhenRacingHandleCompleteWithSubscription() publisher.complete(); - if (sender.size() > 1) { - FrameAssert.assertThat(sender.poll()) + if (sender.getSent().size() > 1) { + FrameAssert.assertThat(sender.awaitFrame()) .hasStreamId(1) .typeOf(REQUEST_N) .hasRequestN(1) .hasNoLeaks(); } - FrameAssert.assertThat(sender.poll()).hasStreamId(1).typeOf(COMPLETE).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()).hasStreamId(1).typeOf(COMPLETE).hasNoLeaks(); // state machine check stateAssert.isTerminated(); @@ -327,7 +469,7 @@ public void streamShouldWorkCorrectlyWhenRacingHandleCompleteWithSubscription() public void streamShouldWorkCorrectlyWhenRacingHandleErrorWithSubscription() { ApplicationErrorException applicationErrorException = new ApplicationErrorException("test"); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); @@ -369,9 +511,59 @@ public void streamShouldWorkCorrectlyWhenRacingHandleErrorWithSubscription() { } } + @Test + public void streamShouldWorkCorrectlyWhenRacingOutboundErrorWithSubscription() { + RuntimeException exception = new RuntimeException("test"); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> publisher.error(exception)); + + stateAssert.isTerminated(); + + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(ERROR) + .hasData("test") + .hasStreamId(1) + .hasNoLeaks(); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Outbound has terminated with an error"); + + allocator.assertHasNoLeaks(); + } + } + @Test public void streamShouldWorkCorrectlyWhenRacingHandleCancelWithSubscription() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); @@ -439,10 +631,10 @@ public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundS Hooks.onErrorDropped(droppedErrors::add); try { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final TestPublisher publisher = TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); @@ -460,7 +652,7 @@ public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundS assertSubscriber.request(Integer.MAX_VALUE); - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .typeOf(FrameType.REQUEST_N) .hasRequestN(Integer.MAX_VALUE) .hasNoLeaks(); @@ -498,7 +690,7 @@ public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundS } }); - ByteBuf errorFrameOrEmpty = sender.poll(); + ByteBuf errorFrameOrEmpty = sender.pollFrame(); if (errorFrameOrEmpty != null) { String message; if (outboundTerminationMode.equals("onError")) { @@ -602,13 +794,14 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(String terminationMode) final Payload oversizePayload = DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; final TestPublisher publisher = TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); - final AssertSubscriber assertSubscriber = new AssertSubscriber<>(1); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(2); Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); final RequestChannelResponderSubscriber requestOperator = @@ -669,8 +862,17 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(String terminationMode) assertSubscriber.assertTerminated().assertError(); } - final ByteBuf frame = sender.poll(); - FrameAssert.assertThat(frame) + final ByteBuf requstFrame = sender.awaitFrame(); + FrameAssert.assertThat(requstFrame) + .isNotNull() + .typeOf(REQUEST_N) + .hasRequestN(1) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf terminalFrame = sender.awaitFrame(); + FrameAssert.assertThat(terminalFrame) .isNotNull() .typeOf(terminationMode.equals("cancel") ? CANCEL : ERROR) .hasClientSideStreamId() diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java index 86babe671..b39ac62d9 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java @@ -31,7 +31,7 @@ import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.frame.FrameType; -import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import io.rsocket.util.EmptyPayload; import java.time.Duration; @@ -75,7 +75,7 @@ public void frameShouldBeSentOnSubscription( transformer) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload payload = genericPayload(allocator); final RequestResponseRequesterMono requestResponseRequesterMono = @@ -105,7 +105,7 @@ public void frameShouldBeSentOnSubscription( // should not add anything to map activeStreams.assertNoActiveStreams(); - final ByteBuf frame = sender.poll(); + final ByteBuf frame = sender.awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() .hasPayloadSize( @@ -122,7 +122,7 @@ public void frameShouldBeSentOnSubscription( stateAssert.isTerminated(); if (!sender.isEmpty()) { - ByteBuf cancelFrame = sender.poll(); + ByteBuf cancelFrame = sender.awaitFrame(); FrameAssert.assertThat(cancelFrame) .isNotNull() .typeOf(FrameType.CANCEL) @@ -320,7 +320,7 @@ public void frameFragmentsShouldBeSentOnSubscription( final int mtu = 64; final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final byte[] metadata = new byte[65]; final byte[] data = new byte[129]; @@ -356,7 +356,7 @@ public void frameFragmentsShouldBeSentOnSubscription( Assertions.assertThat(payload.refCnt()).isZero(); - final ByteBuf frameFragment1 = sender.poll(); + final ByteBuf frameFragment1 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment1) .isNotNull() .hasPayloadSize( @@ -370,7 +370,7 @@ public void frameFragmentsShouldBeSentOnSubscription( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment2 = sender.poll(); + final ByteBuf frameFragment2 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment2) .isNotNull() .hasPayloadSize( @@ -384,7 +384,7 @@ public void frameFragmentsShouldBeSentOnSubscription( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment3 = sender.poll(); + final ByteBuf frameFragment3 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment3) .isNotNull() .hasPayloadSize( @@ -397,7 +397,7 @@ public void frameFragmentsShouldBeSentOnSubscription( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment4 = sender.poll(); + final ByteBuf frameFragment4 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment4) .isNotNull() .hasPayloadSize(35) @@ -410,14 +410,14 @@ public void frameFragmentsShouldBeSentOnSubscription( .hasNoLeaks(); if (!sender.isEmpty()) { - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .isNotNull() .typeOf(FrameType.CANCEL) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); } - Assertions.assertThat(sender).isEmpty(); + Assertions.assertThat(sender.isEmpty()).isTrue(); stateAssert.isTerminated(); allocator.assertHasNoLeaks(); } @@ -430,7 +430,7 @@ public void frameFragmentsShouldBeSentOnSubscription( public void shouldBeNoOpsOnCancel() { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload payload = ByteBufPayload.create("testData", "testMetadata"); final RequestResponseRequesterMono requestResponseRequesterMono = @@ -466,7 +466,8 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( Consumer monoConsumer) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; final Payload payload = ByteBufPayload.create(""); payload.release(); @@ -483,7 +484,7 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( stateAssert.isTerminated(); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(sender).isEmpty(); + Assertions.assertThat(sender.isEmpty()).isTrue(); allocator.assertHasNoLeaks(); } @@ -509,7 +510,8 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; final Payload payload = ByteBufPayload.create(""); final RequestResponseRequesterMono requestResponseRequesterMono = @@ -543,7 +545,8 @@ public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation final int mtu = 64; final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; final byte[] metadata = new byte[65]; final byte[] data = new byte[129]; ThreadLocalRandom.current().nextBytes(metadata); @@ -582,7 +585,8 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( Consumer monoConsumer) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; final byte[] metadata = new byte[FRAME_LENGTH_MASK]; final byte[] data = new byte[FRAME_LENGTH_MASK]; @@ -604,7 +608,7 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( Assertions.assertThat(payload.refCnt()).isZero(); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(sender).isEmpty(); + Assertions.assertThat(sender.isEmpty()).isTrue(); stateAssert.isTerminated(); allocator.assertHasNoLeaks(); } @@ -639,7 +643,6 @@ public void shouldErrorIfNoAvailability(Consumer m final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(new RuntimeException("test")); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); final Payload payload = genericPayload(allocator); final RequestResponseRequesterMono requestResponseRequesterMono = diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java index 9791b0786..8702d1a80 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java @@ -31,8 +31,8 @@ import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.frame.FrameType; -import io.rsocket.internal.UnboundedProcessor; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import io.rsocket.util.EmptyPayload; import java.time.Duration; @@ -80,7 +80,7 @@ public static void setUp() { public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); final RequestStreamRequesterFlux requestStreamRequesterFlux = @@ -108,7 +108,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { // state machine check stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); - final ByteBuf frame = sender.poll(); + final ByteBuf frame = sender.awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() .hasPayloadSize( @@ -126,7 +126,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { Assertions.assertThat(sender.isEmpty()).isTrue(); assertSubscriber.request(1); - final ByteBuf requestNFrame = sender.poll(); + final ByteBuf requestNFrame = sender.awaitFrame(); FrameAssert.assertThat(requestNFrame) .isNotNull() .hasRequestN(1) @@ -142,7 +142,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); assertSubscriber.request(Long.MAX_VALUE); - final ByteBuf requestMaxNFrame = sender.poll(); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); FrameAssert.assertThat(requestMaxNFrame) .isNotNull() .hasRequestN(Integer.MAX_VALUE) @@ -227,7 +227,7 @@ public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { public void requestNFrameShouldBeSentExactlyOnceIfItIsMaxAllowed() { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); final RequestStreamRequesterFlux requestStreamRequesterFlux = @@ -257,7 +257,7 @@ public void requestNFrameShouldBeSentExactlyOnceIfItIsMaxAllowed() { Assertions.assertThat(payload.refCnt()).isZero(); activeStreams.assertHasStream(1, requestStreamRequesterFlux); - final ByteBuf frame = sender.poll(); + final ByteBuf frame = sender.awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() .hasPayloadSize( @@ -332,7 +332,7 @@ public void frameShouldBeSentOnFirstRequest( transformer) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); final RequestStreamRequesterFlux requestStreamRequesterFlux = @@ -369,7 +369,7 @@ public void frameShouldBeSentOnFirstRequest( // should not add anything to map activeStreams.assertNoActiveStreams(); - final ByteBuf frame = sender.poll(); + final ByteBuf frame = sender.awaitFrame(); FrameAssert.assertThat(frame) .isNotNull() .hasPayloadSize( @@ -384,7 +384,7 @@ public void frameShouldBeSentOnFirstRequest( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf requestNFrame = sender.poll(); + final ByteBuf requestNFrame = sender.awaitFrame(); FrameAssert.assertThat(requestNFrame) .isNotNull() .typeOf(FrameType.REQUEST_N) @@ -394,7 +394,7 @@ public void frameShouldBeSentOnFirstRequest( .hasNoLeaks(); if (!sender.isEmpty()) { - final ByteBuf cancelFrame = sender.poll(); + final ByteBuf cancelFrame = sender.awaitFrame(); FrameAssert.assertThat(cancelFrame) .isNotNull() .typeOf(FrameType.CANCEL) @@ -764,7 +764,7 @@ public void frameFragmentsShouldBeSentOnFirstRequest( final int mtu = 64; final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final byte[] metadata = new byte[65]; final byte[] data = new byte[129]; @@ -799,7 +799,7 @@ public void frameFragmentsShouldBeSentOnFirstRequest( Assertions.assertThat(payload.refCnt()).isZero(); - final ByteBuf frameFragment1 = sender.poll(); + final ByteBuf frameFragment1 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment1) .isNotNull() .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N) @@ -812,7 +812,7 @@ public void frameFragmentsShouldBeSentOnFirstRequest( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment2 = sender.poll(); + final ByteBuf frameFragment2 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment2) .isNotNull() .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA) @@ -825,7 +825,7 @@ public void frameFragmentsShouldBeSentOnFirstRequest( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment3 = sender.poll(); + final ByteBuf frameFragment3 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment3) .isNotNull() .hasPayloadSize(64 - FRAME_OFFSET) @@ -837,7 +837,7 @@ public void frameFragmentsShouldBeSentOnFirstRequest( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf frameFragment4 = sender.poll(); + final ByteBuf frameFragment4 = sender.awaitFrame(); FrameAssert.assertThat(frameFragment4) .isNotNull() .hasPayloadSize(39) @@ -849,7 +849,7 @@ public void frameFragmentsShouldBeSentOnFirstRequest( .hasStreamId(1) .hasNoLeaks(); - final ByteBuf requestNFrame = sender.poll(); + final ByteBuf requestNFrame = sender.awaitFrame(); FrameAssert.assertThat(requestNFrame) .isNotNull() .typeOf(FrameType.REQUEST_N) @@ -859,14 +859,14 @@ public void frameFragmentsShouldBeSentOnFirstRequest( .hasNoLeaks(); if (!sender.isEmpty()) { - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .isNotNull() .typeOf(FrameType.CANCEL) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); } - Assertions.assertThat(sender).isEmpty(); + Assertions.assertThat(sender.isEmpty()).isTrue(); // state machine check stateAssert.isTerminated(); allocator.assertHasNoLeaks(); @@ -882,7 +882,7 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( Consumer monoConsumer) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final Payload payload = ByteBufPayload.create(""); payload.release(); @@ -898,7 +898,7 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( monoConsumer.accept(requestStreamRequesterFlux); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(sender).isEmpty(); + Assertions.assertThat(sender.isEmpty()).isTrue(); // state machine check stateAssert.isTerminated(); allocator.assertHasNoLeaks(); @@ -925,7 +925,8 @@ public void shouldErrorOnIncorrectRefCntInGivenPayload( public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = ByteBufPayload.create(""); final RequestStreamRequesterFlux requestStreamRequesterFlux = @@ -953,7 +954,7 @@ public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { .verify(); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(sender).isEmpty(); + Assertions.assertThat(sender.isEmpty()).isTrue(); // state machine check stateAssert.isTerminated(); @@ -969,7 +970,7 @@ public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation final int mtu = 64; final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final byte[] metadata = new byte[65]; final byte[] data = new byte[129]; @@ -1003,7 +1004,7 @@ public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation .verify(); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(sender).isEmpty(); + Assertions.assertThat(sender.isEmpty()).isTrue(); // state machine check stateAssert.isTerminated(); allocator.assertHasNoLeaks(); @@ -1019,7 +1020,7 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( Consumer monoConsumer) { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); final byte[] metadata = new byte[FRAME_LENGTH_MASK]; final byte[] data = new byte[FRAME_LENGTH_MASK]; @@ -1043,7 +1044,7 @@ public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( Assertions.assertThat(payload.refCnt()).isZero(); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(sender).isEmpty(); + Assertions.assertThat(sender.isEmpty()).isTrue(); // state machine check stateAssert.isTerminated(); allocator.assertHasNoLeaks(); @@ -1083,7 +1084,6 @@ public void shouldErrorIfNoAvailability(Consumer mon final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(new RuntimeException("test")); final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); - final UnboundedProcessor sender = activeStreams.getSendProcessor(); final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); final RequestStreamRequesterFlux requestStreamRequesterFlux = @@ -1129,6 +1129,87 @@ static Stream> shouldErrorIfNoAvailabilityS .isInstanceOf(RuntimeException.class)); } + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + Payload requestedPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(requestedPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(unrequestedPayload); + + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isEqualTo(requestedPayload).hasNoLeaks()) + .assertError() + .assertErrorMessage("The number of messages received exceeds the number requested"); + + PayloadAssert.assertThat(requestedPayload).isReleased(); + PayloadAssert.assertThat(unrequestedPayload).isReleased(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + @Test public void checkName() { final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java index 8aee36467..06d050f6f 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java @@ -27,9 +27,11 @@ import io.netty.util.CharsetUtil; import io.rsocket.FrameAssert; import io.rsocket.Payload; +import io.rsocket.RaceTestConstants; import io.rsocket.frame.FrameType; import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.TestRequestInterceptor; import io.rsocket.util.ByteBufPayload; import java.time.Duration; import java.util.ArrayList; @@ -39,11 +41,10 @@ import java.util.stream.Stream; import org.assertj.core.api.Assertions; import org.assertj.core.api.Assumptions; -import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; -import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; @@ -168,9 +169,10 @@ public String toString() { @ParameterizedTest(name = "Should subscribe exactly once to {0}") @MethodSource("scenarios") public void shouldSubscribeExactlyOnce(Scenario scenario) { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); final TestRequesterResponderSupport requesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload( @@ -180,7 +182,7 @@ public void shouldSubscribeExactlyOnce(Scenario scenario) { scenario.requestOperator(payloadSupplier, requesterResponderSupport); StepVerifier stepVerifier = - StepVerifier.create(requesterResponderSupport.getSendProcessor()) + StepVerifier.create(requesterResponderSupport.getDuplexConnection().getSentAsPublisher()) .assertNext( frame -> { FrameAssert frameAssert = @@ -214,6 +216,9 @@ public void shouldSubscribeExactlyOnce(Scenario scenario) { if (requestOperator instanceof FrameHandler) { ((FrameHandler) requestOperator).handleComplete(); + if (scenario.requestType() == REQUEST_CHANNEL) { + ((FrameHandler) requestOperator).handleCancel(); + } } }) .thenCancel() @@ -239,8 +244,30 @@ public void shouldSubscribeExactlyOnce(Scenario scenario) { }); stepVerifier.verify(Duration.ofSeconds(1)); - Assertions.assertThat(requesterResponderSupport.getSendProcessor().isEmpty()).isTrue(); requesterResponderSupport.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } } } @@ -251,8 +278,10 @@ public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()) .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); - for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); @@ -266,11 +295,11 @@ public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { RaceTestUtils.race(() -> assertSubscriber.request(1), () -> assertSubscriber.request(1)); - final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); if (scenario.requestType().hasInitialRequestN()) { if (RequestStreamFrameCodec.initialRequestN(sentFrame) == 1) { - FrameAssert.assertThat(activeStreams.getSendProcessor().poll()) + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) .isNotNull() .hasStreamId(1) .hasRequestN(1) @@ -300,7 +329,7 @@ public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { if (scenario.requestType() == REQUEST_CHANNEL) { ((CoreSubscriber) requestOperator).onComplete(); - FrameAssert.assertThat(activeStreams.getSendProcessor().poll()) + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) .typeOf(COMPLETE) .hasStreamId(1) .hasNoLeaks(); @@ -315,8 +344,14 @@ public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { }); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } } } @@ -330,8 +365,10 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()) .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); - for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); final Supplier payloadSupplier = () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); @@ -342,7 +379,7 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { requestOperator.subscribe(assertSubscriber); - final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(sentFrame) .isNotNull() .hasPayloadSize( @@ -393,25 +430,35 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { }); } - if (!activeStreams.getSendProcessor().isEmpty()) { + if (!activeStreams.getDuplexConnection().isEmpty()) { if (scenario.requestType() != REQUEST_CHANNEL) { assertSubscriber.assertNotTerminated(); } - final ByteBuf cancellationFrame = activeStreams.getSendProcessor().poll(); + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(cancellationFrame) .isNotNull() .typeOf(FrameType.CANCEL) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); } Assertions.assertThat(responsePayload.release()).isTrue(); Assertions.assertThat(responsePayload.refCnt()).isZero(); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); } } @@ -420,24 +467,26 @@ public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { * Ensures that in case of racing between next element and cancel we will not have any memory * leaks */ - @Test - public void shouldHaveNoLeaksOnNextAndCancelRacing() { - for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnNextAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - Payload response = ByteBufPayload.create("test", "test"); + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); - StepVerifier.create(requestResponseRequesterMono.doOnNext(Payload::release)) - .expectSubscription() - .expectComplete() - .verifyLater(); + Payload response = ByteBufPayload.create("test", "test"); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + requestOperator.subscribe((AssertSubscriber) assertSubscriber); - final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(sentFrame) .isNotNull() .hasPayloadSize( @@ -447,32 +496,41 @@ public void shouldHaveNoLeaksOnNextAndCancelRacing() { .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) + .typeOf(scenario.requestType()) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); RaceTestUtils.race( - requestResponseRequesterMono::cancel, - () -> requestResponseRequesterMono.handlePayload(response)); + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handlePayload(response)); - Assertions.assertThat(payload.refCnt()).isZero(); + assertSubscriber.values().forEach(Payload::release); Assertions.assertThat(response.refCnt()).isZero(); activeStreams.assertNoActiveStreams(); - final boolean isEmpty = activeStreams.getSendProcessor().isEmpty(); + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); if (!isEmpty) { - final ByteBuf cancellationFrame = activeStreams.getSendProcessor().poll(); + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(cancellationFrame) .isNotNull() .typeOf(FrameType.CANCEL) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); - } - Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + assertSubscriber.assertTerminated(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); } } @@ -483,84 +541,106 @@ public void shouldHaveNoLeaksOnNextAndCancelRacing() { * cancel we will not have any memory leaks */ @ParameterizedTest - @ValueSource(booleans = {false, true}) - public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(boolean withReassembly) { + @MethodSource("scenarios") + public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + boolean[] withReassemblyOptions = new boolean[] {true, false}; final ArrayList droppedErrors = new ArrayList<>(); Hooks.onErrorDropped(droppedErrors::add); - try { - for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); - - final StateAssert stateAssert = - StateAssert.assertThat(requestResponseRequesterMono); - - stateAssert.isUnsubscribed(); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(AssertSubscriber.create(0)); - stateAssert.hasSubscribedFlagOnly(); - - assertSubscriber.request(1); + try { + for (boolean withReassembly : withReassemblyOptions) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + final StateAssert stateAssert; + if (requestOperator instanceof RequestResponseRequesterMono) { + stateAssert = StateAssert.assertThat((RequestResponseRequesterMono) requestOperator); + } else if (requestOperator instanceof RequestStreamRequesterFlux) { + stateAssert = StateAssert.assertThat((RequestStreamRequesterFlux) requestOperator); + } else { + stateAssert = StateAssert.assertThat((RequestChannelRequesterFlux) requestOperator); + } - stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + stateAssert.isUnsubscribed(); + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); - final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); - FrameAssert.assertThat(sentFrame) - .isNotNull() - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) - .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) - .hasData(TestRequesterResponderSupport.DATA_CONTENT) - .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) - .hasClientSideStreamId() - .hasStreamId(1) - .hasNoLeaks(); + requestOperator.subscribe((AssertSubscriber) assertSubscriber); - if (withReassembly) { - final ByteBuf fragmentBuf = - activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); - requestResponseRequesterMono.handleNext(fragmentBuf, true, false); - // mimic frameHandler behaviour - fragmentBuf.release(); - } + stateAssert.hasSubscribedFlagOnly(); - final RuntimeException testException = new RuntimeException("test"); - RaceTestUtils.race( - requestResponseRequesterMono::cancel, - () -> requestResponseRequesterMono.handleError(testException)); + assertSubscriber.request(1); - Assertions.assertThat(payload.refCnt()).isZero(); + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); - activeStreams.assertNoActiveStreams(); - stateAssert.isTerminated(); - - final boolean isEmpty = activeStreams.getSendProcessor().isEmpty(); - if (!isEmpty) { - final ByteBuf cancellationFrame = activeStreams.getSendProcessor().poll(); - FrameAssert.assertThat(cancellationFrame) + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) .isNotNull() - .typeOf(FrameType.CANCEL) + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); - Assertions.assertThat(droppedErrors).containsExactly(testException); - } else { - assertSubscriber.assertTerminated().assertErrorMessage("test"); - } - Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + if (withReassembly) { + final ByteBuf fragmentBuf = + activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); + ((RequesterFrameHandler) requestOperator).handleNext(fragmentBuf, true, false); + // mimic frameHandler behaviour + fragmentBuf.release(); + } - stateAssert.isTerminated(); - droppedErrors.clear(); - activeStreams.getAllocator().assertHasNoLeaks(); + final RuntimeException testException = new RuntimeException("test"); + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handleError(testException)); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(droppedErrors).containsExactly(testException); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnError(1) + .expectNothing(); + + assertSubscriber.assertTerminated().assertErrorMessage("test"); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + + stateAssert.isTerminated(); + droppedErrors.clear(); + activeStreams.getAllocator().assertHasNoLeaks(); + } } } finally { Hooks.resetOnErrorDropped(); @@ -584,32 +664,33 @@ public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(boolean with * *

Ensures full serialization of outgoing signal (frames) */ - @Test - public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { - for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + @ParameterizedTest + @MethodSource("scenarios") + public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); Payload response = ByteBufPayload.create("test", "test"); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(new AssertSubscriber<>(0)); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); RaceTestUtils.race(() -> assertSubscriber.cancel(), () -> assertSubscriber.request(1)); - if (!activeStreams.getSendProcessor().isEmpty()) { - final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + if (!activeStreams.getDuplexConnection().isEmpty()) { + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(sentFrame) .isNotNull() - .typeOf(FrameType.REQUEST_RESPONSE) - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) + .typeOf(scenario.requestType()) .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasNoFragmentsFollow() @@ -617,65 +698,69 @@ public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest() { .hasStreamId(1) .hasNoLeaks(); - final ByteBuf cancelFrame = activeStreams.getSendProcessor().poll(); + final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(cancelFrame) .isNotNull() .typeOf(FrameType.CANCEL) .hasClientSideStreamId() .hasStreamId(1) .hasNoLeaks(); - } - Assertions.assertThat(payload.refCnt()).isZero(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + ((RequesterFrameHandler) requestOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); - requestResponseRequesterMono.handlePayload(response); Assertions.assertThat(response.refCnt()).isZero(); - activeStreams.assertNoActiveStreams(); - Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); } } /** Ensures that CancelFrame is sent exactly once in case of racing between cancel() methods */ - @Test - public void shouldSentCancelFrameExactlyOnce() { - for (int i = 0; i < 10000; i++) { - final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); - final Payload payload = - TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + @ParameterizedTest + @MethodSource("scenarios") + public void shouldSentCancelFrameExactlyOnce(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); - final RequestResponseRequesterMono requestResponseRequesterMono = - new RequestResponseRequesterMono(payload, activeStreams); + final Publisher requesterOperator = + scenario.requestOperator(payloadSupplier, activeStreams); Payload response = ByteBufPayload.create("test", "test"); - final AssertSubscriber assertSubscriber = - requestResponseRequesterMono.subscribeWith(new AssertSubscriber<>(0)); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requesterOperator.subscribe((AssertSubscriber) assertSubscriber); assertSubscriber.request(1); - final ByteBuf sentFrame = activeStreams.getSendProcessor().poll(); + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(sentFrame) .isNotNull() .hasNoFragmentsFollow() - .typeOf(FrameType.REQUEST_RESPONSE) + .typeOf(scenario.requestType()) .hasClientSideStreamId() - .hasPayloadSize( - TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length - + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) - .length) .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) .hasData(TestRequesterResponderSupport.DATA_CONTENT) .hasStreamId(1) .hasNoLeaks(); RaceTestUtils.race( - requestResponseRequesterMono::cancel, requestResponseRequesterMono::cancel); + ((Subscription) requesterOperator)::cancel, ((Subscription) requesterOperator)::cancel); - final ByteBuf cancelFrame = activeStreams.getSendProcessor().poll(); + final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); FrameAssert.assertThat(cancelFrame) .isNotNull() .typeOf(FrameType.CANCEL) @@ -683,19 +768,22 @@ public void shouldSentCancelFrameExactlyOnce() { .hasStreamId(1) .hasNoLeaks(); - Assertions.assertThat(payload.refCnt()).isZero(); - activeStreams.assertNoActiveStreams(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); - StateAssert.assertThat(requestResponseRequesterMono).isTerminated(); + activeStreams.assertNoActiveStreams(); - requestResponseRequesterMono.handlePayload(response); + ((RequesterFrameHandler) requesterOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); Assertions.assertThat(response.refCnt()).isZero(); - requestResponseRequesterMono.handleComplete(); + ((RequesterFrameHandler) requesterOperator).handleComplete(); assertSubscriber.assertNotTerminated(); activeStreams.assertNoActiveStreams(); - Assertions.assertThat(activeStreams.getSendProcessor().isEmpty()).isTrue(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); activeStreams.getAllocator().assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java b/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java index 29748abbe..382240c4a 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java @@ -16,6 +16,7 @@ package io.rsocket.core; +import io.rsocket.RaceTestConstants; import io.rsocket.internal.subscriber.AssertSubscriber; import java.time.Duration; import java.util.ArrayList; @@ -36,31 +37,31 @@ import org.mockito.Mockito; import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; import reactor.core.publisher.Hooks; -import reactor.core.publisher.MonoProcessor; -import reactor.core.scheduler.Schedulers; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; import reactor.test.util.RaceTestUtils; -import reactor.util.retry.Retry; public class ResolvingOperatorTests { - private Queue retries = new ConcurrentLinkedQueue<>(); - @Test public void shouldExpireValueOnRacingDisposeAndComplete() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final int index = i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; ResolvingTest.create() @@ -76,42 +77,48 @@ public void shouldExpireValueOnRacingDisposeAndComplete() { .ifResolvedAssertEqual("value" + index) .assertIsDisposed(); - if (processor.isError()) { - Assertions.assertThat(processor.getError()) + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + Assertions.assertThat(subscriber.errors().get(0)) .isInstanceOf(CancellationException.class) .hasMessage("Disposed"); } else { - Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(subscriber.values()).containsExactly("value" + i); } } } @Test public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -123,10 +130,7 @@ public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete( self -> { RaceTestUtils.race(() -> self.complete(valueToSend), () -> self.observe(consumer)); - StepVerifier.create(processor) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); }) .assertDisposeCalled(0) .assertReceivedExactly(valueToSend) @@ -134,39 +138,40 @@ public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete( .thenAddObserver(consumer2) .assertPendingSubscribers(0); - StepVerifier.create(processor2) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber2.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); } } @Test public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; final String valueToSend2 = "value2" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -179,10 +184,7 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() self -> { self.complete(valueToSend); - StepVerifier.create(processor) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); }) .assertReceivedExactly(valueToSend) .then( @@ -191,11 +193,10 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() self::invalidate, () -> { self.observe(consumer2); - if (!processor2.isTerminated()) { + if (!subscriber2.isTerminated()) { self.complete(valueToSend2); } - }, - Schedulers.parallel())) + })) .then( self -> { if (self.isPending()) { @@ -209,46 +210,51 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() .assertDisposeCalled(0) .then( self -> - StepVerifier.create(processor2) - .expectNextMatches( - (v) -> { + subscriber2 + .await(Duration.ofMillis(100)) + .assertValueCount(1) + .assertValuesWith( + v -> { if (self.subscribers == ResolvingOperator.READY) { - return v.equals(valueToSend2); + Assertions.assertThat(v).isEqualTo(valueToSend2); } else { - return v.equals(valueToSend); + Assertions.assertThat(v).isEqualTo(valueToSend); } }) - .expectComplete() - .verify(Duration.ofMillis(100))); + .assertComplete()); } } @Test public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; final String valueToSend2 = "value_to_possibly_expire" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -261,25 +267,20 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates( self -> { self.complete(valueToSend); - StepVerifier.create(processor) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber.await(Duration.ofMillis(100)).assertValues(valueToSend).assertComplete(); }) .assertReceivedExactly(valueToSend) .then( self -> RaceTestUtils.race( - () -> - RaceTestUtils.race( - self::invalidate, self::invalidate, Schedulers.parallel()), + self::invalidate, + self::invalidate, () -> { self.observe(consumer2); - if (!processor2.isTerminated()) { + if (!subscriber2.isTerminated()) { self.complete(valueToSend2); } - }, - Schedulers.parallel())) + })) .then( self -> { if (!self.isPending()) { @@ -296,7 +297,7 @@ public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates( .haveAtMost( 2, new Condition<>( - new Predicate() { + new Predicate() { int time = 0; @Override @@ -313,40 +314,39 @@ public boolean test(Object s) { .assertPendingSubscribers(0) .assertDisposeCalled(0) .then( - new Consumer>() { - @Override - public void accept(ResolvingTest self) { - StepVerifier.create(processor2) - .expectNextMatches( - (v) -> { + self -> + subscriber2 + .await(Duration.ofMillis(100)) + .assertValueCount(1) + .assertValuesWith( + v -> { if (self.subscribers == ResolvingOperator.READY) { - return v.equals(valueToSend2); + Assertions.assertThat(v).isEqualTo(valueToSend2); } else { - return v.equals(valueToSend) || v.equals(valueToSend2); + Assertions.assertThat(v).isIn(valueToSend, valueToSend2); } }) - .expectComplete() - .verify(Duration.ofMillis(100)); - } - }); + .assertComplete()); } } @Test public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; final String valueToSend2 = "value2" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; ResolvingTest.create() @@ -359,10 +359,7 @@ public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { self -> { self.complete(valueToSend); - StepVerifier.create(processor) - .expectNext(valueToSend) - .expectComplete() - .verify(Duration.ofMillis(10)); + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); }) .assertReceivedExactly(valueToSend) .then( @@ -371,19 +368,15 @@ public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { () -> Assertions.assertThat(self.block(null)) .matches((v) -> v.equals(valueToSend) || v.equals(valueToSend2)), - () -> - RaceTestUtils.race( - self::invalidate, - () -> { - for (; ; ) { - if (self.subscribers != ResolvingOperator.READY) { - self.complete(valueToSend2); - break; - } - } - }, - Schedulers.parallel()), - Schedulers.parallel())) + self::invalidate, + () -> { + for (; ; ) { + if (self.subscribers != ResolvingOperator.READY) { + self.complete(valueToSend2); + break; + } + } + })) .then( self -> { if (self.isPending()) { @@ -400,29 +393,33 @@ public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -442,11 +439,11 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { .assertDisposeCalled(0) .then( self -> { - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(processor2.isTerminated()).isTrue(); + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo(valueToSend); - Assertions.assertThat(processor2.peek()).isEqualTo(valueToSend); + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); @@ -457,20 +454,23 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; - MonoProcessor processor = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -482,7 +482,11 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { .then( self -> RaceTestUtils.race( - () -> processor.onNext(self.block(null)), () -> self.observe(consumer2))) + () -> { + subscriber.onNext(self.block(null)); + subscriber.onComplete(); + }, + () -> self.observe(consumer2))) .assertSubscribeCalled(1) .assertPendingSubscribers(0) .assertReceivedExactly(valueToSend) @@ -490,11 +494,11 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { .assertDisposeCalled(0) .then( self -> { - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(processor2.isTerminated()).isTrue(); + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo(valueToSend); - Assertions.assertThat(processor2.peek()).isEqualTo(valueToSend); + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); @@ -506,11 +510,14 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { @Test public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { Duration timeout = Duration.ofMillis(100); - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { final String valueToSend = "value" + i; - MonoProcessor processor = MonoProcessor.create(); - MonoProcessor processor2 = MonoProcessor.create(); + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); ResolvingTest.create() .assertNothingExpired() @@ -521,8 +528,14 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { .then( self -> RaceTestUtils.race( - () -> processor.onNext(self.block(timeout)), - () -> processor2.onNext(self.block(timeout)))) + () -> { + subscriber.onNext(self.block(timeout)); + subscriber.onComplete(); + }, + () -> { + subscriber2.onNext(self.block(timeout)); + subscriber2.onComplete(); + })) .assertSubscribeCalled(1) .assertPendingSubscribers(0) .assertReceivedExactly(valueToSend) @@ -530,11 +543,11 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { .assertDisposeCalled(0) .then( self -> { - Assertions.assertThat(processor.isTerminated()).isTrue(); - Assertions.assertThat(processor2.isTerminated()).isTrue(); + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); - Assertions.assertThat(processor.peek()).isEqualTo(valueToSend); - Assertions.assertThat(processor2.peek()).isEqualTo(valueToSend); + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); @@ -548,26 +561,31 @@ public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { public void shouldExpireValueOnRacingDisposeAndError() { Hooks.onErrorDropped(t -> {}); RuntimeException runtimeException = new RuntimeException("test"); - for (int i = 0; i < 10000; i++) { - MonoProcessor processor = MonoProcessor.create(); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); BiConsumer consumer = (v, t) -> { if (t != null) { - processor.onError(t); + subscriber.onError(t); return; } - processor.onNext(v); + subscriber.onNext(v); + subscriber.onComplete(); }; - MonoProcessor processor2 = MonoProcessor.create(); + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); BiConsumer consumer2 = (v, t) -> { if (t != null) { - processor2.onError(t); + subscriber2.onError(t); return; } - processor2.onNext(v); + subscriber2.onNext(v); + subscriber2.onComplete(); }; ResolvingTest.create() @@ -591,8 +609,9 @@ public void shouldExpireValueOnRacingDisposeAndError() { }) .thenAddObserver(consumer2); - StepVerifier.create(processor) - .expectErrorSatisfies( + subscriber + .await(Duration.ofMillis(10)) + .assertErrorWith( t -> { if (t instanceof CancellationException) { Assertions.assertThat(t) @@ -601,11 +620,11 @@ public void shouldExpireValueOnRacingDisposeAndError() { } else { Assertions.assertThat(t).isInstanceOf(RuntimeException.class).hasMessage("test"); } - }) - .verify(Duration.ofMillis(10)); + }); - StepVerifier.create(processor2) - .expectErrorSatisfies( + subscriber2 + .await(Duration.ofMillis(10)) + .assertErrorWith( t -> { if (t instanceof CancellationException) { Assertions.assertThat(t) @@ -614,8 +633,7 @@ public void shouldExpireValueOnRacingDisposeAndError() { } else { Assertions.assertThat(t).isInstanceOf(RuntimeException.class).hasMessage("test"); } - }) - .verify(Duration.ofMillis(10)); + }); // no way to guarantee equality because of racing // Assertions.assertThat(processor.getError()) @@ -664,9 +682,10 @@ public void shouldThrowOnBlockingIfHasAlreadyTerminated() { static Stream, Publisher>> innerCases() { return Stream.of( (self) -> { - final MonoProcessor processor = MonoProcessor.create(); + final Sinks.One processor = Sinks.unsafe().one(); final ResolvingOperator.DeferredResolution operator = - new ResolvingOperator.DeferredResolution(self, processor) { + new ResolvingOperator.DeferredResolution( + self, new SinkOneSubscriber(processor)) { @Override public void accept(String v, Throwable t) { if (t != null) { @@ -677,14 +696,21 @@ public void accept(String v, Throwable t) { onNext(v); } }; - return processor.doOnSubscribe(s -> self.observe(operator)).doOnCancel(operator::cancel); + return processor + .asMono() + .doOnSubscribe(s -> self.observe(operator)) + .doOnCancel(operator::cancel); }, (self) -> { - final MonoProcessor processor = MonoProcessor.create(); + final Sinks.One processor = Sinks.unsafe().one(); + final SinkOneSubscriber subscriber = new SinkOneSubscriber(processor); final ResolvingOperator.MonoDeferredResolutionOperator operator = - new ResolvingOperator.MonoDeferredResolutionOperator<>(self, processor); - processor.onSubscribe(operator); - return processor.doOnSubscribe(s -> self.observe(operator)).doOnCancel(operator::cancel); + new ResolvingOperator.MonoDeferredResolutionOperator<>(self, subscriber); + subscriber.onSubscribe(operator); + return processor + .asMono() + .doOnSubscribe(s -> self.observe(operator)) + .doOnCancel(operator::cancel); }); } @@ -737,12 +763,13 @@ public void shouldExpireValueOnDispose( public void shouldNotifyAllTheSubscribers( Function, Publisher> caseProducer) { - final MonoProcessor sub1 = MonoProcessor.create(); - final MonoProcessor sub2 = MonoProcessor.create(); - final MonoProcessor sub3 = MonoProcessor.create(); - final MonoProcessor sub4 = MonoProcessor.create(); + AssertSubscriber sub1 = AssertSubscriber.create(); + AssertSubscriber sub2 = AssertSubscriber.create(); + AssertSubscriber sub3 = AssertSubscriber.create(); + AssertSubscriber sub4 = AssertSubscriber.create(); - final ArrayList> processors = new ArrayList<>(200); + final ArrayList> processors = + new ArrayList<>(RaceTestConstants.REPEATS * 2); ResolvingTest.create() .assertDisposeCalled(0) @@ -761,9 +788,9 @@ public void shouldNotifyAllTheSubscribers( .assertPendingSubscribers(4) .then( self -> { - for (int i = 0; i < 100; i++) { - final MonoProcessor subA = MonoProcessor.create(); - final MonoProcessor subB = MonoProcessor.create(); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber subA = AssertSubscriber.create(); + AssertSubscriber subB = AssertSubscriber.create(); processors.add(subA); processors.add(subB); RaceTestUtils.race( @@ -772,21 +799,21 @@ public void shouldNotifyAllTheSubscribers( } }) .assertSubscribeCalled(1) - .assertPendingSubscribers(204) - .then(self -> sub1.dispose()) - .assertPendingSubscribers(203) + .assertPendingSubscribers(RaceTestConstants.REPEATS * 2 + 4) + .then(self -> sub1.cancel()) + .assertPendingSubscribers(RaceTestConstants.REPEATS * 2 + 3) .then( self -> { String valueToSend = "value"; self.complete(valueToSend); - Assertions.assertThatThrownBy(sub1::peek).isInstanceOf(CancellationException.class); - Assertions.assertThat(sub2.peek()).isEqualTo(valueToSend); - Assertions.assertThat(sub3.peek()).isEqualTo(valueToSend); - Assertions.assertThat(sub4.peek()).isEqualTo(valueToSend); + Assertions.assertThat(sub1.isTerminated()).isFalse(); + Assertions.assertThat(sub2.values()).containsExactly(valueToSend); + Assertions.assertThat(sub3.values()).containsExactly(valueToSend); + Assertions.assertThat(sub4.values()).containsExactly(valueToSend); - for (MonoProcessor sub : processors) { - Assertions.assertThat(sub.peek()).isEqualTo(valueToSend); + for (AssertSubscriber sub : processors) { + Assertions.assertThat(sub.values()).containsExactly(valueToSend); Assertions.assertThat(sub.isTerminated()).isTrue(); } }) @@ -797,7 +824,7 @@ public void shouldNotifyAllTheSubscribers( @Test public void shouldBeSerialIfRacyMonoInner() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { long[] requested = new long[] {0}; Subscription mockSubscription = Mockito.mock(Subscription.class); Mockito.doAnswer( @@ -833,7 +860,7 @@ public void accept(Object o, Object o2) {} @Test public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ResolvingTest.create() .assertNothingExpired() .assertNothingReceived() @@ -847,7 +874,7 @@ public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { @Test public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { - for (int i = 0; i < 10000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { ResolvingTest.create() .assertNothingExpired() .assertNothingReceived() @@ -967,4 +994,37 @@ protected void doOnDispose() { onDisposeCalls.incrementAndGet(); } } + + private static class SinkOneSubscriber implements CoreSubscriber { + + private final Sinks.One processor; + private boolean valueReceived; + + public SinkOneSubscriber(Sinks.One processor) { + this.processor = processor; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(String s) { + valueReceived = true; + processor.tryEmitValue(s); + } + + @Override + public void onError(Throwable t) { + processor.tryEmitError(t); + } + + @Override + public void onComplete() { + if (!valueReceived) { + processor.tryEmitEmpty(); + } + } + } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java index 2872d8d78..4f7821e4a 100755 --- a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java @@ -20,6 +20,7 @@ import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; import static io.rsocket.frame.FrameType.REQUEST_FNF; import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; import io.netty.buffer.ByteBuf; import io.rsocket.FrameAssert; @@ -28,8 +29,10 @@ import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; -import io.rsocket.internal.UnboundedProcessor; import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; import java.util.ArrayList; import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Stream; @@ -86,6 +89,12 @@ public ResponderFrameHandler responseOperator( new RequestResponseResponderSubscriber( streamId, firstFragment, streamManager, handler); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + return subscriber; } @@ -99,6 +108,12 @@ public ResponderFrameHandler responseOperator( RequestResponseResponderSubscriber subscriber = new RequestResponseResponderSubscriber(streamId, streamManager); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + return handler.requestResponse(firstPayload).subscribeWith(subscriber); } @@ -128,6 +143,12 @@ public ResponderFrameHandler responseOperator( RequestStreamResponderSubscriber subscriber = new RequestStreamResponderSubscriber( streamId, initialRequestN, firstFragment, streamManager, handler); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + streamManager.activeStreams.put(streamId, subscriber); return subscriber; } @@ -142,6 +163,12 @@ public ResponderFrameHandler responseOperator( RequestStreamResponderSubscriber subscriber = new RequestStreamResponderSubscriber(streamId, initialRequestN, streamManager); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + return handler.requestStream(firstPayload).subscribeWith(subscriber); } @@ -172,6 +199,12 @@ public ResponderFrameHandler responseOperator( new RequestChannelResponderSubscriber( streamId, initialRequestN, firstFragment, streamManager, handler); streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + return subscriber; } @@ -186,6 +219,12 @@ public ResponderFrameHandler responseOperator( new RequestChannelResponderSubscriber( streamId, initialRequestN, firstPayload, streamManager); streamManager.activeStreams.put(streamId, responderSubscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + return handler.requestChannel(responderSubscriber).subscribeWith(responderSubscriber); } @@ -242,10 +281,11 @@ public Flux requestChannel(Publisher payloads) { void shouldHandleRequest(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); - final UnboundedProcessor sender = testRequesterResponderSupport.getSendProcessor(); + final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); TestPublisher testPublisher = TestPublisher.create(); TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); @@ -261,7 +301,7 @@ void shouldHandleRequest(Scenario scenario) { testPublisher.next(randomPayload.retain()); testPublisher.complete(); - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .isNotNull() .hasStreamId(1) .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) @@ -274,15 +314,21 @@ void shouldHandleRequest(Scenario scenario) { if (scenario.requestType() != REQUEST_RESPONSE) { - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasStreamId(1).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); if (scenario.requestType() == REQUEST_CHANNEL) { testHandler.consumer.request(2); - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .typeOf(FrameType.REQUEST_N) .hasStreamId(1) .hasRequestN(1) .hasNoLeaks(); + + responderFrameHandler.handleComplete(); + testHandler.consumer.assertComplete(); } } @@ -291,6 +337,10 @@ void shouldHandleRequest(Scenario scenario) { .assertValueCount(1) .assertValuesWith(p -> PayloadAssert.assertThat(p).hasNoLeaks()); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); allocator.assertHasNoLeaks(); } @@ -299,10 +349,11 @@ void shouldHandleRequest(Scenario scenario) { void shouldHandleFragmentedRequest(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); - final UnboundedProcessor sender = testRequesterResponderSupport.getSendProcessor(); + final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); TestPublisher testPublisher = TestPublisher.create(); TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); @@ -332,7 +383,7 @@ void shouldHandleFragmentedRequest(Scenario scenario) { testPublisher.next(randomPayload.retain()); testPublisher.complete(); - FrameAssert.assertThat(sender.poll()) + FrameAssert.assertThat(sender.awaitFrame()) .isNotNull() .hasStreamId(1) .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) @@ -345,11 +396,14 @@ void shouldHandleFragmentedRequest(Scenario scenario) { if (scenario.requestType() != REQUEST_RESPONSE) { - FrameAssert.assertThat(sender.poll()).typeOf(FrameType.COMPLETE).hasStreamId(1).hasNoLeaks(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); if (scenario.requestType() == REQUEST_CHANNEL) { testHandler.consumer.request(2); - FrameAssert.assertThat(sender.poll()).isNull(); + FrameAssert.assertThat(sender.pollFrame()).isNull(); } } @@ -364,6 +418,11 @@ void shouldHandleFragmentedRequest(Scenario scenario) { firstPayload.release(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + allocator.assertHasNoLeaks(); } @@ -372,10 +431,10 @@ void shouldHandleFragmentedRequest(Scenario scenario) { void shouldHandleInterruptedFragmentation(Scenario scenario) { Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); TestRequesterResponderSupport testRequesterResponderSupport = - TestRequesterResponderSupport.client(); + TestRequesterResponderSupport.client(testRequestInterceptor); final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); - final UnboundedProcessor sender = testRequesterResponderSupport.getSendProcessor(); TestPublisher testPublisher = TestPublisher.create(); TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); @@ -408,6 +467,11 @@ void shouldHandleInterruptedFragmentation(Scenario scenario) { testPublisher.assertWasNotSubscribed(); testRequesterResponderSupport.assertNoActiveStreams(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + allocator.assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java new file mode 100644 index 000000000..9a51b9419 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java @@ -0,0 +1,31 @@ +package io.rsocket.core; + +import static org.mockito.Mockito.*; + +import io.netty.util.ReferenceCounted; +import java.util.function.Consumer; +import org.junit.jupiter.api.Test; + +public class SendUtilsTest { + + @Test + void droppedElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer value = extractDroppedElementConsumer(); + value.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = mock(ReferenceCounted.class); + when(referenceCounted.release()).thenReturn(true); + + Consumer value = extractDroppedElementConsumer(); + value.accept(referenceCounted); + + verify(referenceCounted).release(); + } + + private static Consumer extractDroppedElementConsumer() { + return (Consumer) SendUtils.DISCARD_CONTEXT.stream().findAny().get().getValue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index a64bf9b81..87c3a865f 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.core; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; @@ -6,7 +21,12 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.*; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; import io.rsocket.exceptions.RejectedSetupException; @@ -14,15 +34,13 @@ import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.frame.SetupFrameCodec; -import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.transport.ServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; public class SetupRejectionTest { @@ -40,18 +58,21 @@ void responderRejectSetup() { ByteBuf sentFrame = transport.awaitSent(); assertThat(FrameHeaderCodec.frameType(sentFrame)).isEqualTo(FrameType.ERROR); RuntimeException error = Exceptions.from(0, sentFrame); + sentFrame.release(); assertThat(errorMsg).isEqualTo(error.getMessage()); assertThat(error).isInstanceOf(RejectedSetupException.class); RSocket acceptorSender = acceptor.senderRSocket().block(); assertThat(acceptorSender.isDisposed()).isTrue(); + transport.allocator.assertHasNoLeaks(); } @Test - @Disabled("FIXME: needs to be revised") void requesterStreamsTerminatedOnZeroErrorFrame() { LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); TestDuplexConnection conn = new TestDuplexConnection(allocator); + Sinks.Empty onThisSideClosedSink = Sinks.empty(); + RSocketRequester rSocket = new RSocketRequester( conn, @@ -63,7 +84,10 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { 0, 0, null, - RequesterLeaseHandler.None); + __ -> null, + null, + onThisSideClosedSink, + onThisSideClosedSink.asMono()); String errorMsg = "error"; @@ -82,6 +106,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { .verify(Duration.ofSeconds(5)); assertThat(rSocket.isDisposed()).isTrue(); + allocator.assertHasNoLeaks(); } @Test @@ -89,6 +114,7 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); TestDuplexConnection conn = new TestDuplexConnection(allocator); + Sinks.Empty onThisSideClosedSink = Sinks.empty(); RSocketRequester rSocket = new RSocketRequester( conn, @@ -100,7 +126,10 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { 0, 0, null, - RequesterLeaseHandler.None); + __ -> null, + null, + onThisSideClosedSink, + onThisSideClosedSink.asMono()); conn.addToReceivedBuffer( ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); @@ -112,11 +141,13 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() { .expectErrorMatches( err -> err instanceof RejectedSetupException && "error".equals(err.getMessage())) .verify(Duration.ofSeconds(5)); + allocator.assertHasNoLeaks(); } private static class RejectingAcceptor implements SocketAcceptor { private final String errorMessage; - private final UnicastProcessor senderRSockets = UnicastProcessor.create(); + private final Sinks.Many senderRSockets = + Sinks.many().unicast().onBackpressureBuffer(); public RejectingAcceptor(String errorMessage) { this.errorMessage = errorMessage; @@ -124,12 +155,12 @@ public RejectingAcceptor(String errorMessage) { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { - senderRSockets.onNext(sendingSocket); + senderRSockets.tryEmitNext(sendingSocket); return Mono.error(new RuntimeException(errorMessage)); } public Mono senderRSocket() { - return senderRSockets.next(); + return senderRSockets.asFlux().next(); } } @@ -145,11 +176,7 @@ public Mono start(ConnectionAcceptor acceptor) { } public ByteBuf awaitSent() { - try { - return conn.awaitSend(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } + return conn.awaitFrame(); } public void connect() { diff --git a/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java index 98fde97f7..16bd9f16e 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java @@ -16,22 +16,23 @@ package io.rsocket.core; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class StreamIdSupplierTest { @Test public void testClientSequence() { IntObjectMap map = new IntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); - assertEquals(1, s.nextStreamId(map)); - assertEquals(3, s.nextStreamId(map)); - assertEquals(5, s.nextStreamId(map)); + assertThat(s.nextStreamId(map)).isEqualTo(1); + assertThat(s.nextStreamId(map)).isEqualTo(3); + assertThat(s.nextStreamId(map)).isEqualTo(5); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java index 8069e7362..e282d72d5 100644 --- a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java +++ b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -21,10 +21,14 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.FrameType; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.ByteBufPayload; import java.util.ArrayList; import java.util.concurrent.ThreadLocalRandom; @@ -32,7 +36,7 @@ import reactor.core.Exceptions; import reactor.util.annotation.Nullable; -final class TestRequesterResponderSupport extends RequesterResponderSupport { +final class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { static final String DATA_CONTENT = "testData"; static final String METADATA_CONTENT = "testMetadata"; @@ -42,19 +46,27 @@ final class TestRequesterResponderSupport extends RequesterResponderSupport { TestRequesterResponderSupport( @Nullable Throwable error, StreamIdSupplier streamIdSupplier, + DuplexConnection connection, int mtu, int maxFrameLength, - int maxInboundPayloadSize) { + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { super( mtu, maxFrameLength, maxInboundPayloadSize, PayloadDecoder.ZERO_COPY, - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT), - streamIdSupplier); + connection, + streamIdSupplier, + (__) -> requestInterceptor); this.error = error; } + @Override + public TestDuplexConnection getDuplexConnection() { + return (TestDuplexConnection) super.getDuplexConnection(); + } + static Payload genericPayload(LeaksTrackingByteBufAllocator allocator) { ByteBuf data = allocator.buffer(); data.writeCharSequence(DATA_CONTENT, CharsetUtil.UTF_8); @@ -162,14 +174,67 @@ public synchronized int addAndGetNextStreamId(FrameHandler frameHandler) { return nextStreamId; } + public static TestRequesterResponderSupport client( + @Nullable Throwable e, @Nullable RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor, + e); + } + public static TestRequesterResponderSupport client(@Nullable Throwable e) { return client(0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, e); } public static TestRequesterResponderSupport client( int mtu, int maxFrameLength, int maxInboundPayloadSize, @Nullable Throwable e) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + mtu, + maxFrameLength, + maxInboundPayloadSize, + null, + e); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize) { + return client(duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + return client( + duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, requestInterceptor, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor, + @Nullable Throwable e) { return new TestRequesterResponderSupport( - e, StreamIdSupplier.clientSupplier(), mtu, maxFrameLength, maxInboundPayloadSize); + e, + StreamIdSupplier.clientSupplier(), + duplexConnection, + mtu, + maxFrameLength, + maxInboundPayloadSize, + requestInterceptor); } public static TestRequesterResponderSupport client( @@ -189,6 +254,16 @@ public static TestRequesterResponderSupport client() { return client(0); } + public static TestRequesterResponderSupport client(RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor); + } + public TestRequesterResponderSupport assertNoActiveStreams() { Assertions.assertThat(activeStreams).isEmpty(); return this; diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java b/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java deleted file mode 100644 index e0ebf5064..000000000 --- a/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.rsocket.core; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import org.junit.Test; - -public class TestingStuff { - private String f = "00000001110000000068656c6c6f"; - private String f1 = - "00000001286004232e667127bb590fb097cf657776761dcdfe84863f67da47e9de9ac72197424116aae3aadf9e4d8347d8f7923107f7eacb56f741c82327e9c4cbe23e92b5afc306aa24a2153e27082ba0bb1e707bed43100ea84a87539a23442e10431584a42f6eb78fdf140fb14ee71cf4a23ad644dcd3ebceb86e4d0617aa0d2cfee56ce1afea5062842580275b5fdc96cae1bbe9f3a129148dbe1dfc44c8e11aa6a7ec8dafacbbdbd0b68731c16bd16c21eb857c9eb1bb6c359415674b6d41d14e99b1dd56a40fc836d723dd534e83c44d6283745332c627e13bcfc2cd483bccec67232fff0b2ccb7388f0d37da27562d7c3635fef061767400e45729bdef57ca8c041e33074ea1a42004c1b8cb02eb3afeaf5f6d82162d4174c549f840bdb88632faf2578393194f67538bf581a22f31850f88831af632bdaf32c80aa6d96a7afc20b8067c4f9d17859776e4c40beafff18a848df45927ca1c9b024ef278a9fb60bdf965b5822b64bebc74a8b7d95a4bd9d1c1fc82b4fbacc29e36458a878079ddd402788a462528d6c79df797218563cc70811c09b154588a3edd2e948bb61db7b3a36774e0bd5ab67fec4bf1e70811733213f292a728389473b9f68d288ac481529e10cfd428b14ad3f4592d1cc6dd08b1a7842bb492b51057c4d88ac5d538174560f94b49dce6d20ef71671d2e80c2b92ead6d4a26ed8f4187a563cb53eb0c558fe9f77b2133e835e2d2e671978e82a6f60ed61a6a945e39fe0dedcf73d7cb80253a5eb9f311c78ef2587649436f4ab42bcc882faba7bfd57d451407a07ce1d5ac7b5f0cf1ef84047c92e3fbdb64128925ef6e87def450ae8a1643e9906b7dc1f672bd98e012df3039f2ee412909f4b03db39f45b83955f31986b6fd3b5e4f26b6ec2284dcf55ff5fbbfbfb31cd6b22753c6435dbd3ec5558132c6ede9babd7945ac6e697d28b9697f9b2450db2b643a1abc4c9ad5bfa4529d0e1f261df1da5ee035738a5d8c536466fa741e9190c58cf1cacc819838a6b20d85f901f026c66dbaf23cde3a12ce4b443ef15cc247ba48cc0812c6f2c834c5773f3d4042219727404f0f2640cab486e298ae9f1c2f7a7e6f0619f130895d9f41d343fbdb05d68d6e0308d8d046314811066a13300b1346b8762919d833de7f55fea919ad55500ba4ec7e100b32bbabbf9d378eab61532fd91d4d1977db72b828e8d700062b045459d7729f140d889a67472a035d564384844ff16697743e4017e2bf21511ebb4c939bbab202bf6ef59e2be557027272f1bb21c325cf3e0432120bccba17bea52a7621031466e7973415437cd50cc950e63e6e2d17aad36f7a943892901e763e19082260b88f8971b35b4d9cc8725d6e4137b4648427ae68255e076dfb511871de0f7100d2ece6c8be88a0326ba8d73b5c9883f83c0dccd362e61cb16c7a0cc5ff00f7"; - private String f2 = "00000003110000000068656c6c6f"; - - @Test - public void testStuff() { - ByteBuf byteBuf = Unpooled.wrappedBuffer(ByteBufUtil.decodeHexDump(f1)); - System.out.println(ByteBufUtil.prettyHexDump(byteBuf)); - - new DefaultConnectionSetupPayload(byteBuf); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java index fd05cb7da..a316aed8b 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java @@ -31,6 +31,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.RaceTestConstants; import io.rsocket.frame.ErrorFrameCodec; import java.util.concurrent.ThreadLocalRandom; import org.junit.jupiter.api.DisplayName; @@ -42,14 +43,18 @@ final class ExceptionsTest { void fromApplicationException() { ByteBuf byteBuf = createErrorFrame(1, APPLICATION_ERROR, "test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(ApplicationErrorException.class) - .hasMessage("test-message"); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(ApplicationErrorException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); + } finally { + byteBuf.release(); + } } @DisplayName("from returns CanceledException") @@ -57,28 +62,37 @@ void fromApplicationException() { void fromCanceledException() { ByteBuf byteBuf = createErrorFrame(1, CANCELED, "test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(CanceledException.class) - .hasMessage("test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CanceledException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); + } finally { + byteBuf.release(); + } } @DisplayName("from returns ConnectionCloseException") @Test void fromConnectionCloseException() { ByteBuf byteBuf = createErrorFrame(0, CONNECTION_CLOSE, "test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(ConnectionCloseException.class) - .hasMessage("test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(ConnectionCloseException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); + } finally { + byteBuf.release(); + } } @DisplayName("from returns ConnectionErrorException") @@ -86,122 +100,152 @@ void fromConnectionCloseException() { void fromConnectionErrorException() { ByteBuf byteBuf = createErrorFrame(0, CONNECTION_ERROR, "test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(ConnectionErrorException.class) - .hasMessage("test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(ConnectionErrorException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); + } finally { + byteBuf.release(); + } } @DisplayName("from returns IllegalArgumentException if error frame has illegal error code") @Test void fromIllegalErrorFrame() { ByteBuf byteBuf = createErrorFrame(0, 0x00000000, "test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") + .isInstanceOf(IllegalArgumentException.class); - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns InvalidException") @Test void fromInvalidException() { ByteBuf byteBuf = createErrorFrame(1, INVALID, "test-message"); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(InvalidException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(InvalidException.class) - .hasMessage("test-message"); - - assertThat(Exceptions.from(0, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns InvalidSetupException") @Test void fromInvalidSetupException() { ByteBuf byteBuf = createErrorFrame(0, INVALID_SETUP, "test-message"); + try { + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(InvalidSetupException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(InvalidSetupException.class) - .hasMessage("test-message"); - - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns RejectedException") @Test void fromRejectedException() { ByteBuf byteBuf = createErrorFrame(1, REJECTED, "test-message"); + try { - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(RejectedException.class) - .withFailMessage("test-message"); + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(RejectedException.class) + .withFailMessage("test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns RejectedResumeException") @Test void fromRejectedResumeException() { ByteBuf byteBuf = createErrorFrame(0, REJECTED_RESUME, "test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(RejectedResumeException.class) - .hasMessage("test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(RejectedResumeException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns RejectedSetupException") @Test void fromRejectedSetupException() { ByteBuf byteBuf = createErrorFrame(0, REJECTED_SETUP, "test-message"); + try { - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(RejectedSetupException.class) - .withFailMessage("test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(RejectedSetupException.class) + .withFailMessage("test-message"); - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns UnsupportedSetupException") @Test void fromUnsupportedSetupException() { ByteBuf byteBuf = createErrorFrame(0, UNSUPPORTED_SETUP, "test-message"); + try { + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(UnsupportedSetupException.class) + .hasMessage("test-message"); - assertThat(Exceptions.from(0, byteBuf)) - .isInstanceOf(UnsupportedSetupException.class) - .hasMessage("test-message"); - - assertThat(Exceptions.from(1, byteBuf)) - .hasMessage( - "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") - .isInstanceOf(IllegalArgumentException.class); + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } @DisplayName("from returns CustomRSocketException") @Test void fromCustomRSocketException() { - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { int randomCode = ThreadLocalRandom.current().nextBoolean() ? ThreadLocalRandom.current() @@ -209,15 +253,18 @@ void fromCustomRSocketException() { : ThreadLocalRandom.current() .nextInt(ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE, Integer.MAX_VALUE); ByteBuf byteBuf = createErrorFrame(0, randomCode, "test-message"); - - assertThat(Exceptions.from(1, byteBuf)) - .isInstanceOf(CustomRSocketException.class) - .hasMessage("test-message"); - - assertThat(Exceptions.from(0, byteBuf)) - .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") - .isInstanceOf(IllegalArgumentException.class); - byteBuf.release(); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index 75aa2a5b2..b12d72b51 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -18,9 +18,18 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; import io.netty.util.IllegalReferenceCountException; +import org.assertj.core.api.Assertions; import org.assertj.core.presentation.StandardRepresentation; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; -public final class ByteBufRepresentation extends StandardRepresentation { +public final class ByteBufRepresentation extends StandardRepresentation + implements BeforeAllCallback { + + @Override + public void beforeAll(ExtensionContext context) { + Assertions.useRepresentation(this); + } @Override protected String fallbackToStringOf(Object object) { diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java index fe05335d2..4815bfb8e 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java @@ -16,11 +16,12 @@ package io.rsocket.frame; +import static org.assertj.core.api.Assertions.assertThat; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import java.util.Arrays; -import org.junit.Assert; import org.junit.jupiter.api.Test; public class ResumeFrameCodecTest { @@ -31,10 +32,10 @@ void testEncoding() { Arrays.fill(tokenBytes, (byte) 1); ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); ByteBuf byteBuf = ResumeFrameCodec.encode(ByteBufAllocator.DEFAULT, token, 21, 12); - Assert.assertEquals(ResumeFrameCodec.CURRENT_VERSION, ResumeFrameCodec.version(byteBuf)); - Assert.assertEquals(token, ResumeFrameCodec.token(byteBuf)); - Assert.assertEquals(21, ResumeFrameCodec.lastReceivedServerPos(byteBuf)); - Assert.assertEquals(12, ResumeFrameCodec.firstAvailableClientPos(byteBuf)); + assertThat(ResumeFrameCodec.version(byteBuf)).isEqualTo(ResumeFrameCodec.CURRENT_VERSION); + assertThat(ResumeFrameCodec.token(byteBuf)).isEqualTo(token); + assertThat(ResumeFrameCodec.lastReceivedServerPos(byteBuf)).isEqualTo(21); + assertThat(ResumeFrameCodec.firstAvailableClientPos(byteBuf)).isEqualTo(12); byteBuf.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java index 33dd8eb70..b818d579d 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java @@ -1,16 +1,17 @@ package io.rsocket.frame; +import static org.assertj.core.api.Assertions.assertThat; + import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class ResumeOkFrameCodecTest { @Test public void testEncoding() { ByteBuf byteBuf = ResumeOkFrameCodec.encode(ByteBufAllocator.DEFAULT, 42); - Assert.assertEquals(42, ResumeOkFrameCodec.lastReceivedClientPos(byteBuf)); + assertThat(ResumeOkFrameCodec.lastReceivedClientPos(byteBuf)).isEqualTo(42); byteBuf.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java deleted file mode 100644 index fb951eb8a..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java +++ /dev/null @@ -1,274 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertEquals; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.buffer.LeaksTrackingByteBufAllocator; -import io.rsocket.frame.ErrorFrameCodec; -import io.rsocket.frame.KeepAliveFrameCodec; -import io.rsocket.frame.LeaseFrameCodec; -import io.rsocket.frame.MetadataPushFrameCodec; -import io.rsocket.frame.ResumeFrameCodec; -import io.rsocket.frame.ResumeOkFrameCodec; -import io.rsocket.frame.SetupFrameCodec; -import io.rsocket.plugins.InitializingInterceptorRegistry; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.util.DefaultPayload; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import org.junit.Before; -import org.junit.Test; - -public class ClientServerInputMultiplexerTest { - private TestDuplexConnection source; - private ClientServerInputMultiplexer clientMultiplexer; - private LeaksTrackingByteBufAllocator allocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - private ClientServerInputMultiplexer serverMultiplexer; - - @Before - public void setup() { - source = new TestDuplexConnection(allocator); - clientMultiplexer = - new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); - serverMultiplexer = - new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); - } - - @Test - public void clientSplits() { - AtomicInteger clientFrames = new AtomicInteger(); - AtomicInteger serverFrames = new AtomicInteger(); - AtomicInteger setupFrames = new AtomicInteger(); - - clientMultiplexer - .asClientConnection() - .receive() - .doOnNext(f -> clientFrames.incrementAndGet()) - .subscribe(); - clientMultiplexer - .asServerConnection() - .receive() - .doOnNext(f -> serverFrames.incrementAndGet()) - .subscribe(); - clientMultiplexer - .asSetupConnection() - .receive() - .doOnNext(f -> setupFrames.incrementAndGet()) - .subscribe(); - - source.addToReceivedBuffer(setupFrame().retain()); - assertEquals(0, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(1).retain()); - assertEquals(1, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(1).retain()); - assertEquals(2, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(leaseFrame().retain()); - assertEquals(3, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(keepAliveFrame().retain()); - assertEquals(4, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(2).retain()); - assertEquals(4, clientFrames.get()); - assertEquals(1, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(0).retain()); - assertEquals(5, clientFrames.get()); - assertEquals(1, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(metadataPushFrame().retain()); - assertEquals(5, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(resumeFrame().retain()); - assertEquals(5, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(2, setupFrames.get()); - - source.addToReceivedBuffer(resumeOkFrame().retain()); - assertEquals(5, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(3, setupFrames.get()); - } - - @Test - public void serverSplits() { - AtomicInteger clientFrames = new AtomicInteger(); - AtomicInteger serverFrames = new AtomicInteger(); - AtomicInteger setupFrames = new AtomicInteger(); - - serverMultiplexer - .asClientConnection() - .receive() - .doOnNext(f -> clientFrames.incrementAndGet()) - .subscribe(); - serverMultiplexer - .asServerConnection() - .receive() - .doOnNext(f -> serverFrames.incrementAndGet()) - .subscribe(); - serverMultiplexer - .asSetupConnection() - .receive() - .doOnNext(f -> setupFrames.incrementAndGet()) - .subscribe(); - - source.addToReceivedBuffer(setupFrame().retain()); - assertEquals(0, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(1).retain()); - assertEquals(1, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(1).retain()); - assertEquals(2, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(leaseFrame().retain()); - assertEquals(2, clientFrames.get()); - assertEquals(1, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(keepAliveFrame().retain()); - assertEquals(2, clientFrames.get()); - assertEquals(2, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(2).retain()); - assertEquals(2, clientFrames.get()); - assertEquals(3, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(errorFrame(0).retain()); - assertEquals(2, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(metadataPushFrame().retain()); - assertEquals(3, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(1, setupFrames.get()); - - source.addToReceivedBuffer(resumeFrame().retain()); - assertEquals(3, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(2, setupFrames.get()); - - source.addToReceivedBuffer(resumeOkFrame().retain()); - assertEquals(3, clientFrames.get()); - assertEquals(4, serverFrames.get()); - assertEquals(3, setupFrames.get()); - } - - @Test - public void unexpectedFramesBeforeSetupFrame() { - AtomicInteger clientFrames = new AtomicInteger(); - AtomicInteger serverFrames = new AtomicInteger(); - AtomicInteger setupFrames = new AtomicInteger(); - - AtomicReference clientError = new AtomicReference<>(); - AtomicReference serverError = new AtomicReference<>(); - AtomicReference setupError = new AtomicReference<>(); - - serverMultiplexer - .asClientConnection() - .receive() - .subscribe(bb -> clientFrames.incrementAndGet(), clientError::set); - serverMultiplexer - .asServerConnection() - .receive() - .subscribe(bb -> serverFrames.incrementAndGet(), serverError::set); - serverMultiplexer - .asSetupConnection() - .receive() - .subscribe(bb -> setupFrames.incrementAndGet(), setupError::set); - - source.addToReceivedBuffer(keepAliveFrame().retain()); - - assertThat(clientError.get().getMessage()) - .isEqualTo("SETUP or LEASE frame must be received before any others."); - assertThat(serverError.get().getMessage()) - .isEqualTo("SETUP or LEASE frame must be received before any others."); - assertThat(setupError.get().getMessage()) - .isEqualTo("SETUP or LEASE frame must be received before any others."); - - assertEquals(0, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(0, setupFrames.get()); - } - - private ByteBuf resumeFrame() { - return ResumeFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER, 0, 0); - } - - private ByteBuf setupFrame() { - return SetupFrameCodec.encode( - ByteBufAllocator.DEFAULT, - false, - 0, - 42, - "application/octet-stream", - "application/octet-stream", - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER)); - } - - private ByteBuf leaseFrame() { - return LeaseFrameCodec.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); - } - - private ByteBuf errorFrame(int i) { - return ErrorFrameCodec.encode(allocator, i, new Exception()); - } - - private ByteBuf resumeOkFrame() { - return ResumeOkFrameCodec.encode(allocator, 0); - } - - private ByteBuf keepAliveFrame() { - return KeepAliveFrameCodec.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); - } - - private ByteBuf metadataPushFrame() { - return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java index 7bf975543..343a93beb 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-present the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,115 +16,351 @@ package io.rsocket.internal; -import io.rsocket.Payload; -import io.rsocket.util.ByteBufPayload; -import io.rsocket.util.EmptyPayload; -import java.util.concurrent.CountDownLatch; -import org.junit.Assert; -import org.junit.Test; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.time.Duration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; public class UnboundedProcessorTest { - @Test - public void testOnNextBeforeSubscribe_10() { - testOnNextBeforeSubscribeN(10); - } - @Test - public void testOnNextBeforeSubscribe_100() { - testOnNextBeforeSubscribeN(100); + @BeforeAll + public static void setup() { + Hooks.onErrorDropped(__ -> {}); } - @Test - public void testOnNextBeforeSubscribe_10_000() { - testOnNextBeforeSubscribeN(10_000); + @AfterAll + public static void teardown() { + Hooks.resetOnErrorDropped(); } - @Test - public void testOnNextBeforeSubscribe_100_000() { - testOnNextBeforeSubscribeN(100_000); - } + @ParameterizedTest( + name = + "Test that emitting {0} onNext before subscribe and requestN should deliver all the signals once the subscriber is available") + @ValueSource(ints = {10, 100, 10_000, 100_000, 1_000_000, 10_000_000}) + public void testOnNextBeforeSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor(); - @Test - public void testOnNextBeforeSubscribe_1_000_000() { - testOnNextBeforeSubscribeN(1_000_000); - } + for (int i = 0; i < n; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } + + processor.onComplete(); - @Test - public void testOnNextBeforeSubscribe_10_000_000() { - testOnNextBeforeSubscribeN(10_000_000); + StepVerifier.create(processor.count()).expectNext(Long.valueOf(n)).verifyComplete(); } - public void testOnNextBeforeSubscribeN(int n) { - UnboundedProcessor processor = new UnboundedProcessor<>(); + @ParameterizedTest( + name = + "Test that emitting {0} onNext after subscribe and requestN should deliver all the signals") + @ValueSource(ints = {10, 100, 10_000}) + public void testOnNextAfterSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + processor.subscribe(assertSubscriber); for (int i = 0; i < n; i++) { - processor.onNext(EmptyPayload.INSTANCE); + processor.onNext(Unpooled.EMPTY_BUFFER); } - processor.onComplete(); + assertSubscriber.awaitAndAssertNextValueCount(n); + } - long count = processor.count().block(); + @ParameterizedTest( + name = + "Test that prioritized value sending deliver prioritized signals before the others mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void testPrioritizedSending(boolean fusedCase) { + UnboundedProcessor processor = new UnboundedProcessor(); - Assert.assertEquals(n, count); - } + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } - @Test - public void testOnNextAfterSubscribe_10() throws Exception { - testOnNextAfterSubscribeN(10); - } + processor.onNextPrioritized(Unpooled.copiedBuffer("test", CharsetUtil.UTF_8)); - @Test - public void testOnNextAfterSubscribe_100() throws Exception { - testOnNextAfterSubscribeN(100); + assertThat(fusedCase ? processor.poll() : processor.next().block()) + .isNotNull() + .extracting(bb -> bb.toString(CharsetUtil.UTF_8)) + .isEqualTo("test"); } - @Test - public void testOnNextAfterSubscribe_1000() throws Exception { - testOnNextAfterSubscribeN(1000); - } + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensureUnboundedProcessorDisposesQueueProperly(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); - @Test - public void testPrioritizedSending() { - UnboundedProcessor processor = new UnboundedProcessor<>(); + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); - for (int i = 0; i < 1000; i++) { - processor.onNext(EmptyPayload.INSTANCE); + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + }, + unboundedProcessor::dispose, + assertSubscriber::cancel, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest1(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); - processor.onNextPrioritized(ByteBufPayload.create("test")); + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); - Payload closestPayload = processor.next().block(); + unboundedProcessor.subscribe(assertSubscriber); - Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + assertSubscriber::cancel, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } } - @Test - public void testPrioritizedFused() { - UnboundedProcessor processor = new UnboundedProcessor<>(); + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe | request(n) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest2(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> { + unboundedProcessor.subscribe(assertSubscriber); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); - for (int i = 0; i < 1000; i++) { - processor.onNext(EmptyPayload.INSTANCE); + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe(cancelled) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest3(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + assertSubscriber.cancel(); - processor.onNextPrioritized(ByteBufPayload.create("test")); + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> unboundedProcessor.subscribe(assertSubscriber)); - Payload closestPayload = processor.poll(); + assertSubscriber.values().forEach(ReferenceCountUtil::release); - Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + allocator.assertHasNoLeaks(); + } } - public void testOnNextAfterSubscribeN(int n) throws Exception { - CountDownLatch latch = new CountDownLatch(n); - UnboundedProcessor processor = new UnboundedProcessor<>(); - processor.log().doOnNext(integer -> latch.countDown()).subscribe(); + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe(cancelled) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest31(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); - for (int i = 0; i < n; i++) { - System.out.println("onNexting -> " + i); - processor.onNext(EmptyPayload.INSTANCE); + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> unboundedProcessor.subscribe(assertSubscriber), + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }, + assertSubscriber::cancel); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + allocator.assertHasNoLeaks(); } + } - processor.drain(); + @ParameterizedTest( + name = + "Ensures that racing between onNext + dispose | downstream async drain should not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensuresAsyncFusionAndDisposureHasNoDeadlock(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + final ByteBuf buffer5 = allocator.buffer(5); + final ByteBuf buffer6 = allocator.buffer(6); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber() + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + unboundedProcessor.onNext(buffer3); + unboundedProcessor.onNext(buffer4); + unboundedProcessor.onNext(buffer5); + unboundedProcessor.onNext(buffer6); + unboundedProcessor.dispose(); + }, + unboundedProcessor::dispose); + + assertSubscriber.await(Duration.ofSeconds(50)).values().forEach(ReferenceCountUtil::release); + } - latch.await(); + allocator.assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java index 76b2c366b..b6eac9835 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2011-2017 Pivotal Software Inc, All Rights Reserved. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; import java.util.concurrent.atomic.AtomicLongFieldUpdater; import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.BooleanSupplier; @@ -36,6 +37,7 @@ import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.Fuseable; +import reactor.core.Scannable; import reactor.core.publisher.Operators; import reactor.util.annotation.NonNull; import reactor.util.context.Context; @@ -77,7 +79,7 @@ * @author Stephane Maldini * @author Brian Clozel */ -public class AssertSubscriber implements CoreSubscriber, Subscription { +public class AssertSubscriber implements CoreSubscriber, Subscription, Scannable { /** Default timeout for waiting next values to be received */ public static final Duration DEFAULT_VALUES_TIMEOUT = Duration.ofSeconds(3); @@ -86,6 +88,10 @@ public class AssertSubscriber implements CoreSubscriber, Subscription { private static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(AssertSubscriber.class, "requested"); + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(AssertSubscriber.class, "wip"); + @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater NEXT_VALUES = AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, List.class, "values"); @@ -100,10 +106,14 @@ public class AssertSubscriber implements CoreSubscriber, Subscription { private final CountDownLatch cdl = new CountDownLatch(1); + volatile boolean done; + volatile Subscription s; volatile long requested; + volatile int wip; + volatile List values = new LinkedList<>(); /** The fusion mode to request. */ @@ -854,6 +864,13 @@ public void cancel() { a = S.getAndSet(this, Operators.cancelledSubscription()); if (a != null && a != Operators.cancelledSubscription()) { a.cancel(); + + if (establishedFusionMode == Fuseable.ASYNC) { + final int previousState = markWorkAdded(); + if (!isWorkInProgress(previousState)) { + clearAndFinalize(); + } + } } } } @@ -868,37 +885,121 @@ public final boolean isTerminated() { @Override public void onComplete() { + done = true; completionCount++; + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + cdl.countDown(); } @Override public void onError(Throwable t) { + done = true; errors.add(t); + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + cdl.countDown(); } @Override public void onNext(T t) { if (establishedFusionMode == Fuseable.ASYNC) { - for (; ; ) { - t = qs.poll(); - if (t == null) { - break; - } - valueCount++; - if (valuesStorage) { - List nextValuesSnapshot; - for (; ; ) { - nextValuesSnapshot = values; - nextValuesSnapshot.add(t); - if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { - break; - } + drain(); + } else { + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; } } } - } else { + } + } + + static boolean isFinalized(int state) { + return state == Integer.MIN_VALUE; + } + + static boolean isWorkInProgress(int state) { + return state > 0; + } + + int markWorkAdded() { + for (; ; ) { + int state = this.wip; + + if (isFinalized(state)) { + return state; + } + + if ((state & Integer.MAX_VALUE) == Integer.MAX_VALUE) { + return state; + } + int nextState = state + 1; + + if (WIP.compareAndSet(this, state, nextState)) { + return state; + } + } + } + + void clearAndFinalize() { + final Fuseable.QueueSubscription qs = this.qs; + for (; ; ) { + int state = this.wip; + + qs.clear(); + + if (WIP.compareAndSet(this, state, Integer.MIN_VALUE)) { + return; + } + } + } + + void drain() { + final int previousState = markWorkAdded(); + if (isWorkInProgress(previousState)) { + return; + } + + if (isFinalized(previousState)) { + qs.clear(); + return; + } + + T t; + int m = 1; + for (; ; ) { + if (isCancelled()) { + clearAndFinalize(); + break; + } + boolean done = this.done; + t = qs.poll(); + if (t == null) { + if (done) { + clearAndFinalize(); + cdl.countDown(); + return; + } + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + continue; + } valueCount++; if (valuesStorage) { List nextValuesSnapshot; @@ -919,39 +1020,41 @@ public void onSubscribe(Subscription s) { subscriptionCount++; int requestMode = requestedFusionMode; if (requestMode >= 0) { - if (!setWithoutRequesting(s)) { - if (!isCancelled()) { - errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); - } - } else { - if (s instanceof Fuseable.QueueSubscription) { - this.qs = (Fuseable.QueueSubscription) s; + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; - int m = qs.requestFusion(requestMode); - establishedFusionMode = m; + int m = qs.requestFusion(requestMode); + establishedFusionMode = m; - if (m == Fuseable.SYNC) { - for (; ; ) { - T v = qs.poll(); - if (v == null) { - onComplete(); - break; - } + if (!setWithoutRequesting(s)) { + qs.clear(); + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + return; + } - onNext(v); + if (m == Fuseable.SYNC) { + for (; ; ) { + T v = qs.poll(); + if (v == null) { + onComplete(); + break; } - } else { - requestDeferred(); + + onNext(v); } } else { requestDeferred(); } + + return; } - } else { - if (!set(s)) { - if (!isCancelled()) { - errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); - } + } + + if (!set(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); } } } @@ -1143,6 +1246,10 @@ public List values() { return values; } + public List errors() { + return errors; + } + public final AssertSubscriber assertNoEvents() { return assertNoValues().assertNoError().assertNotComplete(); } @@ -1151,4 +1258,20 @@ public final AssertSubscriber assertNoEvents() { public final AssertSubscriber assertIncomplete(T... values) { return assertValues(values).assertNotComplete().assertNoError(); } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) { + return upstream(); + } + + boolean t = isTerminated(); + if (key == Attr.TERMINATED) return t; + if (key == Attr.ERROR) return (!errors.isEmpty() ? errors.get(0) : null); + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + if (key == Attr.CANCELLED) return isCancelled(); + if (key == Attr.RUN_STYLE) return Attr.RunStyle.SYNC; + + return null; + } } diff --git a/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java index d5b2eeb41..9ebca34f7 100644 --- a/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java +++ b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java @@ -16,71 +16,63 @@ package io.rsocket.lease; -import static org.junit.Assert.*; - -import io.netty.buffer.Unpooled; -import java.time.Duration; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; - public class LeaseImplTest { - - @Test - public void emptyLeaseNoAvailability() { - LeaseImpl empty = LeaseImpl.empty(); - Assertions.assertTrue(empty.isEmpty()); - Assertions.assertFalse(empty.isValid()); - Assertions.assertEquals(0.0, empty.availability(), 1e-5); - } - - @Test - public void emptyLeaseUseNoAvailability() { - LeaseImpl empty = LeaseImpl.empty(); - boolean success = empty.use(); - assertFalse(success); - Assertions.assertEquals(0.0, empty.availability(), 1e-5); - } - - @Test - public void leaseAvailability() { - LeaseImpl lease = LeaseImpl.create(2, 100, Unpooled.EMPTY_BUFFER); - Assertions.assertEquals(1.0, lease.availability(), 1e-5); - } - - @Test - public void leaseUseDecreasesAvailability() { - LeaseImpl lease = LeaseImpl.create(30_000, 2, Unpooled.EMPTY_BUFFER); - boolean success = lease.use(); - Assertions.assertTrue(success); - Assertions.assertEquals(0.5, lease.availability(), 1e-5); - Assertions.assertTrue(lease.isValid()); - success = lease.use(); - Assertions.assertTrue(success); - Assertions.assertEquals(0.0, lease.availability(), 1e-5); - Assertions.assertFalse(lease.isValid()); - Assertions.assertEquals(0, lease.getAllowedRequests()); - success = lease.use(); - Assertions.assertFalse(success); - } - - @Test - public void leaseTimeout() { - int numberOfRequests = 1; - LeaseImpl lease = LeaseImpl.create(1, numberOfRequests, Unpooled.EMPTY_BUFFER); - Mono.delay(Duration.ofMillis(100)).block(); - boolean success = lease.use(); - Assertions.assertFalse(success); - Assertions.assertTrue(lease.isExpired()); - Assertions.assertEquals(numberOfRequests, lease.getAllowedRequests()); - Assertions.assertFalse(lease.isValid()); - } - - @Test - public void useLeaseChangesAllowedRequests() { - int numberOfRequests = 2; - LeaseImpl lease = LeaseImpl.create(30_000, numberOfRequests, Unpooled.EMPTY_BUFFER); - lease.use(); - assertEquals(numberOfRequests - 1, lease.getAllowedRequests()); - } + // + // @Test + // public void emptyLeaseNoAvailability() { + // LeaseImpl empty = LeaseImpl.empty(); + // Assertions.assertTrue(empty.isEmpty()); + // Assertions.assertFalse(empty.isValid()); + // Assertions.assertEquals(0.0, empty.availability(), 1e-5); + // } + // + // @Test + // public void emptyLeaseUseNoAvailability() { + // LeaseImpl empty = LeaseImpl.empty(); + // boolean success = empty.use(); + // assertFalse(success); + // Assertions.assertEquals(0.0, empty.availability(), 1e-5); + // } + // + // @Test + // public void leaseAvailability() { + // LeaseImpl lease = LeaseImpl.create(2, 100, Unpooled.EMPTY_BUFFER); + // Assertions.assertEquals(1.0, lease.availability(), 1e-5); + // } + // + // @Test + // public void leaseUseDecreasesAvailability() { + // LeaseImpl lease = LeaseImpl.create(30_000, 2, Unpooled.EMPTY_BUFFER); + // boolean success = lease.use(); + // Assertions.assertTrue(success); + // Assertions.assertEquals(0.5, lease.availability(), 1e-5); + // Assertions.assertTrue(lease.isValid()); + // success = lease.use(); + // Assertions.assertTrue(success); + // Assertions.assertEquals(0.0, lease.availability(), 1e-5); + // Assertions.assertFalse(lease.isValid()); + // Assertions.assertEquals(0, lease.getAllowedRequests()); + // success = lease.use(); + // Assertions.assertFalse(success); + // } + // + // @Test + // public void leaseTimeout() { + // int numberOfRequests = 1; + // LeaseImpl lease = LeaseImpl.create(1, numberOfRequests, Unpooled.EMPTY_BUFFER); + // Mono.delay(Duration.ofMillis(100)).block(); + // boolean success = lease.use(); + // Assertions.assertFalse(success); + // Assertions.assertTrue(lease.isExpired()); + // Assertions.assertEquals(numberOfRequests, lease.getAllowedRequests()); + // Assertions.assertFalse(lease.isValid()); + // } + // + // @Test + // public void useLeaseChangesAllowedRequests() { + // int numberOfRequests = 2; + // LeaseImpl lease = LeaseImpl.create(30_000, numberOfRequests, Unpooled.EMPTY_BUFFER); + // lease.use(); + // assertEquals(numberOfRequests - 1, lease.getAllowedRequests()); + // } } diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java new file mode 100644 index 000000000..a35e89391 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java @@ -0,0 +1,94 @@ +package io.rsocket.loadbalance; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +@ExtendWith(MockitoExtension.class) +class LoadbalanceRSocketClientTest { + + @Mock private ClientTransport clientTransport; + @Mock private RSocketConnector rSocketConnector; + + public static final Duration SHORT_DURATION = Duration.ofMillis(25); + public static final Duration LONG_DURATION = Duration.ofMillis(75); + + private static final Publisher SOURCE = + Flux.interval(SHORT_DURATION) + .onBackpressureBuffer() + .map(String::valueOf) + .map(DefaultPayload::create); + + private static final Mono PROGRESSING_HANDLER = + Mono.just( + new RSocket() { + private final AtomicInteger i = new AtomicInteger(); + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .delayElements(SHORT_DURATION) + .map(Payload::getDataUtf8) + .map(DefaultPayload::create) + .take(i.incrementAndGet()); + } + }); + + @Test + void testChannelReconnection() { + when(rSocketConnector.connect(clientTransport)).thenReturn(PROGRESSING_HANDLER); + + RSocketClient client = + LoadbalanceRSocketClient.create( + rSocketConnector, + Mono.just(singletonList(LoadbalanceTarget.from("key", clientTransport)))); + + Publisher result = + client + .requestChannel(SOURCE) + .repeatWhen(longFlux -> longFlux.delayElements(LONG_DURATION).take(5)) + .map(Payload::getDataUtf8) + .log(); + + StepVerifier.create(result) + .expectSubscription() + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("3")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("3")) + .assertNext(s -> assertThat(s).isEqualTo("4")) + .verifyComplete(); + + verify(rSocketConnector).connect(clientTransport); + verifyNoMoreInteractions(rSocketConnector, clientTransport); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java new file mode 100644 index 000000000..c1b509297 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java @@ -0,0 +1,470 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import io.rsocket.util.RSocketProxy; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; + +public class LoadbalanceTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void shouldDeliverAllTheRequestsWithRoundRobinStrategy() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = new TestClientTransport(); + final RSocket rSocket = + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter.incrementAndGet(); + return Mono.empty(); + } + }; + + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(new TestRSocket(rSocket))); + + final List collectionOfDestination1 = + Collections.singletonList(LoadbalanceTarget.from("1", mockTransport)); + final List collectionOfDestination2 = + Collections.singletonList(LoadbalanceTarget.from("2", mockTransport)); + final List collectionOfDestinations1And2 = + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), LoadbalanceTarget.from("2", mockTransport)); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final Sinks.Many> source = + Sinks.unsafe().many().unicast().onBackpressureError(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, source.asFlux(), new RoundRobinLoadbalanceStrategy()); + final Mono fnfSource = + Mono.defer(() -> rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)); + + RaceTestUtils.race( + () -> { + for (int j = 0; j < 1000; j++) { + fnfSource.subscribe(new RetrySubscriber(fnfSource)); + } + }, + () -> { + for (int j = 0; j < 100; j++) { + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + } + }); + + Assertions.assertThat(counter.get()).isEqualTo(1000); + counter.set(0); + } + } + + @Test + public void shouldDeliverAllTheRequestsWithWeightedStrategy() throws InterruptedException { + final AtomicInteger counter = new AtomicInteger(); + + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + + final LoadbalanceTarget target1 = LoadbalanceTarget.from("1", mockTransport1); + final LoadbalanceTarget target2 = LoadbalanceTarget.from("2", mockTransport2); + + final WeightedRSocket weightedRSocket1 = new WeightedRSocket(counter); + final WeightedRSocket weightedRSocket2 = new WeightedRSocket(counter); + + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + Mockito.when(rSocketConnectorMock.connect(mockTransport1)) + .then(im -> Mono.just(new TestRSocket(weightedRSocket1))); + Mockito.when(rSocketConnectorMock.connect(mockTransport2)) + .then(im -> Mono.just(new TestRSocket(weightedRSocket2))); + final List collectionOfDestination1 = Collections.singletonList(target1); + final List collectionOfDestination2 = Collections.singletonList(target2); + final List collectionOfDestinations1And2 = Arrays.asList(target1, target2); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final Sinks.Many> source = + Sinks.unsafe().many().unicast().onBackpressureError(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source.asFlux(), + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver( + rsocket -> { + if (rsocket instanceof TestRSocket) { + return (WeightedRSocket) ((TestRSocket) rsocket).source(); + } + return ((PooledRSocket) rsocket).target() == target1 + ? weightedRSocket1 + : weightedRSocket2; + }) + .build()); + final Mono fnfSource = + Mono.defer(() -> rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)); + + RaceTestUtils.race( + () -> { + for (int j = 0; j < 1000; j++) { + fnfSource.subscribe(new RetrySubscriber(fnfSource)); + } + }, + () -> { + for (int j = 0; j < 100; j++) { + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + } + }); + + Assertions.assertThat(counter.get()).isEqualTo(1000); + counter.set(0); + } + } + + @Test + public void ensureRSocketIsCleanedFromThePoolIfSourceRSocketIsDisposed() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + final TestRSocket testRSocket = + new TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter.incrementAndGet(); + return Mono.empty(); + } + }); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.delay(Duration.ofMillis(200)).map(__ -> testRSocket)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport))); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + testRSocket.dispose(); + + Assertions.assertThatThrownBy( + () -> + rSocketPool + .select() + .fireAndForget(EmptyPayload.INSTANCE) + .block(Duration.ofSeconds(2))) + .isExactlyInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on blocking read for 2000000000 NANOSECONDS"); + + Assertions.assertThat(counter.get()).isOne(); + } + + @Test + public void ensureContextIsPropagatedCorrectlyForRequestChannel() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.delay(Duration.ofMillis(200)) + .map( + __ -> + new TestRSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher source) { + counter.incrementAndGet(); + return Flux.from(source); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + // check that context is propagated when there is no rsocket + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .then( + () -> + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport)))) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport))); + // check that context is propagated when there is an RSocket but it is unresolved + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + // check that context is propagated when there is an RSocket and it is resolved + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + Assertions.assertThat(counter.get()).isEqualTo(3); + } + + @Test + public void shouldNotifyOnCloseWhenAllTheActiveSubscribersAreClosed() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Sinks.Empty onCloseSocket1 = Sinks.empty(); + Sinks.Empty onCloseSocket2 = Sinks.empty(); + + RSocket socket1 = + new RSocket() { + @Override + public Mono onClose() { + return onCloseSocket1.asMono(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + }; + RSocket socket2 = + new RSocket() { + @Override + public Mono onClose() { + return onCloseSocket2.asMono(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + }; + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(socket1)) + .then(im -> Mono.just(socket2)) + .then(im -> Mono.never().doOnCancel(() -> counter.incrementAndGet())); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport), + LoadbalanceTarget.from("3", mockTransport))); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + + rSocketPool.dispose(); + + AssertSubscriber onCloseSubscriber = + rSocketPool.onClose().subscribeWith(AssertSubscriber.create()); + + onCloseSubscriber.assertNotTerminated(); + + onCloseSocket1.tryEmitEmpty(); + + onCloseSubscriber.assertNotTerminated(); + + onCloseSocket2.tryEmitEmpty(); + + onCloseSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(counter.get()).isOne(); + } + + static class TestRSocket extends RSocketProxy { + + final Sinks.Empty sink = Sinks.empty(); + + public TestRSocket(RSocket rSocket) { + super(rSocket); + } + + @Override + public Mono onClose() { + return sink.asMono(); + } + + @Override + public void dispose() { + sink.tryEmitEmpty(); + } + + public RSocket source() { + return source; + } + } + + private static class WeightedRSocket extends BaseWeightedStats implements RSocket { + + private final AtomicInteger counter; + + public WeightedRSocket(AtomicInteger counter) { + this.counter = counter; + } + + @Override + public Mono fireAndForget(Payload payload) { + final long startTime = startRequest(); + counter.incrementAndGet(); + return Mono.empty() + .doFinally( + (__) -> { + final long stopTime = stopRequest(startTime); + record(stopTime - startTime); + }); + } + } + + static class RetrySubscriber implements CoreSubscriber { + + final Publisher source; + + private RetrySubscriber(Publisher source) { + this.source = source; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void unused) {} + + @Override + public void onError(Throwable t) { + source.subscribe(this); + } + + @Override + public void onComplete() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java new file mode 100644 index 000000000..e43068dbd --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java @@ -0,0 +1,170 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.test.publisher.TestPublisher; + +public class RoundRobinLoadbalanceStrategyTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void shouldDeliverValuesProportionally() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport))); + + Assertions.assertThat(counter1.get()).isCloseTo(500, Offset.offset(1)); + Assertions.assertThat(counter2.get()).isCloseTo(500, Offset.offset(1)); + } + + @Test + public void shouldDeliverValuesToNewlyConnectedSockets() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + if (im.getArgument(0) == mockTransport1) { + counter1.incrementAndGet(); + } else { + counter2.incrementAndGet(); + } + return Mono.empty(); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()).isCloseTo(RaceTestConstants.REPEATS, Offset.offset(1)); + + source.next(Collections.emptyList()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2 + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS / 2, Offset.offset(1)); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport1))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2 + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 3, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java new file mode 100644 index 000000000..8cc254cbb --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java @@ -0,0 +1,254 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.Clock; +import io.rsocket.util.EmptyPayload; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.publisher.TestPublisher; + +public class WeightedLoadbalanceStrategyTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void allRequestsShouldGoToTheSocketWithHigherWeight() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final WeightedTestRSocket rSocket1 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket2 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }); + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(rSocket1)) + .then(im -> Mono.just(rSocket2)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source, + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver(r -> r instanceof WeightedStats ? (WeightedStats) r : null) + .build()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport))); + + Assertions.assertThat(counter1.get()) + .describedAs("c1=" + counter1.get() + " c2=" + counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS, Offset.offset(Math.round(RaceTestConstants.REPEATS * 0.1f))); + Assertions.assertThat(counter2.get()) + .describedAs("c1=" + counter1.get() + " c2=" + counter2.get()) + .isCloseTo(0, Offset.offset(Math.round(RaceTestConstants.REPEATS * 0.1f))); + } + + @Test + public void shouldDeliverValuesToTheSocketWithTheHighestCalculatedWeight() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final WeightedTestRSocket rSocket1 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket2 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket3 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(rSocket1)) + .then(im -> Mono.just(rSocket2)) + .then(im -> Mono.just(rSocket3)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source, + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver(r -> r instanceof WeightedStats ? (WeightedStats) r : null) + .build()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()).isCloseTo(RaceTestConstants.REPEATS, Offset.offset(1)); + + source.next(Collections.emptyList()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + rSocket1.updateAvailability(0.0); + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + final RSocket rSocket = rSocketPool.select(); + rSocket.fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 3 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo(0, Offset.offset(Math.round(RaceTestConstants.REPEATS * 3 * 0.1f))); + + rSocket2.updateAvailability(0.0); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport1))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 4 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 4 * 0.1f))); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + final RSocket rSocket = rSocketPool.select(); + rSocket.fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 5 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 2, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 5 * 0.1f))); + } + + static class WeightedTestRSocket extends BaseWeightedStats implements RSocket { + + final Sinks.Empty sink = Sinks.empty(); + + final RSocket rSocket; + + public WeightedTestRSocket(RSocket rSocket) { + this.rSocket = rSocket; + } + + @Override + public Mono fireAndForget(Payload payload) { + startRequest(); + final long startTime = Clock.now(); + return this.rSocket + .fireAndForget(payload) + .doFinally( + __ -> { + stopRequest(startTime); + record(Clock.now() - startTime); + updateAvailability(1.0); + }); + } + + @Override + public Mono onClose() { + return sink.asMono(); + } + + @Override + public void dispose() { + sink.tryEmitEmpty(); + } + + public RSocket source() { + return rSocket; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java index a6ef8ea37..58ab30021 100644 --- a/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java +++ b/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java @@ -11,7 +11,7 @@ public class AuthMetadataCodecTest { public static final int AUTH_TYPE_ID_LENGTH = 1; - public static final int USER_NAME_BYTES_LENGTH = 1; + public static final int USER_NAME_BYTES_LENGTH = 2; public static final String TEST_BEARER_TOKEN = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJpYXQxIjoxNTE2MjM5MDIyLCJpYXQyIjoxNTE2MjM5MDIyLCJpYXQzIjoxNTE2MjM5MDIyLCJpYXQ0IjoxNTE2MjM5MDIyfQ.ljYuH-GNyyhhLcx-rHMchRkGbNsR2_4aSxo8XjrYrSM"; @@ -82,7 +82,7 @@ private static void checkSimpleAuthMetadataEncoding( Assertions.assertThat(byteBuf.readUnsignedByte() & ~0x80) .isEqualTo(WellKnownAuthType.SIMPLE.getIdentifier()); - Assertions.assertThat(byteBuf.readUnsignedByte()).isEqualTo((short) usernameLength); + Assertions.assertThat(byteBuf.readUnsignedShort()).isEqualTo((short) usernameLength); Assertions.assertThat(byteBuf.readCharSequence(usernameLength, CharsetUtil.UTF_8)) .isEqualTo(username); @@ -116,16 +116,22 @@ private static void checkSimpleAuthMetadataEncodingUsingDecoders( @Test void shouldThrowExceptionIfUsernameLengthExitsAllowedBounds() { - String username = + StringBuilder usernameBuilder = new StringBuilder(); + String usernamePart = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎𠸏𠹷𠺝𠺢𠻗𠻹𠻺𠼭𠼮𠽌𠾴𠾼𠿪𡁜𡁯𡁵𡁶𡁻𡃁𡃉𡇙𢃇𢞵𢫕𢭃𢯊𢱑𢱕𢳂𢴈𢵌𢵧𢺳𣲷𤓓𤶸𤷪𥄫𦉘𦟌𦧲𦧺𧨾𨅝𨈇𨋢𨳊𨳍𨳒𩶘𠜎𠜱𠝹"; + for (int i = 0; i < 65535 / usernamePart.length(); i++) { + usernameBuilder.append(usernamePart); + } String password = "tset1234"; Assertions.assertThatThrownBy( () -> AuthMetadataCodec.encodeSimpleMetadata( - ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray())) + ByteBufAllocator.DEFAULT, + usernameBuilder.toString().toCharArray(), + password.toCharArray())) .hasMessage( - "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + "Username should be shorter than or equal to 65535 bytes length in UTF-8 encoding"); } @Test @@ -242,7 +248,7 @@ void shouldEncodeUsingWellKnownAuthType() { AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, WellKnownAuthType.SIMPLE, - ByteBufAllocator.DEFAULT.buffer(3, 3).writeByte(1).writeByte('u').writeByte('p')); + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); } @@ -253,7 +259,7 @@ void shouldEncodeUsingWellKnownAuthType1() { AuthMetadataCodec.encodeMetadata( ByteBufAllocator.DEFAULT, WellKnownAuthType.SIMPLE, - ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); } @@ -297,7 +303,7 @@ void shouldCompressMetadata() { AuthMetadataCodec.encodeMetadataWithCompression( ByteBufAllocator.DEFAULT, "simple", - ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); } diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java index 3ce07729d..a4e8fb2d8 100644 --- a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java @@ -23,12 +23,22 @@ import io.netty.buffer.*; import io.netty.util.CharsetUtil; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.test.util.ByteBufUtils; import io.rsocket.util.NumberUtils; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; class CompositeMetadataCodecTest { + final LeaksTrackingByteBufAllocator testAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + @AfterEach + void tearDownAndCheckForLeaks() { + testAllocator.assertHasNoLeaks(); + } + static String byteToBitsString(byte b) { return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0'); } @@ -48,17 +58,14 @@ void customMimeHeaderLatin1_encodingFails() { assertThatIllegalArgumentException() .isThrownBy( - () -> - CompositeMetadataCodec.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mimeNotAscii, 0)) + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeNotAscii, 0)) .withMessage("custom mime type must be US_ASCII characters only"); } @Test void customMimeHeaderLength0_encodingFails() { assertThatIllegalArgumentException() - .isThrownBy( - () -> CompositeMetadataCodec.encodeMetadataHeader(ByteBufAllocator.DEFAULT, "", 0)) + .isThrownBy(() -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, "", 0)) .withMessage( "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); } @@ -70,8 +77,7 @@ void customMimeHeaderLength127() { builder.append('a'); } String mimeString = builder.toString(); - ByteBuf encoded = - CompositeMetadataCodec.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); // remember actual length = encoded length + 1 assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111110"); @@ -99,6 +105,7 @@ void customMimeHeaderLength127() { .hasToString(mimeString); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -108,8 +115,7 @@ void customMimeHeaderLength128() { builder.append('a'); } String mimeString = builder.toString(); - ByteBuf encoded = - CompositeMetadataCodec.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); // remember actual length = encoded length + 1 assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111111"); @@ -137,6 +143,7 @@ void customMimeHeaderLength128() { .hasToString(mimeString); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -148,9 +155,7 @@ void customMimeHeaderLength129_encodingFails() { assertThatIllegalArgumentException() .isThrownBy( - () -> - CompositeMetadataCodec.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, builder.toString(), 0)) + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, builder.toString(), 0)) .withMessage( "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); } @@ -158,8 +163,7 @@ void customMimeHeaderLength129_encodingFails() { @Test void customMimeHeaderLengthOne() { String mimeString = "w"; - ByteBuf encoded = - CompositeMetadataCodec.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); // remember actual length = encoded length + 1 assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000000"); @@ -185,13 +189,13 @@ void customMimeHeaderLengthOne() { .hasToString(mimeString); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test void customMimeHeaderLengthTwo() { String mimeString = "ww"; - ByteBuf encoded = - CompositeMetadataCodec.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); // remember actual length = encoded length + 1 assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000001"); @@ -219,6 +223,7 @@ void customMimeHeaderLengthTwo() { .hasToString(mimeString); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -227,9 +232,7 @@ void customMimeHeaderUtf8_encodingFails() { "mime/tyࠒe"; // this is the SAMARITAN LETTER QUF u+0812 represented on 3 bytes assertThatIllegalArgumentException() .isThrownBy( - () -> - CompositeMetadataCodec.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mimeNotAscii, 0)) + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeNotAscii, 0)) .withMessage("custom mime type must be US_ASCII characters only"); } @@ -317,72 +320,73 @@ void decodeTypeSkipsFirstByte() { @Test void encodeMetadataCustomTypeDelegates() { - ByteBuf expected = - CompositeMetadataCodec.encodeMetadataHeader(ByteBufAllocator.DEFAULT, "foo", 2); + ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, "foo", 2); - CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf test = testAllocator.compositeBuffer(); CompositeMetadataCodec.encodeAndAddMetadata( - test, ByteBufAllocator.DEFAULT, "foo", ByteBufUtils.getRandomByteBuf(2)); + test, testAllocator, "foo", ByteBufUtils.getRandomByteBuf(2)); assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); } @Test void encodeMetadataKnownTypeDelegates() { ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, - WellKnownMimeType.APPLICATION_OCTET_STREAM.getIdentifier(), - 2); + testAllocator, WellKnownMimeType.APPLICATION_OCTET_STREAM.getIdentifier(), 2); - CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf test = testAllocator.compositeBuffer(); CompositeMetadataCodec.encodeAndAddMetadata( test, - ByteBufAllocator.DEFAULT, + testAllocator, WellKnownMimeType.APPLICATION_OCTET_STREAM, ByteBufUtils.getRandomByteBuf(2)); assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); } @Test void encodeMetadataReservedTypeDelegates() { - ByteBuf expected = - CompositeMetadataCodec.encodeMetadataHeader(ByteBufAllocator.DEFAULT, (byte) 120, 2); + ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, (byte) 120, 2); - CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf test = testAllocator.compositeBuffer(); CompositeMetadataCodec.encodeAndAddMetadata( - test, ByteBufAllocator.DEFAULT, (byte) 120, ByteBufUtils.getRandomByteBuf(2)); + test, testAllocator, (byte) 120, ByteBufUtils.getRandomByteBuf(2)); assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); } @Test void encodeTryCompressWithCompressableType() { ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); - CompositeByteBuf target = UnpooledByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf target = testAllocator.compositeBuffer(); CompositeMetadataCodec.encodeAndAddMetadataWithCompression( - target, - UnpooledByteBufAllocator.DEFAULT, - WellKnownMimeType.APPLICATION_AVRO.getString(), - metadata); + target, testAllocator, WellKnownMimeType.APPLICATION_AVRO.getString(), metadata); assertThat(target.readableBytes()).as("readableBytes 1 + 3 + 2").isEqualTo(6); + target.release(); } @Test void encodeTryCompressWithCustomType() { ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); - CompositeByteBuf target = UnpooledByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeByteBuf target = testAllocator.compositeBuffer(); CompositeMetadataCodec.encodeAndAddMetadataWithCompression( - target, UnpooledByteBufAllocator.DEFAULT, "custom/example", metadata); + target, testAllocator, "custom/example", metadata); assertThat(target.readableBytes()).as("readableBytes 1 + 14 + 3 + 2").isEqualTo(20); + target.release(); } @Test @@ -390,19 +394,20 @@ void hasEntry() { WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; CompositeByteBuf buffer = - Unpooled.compositeBuffer() + testAllocator + .compositeBuffer() .addComponent( true, - CompositeMetadataCodec.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0)) + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0)) .addComponent( true, CompositeMetadataCodec.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0)); + testAllocator, mime.getIdentifier(), 0)); assertThat(CompositeMetadataCodec.hasEntry(buffer, 0)).isTrue(); assertThat(CompositeMetadataCodec.hasEntry(buffer, 4)).isTrue(); assertThat(CompositeMetadataCodec.hasEntry(buffer, 8)).isFalse(); + buffer.release(); } @Test @@ -417,8 +422,7 @@ void isWellKnownMimeType() { @Test void knownMimeHeader120_reserved() { byte mime = (byte) 120; - ByteBuf encoded = - CompositeMetadataCodec.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mime, 0); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime, 0); assertThat(mime) .as("smoke test RESERVED_120 unsigned 7 bits representation") @@ -443,6 +447,7 @@ void knownMimeHeader120_reserved() { assertThat(decodeMimeIdFromMimeBuffer(header)).as("decoded mime id").isEqualTo(mime); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -453,8 +458,7 @@ void knownMimeHeader127_compositeMetadata() { .isEqualTo((byte) 127) .isEqualTo((byte) 0b01111111); ByteBuf encoded = - CompositeMetadataCodec.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0); + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0); assertThat(toHeaderBits(encoded)) .startsWith("1") @@ -480,6 +484,7 @@ void knownMimeHeader127_compositeMetadata() { .isEqualTo(mime.getIdentifier()); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test @@ -490,8 +495,7 @@ void knownMimeHeaderZero_avro() { .isEqualTo((byte) 0) .isEqualTo((byte) 0b00000000); ByteBuf encoded = - CompositeMetadataCodec.encodeMetadataHeader( - ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0); + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0); assertThat(toHeaderBits(encoded)) .startsWith("1") @@ -517,6 +521,7 @@ void knownMimeHeaderZero_avro() { .isEqualTo(mime.getIdentifier()); assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); } @Test diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java new file mode 100644 index 000000000..5c8d40306 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import java.util.List; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link MimeTypeMetadataCodec}. */ +public class MimeTypeMetadataCodecTest { + + @Test + public void wellKnownMimeType() { + WellKnownMimeType mimeType = WellKnownMimeType.APPLICATION_HESSIAN; + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeType); + try { + List mimeTypes = MimeTypeMetadataCodec.decode(byteBuf); + + assertThat(mimeTypes.size()).isEqualTo(1); + assertThat(WellKnownMimeType.fromString(mimeTypes.get(0))).isEqualTo(mimeType); + } finally { + byteBuf.release(); + } + } + + @Test + public void customMimeType() { + String mimeType = "aaa/bb"; + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeType); + try { + List mimeTypes = MimeTypeMetadataCodec.decode(byteBuf); + + assertThat(mimeTypes.size()).isEqualTo(1); + assertThat(mimeTypes.get(0)).isEqualTo(mimeType); + } finally { + byteBuf.release(); + } + } + + @Test + public void multipleMimeTypes() { + List mimeTypes = Lists.newArrayList("aaa/bbb", "application/x-hessian"); + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeTypes); + + try { + assertThat(MimeTypeMetadataCodec.decode(byteBuf)).isEqualTo(mimeTypes); + } finally { + byteBuf.release(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java new file mode 100644 index 000000000..9a19050f9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java @@ -0,0 +1,790 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.FrameType; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.annotation.Nullable; + +public class RequestInterceptorTest { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientRequesterSide(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientResponderSide(boolean errorOutcome) + throws InterruptedException { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + CountDownLatch latch = new CountDownLatch(1); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forRequestsInResponder( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerRequesterSide(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forRequestsInResponder( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerResponderSide(boolean errorOutcome) + throws InterruptedException { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + CountDownLatch latch = new CountDownLatch(1); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @Test + void ensuresExceptionInTheInterceptorIsHandledProperly() { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + throw new RuntimeException("testOnCancel"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + StepVerifier.create(rSocket.fireAndForget(DefaultPayload.create("test"))) + .expectSubscription() + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestStream(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestChannel(Flux.just(DefaultPayload.create("test")))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldSupportMultipleInterceptors(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor1 = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequestInterceptor testRequestInterceptor2 = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor) + .forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor1) + .forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor2)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + testRequestInterceptor2 + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java new file mode 100644 index 000000000..8261b3f49 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java @@ -0,0 +1,142 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; +import java.util.Queue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Condition; +import reactor.util.annotation.Nullable; + +public class TestRequestInterceptor implements RequestInterceptor { + + final Queue events = new MpscUnboundedArrayQueue<>(128); + + @Override + public void dispose() {} + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_START, streamId, requestType, null)); + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + events.add( + new Event( + t == null ? EventType.ON_COMPLETE : EventType.ON_ERROR, streamId, requestType, t)); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + events.add(new Event(EventType.ON_CANCEL, streamId, requestType, null)); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_REJECT, -1, requestType, rejectionReason)); + } + + public TestRequestInterceptor expectOnStart(int streamId, FrameType requestType) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_START) + .hasFieldOrPropertyWithValue("streamId", streamId) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectOnComplete(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_COMPLETE) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnError(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_ERROR) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnCancel(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_CANCEL) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor assertNext(Consumer consumer) { + final Event event = events.poll(); + Assertions.assertThat(event).isNotNull(); + + consumer.accept(event); + + return this; + } + + public TestRequestInterceptor expectOnReject(FrameType requestType, Throwable rejectionReason) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_REJECT) + .has( + new Condition<>( + e -> { + Assertions.assertThat(e.error) + .isExactlyInstanceOf(rejectionReason.getClass()) + .hasMessage(rejectionReason.getMessage()) + .hasCause(rejectionReason.getCause()); + return true; + }, + "Has rejection reason which matches to %s", + rejectionReason)) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectNothing() { + final Event event = events.poll(); + + Assertions.assertThat(event).isNull(); + + return this; + } + + public static final class Event { + public final EventType eventType; + public final int streamId; + public final FrameType requestType; + public final Throwable error; + + Event(EventType eventType, int streamId, FrameType requestType, Throwable error) { + this.eventType = eventType; + this.streamId = streamId; + this.requestType = requestType; + this.error = error; + } + } + + public enum EventType { + ON_START, + ON_COMPLETE, + ON_ERROR, + ON_CANCEL, + ON_REJECT + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java new file mode 100644 index 000000000..8229bf42b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java @@ -0,0 +1,470 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.exceptions.ConnectionCloseException; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ClientRSocketSessionTest { + + @Test + void sessionTimeoutSmokeTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME_OK frame + transport + .testConnection() + .addToReceivedBuffer(ResumeOkFrameCodec.encode(transport.alloc(), 0)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + transport + .testConnection() + .addToReceivedBuffer( + ErrorFrameCodec.encode( + transport.alloc(), 0, new ConnectionCloseException("some message"))); + // connection should be closed because of the wrong first frame + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout is still in progress + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + // should obtain new connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_OK frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(transport.testConnection().isDisposed()).isTrue(); + + assertThat(session.isDisposed()).isTrue(); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectComplete().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void sessionTerminationOnWrongFrameTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME_OK frame + transport + .testConnection() + .addToReceivedBuffer(ResumeOkFrameCodec.encode(transport.alloc(), 0)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // Send KEEPALIVE frame as a first frame + transport + .testConnection() + .addToReceivedBuffer( + KeepAliveFrameCodec.encode(transport.alloc(), false, 0, Unpooled.EMPTY_BUFFER)); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(transport.testConnection().isDisposed()).isTrue(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection + .onClose() + .as(StepVerifier::create) + .expectErrorMessage("RESUME_OK frame must be received before any others") + .verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldErrorWithNoRetriesOnErrorFrameTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send REJECTED_RESUME_ERROR frame + transport + .testConnection() + .addToReceivedBuffer( + ErrorFrameCodec.encode( + transport.alloc(), 0, new RejectedResumeException("failed resumption"))); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + resumableDuplexConnection + .onClose() + .as(StepVerifier::create) + .expectError(RejectedResumeException.class) + .verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldTerminateConnectionOnIllegalStateInKeepAliveFrame() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + keepAliveSupport.resumeState(session); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + final ByteBuf keepAliveFrame = + KeepAliveFrameCodec.encode(transport.alloc(), false, 1529, Unpooled.EMPTY_BUFFER); + keepAliveSupport.receive(keepAliveFrame); + keepAliveFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectError().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java index 9da66d424..bba40d674 100644 --- a/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java +++ b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java @@ -1,80 +1,528 @@ package io.rsocket.resume; +import static org.assertj.core.api.Assertions.assertThat; + import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.subscriber.AssertSubscriber; import java.util.Arrays; -import org.junit.Assert; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Disposable; +import reactor.core.publisher.Hooks; +import reactor.test.util.RaceTestUtils; public class InMemoryResumeStoreTest { + @Test + void saveNonResumableFrame() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeConnectionFrame(10); + final ByteBuf frame2 = fakeConnectionFrame(35); + + sender.onNext(frame1); + sender.onNext(frame2); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + assertThat(store.firstAvailableFramePosition).isZero(); + + assertSubscriber.assertValueCount(2).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + } + @Test void saveWithoutTailRemoval() { - InMemoryResumableFramesStore store = inMemoryStore(25); - ByteBuf frame = frameMock(10); - store.saveFrames(Flux.just(frame)).block(); - Assert.assertEquals(1, store.cachedFrames.size()); - Assert.assertEquals(frame.readableBytes(), store.cacheSize); - Assert.assertEquals(0, store.position); + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame = fakeResumableFrame(10); + + sender.onNext(frame); + + assertThat(store.cachedFrames.size()).isEqualTo(1); + assertThat(store.cacheSize).isEqualTo(frame.readableBytes()); + assertThat(store.firstAvailableFramePosition).isZero(); + + assertSubscriber.assertValueCount(1).values().forEach(ByteBuf::release); + + assertThat(frame.refCnt()).isOne(); } @Test void saveRemoveOneFromTail() { - InMemoryResumableFramesStore store = inMemoryStore(25); - ByteBuf frame1 = frameMock(20); - ByteBuf frame2 = frameMock(10); - store.saveFrames(Flux.just(frame1, frame2)).block(); - Assert.assertEquals(1, store.cachedFrames.size()); - Assert.assertEquals(frame2.readableBytes(), store.cacheSize); - Assert.assertEquals(frame1.readableBytes(), store.position); + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + final ByteBuf frame1 = fakeResumableFrame(20); + final ByteBuf frame2 = fakeResumableFrame(10); + + sender.onNext(frame1); + sender.onNext(frame2); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame2.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(frame1.readableBytes()); + + assertSubscriber.assertValueCount(2).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isOne(); } @Test void saveRemoveTwoFromTail() { - InMemoryResumableFramesStore store = inMemoryStore(25); - ByteBuf frame1 = frameMock(10); - ByteBuf frame2 = frameMock(10); - ByteBuf frame3 = frameMock(20); - store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); - Assert.assertEquals(1, store.cachedFrames.size()); - Assert.assertEquals(frame3.readableBytes(), store.cacheSize); - Assert.assertEquals(size(frame1, frame2), store.position); + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(20); + + sender.onNext(frame1); + sender.onNext(frame2); + sender.onNext(frame3); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame3.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isOne(); } @Test void saveBiggerThanStore() { - InMemoryResumableFramesStore store = inMemoryStore(25); - ByteBuf frame1 = frameMock(10); - ByteBuf frame2 = frameMock(10); - ByteBuf frame3 = frameMock(30); - store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); - Assert.assertEquals(0, store.cachedFrames.size()); - Assert.assertEquals(0, store.cacheSize); - Assert.assertEquals(size(frame1, frame2, frame3), store.position); + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + sender.onNext(frame1); + sender.onNext(frame2); + sender.onNext(frame3); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2, frame3)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); } @Test void releaseFrames() { - InMemoryResumableFramesStore store = inMemoryStore(100); - ByteBuf frame1 = frameMock(10); - ByteBuf frame2 = frameMock(10); - ByteBuf frame3 = frameMock(30); - store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + store.releaseFrames(20); - Assert.assertEquals(1, store.cachedFrames.size()); - Assert.assertEquals(frame3.readableBytes(), store.cacheSize); - Assert.assertEquals(size(frame1, frame2), store.position); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame3.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isOne(); } @Test void receiveImpliedPosition() { - InMemoryResumableFramesStore store = inMemoryStore(100); - ByteBuf frame1 = frameMock(10); - ByteBuf frame2 = frameMock(30); + final InMemoryResumableFramesStore store = inMemoryStore(100); + + ByteBuf frame1 = fakeResumableFrame(10); + ByteBuf frame2 = fakeResumableFrame(30); + store.resumableFrameReceived(frame1); store.resumableFrameReceived(frame2); - Assert.assertEquals(size(frame1, frame2), store.frameImpliedPosition()); + + assertThat(store.frameImpliedPosition()).isEqualTo(size(frame1, frame2)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void ensuresCleansOnTerminal(boolean hasSubscriber) { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final AssertSubscriber assertSubscriber = + hasSubscriber ? store.resumeStream().subscribeWith(AssertSubscriber.create()) : null; + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + producer.onComplete(); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + + assertThat(producer.isDisposed()).isTrue(); + + if (hasSubscriber) { + assertSubscriber.assertValueCount(3).assertTerminated().values().forEach(ByteBuf::release); + } + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @Test + void ensuresCleansOnTerminalLateSubscriber() { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + producer.onComplete(); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + + assertThat(producer.isDisposed()).isTrue(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + assertSubscriber.assertTerminated(); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @ParameterizedTest(name = "Sending vs Reconnect Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void sendingVsReconnectRaceTest(boolean withLateSubscriber) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + final BlockingQueue receivedFrames = new ArrayBlockingQueue<>(10); + final AtomicInteger receivedPosition = new AtomicInteger(); + + store.saveFrames(frames).subscribe(); + + final Consumer consumer = + f -> { + if (ResumableDuplexConnection.isResumableFrame(f)) { + receivedPosition.addAndGet(f.readableBytes()); + receivedFrames.offer(f); + return; + } + f.release(); + }; + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber ? null : store.resumeStream().subscribe(consumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer)); + } + + // disconnect + disposableReference.get().dispose(); + + while (InMemoryResumableFramesStore.isWorkInProgress(store.state)) { + // ignore + } + + // mimic RESUME_OK frame received + store.releaseFrames(receivedPosition.get()); + disposableReference.set(store.resumeStream().subscribe(consumer)); + + // disconnect + disposableReference.get().dispose(); + + while (InMemoryResumableFramesStore.isWorkInProgress(store.state)) { + // ignore + } + + // mimic RESUME_OK frame received + store.releaseFrames(receivedPosition.get()); + disposableReference.set(store.resumeStream().subscribe(consumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }); + + store.releaseFrames(receivedFrames.stream().mapToInt(ByteBuf::readableBytes).sum()); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + + assertThat(receivedFrames) + .hasSize(5) + .containsSequence(byteBuf1, byteBuf2, byteBuf3, byteBuf4, byteBuf5); + receivedFrames.forEach(ReferenceCounted::release); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } + + @ParameterizedTest( + name = "Sending vs Reconnect with incorrect position Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void incorrectReleaseFramesWithOnNextRaceTest(boolean withLateSubscriber) { + Hooks.onErrorDropped(t -> {}); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + + store.saveFrames(frames).subscribe(); + + final AtomicInteger terminationCnt = new AtomicInteger(); + final Consumer consumer = ReferenceCounted::release; + final Consumer errorConsumer = __ -> terminationCnt.incrementAndGet(); + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber + ? null + : store.resumeStream().subscribe(consumer, errorConsumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + } + // disconnect + disposableReference.get().dispose(); + + // mimic RESUME_OK frame received but with incorrect position + store.releaseFrames(25); + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + assertThat(disposableReference.get().isDisposed()).isTrue(); + assertThat(terminationCnt).hasValue(1); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest( + name = + "Dispose vs Sending vs Reconnect with incorrect position Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void incorrectReleaseFramesWithOnNextWithDisposeRaceTest(boolean withLateSubscriber) { + Hooks.onErrorDropped(t -> {}); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + + store.saveFrames(frames).subscribe(); + + final AtomicInteger terminationCnt = new AtomicInteger(); + final Consumer consumer = ReferenceCounted::release; + final Consumer errorConsumer = __ -> terminationCnt.incrementAndGet(); + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber + ? null + : store.resumeStream().subscribe(consumer, errorConsumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + } + // disconnect + disposableReference.get().dispose(); + + // mimic RESUME_OK frame received but with incorrect position + store.releaseFrames(25); + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }, + store::dispose); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + assertThat(disposableReference.get().isDisposed()).isTrue(); + assertThat(terminationCnt).hasValueGreaterThanOrEqualTo(1).hasValueLessThanOrEqualTo(2); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } finally { + Hooks.resetOnErrorDropped(); + } } private int size(ByteBuf... byteBufs) { @@ -82,12 +530,18 @@ private int size(ByteBuf... byteBufs) { } private static InMemoryResumableFramesStore inMemoryStore(int size) { - return new InMemoryResumableFramesStore("test", size); + return new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, size); } - private static ByteBuf frameMock(int size) { + private static ByteBuf fakeResumableFrame(int size) { byte[] bytes = new byte[size]; Arrays.fill(bytes, (byte) 7); return Unpooled.wrappedBuffer(bytes); } + + private static ByteBuf fakeConnectionFrame(int size) { + byte[] bytes = new byte[size]; + Arrays.fill(bytes, (byte) 0); + return Unpooled.wrappedBuffer(bytes); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java deleted file mode 100644 index 7d2a7bcc8..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -public class ResumeCalculatorTest { - - @BeforeEach - void setUp() {} - - @Test - void clientResumeSuccess() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 42, -1, 3); - Assertions.assertEquals(3, position); - } - - @Test - void clientResumeError() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(4, 42, -1, 3); - Assertions.assertEquals(-1, position); - } - - @Test - void serverResumeSuccess() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 42, 4, 23); - Assertions.assertEquals(23, position); - } - - @Test - void serverResumeErrorClientState() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 3, 4, 23); - Assertions.assertEquals(-1, position); - } - - @Test - void serverResumeErrorServerState() { - long position = ResumableDuplexConnection.calculateRemoteImpliedPos(4, 42, 4, 1); - Assertions.assertEquals(-1, position); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java new file mode 100644 index 000000000..b5625bf8e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java @@ -0,0 +1,190 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.test.util.TestClientTransport; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; + +public class ServerRSocketSessionTest { + + @Test + void sessionTimeoutSmokeTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ServerRSocketSession session = + new ServerRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.testConnection(), + framesStore, + Duration.ofMinutes(1), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // resubscribe so a new connection is generated + transport.connect().subscribe(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME frame + final ByteBuf resumeFrame = + ResumeFrameCodec.encode(transport.alloc(), Unpooled.EMPTY_BUFFER, 0, 0); + session.resumeWith(resumeFrame, transport.testConnection()); + resumeFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME_OK) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + transport.connect().subscribe(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(61)); + + final ByteBuf resumeFrame1 = + ResumeFrameCodec.encode(transport.alloc(), Unpooled.EMPTY_BUFFER, 0, 0); + session.resumeWith(resumeFrame1, transport.testConnection()); + resumeFrame1.release(); + + // should obtain new connection + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be still active since no RESUME_OK frame has been received yet + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectComplete().verify(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldTerminateConnectionOnIllegalStateInKeepAliveFrame() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ServerRSocketSession session = + new ServerRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.testConnection(), + framesStore, + Duration.ofMinutes(1), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + keepAliveSupport.resumeState(session); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + final ByteBuf keepAliveFrame = + KeepAliveFrameCodec.encode(transport.alloc(), false, 1529, Unpooled.EMPTY_BUFFER); + keepAliveSupport.receive(keepAliveFrame); + keepAliveFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectError().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java index a2957c5a1..cdfcefdc8 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,46 +19,54 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import org.reactivestreams.Publisher; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import java.net.SocketAddress; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; -import reactor.core.publisher.DirectProcessor; +import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; public class LocalDuplexConnection implements DuplexConnection { private final ByteBufAllocator allocator; - private final DirectProcessor send; - private final DirectProcessor receive; - private final MonoProcessor onClose; + private final Sinks.Many send; + private final Sinks.Many receive; + private final Sinks.Empty onClose; private final String name; public LocalDuplexConnection( String name, ByteBufAllocator allocator, - DirectProcessor send, - DirectProcessor receive) { + Sinks.Many send, + Sinks.Many receive) { this.name = name; this.allocator = allocator; this.send = send; this.receive = receive; - this.onClose = MonoProcessor.create(); + this.onClose = Sinks.empty(); } @Override - public Mono send(Publisher frame) { - return Flux.from(frame) - .doOnNext(f -> System.out.println(name + " - " + f.toString())) - .doOnNext(send::onNext) - .doOnError(send::onError) - .then(); + public void sendFrame(int streamId, ByteBuf frame) { + System.out.println(name + " - " + frame.toString()); + send.tryEmitNext(frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + System.out.println(name + " - " + errorFrame.toString()); + send.tryEmitNext(errorFrame); + onClose.tryEmitEmpty(); } @Override public Flux receive() { return receive + .asFlux() .doOnNext(f -> System.out.println(name + " - " + f.toString())) .transform( Operators.lift( @@ -93,18 +101,24 @@ public ByteBufAllocator alloc() { return allocator; } + @Override + public SocketAddress remoteAddress() { + return new TestLocalSocketAddress(name); + } + @Override public void dispose() { - onClose.onComplete(); + onClose.tryEmitEmpty(); } @Override + @SuppressWarnings("ConstantConditions") public boolean isDisposed() { - return onClose.isDisposed(); + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return onClose; + return onClose.asMono(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java b/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java index 179afff58..a33c4c4b3 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java @@ -16,8 +16,7 @@ package io.rsocket.test.util; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.is; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -116,6 +115,8 @@ public void assertMetadataPushCount(int expected) { } private static void assertCount(int expected, String type, AtomicInteger counter) { - assertThat("Unexpected invocations for " + type + '.', counter.get(), is(expected)); + assertThat(counter.get()) + .describedAs("Unexpected invocations for " + type + '.') + .isEqualTo(expected); } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java index 88694d209..f02bc99a4 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java @@ -6,18 +6,21 @@ import io.rsocket.DuplexConnection; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.transport.ClientTransport; +import java.time.Duration; import reactor.core.publisher.Mono; public class TestClientTransport implements ClientTransport { private final LeaksTrackingByteBufAllocator allocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); - private final TestDuplexConnection testDuplexConnection = new TestDuplexConnection(allocator); + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "client"); + + private volatile TestDuplexConnection testDuplexConnection; int maxFrameLength = FRAME_LENGTH_MASK; @Override public Mono connect() { - return Mono.just(testDuplexConnection); + return Mono.fromSupplier(() -> testDuplexConnection = new TestDuplexConnection(allocator)); } public TestDuplexConnection testConnection() { diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java index abc15509b..8793d6ca4 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,11 @@ package io.rsocket.test.util; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import java.net.SocketAddress; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import org.reactivestreams.Publisher; @@ -32,6 +35,7 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; /** * An implementation of {@link DuplexConnection} that provides functionality to modify the behavior @@ -42,16 +46,17 @@ public class TestDuplexConnection implements DuplexConnection { private static final Logger logger = LoggerFactory.getLogger(TestDuplexConnection.class); private final LinkedBlockingQueue sent; + private final DirectProcessor sentPublisher; private final FluxSink sendSink; private final DirectProcessor received; private final FluxSink receivedSink; private final MonoProcessor onClose; - private final ByteBufAllocator allocator; + private final LeaksTrackingByteBufAllocator allocator; private volatile double availability = 1; private volatile int initialSendRequestN = Integer.MAX_VALUE; - public TestDuplexConnection(ByteBufAllocator allocator) { + public TestDuplexConnection(LeaksTrackingByteBufAllocator allocator) { this.allocator = allocator; this.sent = new LinkedBlockingQueue<>(); this.received = DirectProcessor.create(); @@ -62,19 +67,13 @@ public TestDuplexConnection(ByteBufAllocator allocator) { } @Override - public Mono send(Publisher frames) { + public void sendFrame(int streamId, ByteBuf frame) { if (availability <= 0) { - return Mono.error( - new IllegalStateException("RSocket not available. Availability: " + availability)); + throw new IllegalStateException("RSocket not available. Availability: " + availability); } - return Flux.from(frames) - .doOnNext( - frame -> { - sendSink.next(frame); - sent.offer(frame); - }) - .doOnError(throwable -> logger.error("Error in send stream on test connection.", throwable)) - .then(); + + sendSink.next(frame); + sent.offer(frame); } @Override @@ -107,10 +106,29 @@ public void onComplete() { } @Override - public ByteBufAllocator alloc() { + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + sendSink.next(errorFrame); + sent.offer(errorFrame); + + final Throwable cause = e.getCause(); + if (cause == null) { + onClose.onComplete(); + } else { + onClose.onError(cause); + } + } + + @Override + public LeaksTrackingByteBufAllocator alloc() { return allocator; } + @Override + public SocketAddress remoteAddress() { + return new TestLocalSocketAddress("TestDuplexConnection"); + } + @Override public double availability() { return availability; @@ -131,8 +149,21 @@ public Mono onClose() { return onClose; } - public ByteBuf awaitSend() throws InterruptedException { - return sent.take(); + public boolean isEmpty() { + return sent.isEmpty(); + } + + @NonNull + public ByteBuf awaitFrame() { + try { + return sent.take(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + public ByteBuf pollFrame() { + return sent.poll(); } public void setAvailability(double availability) { diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java new file mode 100644 index 000000000..2dad2cc1f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java @@ -0,0 +1,46 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test.util; + +import java.net.SocketAddress; +import java.util.Objects; + +public final class TestLocalSocketAddress extends SocketAddress { + + private static final long serialVersionUID = 2608695156052100164L; + + private final String name; + + /** + * Creates a new instance. + * + * @param name the name representing the address + * @throws NullPointerException if {@code name} is {@code null} + */ + public TestLocalSocketAddress(String name) { + this.name = Objects.requireNonNull(name, "name must not be null"); + } + + /** Return the name for this connection. */ + public String getName() { + return name; + } + + @Override + public String toString() { + return "[local address] " + name; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java index 0f9ea8e48..fa9331d3b 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.test.util; import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; @@ -6,11 +21,13 @@ import io.rsocket.Closeable; import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.transport.ServerTransport; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; public class TestServerTransport implements ServerTransport { - private final MonoProcessor conn = MonoProcessor.create(); + private final Sinks.One connSink = Sinks.one(); + private TestDuplexConnection connection; private final LeaksTrackingByteBufAllocator allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); @@ -18,29 +35,33 @@ public class TestServerTransport implements ServerTransport { @Override public Mono start(ConnectionAcceptor acceptor) { - conn.flatMap(acceptor::apply) + connSink + .asMono() + .flatMap(duplexConnection -> acceptor.apply(duplexConnection)) .subscribe(ignored -> {}, err -> disposeConnection(), this::disposeConnection); return Mono.just( new Closeable() { @Override public Mono onClose() { - return conn.then(); + return connSink.asMono().then(); } @Override public void dispose() { - conn.onComplete(); + connSink.tryEmitEmpty(); } @Override + @SuppressWarnings("ConstantConditions") public boolean isDisposed() { - return conn.isTerminated(); + return connSink.scan(Scannable.Attr.TERMINATED) + || connSink.scan(Scannable.Attr.CANCELLED); } }); } private void disposeConnection() { - TestDuplexConnection c = conn.peek(); + TestDuplexConnection c = connection; if (c != null) { c.dispose(); } @@ -48,7 +69,8 @@ private void disposeConnection() { public TestDuplexConnection connect() { TestDuplexConnection c = new TestDuplexConnection(allocator); - conn.onNext(c); + connection = c; + connSink.tryEmitValue(c); return c; } diff --git a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java index 3f97ab9dc..f04de78b6 100644 --- a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java +++ b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java @@ -16,8 +16,7 @@ package io.rsocket.util; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; +import static org.assertj.core.api.Assertions.assertThat; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; @@ -26,8 +25,7 @@ import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import java.nio.ByteBuffer; import java.util.concurrent.ThreadLocalRandom; -import org.assertj.core.api.Assertions; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class DefaultPayloadTest { public static final String DATA_VAL = "data"; @@ -41,12 +39,12 @@ public void testReuse() { } public void assertDataAndMetadata(Payload p, String dataVal, String metadataVal) { - assertThat("Unexpected data.", p.getDataUtf8(), equalTo(dataVal)); + assertThat(p.getDataUtf8()).describedAs("Unexpected data.").isEqualTo(dataVal); if (metadataVal == null) { - assertThat("Non-null metadata", p.hasMetadata(), equalTo(false)); + assertThat(p.hasMetadata()).describedAs("Non-null metadata").isEqualTo(false); } else { - assertThat("Null metadata", p.hasMetadata(), equalTo(true)); - assertThat("Unexpected metadata.", p.getMetadataUtf8(), equalTo(metadataVal)); + assertThat(p.hasMetadata()).describedAs("Null metadata").isEqualTo(true); + assertThat(p.getMetadataUtf8()).describedAs("Unexpected metadata.").isEqualTo(metadataVal); } } @@ -60,7 +58,7 @@ public void staticMethods() { public void shouldIndicateThatItHasNotMetadata() { Payload payload = DefaultPayload.create("data"); - Assertions.assertThat(payload.hasMetadata()).isFalse(); + assertThat(payload.hasMetadata()).isFalse(); } @Test @@ -68,7 +66,7 @@ public void shouldIndicateThatItHasMetadata1() { Payload payload = DefaultPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); - Assertions.assertThat(payload.hasMetadata()).isTrue(); + assertThat(payload.hasMetadata()).isTrue(); } @Test @@ -76,7 +74,7 @@ public void shouldIndicateThatItHasMetadata2() { Payload payload = DefaultPayload.create(ByteBuffer.wrap("data".getBytes()), ByteBuffer.allocate(0)); - Assertions.assertThat(payload.hasMetadata()).isTrue(); + assertThat(payload.hasMetadata()).isTrue(); } @Test @@ -96,9 +94,9 @@ public void shouldReleaseGivenByteBufDataAndMetadataUpOnPayloadCreation() { Payload payload = DefaultPayload.create(data, metadata); - Assertions.assertThat(payload.getData()).isEqualTo(ByteBuffer.wrap(new byte[] {i})); + assertThat(payload.getData()).isEqualTo(ByteBuffer.wrap(new byte[] {i})); - Assertions.assertThat(payload.getMetadata()) + assertThat(payload.getMetadata()) .isEqualTo( metadataPresent ? ByteBuffer.wrap(new byte[] {(byte) (i + 1)}) diff --git a/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension b/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension new file mode 100644 index 000000000..2b51ba0de --- /dev/null +++ b/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension @@ -0,0 +1 @@ +io.rsocket.frame.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-examples/build.gradle b/rsocket-examples/build.gradle index c86befc92..4059eb957 100644 --- a/rsocket-examples/build.gradle +++ b/rsocket-examples/build.gradle @@ -23,6 +23,16 @@ dependencies { implementation project(':rsocket-load-balancer') implementation project(':rsocket-transport-local') implementation project(':rsocket-transport-netty') + + implementation "io.micrometer:micrometer-core" + implementation "io.micrometer:micrometer-tracing" + implementation project(":rsocket-micrometer") + + implementation 'com.netflix.concurrency-limits:concurrency-limits-core' + implementation "io.micrometer:micrometer-core" + implementation "io.micrometer:micrometer-tracing" + implementation project(":rsocket-micrometer") + runtimeOnly 'ch.qos.logback:logback-classic' testImplementation project(':rsocket-test') @@ -30,11 +40,11 @@ dependencies { testImplementation 'org.mockito:mockito-core' testImplementation 'org.assertj:assertj-core' testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.awaitility:awaitility' + testImplementation "io.micrometer:micrometer-test" + testImplementation "io.micrometer:micrometer-tracing-integration-test" - // TODO: Remove after JUnit5 migration - testCompileOnly 'junit:junit' - testImplementation 'org.hamcrest:hamcrest-library' - testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' } description = 'Example usage of the RSocket library' diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java index b532c0140..463043020 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java @@ -43,9 +43,7 @@ public static void main(String[] args) { .map(s -> "Echo: " + s) .map(DefaultPayload::create)); - RSocketServer.create(echoAcceptor) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + RSocketServer.create(echoAcceptor).bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java index 2d19b9ce4..dfbbcde53 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java @@ -34,7 +34,7 @@ public static void main(String[] args) { .bind(TcpServerTransport.create("localhost", 7000)) .delaySubscription(Duration.ofSeconds(5)) .doOnNext(cc -> logger.info("Server started on the address : {}", cc.address())) - .subscribe(); + .block(); Mono source = RSocketConnector.create() diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java index cf68dcdde..89b22749f 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java @@ -35,7 +35,7 @@ import reactor.core.publisher.BaseSubscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.Sinks; import reactor.util.concurrent.Queues; /** @@ -48,12 +48,12 @@ public class TaskProcessingWithServerSideNotificationsExample { public static void main(String[] args) throws InterruptedException { - UnicastProcessor tasksProcessor = - UnicastProcessor.create(Queues.unboundedMultiproducer().get()); + Sinks.Many tasksProcessor = + Sinks.many().unicast().onBackpressureBuffer(Queues.unboundedMultiproducer().get()); ConcurrentMap> idToCompletedTasksMap = new ConcurrentHashMap<>(); ConcurrentMap idToRSocketMap = new ConcurrentHashMap<>(); BackgroundWorker backgroundWorker = - new BackgroundWorker(tasksProcessor, idToCompletedTasksMap, idToRSocketMap); + new BackgroundWorker(tasksProcessor.asFlux(), idToCompletedTasksMap, idToRSocketMap); RSocketServer.create(new TasksAcceptor(tasksProcessor, idToCompletedTasksMap, idToRSocketMap)) .bindNow(TcpServerTransport.create(9991)); @@ -132,12 +132,12 @@ static class TasksAcceptor implements SocketAcceptor { static final Logger logger = LoggerFactory.getLogger(TasksAcceptor.class); - final UnicastProcessor tasksToProcess; + final Sinks.Many tasksToProcess; final ConcurrentMap> idToCompletedTasksMap; final ConcurrentMap idToRSocketMap; TasksAcceptor( - UnicastProcessor tasksToProcess, + Sinks.Many tasksToProcess, ConcurrentMap> idToCompletedTasksMap, ConcurrentMap idToRSocketMap) { this.tasksToProcess = tasksToProcess; @@ -197,11 +197,11 @@ private static class RSocketTaskHandler implements RSocket { private final String id; private final RSocket sendingSocket; private ConcurrentMap idToRSocketMap; - private UnicastProcessor tasksToProcess; + private Sinks.Many tasksToProcess; public RSocketTaskHandler( ConcurrentMap idToRSocketMap, - UnicastProcessor tasksToProcess, + Sinks.Many tasksToProcess, String id, RSocket sendingSocket) { this.id = id; @@ -213,9 +213,9 @@ public RSocketTaskHandler( @Override public Mono fireAndForget(Payload payload) { logger.info("Received a Task[{}] from Client.ID[{}]", payload.getDataUtf8(), id); - tasksToProcess.onNext(new Task(id, payload.getDataUtf8())); + Sinks.EmitResult result = tasksToProcess.tryEmitNext(new Task(id, payload.getDataUtf8())); payload.release(); - return Mono.empty(); + return result.isFailure() ? Mono.error(new Sinks.EmissionException(result)) : Mono.empty(); } @Override diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java new file mode 100644 index 000000000..272caf7a0 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java @@ -0,0 +1,144 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class LeaseManager implements Runnable { + + static final Logger logger = LoggerFactory.getLogger(LeaseManager.class); + + volatile int activeConnectionsCount; + static final AtomicIntegerFieldUpdater ACTIVE_CONNECTIONS_COUNT = + AtomicIntegerFieldUpdater.newUpdater(LeaseManager.class, "activeConnectionsCount"); + + volatile int stateAndInFlight; + static final AtomicIntegerFieldUpdater STATE_AND_IN_FLIGHT = + AtomicIntegerFieldUpdater.newUpdater(LeaseManager.class, "stateAndInFlight"); + + static final int MASK_PAUSED = 0b1_000_0000_0000_0000_0000_0000_0000_0000; + static final int MASK_IN_FLIGHT = 0b0_111_1111_1111_1111_1111_1111_1111_1111; + + final BlockingDeque sendersQueue = new LinkedBlockingDeque<>(); + final Scheduler worker = Schedulers.newSingle(LeaseManager.class.getName()); + + final int capacity; + final int ttl; + + public LeaseManager(int capacity, int ttl) { + this.capacity = capacity; + this.ttl = ttl; + } + + @Override + public void run() { + try { + LimitBasedLeaseSender leaseSender = sendersQueue.poll(); + + if (leaseSender == null) { + return; + } + + if (leaseSender.isDisposed()) { + logger.debug("Connection[" + leaseSender.connectionId + "]: LeaseSender is Disposed"); + worker.schedule(this); + return; + } + + int limit = leaseSender.limitAlgorithm.getLimit(); + + if (limit == 0) { + throw new IllegalStateException("Limit is 0"); + } + + if (pauseIfNoCapacity()) { + sendersQueue.addFirst(leaseSender); + logger.debug("Pause execution. Not enough capacity"); + return; + } + + leaseSender.sendLease(ttl, limit); + sendersQueue.offer(leaseSender); + + int activeConnections = activeConnectionsCount; + int nextDelay = activeConnections == 0 ? ttl : (ttl / activeConnections); + + logger.debug("Next check happens in " + nextDelay + "ms"); + + worker.schedule(this, nextDelay, TimeUnit.MILLISECONDS); + } catch (Throwable e) { + logger.error("LeaseSender failed to send lease", e); + } + } + + int incrementInFlightAndGet() { + for (; ; ) { + int state = stateAndInFlight; + int paused = state & MASK_PAUSED; + int inFlight = stateAndInFlight & MASK_IN_FLIGHT; + + // assume overflow is impossible due to max concurrency in RSocket it self + int nextInFlight = inFlight + 1; + + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight | paused)) { + return nextInFlight; + } + } + } + + void decrementInFlight() { + for (; ; ) { + int state = stateAndInFlight; + int paused = state & MASK_PAUSED; + int inFlight = stateAndInFlight & MASK_IN_FLIGHT; + + // assume overflow is impossible due to max concurrency in RSocket it self + int nextInFlight = inFlight - 1; + + if (inFlight == capacity && paused == MASK_PAUSED) { + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight)) { + logger.debug("Resume execution"); + worker.schedule(this); + return; + } + } else { + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight | paused)) { + return; + } + } + } + } + + boolean pauseIfNoCapacity() { + int capacity = this.capacity; + for (; ; ) { + int inFlight = stateAndInFlight; + + if (inFlight < capacity) { + return false; + } + + if (STATE_AND_IN_FLIGHT.compareAndSet(this, inFlight, inFlight | MASK_PAUSED)) { + return true; + } + } + } + + void unregister() { + ACTIVE_CONNECTIONS_COUNT.decrementAndGet(this); + } + + void register(LimitBasedLeaseSender sender) { + sendersQueue.offer(sender); + final int activeCount = ACTIVE_CONNECTIONS_COUNT.getAndIncrement(this); + + if (activeCount == 0) { + worker.schedule(this); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java new file mode 100644 index 000000000..8e1b27823 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java @@ -0,0 +1,54 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import com.netflix.concurrency.limits.Limit; +import io.rsocket.lease.Lease; +import io.rsocket.lease.TrackingLeaseSender; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import reactor.util.concurrent.Queues; + +public class LimitBasedLeaseSender extends LimitBasedStatsCollector implements TrackingLeaseSender { + + static final Logger logger = LoggerFactory.getLogger(LimitBasedLeaseSender.class); + + final String connectionId; + final Sinks.Many sink = + Sinks.many().unicast().onBackpressureBuffer(Queues.one().get()); + + public LimitBasedLeaseSender( + String connectionId, LeaseManager leaseManager, Limit limitAlgorithm) { + super(leaseManager, limitAlgorithm); + this.connectionId = connectionId; + } + + @Override + public Flux send() { + logger.info("Received new leased Connection[" + connectionId + "]"); + + leaseManager.register(this); + + return sink.asFlux(); + } + + public void sendLease(int ttl, int amount) { + final Lease nextLease = Lease.create(Duration.ofMillis(ttl), amount); + final Sinks.EmitResult result = sink.tryEmitNext(nextLease); + + if (result.isFailure()) { + logger.warn( + "Connection[" + + connectionId + + "]. Issued Lease: [" + + nextLease + + "] was not sent due to " + + result); + } else { + if (logger.isDebugEnabled()) { + logger.debug("To Connection[" + connectionId + "]: Issued Lease: [" + nextLease + "]"); + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java new file mode 100644 index 000000000..7f639ab87 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java @@ -0,0 +1,73 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import com.netflix.concurrency.limits.Limit; +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.LongSupplier; +import reactor.util.annotation.Nullable; + +public class LimitBasedStatsCollector extends AtomicBoolean implements RequestInterceptor { + + final LeaseManager leaseManager; + final Limit limitAlgorithm; + + final ConcurrentMap inFlightMap = new ConcurrentHashMap<>(); + final ConcurrentMap timeMap = new ConcurrentHashMap<>(); + + final LongSupplier clock = System::nanoTime; + + public LimitBasedStatsCollector(LeaseManager leaseManager, Limit limitAlgorithm) { + this.leaseManager = leaseManager; + this.limitAlgorithm = limitAlgorithm; + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + long startTime = clock.getAsLong(); + + int currentInFlight = leaseManager.incrementInFlightAndGet(); + + inFlightMap.put(streamId, currentInFlight); + timeMap.put(streamId, startTime); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) {} + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + leaseManager.decrementInFlight(); + + Long startTime = timeMap.remove(streamId); + Integer currentInflight = inFlightMap.remove(streamId); + + limitAlgorithm.onSample(startTime, clock.getAsLong() - startTime, currentInflight, t != null); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + leaseManager.decrementInFlight(); + + Long startTime = timeMap.remove(streamId); + Integer currentInflight = inFlightMap.remove(streamId); + + limitAlgorithm.onSample(startTime, clock.getAsLong() - startTime, currentInflight, true); + } + + @Override + public boolean isDisposed() { + return get(); + } + + @Override + public void dispose() { + if (!getAndSet(true)) { + leaseManager.unregister(); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java new file mode 100644 index 000000000..a18dd9484 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java @@ -0,0 +1,27 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.controller; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +// emulating a worker that process data from the queue +public class Task implements Runnable { + private static final Logger logger = LoggerFactory.getLogger(Task.class); + + final String message; + final int processingTime; + + Task(String message, int processingTime) { + this.message = message; + this.processingTime = processingTime; + } + + @Override + public void run() { + logger.info("Processing Task[{}]", message); + try { + Thread.sleep(processingTime); // emulating processing + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java new file mode 100644 index 000000000..cbecadfc3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java @@ -0,0 +1,44 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.controller; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; + +public class TasksHandlingRSocket implements RSocket { + + private static final Logger logger = LoggerFactory.getLogger(TasksHandlingRSocket.class); + + final Disposable terminatable; + final Scheduler workScheduler; + final int processingTime; + + public TasksHandlingRSocket(Disposable terminatable, Scheduler scheduler, int processingTime) { + this.terminatable = terminatable; + this.workScheduler = scheduler; + this.processingTime = processingTime; + } + + @Override + public Mono fireAndForget(Payload payload) { + + // specifically to show that lease can limit rate of fnf requests in + // that example + String message = payload.getDataUtf8(); + payload.release(); + + return Mono.fromRunnable(new Task(message, processingTime)) + // schedule task on specific, limited in size scheduler + .subscribeOn(workScheduler) + // if errors - terminates server + .doOnError( + t -> { + logger.error("Queue has been overflowed. Terminating server"); + terminatable.dispose(); + System.exit(9); + }); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/README.MD b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/README.MD new file mode 100644 index 000000000..e69de29bb diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java new file mode 100644 index 000000000..30eb0c0e3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.advanced.invertmulticlient; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Comparator; +import java.util.concurrent.PriorityBlockingQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class RequestingServer { + + private static final Logger logger = LoggerFactory.getLogger(RequestingServer.class); + + public static void main(String[] args) { + PriorityBlockingQueue rSockets = + new PriorityBlockingQueue<>( + 16, Comparator.comparingDouble(RSocket::availability).reversed()); + + CloseableChannel server = + RSocketServer.create( + (setup, sendingSocket) -> { + logger.info("Received new connection"); + return Mono.just(new RSocket() {}) + .doAfterTerminate(() -> rSockets.put(sendingSocket)); + }) + .lease(spec -> spec.maxPendingRequests(Integer.MAX_VALUE)) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + logger.info("Server started on port {}", server.address().getPort()); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + .flatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + + return Mono.fromCallable( + () -> { + RSocket rSocket = rSockets.take(); + rSockets.offer(rSocket); + return rSocket; + }) + .flatMap( + clientRSocket -> + clientRSocket.fireAndForget(ByteBufPayload.create("" + tick))) + .retry(); + }) + .blockLast(); + + server.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java new file mode 100644 index 000000000..4a06855b2 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java @@ -0,0 +1,67 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.invertmulticlient; + +import com.netflix.concurrency.limits.limit.VegasLimit; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LeaseManager; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LimitBasedLeaseSender; +import io.rsocket.examples.transport.tcp.lease.advanced.controller.TasksHandlingRSocket; +import io.rsocket.transport.netty.client.TcpClientTransport; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class RespondingClient { + private static final Logger logger = LoggerFactory.getLogger(RespondingClient.class); + + public static final int PROCESSING_TASK_TIME = 500; + public static final int CONCURRENT_WORKERS_COUNT = 1; + public static final int QUEUE_CAPACITY = 50; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + BlockingQueue tasksQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + ThreadPoolExecutor threadPoolExecutor = + new ThreadPoolExecutor(1, CONCURRENT_WORKERS_COUNT, 1, TimeUnit.MINUTES, tasksQueue); + + Scheduler workScheduler = Schedulers.fromExecutorService(threadPoolExecutor); + + LeaseManager periodicLeaseSender = + new LeaseManager(CONCURRENT_WORKERS_COUNT, PROCESSING_TASK_TIME); + + Disposable.Composite disposable = Disposables.composite(); + RSocket clientRSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new TasksHandlingRSocket(disposable, workScheduler, PROCESSING_TASK_TIME))) + .lease( + (config) -> + config.sender( + new LimitBasedLeaseSender( + UUID.randomUUID().toString(), + periodicLeaseSender, + VegasLimit.newBuilder() + .initialLimit(CONCURRENT_WORKERS_COUNT) + .maxConcurrency(QUEUE_CAPACITY) + .build()))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + Objects.requireNonNull(clientRSocket); + disposable.add(clientRSocket); + clientRSocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/README.MD b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/README.MD new file mode 100644 index 000000000..e69de29bb diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java new file mode 100644 index 000000000..c2fde38e3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java @@ -0,0 +1,41 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.multiclient; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public class RequestingClient { + private static final Logger logger = LoggerFactory.getLogger(RequestingClient.class); + + public static void main(String[] args) { + + RSocket clientRSocket = + RSocketConnector.create() + .lease() + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + Objects.requireNonNull(clientRSocket); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + .concatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + return clientRSocket.fireAndForget(ByteBufPayload.create("" + tick)); + }) + .blockLast(); + + clientRSocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java new file mode 100644 index 000000000..b54330450 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java @@ -0,0 +1,81 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.advanced.multiclient; + +import com.netflix.concurrency.limits.limit.VegasLimit; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketServer; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LeaseManager; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LimitBasedLeaseSender; +import io.rsocket.examples.transport.tcp.lease.advanced.controller.TasksHandlingRSocket; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class RespondingServer { + + private static final Logger logger = LoggerFactory.getLogger(RespondingServer.class); + + public static final int TASK_PROCESSING_TIME = 500; + public static final int CONCURRENT_WORKERS_COUNT = 1; + public static final int QUEUE_CAPACITY = 50; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + BlockingQueue tasksQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + ThreadPoolExecutor threadPoolExecutor = + new ThreadPoolExecutor(1, CONCURRENT_WORKERS_COUNT, 1, TimeUnit.MINUTES, tasksQueue); + + Scheduler workScheduler = Schedulers.fromExecutorService(threadPoolExecutor); + + LeaseManager leaseManager = new LeaseManager(CONCURRENT_WORKERS_COUNT, TASK_PROCESSING_TIME); + + Disposable.Composite disposable = Disposables.composite(); + CloseableChannel server = + RSocketServer.create( + SocketAcceptor.with( + new TasksHandlingRSocket(disposable, workScheduler, TASK_PROCESSING_TIME))) + .lease( + (config) -> + config.sender( + new LimitBasedLeaseSender( + UUID.randomUUID().toString(), + leaseManager, + VegasLimit.newBuilder() + .initialLimit(CONCURRENT_WORKERS_COUNT) + .maxConcurrency(QUEUE_CAPACITY) + .build()))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + disposable.add(server); + + logger.info("Server started on port {}", server.address().getPort()); + server.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java similarity index 66% rename from rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java rename to rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java index 49f683204..c54335ccc 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java @@ -14,33 +14,26 @@ * limitations under the License. */ -package io.rsocket.examples.transport.tcp.lease; +package io.rsocket.examples.transport.tcp.lease.simple; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.core.RSocketConnector; import io.rsocket.core.RSocketServer; import io.rsocket.lease.Lease; -import io.rsocket.lease.LeaseStats; -import io.rsocket.lease.Leases; -import io.rsocket.lease.MissingLeaseException; +import io.rsocket.lease.LeaseSender; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.ByteBufPayload; import java.time.Duration; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; -import java.util.function.Consumer; -import java.util.function.Function; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.ReplayProcessor; -import reactor.util.retry.Retry; public class LeaseExample { @@ -95,13 +88,12 @@ public Mono fireAndForget(Payload payload) { return Mono.empty(); } })) - .lease(() -> Leases.create().sender(new LeaseCalculator(SERVER_TAG, messagesQueue))) + .lease(leases -> leases.sender(new LeaseCalculator(SERVER_TAG, messagesQueue))) .bindNow(TcpServerTransport.create("localhost", 7000)); - LeaseReceiver receiver = new LeaseReceiver(CLIENT_TAG); RSocket clientRSocket = RSocketConnector.create() - .lease(() -> Leases.create().receiver(receiver)) + .lease((config) -> config.maxPendingRequests(1)) .connect(TcpClientTransport.create(server.address())) .block(); @@ -116,22 +108,10 @@ public Mono fireAndForget(Payload payload) { }) // here we wait for the first lease for the responder side and start execution // on if there is allowance - .delaySubscription(receiver.notifyWhenNewLease().then()) .concatMap( tick -> { logger.info("Requesting FireAndForget({})", tick); - return Mono.defer(() -> clientRSocket.fireAndForget(ByteBufPayload.create("" + tick))) - .retryWhen( - Retry.indefinitely() - // ensures that error is the result of missed lease - .filter(t -> t instanceof MissingLeaseException) - .doBeforeRetryAsync( - rs -> { - // here we create a mechanism to delay the retry until - // the new lease allowance comes in. - logger.info("Ran out of leases {}", rs); - return receiver.notifyWhenNewLease().then(); - })); + return clientRSocket.fireAndForget(ByteBufPayload.create("" + tick)); }) .blockLast(); @@ -146,7 +126,7 @@ public Mono fireAndForget(Payload payload) { * connection.
* In real-world projects this class has to issue leases based on real metrics */ - private static class LeaseCalculator implements Function, Flux> { + private static class LeaseCalculator implements LeaseSender { final String tag; final BlockingQueue queue; @@ -156,8 +136,7 @@ public LeaseCalculator(String tag, BlockingQueue queue) { } @Override - public Flux apply(Optional leaseStats) { - logger.info("{} stats are {}", tag, leaseStats.isPresent() ? "present" : "absent"); + public Flux send() { Duration ttlDuration = Duration.ofSeconds(5); // The interval function is used only for the demo purpose and should not be // considered as the way to issue leases. @@ -173,45 +152,9 @@ public Flux apply(Optional leaseStats) { // reissue new lease only if queue has remaining capacity to // accept more requests if (requests > 0) { - long ttl = ttlDuration.toMillis(); - sink.next(Lease.create((int) ttl, requests)); + sink.next(Lease.create(ttlDuration, requests)); } }); } } - - /** - * Requester-side Lease listener.
- * In the nutshell this class implements mechanism to listen (and do appropriate actions as - * needed) to incoming leases issued by the Responder - */ - private static class LeaseReceiver implements Consumer> { - final String tag; - final ReplayProcessor lastLeaseReplay = ReplayProcessor.cacheLast(); - - public LeaseReceiver(String tag) { - this.tag = tag; - } - - @Override - public void accept(Flux receivedLeases) { - receivedLeases.subscribe( - l -> { - logger.info( - "{} received leases - ttl: {}, requests: {}", - tag, - l.getTimeToLiveMillis(), - l.getAllowedRequests()); - lastLeaseReplay.onNext(l); - }); - } - - /** - * This method allows to listen to new incoming leases and delay some action (e.g . retry) until - * new valid lease has come in - */ - public Mono notifyWhenNewLease() { - return lastLeaseReplay.filter(l -> l.isValid()).next(); - } - } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java index 27d10b472..abed4a52d 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java @@ -101,7 +101,6 @@ public static void main(String[] args) { for (int i = 0; i < 10000; i++) { try { - rSocketClient.requestResponse(Mono.just(DefaultPayload.create("test" + i))).block(); } catch (Throwable t) { // no ops diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java new file mode 100644 index 000000000..a0a02a946 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java @@ -0,0 +1,102 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.metadata.routing; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.CompositeMetadataCodec; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TaggingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Collections; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class CompositeMetadataExample { + static final Logger logger = LoggerFactory.getLogger(CompositeMetadataExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.forRequestResponse( + payload -> { + final String route = decodeRoute(payload.sliceMetadata()); + + logger.info("Received RequestResponse[route={}]", route); + + payload.release(); + + if ("my.test.route".equals(route)) { + return Mono.just(ByteBufPayload.create("Hello From My Test Route")); + } + + return Mono.error(new IllegalArgumentException("Route " + route + " not found")); + })) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + // here we specify that every metadata payload will be encoded using + // CompositeMetadata layout as specified in the following subspec + // https://github.com/rsocket/rsocket/blob/master/Extensions/CompositeMetadata.md + .metadataMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final ByteBuf routeMetadata = + TaggingMetadataCodec.createTaggingContent( + ByteBufAllocator.DEFAULT, Collections.singletonList("my.test.route")); + final CompositeByteBuf compositeMetadata = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadata, + ByteBufAllocator.DEFAULT, + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, + routeMetadata); + + socket + .requestResponse( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "HelloWorld"), compositeMetadata)) + .log() + .block(); + } + + static String decodeRoute(ByteBuf metadata) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + + for (CompositeMetadata.Entry metadatum : compositeMetadata) { + if (Objects.requireNonNull(metadatum.getMimeType()) + .equals(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString())) { + return new RoutingMetadata(metadatum.getContent()).iterator().next(); + } + } + + return null; + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java new file mode 100644 index 000000000..2aee18bf9 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java @@ -0,0 +1,83 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.metadata.routing; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TaggingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Collections; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class RoutingMetadataExample { + static final Logger logger = LoggerFactory.getLogger(RoutingMetadataExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.forRequestResponse( + payload -> { + final String route = decodeRoute(payload.sliceMetadata()); + + logger.info("Received RequestResponse[route={}]", route); + + payload.release(); + + if ("my.test.route".equals(route)) { + return Mono.just(ByteBufPayload.create("Hello From My Test Route")); + } + + return Mono.error(new IllegalArgumentException("Route " + route + " not found")); + })) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + // here we specify that route will be encoded using + // Routing&Tagging Metadata layout specified at this + // subspec https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + .metadataMimeType(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final ByteBuf routeMetadata = + TaggingMetadataCodec.createTaggingContent( + ByteBufAllocator.DEFAULT, Collections.singletonList("my.test.route")); + socket + .requestResponse( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "HelloWorld"), routeMetadata)) + .log() + .block(); + } + + static String decodeRoute(ByteBuf metadata) { + final RoutingMetadata routingMetadata = new RoutingMetadata(metadata); + + return routingMetadata.iterator().next(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java index 67a85b67f..5491a1aab 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java @@ -39,8 +39,7 @@ public Flux requestChannel(Publisher payloads) { } })) .interceptors(registry -> registry.forResponder(LimitRateInterceptor.forResponder(64))) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + .bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = RSocketConnector.create() diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java index 85faeee82..0c372d2d8 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java @@ -50,8 +50,7 @@ public Mono requestResponse(Payload p) { }; RSocketServer.create(SocketAcceptor.with(rsocket)) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + .bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java index 93b54e146..ba82c7c93 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -62,11 +62,11 @@ public static void main(String[] args) { return Files.fileSource(fileName, chunkSize) .map(DefaultPayload::create) - .zipWith(ticks, (p, tick) -> p); + .zipWith(ticks, (p, tick) -> p) + .log("server"); })) .resume(resume) - .bind(TcpServerTransport.create("localhost", 8000)) - .block(); + .bindNow(TcpServerTransport.create("localhost", 8000)); RSocket client = RSocketConnector.create() @@ -76,8 +76,9 @@ public static void main(String[] args) { client .requestStream(codec.encode(new Request(16, "lorem.txt"))) + .log("client") .doFinally(s -> server.dispose()) - .subscribe(Files.fileSink("rsocket-examples/out/lorem_output.txt", PREFETCH_WINDOW_SIZE)); + .subscribe(Files.fileSink("rsocket-examples/build/lorem_output.txt", PREFETCH_WINDOW_SIZE)); server.onClose().block(); } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java index 8116ad4ae..af0df3be1 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java @@ -33,20 +33,23 @@ public final class ClientStreamingToServer { private static final Logger logger = LoggerFactory.getLogger(ClientStreamingToServer.class); - public static void main(String[] args) { + public static void main(String[] args) throws InterruptedException { RSocketServer.create( SocketAcceptor.forRequestStream( payload -> Flux.interval(Duration.ofMillis(100)) .map(aLong -> DefaultPayload.create("Interval: " + aLong)))) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + .bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = - RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); + RSocketConnector.create() + .setupPayload(DefaultPayload.create("test", "test")) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + final Payload payload = DefaultPayload.create("Hello"); socket - .requestStream(DefaultPayload.create("Hello")) + .requestStream(payload) .map(Payload::getDataUtf8) .doOnNext(logger::debug) .take(10) @@ -54,5 +57,7 @@ public static void main(String[] args) { .doFinally(signalType -> socket.dispose()) .then() .block(); + + Thread.sleep(1000000); } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java index f5b1e00e5..10ed34553 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java @@ -43,8 +43,7 @@ public static void main(String[] args) { return Mono.just(new RSocket() {}); }) - .bind(TcpServerTransport.create("localhost", 7000)) - .subscribe(); + .bindNow(TcpServerTransport.create("localhost", 7000)); RSocket rsocket = RSocketConnector.create() diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java new file mode 100644 index 000000000..89304853c --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.ws; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +public class WebSocketAggregationSample { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketAggregationSample.class); + + public static void main(String[] args) { + + ServerTransport.ConnectionAcceptor connectionAcceptor = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .port(0) + .handle( + (req, res) -> + res.sendWebsocket( + (in, out) -> + connectionAcceptor + .apply( + new WebsocketDuplexConnection( + (Connection) in.aggregateFrames())) + .then(out.neverComplete()))) + .bindNow(); + + WebsocketClientTransport transport = + WebsocketClientTransport.create(server.host(), server.port()); + + RSocket clientRSocket = + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(transport) + .block(); + + Flux.range(1, 100) + .concatMap(i -> clientRSocket.requestResponse(ByteBufPayload.create("Hello " + i))) + .doOnNext(payload -> logger.debug("Processed " + payload.getDataUtf8())) + .blockLast(); + clientRSocket.dispose(); + server.dispose(); + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java index e2471f2fc..ac311a231 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java @@ -16,9 +16,8 @@ package io.rsocket.integration; -import static org.hamcrest.Matchers.is; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -38,9 +37,10 @@ import io.rsocket.util.RSocketProxy; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import reactor.core.publisher.Flux; @@ -108,7 +108,7 @@ public Mono requestResponse(Payload payload) { private CountDownLatch disconnectionCounter; private AtomicInteger errorCount; - @Before + @BeforeEach public void startup() { errorCount = new AtomicInteger(); requestCount = new AtomicInteger(); @@ -163,23 +163,26 @@ public Flux requestChannel(Publisher payloads) { .block(); } - @After + @AfterEach public void teardown() { server.dispose(); } - @Test(timeout = 5_000L) + @Test + @Timeout(5_000L) public void testRequest() { client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); - assertThat("Server did not see the request.", requestCount.get(), is(1)); - assertTrue(calledRequester); - assertTrue(calledResponder); - assertTrue(calledClientAcceptor); - assertTrue(calledServerAcceptor); - assertTrue(calledFrame); + assertThat(requestCount).as("Server did not see the request.").hasValue(1); + + assertThat(calledRequester).isTrue(); + assertThat(calledResponder).isTrue(); + assertThat(calledClientAcceptor).isTrue(); + assertThat(calledServerAcceptor).isTrue(); + assertThat(calledFrame).isTrue(); } - @Test(timeout = 5_000L) + @Test + @Timeout(5_000L) public void testStream() { Subscriber subscriber = TestSubscriber.createCancelling(); client.requestStream(DefaultPayload.create("start")).subscribe(subscriber); @@ -188,7 +191,8 @@ public void testStream() { verifyNoMoreInteractions(subscriber); } - @Test(timeout = 5_000L) + @Test + @Timeout(5_000L) public void testClose() throws InterruptedException { client.dispose(); disconnectionCounter.await(); @@ -196,10 +200,8 @@ public void testClose() throws InterruptedException { @Test // (timeout = 5_000L) public void testCallRequestWithErrorAndThenRequest() { - try { - client.requestChannel(Mono.error(new Throwable())).blockLast(); - } catch (Throwable t) { - } + assertThatThrownBy(client.requestChannel(Mono.error(new Throwable("test")))::blockLast) + .hasMessage("java.lang.Throwable: test"); testRequest(); } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java index de27bcb9b..1924668fb 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,8 +16,7 @@ package io.rsocket.integration; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -31,12 +30,13 @@ import io.rsocket.util.RSocketProxy; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.Sinks; import reactor.core.scheduler.Schedulers; public class TcpIntegrationTest { @@ -44,7 +44,7 @@ public class TcpIntegrationTest { private CloseableChannel server; - @Before + @BeforeEach public void startup() { server = RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) @@ -56,12 +56,13 @@ private RSocket buildClient() { return RSocketConnector.connectWith(TcpClientTransport.create(server.address())).block(); } - @After + @AfterEach public void cleanup() { server.dispose(); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testCompleteWithoutNext() { handler = new RSocket() { @@ -74,10 +75,11 @@ public Flux requestStream(Payload payload) { Boolean hasElements = client.requestStream(DefaultPayload.create("REQUEST", "META")).log().hasElements().block(); - assertFalse(hasElements); + assertThat(hasElements).isFalse(); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testSingleStream() { handler = new RSocket() { @@ -91,10 +93,11 @@ public Flux requestStream(Payload payload) { Payload result = client.requestStream(DefaultPayload.create("REQUEST", "META")).blockLast(); - assertEquals("RESPONSE", result.getDataUtf8()); + assertThat(result.getDataUtf8()).isEqualTo("RESPONSE"); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testZeroPayload() { handler = new RSocket() { @@ -108,10 +111,11 @@ public Flux requestStream(Payload payload) { Payload result = client.requestStream(DefaultPayload.create("REQUEST", "META")).blockFirst(); - assertEquals("", result.getDataUtf8()); + assertThat(result.getDataUtf8()).isEmpty(); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testRequestResponseErrors() { handler = new RSocket() { @@ -141,23 +145,24 @@ public Mono requestResponse(Payload payload) { .onErrorReturn(DefaultPayload.create("ERROR")) .block(); - assertEquals("ERROR", response1.getDataUtf8()); - assertEquals("SUCCESS", response2.getDataUtf8()); + assertThat(response1.getDataUtf8()).isEqualTo("ERROR"); + assertThat(response2.getDataUtf8()).isEqualTo("SUCCESS"); } - @Test(timeout = 15_000L) + @Test + @Timeout(15_000L) public void testTwoConcurrentStreams() throws InterruptedException { - ConcurrentHashMap> map = new ConcurrentHashMap<>(); - UnicastProcessor processor1 = UnicastProcessor.create(); + ConcurrentHashMap> map = new ConcurrentHashMap<>(); + Sinks.Many processor1 = Sinks.many().unicast().onBackpressureBuffer(); map.put("REQUEST1", processor1); - UnicastProcessor processor2 = UnicastProcessor.create(); + Sinks.Many processor2 = Sinks.many().unicast().onBackpressureBuffer(); map.put("REQUEST2", processor2); handler = new RSocket() { @Override public Flux requestStream(Payload payload) { - return map.get(payload.getDataUtf8()); + return map.get(payload.getDataUtf8()).asFlux(); } }; @@ -177,13 +182,13 @@ public Flux requestStream(Payload payload) { .subscribeOn(Schedulers.newSingle("2")) .subscribe(c -> nextCountdown.countDown(), t -> {}, completeCountdown::countDown); - processor1.onNext(DefaultPayload.create("RESPONSE1A")); - processor2.onNext(DefaultPayload.create("RESPONSE2A")); + processor1.tryEmitNext(DefaultPayload.create("RESPONSE1A")); + processor2.tryEmitNext(DefaultPayload.create("RESPONSE2A")); nextCountdown.await(); - processor1.onComplete(); - processor2.onComplete(); + processor1.tryEmitComplete(); + processor2.tryEmitComplete(); completeCountdown.await(); } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java index 7d34ba478..cd96584ed 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java @@ -27,13 +27,14 @@ import io.rsocket.util.DefaultPayload; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Test; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; public class TestingStreaming { LocalServerTransport serverTransport = LocalServerTransport.create("test"); - @Test(expected = ApplicationErrorException.class) + @Test public void testRangeButThrowException() { Closeable server = null; try { @@ -53,8 +54,9 @@ public void testRangeButThrowException() { .bind(serverTransport) .block(); - Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); - System.out.println("here"); + Assertions.assertThatThrownBy( + Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i))::blockLast) + .isInstanceOf(ApplicationErrorException.class); } finally { server.dispose(); @@ -76,8 +78,6 @@ public void testRangeOfConsumers() { .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); - System.out.println("here"); - } finally { server.dispose(); } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java new file mode 100644 index 000000000..870ecf0cd --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java @@ -0,0 +1,246 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration.observation; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import io.micrometer.core.instrument.observation.DefaultMeterObservationHandler; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.micrometer.core.tck.MeterRegistryAssert; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.tracing.test.SampleTestRunner; +import io.micrometer.tracing.test.reporter.BuildingBlocks; +import io.micrometer.tracing.test.simple.SpansAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.micrometer.observation.ByteBufGetter; +import io.rsocket.micrometer.observation.ByteBufSetter; +import io.rsocket.micrometer.observation.ObservationRequesterRSocketProxy; +import io.rsocket.micrometer.observation.ObservationResponderRSocketProxy; +import io.rsocket.micrometer.observation.RSocketRequesterTracingObservationHandler; +import io.rsocket.micrometer.observation.RSocketResponderTracingObservationHandler; +import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.Deque; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class ObservationIntegrationTest extends SampleTestRunner { + private static final MeterRegistry registry = new SimpleMeterRegistry(); + private static final ObservationRegistry observationRegistry = ObservationRegistry.create(); + + static { + observationRegistry + .observationConfig() + .observationHandler(new DefaultMeterObservationHandler(registry)); + } + + private final RSocketInterceptor requesterInterceptor; + private final RSocketInterceptor responderInterceptor; + + ObservationIntegrationTest() { + super(SampleRunnerConfig.builder().build()); + requesterInterceptor = + reactiveSocket -> new ObservationRequesterRSocketProxy(reactiveSocket, observationRegistry); + + responderInterceptor = + reactiveSocket -> new ObservationResponderRSocketProxy(reactiveSocket, observationRegistry); + } + + private CloseableChannel server; + private RSocket client; + private AtomicInteger counter; + + @Override + public BiConsumer>> + customizeObservationHandlers() { + return (buildingBlocks, observationHandlers) -> { + observationHandlers.addFirst( + new RSocketRequesterTracingObservationHandler( + buildingBlocks.getTracer(), + buildingBlocks.getPropagator(), + new ByteBufSetter(), + false)); + observationHandlers.addFirst( + new RSocketResponderTracingObservationHandler( + buildingBlocks.getTracer(), + buildingBlocks.getPropagator(), + new ByteBufGetter(), + false)); + }; + } + + @AfterEach + public void teardown() { + if (server != null) { + server.dispose(); + } + } + + private void testRequest() { + counter.set(0); + client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testStream() { + counter.set(0); + client.requestStream(DefaultPayload.create("start")).blockLast(); + + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testRequestChannel() { + counter.set(0); + client.requestChannel(Mono.just(DefaultPayload.create("start"))).blockFirst(); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testFireAndForget() { + counter.set(0); + client.fireAndForget(DefaultPayload.create("start")).subscribe(); + Awaitility.await().atMost(Duration.ofSeconds(50)).until(() -> counter.get() == 1); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + @Override + public SampleTestRunnerConsumer yourCode() { + return (bb, meterRegistry) -> { + counter = new AtomicInteger(); + server = + RSocketServer.create( + (setup, sendingSocket) -> { + sendingSocket.onClose().subscribe(); + + return Mono.just( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Mono.just(DefaultPayload.create("RESPONSE", "METADATA")); + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Flux.range(1, 10_000) + .map(i -> DefaultPayload.create("data -> " + i)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + counter.incrementAndGet(); + return Flux.from(payloads); + } + + @Override + public Mono fireAndForget(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Mono.empty(); + } + }); + }) + .interceptors(registry -> registry.forResponder(responderInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + + client = + RSocketConnector.create() + .interceptors(registry -> registry.forRequester(requesterInterceptor)) + .connect(TcpClientTransport.create(server.address())) + .block(); + + testRequest(); + + testStream(); + + testRequestChannel(); + + testFireAndForget(); + + // @formatter:off + SpansAssert.assertThat(bb.getFinishedSpans()) + .haveSameTraceId() + // "request_*" + "handle" x 4 + .hasNumberOfSpansEqualTo(8) + .hasNumberOfSpansWithNameEqualTo("handle", 4) + .forAllSpansWithNameEqualTo("handle", span -> span.hasTagWithKey("rsocket.request-type")) + .hasASpanWithNameIgnoreCase("request_stream") + .thenASpanWithNameEqualToIgnoreCase("request_stream") + .hasTag("rsocket.request-type", "REQUEST_STREAM") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_channel") + .thenASpanWithNameEqualToIgnoreCase("request_channel") + .hasTag("rsocket.request-type", "REQUEST_CHANNEL") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_fnf") + .thenASpanWithNameEqualToIgnoreCase("request_fnf") + .hasTag("rsocket.request-type", "REQUEST_FNF") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_response") + .thenASpanWithNameEqualToIgnoreCase("request_response") + .hasTag("rsocket.request-type", "REQUEST_RESPONSE"); + + MeterRegistryAssert.assertThat(registry) + .hasTimerWithNameAndTags( + "rsocket.response", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_RESPONSE"))) + .hasTimerWithNameAndTags( + "rsocket.fnf", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_FNF"))) + .hasTimerWithNameAndTags( + "rsocket.request", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_RESPONSE"))) + .hasTimerWithNameAndTags( + "rsocket.channel", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_CHANNEL"))) + .hasTimerWithNameAndTags( + "rsocket.stream", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_STREAM"))); + // @formatter:on + }; + } + + @Override + protected MeterRegistry getMeterRegistry() { + return registry; + } + + @Override + protected ObservationRegistry getObservationRegistry() { + return observationRegistry; + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java index b2dad0022..5eb78fabe 100644 --- a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java @@ -182,7 +182,7 @@ private static Mono newClientRSocket( .resume( new Resume() .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) - .storeFactory(t -> new InMemoryResumableFramesStore("client", 500_000)) + .storeFactory(t -> new InMemoryResumableFramesStore("client", t, 500_000)) .cleanupStoreOnKeepAlive() .retry(Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)))) .keepAlive(Duration.ofSeconds(5), Duration.ofMinutes(5)) @@ -199,7 +199,7 @@ private static Mono newServerRSocket(int sessionDurationSecond new Resume() .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) .cleanupStoreOnKeepAlive() - .storeFactory(t -> new InMemoryResumableFramesStore("server", 500_000))) + .storeFactory(t -> new InMemoryResumableFramesStore("server", t, 500_000))) .bind(serverTransport(SERVER_HOST, SERVER_PORT)); } @@ -212,7 +212,7 @@ public Flux requestChannel(Publisher payloads) { return duplicate( Flux.interval(Duration.ofMillis(1)) .onBackpressureLatest() - .publishOn(Schedulers.elastic()), + .publishOn(Schedulers.boundedElastic()), 20) .map(v -> DefaultPayload.create(String.valueOf(counter.getAndIncrement()))) .takeUntilOther(Flux.from(payloads).then()); diff --git a/rsocket-load-balancer/build.gradle b/rsocket-load-balancer/build.gradle index 29003feaf..6d91324ae 100644 --- a/rsocket-load-balancer/build.gradle +++ b/rsocket-load-balancer/build.gradle @@ -17,8 +17,7 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } dependencies { @@ -33,10 +32,7 @@ dependencies { testImplementation 'org.assertj:assertj-core' testImplementation 'io.projectreactor:reactor-test' - // TODO: Remove after JUnit5 migration - testCompileOnly 'junit:junit' - testImplementation 'org.hamcrest:hamcrest-library' - testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' testRuntimeOnly 'ch.qos.logback:logback-classic' } diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java index c7f64674c..6329da826 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java @@ -34,7 +34,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,6 +41,8 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; import reactor.util.retry.Retry; /** @@ -667,26 +668,28 @@ private class WeightedSocket implements LoadBalancerSocketMetrics, RSocket { @Override public Mono requestResponse(Payload payload) { return rSocketMono.flatMap( - source -> { - return Mono.from( - subscriber -> - source - .requestResponse(payload) - .subscribe(new LatencySubscriber<>(subscriber, this))); - }); + source -> + Mono.from( + subscriber -> + source + .requestResponse(payload) + .subscribe( + new LatencySubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } @Override public Flux requestStream(Payload payload) { return rSocketMono.flatMapMany( - source -> { - return Flux.from( - subscriber -> - source - .requestStream(payload) - .subscribe(new CountingSubscriber<>(subscriber, this))); - }); + source -> + Flux.from( + subscriber -> + source + .requestStream(payload) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } @Override @@ -698,7 +701,9 @@ public Mono fireAndForget(Payload payload) { subscriber -> source .fireAndForget(payload) - .subscribe(new CountingSubscriber<>(subscriber, this))); + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this))); }); } @@ -710,7 +715,9 @@ public Mono metadataPush(Payload payload) { subscriber -> source .metadataPush(payload) - .subscribe(new CountingSubscriber<>(subscriber, this))); + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this))); }); } @@ -718,13 +725,14 @@ public Mono metadataPush(Payload payload) { public Flux requestChannel(Publisher payloads) { return rSocketMono.flatMapMany( - source -> { - return Flux.from( - subscriber -> - source - .requestChannel(payloads) - .subscribe(new CountingSubscriber<>(subscriber, this))); - }); + source -> + Flux.from( + subscriber -> + source + .requestChannel(payloads) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } synchronized double getPredictedLatency() { @@ -867,18 +875,23 @@ public long lastTimeUsedMillis() { * Subscriber wrapper used for request/response interaction model, measure and collect latency * information. */ - private class LatencySubscriber implements Subscriber { - private final Subscriber child; + private class LatencySubscriber implements CoreSubscriber { + private final CoreSubscriber child; private final WeightedSocket socket; private final AtomicBoolean done; private long start; - LatencySubscriber(Subscriber child, WeightedSocket socket) { + LatencySubscriber(CoreSubscriber child, WeightedSocket socket) { this.child = child; this.socket = socket; this.done = new AtomicBoolean(false); } + @Override + public Context currentContext() { + return child.currentContext(); + } + @Override public void onSubscribe(Subscription s) { start = incr(); @@ -931,15 +944,20 @@ public void onComplete() { * Subscriber wrapper used for stream like interaction model, it only counts the number of * active streams */ - private class CountingSubscriber implements Subscriber { - private final Subscriber child; + private class CountingSubscriber implements CoreSubscriber { + private final CoreSubscriber child; private final WeightedSocket socket; - CountingSubscriber(Subscriber child, WeightedSocket socket) { + CountingSubscriber(CoreSubscriber child, WeightedSocket socket) { this.child = child; this.socket = socket; } + @Override + public Context currentContext() { + return child.currentContext(); + } + @Override public void onSubscribe(Subscription s) { socket.pendingStreams.incrementAndGet(); diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java index 0589cc346..52bf89558 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java @@ -16,6 +16,9 @@ package io.rsocket.client; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.client.filter.RSocketSupplier; @@ -24,9 +27,9 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import org.junit.Assert; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.mockito.Mockito; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -34,7 +37,8 @@ public class LoadBalancedRSocketMonoTest { - @Test(timeout = 10_000L) + @Test + @Timeout(10_000L) public void testNeverSelectFailingFactories() throws InterruptedException { TestingRSocket socket = new TestingRSocket(Function.identity()); RSocketSupplier failing = failingClient(); @@ -44,7 +48,8 @@ public void testNeverSelectFailingFactories() throws InterruptedException { testBalancer(factories); } - @Test(timeout = 10_000L) + @Test + @Timeout(10_000L) public void testNeverSelectFailingSocket() throws InterruptedException { TestingRSocket socket = new TestingRSocket(Function.identity()); TestingRSocket failingSocket = @@ -67,8 +72,9 @@ public double availability() { testBalancer(clients); } - @Test(timeout = 10_000L) - @Ignore + @Test + @Timeout(10_000L) + @Disabled public void testRefreshesSocketsOnSelectBeforeReturningFailedAfterNewFactoriesDelivered() { TestingRSocket socket = new TestingRSocket(Function.identity()); @@ -87,12 +93,12 @@ public void testRefreshesSocketsOnSelectBeforeReturningFailedAfterNewFactoriesDe LoadBalancedRSocketMono balancer = LoadBalancedRSocketMono.create(factories); - Assert.assertEquals(0.0, balancer.availability(), 0); + assertThat(balancer.availability()).isZero(); laterSupplier.complete(succeedingFactory(socket)); balancer.rSocketMono.block(); - Assert.assertEquals(1.0, balancer.availability(), 0); + assertThat(balancer.availability()).isEqualTo(1.0); } private void testBalancer(List factories) throws InterruptedException { @@ -128,7 +134,7 @@ private static RSocketSupplier failingClient() { Mockito.when(mock.get()) .thenAnswer( a -> { - Assert.fail(); + fail(); return null; }); diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java index 887132f99..9e1982465 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java @@ -16,9 +16,8 @@ package io.rsocket.client; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; @@ -31,7 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.mockito.Mockito; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -44,7 +43,7 @@ public class RSocketSupplierTest { public void testError() throws InterruptedException { testRSocket( (latch, socket) -> { - assertEquals(1.0, socket.availability(), 0.0); + assertThat(socket.availability()).isEqualTo(1.0); Publisher payloadPublisher = socket.requestResponse(EmptyPayload.INSTANCE); Subscriber subscriber = TestSubscriber.create(); @@ -64,7 +63,7 @@ public void testError() throws InterruptedException { payloadPublisher.subscribe(subscriber); verify(subscriber).onError(any(RuntimeException.class)); double bad = socket.availability(); - assertTrue(good > bad); + assertThat(good > bad).isTrue(); latch.countDown(); }); } @@ -73,7 +72,7 @@ public void testError() throws InterruptedException { public void testWidowReset() throws InterruptedException { testRSocket( (latch, socket) -> { - assertEquals(1.0, socket.availability(), 0.0); + assertThat(socket.availability()).isEqualTo(1.0); Publisher payloadPublisher = socket.requestResponse(EmptyPayload.INSTANCE); Subscriber subscriber = TestSubscriber.create(); @@ -87,7 +86,7 @@ public void testWidowReset() throws InterruptedException { verify(subscriber).onError(any(RuntimeException.class)); double bad = socket.availability(); - assertTrue(good > bad); + assertThat(good > bad).isTrue(); try { Thread.sleep(200); @@ -96,7 +95,7 @@ public void testWidowReset() throws InterruptedException { } double reset = socket.availability(); - assertTrue(reset > bad); + assertThat(reset > bad).isTrue(); latch.countDown(); }); } diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java index 96982121b..2827c8ed4 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,12 +24,13 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.Scannable; import reactor.core.publisher.*; public class TestingRSocket implements RSocket { private final AtomicInteger count; - private final MonoProcessor onClose = MonoProcessor.create(); + private final Sinks.Empty onClose = Sinks.empty(); private final BiFunction, Payload, Boolean> eachPayloadHandler; public TestingRSocket(Function responder) { @@ -128,16 +129,17 @@ public double availability() { @Override public void dispose() { - onClose.onComplete(); + onClose.tryEmitEmpty(); } @Override + @SuppressWarnings("ConstantConditions") public boolean isDisposed() { - return onClose.isDisposed(); + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return onClose; + return onClose.asMono(); } } diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java index 9a5ac644b..b8866b1f6 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java @@ -16,15 +16,14 @@ package io.rsocket.client; -import static org.hamcrest.Matchers.instanceOf; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.client.filter.RSockets; import io.rsocket.util.EmptyPayload; import java.time.Duration; -import org.hamcrest.MatcherAssert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -50,8 +49,9 @@ public void onNext(Payload payload) { @Override public void onError(Throwable t) { - MatcherAssert.assertThat( - "Unexpected exception in onError", t, instanceOf(TimeoutException.class)); + assertThat(t) + .describedAs("Unexpected exception in onError") + .isInstanceOf(TimeoutException.class); } @Override diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java index 0aab4afd7..b214a725e 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java @@ -18,8 +18,8 @@ import java.util.Arrays; import java.util.Random; -import org.junit.Assert; -import org.junit.Test; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; public class MedianTest { private double errorSum = 0; @@ -59,7 +59,8 @@ private void testMedian(Random rng) { maxError = Math.max(maxError, error); minError = Math.min(minError, error); - Assert.assertTrue( - "p50=" + estimation + ", real=" + expected + ", error=" + error, error < 0.02); + Assertions.assertThat(error < 0.02) + .describedAs("p50=" + estimation + ", real=" + expected + ", error=" + error) + .isTrue(); } } diff --git a/rsocket-micrometer/build.gradle b/rsocket-micrometer/build.gradle index 4be616623..debf02f34 100644 --- a/rsocket-micrometer/build.gradle +++ b/rsocket-micrometer/build.gradle @@ -17,13 +17,14 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } dependencies { api project(':rsocket-core') + api 'io.micrometer:micrometer-observation' api 'io.micrometer:micrometer-core' + api 'io.micrometer:micrometer-tracing' implementation 'org.slf4j:slf4j-api' @@ -37,4 +38,10 @@ dependencies { testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' } +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.micrometer") + } +} + description = 'Transparent Metrics exposure to Micrometer' diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java index c8b22382a..7c7ac37b9 100644 --- a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,12 +22,13 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import java.net.SocketAddress; import java.util.Objects; import java.util.function.Consumer; -import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -88,6 +89,11 @@ public ByteBufAllocator alloc() { return delegate.alloc(); } + @Override + public SocketAddress remoteAddress() { + return delegate.remoteAddress(); + } + @Override public void dispose() { delegate.dispose(); @@ -105,10 +111,14 @@ public Flux receive() { } @Override - public Mono send(Publisher frames) { - Objects.requireNonNull(frames, "frames must not be null"); + public void sendFrame(int streamId, ByteBuf frame) { + frameCounters.accept(frame); + delegate.sendFrame(streamId, frame); + } - return delegate.send(Flux.from(frames).doOnNext(frameCounters)); + @Override + public void sendErrorAndClose(RSocketErrorException e) { + delegate.sendErrorAndClose(e); } private static final class FrameCounters implements Consumer { diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java new file mode 100644 index 000000000..09c8ba316 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.metadata.CompositeMetadata; + +public class ByteBufGetter implements Propagator.Getter { + + @Override + public String get(ByteBuf carrier, String key) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(carrier, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + if (key.equals(entry.getMimeType())) { + return entry.getContent().toString(CharsetUtil.UTF_8); + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java new file mode 100644 index 000000000..678bdb1ed --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java @@ -0,0 +1,33 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.metadata.CompositeMetadataCodec; + +public class ByteBufSetter implements Propagator.Setter { + + @Override + public void set(CompositeByteBuf carrier, String key, String value) { + final ByteBufAllocator alloc = carrier.alloc(); + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + carrier, alloc, key, ByteBufUtil.writeUtf8(alloc, value)); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java new file mode 100644 index 000000000..357be8f15 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java @@ -0,0 +1,40 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.core.lang.Nullable; +import io.netty.buffer.ByteBuf; +import io.rsocket.metadata.CompositeMetadata; + +final class CompositeMetadataUtils { + + private CompositeMetadataUtils() { + throw new IllegalStateException("Can't instantiate a utility class"); + } + + @Nullable + static ByteBuf extract(ByteBuf metadata, String key) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + final String entryKey = entry.getMimeType(); + if (key.equals(entryKey)) { + return entry.getContent(); + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java new file mode 100644 index 000000000..2c10fc78d --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java @@ -0,0 +1,49 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +class DefaultRSocketObservationConvention { + + private final RSocketContext rSocketContext; + + public DefaultRSocketObservationConvention(RSocketContext rSocketContext) { + this.rSocketContext = rSocketContext; + } + + String getName() { + if (this.rSocketContext.frameType == FrameType.REQUEST_FNF) { + return "rsocket.fnf"; + } else if (this.rSocketContext.frameType == FrameType.REQUEST_STREAM) { + return "rsocket.stream"; + } else if (this.rSocketContext.frameType == FrameType.REQUEST_CHANNEL) { + return "rsocket.channel"; + } + return "%s"; + } + + protected RSocketContext getRSocketContext() { + return this.rSocketContext; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java new file mode 100644 index 000000000..73e04b749 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java @@ -0,0 +1,62 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.KeyValues; +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public class DefaultRSocketRequesterObservationConvention + extends DefaultRSocketObservationConvention implements RSocketRequesterObservationConvention { + + public DefaultRSocketRequesterObservationConvention(RSocketContext rSocketContext) { + super(rSocketContext); + } + + @Override + public KeyValues getLowCardinalityKeyValues(RSocketContext context) { + KeyValues values = + KeyValues.of( + RSocketObservationDocumentation.ResponderTags.REQUEST_TYPE.withValue( + context.frameType.name())); + if (StringUtils.isNotBlank(context.route)) { + values = + values.and(RSocketObservationDocumentation.ResponderTags.ROUTE.withValue(context.route)); + } + return values; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext; + } + + @Override + public String getName() { + if (getRSocketContext().frameType == FrameType.REQUEST_RESPONSE) { + return "rsocket.request"; + } + return super.getName(); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java new file mode 100644 index 000000000..5318c1b37 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java @@ -0,0 +1,61 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.KeyValues; +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public class DefaultRSocketResponderObservationConvention + extends DefaultRSocketObservationConvention implements RSocketResponderObservationConvention { + + public DefaultRSocketResponderObservationConvention(RSocketContext rSocketContext) { + super(rSocketContext); + } + + @Override + public KeyValues getLowCardinalityKeyValues(RSocketContext context) { + KeyValues tags = + KeyValues.of( + RSocketObservationDocumentation.ResponderTags.REQUEST_TYPE.withValue( + context.frameType.name())); + if (StringUtils.isNotBlank(context.route)) { + tags = tags.and(RSocketObservationDocumentation.ResponderTags.ROUTE.withValue(context.route)); + } + return tags; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext; + } + + @Override + public String getName() { + if (getRSocketContext().frameType == FrameType.REQUEST_RESPONSE) { + return "rsocket.response"; + } + return super.getName(); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java new file mode 100644 index 000000000..fb80ea317 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java @@ -0,0 +1,208 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.docs.ObservationDocumentation; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import java.util.Iterator; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; +import reactor.util.context.ContextView; + +/** + * Tracing representation of a {@link RSocketProxy} for the requester. + * + * @author Marcin Grzejszczak + * @author Oleh Dokuka + * @since 1.1.4 + */ +public class ObservationRequesterRSocketProxy extends RSocketProxy { + + /** Aligned with ObservationThreadLocalAccessor#KEY */ + private static final String MICROMETER_OBSERVATION_KEY = "micrometer.observation"; + + private final ObservationRegistry observationRegistry; + + @Nullable private final RSocketRequesterObservationConvention observationConvention; + + public ObservationRequesterRSocketProxy(RSocket source, ObservationRegistry observationRegistry) { + this(source, observationRegistry, null); + } + + public ObservationRequesterRSocketProxy( + RSocket source, + ObservationRegistry observationRegistry, + RSocketRequesterObservationConvention observationConvention) { + super(source); + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; + } + + @Override + public Mono fireAndForget(Payload payload) { + return setObservation( + super::fireAndForget, + payload, + FrameType.REQUEST_FNF, + RSocketObservationDocumentation.RSOCKET_REQUESTER_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return setObservation( + super::requestResponse, + payload, + FrameType.REQUEST_RESPONSE, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_RESPONSE); + } + + Mono setObservation( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation observation) { + return Mono.deferContextual( + contextView -> observe(input, payload, frameType, observation, contextView)); + } + + private String route(Payload payload) { + if (payload.hasMetadata()) { + try { + ByteBuf extracted = + CompositeMetadataUtils.extract( + payload.sliceMetadata(), WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + final RoutingMetadata routingMetadata = new RoutingMetadata(extracted); + final Iterator iterator = routingMetadata.iterator(); + return iterator.next(); + } catch (Exception e) { + + } + } + return null; + } + + private Mono observe( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation obs, + ContextView contextView) { + String route = route(payload); + RSocketContext rSocketContext = + new RSocketContext( + payload, payload.sliceMetadata(), frameType, route, RSocketContext.Side.REQUESTER); + Observation parentObservation = contextView.getOrDefault(MICROMETER_OBSERVATION_KEY, null); + Observation observation = + obs.observation( + this.observationConvention, + new DefaultRSocketRequesterObservationConvention(rSocketContext), + () -> rSocketContext, + observationRegistry) + .parentObservation(parentObservation); + setContextualName(frameType, route, observation); + observation.start(); + Payload newPayload = payload; + if (rSocketContext.modifiedPayload != null) { + newPayload = rSocketContext.modifiedPayload; + } + return input + .apply(newPayload) + .doOnError(observation::error) + .doFinally(signalType -> observation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, observation)); + } + + @Override + public Flux requestStream(Payload payload) { + return observationFlux( + super::requestStream, + payload, + FrameType.REQUEST_STREAM, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher inbound) { + return Flux.from(inbound) + .switchOnFirst( + (firstSignal, flux) -> { + final Payload firstPayload = firstSignal.get(); + if (firstPayload != null) { + return observationFlux( + p -> super.requestChannel(flux.skip(1).startWith(p)), + firstPayload, + FrameType.REQUEST_CHANNEL, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_CHANNEL); + } + return flux; + }); + } + + private Flux observationFlux( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation obs) { + return Flux.deferContextual( + contextView -> { + String route = route(payload); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + frameType, + route, + RSocketContext.Side.REQUESTER); + Observation parentObservation = + contextView.getOrDefault(MICROMETER_OBSERVATION_KEY, null); + Observation newObservation = + obs.observation( + this.observationConvention, + new DefaultRSocketRequesterObservationConvention(rSocketContext), + () -> rSocketContext, + this.observationRegistry) + .parentObservation(parentObservation); + setContextualName(frameType, route, newObservation); + newObservation.start(); + return input + .apply(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + }); + } + + private void setContextualName(FrameType frameType, String route, Observation newObservation) { + if (StringUtils.isNotBlank(route)) { + newObservation.contextualName(frameType.name() + " " + route); + } else { + newObservation.contextualName(frameType.name()); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java new file mode 100644 index 000000000..9ed27adf3 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java @@ -0,0 +1,179 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import java.util.Iterator; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Tracing representation of a {@link RSocketProxy} for the responder. + * + * @author Marcin Grzejszczak + * @author Oleh Dokuka + * @since 1.1.4 + */ +public class ObservationResponderRSocketProxy extends RSocketProxy { + /** Aligned with ObservationThreadLocalAccessor#KEY */ + private static final String MICROMETER_OBSERVATION_KEY = "micrometer.observation"; + + private final ObservationRegistry observationRegistry; + + @Nullable private final RSocketResponderObservationConvention observationConvention; + + public ObservationResponderRSocketProxy(RSocket source, ObservationRegistry observationRegistry) { + this(source, observationRegistry, null); + } + + public ObservationResponderRSocketProxy( + RSocket source, + ObservationRegistry observationRegistry, + RSocketResponderObservationConvention observationConvention) { + super(source); + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; + } + + @Override + public Mono fireAndForget(Payload payload) { + // called on Netty EventLoop + // there can't be observation in thread local here + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + FrameType.REQUEST_FNF, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation(RSocketObservationDocumentation.RSOCKET_RESPONDER_FNF, rSocketContext); + return super.fireAndForget(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + private Observation startObservation( + RSocketObservationDocumentation observation, RSocketContext rSocketContext) { + return observation.start( + this.observationConvention, + new DefaultRSocketResponderObservationConvention(rSocketContext), + () -> rSocketContext, + this.observationRegistry); + } + + @Override + public Mono requestResponse(Payload payload) { + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + FrameType.REQUEST_RESPONSE, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_RESPONSE, rSocketContext); + return super.requestResponse(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + @Override + public Flux requestStream(Payload payload) { + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, sliceMetadata, FrameType.REQUEST_STREAM, route, RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_STREAM, rSocketContext); + return super.requestStream(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .switchOnFirst( + (firstSignal, flux) -> { + final Payload firstPayload = firstSignal.get(); + if (firstPayload != null) { + ByteBuf sliceMetadata = firstPayload.sliceMetadata(); + String route = route(firstPayload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + firstPayload, + firstPayload.sliceMetadata(), + FrameType.REQUEST_CHANNEL, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_CHANNEL, + rSocketContext); + if (StringUtils.isNotBlank(route)) { + newObservation.contextualName(rSocketContext.frameType.name() + " " + route); + } + return super.requestChannel(flux.skip(1).startWith(rSocketContext.modifiedPayload)) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite( + context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + return flux; + }); + } + + private String route(Payload payload, ByteBuf headers) { + if (payload.hasMetadata()) { + try { + final ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + if (extract != null) { + final RoutingMetadata routingMetadata = new RoutingMetadata(extract); + final Iterator iterator = routingMetadata.iterator(); + return iterator.next(); + } + } catch (Exception e) { + + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java new file mode 100644 index 000000000..e5286a53f --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java @@ -0,0 +1,73 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.CompositeMetadata.Entry; +import io.rsocket.metadata.CompositeMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.util.HashSet; +import java.util.Set; + +final class PayloadUtils { + + private PayloadUtils() { + throw new IllegalStateException("Can't instantiate a utility class"); + } + + static CompositeByteBuf cleanTracingMetadata(Payload payload, Set fields) { + Set fieldsWithDefaultZipkin = new HashSet<>(fields); + fieldsWithDefaultZipkin.add(WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString()); + final CompositeByteBuf metadata = ByteBufAllocator.DEFAULT.compositeBuffer(); + if (payload.hasMetadata()) { + try { + final CompositeMetadata entries = new CompositeMetadata(payload.metadata(), false); + for (Entry entry : entries) { + if (!fieldsWithDefaultZipkin.contains(entry.getMimeType())) { + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + metadata, + ByteBufAllocator.DEFAULT, + entry.getMimeType(), + entry.getContent().retain()); + } + } + } catch (Exception e) { + + } + } + return metadata; + } + + static Payload payload(Payload payload, CompositeByteBuf metadata) { + final Payload newPayload; + try { + if (payload instanceof ByteBufPayload) { + newPayload = ByteBufPayload.create(payload.data().retain(), metadata); + } else { + newPayload = DefaultPayload.create(payload.data().retain(), metadata); + } + } finally { + payload.release(); + } + return newPayload; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java new file mode 100644 index 000000000..8622cdfa5 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java @@ -0,0 +1,76 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.lang.Nullable; +import io.micrometer.observation.Observation; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; + +public class RSocketContext extends Observation.Context { + + final Payload payload; + + final ByteBuf metadata; + + final FrameType frameType; + + final String route; + + final Side side; + + Payload modifiedPayload; + + RSocketContext( + Payload payload, ByteBuf metadata, FrameType frameType, @Nullable String route, Side side) { + this.payload = payload; + this.metadata = metadata; + this.frameType = frameType; + this.route = route; + this.side = side; + } + + public enum Side { + REQUESTER, + RESPONDER + } + + public Payload getPayload() { + return payload; + } + + public ByteBuf getMetadata() { + return metadata; + } + + public FrameType getFrameType() { + return frameType; + } + + public String getRoute() { + return route; + } + + public Side getSide() { + return side; + } + + public Payload getModifiedPayload() { + return modifiedPayload; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java new file mode 100644 index 000000000..1be6b4599 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java @@ -0,0 +1,232 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.docs.KeyName; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; +import io.micrometer.observation.docs.ObservationDocumentation; + +enum RSocketObservationDocumentation implements ObservationDocumentation { + + /** Observation created on the RSocket responder side. */ + RSOCKET_RESPONDER { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + }, + + /** Observation created on the RSocket requester side for Fire and Forget frame type. */ + RSOCKET_REQUESTER_FNF { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Fire and Forget frame type. */ + RSOCKET_RESPONDER_FNF { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Response frame type. */ + RSOCKET_REQUESTER_REQUEST_RESPONSE { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Response frame type. */ + RSOCKET_RESPONDER_REQUEST_RESPONSE { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Stream frame type. */ + RSOCKET_REQUESTER_REQUEST_STREAM { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Stream frame type. */ + RSOCKET_RESPONDER_REQUEST_STREAM { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Channel frame type. */ + RSOCKET_REQUESTER_REQUEST_CHANNEL { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Channel frame type. */ + RSOCKET_RESPONDER_REQUEST_CHANNEL { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }; + + enum RequesterTags implements KeyName { + + /** Name of the RSocket route. */ + ROUTE { + @Override + public String asString() { + return "rsocket.route"; + } + }, + + /** Name of the RSocket request type. */ + REQUEST_TYPE { + @Override + public String asString() { + return "rsocket.request-type"; + } + }, + + /** Name of the RSocket content type. */ + CONTENT_TYPE { + @Override + public String asString() { + return "rsocket.content-type"; + } + } + } + + enum ResponderTags implements KeyName { + + /** Name of the RSocket route. */ + ROUTE { + @Override + public String asString() { + return "rsocket.route"; + } + }, + + /** Name of the RSocket request type. */ + REQUEST_TYPE { + @Override + public String asString() { + return "rsocket.request-type"; + } + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java new file mode 100644 index 000000000..d795f81b5 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; + +/** + * {@link ObservationConvention} for RSocket requester {@link RSocketContext}. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public interface RSocketRequesterObservationConvention + extends ObservationConvention { + + @Override + default boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.REQUESTER; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java new file mode 100644 index 000000000..996267d4a --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java @@ -0,0 +1,131 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.micrometer.tracing.internal.EncodingUtils; +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.TracingMetadataCodec; +import java.util.HashSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RSocketRequesterTracingObservationHandler + implements TracingObservationHandler { + private static final Logger log = + LoggerFactory.getLogger(RSocketRequesterTracingObservationHandler.class); + + private final Propagator propagator; + + private final Propagator.Setter setter; + + private final Tracer tracer; + + private final boolean isZipkinPropagationEnabled; + + public RSocketRequesterTracingObservationHandler( + Tracer tracer, + Propagator propagator, + Propagator.Setter setter, + boolean isZipkinPropagationEnabled) { + this.tracer = tracer; + this.propagator = propagator; + this.setter = setter; + this.isZipkinPropagationEnabled = isZipkinPropagationEnabled; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.REQUESTER; + } + + @Override + public Tracer getTracer() { + return this.tracer; + } + + @Override + public void onStart(RSocketContext context) { + Payload payload = context.payload; + Span.Builder spanBuilder = this.tracer.spanBuilder(); + Span parentSpan = getParentSpan(context); + if (parentSpan != null) { + spanBuilder.setParent(parentSpan.context()); + } + Span span = spanBuilder.kind(Span.Kind.PRODUCER).start(); + log.debug("Extracted result from context or thread local {}", span); + // TODO: newmetadata returns an empty composite byte buf + final CompositeByteBuf newMetadata = + PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields())); + TraceContext traceContext = span.context(); + if (this.isZipkinPropagationEnabled) { + injectDefaultZipkinRSocketHeaders(newMetadata, traceContext); + } + this.propagator.inject(traceContext, newMetadata, this.setter); + context.modifiedPayload = PayloadUtils.payload(payload, newMetadata); + getTracingContext(context).setSpan(span); + } + + @Override + public void onError(RSocketContext context) { + Throwable error = context.getError(); + if (error != null) { + getRequiredSpan(context).error(error); + } + } + + @Override + public void onStop(RSocketContext context) { + Span span = getRequiredSpan(context); + tagSpan(context, span); + span.name(context.getContextualName()).end(); + } + + private void injectDefaultZipkinRSocketHeaders( + CompositeByteBuf newMetadata, TraceContext traceContext) { + TracingMetadataCodec.Flags flags = + traceContext.sampled() == null + ? TracingMetadataCodec.Flags.UNDECIDED + : traceContext.sampled() + ? TracingMetadataCodec.Flags.SAMPLE + : TracingMetadataCodec.Flags.NOT_SAMPLE; + String traceId = traceContext.traceId(); + long[] traceIds = EncodingUtils.fromString(traceId); + long[] spanId = EncodingUtils.fromString(traceContext.spanId()); + long[] parentSpanId = EncodingUtils.fromString(traceContext.parentId()); + boolean isTraceId128Bit = traceIds.length == 2; + if (isTraceId128Bit) { + TracingMetadataCodec.encode128( + newMetadata.alloc(), + traceIds[0], + traceIds[1], + spanId[0], + EncodingUtils.fromString(traceContext.parentId())[0], + flags); + } else { + TracingMetadataCodec.encode64( + newMetadata.alloc(), traceIds[0], spanId[0], parentSpanId[0], flags); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java new file mode 100644 index 000000000..a5d6808bd --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; + +/** + * {@link ObservationConvention} for RSocket responder {@link RSocketContext}. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public interface RSocketResponderObservationConvention + extends ObservationConvention { + + @Override + default boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.RESPONDER; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java new file mode 100644 index 000000000..e3975b577 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java @@ -0,0 +1,152 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.micrometer.tracing.internal.EncodingUtils; +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TracingMetadata; +import io.rsocket.metadata.TracingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import java.util.HashSet; +import java.util.Iterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RSocketResponderTracingObservationHandler + implements TracingObservationHandler { + + private static final Logger log = + LoggerFactory.getLogger(RSocketResponderTracingObservationHandler.class); + + private final Propagator propagator; + + private final Propagator.Getter getter; + + private final Tracer tracer; + + private final boolean isZipkinPropagationEnabled; + + public RSocketResponderTracingObservationHandler( + Tracer tracer, + Propagator propagator, + Propagator.Getter getter, + boolean isZipkinPropagationEnabled) { + this.tracer = tracer; + this.propagator = propagator; + this.getter = getter; + this.isZipkinPropagationEnabled = isZipkinPropagationEnabled; + } + + @Override + public void onStart(RSocketContext context) { + Span handle = consumerSpanBuilder(context.payload, context.metadata, context.frameType); + CompositeByteBuf bufs = + PayloadUtils.cleanTracingMetadata(context.payload, new HashSet<>(propagator.fields())); + context.modifiedPayload = PayloadUtils.payload(context.payload, bufs); + getTracingContext(context).setSpan(handle); + } + + @Override + public void onError(RSocketContext context) { + Throwable error = context.getError(); + if (error != null) { + getRequiredSpan(context).error(error); + } + } + + @Override + public void onStop(RSocketContext context) { + Span span = getRequiredSpan(context); + tagSpan(context, span); + span.end(); + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.RESPONDER; + } + + @Override + public Tracer getTracer() { + return this.tracer; + } + + private Span consumerSpanBuilder(Payload payload, ByteBuf headers, FrameType requestType) { + Span.Builder consumerSpanBuilder = consumerSpanBuilder(payload, headers); + log.debug("Extracted result from headers {}", consumerSpanBuilder); + String name = "handle"; + if (payload.hasMetadata()) { + try { + final ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + if (extract != null) { + final RoutingMetadata routingMetadata = new RoutingMetadata(extract); + final Iterator iterator = routingMetadata.iterator(); + name = requestType.name() + " " + iterator.next(); + } + } catch (Exception e) { + + } + } + return consumerSpanBuilder.kind(Span.Kind.CONSUMER).name(name).start(); + } + + private Span.Builder consumerSpanBuilder(Payload payload, ByteBuf headers) { + if (this.isZipkinPropagationEnabled && payload.hasMetadata()) { + try { + ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString()); + if (extract != null) { + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(extract); + Span.Builder builder = this.tracer.spanBuilder(); + String traceId = EncodingUtils.fromLong(tracingMetadata.traceId()); + long traceIdHigh = tracingMetadata.traceIdHigh(); + if (traceIdHigh != 0L) { + // ExtendedTraceId + traceId = EncodingUtils.fromLong(traceIdHigh) + traceId; + } + TraceContext.Builder parentBuilder = + this.tracer + .traceContextBuilder() + .sampled(tracingMetadata.isDebug() || tracingMetadata.isSampled()) + .traceId(traceId) + .spanId(EncodingUtils.fromLong(tracingMetadata.spanId())) + .parentId(EncodingUtils.fromLong(tracingMetadata.parentId())); + return builder.setParent(parentBuilder.build()); + } else { + return this.propagator.extract(headers, this.getter); + } + } catch (Exception e) { + + } + } + return this.propagator.extract(headers, this.getter); + } +} diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java index 03abd2084..7806200dd 100644 --- a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java @@ -34,7 +34,7 @@ import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; -import org.reactivestreams.Publisher; +import org.mockito.Mockito; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; @@ -153,32 +153,29 @@ void receive() { @SuppressWarnings("unchecked") @Test void send() { - ArgumentCaptor> captor = ArgumentCaptor.forClass(Publisher.class); - when(delegate.send(captor.capture())).thenReturn(Mono.empty()); - - Flux frames = - Flux.just( - createTestCancelFrame(), - createTestErrorFrame(), - createTestKeepaliveFrame(), - createTestLeaseFrame(), - createTestMetadataPushFrame(), - createTestPayloadFrame(), - createTestRequestChannelFrame(), - createTestRequestFireAndForgetFrame(), - createTestRequestNFrame(), - createTestRequestResponseFrame(), - createTestRequestStreamFrame(), - createTestSetupFrame()); - - new MicrometerDuplexConnection( - SERVER, delegate, meterRegistry, Tag.of("test-key", "test-value")) - .send(frames) - .as(StepVerifier::create) + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + doNothing().when(delegate).sendFrame(Mockito.anyInt(), captor.capture()); + + final MicrometerDuplexConnection micrometerDuplexConnection = + new MicrometerDuplexConnection( + SERVER, delegate, meterRegistry, Tag.of("test-key", "test-value")); + micrometerDuplexConnection.sendFrame(1, createTestCancelFrame()); + micrometerDuplexConnection.sendFrame(1, createTestErrorFrame()); + micrometerDuplexConnection.sendFrame(1, createTestKeepaliveFrame()); + micrometerDuplexConnection.sendFrame(1, createTestLeaseFrame()); + micrometerDuplexConnection.sendFrame(1, createTestMetadataPushFrame()); + micrometerDuplexConnection.sendFrame(1, createTestPayloadFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestChannelFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestFireAndForgetFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestNFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestResponseFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestStreamFrame()); + micrometerDuplexConnection.sendFrame(1, createTestSetupFrame()); + + StepVerifier.create(Flux.fromIterable(captor.getAllValues())) + .expectNextCount(12) .verifyComplete(); - StepVerifier.create(captor.getValue()).expectNextCount(12).verifyComplete(); - assertThat(findCounter(SERVER, CANCEL).count()).isEqualTo(1); assertThat(findCounter(SERVER, COMPLETE).count()).isEqualTo(1); assertThat(findCounter(SERVER, ERROR).count()).isEqualTo(1); @@ -193,15 +190,6 @@ void send() { assertThat(findCounter(SERVER, SETUP).count()).isEqualTo(1); } - @DisplayName("send throws NullPointerException with null frames") - @Test - void sendNullFrames() { - assertThatNullPointerException() - .isThrownBy( - () -> new MicrometerDuplexConnection(CLIENT, delegate, meterRegistry).send(null)) - .withMessage("frames must not be null"); - } - private Counter findCounter(Type connectionType, FrameType frameType) { return meterRegistry .get("rsocket.frame") diff --git a/rsocket-test/build.gradle b/rsocket-test/build.gradle index 5ec1a8061..bcdf88f28 100644 --- a/rsocket-test/build.gradle +++ b/rsocket-test/build.gradle @@ -17,8 +17,7 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } dependencies { @@ -29,9 +28,14 @@ dependencies { implementation 'io.projectreactor:reactor-test' implementation 'org.assertj:assertj-core' implementation 'org.mockito:mockito-core' + implementation 'org.awaitility:awaitility' + implementation 'org.slf4j:slf4j-api' +} - // TODO: Remove after JUnit5 migration - implementation 'junit:junit' +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.test") + } } description = 'Test utilities for RSocket projects' diff --git a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java index d74f59fd8..e773b4a0d 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java @@ -16,22 +16,35 @@ package io.rsocket.test; -import static org.junit.Assert.assertEquals; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.util.DefaultPayload; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Flux; public abstract class BaseClientServerTest> { - @Rule public final T setup = createClientServer(); + public final T setup = createClientServer(); protected abstract T createClientServer(); - @Test(timeout = 10000) + @BeforeEach + public void init() { + setup.init(); + } + + @AfterEach + public void teardown() { + setup.tearDown(); + } + + @Test + @Timeout(10000) public void testFireNForget10() { long outputCount = Flux.range(1, 10) @@ -40,10 +53,11 @@ public void testFireNForget10() { .count() .block(); - assertEquals(0, outputCount); + assertThat(outputCount).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testPushMetadata10() { long outputCount = Flux.range(1, 10) @@ -52,7 +66,7 @@ public void testPushMetadata10() { .count() .block(); - assertEquals(0, outputCount); + assertThat(outputCount).isZero(); } @Test // (timeout = 10000) @@ -65,10 +79,11 @@ public void testRequestResponse1() { .count() .block(); - assertEquals(1, outputCount); + assertThat(outputCount).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestResponse10() { long outputCount = Flux.range(1, 10) @@ -78,7 +93,7 @@ public void testRequestResponse10() { .count() .block(); - assertEquals(10, outputCount); + assertThat(outputCount).isEqualTo(10); } private Payload testPayload(int metadataPresent) { @@ -97,7 +112,8 @@ private Payload testPayload(int metadataPresent) { return DefaultPayload.create("hello", metadata); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestResponse100() { long outputCount = Flux.range(1, 100) @@ -107,10 +123,11 @@ public void testRequestResponse100() { .count() .block(); - assertEquals(100, outputCount); + assertThat(outputCount).isEqualTo(100); } - @Test(timeout = 20000) + @Test + @Timeout(20000) public void testRequestResponse10_000() { long outputCount = Flux.range(1, 10_000) @@ -120,28 +137,31 @@ public void testRequestResponse10_000() { .count() .block(); - assertEquals(10_000, outputCount); + assertThat(outputCount).isEqualTo(10_000); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStream() { Flux publisher = setup.getRSocket().requestStream(testPayload(3)); long count = publisher.take(5).count().block(); - assertEquals(5, count); + assertThat(count).isEqualTo(5); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStreamAll() { Flux publisher = setup.getRSocket().requestStream(testPayload(3)); long count = publisher.count().block(); - assertEquals(10000, count); + assertThat(count).isEqualTo(10000); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStreamWithRequestN() { CountdownBaseSubscriber ts = new CountdownBaseSubscriber(); ts.expect(5); @@ -149,16 +169,17 @@ public void testRequestStreamWithRequestN() { setup.getRSocket().requestStream(testPayload(3)).subscribe(ts); ts.await(); - assertEquals(5, ts.count()); + assertThat(ts.count()).isEqualTo(5); ts.expect(5); ts.await(); ts.cancel(); - assertEquals(10, ts.count()); + assertThat(ts.count()).isEqualTo(10); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStreamWithDelayedRequestN() { CountdownBaseSubscriber ts = new CountdownBaseSubscriber(); @@ -167,34 +188,37 @@ public void testRequestStreamWithDelayedRequestN() { ts.expect(5); ts.await(); - assertEquals(5, ts.count()); + assertThat(ts.count()).isEqualTo(5); ts.expect(5); ts.await(); ts.cancel(); - assertEquals(10, ts.count()); + assertThat(ts.count()).isEqualTo(10); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel0() { Flux publisher = setup.getRSocket().requestChannel(Flux.empty()); long count = publisher.count().block(); - assertEquals(0, count); + assertThat(count).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel1() { Flux publisher = setup.getRSocket().requestChannel(Flux.just(testPayload(0))); long count = publisher.count().block(); - assertEquals(1, count); + assertThat(count).isOne(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel3() { Flux publisher = setup @@ -203,44 +227,48 @@ public void testChannel3() { long count = publisher.count().block(); - assertEquals(3, count); + assertThat(count).isEqualTo(3); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel512() { Flux payloads = Flux.range(1, 512).map(i -> DefaultPayload.create("hello " + i)); long count = setup.getRSocket().requestChannel(payloads).count().block(); - assertEquals(512, count); + assertThat(count).isEqualTo(512); } - @Test(timeout = 30000) + @Test + @Timeout(30000) public void testChannel20_000() { Flux payloads = Flux.range(1, 20_000).map(i -> DefaultPayload.create("hello " + i)); long count = setup.getRSocket().requestChannel(payloads).count().block(); - assertEquals(20_000, count); + assertThat(count).isEqualTo(20_000); } - @Test(timeout = 60_000) + @Test + @Timeout(60_000) public void testChannel200_000() { Flux payloads = Flux.range(1, 200_000).map(i -> DefaultPayload.create("hello " + i)); long count = setup.getRSocket().requestChannel(payloads).count().block(); - assertEquals(200_000, count); + assertThat(count).isEqualTo(200_000); } - @Test(timeout = 60_000) - @Ignore + @Test + @Timeout(60_000) + @Disabled public void testChannel2_000_000() { AtomicInteger counter = new AtomicInteger(0); Flux payloads = Flux.range(1, 2_000_000).map(i -> DefaultPayload.create("hello " + i)); long count = setup.getRSocket().requestChannel(payloads).count().block(); - assertEquals(2_000_000, count); + assertThat(count).isEqualTo(2_000_000); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java index 6f562875f..1d6b7f69e 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java +++ b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java @@ -25,12 +25,9 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; import reactor.core.publisher.Mono; -public class ClientSetupRule extends ExternalResource { +public class ClientSetupRule { private static final String data = "hello world"; private static final String metadata = "metadata"; @@ -39,6 +36,7 @@ public class ClientSetupRule extends ExternalResource { private Function serverInit; private RSocket client; + private S server; public ClientSetupRule( Supplier addressSupplier, @@ -59,18 +57,14 @@ public ClientSetupRule( .block(); } - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - T address = addressSupplier.get(); - S server = serverInit.apply(address); - client = clientConnector.apply(address, server); - base.evaluate(); - server.dispose(); - } - }; + public void init() { + T address = addressSupplier.get(); + S server = serverInit.apply(address); + client = clientConnector.apply(address, server); + } + + public void tearDown() { + server.dispose(); } public RSocket getRSocket() { diff --git a/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java index 141ed4385..46e807b09 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java +++ b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java @@ -5,15 +5,25 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import java.lang.reflect.Field; import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; import org.assertj.core.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created * ByteBuffs */ -class LeaksTrackingByteBufAllocator implements ByteBufAllocator { +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + static final Logger LOGGER = LoggerFactory.getLogger(LeaksTrackingByteBufAllocator.class); /** * Allows to instrument any given the instance of ByteBufAllocator @@ -22,7 +32,7 @@ class LeaksTrackingByteBufAllocator implements ByteBufAllocator { * @return */ public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { - return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO); + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO, ""); } /** @@ -32,8 +42,8 @@ public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocato * @return */ public static LeaksTrackingByteBufAllocator instrument( - ByteBufAllocator allocator, Duration awaitZeroRefCntDuration) { - return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration); + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration, String tag) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration, tag); } final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); @@ -42,34 +52,76 @@ public static LeaksTrackingByteBufAllocator instrument( final Duration awaitZeroRefCntDuration; + final String tag; + private LeaksTrackingByteBufAllocator( - ByteBufAllocator delegate, Duration awaitZeroRefCntDuration) { + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration, String tag) { this.delegate = delegate; this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + this.tag = tag; } public LeaksTrackingByteBufAllocator assertHasNoLeaks() { try { - Assertions.assertThat(tracker) - .allSatisfy( - buf -> - Assertions.assertThat(buf) - .matches( - bb -> { - final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; - if (!awaitZeroRefCntDuration.isZero()) { - long end = - awaitZeroRefCntDuration.plusNanos(System.nanoTime()).toNanos(); - while (bb.refCnt() != 0) { - if (System.nanoTime() >= end) { - break; - } - parkNanos(100); - } - } - return bb.refCnt() == 0; - }, - "buffer should be released")); + ArrayList unreleased = new ArrayList<>(); + for (ByteBuf bb : tracker) { + if (bb.refCnt() != 0) { + unreleased.add(bb); + } + } + + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!unreleased.isEmpty() && !awaitZeroRefCntDuration.isZero()) { + final long startTime = System.currentTimeMillis(); + final long endTimeInMillis = startTime + awaitZeroRefCntDuration.toMillis(); + boolean hasUnreleased; + while (System.currentTimeMillis() <= endTimeInMillis) { + hasUnreleased = false; + for (ByteBuf bb : unreleased) { + if (bb.refCnt() != 0) { + hasUnreleased = true; + break; + } + } + + if (!hasUnreleased) { + return this; + } + + LOGGER.debug(tag + " await buffers to be released"); + for (int i = 0; i < 100; i++) { + System.gc(); + parkNanos(1000); + System.gc(); + } + } + } + + Set collected = new HashSet<>(); + for (ByteBuf buf : unreleased) { + if (buf.refCnt() != 0) { + try { + collected.add(buf); + } catch (IllegalReferenceCountException ignored) { + // fine to ignore if throws because of refCnt + } + } + } + + Assertions.assertThat( + collected + .stream() + .filter(bb -> bb.refCnt() != 0) + .peek( + bb -> { + try { + LOGGER.debug(tag + " " + resolveTrackingInfo(bb)); + } catch (Exception e) { + e.printStackTrace(); + } + })) + .describedAs("[" + tag + "] all buffers expected to be released but got ") + .isEmpty(); } finally { tracker.clear(); } @@ -183,4 +235,60 @@ T track(T buffer) { return buffer; } + + static final Class simpleLeakAwareCompositeByteBufClass; + static final Field leakFieldForComposite; + static final Class simpleLeakAwareByteBufClass; + static final Field leakFieldForNormal; + static final Field allLeaksField; + + static { + try { + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareCompositeByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareCompositeByteBufClass = aClass; + leakFieldForComposite = leakField; + } + + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareByteBufClass = aClass; + leakFieldForNormal = leakField; + } + + { + final Class aClass = + Class.forName("io.netty.util.ResourceLeakDetector$DefaultResourceLeak"); + final Field field = aClass.getDeclaredField("allLeaks"); + + field.setAccessible(true); + + allLeaksField = field; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + static Set resolveTrackingInfo(ByteBuf byteBuf) throws Exception { + if (ResourceLeakDetector.getLevel().ordinal() + >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { + if (simpleLeakAwareCompositeByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForComposite.get(byteBuf)); + } else if (simpleLeakAwareByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForNormal.get(byteBuf)); + } + } + + return Collections.emptySet(); + } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java index 9017e854b..14740950a 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java @@ -63,8 +63,8 @@ Flux pingPong( BiFunction> interaction, int count, final Recorder histogram) { - return client - .flatMapMany( + return Flux.usingWhen( + client, rsocket -> Flux.range(1, count) .flatMap( @@ -78,7 +78,11 @@ Flux pingPong( histogram.recordValue(diff); }); }, - 64)) + 64), + rsocket -> { + rsocket.dispose(); + return rsocket.onClose(); + }) .doOnError(Throwable::printStackTrace); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java b/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java new file mode 100644 index 000000000..57a00e229 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java @@ -0,0 +1,166 @@ +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.PayloadFrameCodec; +import java.net.SocketAddress; +import java.util.function.BiFunction; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +public class TestDuplexConnection implements DuplexConnection { + + final ByteBufAllocator allocator; + final Sinks.Many inbound = Sinks.unsafe().many().unicast().onBackpressureError(); + final Sinks.Many outbound = Sinks.unsafe().many().unicast().onBackpressureError(); + final Sinks.One close = Sinks.one(); + + public TestDuplexConnection( + CoreSubscriber outboundSubscriber, boolean trackLeaks) { + this.outbound.asFlux().subscribe(outboundSubscriber); + this.allocator = + trackLeaks + ? LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT) + : ByteBufAllocator.DEFAULT; + } + + @Override + public void dispose() { + this.inbound.tryEmitComplete(); + this.outbound.tryEmitComplete(); + this.close.tryEmitEmpty(); + } + + @Override + public Mono onClose() { + return this.close.asMono(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException errorException) {} + + @Override + public Flux receive() { + return this.inbound + .asFlux() + .transform( + Operators.lift( + (BiFunction< + Scannable, + CoreSubscriber, + CoreSubscriber>) + ByteBufReleaserOperator::create)); + } + + @Override + public ByteBufAllocator alloc() { + return this.allocator; + } + + @Override + public SocketAddress remoteAddress() { + return new SocketAddress() { + @Override + public String toString() { + return "Test"; + } + }; + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + this.outbound.tryEmitNext(frame); + } + + public void sendPayloadFrame( + int streamId, ByteBuf data, @Nullable ByteBuf metadata, boolean complete) { + sendFrame( + streamId, + PayloadFrameCodec.encode(this.allocator, streamId, false, complete, true, metadata, data)); + } + + static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + static CoreSubscriber create( + Scannable scannable, CoreSubscriber actual) { + return new ByteBufReleaserOperator(actual); + } + + final CoreSubscriber actual; + + Subscription s; + + public ByteBufReleaserOperator(CoreSubscriber actual) { + this.actual = actual; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + this.actual.onNext(buf); + buf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index e322ad292..1b294e394 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -95,7 +95,7 @@ public boolean awaitAllInteractionTermination(Duration duration) { } public boolean awaitUntilObserved(int interactions, Duration duration) { - long end = duration.plusNanos(System.nanoTime()).toNanos(); + long end = System.nanoTime() + duration.toNanos(); long observed; while ((observed = observedInteractions.get()) < interactions) { if (System.nanoTime() >= end) { diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index 436550130..1fcca97db 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,46 +25,57 @@ import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; +import io.rsocket.RSocketErrorException; import io.rsocket.core.RSocketConnector; import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.resume.InMemoryResumableFramesStore; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.util.ByteBufPayload; import java.io.BufferedReader; import java.io.InputStreamReader; +import java.net.SocketAddress; import java.time.Duration; +import java.util.Arrays; import java.util.concurrent.CancellationException; import java.util.concurrent.Executors; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; +import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.zip.GZIPInputStream; import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -import org.junit.platform.commons.logging.Logger; -import org.junit.platform.commons.logging.LoggerFactory; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; import reactor.core.scheduler.Scheduler; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; +import reactor.util.Logger; +import reactor.util.Loggers; public interface TransportTest { - Logger logger = LoggerFactory.getLogger(TransportTest.class); + Logger logger = Loggers.getLogger(TransportTest.class); String MOCK_DATA = "test-data"; String MOCK_METADATA = "metadata"; @@ -72,7 +83,6 @@ public interface TransportTest { Payload LARGE_PAYLOAD = ByteBufPayload.create(LARGE_DATA, LARGE_DATA); static String read(String resourceName) { - try (BufferedReader br = new BufferedReader( new InputStreamReader( @@ -86,16 +96,55 @@ static String read(String resourceName) { } @BeforeEach - default void setUp() { + default void setup() { Hooks.onOperatorDebug(); } @AfterEach default void close() { - getTransportPair().responder.awaitAllInteractionTermination(getTimeout()); - getTransportPair().dispose(); - getTransportPair().byteBufAllocator.assertHasNoLeaks(); - Hooks.resetOnOperatorDebug(); + try { + logger.debug("------------------Awaiting communication to finish------------------"); + getTransportPair().responder.awaitAllInteractionTermination(getTimeout()); + logger.debug("---------------------Disposing Client And Server--------------------"); + getTransportPair().dispose(); + getTransportPair().awaitClosed(getTimeout()); + logger.debug("------------------------Disposing Schedulers-------------------------"); + Schedulers.parallel().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + Schedulers.boundedElastic().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + Schedulers.single().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + logger.debug("---------------------------Leaks Checking----------------------------"); + RuntimeException throwable = + new RuntimeException() { + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } + + @Override + public String getMessage() { + return Arrays.toString(getSuppressed()); + } + }; + + try { + getTransportPair().byteBufAllocator2.assertHasNoLeaks(); + } catch (Throwable t) { + throwable = Exceptions.addSuppressed(throwable, t); + } + + try { + getTransportPair().byteBufAllocator1.assertHasNoLeaks(); + } catch (Throwable t) { + throwable = Exceptions.addSuppressed(throwable, t); + } + + if (throwable.getSuppressed().length > 0) { + throw throwable; + } + } finally { + Hooks.resetOnOperatorDebug(); + Schedulers.resetOnHandleError(); + } } default Payload createTestPayload(int metadataPresent) { @@ -152,6 +201,7 @@ default RSocket getClient() { @DisplayName("makes 10 metadataPush requests") @Test default void metadataPush10() { + Assumptions.assumeThat(getTransportPair().withResumability).isFalse(); Flux.range(1, 10) .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", "test-metadata"))) .as(StepVerifier::create) @@ -164,6 +214,7 @@ default void metadataPush10() { @DisplayName("makes 10 metadataPush with Large Metadata in requests") @Test default void largePayloadMetadataPush10() { + Assumptions.assumeThat(getTransportPair().withResumability).isFalse(); Flux.range(1, 10) .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", LARGE_DATA))) .as(StepVerifier::create) @@ -194,7 +245,7 @@ default void requestChannel1() { .requestChannel(Mono.just(createTestPayload(0))) .doOnNext(Payload::release) .as(StepVerifier::create) - .expectNextCount(1) + .thenConsumeWhile(new PayloadPredicate(1)) .expectComplete() .verify(getTimeout()); } @@ -207,8 +258,9 @@ default void requestChannel200_000() { getClient() .requestChannel(payloads) .doOnNext(Payload::release) + .limitRate(8) .as(StepVerifier::create) - .expectNextCount(200_000) + .thenConsumeWhile(new PayloadPredicate(200_000)) .expectComplete() .verify(getTimeout()); } @@ -222,7 +274,7 @@ default void largePayloadRequestChannel50() { .requestChannel(payloads) .doOnNext(Payload::release) .as(StepVerifier::create) - .expectNextCount(50) + .thenConsumeWhile(new PayloadPredicate(50)) .expectComplete() .verify(getTimeout()); } @@ -237,7 +289,7 @@ default void requestChannel20_000() { .doOnNext(this::assertChannelPayload) .doOnNext(Payload::release) .as(StepVerifier::create) - .expectNextCount(20_000) + .thenConsumeWhile(new PayloadPredicate(20_000)) .expectComplete() .verify(getTimeout()); } @@ -250,8 +302,9 @@ default void requestChannel2_000_000() { getClient() .requestChannel(payloads) .doOnNext(Payload::release) + .limitRate(8) .as(StepVerifier::create) - .expectNextCount(2_000_000) + .thenConsumeWhile(new PayloadPredicate(2_000_000)) .expectComplete() .verify(getTimeout()); } @@ -267,31 +320,44 @@ default void requestChannel3() { .requestChannel(payloads) .doOnNext(Payload::release) .as(publisher -> StepVerifier.create(publisher, 3)) - .expectNextCount(3) + .thenConsumeWhile(new PayloadPredicate(3)) .expectComplete() .verify(getTimeout()); Assertions.assertThat(requested.get()).isEqualTo(3L); } - @DisplayName("makes 1 requestChannel request with 512 payloads") + @DisplayName("makes 1 requestChannel request with 256 payloads") @Test - default void requestChannel512() { - Flux payloads = Flux.range(0, 512).map(this::createTestPayload); - final Scheduler scheduler = Schedulers.fromExecutorService(Executors.newFixedThreadPool(13)); - - Flux.range(0, 1024) - .flatMap(v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(scheduler), 12) - .blockLast(); + default void requestChannel256() { + AtomicInteger counter = new AtomicInteger(); + Flux payloads = + Flux.defer( + () -> { + final int subscription = counter.getAndIncrement(); + return Flux.range(0, 256) + .map(i -> "S{" + subscription + "}: Data{" + i + "}") + .map(data -> ByteBufPayload.create(data)); + }); + final Scheduler scheduler = Schedulers.fromExecutorService(Executors.newFixedThreadPool(12)); + + try { + Flux.range(0, 1024) + .flatMap(v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(scheduler), 12) + .blockLast(); + } finally { + scheduler.disposeGracefully().block(); + } } default void check(Flux payloads) { getClient() .requestChannel(payloads) - .doOnNext(Payload::release) + .doOnNext(ReferenceCounted::release) + .limitRate(8) .as(StepVerifier::create) - .expectNextCount(512) - .as("expected 512 items") + .thenConsumeWhile(new PayloadPredicate(256)) + .as("expected 256 items") .expectComplete() .verify(getTimeout()); } @@ -417,11 +483,20 @@ default void assertChannelPayload(Payload p) { } class TransportPair implements Disposable { + private static final String data = "hello world"; private static final String metadata = "metadata"; - private final LeaksTrackingByteBufAllocator byteBufAllocator = - LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT, Duration.ofMinutes(1)); + private final boolean withResumability; + private final boolean runClientWithAsyncInterceptors; + private final boolean runServerWithAsyncInterceptors; + + private final LeaksTrackingByteBufAllocator byteBufAllocator1 = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofMinutes(1), "Client"); + private final LeaksTrackingByteBufAllocator byteBufAllocator2 = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofMinutes(1), "Server"); private final TestRSocket responder; @@ -441,19 +516,40 @@ public TransportPair( TriFunction clientTransportSupplier, BiFunction> serverTransportSupplier, boolean withRandomFragmentation) { + this( + addressSupplier, + clientTransportSupplier, + serverTransportSupplier, + withRandomFragmentation, + false); + } + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier, + boolean withRandomFragmentation, + boolean withResumability) { + Schedulers.onHandleError((t, e) -> e.printStackTrace()); + Schedulers.resetFactory(); + + this.withResumability = withResumability; T address = addressSupplier.get(); - final boolean runClientWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); - final boolean runServerWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + this.runClientWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + this.runServerWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); - ByteBufAllocator allocatorToSupply; + ByteBufAllocator allocatorToSupply1; + ByteBufAllocator allocatorToSupply2; if (ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.ADVANCED || ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.PARANOID) { - logger.info(() -> "Using LeakTrackingByteBufAllocator"); - allocatorToSupply = byteBufAllocator; + logger.info("Using LeakTrackingByteBufAllocator"); + allocatorToSupply1 = byteBufAllocator1; + allocatorToSupply2 = byteBufAllocator2; } else { - allocatorToSupply = ByteBufAllocator.DEFAULT; + allocatorToSupply1 = ByteBufAllocator.DEFAULT; + allocatorToSupply2 = ByteBufAllocator.DEFAULT; } responder = new TestRSocket(TransportPair.data, metadata); final RSocketServer rSocketServer = @@ -461,14 +557,13 @@ public TransportPair( .payloadDecoder(PayloadDecoder.ZERO_COPY) .interceptors( registry -> { - if (runServerWithAsyncInterceptors) { + if (runServerWithAsyncInterceptors && !withResumability) { logger.info( - () -> - "Perform Integration Test with Async Interceptors Enabled For Server"); + "Perform Integration Test with Async Interceptors Enabled For Server"); registry .forConnection( (type, duplexConnection) -> - new AsyncDuplexConnection(duplexConnection)) + new AsyncDuplexConnection(duplexConnection, "server")) .forSocketAcceptor( delegate -> (connectionSetupPayload, sendingSocket) -> @@ -476,29 +571,47 @@ public TransportPair( .accept(connectionSetupPayload, sendingSocket) .subscribeOn(Schedulers.parallel())); } + + if (withResumability) { + registry.forConnection( + (type, duplexConnection) -> + type == DuplexConnectionInterceptor.Type.SOURCE + ? new DisconnectingDuplexConnection( + "Server", + duplexConnection, + Duration.ofMillis( + ThreadLocalRandom.current().nextInt(100, 1000))) + : duplexConnection); + } }); + if (withResumability) { + rSocketServer.resume( + new Resume() + .storeFactory( + token -> new InMemoryResumableFramesStore("server", token, Integer.MAX_VALUE))); + } + if (withRandomFragmentation) { rSocketServer.fragment(ThreadLocalRandom.current().nextInt(256, 512)); } server = - rSocketServer.bind(serverTransportSupplier.apply(address, allocatorToSupply)).block(); + rSocketServer.bind(serverTransportSupplier.apply(address, allocatorToSupply2)).block(); final RSocketConnector rSocketConnector = RSocketConnector.create() .payloadDecoder(PayloadDecoder.ZERO_COPY) - .keepAlive(Duration.ofMillis(Integer.MAX_VALUE), Duration.ofMillis(Integer.MAX_VALUE)) + .keepAlive(Duration.ofMillis(10), Duration.ofHours(1)) .interceptors( registry -> { - if (runClientWithAsyncInterceptors) { + if (runClientWithAsyncInterceptors && !withResumability) { logger.info( - () -> - "Perform Integration Test with Async Interceptors Enabled For Client"); + "Perform Integration Test with Async Interceptors Enabled For Client"); registry .forConnection( (type, duplexConnection) -> - new AsyncDuplexConnection(duplexConnection)) + new AsyncDuplexConnection(duplexConnection, "client")) .forSocketAcceptor( delegate -> (connectionSetupPayload, sendingSocket) -> @@ -506,22 +619,41 @@ public TransportPair( .accept(connectionSetupPayload, sendingSocket) .subscribeOn(Schedulers.parallel())); } + + if (withResumability) { + registry.forConnection( + (type, duplexConnection) -> + type == DuplexConnectionInterceptor.Type.SOURCE + ? new DisconnectingDuplexConnection( + "Client", + duplexConnection, + Duration.ofMillis( + ThreadLocalRandom.current().nextInt(10, 1500))) + : duplexConnection); + } }); + if (withResumability) { + rSocketConnector.resume( + new Resume() + .storeFactory( + token -> new InMemoryResumableFramesStore("client", token, Integer.MAX_VALUE))); + } + if (withRandomFragmentation) { rSocketConnector.fragment(ThreadLocalRandom.current().nextInt(256, 512)); } client = rSocketConnector - .connect(clientTransportSupplier.apply(address, server, allocatorToSupply)) + .connect(clientTransportSupplier.apply(address, server, allocatorToSupply1)) .doOnError(Throwable::printStackTrace) .block(); } @Override public void dispose() { - server.dispose(); + logger.info("terminating transport pair"); client.dispose(); } @@ -537,30 +669,75 @@ public String expectedPayloadMetadata() { return metadata; } + public void awaitClosed(Duration timeout) { + logger.info("awaiting termination of transport pair"); + logger.info( + "wrappers combination: client{async=" + + runClientWithAsyncInterceptors + + "; resume=" + + withResumability + + "} server{async=" + + runServerWithAsyncInterceptors + + "; resume=" + + withResumability + + "}"); + client + .onClose() + .doOnSubscribe(s -> logger.info("Client termination stage=onSubscribe(" + s + ")")) + .doOnEach(s -> logger.info("Client termination stage=" + s)) + .onErrorResume(t -> Mono.empty()) + .doOnTerminate(() -> logger.info("Client terminated. Terminating Server")) + .then(Mono.fromRunnable(server::dispose)) + .then( + server + .onClose() + .doOnSubscribe( + s -> logger.info("Server termination stage=onSubscribe(" + s + ")")) + .doOnEach(s -> logger.info("Server termination stage=" + s))) + .onErrorResume(t -> Mono.empty()) + .block(timeout); + + logger.info("TransportPair has been terminated"); + } + private static class AsyncDuplexConnection implements DuplexConnection { private final DuplexConnection duplexConnection; + private String tag; + private final ByteBufReleaserOperator bufReleaserOperator; - public AsyncDuplexConnection(DuplexConnection duplexConnection) { + public AsyncDuplexConnection(DuplexConnection duplexConnection, String tag) { this.duplexConnection = duplexConnection; + this.tag = tag; + this.bufReleaserOperator = new ByteBufReleaserOperator(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + duplexConnection.sendFrame(streamId, frame); } @Override - public Mono send(Publisher frames) { - return duplexConnection.send(frames); + public void sendErrorAndClose(RSocketErrorException e) { + duplexConnection.sendErrorAndClose(e); } @Override public Flux receive() { return duplexConnection .receive() - .subscribeOn(Schedulers.parallel()) + .doOnTerminate(() -> logger.info("[" + this + "] Receive is done before PO")) + .subscribeOn(Schedulers.boundedElastic()) .doOnNext(ByteBuf::retain) - .publishOn(Schedulers.parallel(), Integer.MAX_VALUE) + .publishOn(Schedulers.boundedElastic(), Integer.MAX_VALUE) + .doOnTerminate(() -> logger.info("[" + this + "] Receive is done after PO")) .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::safeRelease) .transform( Operators.lift( - (__, actual) -> new ByteBufReleaserOperator(actual))); + (__, actual) -> { + bufReleaserOperator.actual = actual; + return bufReleaserOperator; + })); } @Override @@ -568,26 +745,136 @@ public ByteBufAllocator alloc() { return duplexConnection.alloc(); } + @Override + public SocketAddress remoteAddress() { + return duplexConnection.remoteAddress(); + } + @Override public Mono onClose() { - return duplexConnection.onClose(); + return Mono.whenDelayError( + duplexConnection + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] Source Connection is done")), + bufReleaserOperator + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] BufferReleaser is done"))); } @Override public void dispose() { duplexConnection.dispose(); } + + @Override + public String toString() { + return "AsyncDuplexConnection{" + + "duplexConnection=" + + duplexConnection + + ", tag='" + + tag + + '\'' + + ", bufReleaserOperator=" + + bufReleaserOperator + + '}'; + } + } + + private static class DisconnectingDuplexConnection implements DuplexConnection { + + private final String tag; + final DuplexConnection source; + final Duration delay; + final Disposable.Swap disposables = Disposables.swap(); + + DisconnectingDuplexConnection(String tag, DuplexConnection source, Duration delay) { + this.tag = tag; + this.source = source; + this.delay = delay; + } + + @Override + public void dispose() { + disposables.dispose(); + source.dispose(); + } + + @Override + public Mono onClose() { + return source + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] Source Connection is done")); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException errorException) { + source.sendErrorAndClose(errorException); + } + + boolean receivedFirst; + + @Override + public Flux receive() { + return source + .receive() + .doOnSubscribe( + __ -> logger.warn("Tag {}. Subscribing Connection[{}]", tag, source.hashCode())) + .doOnNext( + bb -> { + if (!receivedFirst) { + receivedFirst = true; + disposables.replace( + Mono.delay(delay) + .takeUntilOther(source.onClose()) + .subscribe( + __ -> { + logger.warn( + "Tag {}. Disposing Connection[{}]", tag, source.hashCode()); + source.dispose(); + })); + } + }); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public String toString() { + return "DisconnectingDuplexConnection{" + + "tag='" + + tag + + '\'' + + ", source=" + + source + + ", disposables=" + + disposables + + '}'; + } } private static class ByteBufReleaserOperator implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { - final CoreSubscriber actual; + CoreSubscriber actual; + final Sinks.Empty closeableMonoSink; Subscription s; - public ByteBufReleaserOperator(CoreSubscriber actual) { - this.actual = actual; + public ByteBufReleaserOperator() { + this.closeableMonoSink = Sinks.unsafe().empty(); } @Override @@ -600,18 +887,27 @@ public void onSubscribe(Subscription s) { @Override public void onNext(ByteBuf buf) { - actual.onNext(buf); - buf.release(); + try { + actual.onNext(buf); + } finally { + buf.release(); + } + } + + Mono onClose() { + return closeableMonoSink.asMono(); } @Override public void onError(Throwable t) { actual.onError(t); + closeableMonoSink.tryEmitError(t); } @Override public void onComplete() { actual.onComplete(); + closeableMonoSink.tryEmitEmpty(); } @Override @@ -622,6 +918,7 @@ public void request(long n) { @Override public void cancel() { s.cancel(); + closeableMonoSink.tryEmitEmpty(); } @Override @@ -648,6 +945,40 @@ public boolean isEmpty() { public void clear() { throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); } + + @Override + public String toString() { + return "ByteBufReleaserOperator{" + + "isActualPresent=" + + (actual != null) + + ", " + + "isSubscriptionPresent=" + + (s != null) + + '}'; + } + } + } + + class PayloadPredicate implements Predicate { + final int expectedCnt; + int cnt; + + public PayloadPredicate(int expectedCnt) { + this.expectedCnt = expectedCnt; + } + + @Override + public boolean test(Payload p) { + boolean shouldConsume = cnt++ < expectedCnt; + if (!shouldConsume) { + logger.info( + "Metadata: \n\r{}\n\rData:{}", + p.hasMetadata() + ? new ByteBufRepresentation().fallbackToStringOf(p.sliceMetadata()) + : "Empty", + new ByteBufRepresentation().fallbackToStringOf(p.sliceData())); + } + return shouldConsume; } } } diff --git a/rsocket-transport-local/build.gradle b/rsocket-transport-local/build.gradle index a5ba84d5c..fc32125e2 100644 --- a/rsocket-transport-local/build.gradle +++ b/rsocket-transport-local/build.gradle @@ -17,8 +17,7 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' } dependencies { @@ -33,4 +32,10 @@ dependencies { testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' } +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.transport.local") + } +} + description = 'Local RSocket transport implementation' diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java index b80fc2337..1b3779e85 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package io.rsocket.transport.local; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.internal.UnboundedProcessor; @@ -24,7 +23,7 @@ import io.rsocket.transport.ServerTransport; import java.util.Objects; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; /** * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} in the @@ -78,14 +77,17 @@ public Mono connect() { return Mono.error(new IllegalArgumentException("Could not find server: " + name)); } - UnboundedProcessor in = new UnboundedProcessor<>(); - UnboundedProcessor out = new UnboundedProcessor<>(); - MonoProcessor closeNotifier = MonoProcessor.create(); + Sinks.One inSink = Sinks.one(); + Sinks.One outSink = Sinks.one(); + UnboundedProcessor in = new UnboundedProcessor(inSink::tryEmitEmpty); + UnboundedProcessor out = new UnboundedProcessor(outSink::tryEmitEmpty); - server.apply(new LocalDuplexConnection(allocator, out, in, closeNotifier)).subscribe(); + Mono onClose = inSink.asMono().and(outSink.asMono()); - return Mono.just( - (DuplexConnection) new LocalDuplexConnection(allocator, in, out, closeNotifier)); + server.apply(new LocalDuplexConnection(name, allocator, out, in, onClose)).subscribe(); + + return Mono.just( + new LocalDuplexConnection(name, allocator, in, out, onClose)); }); } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java index 026f30ced..c1d0fd2a3 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,40 +19,45 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import java.net.SocketAddress; import java.util.Objects; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.Operators; /** An implementation of {@link DuplexConnection} that connects inside the same JVM. */ final class LocalDuplexConnection implements DuplexConnection { + private final LocalSocketAddress address; private final ByteBufAllocator allocator; - private final Flux in; + private final UnboundedProcessor in; - private final MonoProcessor onClose; + private final Mono onClose; - private final Subscriber out; + private final UnboundedProcessor out; /** * Creates a new instance. * + * @param name the name assigned to this local connection * @param in the inbound {@link ByteBuf}s * @param out the outbound {@link ByteBuf}s * @param onClose the closing notifier * @throws NullPointerException if {@code in}, {@code out}, or {@code onClose} are {@code null} */ LocalDuplexConnection( + String name, ByteBufAllocator allocator, - Flux in, - Subscriber out, - MonoProcessor onClose) { + UnboundedProcessor in, + UnboundedProcessor out, + Mono onClose) { + this.address = new LocalSocketAddress(name); this.allocator = Objects.requireNonNull(allocator, "allocator must not be null"); this.in = Objects.requireNonNull(in, "in must not be null"); this.out = Objects.requireNonNull(out, "out must not be null"); @@ -62,12 +67,11 @@ final class LocalDuplexConnection implements DuplexConnection { @Override public void dispose() { out.onComplete(); - onClose.onComplete(); } @Override public boolean isDisposed() { - return onClose.isDisposed(); + return out.isDisposed(); } @Override @@ -78,21 +82,23 @@ public Mono onClose() { @Override public Flux receive() { return in.transform( - Operators.lift((__, actual) -> new ByteBufReleaserOperator(actual))); + Operators.lift( + (__, actual) -> new ByteBufReleaserOperator(actual, this))); } @Override - public Mono send(Publisher frames) { - Objects.requireNonNull(frames, "frames must not be null"); - - return Flux.from(frames).doOnNext(out::onNext).then(); + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + out.tryEmitPrioritized(frame); + } else { + out.tryEmitNormal(frame); + } } @Override - public Mono sendOne(ByteBuf frame) { - Objects.requireNonNull(frame, "frame must not be null"); - out.onNext(frame); - return Mono.empty(); + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + out.tryEmitFinal(errorFrame); } @Override @@ -100,15 +106,28 @@ public ByteBufAllocator alloc() { return allocator; } + @Override + public SocketAddress remoteAddress() { + return address; + } + + @Override + public String toString() { + return "LocalDuplexConnection{" + "address=" + address + "hash=" + hashCode() + '}'; + } + static class ByteBufReleaserOperator implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { final CoreSubscriber actual; + final LocalDuplexConnection parent; Subscription s; - public ByteBufReleaserOperator(CoreSubscriber actual) { + public ByteBufReleaserOperator( + CoreSubscriber actual, LocalDuplexConnection parent) { this.actual = actual; + this.parent = parent; } @Override @@ -121,17 +140,22 @@ public void onSubscribe(Subscription s) { @Override public void onNext(ByteBuf buf) { - actual.onNext(buf); - buf.release(); + try { + actual.onNext(buf); + } finally { + buf.release(); + } } @Override public void onError(Throwable t) { + parent.out.onError(t); actual.onError(t); } @Override public void onComplete() { + parent.out.onComplete(); actual.onComplete(); } @@ -143,6 +167,7 @@ public void request(long n) { @Override public void cancel() { s.cancel(); + parent.out.onComplete(); } @Override diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java index c07713cb3..975cb6793 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,14 +17,18 @@ package io.rsocket.transport.local; import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.Objects; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; import reactor.util.annotation.Nullable; /** @@ -33,7 +37,7 @@ */ public final class LocalServerTransport implements ServerTransport { - private static final ConcurrentMap registry = + private static final ConcurrentMap registry = new ConcurrentHashMap<>(); private final String name; @@ -71,7 +75,10 @@ public static LocalServerTransport createEphemeral() { */ public static void dispose(String name) { Objects.requireNonNull(name, "name must not be null"); - registry.remove(name); + ServerCloseableAcceptor sca = registry.remove(name); + if (sca != null) { + sca.dispose(); + } } /** @@ -106,44 +113,66 @@ public Mono start(ConnectionAcceptor acceptor) { Objects.requireNonNull(acceptor, "acceptor must not be null"); return Mono.create( sink -> { - ServerCloseable closeable = new ServerCloseable(name, acceptor); - if (registry.putIfAbsent(name, acceptor) != null) { - throw new IllegalStateException("name already registered: " + name); + ServerCloseableAcceptor closeable = new ServerCloseableAcceptor(name, acceptor); + if (registry.putIfAbsent(name, closeable) != null) { + sink.error(new IllegalStateException("name already registered: " + name)); } sink.success(closeable); }); } - static class ServerCloseable implements Closeable { + @SuppressWarnings({"ReactorTransformationOnMonoVoid", "CallingSubscribeInNonBlockingScope"}) + static class ServerCloseableAcceptor implements ConnectionAcceptor, Closeable { private final LocalSocketAddress address; private final ConnectionAcceptor acceptor; - private final MonoProcessor onClose = MonoProcessor.create(); + private final Set activeConnections = ConcurrentHashMap.newKeySet(); + + private final Sinks.Empty onClose = Sinks.unsafe().empty(); - ServerCloseable(String name, ConnectionAcceptor acceptor) { + ServerCloseableAcceptor(String name, ConnectionAcceptor acceptor) { Objects.requireNonNull(name, "name must not be null"); this.address = new LocalSocketAddress(name); this.acceptor = acceptor; } + @Override + public Mono apply(DuplexConnection duplexConnection) { + activeConnections.add(duplexConnection); + duplexConnection + .onClose() + .doFinally(__ -> activeConnections.remove(duplexConnection)) + .subscribe(); + return acceptor.apply(duplexConnection); + } + @Override public void dispose() { - if (!registry.remove(address.getName(), acceptor)) { - throw new AssertionError(); + if (!registry.remove(address.getName(), this)) { + // already disposed + return; } - onClose.onComplete(); + + Mono.whenDelayError( + activeConnections + .stream() + .peek(DuplexConnection::dispose) + .map(DuplexConnection::onClose) + .collect(Collectors.toList())) + .subscribe(null, onClose::tryEmitError, onClose::tryEmitEmpty); } @Override + @SuppressWarnings("ConstantConditions") public boolean isDisposed() { - return onClose.isDisposed(); + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return onClose; + return onClose.asMono(); } } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java index d04fd482e..4d0da126a 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import java.util.Objects; /** An implementation of {@link SocketAddress} representing a local connection. */ -final class LocalSocketAddress extends SocketAddress { +public final class LocalSocketAddress extends SocketAddress { private static final long serialVersionUID = -7513338854585475473L; @@ -32,16 +32,17 @@ final class LocalSocketAddress extends SocketAddress { * @param name the name representing the address * @throws NullPointerException if {@code name} is {@code null} */ - LocalSocketAddress(String name) { + public LocalSocketAddress(String name) { this.name = Objects.requireNonNull(name, "name must not be null"); } - @Override - public String toString() { - return "[local server] " + name; + /** Return the name for this connection. */ + public String getName() { + return name; } - String getName() { - return name; + @Override + public String toString() { + return "[local address] " + name; } } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java index ac4c13efe..095de3f0e 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java @@ -19,9 +19,10 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import io.rsocket.Closeable; +import java.time.Duration; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; -import reactor.core.publisher.Mono; import reactor.test.StepVerifier; final class LocalClientTransportTest { @@ -31,12 +32,20 @@ final class LocalClientTransportTest { void connect() { LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); - serverTransport - .start(duplexConnection -> Mono.empty()) - .flatMap(closeable -> LocalClientTransport.create(serverTransport.getName()).connect()) - .as(StepVerifier::create) - .expectNextCount(1) - .verifyComplete(); + Closeable closeable = + serverTransport.start(duplexConnection -> duplexConnection.receive().then()).block(); + + try { + LocalClientTransport.create(serverTransport.getName()) + .connect() + .doOnNext(d -> d.receive().subscribe()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } finally { + closeable.dispose(); + closeable.onClose().block(Duration.ofSeconds(5)); + } } @DisplayName("generates error if server not started") diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java new file mode 100644 index 000000000..28c1dacac --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalResumableTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalResumableTransportTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..8ae16a0a5 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalResumableWithFragmentationTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalResumableWithFragmentationTransportTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java index ed906f65b..e4edafc39 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java @@ -96,11 +96,16 @@ void named() { @DisplayName("starts local server transport") @Test void start() { - LocalServerTransport.createEphemeral() - .start(duplexConnection -> Mono.empty()) - .as(StepVerifier::create) - .expectNextCount(1) - .verifyComplete(); + LocalServerTransport ephemeral = LocalServerTransport.createEphemeral(); + try { + ephemeral + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } finally { + LocalServerTransport.dispose(ephemeral.getName()); + } } @DisplayName("start throws NullPointerException with null acceptor") diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java index e9c137255..87ad2105b 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,18 +19,25 @@ import io.rsocket.test.TransportTest; import java.time.Duration; import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; final class LocalTransportTest implements TransportTest { - private final TransportPair transportPair = - new TransportPair<>( - () -> "test-" + UUID.randomUUID(), - (address, server, allocator) -> LocalClientTransport.create(address, allocator), - (address, allocator) -> LocalServerTransport.create(address)); + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> "LocalTransportTest-" + testInfo.getDisplayName() + "-" + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address)); + } @Override public Duration getTimeout() { - return Duration.ofSeconds(10); + return Duration.ofMinutes(1); } @Override diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java index 4c2f47771..3ca5f5911 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,19 +19,30 @@ import io.rsocket.test.TransportTest; import java.time.Duration; import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; final class LocalTransportWithFragmentationTest implements TransportTest { - private final TransportPair transportPair = - new TransportPair<>( - () -> "test-" + UUID.randomUUID(), - (address, server, allocator) -> LocalClientTransport.create(address, allocator), - (address, allocator) -> LocalServerTransport.create(address), - true); + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalTransportWithFragmentationTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + true); + } @Override public Duration getTimeout() { - return Duration.ofSeconds(10); + return Duration.ofMinutes(1); } @Override diff --git a/rsocket-transport-local/src/test/resources/logback-test.xml b/rsocket-transport-local/src/test/resources/logback-test.xml index 01a7fa4cd..5c92235c2 100644 --- a/rsocket-transport-local/src/test/resources/logback-test.xml +++ b/rsocket-transport-local/src/test/resources/logback-test.xml @@ -23,10 +23,27 @@ + + ./test-out.log + false + + %-5relative %-5level %logger{35} - %msg%n + + + + + + + + + + + - + + diff --git a/rsocket-transport-netty/build.gradle b/rsocket-transport-netty/build.gradle index 70db87b3a..39a5ceac5 100644 --- a/rsocket-transport-netty/build.gradle +++ b/rsocket-transport-netty/build.gradle @@ -17,13 +17,12 @@ plugins { id 'java-library' id 'maven-publish' - id 'com.jfrog.artifactory' - id 'com.jfrog.bintray' + id 'signing' id "com.google.osdetector" version "1.4.0" } def os_suffix = "" -if (osdetector.classifier in ["linux-x86_64"] || ["osx-x86_64"] || ["windows-x86_64"]) { +if (osdetector.classifier in ["linux-x86_64", "linux-aarch_64", "osx-x86_64", "osx-aarch_64", "windows-x86_64"]) { os_suffix = "::" + osdetector.classifier } @@ -41,9 +40,21 @@ dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.junit.jupiter:junit-jupiter-params' + testRuntimeOnly 'org.bouncycastle:bcpkix-jdk15on' testRuntimeOnly 'ch.qos.logback:logback-classic' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' testRuntimeOnly 'io.netty:netty-tcnative-boringssl-static' + os_suffix } +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.transport.netty") + } +} + +test { + minHeapSize = "512m" + maxHeapSize = "4096m" +} + description = 'Reactor Netty RSocket transport implementations (TCP, Websocket)' diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java index 80c8b8256..f5d36269c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,17 +19,19 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.frame.FrameLengthCodec; import io.rsocket.internal.BaseDuplexConnection; +import java.net.SocketAddress; import java.util.Objects; -import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.Connection; /** An implementation of {@link DuplexConnection} that connects via TCP. */ public final class TcpDuplexConnection extends BaseDuplexConnection { - + private final String side; private final Connection connection; /** @@ -38,15 +40,19 @@ public final class TcpDuplexConnection extends BaseDuplexConnection { * @param connection the {@link Connection} for managing the server */ public TcpDuplexConnection(Connection connection) { + this("unknown", connection); + } + + /** + * Creates a new instance + * + * @param connection the {@link Connection} for managing the server + */ + public TcpDuplexConnection(String side, Connection connection) { this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.side = side; - connection - .channel() - .closeFuture() - .addListener( - future -> { - if (!isDisposed()) dispose(); - }); + connection.outbound().send(sender).then().doFinally(__ -> connection.dispose()).subscribe(); } @Override @@ -54,11 +60,25 @@ public ByteBufAllocator alloc() { return connection.channel().alloc(); } + @Override + public SocketAddress remoteAddress() { + return connection.channel().remoteAddress(); + } + @Override protected void doOnClose() { - if (!connection.isDisposed()) { - connection.dispose(); - } + connection.dispose(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError(super.onClose(), connection.onTerminate()); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(alloc(), 0, e); + sender.tryEmitFinal(FrameLengthCodec.encode(alloc(), errorFrame.readableBytes(), errorFrame)); } @Override @@ -67,14 +87,12 @@ public Flux receive() { } @Override - public Mono send(Publisher frames) { - if (frames instanceof Mono) { - return connection.outbound().sendObject(((Mono) frames).map(this::encode)).then(); - } - return connection.outbound().send(Flux.from(frames).map(this::encode)).then(); + public void sendFrame(int streamId, ByteBuf frame) { + super.sendFrame(streamId, FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame)); } - private ByteBuf encode(ByteBuf frame) { - return FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame); + @Override + public String toString() { + return "TcpDuplexConnection{" + "side='" + side + '\'' + ", connection=" + connection + '}'; } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java index a3745bd1f..8f1170c5b 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,11 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; import io.rsocket.internal.BaseDuplexConnection; +import java.net.SocketAddress; import java.util.Objects; -import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.Connection; @@ -34,7 +36,7 @@ * stitched back on for frames received. */ public final class WebsocketDuplexConnection extends BaseDuplexConnection { - + private final String side; private final Connection connection; /** @@ -43,15 +45,24 @@ public final class WebsocketDuplexConnection extends BaseDuplexConnection { * @param connection the {@link Connection} to for managing the server */ public WebsocketDuplexConnection(Connection connection) { + this("unknown", connection); + } + + /** + * Creates a new instance + * + * @param connection the {@link Connection} to for managing the server + */ + public WebsocketDuplexConnection(String side, Connection connection) { this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.side = side; connection - .channel() - .closeFuture() - .addListener( - future -> { - if (!isDisposed()) dispose(); - }); + .outbound() + .sendObject(sender.map(BinaryWebSocketFrame::new)) + .then() + .doFinally(__ -> connection.dispose()) + .subscribe(); } @Override @@ -59,11 +70,19 @@ public ByteBufAllocator alloc() { return connection.channel().alloc(); } + @Override + public SocketAddress remoteAddress() { + return connection.channel().remoteAddress(); + } + @Override protected void doOnClose() { - if (!connection.isDisposed()) { - connection.dispose(); - } + connection.dispose(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError(super.onClose(), connection.onTerminate()); } @Override @@ -72,16 +91,19 @@ public Flux receive() { } @Override - public Mono send(Publisher frames) { - if (frames instanceof Mono) { - return connection - .outbound() - .sendObject(((Mono) frames).map(BinaryWebSocketFrame::new)) - .then(); - } - return connection - .outbound() - .sendObject(Flux.from(frames).map(BinaryWebSocketFrame::new)) - .then(); + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(alloc(), 0, e); + sender.tryEmitFinal(errorFrame); + } + + @Override + public String toString() { + return "WebsocketDuplexConnection{" + + "side='" + + side + + '\'' + + ", connection=" + + connection + + '}'; } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java index f64c6063c..84214b98c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -116,6 +116,6 @@ public Mono connect() { return client .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec(maxFrameLength))) .connect() - .map(TcpDuplexConnection::new); + .map(connection -> new TcpDuplexConnection("client", connection)); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index fe66da50a..86be47893 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -172,6 +172,6 @@ public Mono connect() { .websocket(specBuilder.build()) .uri(path) .connect() - .map(WebsocketDuplexConnection::new); + .map(connection -> new WebsocketDuplexConnection("client", connection)); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java index 5f04eb575..33cff28b4 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java @@ -24,10 +24,7 @@ abstract class BaseWebsocketServerTransport< private static final ChannelHandler pongHandler = new PongHandler(); static Function serverConfigurer = - server -> - server.tcpConfiguration( - tcpServer -> - tcpServer.doOnConnection(connection -> connection.addHandlerLast(pongHandler))); + server -> server.doOnConnection(connection -> connection.addHandlerLast(pongHandler)); final WebsocketServerSpec.Builder specBuilder = WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK); diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java index 3c8192eb3..7e98905ff 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java @@ -29,7 +29,7 @@ */ public final class CloseableChannel implements Closeable { - /** For 1.0 and 1.1 compatibility: remove when RSocket requires Reactor Netty 1.0+. */ + /** For forward compatibility: remove when RSocket compiles against Reactor 1.0. */ private static final Method channelAddressMethod; static { @@ -61,7 +61,7 @@ public final class CloseableChannel implements Closeable { public InetSocketAddress address() { try { return (InetSocketAddress) channel.address(); - } catch (NoSuchMethodError e) { + } catch (ClassCastException | NoSuchMethodError e) { try { return (InetSocketAddress) channelAddressMethod.invoke(this.channel); } catch (Exception ex) { diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java index effc7bed5..32562c4a4 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -114,7 +114,7 @@ public Mono start(ConnectionAcceptor acceptor) { c -> { c.addHandlerLast(new RSocketLengthCodec(maxFrameLength)); acceptor - .apply(new TcpDuplexConnection(c)) + .apply(new TcpDuplexConnection("server", c)) .then(Mono.never()) .subscribe(c.disposeSubscriber()); }) diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java index 38344c472..db13720e7 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -80,6 +80,8 @@ public Mono start(ConnectionAcceptor acceptor) { public static BiFunction> newHandler( ConnectionAcceptor acceptor) { return (in, out) -> - acceptor.apply(new WebsocketDuplexConnection((Connection) in)).then(out.neverComplete()); + acceptor + .apply(new WebsocketDuplexConnection("server", (Connection) in)) + .then(out.neverComplete()); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index 81ac8dcb6..4fe736fad 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -117,7 +117,7 @@ public Mono start(ConnectionAcceptor acceptor) { return response.sendWebsocket( (in, out) -> acceptor - .apply(new WebsocketDuplexConnection((Connection) in)) + .apply(new WebsocketDuplexConnection("server", (Connection) in)) .then(out.neverComplete()), specBuilder.build()); }) diff --git a/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json b/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json new file mode 100644 index 000000000..3a2baa440 --- /dev/null +++ b/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json @@ -0,0 +1,16 @@ +[ + { + "condition": { + "typeReachable": "io.rsocket.transport.netty.RSocketLengthCodec" + }, + "name": "io.rsocket.transport.netty.RSocketLengthCodec", + "queryAllPublicMethods": true + }, + { + "condition": { + "typeReachable": "io.rsocket.transport.netty.server.BaseWebsocketServerTransport$PongHandler" + }, + "name": "io.rsocket.transport.netty.server.BaseWebsocketServerTransport$PongHandler", + "queryAllPublicMethods": true + } +] \ No newline at end of file diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java new file mode 100644 index 000000000..f05713215 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java @@ -0,0 +1,190 @@ +package io.rsocket.integration; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; +import reactor.util.retry.RetryBackoffSpec; + +/** + * Test case that reproduces the following GitHub Issue + */ +public class KeepaliveTest { + + private static final Logger LOG = LoggerFactory.getLogger(KeepaliveTest.class); + private static final int PORT = 23200; + + private CloseableChannel server; + + @BeforeEach + void setUp() { + server = createServer().block(); + } + + @AfterEach + void tearDown() { + server.dispose(); + server.onClose().block(); + } + + @Test + void keepAliveTest() { + RSocketClient rsocketClient = createClient(); + + int expectedCount = 4; + AtomicBoolean sleepOnce = new AtomicBoolean(true); + StepVerifier.create( + Flux.range(0, expectedCount) + .delayElements(Duration.ofMillis(2000)) + .concatMap( + i -> + rsocketClient + .requestResponse(Mono.just(DefaultPayload.create(""))) + .doOnNext( + __ -> { + if (sleepOnce.getAndSet(false)) { + try { + LOG.info("Sleeping..."); + Thread.sleep(1_000); + LOG.info("Waking up."); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }) + .log("id " + i) + .onErrorComplete())) + .expectSubscription() + .expectNextCount(expectedCount) + .verifyComplete(); + } + + @Test + void keepAliveTestLazy() { + Mono rsocketMono = createClientLazy(); + + int expectedCount = 4; + AtomicBoolean sleepOnce = new AtomicBoolean(true); + StepVerifier.create( + Flux.range(0, expectedCount) + .delayElements(Duration.ofMillis(2000)) + .concatMap( + i -> + rsocketMono.flatMap( + rsocket -> + rsocket + .requestResponse(DefaultPayload.create("")) + .doOnNext( + __ -> { + if (sleepOnce.getAndSet(false)) { + try { + LOG.info("Sleeping..."); + Thread.sleep(1_000); + LOG.info("Waking up."); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }) + .log("id " + i) + .onErrorComplete()))) + .expectSubscription() + .expectNextCount(expectedCount) + .verifyComplete(); + } + + private static Mono createServer() { + LOG.info("Starting server at port {}", PORT); + + TcpServer tcpServer = TcpServer.create().host("localhost").port(PORT); + + return RSocketServer.create( + (setupPayload, rSocket) -> { + rSocket + .onClose() + .doFirst(() -> LOG.info("Connected on server side.")) + .doOnTerminate(() -> LOG.info("Connection closed on server side.")) + .subscribe(); + + return Mono.just(new MyServerRsocket()); + }) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(tcpServer)) + .doOnNext(closeableChannel -> LOG.info("RSocket server started.")); + } + + private static RSocketClient createClient() { + LOG.info("Connecting...."); + + Function reconnectSpec = + reason -> + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(10L)) + .doBeforeRetry(retrySignal -> LOG.info("Reconnecting. Reason: {}", reason)); + + Mono rsocketMono = + RSocketConnector.create() + .fragment(16384) + .reconnect(reconnectSpec.apply("connector-close")) + .keepAlive(Duration.ofMillis(100L), Duration.ofMillis(900L)) + .connect(TcpClientTransport.create(TcpClient.create().host("localhost").port(PORT))); + + RSocketClient client = RSocketClient.from(rsocketMono); + + client + .source() + .doOnNext(r -> LOG.info("Got RSocket")) + .flatMap(RSocket::onClose) + .doOnError(err -> LOG.error("Error during onClose.", err)) + .retryWhen(reconnectSpec.apply("client-close")) + .doFirst(() -> LOG.info("Connected on client side.")) + .doOnTerminate(() -> LOG.info("Connection closed on client side.")) + .repeat() + .subscribe(); + + return client; + } + + private static Mono createClientLazy() { + LOG.info("Connecting...."); + + Function reconnectSpec = + reason -> + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(10L)) + .doBeforeRetry(retrySignal -> LOG.info("Reconnecting. Reason: {}", reason)); + + return RSocketConnector.create() + .fragment(16384) + .reconnect(reconnectSpec.apply("connector-close")) + .keepAlive(Duration.ofMillis(100L), Duration.ofMillis(900L)) + .connect(TcpClientTransport.create(TcpClient.create().host("localhost").port(PORT))); + } + + public static class MyServerRsocket implements RSocket { + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just("Pong").map(DefaultPayload::create); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java index 6fd3de791..76c352768 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.transport.netty; import io.rsocket.ConnectionSetupPayload; @@ -20,9 +35,9 @@ import java.util.function.Function; import java.util.stream.Stream; import org.junit.jupiter.params.provider.Arguments; -import reactor.core.publisher.EmitterProcessor; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; import reactor.test.StepVerifier; public class SetupRejectionTest { @@ -85,21 +100,21 @@ static Stream transports() { } static class ErrorConsumer implements Consumer { - private final EmitterProcessor errors = EmitterProcessor.create(); + private final Sinks.Many errors = Sinks.many().multicast().onBackpressureBuffer(); @Override public void accept(Throwable t) { - errors.onNext(t); + errors.tryEmitNext(t); } Flux errors() { - return errors; + return errors.asFlux(); } } private static class RejectingAcceptor implements SocketAcceptor { private final String msg; - private final EmitterProcessor requesters = EmitterProcessor.create(); + private final Sinks.Many requesters = Sinks.many().multicast().onBackpressureBuffer(); public RejectingAcceptor(String msg) { this.msg = msg; @@ -107,12 +122,12 @@ public RejectingAcceptor(String msg) { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { - requesters.onNext(sendingSocket); + requesters.tryEmitNext(sendingSocket); return Mono.error(new RuntimeException(msg)); } public Mono requesterRSocket() { - return requesters.next(); + return requesters.asFlux().next(); } } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java index 299ea96c0..b17da654f 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,25 +22,31 @@ import io.rsocket.transport.netty.server.TcpServerTransport; import java.net.InetSocketAddress; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; import reactor.netty.tcp.TcpClient; import reactor.netty.tcp.TcpServer; final class TcpFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; - private final TransportPair transportPair = - new TransportPair<>( - () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server, allocator) -> - TcpClientTransport.create( - TcpClient.create() - .remoteAddress(server::address) - .option(ChannelOption.ALLOCATOR, allocator)), - (address, allocator) -> - TcpServerTransport.create( + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( TcpServer.create() .bindAddress(() -> address) - .option(ChannelOption.ALLOCATOR, allocator)), - true); + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true); + } @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java new file mode 100644 index 000000000..7be1c1c54 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpResumableTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..39b3cec67 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpResumableWithFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java index 85481924a..ee49b83cd 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java @@ -1,3 +1,19 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.transport.netty; import io.netty.channel.ChannelOption; @@ -10,41 +26,47 @@ import java.net.InetSocketAddress; import java.security.cert.CertificateException; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; import reactor.core.Exceptions; import reactor.netty.tcp.TcpClient; import reactor.netty.tcp.TcpServer; public class TcpSecureTransportTest implements TransportTest { - private final TransportPair transportPair = - new TransportPair<>( - () -> new InetSocketAddress("localhost", 0), - (address, server, allocator) -> - TcpClientTransport.create( - TcpClient.create() - .option(ChannelOption.ALLOCATOR, allocator) - .remoteAddress(server::address) - .secure( - ssl -> - ssl.sslContext( - SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE)))), - (address, allocator) -> { - try { - SelfSignedCertificate ssc = new SelfSignedCertificate(); - TcpServer server = - TcpServer.create() - .option(ChannelOption.ALLOCATOR, allocator) - .bindAddress(() -> address) - .secure( - ssl -> - ssl.sslContext( - SslContextBuilder.forServer( - ssc.certificate(), ssc.privateKey()))); - return TcpServerTransport.create(server); - } catch (CertificateException e) { - throw Exceptions.propagate(e); - } - }); + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE)))), + (address, allocator) -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + TcpServer server = + TcpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return TcpServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + } @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java index c474f9b0b..428681f3e 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,24 +22,30 @@ import io.rsocket.transport.netty.server.TcpServerTransport; import java.net.InetSocketAddress; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; import reactor.netty.tcp.TcpClient; import reactor.netty.tcp.TcpServer; final class TcpTransportTest implements TransportTest { + private TransportPair transportPair; - private final TransportPair transportPair = - new TransportPair<>( - () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server, allocator) -> - TcpClientTransport.create( - TcpClient.create() - .remoteAddress(server::address) - .option(ChannelOption.ALLOCATOR, allocator)), - (address, allocator) -> - TcpServerTransport.create( + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( TcpServer.create() .bindAddress(() -> address) - .option(ChannelOption.ALLOCATOR, allocator))); + .option(ChannelOption.ALLOCATOR, allocator)); + }); + } @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java index e2ee9e521..ff0fa75b4 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package io.rsocket.transport.netty; import io.netty.buffer.Unpooled; @@ -25,8 +40,9 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; import reactor.netty.http.client.HttpClient; import reactor.netty.http.server.HttpServer; import reactor.test.StepVerifier; @@ -100,13 +116,13 @@ private static Stream provideServerTransport() { } private static class PingSender extends ChannelInboundHandlerAdapter { - private final MonoProcessor channel = MonoProcessor.create(); - private final MonoProcessor pong = MonoProcessor.create(); + private final Sinks.One channel = Sinks.one(); + private final Sinks.One pong = Sinks.one(); @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { if (msg instanceof PongWebSocketFrame) { - pong.onNext(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); + pong.tryEmitValue(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); ReferenceCountUtil.safeRelease(msg); ctx.read(); } else { @@ -117,8 +133,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception @Override public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { Channel ch = ctx.channel(); - if (!channel.isTerminated() && ch.isWritable()) { - channel.onNext(ctx.channel()); + if (!(channel.scan(Scannable.Attr.TERMINATED)) && ch.isWritable()) { + channel.tryEmitValue(ctx.channel()); } super.channelWritabilityChanged(ctx); } @@ -127,7 +143,7 @@ public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exceptio public void handlerAdded(ChannelHandlerContext ctx) throws Exception { Channel ch = ctx.channel(); if (ch.isWritable()) { - channel.onNext(ch); + channel.tryEmitValue(ch); } super.handlerAdded(ctx); } @@ -142,11 +158,11 @@ public Mono sendPong() { } public Mono receivePong() { - return pong; + return pong.asMono(); } private Mono send(WebSocketFrame webSocketFrame) { - return channel.doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then(); + return channel.asMono().doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then(); } } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java new file mode 100644 index 000000000..043f6bc64 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketResumableTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..b1ca65fcc --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketResumableWithFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java index 9777c8bfa..81f7ffb95 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2020 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,45 +26,50 @@ import java.net.InetSocketAddress; import java.security.cert.CertificateException; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; import reactor.core.Exceptions; import reactor.netty.http.client.HttpClient; import reactor.netty.http.server.HttpServer; final class WebsocketSecureTransportTest implements TransportTest { + private TransportPair transportPair; - private final TransportPair transportPair = - new TransportPair<>( - () -> new InetSocketAddress("localhost", 0), - (address, server, allocator) -> - WebsocketClientTransport.create( - HttpClient.create() - .option(ChannelOption.ALLOCATOR, allocator) - .remoteAddress(server::address) - .secure( - ssl -> - ssl.sslContext( - SslContextBuilder.forClient() - .trustManager(InsecureTrustManagerFactory.INSTANCE))), - String.format( - "https://%s:%d/", - server.address().getHostName(), server.address().getPort())), - (address, allocator) -> { - try { - SelfSignedCertificate ssc = new SelfSignedCertificate(); - HttpServer server = - HttpServer.create() - .option(ChannelOption.ALLOCATOR, allocator) - .bindAddress(() -> address) - .secure( - ssl -> - ssl.sslContext( - SslContextBuilder.forServer( - ssc.certificate(), ssc.privateKey()))); - return WebsocketServerTransport.create(server); - } catch (CertificateException e) { - throw Exceptions.propagate(e); - } - }); + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE))), + String.format( + "https://%s:%d/", + server.address().getHostName(), server.address().getPort())), + (address, allocator) -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + HttpServer server = + HttpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return WebsocketServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + } @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java index 93d7bdb2f..cdd507456 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,27 +22,33 @@ import io.rsocket.transport.netty.server.WebsocketServerTransport; import java.net.InetSocketAddress; import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; import reactor.netty.http.client.HttpClient; import reactor.netty.http.server.HttpServer; final class WebsocketTransportTest implements TransportTest { + private TransportPair transportPair; - private final TransportPair transportPair = - new TransportPair<>( - () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server, allocator) -> - WebsocketClientTransport.create( - HttpClient.create() - .host(server.address().getHostName()) - .port(server.address().getPort()) - .option(ChannelOption.ALLOCATOR, allocator), - ""), - (address, allocator) -> - WebsocketServerTransport.create( + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( HttpServer.create() .host(address.getHostName()) .port(address.getPort()) - .option(ChannelOption.ALLOCATOR, allocator))); + .option(ChannelOption.ALLOCATOR, allocator)); + }); + } @Override public Duration getTimeout() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java index 308118955..bd53a9b3f 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java @@ -19,7 +19,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; @@ -57,9 +56,6 @@ void constructorNullContext() { .withMessage("channel must not be null"); } - @Disabled( - "NettyContext isDisposed() is not accurate\n" - + "https://github.com/reactor/reactor-netty/issues/360") @DisplayName("disposes context") @Test void dispose() { diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java index b9b6201b8..540076704 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java @@ -41,7 +41,7 @@ public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() { ArgumentCaptor httpHandlerCaptor = ArgumentCaptor.forClass(BiFunction.class); HttpServer server = Mockito.spy(HttpServer.create()); Mockito.doAnswer(a -> server).when(server).handle(httpHandlerCaptor.capture()); - Mockito.doAnswer(a -> server).when(server).tcpConfiguration(any()); + Mockito.doAnswer(a -> server).when(server).doOnConnection(any()); Mockito.doAnswer(a -> Mono.empty()).when(server).bind(); WebsocketServerTransport serverTransport = WebsocketServerTransport.create(server); diff --git a/rsocket-transport-netty/src/test/resources/logback-test.xml b/rsocket-transport-netty/src/test/resources/logback-test.xml index f9dec2bbe..981d6d0b6 100644 --- a/rsocket-transport-netty/src/test/resources/logback-test.xml +++ b/rsocket-transport-netty/src/test/resources/logback-test.xml @@ -26,6 +26,14 @@ + + + + + + + +