diff --git a/.travis.yml b/.travis.yml index 2743ae0bc..4722957c8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,9 +16,15 @@ --- language: java -jdk: -- oraclejdk8 -# - oraclejdk9 +dist: trusty + +matrix: + include: + - jdk: openjdk8 + - jdk: openjdk11 + env: SKIP_RELEASE=true + - jdk: openjdk14 + env: SKIP_RELEASE=true env: global: @@ -29,12 +35,6 @@ env: script: ci/travis.sh -addons: - apt: - packages: - - oracle-java8-installer - # - oracle-java9-installer - before_cache: - rm -f $HOME/.gradle/caches/modules-2/modules-2.lock - rm -fr $HOME/.gradle/caches/*/plugin-resolution/ @@ -43,4 +43,3 @@ cache: directories: - $HOME/.gradle/caches/ - $HOME/.gradle/wrapper/ - diff --git a/AUTHORS b/AUTHORS index 89f6e3696..ef7dd9dda 100644 --- a/AUTHORS +++ b/AUTHORS @@ -18,3 +18,4 @@ somasun = somasun stevegury = Steve Gury tmontgomery = Todd L. Montgomery yschimke = Yuri Schimke +OlegDokuka = Oleh Dokuka diff --git a/README.md b/README.md index 8a9f85e11..f8110a31e 100644 --- a/README.md +++ b/README.md @@ -15,18 +15,33 @@ Learn more at http://rsocket.io ## Build and Binaries -[![Build Status](https://travis-ci.org/rsocket/rsocket-java.svg?branch=1.0.x)](https://travis-ci.org/rsocket/rsocket-java) +[![Build Status](https://travis-ci.org/rsocket/rsocket-java.svg?branch=develop)](https://travis-ci.org/rsocket/rsocket-java) Releases are available via Maven Central. Example: ```groovy +repositories { + mavenCentral() +} +dependencies { + implementation 'io.rsocket:rsocket-core:1.0.0' + implementation 'io.rsocket:rsocket-transport-netty:1.0.0' +} +``` + +Snapshots are available via [oss.jfrog.org](oss.jfrog.org) (OJO). + +Example: + +```groovy +repositories { + maven { url 'https://oss.jfrog.org/oss-snapshot-local' } +} dependencies { - implementation 'io.rsocket:rsocket-core:0.11.14' - implementation 'io.rsocket:rsocket-transport-netty:0.11.14' -// implementation 'io.rsocket:rsocket-core:0.11.15.BUILD-SNAPSHOT' -// implementation 'io.rsocket:rsocket-transport-netty:0.11.15.BUILD-SNAPSHOT' + implementation 'io.rsocket:rsocket-core:1.0.1-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.0.1-SNAPSHOT' } ``` @@ -52,12 +67,12 @@ Frames can be printed out to help debugging. Set the logger `io.rsocket.FrameLog ## Trivial Client -``` +```java package io.rsocket.transport.netty; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.util.DefaultPayload; import reactor.core.publisher.Flux; @@ -67,14 +82,14 @@ import java.net.URI; public class ExampleClient { public static void main(String[] args) { WebsocketClientTransport ws = WebsocketClientTransport.create(URI.create("ws://rsocket-demo.herokuapp.com/ws")); - RSocket client = RSocketFactory.connect().keepAlive().transport(ws).start().block(); + RSocket clientRSocket = RSocketConnector.connectWith(ws).block(); try { - Flux s = client.requestStream(DefaultPayload.create("peace")); + Flux s = clientRSocket.requestStream(DefaultPayload.create("peace")); s.take(10).doOnNext(p -> System.out.println(p.getDataUtf8())).blockLast(); } finally { - client.dispose(); + clientRSocket.dispose(); } } } @@ -89,12 +104,10 @@ or you will get a memory leak. Used correctly this will reduce latency and incre ### Example Server setup ```java -RSocketFactory.receive() +RSocketServer.create(new PingHandler()) // Enable Zero Copy - .payloadDecoder(Frame::retain) - .acceptor(new PingHandler()) - .transport(TcpServerTransport.create(7878)) - .start() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(7878)) .block() .onClose() .block(); @@ -102,12 +115,12 @@ RSocketFactory.receive() ### Example Client setup ```java -Mono client = - RSocketFactory.connect() +RSocket clientRSocket = + RSocketConnector.create() // Enable Zero Copy - .payloadDecoder(Frame::retain) - .transport(TcpClientTransport.create(7878)) - .start(); + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(TcpClientTransport.create(7878)) + .block(); ``` ## Bugs and Feedback diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..6ba6755a6 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,47 @@ +## Usage of JMH tasks + +Only execute specific benchmark(s) (wildcards are added before and after): +``` +../gradlew jmh --include="(BenchmarkPrimary|OtherBench)" +``` +If you want to specify the wildcards yourself, you can pass the full regexp: +``` +../gradlew jmh --fullInclude=.*MyBenchmark.* +``` + +Specify extra profilers: +``` +../gradlew jmh --profilers="gc,stack" +``` + +Prominent profilers (for full list call `jmhProfilers` task): +- comp - JitCompilations, tune your iterations +- stack - which methods used most time +- gc - print garbage collection stats +- hs_thr - thread usage + +Change report format from JSON to one of [CSV, JSON, NONE, SCSV, TEXT]: +``` +./gradlew jmh --format=csv +``` + +Specify JVM arguments: +``` +../gradlew jmh --jvmArgs="-Dtest.cluster=local" +``` + +Run in verification mode (execute benchmarks with minimum of fork/warmup-/benchmark-iterations): +``` +../gradlew jmh --verify=true +``` + +## Comparing with the baseline +If you wish you run two sets of benchmarks, one for the current change and another one for the "baseline", +there is an additional task `jmhBaseline` that will use the latest release: +``` +../gradlew jmh jmhBaseline --include=MyBenchmark +``` + +## Resources +- http://tutorials.jenkov.com/java-performance/jmh.html (Introduction) +- http://hg.openjdk.java.net/code-tools/jmh/file/tip/jmh-samples/src/main/java/org/openjdk/jmh/samples/ (Samples) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle new file mode 100644 index 000000000..f07f7c6f5 --- /dev/null +++ b/benchmarks/build.gradle @@ -0,0 +1,167 @@ +apply plugin: 'java' +apply plugin: 'idea' + +configurations { + current + baseline { + resolutionStrategy.cacheChangingModulesFor 0, 'seconds' + } +} + +dependencies { + // Use the baseline to avoid using new APIs in the benchmarks + compileOnly "io.rsocket:rsocket-core:${perfBaselineVersion}" + compileOnly "io.rsocket:rsocket-transport-local:${perfBaselineVersion}" + + implementation "org.openjdk.jmh:jmh-core:1.21" + annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.21" + + current project(':rsocket-core') + current project(':rsocket-transport-local') + baseline "io.rsocket:rsocket-core:${perfBaselineVersion}", { + changing = true + } + baseline "io.rsocket:rsocket-transport-local:${perfBaselineVersion}", { + changing = true + } +} + +task jmhProfilers(type: JavaExec, description:'Lists the available profilers for the jmh task', group: 'Development') { + classpath = sourceSets.main.runtimeClasspath + main = 'org.openjdk.jmh.Main' + args '-lprof' +} + +task jmh(type: JmhExecTask, description: 'Executing JMH benchmarks') { + classpath = sourceSets.main.runtimeClasspath + configurations.current +} + +task jmhBaseline(type: JmhExecTask, description: 'Executing JMH baseline benchmarks') { + classpath = sourceSets.main.runtimeClasspath + configurations.baseline +} + +clean { + delete "${projectDir}/src/main/generated" +} + +class JmhExecTask extends JavaExec { + + private String include; + private String fullInclude; + private String exclude; + private String format = "json"; + private String profilers; + private String jmhJvmArgs; + private String verify; + + public JmhExecTask() { + super(); + } + + public String getInclude() { + return include; + } + + @Option(option = "include", description="configure bench inclusion using substring") + public void setInclude(String include) { + this.include = include; + } + + public String getFullInclude() { + return fullInclude; + } + + @Option(option = "fullInclude", description = "explicitly configure bench inclusion using full JMH style regexp") + public void setFullInclude(String fullInclude) { + this.fullInclude = fullInclude; + } + + public String getExclude() { + return exclude; + } + + @Option(option = "exclude", description = "explicitly configure bench exclusion using full JMH style regexp") + public void setExclude(String exclude) { + this.exclude = exclude; + } + + public String getFormat() { + return format; + } + + @Option(option = "format", description = "configure report format") + public void setFormat(String format) { + this.format = format; + } + + public String getProfilers() { + return profilers; + } + + @Option(option = "profilers", description = "configure jmh profiler(s) to use, comma separated") + public void setProfilers(String profilers) { + this.profilers = profilers; + } + + public String getJmhJvmArgs() { + return jmhJvmArgs; + } + + @Option(option = "jvmArgs", description = "configure additional JMH JVM arguments, comma separated") + public void setJmhJvmArgs(String jvmArgs) { + this.jmhJvmArgs = jvmArgs; + } + + public String getVerify() { + return verify; + } + + @Option(option = "verify", description = "run in verify mode") + public void setVerify(String verify) { + this.verify = verify; + } + + @TaskAction + public void exec() { + setMain("org.openjdk.jmh.Main"); + File resultFile = getProject().file("build/reports/" + getName() + "/result." + format); + + if (include != null) { + args(".*" + include + ".*"); + } + else if (fullInclude != null) { + args(fullInclude); + } + + if(exclude != null) { + args("-e", exclude); + } + if(verify != null) { // execute benchmarks with the minimum amount of execution (only to check if they are working) + System.out.println("Running in verify mode"); + args("-f", 1); + args("-wi", 1); + args("-i", 1); + } + args("-foe", "true"); //fail-on-error + args("-v", "NORMAL"); //verbosity [SILENT, NORMAL, EXTRA] + if(profilers != null) { + for (String prof : profilers.split(",")) { + args("-prof", prof); + } + } + args("-jvmArgsPrepend", "-Xmx3072m"); + args("-jvmArgsPrepend", "-Xms3072m"); + if(jmhJvmArgs != null) { + for(String jvmArg : jmhJvmArgs.split(" ")) { + args("-jvmArgsPrepend", jvmArg); + } + } + args("-rf", format); + args("-rff", resultFile); + + System.out.println("\nExecuting JMH with: " + getArgs() + "\n"); + resultFile.getParentFile().mkdirs(); + + super.exec(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java new file mode 100644 index 000000000..2e6fa6acc --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java @@ -0,0 +1,37 @@ +package io.rsocket; + +import java.util.concurrent.CountDownLatch; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +public class MaxPerfSubscriber extends CountDownLatch implements CoreSubscriber { + + final Blackhole blackhole; + + public MaxPerfSubscriber(Blackhole blackhole) { + super(1); + this.blackhole = blackhole; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(T payload) { + blackhole.consume(payload); + } + + @Override + public void onError(Throwable t) { + blackhole.consume(t); + countDown(); + } + + @Override + public void onComplete() { + countDown(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java new file mode 100644 index 000000000..7a7a1fdd6 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsMaxPerfSubscriber extends MaxPerfSubscriber { + + public PayloadsMaxPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java new file mode 100644 index 000000000..efc116958 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsPerfSubscriber extends PerfSubscriber { + + public PayloadsPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java new file mode 100644 index 000000000..92577d95c --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java @@ -0,0 +1,41 @@ +package io.rsocket; + +import java.util.concurrent.CountDownLatch; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +public class PerfSubscriber extends CountDownLatch implements CoreSubscriber { + + final Blackhole blackhole; + + Subscription s; + + public PerfSubscriber(Blackhole blackhole) { + super(1); + this.blackhole = blackhole; + } + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + s.request(1); + } + + @Override + public void onNext(T payload) { + blackhole.consume(payload); + s.request(1); + } + + @Override + public void onError(Throwable t) { + blackhole.consume(t); + countDown(); + } + + @Override + public void onComplete() { + countDown(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java new file mode 100644 index 000000000..f78843f5b --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java @@ -0,0 +1,170 @@ +package io.rsocket.core; + +import io.rsocket.AbstractRSocket; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.PayloadsMaxPerfSubscriber; +import io.rsocket.PayloadsPerfSubscriber; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.util.EmptyPayload; +import java.lang.reflect.Field; +import java.util.Queue; +import java.util.concurrent.locks.LockSupport; +import java.util.stream.IntStream; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Publisher; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@BenchmarkMode(Mode.Throughput) +@Fork( + value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} + ) +@Warmup(iterations = 10) +@Measurement(iterations = 10, time = 20) +@State(Scope.Benchmark) +public class RSocketPerf { + + static final Payload PAYLOAD = EmptyPayload.INSTANCE; + static final Mono PAYLOAD_MONO = Mono.just(PAYLOAD); + static final Flux PAYLOAD_FLUX = + Flux.fromArray(IntStream.range(0, 100000).mapToObj(__ -> PAYLOAD).toArray(Payload[]::new)); + + RSocket client; + Closeable server; + Queue clientsQueue; + + @TearDown + public void tearDown() { + client.dispose(); + server.dispose(); + } + + @TearDown(Level.Iteration) + public void awaitToBeConsumed() { + while (!clientsQueue.isEmpty()) { + LockSupport.parkNanos(1000); + } + } + + @Setup + public void setUp() throws NoSuchFieldException, IllegalAccessException { + server = + RSocketServer.create( + (setup, sendingSocket) -> + Mono.just( + new AbstractRSocket() { + + @Override + public Mono fireAndForget(Payload payload) { + payload.release(); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return PAYLOAD_MONO; + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return PAYLOAD_FLUX; + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(LocalServerTransport.create("server")) + .block(); + + client = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(LocalClientTransport.create("server")) + .block(); + + Field sendProcessorField = RSocketRequester.class.getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); + + clientsQueue = (Queue) sendProcessorField.get(client); + } + + @Benchmark + @SuppressWarnings("unchecked") + public PayloadsPerfSubscriber fireAndForget(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.fireAndForget(PAYLOAD).subscribe((CoreSubscriber) subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestResponse(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestResponse(PAYLOAD).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestStream(PAYLOAD).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsMaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); + client.requestStream(PAYLOAD).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestChannel(PAYLOAD_FLUX).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsMaxPerfSubscriber requestChannelWithRequestAllStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); + client.requestChannel(PAYLOAD_FLUX).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } +} diff --git a/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java b/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java new file mode 100644 index 000000000..6b4f3f624 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java @@ -0,0 +1,43 @@ +package io.rsocket.core; + +import io.netty.util.collection.IntObjectMap; +import io.rsocket.internal.SynchronizedIntObjectHashMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork( + value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} + ) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class StreamIdSupplierPerf { + @Benchmark + public void benchmarkStreamId(Input input) { + int i = input.supplier.nextStreamId(input.map); + input.bh.consume(i); + } + + @State(Scope.Benchmark) + public static class Input { + Blackhole bh; + IntObjectMap map; + StreamIdSupplier supplier; + + @Setup + public void setup(Blackhole bh) { + this.supplier = StreamIdSupplier.clientSupplier(); + this.bh = bh; + this.map = new SynchronizedIntObjectHashMap(); + } + } +} diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java similarity index 76% rename from rsocket-core/src/jmh/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java rename to benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java index 139114466..b4ac808d0 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java +++ b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java @@ -16,7 +16,7 @@ public class FrameHeaderFlyweightPerf { @Benchmark public void encode(Input input) { - ByteBuf byteBuf = FrameHeaderFlyweight.encodeStreamZero(input.allocator, FrameType.SETUP, 0); + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(input.allocator, FrameType.SETUP, 0); boolean release = byteBuf.release(); input.bh.consume(release); } @@ -24,9 +24,9 @@ public void encode(Input input) { @Benchmark public void decode(Input input) { ByteBuf frame = input.frame; - FrameType frameType = FrameHeaderFlyweight.frameType(frame); - int streamId = FrameHeaderFlyweight.streamId(frame); - int flags = FrameHeaderFlyweight.flags(frame); + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); + int flags = FrameHeaderCodec.flags(frame); input.bh.consume(streamId); input.bh.consume(flags); input.bh.consume(frameType); @@ -44,7 +44,7 @@ public void setup(Blackhole bh) { this.bh = bh; this.frameType = FrameType.REQUEST_RESPONSE; allocator = ByteBufAllocator.DEFAULT; - frame = FrameHeaderFlyweight.encode(allocator, 123, FrameType.SETUP, 0); + frame = FrameHeaderCodec.encode(allocator, 123, FrameType.SETUP, 0); } @TearDown diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/FrameTypePerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java similarity index 100% rename from rsocket-core/src/jmh/java/io/rsocket/frame/FrameTypePerf.java rename to benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/PayloadFlyweightPerf.java b/benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java similarity index 90% rename from rsocket-core/src/jmh/java/io/rsocket/frame/PayloadFlyweightPerf.java rename to benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java index 89e50a6e9..01d82a08f 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/frame/PayloadFlyweightPerf.java +++ b/benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java @@ -18,7 +18,7 @@ public class PayloadFlyweightPerf { @Benchmark public void encode(Input input) { ByteBuf encode = - PayloadFrameFlyweight.encode( + PayloadFrameCodec.encode( input.allocator, 100, false, @@ -33,8 +33,8 @@ public void encode(Input input) { @Benchmark public void decode(Input input) { ByteBuf frame = input.payload; - ByteBuf data = PayloadFrameFlyweight.data(frame); - ByteBuf metadata = PayloadFrameFlyweight.metadata(frame); + ByteBuf data = PayloadFrameCodec.data(frame); + ByteBuf metadata = PayloadFrameCodec.metadata(frame); input.bh.consume(data); input.bh.consume(metadata); } @@ -57,7 +57,7 @@ public void setup(Blackhole bh) { // Encode a payload and then copy it a single bytebuf payload = allocator.buffer(); ByteBuf encode = - PayloadFrameFlyweight.encode( + PayloadFrameCodec.encode( allocator, 100, false, diff --git a/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java b/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java new file mode 100644 index 000000000..8f429fc19 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java @@ -0,0 +1,96 @@ +package io.rsocket.metadata; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork(value = 1) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class WellKnownMimeTypePerf { + + // this is the old values() looping implementation of fromIdentifier + private WellKnownMimeType fromIdValuesLoop(int id) { + if (id < 0 || id > 127) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE; + } + for (WellKnownMimeType value : WellKnownMimeType.values()) { + if (value.getIdentifier() == id) { + return value; + } + } + return WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE; + } + + // this is the core of the old values() looping implementation of fromString + private WellKnownMimeType fromStringValuesLoop(String mimeType) { + for (WellKnownMimeType value : WellKnownMimeType.values()) { + if (mimeType.equals(value.getString())) { + return value; + } + } + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE; + } + + @Benchmark + public void fromIdArrayLookup(final Blackhole bh) { + // negative lookup + bh.consume(WellKnownMimeType.fromIdentifier(-10)); + bh.consume(WellKnownMimeType.fromIdentifier(-1)); + // too large lookup + bh.consume(WellKnownMimeType.fromIdentifier(129)); + // first lookup + bh.consume(WellKnownMimeType.fromIdentifier(0)); + // middle lookup + bh.consume(WellKnownMimeType.fromIdentifier(37)); + // reserved lookup + bh.consume(WellKnownMimeType.fromIdentifier(63)); + // last lookup + bh.consume(WellKnownMimeType.fromIdentifier(127)); + } + + @Benchmark + public void fromIdValuesLoopLookup(final Blackhole bh) { + // negative lookup + bh.consume(fromIdValuesLoop(-10)); + bh.consume(fromIdValuesLoop(-1)); + // too large lookup + bh.consume(fromIdValuesLoop(129)); + // first lookup + bh.consume(fromIdValuesLoop(0)); + // middle lookup + bh.consume(fromIdValuesLoop(37)); + // reserved lookup + bh.consume(fromIdValuesLoop(63)); + // last lookup + bh.consume(fromIdValuesLoop(127)); + } + + @Benchmark + public void fromStringMapLookup(final Blackhole bh) { + // unknown lookup + bh.consume(WellKnownMimeType.fromString("foo/bar")); + // first lookup + bh.consume(WellKnownMimeType.fromString(WellKnownMimeType.APPLICATION_AVRO.getString())); + // middle lookup + bh.consume(WellKnownMimeType.fromString(WellKnownMimeType.VIDEO_VP8.getString())); + // last lookup + bh.consume( + WellKnownMimeType.fromString( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString())); + } + + @Benchmark + public void fromStringValuesLoopLookup(final Blackhole bh) { + // unknown lookup + bh.consume(fromStringValuesLoop("foo/bar")); + // first lookup + bh.consume(fromStringValuesLoop(WellKnownMimeType.APPLICATION_AVRO.getString())); + // middle lookup + bh.consume(fromStringValuesLoop(WellKnownMimeType.VIDEO_VP8.getString())); + // last lookup + bh.consume( + fromStringValuesLoop(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString())); + } +} diff --git a/build.gradle b/build.gradle index d8919e8ea..f579b3ae0 100644 --- a/build.gradle +++ b/build.gradle @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,13 +15,11 @@ */ plugins { - id 'com.gradle.build-scan' version '1.16' - - id 'com.github.sherter.google-java-format' version '0.7.1' apply false - id 'com.jfrog.artifactory' version '4.7.3' apply false + id 'com.github.sherter.google-java-format' version '0.8' apply false + id 'com.jfrog.artifactory' version '4.11.0' apply false id 'com.jfrog.bintray' version '1.8.4' apply false - id 'me.champeau.gradle.jmh' version '0.4.7' apply false - id 'io.spring.dependency-management' version '1.0.6.RELEASE' apply false + id 'me.champeau.gradle.jmh' version '0.5.0' apply false + id 'io.spring.dependency-management' version '1.0.8.RELEASE' apply false id 'io.morethan.jmhreport' version '0.9.0' apply false } @@ -29,68 +27,78 @@ subprojects { apply plugin: 'io.spring.dependency-management' apply plugin: 'com.github.sherter.google-java-format' - ext['reactor-bom.version'] = 'Californium-SR5' + ext['reactor-bom.version'] = 'Dysprosium-SR7' ext['logback.version'] = '1.2.3' ext['findbugs.version'] = '3.0.2' - ext['netty.version'] = '4.1.31.Final' - ext['netty-boringssl.version'] = '2.0.18.Final' + ext['netty-bom.version'] = '4.1.48.Final' + ext['netty-boringssl.version'] = '2.0.30.Final' ext['hdrhistogram.version'] = '2.1.10' - ext['mockito.version'] = '2.23.0' + ext['mockito.version'] = '3.2.0' ext['slf4j.version'] = '1.7.25' ext['jmh.version'] = '1.21' - ext['junit.version'] = '5.1.0' + ext['junit.version'] = '5.5.2' ext['hamcrest.version'] = '1.3' ext['micrometer.version'] = '1.0.6' ext['assertj.version'] = '3.11.1' + group = "io.rsocket" + googleJavaFormat { toolVersion = '1.6' } - + + ext { + if (project.hasProperty('versionSuffix')) { + project.version += project.getProperty('versionSuffix') + } + } + dependencyManagement { imports { mavenBom "io.projectreactor:reactor-bom:${ext['reactor-bom.version']}" + mavenBom "io.netty:netty-bom:${ext['netty-bom.version']}" + mavenBom "org.junit:junit-bom:${ext['junit.version']}" } dependencies { dependency "ch.qos.logback:logback-classic:${ext['logback.version']}" - dependency "com.google.code.findbugs:jsr305:${ext['findbugs.version']}" - dependency "io.netty:netty-buffer:${ext['netty.version']}" dependency "io.netty:netty-tcnative-boringssl-static:${ext['netty-boringssl.version']}" dependency "io.micrometer:micrometer-core:${ext['micrometer.version']}" dependency "org.assertj:assertj-core:${ext['assertj.version']}" dependency "org.hdrhistogram:HdrHistogram:${ext['hdrhistogram.version']}" - dependency "org.mockito:mockito-core:${ ext['mockito.version']}" dependency "org.slf4j:slf4j-api:${ext['slf4j.version']}" - - dependencySet(group: 'org.junit.jupiter', version: ext['junit.version']) { - entry 'junit-jupiter-api' - entry 'junit-jupiter-engine' - entry 'junit-jupiter-params' + dependencySet(group: 'org.mockito', version: ext['mockito.version']) { + entry 'mockito-junit-jupiter' + entry 'mockito-core' } - // TODO: Remove after JUnit5 migration dependency 'junit:junit:4.12' dependency "org.hamcrest:hamcrest-library:${ext['hamcrest.version']}" - dependencySet(group: 'org.junit.vintage', version: ext['junit.version']) { - entry 'junit-vintage-engine' - } - dependencySet(group: 'org.openjdk.jmh', version: ext['jmh.version']) { entry 'jmh-core' entry 'jmh-generator-annprocess' } } + generatedPomCustomization { + enabled = false + } } repositories { mavenCentral() - if (version.endsWith('BUILD-SNAPSHOT') || project.hasProperty('platformVersion')) { + if (version.endsWith('SNAPSHOT') || project.hasProperty('platformVersion')) { maven { url 'http://repo.spring.io/libs-snapshot' } + maven { + url 'https://oss.jfrog.org/artifactory/oss-snapshot-local' + } } } + tasks.withType(GenerateModuleMetadata) { + enabled = false + } + plugins.withType(JavaPlugin) { compileJava { sourceCompatibility = 1.8 @@ -100,18 +108,62 @@ subprojects { } javadoc { + def jdk = JavaVersion.current().majorVersion + def jdkJavadoc = "https://docs.oracle.com/javase/$jdk/docs/api/" + if (JavaVersion.current().isJava11Compatible()) { + jdkJavadoc = "https://docs.oracle.com/en/java/javase/$jdk/docs/api/" + } options.with { - links 'https://docs.oracle.com/javase/8/docs/api/' + links jdkJavadoc links 'https://projectreactor.io/docs/core/release/api/' links 'https://netty.io/4.1/api/' } } + tasks.named("javadoc").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + test { useJUnitPlatform() systemProperty "io.netty.leakDetection.level", "ADVANCED" } + + //all test tasks will show FAILED for each test method, + // common exclusions, no scanning + project.tasks.withType(Test).all { + testLogging { + events "FAILED" + showExceptions true + exceptionFormat "FULL" + stackTraceFilters "ENTRY_POINT" + maxGranularity 3 + } + + if (JavaVersion.current().isJava9Compatible()) { + println "Java 9+: lowering MaxGCPauseMillis to 20ms in ${project.name} ${name}" + jvmArgs = ["-XX:MaxGCPauseMillis=20"] + } + + systemProperty("java.awt.headless", "true") + systemProperty("reactor.trace.cancel", "true") + systemProperty("reactor.trace.nocapacity", "true") + systemProperty("testGroups", project.properties.get("testGroups")) + scanForTestClasses = false + exclude '**/*Abstract*.*' + + //allow re-run of failed tests only without special test tasks failing + // because the filter is too restrictive + filter.setFailOnNoMatchingTests(false) + + //display intermediate results for special test tasks + afterSuite { desc, result -> + if (!desc.parent) { // will match the outermost suite + println('\n' + "${desc} Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} successes, ${result.failedTestCount} failures, ${result.skippedTestCount} skipped)") + } + } + } } plugins.withType(JavaLibraryPlugin) { @@ -124,58 +176,14 @@ subprojects { classifier 'javadoc' from javadoc.destinationDir } - } - plugins.withType(MavenPublishPlugin) { - publishing { - publications { - maven(MavenPublication) { - groupId 'io.rsocket' - - from components.java - - artifact sourcesJar - artifact javadocJar - - pom.withXml { - asNode().':version' + { - resolveStrategy = DELEGATE_FIRST - - name project.name - description project.description - url 'http://rsocket.io' - - licenses { - license { - name 'The Apache Software License, Version 2.0' - url 'http://www.apache.org/license/LICENSE-2.0.txt' - } - } - - developers { - developer { - id 'robertroeser' - name 'Robert Roeser' - email 'robert@netifi.com' - } - developer { - id 'rdegnan' - name 'Ryland Degnan' - email 'ryland@netifi.com' - } - developer { - id 'yschimke' - name 'Yuri Schimke' - email 'yuri@schimke.ee' - } - } - - scm { - connection 'scm:git:https://github.com/rsocket/rsocket-java.git' - developerConnection 'scm:git:https://github.com/rsocket/rsocket-java.git' - url 'https://github.com/rsocket/rsocket-java' - } - } + plugins.withType(MavenPublishPlugin) { + publishing { + publications { + maven(MavenPublication) { + from components.java + artifact sourcesJar + artifact javadocJar } } } @@ -183,8 +191,7 @@ subprojects { } } -apply from: 'artifactory.gradle' -apply from: 'bintray.gradle' +apply from: "${rootDir}/gradle/publications.gradle" buildScan { termsOfServiceUrl = 'https://gradle.com/terms-of-service' diff --git a/ci/travis.sh b/ci/travis.sh index 372c01070..d190a59ec 100755 --- a/ci/travis.sh +++ b/ci/travis.sh @@ -5,12 +5,24 @@ if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then echo -e "Building PR #$TRAVIS_PULL_REQUEST [$TRAVIS_PULL_REQUEST_SLUG/$TRAVIS_PULL_REQUEST_BRANCH => $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH]" ./gradlew build +elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] && [ "$TRAVIS_BRANCH" == "develop" ] ; then + + echo -e "Building Develop Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH/$TRAVIS_BUILD_NUMBER" + ./gradlew \ + -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ + -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ + -PversionSuffix="-SNAPSHOT" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ + build artifactoryPublish --stacktrace + elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] ; then - echo -e "Building Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH" + echo -e "Building Branch Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH/$TRAVIS_BUILD_NUMBER" ./gradlew \ -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ + -PversionSuffix="-${TRAVIS_BRANCH//\//-}-SNAPSHOT" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ build artifactoryPublish --stacktrace elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ] && [ "$bintrayUser" != "" ] ; then @@ -20,6 +32,7 @@ elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ] && [ "$bin -Pversion="$TRAVIS_TAG" \ -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ build bintrayUpload --stacktrace else diff --git a/gradle.properties b/gradle.properties index 27abb8cac..b0b107ec4 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,5 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -version=0.11.16-SNAPSHOT +version=1.0.1 +perfBaselineVersion=1.0.0 diff --git a/artifactory.gradle b/gradle/artifactory.gradle similarity index 77% rename from artifactory.gradle rename to gradle/artifactory.gradle index 6e622a610..cdffb2741 100644 --- a/artifactory.gradle +++ b/gradle/artifactory.gradle @@ -31,10 +31,17 @@ if (project.hasProperty('bintrayUser') && project.hasProperty('bintrayKey')) { } defaults { - publications('maven') + publications(publishing.publications.maven) + } + + if (project.hasProperty('buildNumber')) { + clientConfig.info.setBuildNumber(project.property('buildNumber').toString()) } } } + tasks.named("artifactoryPublish").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } } } } diff --git a/bintray.gradle b/gradle/bintray.gradle similarity index 94% rename from bintray.gradle rename to gradle/bintray.gradle index 6fe0db84b..5015f94e4 100644 --- a/bintray.gradle +++ b/gradle/bintray.gradle @@ -55,6 +55,9 @@ if (project.hasProperty('bintrayUser') && project.hasProperty('bintrayKey') && } } } + tasks.named("bintrayUpload").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } } } } diff --git a/gradle/publications.gradle b/gradle/publications.gradle new file mode 100644 index 000000000..b12d9e9c2 --- /dev/null +++ b/gradle/publications.gradle @@ -0,0 +1,68 @@ +apply from: "${rootDir}/gradle/artifactory.gradle" +apply from: "${rootDir}/gradle/bintray.gradle" + +subprojects { + plugins.withType(MavenPublishPlugin) { + publishing { + publications { + maven(MavenPublication) { + pom { + name = project.name + afterEvaluate { + description = project.description + } + groupId = 'io.rsocket' + url = 'http://rsocket.io' + licenses { + license { + name = "The Apache Software License, Version 2.0" + url = "https://www.apache.org/licenses/LICENSE-2.0.txt" + distribution = "repo" + } + } + developers { + developer { + id = 'robertroeser' + name = 'Robert Roeser' + email = 'robert@netifi.com' + } + developer { + id = 'rdegnan' + name = 'Ryland Degnan' + email = 'ryland@netifi.com' + } + developer { + id = 'yschimke' + name = 'Yuri Schimke' + email = 'yuri@schimke.ee' + } + developer { + id = 'OlegDokuka' + name = 'Oleh Dokuka' + email = 'oleh@netifi.com' + } + developer { + id = 'mostroverkhov' + name = 'Maksym Ostroverkhov' + email = 'm.ostroverkhov@gmail.com' + } + } + scm { + connection = 'scm:git:https://github.com/rsocket/rsocket-java.git' + developerConnection = 'scm:git:https://github.com/rsocket/rsocket-java.git' + url = 'https://github.com/rsocket/rsocket-java' + } + versionMapping { + usage('java-api') { + fromResolutionResult() + } + usage('java-runtime') { + fromResolutionResult() + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index f6b961fd5..5c2d1cf01 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index e0b3fb8d7..a4b442974 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-4.10.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.3-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index cccdd3d51..83f2acfdc 100755 --- a/gradlew +++ b/gradlew @@ -1,5 +1,21 @@ #!/usr/bin/env sh +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + ############################################################################## ## ## Gradle start up script for UN*X @@ -28,7 +44,7 @@ APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="" +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" @@ -109,8 +125,8 @@ if $darwin; then GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" fi -# For Cygwin, switch paths to Windows format before running java -if $cygwin ; then +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` JAVACMD=`cygpath --unix "$JAVACMD"` diff --git a/gradlew.bat b/gradlew.bat index e95643d6a..24467a141 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,3 +1,19 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + @if "%DEBUG%" == "" @echo off @rem ########################################################################## @rem @@ -14,7 +30,7 @@ set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome diff --git a/rsocket-bom/build.gradle b/rsocket-bom/build.gradle new file mode 100755 index 000000000..2efc20a91 --- /dev/null +++ b/rsocket-bom/build.gradle @@ -0,0 +1,41 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'java-platform' + id 'maven-publish' + id 'com.jfrog.artifactory' + id 'com.jfrog.bintray' +} + +description = 'RSocket Java Bill of materials.' + +def excluded = ["rsocket-examples", "benchmarks"] + +dependencies { + constraints { + parent.subprojects.findAll { it.name != project.name && !excluded.contains(it.name) } .sort { "$it.name" }.each { + api it + } + } +} + +publishing { + publications { + maven(MavenPublication) { + from components.javaPlatform + } + } +} \ No newline at end of file diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index d62452619..41adbd7a8 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -29,8 +29,6 @@ dependencies { implementation 'org.slf4j:slf4j-api' - compileOnly 'com.google.code.findbugs:jsr305' - testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' testImplementation 'org.junit.jupiter:junit-jupiter-api' @@ -46,6 +44,4 @@ dependencies { testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' } -description = "Core functionality for the RSocket library" - -apply from: 'jmh.gradle' +description = "Core functionality for the RSocket library" \ No newline at end of file diff --git a/rsocket-core/jmh.gradle b/rsocket-core/jmh.gradle deleted file mode 100644 index b7cd10a13..000000000 --- a/rsocket-core/jmh.gradle +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -dependencies { - jmh configurations.api - jmh configurations.implementation - jmh 'org.openjdk.jmh:jmh-core' - jmh 'org.openjdk.jmh:jmh-generator-annprocess' -} - -jmhCompileGeneratedClasses.enabled = false - -jmh { - includeTests = false - profilers = ['gc'] - resultFormat = 'JSON' - - jvmArgs = ['-XX:+UnlockCommercialFeatures', '-XX:+FlightRecorder'] - // jvmArgsAppend = ['-XX:+UseG1GC', '-Xms4g', '-Xmx4g'] -} - -jmhJar { - from project.configurations.jmh -} - -tasks.jmh.finalizedBy tasks.jmhReport - -jmhReport { - jmhResultPath = project.file('build/reports/jmh/results.json') - jmhReportOutput = project.file('build/reports/jmh') -} diff --git a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java b/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java index c099a3120..7f39956dc 100644 --- a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java @@ -16,48 +16,21 @@ package io.rsocket; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; /** * An abstract implementation of {@link RSocket}. All request handling methods emit {@link * UnsupportedOperationException} and hence must be overridden to provide a valid implementation. + * + * @deprecated as of 1.0 in favor of implementing {@link RSocket} directly which has default + * methods. */ +@Deprecated public abstract class AbstractRSocket implements RSocket { private final MonoProcessor onClose = MonoProcessor.create(); - @Override - public Mono fireAndForget(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Fire and forget not implemented.")); - } - - @Override - public Mono requestResponse(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Request-Response not implemented.")); - } - - @Override - public Flux requestStream(Payload payload) { - payload.release(); - return Flux.error(new UnsupportedOperationException("Request-Stream not implemented.")); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.error(new UnsupportedOperationException("Request-Channel not implemented.")); - } - - @Override - public Mono metadataPush(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Metadata-Push not implemented.")); - } - @Override public void dispose() { onClose.onComplete(); diff --git a/rsocket-core/src/main/java/io/rsocket/Closeable.java b/rsocket-core/src/main/java/io/rsocket/Closeable.java index 5eb871e18..2ea9a0371 100644 --- a/rsocket-core/src/main/java/io/rsocket/Closeable.java +++ b/rsocket-core/src/main/java/io/rsocket/Closeable.java @@ -16,17 +16,21 @@ package io.rsocket; +import org.reactivestreams.Subscriber; import reactor.core.Disposable; import reactor.core.publisher.Mono; -/** */ +/** An interface which allows listening to when a specific instance of this interface is closed */ public interface Closeable extends Disposable { /** - * Returns a {@code Publisher} that completes when this {@code RSocket} is closed. A {@code - * RSocket} can be closed by explicitly calling {@link RSocket#dispose()} or when the underlying - * transport connection is closed. + * Returns a {@link Mono} that terminates when the instance is terminated by any reason. Note, in + * case of error termination, the cause of error will be propagated as an error signal through + * {@link org.reactivestreams.Subscriber#onError(Throwable)}. Otherwise, {@link + * Subscriber#onComplete()} will be called. * - * @return A {@code Publisher} that completes when this {@code RSocket} close is complete. + * @return a {@link Mono} to track completion with success or error of the underlying resource. + * When the underlying resource is an `RSocket`, the {@code Mono} exposes stream 0 (i.e. + * connection level) errors. */ Mono onClose(); } diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java index 9f1d7ea6b..ece2aa9fa 100644 --- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -1,11 +1,11 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,31 +18,31 @@ import io.netty.buffer.ByteBuf; import io.netty.util.AbstractReferenceCounted; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.core.DefaultConnectionSetupPayload; +import reactor.util.annotation.Nullable; /** - * Exposed to server for determination of ResponderRSocket based on mime types and SETUP - * metadata/data + * Exposes information from the {@code SETUP} frame to a server, as well as to client responders. */ public abstract class ConnectionSetupPayload extends AbstractReferenceCounted implements Payload { - public static ConnectionSetupPayload create(final ByteBuf setupFrame) { - return new DefaultConnectionSetupPayload(setupFrame); - } + public abstract String metadataMimeType(); + + public abstract String dataMimeType(); public abstract int keepAliveInterval(); public abstract int keepAliveMaxLifetime(); - public abstract String metadataMimeType(); - - public abstract String dataMimeType(); - public abstract int getFlags(); public abstract boolean willClientHonorLease(); + public abstract boolean isResumeEnabled(); + + @Nullable + public abstract ByteBuf resumeToken(); + @Override public ConnectionSetupPayload retain() { super.retain(); @@ -55,77 +55,18 @@ public ConnectionSetupPayload retain(int increment) { return this; } + @Override public abstract ConnectionSetupPayload touch(); - public abstract ConnectionSetupPayload touch(Object hint); - - private static final class DefaultConnectionSetupPayload extends ConnectionSetupPayload { - private final ByteBuf setupFrame; - - public DefaultConnectionSetupPayload(ByteBuf setupFrame) { - this.setupFrame = setupFrame; - } - - @Override - public boolean hasMetadata() { - return FrameHeaderFlyweight.hasMetadata(setupFrame); - } - - @Override - public int keepAliveInterval() { - return SetupFrameFlyweight.keepAliveInterval(setupFrame); - } - - @Override - public int keepAliveMaxLifetime() { - return SetupFrameFlyweight.keepAliveMaxLifetime(setupFrame); - } - - @Override - public String metadataMimeType() { - return SetupFrameFlyweight.metadataMimeType(setupFrame); - } - - @Override - public String dataMimeType() { - return SetupFrameFlyweight.dataMimeType(setupFrame); - } - - @Override - public int getFlags() { - return FrameHeaderFlyweight.flags(setupFrame); - } - - @Override - public boolean willClientHonorLease() { - return SetupFrameFlyweight.honorLease(setupFrame); - } - - @Override - public ConnectionSetupPayload touch() { - setupFrame.touch(); - return this; - } - - @Override - public ConnectionSetupPayload touch(Object hint) { - setupFrame.touch(hint); - return this; - } - - @Override - protected void deallocate() { - setupFrame.release(); - } - - @Override - public ByteBuf sliceMetadata() { - return SetupFrameFlyweight.metadata(setupFrame); - } - - @Override - public ByteBuf sliceData() { - return SetupFrameFlyweight.data(setupFrame); - } + /** + * Create a {@code ConnectionSetupPayload}. + * + * @deprecated as of 1.0 RC7. Please, use {@link + * DefaultConnectionSetupPayload#DefaultConnectionSetupPayload(ByteBuf) + * DefaultConnectionSetupPayload} constructor. + */ + @Deprecated + public static ConnectionSetupPayload create(final ByteBuf setupFrame) { + return new DefaultConnectionSetupPayload(setupFrame); } } diff --git a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java index 7739a34c0..6190d24e3 100644 --- a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import java.nio.channels.ClosedChannelException; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -30,9 +31,9 @@ public interface DuplexConnection extends Availability, Closeable { * Sends the source of Frames on this connection and returns the {@code Publisher} representing * the result of this send. * - *

Flow control

+ *

Flow control * - * The passed {@code Publisher} must + *

The passed {@code Publisher} must * * @param frames Stream of {@code Frame}s to send on the connection. * @return {@code Publisher} that completes when all the frames are written on the connection @@ -56,20 +57,20 @@ default Mono sendOne(ByteBuf frame) { /** * Returns a stream of all {@code Frame}s received on this connection. * - *

Completion

+ *

Completion * - * Returned {@code Publisher} MUST never emit a completion event ({@link + *

Returned {@code Publisher} MUST never emit a completion event ({@link * Subscriber#onComplete()}. * - *

Error

+ *

Error * - * Returned {@code Publisher} can error with various transport errors. If the underlying physical - * connection is closed by the peer, then the returned stream from here MUST emit an - * {@link ClosedChannelException}. + *

Returned {@code Publisher} can error with various transport errors. If the underlying + * physical connection is closed by the peer, then the returned stream from here MUST + * emit an {@link ClosedChannelException}. * - *

Multiple Subscriptions

+ *

Multiple Subscriptions * - * Returned {@code Publisher} is not required to support multiple concurrent subscriptions. + *

Returned {@code Publisher} is not required to support multiple concurrent subscriptions. * RSocket will never have multiple subscriptions to this source. Implementations MUST * emit an {@link IllegalStateException} for subsequent concurrent subscriptions, if they do not * support multiple concurrent subscriptions. @@ -78,6 +79,13 @@ default Mono sendOne(ByteBuf frame) { */ Flux receive(); + /** + * Returns the assigned {@link ByteBufAllocator}. + * + * @return the {@link ByteBufAllocator} + */ + ByteBufAllocator alloc(); + @Override default double availability() { return isDisposed() ? 0.0 : 1.0; diff --git a/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java deleted file mode 100644 index 7eda01fbe..000000000 --- a/rsocket-core/src/main/java/io/rsocket/KeepAliveHandler.java +++ /dev/null @@ -1,125 +0,0 @@ -package io.rsocket; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.frame.KeepAliveFrameFlyweight; -import java.time.Duration; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.UnicastProcessor; - -abstract class KeepAliveHandler implements Disposable { - private final KeepAlive keepAlive; - private final UnicastProcessor sent = UnicastProcessor.create(); - private final MonoProcessor timeout = MonoProcessor.create(); - private Disposable intervalDisposable; - private volatile long lastReceivedMillis; - - private KeepAliveHandler(KeepAlive keepAlive) { - this.keepAlive = keepAlive; - this.lastReceivedMillis = System.currentTimeMillis(); - this.intervalDisposable = - Flux.interval(Duration.ofMillis(keepAlive.getTickPeriod())) - .subscribe(v -> onIntervalTick()); - } - - static KeepAliveHandler ofServer(KeepAlive keepAlive) { - return new KeepAliveHandler.Server(keepAlive); - } - - static KeepAliveHandler ofClient(KeepAlive keepAlive) { - return new KeepAliveHandler.Client(keepAlive); - } - - @Override - public void dispose() { - sent.onComplete(); - timeout.onComplete(); - intervalDisposable.dispose(); - } - - public void receive(ByteBuf keepAliveFrame) { - this.lastReceivedMillis = System.currentTimeMillis(); - if (KeepAliveFrameFlyweight.respondFlag(keepAliveFrame)) { - doSend( - KeepAliveFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - false, - 0, - KeepAliveFrameFlyweight.data(keepAliveFrame).retain())); - } - } - - public Flux send() { - return sent; - } - - public Mono timeout() { - return timeout; - } - - abstract void onIntervalTick(); - - void doSend(ByteBuf frame) { - sent.onNext(frame); - } - - void doCheckTimeout() { - long now = System.currentTimeMillis(); - if (now - lastReceivedMillis >= keepAlive.getTimeoutMillis()) { - timeout.onNext(keepAlive); - } - } - - private static class Server extends KeepAliveHandler { - - Server(KeepAlive keepAlive) { - super(keepAlive); - } - - @Override - void onIntervalTick() { - doCheckTimeout(); - } - } - - private static final class Client extends KeepAliveHandler { - - Client(KeepAlive keepAlive) { - super(keepAlive); - } - - @Override - void onIntervalTick() { - doCheckTimeout(); - doSend( - KeepAliveFrameFlyweight.encode(ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER)); - } - } - - static final class KeepAlive { - private final long tickPeriod; - private final long timeoutMillis; - - KeepAlive(Duration tickPeriod, Duration timeoutMillis, int maxTicks) { - this.tickPeriod = tickPeriod.toMillis(); - this.timeoutMillis = timeoutMillis.toMillis() + maxTicks * tickPeriod.toMillis(); - } - - KeepAlive(long tickPeriod, long timeoutMillis) { - this.tickPeriod = tickPeriod; - this.timeoutMillis = timeoutMillis; - } - - public long getTickPeriod() { - return tickPeriod; - } - - public long getTimeoutMillis() { - return timeoutMillis; - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/Payload.java b/rsocket-core/src/main/java/io/rsocket/Payload.java index 58fab3382..fc130528e 100644 --- a/rsocket-core/src/main/java/io/rsocket/Payload.java +++ b/rsocket-core/src/main/java/io/rsocket/Payload.java @@ -32,8 +32,8 @@ public interface Payload extends ReferenceCounted { boolean hasMetadata(); /** - * Returns the Payload metadata. Always non-null, check {@link #hasMetadata()} to differentiate - * null from "". + * Returns a slice Payload metadata. Always non-null, check {@link #hasMetadata()} to + * differentiate null from "". * * @return payload metadata. */ @@ -46,6 +46,22 @@ public interface Payload extends ReferenceCounted { */ ByteBuf sliceData(); + /** + * Returns the Payloads' data without slicing if possible. This is not safe and editing this could + * effect the payload. It is recommended to call sliceData(). + * + * @return data as a bytebuf or slice of the data + */ + ByteBuf data(); + + /** + * Returns the Payloads' metadata without slicing if possible. This is not safe and editing this + * could effect the payload. It is recommended to call sliceMetadata(). + * + * @return metadata as a bytebuf or slice of the metadata + */ + ByteBuf metadata(); + /** Increases the reference count by {@code 1}. */ @Override Payload retain(); diff --git a/rsocket-core/src/main/java/io/rsocket/RSocket.java b/rsocket-core/src/main/java/io/rsocket/RSocket.java index 5468b4de8..773c93dc2 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocket.java @@ -33,7 +33,10 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} that completes when the passed {@code payload} is successfully * handled, otherwise errors. */ - Mono fireAndForget(Payload payload); + default Mono fireAndForget(Payload payload) { + payload.release(); + return Mono.error(new UnsupportedOperationException("Fire-and-Forget not implemented.")); + } /** * Request-Response interaction model of {@code RSocket}. @@ -42,7 +45,10 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} containing at most a single {@code Payload} representing the * response. */ - Mono requestResponse(Payload payload); + default Mono requestResponse(Payload payload) { + payload.release(); + return Mono.error(new UnsupportedOperationException("Request-Response not implemented.")); + } /** * Request-Stream interaction model of {@code RSocket}. @@ -50,7 +56,10 @@ public interface RSocket extends Availability, Closeable { * @param payload Request payload. * @return {@code Publisher} containing the stream of {@code Payload}s representing the response. */ - Flux requestStream(Payload payload); + default Flux requestStream(Payload payload) { + payload.release(); + return Flux.error(new UnsupportedOperationException("Request-Stream not implemented.")); + } /** * Request-Channel interaction model of {@code RSocket}. @@ -58,7 +67,9 @@ public interface RSocket extends Availability, Closeable { * @param payloads Stream of request payloads. * @return Stream of response payloads. */ - Flux requestChannel(Publisher payloads); + default Flux requestChannel(Publisher payloads) { + return Flux.error(new UnsupportedOperationException("Request-Channel not implemented.")); + } /** * Metadata-Push interaction model of {@code RSocket}. @@ -67,10 +78,26 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} that completes when the passed {@code payload} is successfully * handled, otherwise errors. */ - Mono metadataPush(Payload payload); + default Mono metadataPush(Payload payload) { + payload.release(); + return Mono.error(new UnsupportedOperationException("Metadata-Push not implemented.")); + } @Override default double availability() { return isDisposed() ? 0.0 : 1.0; } + + @Override + default void dispose() {} + + @Override + default boolean isDisposed() { + return false; + } + + @Override + default Mono onClose() { + return Mono.never(); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java deleted file mode 100644 index 27a882d01..000000000 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ /dev/null @@ -1,600 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.util.collection.IntObjectHashMap; -import io.rsocket.exceptions.ConnectionErrorException; -import io.rsocket.exceptions.Exceptions; -import io.rsocket.frame.*; -import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.LimitableRequestPublisher; -import io.rsocket.internal.UnboundedProcessor; -import io.rsocket.internal.UnicastMonoProcessor; -import java.nio.channels.ClosedChannelException; -import java.time.Duration; -import java.util.Collections; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Consumer; -import org.reactivestreams.Processor; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.SignalType; -import reactor.core.publisher.UnicastProcessor; - -/** Client Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketServer} */ -class RSocketClient implements RSocket { - - private final DuplexConnection connection; - private final PayloadDecoder payloadDecoder; - private final Consumer errorConsumer; - private final StreamIdSupplier streamIdSupplier; - private final Map senders; - private final Map> receivers; - private final UnboundedProcessor sendProcessor; - private final Lifecycle lifecycle = new Lifecycle(); - private final ByteBufAllocator allocator; - private KeepAliveHandler keepAliveHandler; - - /*server requester*/ - RSocketClient( - ByteBufAllocator allocator, - DuplexConnection connection, - PayloadDecoder payloadDecoder, - Consumer errorConsumer, - StreamIdSupplier streamIdSupplier) { - this( - allocator, - connection, - payloadDecoder, - errorConsumer, - streamIdSupplier, - Duration.ZERO, - Duration.ZERO, - 0); - } - - /*client requester*/ - RSocketClient( - ByteBufAllocator allocator, - DuplexConnection connection, - PayloadDecoder payloadDecoder, - Consumer errorConsumer, - StreamIdSupplier streamIdSupplier, - Duration tickPeriod, - Duration ackTimeout, - int missedAcks) { - this.allocator = allocator; - this.connection = connection; - this.payloadDecoder = payloadDecoder; - this.errorConsumer = errorConsumer; - this.streamIdSupplier = streamIdSupplier; - this.senders = Collections.synchronizedMap(new IntObjectHashMap<>()); - this.receivers = Collections.synchronizedMap(new IntObjectHashMap<>()); - - // DO NOT Change the order here. The Send processor must be subscribed to before receiving - this.sendProcessor = new UnboundedProcessor<>(); - - connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer); - - connection - .send(sendProcessor) - .doFinally(this::handleSendProcessorCancel) - .subscribe(null, this::handleSendProcessorError); - - connection.receive().subscribe(this::handleIncomingFrames, errorConsumer); - - if (!Duration.ZERO.equals(tickPeriod)) { - this.keepAliveHandler = - KeepAliveHandler.ofClient( - new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout, missedAcks)); - - keepAliveHandler - .timeout() - .subscribe( - keepAlive -> { - String message = - String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis()); - ConnectionErrorException err = new ConnectionErrorException(message); - lifecycle.setTerminationError(err); - errorConsumer.accept(err); - connection.dispose(); - }); - keepAliveHandler.send().subscribe(sendProcessor::onNext); - } else { - keepAliveHandler = null; - } - } - - private void handleSendProcessorError(Throwable t) { - Throwable terminationError = lifecycle.getTerminationError(); - Throwable err = terminationError != null ? terminationError : t; - receivers - .values() - .forEach( - subscriber -> { - try { - subscriber.onError(err); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - - senders.values().forEach(LimitableRequestPublisher::cancel); - } - - private void handleSendProcessorCancel(SignalType t) { - if (SignalType.ON_ERROR == t) { - return; - } - - receivers - .values() - .forEach( - subscriber -> { - try { - subscriber.onError(new Throwable("closed connection")); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - - senders.values().forEach(LimitableRequestPublisher::cancel); - } - - @Override - public Mono fireAndForget(Payload payload) { - return handleFireAndForget(payload); - } - - @Override - public Mono requestResponse(Payload payload) { - return handleRequestResponse(payload); - } - - @Override - public Flux requestStream(Payload payload) { - return handleRequestStream(payload); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return handleChannel(Flux.from(payloads)); - } - - @Override - public Mono metadataPush(Payload payload) { - return handleMetadataPush(payload); - } - - @Override - public double availability() { - return connection.availability(); - } - - @Override - public void dispose() { - connection.dispose(); - } - - @Override - public boolean isDisposed() { - return connection.isDisposed(); - } - - @Override - public Mono onClose() { - return connection.onClose(); - } - - private Mono handleFireAndForget(Payload payload) { - return lifecycle - .active() - .then( - Mono.fromRunnable( - () -> { - final int streamId = streamIdSupplier.nextStreamId(); - ByteBuf requestFrame = - RequestFireAndForgetFrameFlyweight.encode( - allocator, - streamId, - false, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); - - payload.release(); - sendProcessor.onNext(requestFrame); - })); - } - - private Flux handleRequestStream(final Payload payload) { - return lifecycle - .active() - .thenMany( - Flux.defer( - () -> { - int streamId = streamIdSupplier.nextStreamId(); - - UnicastProcessor receiver = UnicastProcessor.create(); - receivers.put(streamId, receiver); - - AtomicBoolean first = new AtomicBoolean(false); - - return receiver - .doOnRequest( - n -> { - if (first.compareAndSet(false, true) && !receiver.isDisposed()) { - sendProcessor.onNext( - RequestStreamFrameFlyweight.encode( - allocator, - streamId, - false, - n, - payload.sliceMetadata().retain(), - payload.sliceData().retain())); - } else if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - RequestNFrameFlyweight.encode(allocator, streamId, n)); - } - sendProcessor.drain(); - }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - ErrorFrameFlyweight.encode(allocator, streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - CancelFrameFlyweight.encode(allocator, streamId)); - } - }) - .doFinally( - s -> { - receivers.remove(streamId); - }); - })); - } - - private Mono handleRequestResponse(final Payload payload) { - return lifecycle - .active() - .then( - Mono.defer( - () -> { - int streamId = streamIdSupplier.nextStreamId(); - ByteBuf requestFrame = - RequestResponseFrameFlyweight.encode( - allocator, - streamId, - false, - payload.sliceMetadata().retain(), - payload.sliceData().retain()); - - UnicastMonoProcessor receiver = UnicastMonoProcessor.create(); - receivers.put(streamId, receiver); - - sendProcessor.onNext(requestFrame); - - return receiver - .doOnError( - t -> - sendProcessor.onNext( - ErrorFrameFlyweight.encode(allocator, streamId, t))) - .doFinally( - s -> { - if (s == SignalType.CANCEL) { - sendProcessor.onNext( - CancelFrameFlyweight.encode(allocator, streamId)); - } - - receivers.remove(streamId); - }); - })); - } - - private Flux handleChannel(Flux request) { - return lifecycle - .active() - .thenMany( - Flux.defer( - () -> { - final UnicastProcessor receiver = UnicastProcessor.create(); - final int streamId = streamIdSupplier.nextStreamId(); - final AtomicBoolean firstRequest = new AtomicBoolean(true); - - return receiver - .doOnRequest( - n -> { - if (firstRequest.compareAndSet(true, false)) { - final AtomicBoolean firstPayload = new AtomicBoolean(true); - final Flux requestFrames = - request - .transform( - f -> { - LimitableRequestPublisher wrapped = - LimitableRequestPublisher.wrap(f); - // Need to set this to one for first the frame - wrapped.increaseRequestLimit(1); - senders.put(streamId, wrapped); - receivers.put(streamId, receiver); - - return wrapped; - }) - .map( - payload -> { - final ByteBuf requestFrame; - if (firstPayload.compareAndSet(true, false)) { - requestFrame = - RequestChannelFrameFlyweight.encode( - allocator, - streamId, - false, - false, - n, - payload.sliceMetadata().retain(), - payload.sliceData().retain()); - } else { - requestFrame = - PayloadFrameFlyweight.encode( - allocator, streamId, false, false, true, - payload); - } - return requestFrame; - }) - .doOnComplete( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - PayloadFrameFlyweight.encodeComplete( - allocator, streamId)); - } - if (firstPayload.get()) { - receiver.onComplete(); - } - }); - - requestFrames.subscribe( - sendProcessor::onNext, - t -> { - errorConsumer.accept(t); - receiver.dispose(); - }); - } else { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - RequestNFrameFlyweight.encode(allocator, streamId, n)); - } - } - }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - ErrorFrameFlyweight.encode(allocator, streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - CancelFrameFlyweight.encode(allocator, streamId)); - } - }) - .doFinally( - s -> { - receivers.remove(streamId); - LimitableRequestPublisher sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - }); - })); - } - - private Mono handleMetadataPush(Payload payload) { - return lifecycle - .active() - .then( - Mono.fromRunnable( - () -> { - sendProcessor.onNext( - MetadataPushFrameFlyweight.encode( - allocator, payload.sliceMetadata().retain())); - })); - } - - private boolean contains(int streamId) { - return receivers.containsKey(streamId); - } - - protected void terminate() { - lifecycle.setTerminationError(new ClosedChannelException()); - - if (keepAliveHandler != null) { - keepAliveHandler.dispose(); - } - try { - receivers.values().forEach(this::cleanUpSubscriber); - senders.values().forEach(this::cleanUpLimitableRequestPublisher); - } finally { - senders.clear(); - receivers.clear(); - sendProcessor.dispose(); - } - } - - private synchronized void cleanUpLimitableRequestPublisher( - LimitableRequestPublisher limitableRequestPublisher) { - try { - limitableRequestPublisher.cancel(); - } catch (Throwable t) { - errorConsumer.accept(t); - } - } - - private synchronized void cleanUpSubscriber(Processor subscriber) { - try { - subscriber.onError(lifecycle.getTerminationError()); - } catch (Throwable t) { - errorConsumer.accept(t); - } - } - - private void handleIncomingFrames(ByteBuf frame) { - try { - int streamId = FrameHeaderFlyweight.streamId(frame); - FrameType type = FrameHeaderFlyweight.frameType(frame); - if (streamId == 0) { - handleStreamZero(type, frame); - } else { - handleFrame(streamId, type, frame); - } - } finally { - frame.release(); - } - } - - private void handleStreamZero(FrameType type, ByteBuf frame) { - switch (type) { - case ERROR: - RuntimeException error = Exceptions.from(frame); - lifecycle.setTerminationError(error); - errorConsumer.accept(error); - connection.dispose(); - break; - case LEASE: - break; - case KEEPALIVE: - if (keepAliveHandler != null) { - keepAliveHandler.receive(frame); - } - break; - default: - // Ignore unknown frames. Throwing an error will close the socket. - errorConsumer.accept( - new IllegalStateException( - "Client received supported frame on stream 0: " + frame.toString())); - } - } - - private void handleFrame(int streamId, FrameType type, ByteBuf frame) { - Subscriber receiver = receivers.get(streamId); - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - } else { - switch (type) { - case ERROR: - receiver.onError(Exceptions.from(frame)); - receivers.remove(streamId); - break; - case NEXT_COMPLETE: - receiver.onNext(payloadDecoder.apply(frame)); - receiver.onComplete(); - break; - case CANCEL: - { - LimitableRequestPublisher sender = senders.remove(streamId); - receivers.remove(streamId); - if (sender != null) { - sender.cancel(); - } - break; - } - case NEXT: - receiver.onNext(payloadDecoder.apply(frame)); - break; - case REQUEST_N: - { - LimitableRequestPublisher sender = senders.get(streamId); - if (sender != null) { - int n = RequestNFrameFlyweight.requestN(frame); - sender.increaseRequestLimit(n); - sendProcessor.drain(); - } - break; - } - case COMPLETE: - receiver.onComplete(); - receivers.remove(streamId); - break; - default: - throw new IllegalStateException( - "Client received supported frame on stream " + streamId + ": " + frame.toString()); - } - } - } - - private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBuf frame) { - if (!streamIdSupplier.isBeforeOrCurrent(streamId)) { - if (type == FrameType.ERROR) { - // message for stream that has never existed, we have a problem with - // the overall connection and must tear down - String errorMessage = ErrorFrameFlyweight.dataUtf8(frame); - - throw new IllegalStateException( - "Client received error for non-existent stream: " - + streamId - + " Message: " - + errorMessage); - } else { - throw new IllegalStateException( - "Client received message for non-existent stream: " - + streamId - + ", frame type: " - + type); - } - } - // receiving a frame after a given stream has been cancelled/completed, - // so ignore (cancellation is async so there is a race condition) - } - - private static class Lifecycle { - - private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = - AtomicReferenceFieldUpdater.newUpdater( - Lifecycle.class, Throwable.class, "terminationError"); - private volatile Throwable terminationError; - - public Mono active() { - return Mono.create( - sink -> { - if (terminationError == null) { - sink.success(); - } else { - sink.error(terminationError); - } - }); - } - - public Throwable getTerminationError() { - return terminationError; - } - - public void setTerminationError(Throwable err) { - TERMINATION_ERROR.compareAndSet(this, null, err); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java new file mode 100644 index 000000000..b43b14bae --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java @@ -0,0 +1,82 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import reactor.util.annotation.Nullable; + +/** + * Exception that represents an RSocket protocol error. + * + * @see ERROR + * Frame (0x0B) + */ +public class RSocketErrorException extends RuntimeException { + + private static final long serialVersionUID = -1628781753426267554L; + + private static final int MIN_ERROR_CODE = 0x00000001; + + private static final int MAX_ERROR_CODE = 0xFFFFFFFE; + + private final int errorCode; + + /** + * Constructor with a protocol error code and a message. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + */ + public RSocketErrorException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Alternative to {@link #RSocketErrorException(int, String)} with a root cause. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + * @param cause a root cause for the error + */ + public RSocketErrorException(int errorCode, String message, @Nullable Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + if (errorCode > MAX_ERROR_CODE && errorCode < MIN_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000001-0xFFFFFFFE]", this); + } + } + + /** + * Return the RSocket error code + * represented by this exception + * + * @return the RSocket protocol error code + */ + public int errorCode() { + return errorCode; + } + + @Override + public String toString() { + return getClass().getSimpleName() + + " (0x" + + Integer.toHexString(errorCode) + + "): " + + getMessage(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index 8e8afda0a..e23bcceb2 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,48 +13,62 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.rsocket; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.InvalidSetupException; -import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.fragmentation.FragmentationDuplexConnection; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; -import io.rsocket.frame.VersionFlyweight; +import io.netty.buffer.Unpooled; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.lease.LeaseStats; +import io.rsocket.lease.Leases; import io.rsocket.plugins.DuplexConnectionInterceptor; -import io.rsocket.plugins.PluginRegistry; -import io.rsocket.plugins.Plugins; import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.plugins.SocketAcceptorInterceptor; +import io.rsocket.resume.ClientResume; +import io.rsocket.resume.ResumableFramesStore; +import io.rsocket.resume.ResumeStrategy; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; -import io.rsocket.util.EmptyPayload; import java.time.Duration; -import java.util.Objects; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +/** + * Main entry point to create RSocket clients or servers as follows: + * + *

    + *
  • {@link ClientRSocketFactory} to connect as a client. Use {@link #connect()} for a default + * instance. + *
  • {@link ServerRSocketFactory} to start a server. Use {@link #receive()} for a default + * instance. + *
+ * + * @deprecated please use {@link RSocketConnector} and {@link RSocketServer}. + */ +@Deprecated +public final class RSocketFactory { -/** Factory for creating RSocket clients and servers. */ -public class RSocketFactory { /** - * Creates a factory that establishes client connections to other RSockets. + * Create a {@code ClientRSocketFactory} to connect to a remote RSocket endpoint. Internally + * delegates to {@link RSocketConnector}. * - * @return a client factory + * @return the {@code ClientRSocketFactory} instance */ public static ClientRSocketFactory connect() { return new ClientRSocketFactory(); } /** - * Creates a factory that receives server connections from client RSockets. + * Create a {@code ServerRSocketFactory} to accept connections from RSocket clients. Internally + * delegates to {@link RSocketServer}. * - * @return a server factory. + * @return the {@code ClientRSocketFactory} instance */ public static ServerRSocketFactory receive() { return new ServerRSocketFactory(); @@ -73,6 +87,9 @@ default Start transport(ClientTransport transport) { } public interface ServerTransportAcceptor { + + ServerTransport.ConnectionAcceptor toConnectionAcceptor(); + Start transport(Supplier> transport); default Start transport(ServerTransport transport) { @@ -80,335 +97,475 @@ default Start transport(ServerTransport transport) { } } + /** Factory to create and configure an RSocket client, and connect to a server. */ public static class ClientRSocketFactory implements ClientTransportAcceptor { - private Supplier> acceptor = - () -> rSocket -> new AbstractRSocket() {}; + private static final ClientResume CLIENT_RESUME = + new ClientResume(Duration.ofMinutes(2), Unpooled.EMPTY_BUFFER); - private Consumer errorConsumer = Throwable::printStackTrace; - private int mtu = 0; - private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins()); - - private Payload setupPayload = EmptyPayload.INSTANCE; - private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + private final RSocketConnector connector; private Duration tickPeriod = Duration.ofSeconds(20); private Duration ackTimeout = Duration.ofSeconds(30); private int missedAcks = 3; - private String metadataMimeType = "application/binary"; - private String dataMimeType = "application/binary"; + private Resume resume; - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + public ClientRSocketFactory() { + this(RSocketConnector.create()); + } + public ClientRSocketFactory(RSocketConnector connector) { + this.connector = connector; + } + + /** + * @deprecated this method is deprecated and deliberately has no effect anymore. Right now, in + * order configure the custom {@link ByteBufAllocator} it is recommended to use the + * following setup for Reactor Netty based transport:
+ * 1. For Client:
+ *
{@code
+     * TcpClient.create()
+     *          ...
+     *          .bootstrap(bootstrap -> bootstrap.option(ChannelOption.ALLOCATOR, clientAllocator))
+     * }
+ *
+ * 2. For server:
+ *
{@code
+     * TcpServer.create()
+     *          ...
+     *          .bootstrap(serverBootstrap -> serverBootstrap.childOption(ChannelOption.ALLOCATOR, serverAllocator))
+     * }
+ * Or in case of local transport, to use corresponding factory method {@code + * LocalClientTransport.creat(String, ByteBufAllocator)} + * @param allocator instance of {@link ByteBufAllocator} + * @return this factory instance + */ public ClientRSocketFactory byteBufAllocator(ByteBufAllocator allocator) { - Objects.requireNonNull(allocator); - this.allocator = allocator; return this; } public ClientRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - plugins.addConnectionPlugin(interceptor); + connector.interceptors(registry -> registry.forConnection(interceptor)); return this; } + /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ + @Deprecated public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { - plugins.addClientPlugin(interceptor); + return addRequesterPlugin(interceptor); + } + + public ClientRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { + connector.interceptors(registry -> registry.forRequester(interceptor)); return this; } + /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ + @Deprecated public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { - plugins.addServerPlugin(interceptor); + return addResponderPlugin(interceptor); + } + + public ClientRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { + connector.interceptors(registry -> registry.forResponder(interceptor)); + return this; + } + + public ClientRSocketFactory addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { + connector.interceptors(registry -> registry.forSocketAcceptor(interceptor)); return this; } /** - * Deprecated as Keep-Alive is not optional according to spec + * Deprecated without replacement as Keep-Alive is not optional according to spec * * @return this ClientRSocketFactory */ @Deprecated public ClientRSocketFactory keepAlive() { + connector.keepAlive(tickPeriod, ackTimeout.plus(tickPeriod.multipliedBy(missedAcks))); return this; } - public ClientRSocketFactory keepAlive( + public ClientTransportAcceptor keepAlive( Duration tickPeriod, Duration ackTimeout, int missedAcks) { this.tickPeriod = tickPeriod; this.ackTimeout = ackTimeout; this.missedAcks = missedAcks; + keepAlive(); return this; } public ClientRSocketFactory keepAliveTickPeriod(Duration tickPeriod) { this.tickPeriod = tickPeriod; + keepAlive(); return this; } public ClientRSocketFactory keepAliveAckTimeout(Duration ackTimeout) { this.ackTimeout = ackTimeout; + keepAlive(); return this; } public ClientRSocketFactory keepAliveMissedAcks(int missedAcks) { this.missedAcks = missedAcks; + keepAlive(); return this; } public ClientRSocketFactory mimeType(String metadataMimeType, String dataMimeType) { - this.dataMimeType = dataMimeType; - this.metadataMimeType = metadataMimeType; + connector.metadataMimeType(metadataMimeType); + connector.dataMimeType(dataMimeType); return this; } public ClientRSocketFactory dataMimeType(String dataMimeType) { - this.dataMimeType = dataMimeType; + connector.dataMimeType(dataMimeType); return this; } public ClientRSocketFactory metadataMimeType(String metadataMimeType) { - this.metadataMimeType = metadataMimeType; + connector.metadataMimeType(metadataMimeType); return this; } - @Override - public Start transport(Supplier transportClient) { - return new StartClient(transportClient); + public ClientRSocketFactory lease(Supplier> supplier) { + connector.lease(supplier); + return this; + } + + public ClientRSocketFactory lease() { + connector.lease(Leases::new); + return this; + } + + /** @deprecated without a replacement and no longer used. */ + @Deprecated + public ClientRSocketFactory singleSubscriberRequester() { + return this; + } + + /** + * Enables a reconnectable, shared instance of {@code Mono} so every subscriber will + * observe the same RSocket instance up on connection establishment.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  RSocket r1 = sharedRSocketMono.block();
+     *  RSocket r2 = sharedRSocketMono.block();
+     *
+     *  assert r1 == r2;
+     *
+     * }
+ * + * Apart of the shared behavior, if the connection is lost, the same {@code Mono} + * instance will transparently re-establish the connection for subsequent subscribers.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  RSocket r1 = sharedRSocketMono.block();
+     *  RSocket r2 = sharedRSocketMono.block();
+     *
+     *  assert r1 == r2;
+     *
+     *  r1.dispose()
+     *
+     *  assert r2.isDisposed()
+     *
+     *  RSocket r3 = sharedRSocketMono.block();
+     *  RSocket r4 = sharedRSocketMono.block();
+     *
+     *
+     *  assert r1 != r3;
+     *  assert r4 == r3;
+     *
+     * }
+ * + * Note, having reconnect() enabled does not eliminate the need to accompany each + * individual request with the corresponding retry logic.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  sharedRSocket.flatMap(rSocket -> rSocket.requestResponse(...))
+     *               .retryWhen(ownRetry)
+     *               .subscribe()
+     *
+     * }
+ * + * @param retrySpec a retry factory applied for {@link Mono#retryWhen(Retry)} + * @return a shared instance of {@code Mono}. + */ + public ClientRSocketFactory reconnect(Retry retrySpec) { + connector.reconnect(retrySpec); + return this; + } + + public ClientRSocketFactory resume() { + resume = resume != null ? resume : new Resume(); + connector.resume(resume); + return this; + } + + public ClientRSocketFactory resumeToken(Supplier supplier) { + resume(); + resume.token(supplier); + return this; + } + + public ClientRSocketFactory resumeStore( + Function storeFactory) { + resume(); + resume.storeFactory(storeFactory); + return this; + } + + public ClientRSocketFactory resumeSessionDuration(Duration sessionDuration) { + resume(); + resume.sessionDuration(sessionDuration); + return this; + } + + public ClientRSocketFactory resumeStreamTimeout(Duration streamTimeout) { + resume(); + resume.streamTimeout(streamTimeout); + return this; + } + + public ClientRSocketFactory resumeStrategy(Supplier strategy) { + resume(); + resume.retry( + Retry.from( + signals -> signals.flatMap(s -> strategy.get().apply(CLIENT_RESUME, s.failure())))); + return this; + } + + public ClientRSocketFactory resumeCleanupOnKeepAlive() { + resume(); + resume.cleanupStoreOnKeepAlive(); + return this; + } + + public Start transport(Supplier transport) { + return () -> connector.connect(transport); } public ClientTransportAcceptor acceptor(Function acceptor) { - this.acceptor = () -> acceptor; - return StartClient::new; + return acceptor(() -> acceptor); + } + + public ClientTransportAcceptor acceptor(Supplier> acceptorSupplier) { + return acceptor( + (setup, sendingSocket) -> { + acceptorSupplier.get().apply(sendingSocket); + return Mono.empty(); + }); } - public ClientTransportAcceptor acceptor(Supplier> acceptor) { - this.acceptor = acceptor; - return StartClient::new; + public ClientTransportAcceptor acceptor(SocketAcceptor acceptor) { + connector.acceptor(acceptor); + return this; } public ClientRSocketFactory fragment(int mtu) { - this.mtu = mtu; + connector.fragment(mtu); return this; } + /** + * @deprecated this handler is deliberately no-ops and is deprecated with no replacement. In + * order to observe errors, it is recommended to add error handler using {@code doOnError} + * on the specific logical stream. In order to observe connection, or RSocket terminal + * errors, it is recommended to hook on {@link Closeable#onClose()} handler. + */ public ClientRSocketFactory errorConsumer(Consumer errorConsumer) { - this.errorConsumer = errorConsumer; return this; } public ClientRSocketFactory setupPayload(Payload payload) { - this.setupPayload = payload; + connector.setupPayload(payload); return this; } public ClientRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) { - this.payloadDecoder = payloadDecoder; - return this; - } - - private class StartClient implements Start { - private final Supplier transportClient; - - StartClient(Supplier transportClient) { - this.transportClient = transportClient; - } - - @Override - public Mono start() { - return transportClient - .get() - .connect() - .flatMap( - connection -> { - ByteBuf setupFrame = - SetupFrameFlyweight.encode( - allocator, - false, - false, - (int) tickPeriod.toMillis(), - (int) (ackTimeout.toMillis() + tickPeriod.toMillis() * missedAcks), - metadataMimeType, - dataMimeType, - setupPayload.sliceMetadata(), - setupPayload.sliceData()); - - if (mtu > 0) { - connection = new FragmentationDuplexConnection(connection, mtu); - } - - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, plugins); - - RSocketClient rSocketClient = - new RSocketClient( - allocator, - multiplexer.asClientConnection(), - payloadDecoder, - errorConsumer, - StreamIdSupplier.clientSupplier(), - tickPeriod, - ackTimeout, - missedAcks); - - RSocket wrappedRSocketClient = plugins.applyClient(rSocketClient); - - RSocket unwrappedServerSocket = acceptor.get().apply(wrappedRSocketClient); - - RSocket wrappedRSocketServer = plugins.applyServer(unwrappedServerSocket); - - RSocketServer rSocketServer = - new RSocketServer( - allocator, - multiplexer.asServerConnection(), - wrappedRSocketServer, - payloadDecoder, - errorConsumer); - - return connection.sendOne(setupFrame).thenReturn(wrappedRSocketClient); - }); - } + connector.payloadDecoder(payloadDecoder); + return this; } } - public static class ServerRSocketFactory { - private SocketAcceptor acceptor; - private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; - private Consumer errorConsumer = Throwable::printStackTrace; - private int mtu = 0; - private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins()); - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + /** Factory to create, configure, and start an RSocket server. */ + public static class ServerRSocketFactory implements ServerTransportAcceptor { + private final RSocketServer server; + + private Resume resume; - private ServerRSocketFactory() {} + public ServerRSocketFactory() { + this(RSocketServer.create()); + } + + public ServerRSocketFactory(RSocketServer server) { + this.server = server; + } + /** + * @deprecated this method is deprecated and deliberately has no effect anymore. Right now, in + * order configure the custom {@link ByteBufAllocator} it is recommended to use the + * following setup for Reactor Netty based transport:
+ * 1. For Client:
+ *
{@code
+     * TcpClient.create()
+     *          ...
+     *          .bootstrap(bootstrap -> bootstrap.option(ChannelOption.ALLOCATOR, clientAllocator))
+     * }
+ *
+ * 2. For server:
+ *
{@code
+     * TcpServer.create()
+     *          ...
+     *          .bootstrap(serverBootstrap -> serverBootstrap.childOption(ChannelOption.ALLOCATOR, serverAllocator))
+     * }
+ * Or in case of local transport, to use corresponding factory method {@code + * LocalClientTransport.creat(String, ByteBufAllocator)} + * @param allocator instance of {@link ByteBufAllocator} + * @return this factory instance + */ + @Deprecated public ServerRSocketFactory byteBufAllocator(ByteBufAllocator allocator) { - Objects.requireNonNull(allocator); - this.allocator = allocator; return this; } public ServerRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - plugins.addConnectionPlugin(interceptor); + server.interceptors(registry -> registry.forConnection(interceptor)); return this; } - + /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ + @Deprecated public ServerRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { - plugins.addClientPlugin(interceptor); + return addRequesterPlugin(interceptor); + } + + public ServerRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { + server.interceptors(registry -> registry.forRequester(interceptor)); return this; } + /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ + @Deprecated public ServerRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { - plugins.addServerPlugin(interceptor); + return addResponderPlugin(interceptor); + } + + public ServerRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { + server.interceptors(registry -> registry.forResponder(interceptor)); + return this; + } + + public ServerRSocketFactory addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { + server.interceptors(registry -> registry.forSocketAcceptor(interceptor)); return this; } public ServerTransportAcceptor acceptor(SocketAcceptor acceptor) { - this.acceptor = acceptor; - return ServerStart::new; + server.acceptor(acceptor); + return this; } public ServerRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) { - this.payloadDecoder = payloadDecoder; + server.payloadDecoder(payloadDecoder); return this; } public ServerRSocketFactory fragment(int mtu) { - this.mtu = mtu; + server.fragment(mtu); return this; } + /** + * @deprecated this handler is deliberately no-ops and is deprecated with no replacement. In + * order to observe errors, it is recommended to add error handler using {@code doOnError} + * on the specific logical stream. In order to observe connection, or RSocket terminal + * errors, it is recommended to hook on {@link Closeable#onClose()} handler. + */ public ServerRSocketFactory errorConsumer(Consumer errorConsumer) { - this.errorConsumer = errorConsumer; - return this; - } - - private class ServerStart implements Start { - private final Supplier> transportServer; - - ServerStart(Supplier> transportServer) { - this.transportServer = transportServer; - } - - @Override - public Mono start() { - return transportServer - .get() - .start( - connection -> { - if (mtu > 0) { - connection = new FragmentationDuplexConnection(connection, mtu); - } - - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, plugins); - - return multiplexer - .asStreamZeroConnection() - .receive() - .next() - .flatMap(setupFrame -> processSetupFrame(multiplexer, setupFrame)); - }); - } - - private Mono processSetupFrame( - ClientServerInputMultiplexer multiplexer, ByteBuf setupFrame) { - int version = SetupFrameFlyweight.version(setupFrame); - if (version != SetupFrameFlyweight.CURRENT_VERSION) { - setupFrame.release(); - InvalidSetupException error = - new InvalidSetupException( - "Unsupported version " + VersionFlyweight.toString(version)); - return multiplexer - .asStreamZeroConnection() - .sendOne(ErrorFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 0, error)) - .doFinally(signalType -> multiplexer.dispose()); - } - - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame); - int keepAliveInterval = setupPayload.keepAliveInterval(); - int keepAliveMaxLifetime = setupPayload.keepAliveMaxLifetime(); - - RSocketClient rSocketClient = - new RSocketClient( - allocator, - multiplexer.asServerConnection(), - payloadDecoder, - errorConsumer, - StreamIdSupplier.serverSupplier()); - - RSocket wrappedRSocketClient = plugins.applyClient(rSocketClient); - - return acceptor - .accept(setupPayload, wrappedRSocketClient) - .onErrorResume( - err -> - multiplexer - .asStreamZeroConnection() - .sendOne(rejectedSetupErrorFrame(err)) - .then(Mono.error(err))) - .doOnNext( - unwrappedServerSocket -> { - RSocket wrappedRSocketServer = plugins.applyServer(unwrappedServerSocket); - - RSocketServer rSocketServer = - new RSocketServer( - allocator, - multiplexer.asClientConnection(), - wrappedRSocketServer, - payloadDecoder, - errorConsumer, - keepAliveInterval, - keepAliveMaxLifetime); - }) - .doFinally(signalType -> setupPayload.release()) - .then(); - } - - private ByteBuf rejectedSetupErrorFrame(Throwable err) { - String msg = err.getMessage(); - return ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 0, - new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg)); - } + return this; + } + + public ServerRSocketFactory lease(Supplier> supplier) { + server.lease(supplier); + return this; + } + + public ServerRSocketFactory lease() { + server.lease(Leases::new); + return this; + } + + /** @deprecated without a replacement and no longer used. */ + @Deprecated + public ServerRSocketFactory singleSubscriberRequester() { + return this; + } + + public ServerRSocketFactory resume() { + resume = resume != null ? resume : new Resume(); + server.resume(resume); + return this; + } + + public ServerRSocketFactory resumeStore( + Function storeFactory) { + resume(); + resume.storeFactory(storeFactory); + return this; + } + + public ServerRSocketFactory resumeSessionDuration(Duration sessionDuration) { + resume(); + resume.sessionDuration(sessionDuration); + return this; + } + + public ServerRSocketFactory resumeStreamTimeout(Duration streamTimeout) { + resume(); + resume.streamTimeout(streamTimeout); + return this; + } + + public ServerRSocketFactory resumeCleanupOnKeepAlive() { + resume(); + resume.cleanupStoreOnKeepAlive(); + return this; + } + + @Override + public ServerTransport.ConnectionAcceptor toConnectionAcceptor() { + return server.asConnectionAcceptor(); + } + + @Override + public Start transport(Supplier> transport) { + return () -> server.bind(transport.get()); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java deleted file mode 100644 index 2b0eadaf2..000000000 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ /dev/null @@ -1,420 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.collection.IntObjectHashMap; -import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.exceptions.ConnectionErrorException; -import io.rsocket.frame.*; -import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.LimitableRequestPublisher; -import io.rsocket.internal.UnboundedProcessor; -import java.util.Collections; -import java.util.Map; -import java.util.function.Consumer; -import org.reactivestreams.Processor; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.Disposable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.SignalType; -import reactor.core.publisher.UnicastProcessor; - -/** Server side RSocket. Receives {@link ByteBuf}s from a {@link RSocketClient} */ -class RSocketServer implements RSocket { - - private final DuplexConnection connection; - private final RSocket requestHandler; - private final PayloadDecoder payloadDecoder; - private final Consumer errorConsumer; - - private final Map sendingSubscriptions; - private final Map> channelProcessors; - - private final UnboundedProcessor sendProcessor; - private final ByteBufAllocator allocator; - private KeepAliveHandler keepAliveHandler; - - /*client responder*/ - RSocketServer( - ByteBufAllocator allocator, - DuplexConnection connection, - RSocket requestHandler, - PayloadDecoder payloadDecoder, - Consumer errorConsumer) { - this(allocator, connection, requestHandler, payloadDecoder, errorConsumer, 0, 0); - } - - /*server responder*/ - RSocketServer( - ByteBufAllocator allocator, - DuplexConnection connection, - RSocket requestHandler, - PayloadDecoder payloadDecoder, - Consumer errorConsumer, - long tickPeriod, - long ackTimeout) { - this.allocator = allocator; - this.connection = connection; - this.requestHandler = requestHandler; - this.payloadDecoder = payloadDecoder; - this.errorConsumer = errorConsumer; - this.sendingSubscriptions = Collections.synchronizedMap(new IntObjectHashMap<>()); - this.channelProcessors = Collections.synchronizedMap(new IntObjectHashMap<>()); - - // DO NOT Change the order here. The Send processor must be subscribed to before receiving - // connections - this.sendProcessor = new UnboundedProcessor<>(); - - connection - .send(sendProcessor) - .doFinally(this::handleSendProcessorCancel) - .subscribe(null, this::handleSendProcessorError); - - Disposable receiveDisposable = connection.receive().subscribe(this::handleFrame, errorConsumer); - - this.connection - .onClose() - .doFinally( - s -> { - cleanup(); - receiveDisposable.dispose(); - }) - .subscribe(null, errorConsumer); - - if (tickPeriod != 0) { - keepAliveHandler = - KeepAliveHandler.ofServer(new KeepAliveHandler.KeepAlive(tickPeriod, ackTimeout)); - - keepAliveHandler - .timeout() - .subscribe( - keepAlive -> { - String message = - String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis()); - errorConsumer.accept(new ConnectionErrorException(message)); - connection.dispose(); - }); - keepAliveHandler.send().subscribe(sendProcessor::onNext); - } else { - keepAliveHandler = null; - } - } - - private void handleSendProcessorError(Throwable t) { - sendingSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - - channelProcessors - .values() - .forEach( - subscription -> { - try { - subscription.onError(t); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - } - - private void handleSendProcessorCancel(SignalType t) { - if (SignalType.ON_ERROR == t) { - return; - } - - sendingSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - - channelProcessors - .values() - .forEach( - subscription -> { - try { - subscription.onComplete(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - } - - @Override - public Mono fireAndForget(Payload payload) { - try { - return requestHandler.fireAndForget(payload); - } catch (Throwable t) { - return Mono.error(t); - } - } - - @Override - public Mono requestResponse(Payload payload) { - try { - return requestHandler.requestResponse(payload); - } catch (Throwable t) { - return Mono.error(t); - } - } - - @Override - public Flux requestStream(Payload payload) { - try { - return requestHandler.requestStream(payload); - } catch (Throwable t) { - return Flux.error(t); - } - } - - @Override - public Flux requestChannel(Publisher payloads) { - try { - return requestHandler.requestChannel(payloads); - } catch (Throwable t) { - return Flux.error(t); - } - } - - @Override - public Mono metadataPush(Payload payload) { - try { - return requestHandler.metadataPush(payload); - } catch (Throwable t) { - return Mono.error(t); - } - } - - @Override - public void dispose() { - connection.dispose(); - } - - @Override - public boolean isDisposed() { - return connection.isDisposed(); - } - - @Override - public Mono onClose() { - return connection.onClose(); - } - - private void cleanup() { - if (keepAliveHandler != null) { - keepAliveHandler.dispose(); - } - cleanUpSendingSubscriptions(); - cleanUpChannelProcessors(); - - requestHandler.dispose(); - sendProcessor.dispose(); - } - - private synchronized void cleanUpSendingSubscriptions() { - sendingSubscriptions.values().forEach(Subscription::cancel); - sendingSubscriptions.clear(); - } - - private synchronized void cleanUpChannelProcessors() { - channelProcessors.values().forEach(Processor::onComplete); - channelProcessors.clear(); - } - - private void handleFrame(ByteBuf frame) { - try { - int streamId = FrameHeaderFlyweight.streamId(frame); - Subscriber receiver; - FrameType frameType = FrameHeaderFlyweight.frameType(frame); - switch (frameType) { - case REQUEST_FNF: - handleFireAndForget(streamId, fireAndForget(payloadDecoder.apply(frame))); - break; - case REQUEST_RESPONSE: - handleRequestResponse(streamId, requestResponse(payloadDecoder.apply(frame))); - break; - case CANCEL: - handleCancelFrame(streamId); - break; - case KEEPALIVE: - handleKeepAliveFrame(frame); - break; - case REQUEST_N: - handleRequestN(streamId, frame); - break; - case REQUEST_STREAM: - handleStream( - streamId, - requestStream(payloadDecoder.apply(frame)), - RequestStreamFrameFlyweight.initialRequestN(frame)); - break; - case REQUEST_CHANNEL: - handleChannel( - streamId, - payloadDecoder.apply(frame), - RequestChannelFrameFlyweight.initialRequestN(frame)); - break; - case METADATA_PUSH: - metadataPush(payloadDecoder.apply(frame)); - break; - case PAYLOAD: - // TODO: Hook in receiving socket. - break; - case LEASE: - // Lease must not be received here as this is the server end of the socket which sends - // leases. - break; - case NEXT: - receiver = channelProcessors.get(streamId); - if (receiver != null) { - receiver.onNext(payloadDecoder.apply(frame)); - } - break; - case COMPLETE: - receiver = channelProcessors.get(streamId); - if (receiver != null) { - receiver.onComplete(); - } - break; - case ERROR: - receiver = channelProcessors.get(streamId); - if (receiver != null) { - receiver.onError(new ApplicationErrorException(ErrorFrameFlyweight.dataUtf8(frame))); - } - break; - case NEXT_COMPLETE: - receiver = channelProcessors.get(streamId); - if (receiver != null) { - receiver.onNext(payloadDecoder.apply(frame)); - receiver.onComplete(); - } - break; - case SETUP: - handleError(streamId, new IllegalStateException("Setup frame received post setup.")); - break; - default: - handleError( - streamId, - new IllegalStateException("ServerRSocket: Unexpected frame type: " + frameType)); - break; - } - } finally { - ReferenceCountUtil.safeRelease(frame); - } - } - - private void handleFireAndForget(int streamId, Mono result) { - result - .doOnSubscribe(subscription -> sendingSubscriptions.put(streamId, subscription)) - .doFinally(signalType -> sendingSubscriptions.remove(streamId)) - .subscribe(null, errorConsumer); - } - - private void handleRequestResponse(int streamId, Mono response) { - response - .doOnSubscribe(subscription -> sendingSubscriptions.put(streamId, subscription)) - .map(payload -> PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload)) - .switchIfEmpty( - Mono.fromCallable(() -> PayloadFrameFlyweight.encodeComplete(allocator, streamId))) - .doFinally(signalType -> sendingSubscriptions.remove(streamId)) - .subscribe(t1 -> sendProcessor.onNext(t1), t -> handleError(streamId, t)); - } - - private void handleStream(int streamId, Flux response, int initialRequestN) { - response - .transform( - frameFlux -> { - LimitableRequestPublisher payloads = - LimitableRequestPublisher.wrap(frameFlux); - sendingSubscriptions.put(streamId, payloads); - payloads.increaseRequestLimit(initialRequestN); - return payloads; - }) - .doFinally(signalType -> sendingSubscriptions.remove(streamId)) - .subscribe( - payload -> - sendProcessor.onNext( - PayloadFrameFlyweight.encodeNext(allocator, streamId, payload)), - t -> handleError(streamId, t), - () -> sendProcessor.onNext(PayloadFrameFlyweight.encodeComplete(allocator, streamId))); - } - - private void handleChannel(int streamId, Payload payload, int initialRequestN) { - UnicastProcessor frames = UnicastProcessor.create(); - channelProcessors.put(streamId, frames); - - Flux payloads = - frames - .doOnCancel( - () -> sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId))) - .doOnError(t -> handleError(streamId, t)) - .doOnRequest( - l -> sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, l))) - .doFinally(signalType -> channelProcessors.remove(streamId)); - - // not chained, as the payload should be enqueued in the Unicast processor before this method - // returns - // and any later payload can be processed - frames.onNext(payload); - - handleStream(streamId, requestChannel(payloads), initialRequestN); - } - - private void handleKeepAliveFrame(ByteBuf frame) { - if (keepAliveHandler != null) { - keepAliveHandler.receive(frame); - } - } - - private void handleCancelFrame(int streamId) { - Subscription subscription = sendingSubscriptions.remove(streamId); - if (subscription != null) { - subscription.cancel(); - } - } - - private void handleError(int streamId, Throwable t) { - errorConsumer.accept(t); - sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t)); - } - - private void handleRequestN(int streamId, ByteBuf frame) { - final Subscription subscription = sendingSubscriptions.get(streamId); - if (subscription != null) { - int n = RequestNFrameFlyweight.requestN(frame); - subscription.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java b/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java index f98901472..22697f130 100644 --- a/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java @@ -1,12 +1,17 @@ package io.rsocket; +import java.util.function.BiFunction; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; /** * Extends the {@link RSocket} that allows an implementer to peek at the first request payload of a * channel. + * + * @deprecated as of 1.0 RC7 in favor of using {@link RSocket#requestChannel(Publisher)} with {@link + * Flux#switchOnFirst(BiFunction)} */ +@Deprecated public interface ResponderRSocket extends RSocket { /** * Implement this method to peak at the first payload of the incoming request stream without diff --git a/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java b/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java index 0f6b99d0e..a42626e78 100644 --- a/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,24 +17,77 @@ package io.rsocket; import io.rsocket.exceptions.SetupException; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** - * {@code RSocket} is a full duplex protocol where a client and server are identical in terms of - * both having the capability to initiate requests to their peer. This interface provides the - * contract where a server accepts a new {@code RSocket} for sending requests to the peer and - * returns a new {@code RSocket} that will be used to accept requests from it's peer. + * RSocket is a full duplex protocol where a client and server are identical in terms of both having + * the capability to initiate requests to their peer. This interface provides the contract where a + * client or server handles the {@code setup} for a new connection and creates a responder {@code + * RSocket} for accepting requests from the remote peer. */ public interface SocketAcceptor { /** - * Accepts a new {@code RSocket} used to send requests to the peer and returns another {@code - * RSocket} that is used for accepting requests from the peer. + * Handle the {@code SETUP} frame for a new connection and create a responder {@code RSocket} for + * handling requests from the remote peer. * - * @param setup Setup as sent by the client. - * @param sendingSocket Socket used to send requests to the peer. - * @return Socket to accept requests from the peer. + * @param setup the {@code setup} received from a client in a server scenario, or in a client + * scenario this is the setup about to be sent to the server. + * @param sendingSocket socket for sending requests to the remote peer. + * @return {@code RSocket} to accept requests with. * @throws SetupException If the acceptor needs to reject the setup of this socket. */ Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket); + + /** Create a {@code SocketAcceptor} that handles requests with the given {@code RSocket}. */ + static SocketAcceptor with(RSocket rsocket) { + return (setup, sendingRSocket) -> Mono.just(rsocket); + } + + /** Create a {@code SocketAcceptor} for fire-and-forget interactions with the given handler. */ + static SocketAcceptor forFireAndForget(Function> handler) { + return with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-response interactions with the given handler. */ + static SocketAcceptor forRequestResponse(Function> handler) { + return with( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-stream interactions with the given handler. */ + static SocketAcceptor forRequestStream(Function> handler) { + return with( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-channel interactions with the given handler. */ + static SocketAcceptor forRequestChannel(Function, Flux> handler) { + return with( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return handler.apply(payloads); + } + }); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java new file mode 100644 index 000000000..9b5647c6f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java @@ -0,0 +1,119 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.SetupFrameCodec; + +/** + * Default implementation of {@link ConnectionSetupPayload}. Primarily for internal use within + * RSocket Java but may be created in an application, e.g. for testing purposes. + */ +public class DefaultConnectionSetupPayload extends ConnectionSetupPayload { + + private final ByteBuf setupFrame; + + public DefaultConnectionSetupPayload(ByteBuf setupFrame) { + this.setupFrame = setupFrame; + } + + @Override + public boolean hasMetadata() { + return FrameHeaderCodec.hasMetadata(setupFrame); + } + + @Override + public ByteBuf sliceMetadata() { + final ByteBuf metadata = SetupFrameCodec.metadata(setupFrame); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + return SetupFrameCodec.data(setupFrame); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public String metadataMimeType() { + return SetupFrameCodec.metadataMimeType(setupFrame); + } + + @Override + public String dataMimeType() { + return SetupFrameCodec.dataMimeType(setupFrame); + } + + @Override + public int keepAliveInterval() { + return SetupFrameCodec.keepAliveInterval(setupFrame); + } + + @Override + public int keepAliveMaxLifetime() { + return SetupFrameCodec.keepAliveMaxLifetime(setupFrame); + } + + @Override + public int getFlags() { + return FrameHeaderCodec.flags(setupFrame); + } + + @Override + public boolean willClientHonorLease() { + return SetupFrameCodec.honorLease(setupFrame); + } + + @Override + public boolean isResumeEnabled() { + return SetupFrameCodec.resumeEnabled(setupFrame); + } + + @Override + public ByteBuf resumeToken() { + return SetupFrameCodec.resumeToken(setupFrame); + } + + @Override + public ConnectionSetupPayload touch() { + setupFrame.touch(); + return this; + } + + @Override + public ConnectionSetupPayload touch(Object hint) { + setupFrame.touch(hint); + return this; + } + + @Override + protected void deallocate() { + setupFrame.release(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java new file mode 100644 index 000000000..2d2b96f7e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -0,0 +1,32 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; + +final class PayloadValidationUtils { + static final String INVALID_PAYLOAD_ERROR_MESSAGE = + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."; + + static boolean isValid(int mtu, Payload payload) { + if (mtu > 0) { + return true; + } + + if (payload.hasMetadata()) { + return (((FrameHeaderCodec.size() + + FrameLengthCodec.FRAME_LENGTH_SIZE + + FrameHeaderCodec.size() + + payload.data().readableBytes() + + payload.metadata().readableBytes()) + & ~FrameLengthCodec.FRAME_LENGTH_MASK) + == 0); + } else { + return (((FrameHeaderCodec.size() + + payload.data().readableBytes() + + FrameLengthCodec.FRAME_LENGTH_SIZE) + & ~FrameLengthCodec.FRAME_LENGTH_MASK) + == 0); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java new file mode 100644 index 000000000..02f1a51cc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -0,0 +1,580 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.lease.LeaseStats; +import io.rsocket.lease.Leases; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.resume.ClientRSocketSession; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.util.annotation.Nullable; +import reactor.util.retry.Retry; + +/** + * The main class to use to establish a connection to an RSocket server. + * + *

To connect over TCP using default settings: + * + *

{@code
+ * import io.rsocket.transport.netty.client.TcpClientTransport;
+ *
+ * Mono rocketMono =
+ *         RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000));
+ * }
+ * + *

To customize connection settings before connecting: + * + *

{@code
+ * Mono rocketMono =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ * }
+ */ +public class RSocketConnector { + private static final String CLIENT_TAG = "client"; + + private static final BiConsumer INVALIDATE_FUNCTION = + (r, i) -> r.onClose().subscribe(null, __ -> i.invalidate(), i::invalidate); + + private Payload setupPayload = EmptyPayload.INSTANCE; + private String metadataMimeType = "application/binary"; + private String dataMimeType = "application/binary"; + private Duration keepAliveInterval = Duration.ofSeconds(20); + private Duration keepAliveMaxLifeTime = Duration.ofSeconds(90); + + @Nullable private SocketAcceptor acceptor; + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Retry retrySpec; + private Resume resume; + private Supplier> leasesSupplier; + + private int mtu = 0; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + + private RSocketConnector() {} + + /** + * Static factory method to create an {@code RSocketConnector} instance and customize default + * settings before connecting. To connect only, use {@link #connectWith(ClientTransport)}. + */ + public static RSocketConnector create() { + return new RSocketConnector(); + } + + /** + * Static factory method to connect with default settings, effectively a shortcut for: + * + *
+   * RSocketConnector.create().connectWith(transport);
+   * 
+ * + * @param transport the transport of choice to connect with + * @return a {@code Mono} with the connected RSocket + */ + public static Mono connectWith(ClientTransport transport) { + return RSocketConnector.create().connect(() -> transport); + } + + /** + * Provide a {@code Payload} with data and/or metadata for the initial {@code SETUP} frame. Data + * and metadata should be formatted according to the MIME types specified via {@link + * #dataMimeType(String)} and {@link #metadataMimeType(String)}. + * + * @param payload the payload containing data and/or metadata for the {@code SETUP} frame + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector setupPayload(Payload payload) { + this.setupPayload = Objects.requireNonNull(payload); + return this; + } + + /** + * Set the MIME type to use for formatting payload data on the established connection. This is set + * in the initial {@code SETUP} frame sent to the server. + * + *

By default this is set to {@code "application/binary"}. + * + * @param dataMimeType the MIME type to be used for payload data + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector dataMimeType(String dataMimeType) { + this.dataMimeType = Objects.requireNonNull(dataMimeType); + return this; + } + + /** + * Set the MIME type to use for formatting payload metadata on the established connection. This is + * set in the initial {@code SETUP} frame sent to the server. + * + *

For metadata encoding, consider using one of the following encoders: + * + *

    + *
  • {@link io.rsocket.metadata.CompositeMetadataFlyweight Composite Metadata} + *
  • {@link io.rsocket.metadata.TaggingMetadataFlyweight Routing} + *
  • {@link io.rsocket.metadata.security.AuthMetadataFlyweight Authentication} + *
+ * + *

For more on the above metadata formats, see the corresponding protocol extensions + * + *

By default this is set to {@code "application/binary"}. + * + * @param metadataMimeType the MIME type to be used for payload metadata + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector metadataMimeType(String metadataMimeType) { + this.metadataMimeType = Objects.requireNonNull(metadataMimeType); + return this; + } + + /** + * Set the "Time Between {@code KEEPALIVE} Frames" which is how frequently {@code KEEPALIVE} + * frames should be emitted, and the "Max Lifetime" which is how long to allow between {@code + * KEEPALIVE} frames from the remote end before concluding that connectivity is lost. Both + * settings are specified in the initial {@code SETUP} frame sent to the server. The spec mentions + * the following: + * + *

    + *
  • For server-to-server connections, a reasonable time interval between client {@code + * KEEPALIVE} frames is 500ms. + *
  • For mobile-to-server connections, the time interval between client {@code KEEPALIVE} + * frames is often > 30,000ms. + *
+ * + *

By default these are set to 20 seconds and 90 seconds respectively. + * + * @param interval how frequently to emit KEEPALIVE frames + * @param maxLifeTime how long to allow between {@code KEEPALIVE} frames from the remote end + * before assuming that connectivity is lost; the value should be generous and allow for + * multiple missed {@code KEEPALIVE} frames. + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector keepAlive(Duration interval, Duration maxLifeTime) { + if (!interval.negated().isNegative()) { + throw new IllegalArgumentException("`interval` for keepAlive must be > 0"); + } + if (!maxLifeTime.negated().isNegative()) { + throw new IllegalArgumentException("`maxLifeTime` for keepAlive must be > 0"); + } + this.keepAliveInterval = interval; + this.keepAliveMaxLifeTime = maxLifeTime; + return this; + } + + /** + * Configure interception at one of the following levels: + * + *

    + *
  • Transport level + *
  • At the level of accepting new connections + *
  • Performing requests + *
  • Responding to requests + *
+ * + * @param configurer a configurer to customize interception with. + * @return the same instance for method chaining + * @see io.rsocket.plugins.LimitRateInterceptor + */ + public RSocketConnector interceptors(Consumer configurer) { + configurer.accept(this.interceptors); + return this; + } + + /** + * Configure a client-side {@link SocketAcceptor} for responding to requests from the server. + * + *

A full-form example with access to the {@code SETUP} frame and the "sending" RSocket (the + * same as the one returned from {@link #connect(ClientTransport)}): + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor((setup, sendingRSocket) -> Mono.just(new RSocket() {...}))
+   *             .connect(transport);
+   * }
+ * + *

A shortcut example with just the handling RSocket: + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor(SocketAcceptor.with(new RSocket() {...})))
+   *             .connect(transport);
+   * }
+ * + *

A shortcut example handling only request-response: + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor(SocketAcceptor.forRequestResponse(payload -> ...))
+   *             .connect(transport);
+   * }
+ * + *

By default, {@code new RSocket(){}} is used which rejects all requests from the server with + * {@link UnsupportedOperationException}. + * + * @param acceptor the acceptor to use for responding to server requests + * @return the same instance for method chaining + */ + public RSocketConnector acceptor(SocketAcceptor acceptor) { + this.acceptor = acceptor; + return this; + } + + /** + * When this is enabled, the connect methods of this class return a special {@code Mono} + * that maintains a single, shared {@code RSocket} for all subscribers: + * + *

{@code
+   * Mono rsocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = rsocketMono.block();
+   *  RSocket r2 = rsocketMono.block();
+   *
+   *  assert r1 == r2;
+   * }
+ * + *

The {@code RSocket} remains cached until the connection is lost and after that, new attempts + * to subscribe or re-subscribe trigger a reconnect and result in a new shared {@code RSocket}: + * + *

{@code
+   * Mono rsocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = rsocketMono.block();
+   *  RSocket r2 = rsocketMono.block();
+   *
+   *  r1.dispose();
+   *
+   *  RSocket r3 = rsocketMono.block();
+   *  RSocket r4 = rsocketMono.block();
+   *
+   *  assert r1 == r2;
+   *  assert r3 == r4;
+   *  assert r1 != r3;
+   *
+   * }
+ * + *

Downstream subscribers for individual requests still need their own retry logic to determine + * if or when failed requests should be retried which in turn triggers the shared reconnect: + * + *

{@code
+   * Mono rocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  rsocketMono.flatMap(rsocket -> rsocket.requestResponse(...))
+   *           .retryWhen(Retry.fixedDelay(1, Duration.ofSeconds(5)))
+   *           .subscribe()
+   * }
+ * + *

Note: this feature is mutually exclusive with {@link #resume(Resume)}. If + * both are enabled, "resume" takes precedence. Consider using "reconnect" when the server does + * not have "resume" enabled or supported, or when you don't need to incur the overhead of saving + * in-flight frames to be potentially replayed after a reconnect. + * + *

By default this is not enabled in which case a new connection is obtained per subscriber. + * + * @param retry a retry spec that declares the rules for reconnecting + * @return the same instance for method chaining + */ + public RSocketConnector reconnect(Retry retry) { + this.retrySpec = Objects.requireNonNull(retry); + return this; + } + + /** + * Enables the Resume capability of the RSocket protocol where if the client gets disconnected, + * the connection is re-acquired and any interrupted streams are resumed automatically. For this + * to work the server must also support and have the Resume capability enabled. + * + *

See {@link Resume} for settings to customize the Resume capability. + * + *

Note: this feature is mutually exclusive with {@link #reconnect(Retry)}. If + * both are enabled, "resume" takes precedence. Consider using "reconnect" when the server does + * not have "resume" enabled or supported, or when you don't need to incur the overhead of saving + * in-flight frames to be potentially replayed after a reconnect. + * + *

By default this is not enabled. + * + * @param resume configuration for the Resume capability + * @return the same instance for method chaining + * @see Resuming + * Operation + */ + public RSocketConnector resume(Resume resume) { + this.resume = resume; + return this; + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. + * + *

Example usage: + * + *

{@code
+   * Mono rocketMono =
+   *         RSocketConnector.create().lease(Leases::new).connect(transport);
+   * }
+ * + *

By default this is not enabled. + * + * @param supplier supplier for a {@link Leases} + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketConnector lease(Supplier> supplier) { + this.leasesSupplier = supplier; + return this; + } + + /** + * When this is set, frames larger than the given maximum transmission unit (mtu) size value are + * broken down into fragments to fit that size. + * + *

By default this is not set in which case payloads are sent whole up to the maximum frame + * size of 16,777,215 bytes. + * + * @param mtu the threshold size for fragmentation, must be no less than 64 + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketConnector fragment(int mtu) { + if (mtu > 0 && mtu < FragmentationDuplexConnection.MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format( + "The smallest allowed mtu size is %d bytes, provided: %d", + FragmentationDuplexConnection.MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } + this.mtu = mtu; + return this; + } + + /** + * Configure the {@code PayloadDecoder} used to create {@link Payload}'s from incoming raw frame + * buffers. The following decoders are available: + * + *

    + *
  • {@link PayloadDecoder#DEFAULT} -- the data and metadata are independent copies of the + * underlying frame {@link ByteBuf} + *
  • {@link PayloadDecoder#ZERO_COPY} -- the data and metadata are retained slices of the + * underlying {@link ByteBuf}. That's more efficient but requires careful tracking and + * {@link Payload#release() release} of the payload when no longer needed. + *
+ * + *

By default this is set to {@link PayloadDecoder#DEFAULT} in which case data and metadata are + * copied and do not need to be tracked and released. + * + * @param decoder the decoder to use + * @return the same instance for method chaining + */ + public RSocketConnector payloadDecoder(PayloadDecoder decoder) { + Objects.requireNonNull(decoder); + this.payloadDecoder = decoder; + return this; + } + + /** + * The final step to connect with the transport to use as input and the resulting {@code + * Mono} as output. Each subscriber to the returned {@code Mono} starts a new connection + * if neither {@link #reconnect(Retry) reconnect} nor {@link #resume(Resume)} are enabled. + * + *

The following transports are available (via additional RSocket Java modules): + * + *

    + *
  • {@link io.rsocket.transport.netty.client.TcpClientTransport TcpClientTransport} via + * {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.netty.client.WebsocketClientTransport + * WebsocketClientTransport} via {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.local.LocalClientTransport LocalClientTransport} via {@code + * rsocket-transport-local} + *
+ * + * @param transport the transport of choice to connect with + * @return a {@code Mono} with the connected RSocket + */ + public Mono connect(ClientTransport transport) { + return connect(() -> transport); + } + + /** + * Variant of {@link #connect(ClientTransport)} with a {@link Supplier} for the {@code + * ClientTransport}. + * + *

// TODO: when to use? + * + * @param transportSupplier supplier for the transport to connect with + * @return a {@code Mono} with the connected RSocket + */ + public Mono connect(Supplier transportSupplier) { + Mono connectionMono = + Mono.fromSupplier(transportSupplier).flatMap(t -> t.connect(mtu)); + return connectionMono + .flatMap( + connection -> { + ByteBuf resumeToken; + KeepAliveHandler keepAliveHandler; + DuplexConnection wrappedConnection; + + if (resume != null) { + resumeToken = resume.getTokenSupplier().get(); + ClientRSocketSession session = + new ClientRSocketSession( + connection, + resume.getSessionDuration(), + resume.getRetry(), + resume.getStoreFactory(CLIENT_TAG).apply(resumeToken), + resume.getStreamTimeout(), + resume.isCleanupStoreOnKeepAlive()) + .continueWith(connectionMono) + .resumeToken(resumeToken); + keepAliveHandler = + new KeepAliveHandler.ResumableKeepAliveHandler(session.resumableConnection()); + wrappedConnection = session.resumableConnection(); + } else { + resumeToken = Unpooled.EMPTY_BUFFER; + keepAliveHandler = new KeepAliveHandler.DefaultKeepAliveHandler(connection); + wrappedConnection = connection; + } + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(wrappedConnection, interceptors, true); + + boolean leaseEnabled = leasesSupplier != null; + Leases leases = leaseEnabled ? leasesSupplier.get() : null; + RequesterLeaseHandler requesterLeaseHandler = + leaseEnabled + ? new RequesterLeaseHandler.Impl(CLIENT_TAG, leases.receiver()) + : RequesterLeaseHandler.None; + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), + payloadDecoder, + StreamIdSupplier.clientSupplier(), + mtu, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + keepAliveHandler, + requesterLeaseHandler, + Schedulers.single(Schedulers.parallel())); + + RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); + + ByteBuf setupFrame = + SetupFrameCodec.encode( + wrappedConnection.alloc(), + leaseEnabled, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + resumeToken, + metadataMimeType, + dataMimeType, + setupPayload); + + SocketAcceptor acceptor = + this.acceptor != null ? this.acceptor : SocketAcceptor.with(new RSocket() {}); + + ConnectionSetupPayload setup = new DefaultConnectionSetupPayload(setupFrame); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setup, wrappedRSocketRequester) + .flatMap( + rSocketHandler -> { + RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); + + ResponderLeaseHandler responderLeaseHandler = + leaseEnabled + ? new ResponderLeaseHandler.Impl<>( + CLIENT_TAG, + wrappedConnection.alloc(), + leases.sender(), + leases.stats()) + : ResponderLeaseHandler.None; + + RSocket rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + wrappedRSocketHandler, + payloadDecoder, + responderLeaseHandler, + mtu); + + return wrappedConnection + .sendOne(setupFrame) + .thenReturn(wrappedRSocketRequester); + }); + }) + .as( + source -> { + if (retrySpec != null) { + return new ReconnectMono<>( + source.retryWhen(retrySpec), Disposable::dispose, INVALIDATE_FUNCTION); + } else { + return source; + } + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java new file mode 100644 index 000000000..f7fb161fd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -0,0 +1,773 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; +import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.SynchronizedIntObjectHashMap; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.keepalive.KeepAliveFramesAcceptor; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.lease.RequesterLeaseHandler; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Processor; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.core.publisher.UnicastProcessor; +import reactor.core.scheduler.Scheduler; +import reactor.util.annotation.Nullable; +import reactor.util.concurrent.Queues; + +/** + * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer + */ +class RSocketRequester implements RSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketRequester.class); + + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + referenceCounted -> { + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + }; + + static { + CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); + } + + private volatile Throwable terminationError; + + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketRequester.class, Throwable.class, "terminationError"); + + private final DuplexConnection connection; + private final PayloadDecoder payloadDecoder; + private final StreamIdSupplier streamIdSupplier; + private final IntObjectMap senders; + private final IntObjectMap> receivers; + private final UnboundedProcessor sendProcessor; + private final int mtu; + private final RequesterLeaseHandler leaseHandler; + private final ByteBufAllocator allocator; + private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; + private final MonoProcessor onClose; + private final Scheduler serialScheduler; + + RSocketRequester( + DuplexConnection connection, + PayloadDecoder payloadDecoder, + StreamIdSupplier streamIdSupplier, + int mtu, + int keepAliveTickPeriod, + int keepAliveAckTimeout, + @Nullable KeepAliveHandler keepAliveHandler, + RequesterLeaseHandler leaseHandler, + Scheduler serialScheduler) { + this.connection = connection; + this.allocator = connection.alloc(); + this.payloadDecoder = payloadDecoder; + this.streamIdSupplier = streamIdSupplier; + this.mtu = mtu; + this.leaseHandler = leaseHandler; + this.senders = new SynchronizedIntObjectHashMap<>(); + this.receivers = new SynchronizedIntObjectHashMap<>(); + this.onClose = MonoProcessor.create(); + this.serialScheduler = serialScheduler; + + // DO NOT Change the order here. The Send processor must be subscribed to before receiving + this.sendProcessor = new UnboundedProcessor<>(); + + connection.onClose().subscribe(null, this::tryTerminateOnConnectionError, this::tryShutdown); + connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); + + connection.receive().subscribe(this::handleIncomingFrames, e -> {}); + + if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { + KeepAliveSupport keepAliveSupport = + new ClientKeepAliveSupport(this.allocator, keepAliveTickPeriod, keepAliveAckTimeout); + this.keepAliveFramesAcceptor = + keepAliveHandler.start( + keepAliveSupport, sendProcessor::onNextPrioritized, this::tryTerminateOnKeepAlive); + } else { + keepAliveFramesAcceptor = null; + } + } + + @Override + public Mono fireAndForget(Payload payload) { + return handleFireAndForget(payload); + } + + @Override + public Mono requestResponse(Payload payload) { + return handleRequestResponse(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return handleRequestStream(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return handleChannel(Flux.from(payloads)); + } + + @Override + public Mono metadataPush(Payload payload) { + return handleMetadataPush(payload); + } + + @Override + public double availability() { + return Math.min(connection.availability(), leaseHandler.availability()); + } + + @Override + public void dispose() { + tryShutdown(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + + private Mono handleFireAndForget(Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + + Throwable err = checkAvailable(); + if (err != null) { + payload.release(); + return Mono.error(err); + } + + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + + final AtomicBoolean once = new AtomicBoolean(); + + return Mono.defer( + () -> { + if (once.getAndSet(true)) { + return Mono.error( + new IllegalStateException("FireAndForgetMono allows only a single subscriber")); + } + + final int streamId = streamIdSupplier.nextStreamId(receivers); + final ByteBuf requestFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload( + allocator, streamId, payload); + + sendProcessor.onNext(requestFrame); + + return Mono.empty(); + }) + .subscribeOn(serialScheduler); + } + + private Mono handleRequestResponse(final Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + + Throwable err = checkAvailable(); + if (err != null) { + payload.release(); + return Mono.error(err); + } + + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + + final UnboundedProcessor sendProcessor = this.sendProcessor; + final UnicastProcessor receiver = UnicastProcessor.create(Queues.one().get()); + final AtomicBoolean once = new AtomicBoolean(); + + return Mono.defer( + () -> { + if (once.getAndSet(true)) { + return Mono.error( + new IllegalStateException("RequestResponseMono allows only a single subscriber")); + } + + return receiver + .next() + .transform( + Operators.lift( + (s, actual) -> + new RequestOperator(actual) { + + @Override + void hookOnFirstRequest(long n) { + int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload( + allocator, streamId, payload); + + receivers.put(streamId, receiver); + sendProcessor.onNext(requestResponseFrame); + } + + @Override + void hookOnCancel() { + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } else { + payload.release(); + } + } + + @Override + public void hookOnTerminal(SignalType signalType) { + receivers.remove(streamId, receiver); + } + })) + .subscribeOn(serialScheduler) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + }); + } + + private Flux handleRequestStream(final Payload payload) { + if (payload.refCnt() <= 0) { + return Flux.error(new IllegalReferenceCountException()); + } + + Throwable err = checkAvailable(); + if (err != null) { + payload.release(); + return Flux.error(err); + } + + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Flux.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + + final UnboundedProcessor sendProcessor = this.sendProcessor; + final UnicastProcessor receiver = UnicastProcessor.create(); + final AtomicBoolean once = new AtomicBoolean(); + + return Flux.defer( + () -> { + if (once.getAndSet(true)) { + return Flux.error( + new IllegalStateException("RequestStreamFlux allows only a single subscriber")); + } + + return receiver + .transform( + Operators.lift( + (s, actual) -> + new RequestOperator(actual) { + + @Override + void hookOnFirstRequest(long n) { + int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload( + allocator, streamId, n, payload); + + receivers.put(streamId, receiver); + + sendProcessor.onNext(requestStreamFrame); + } + + @Override + void hookOnRemainingRequests(long n) { + if (receiver.isDisposed()) { + return; + } + + sendProcessor.onNext( + RequestNFrameCodec.encode(allocator, streamId, n)); + } + + @Override + void hookOnCancel() { + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } else { + payload.release(); + } + } + + @Override + void hookOnTerminal(SignalType signalType) { + receivers.remove(streamId); + } + })) + .subscribeOn(serialScheduler, false) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + }); + } + + private Flux handleChannel(Flux request) { + Throwable err = checkAvailable(); + if (err != null) { + return Flux.error(err); + } + + return request + .switchOnFirst( + (s, flux) -> { + Payload payload = s.get(); + if (payload != null) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + return Mono.error(t); + } + return handleChannel(payload, flux); + } else { + return flux; + } + }, + false) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + } + + private Flux handleChannel(Payload initialPayload, Flux inboundFlux) { + final UnboundedProcessor sendProcessor = this.sendProcessor; + + final UnicastProcessor receiver = UnicastProcessor.create(); + + return receiver + .transform( + Operators.lift( + (s, actual) -> + new RequestOperator(actual) { + + final BaseSubscriber upstreamSubscriber = + new BaseSubscriber() { + + boolean first = true; + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // noops + } + + @Override + protected void hookOnNext(Payload payload) { + if (first) { + // need to skip first since we have already sent it + // no need to release it since it was released earlier on the + // request + // establishment + // phase + first = false; + request(1); + return; + } + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + // no need to send any errors. + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + receiver.onError(t); + return; + } + final ByteBuf frame = + PayloadFrameCodec.encodeNextReleasingPayload( + allocator, streamId, payload); + + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnComplete() { + ByteBuf frame = PayloadFrameCodec.encodeComplete(allocator, streamId); + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnError(Throwable t) { + ByteBuf frame = ErrorFrameCodec.encode(allocator, streamId, t); + sendProcessor.onNext(frame); + receiver.onError(t); + } + + @Override + protected void hookFinally(SignalType type) { + senders.remove(streamId, this); + } + }; + + @Override + void hookOnFirstRequest(long n) { + final int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + final ByteBuf frame = + RequestChannelFrameCodec.encodeReleasingPayload( + allocator, streamId, false, n, initialPayload); + + senders.put(streamId, upstreamSubscriber); + receivers.put(streamId, receiver); + + inboundFlux + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(upstreamSubscriber); + + sendProcessor.onNext(frame); + } + + @Override + void hookOnRemainingRequests(long n) { + if (receiver.isDisposed()) { + return; + } + + sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); + } + + @Override + void hookOnCancel() { + senders.remove(streamId, upstreamSubscriber); + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } + } + + @Override + void hookOnTerminal(SignalType signalType) { + if (signalType == SignalType.ON_ERROR) { + upstreamSubscriber.cancel(); + } + receivers.remove(streamId, receiver); + } + + @Override + public void cancel() { + upstreamSubscriber.cancel(); + super.cancel(); + } + })) + .subscribeOn(serialScheduler, false); + } + + private Mono handleMetadataPush(Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + + Throwable err = this.terminationError; + if (err != null) { + payload.release(); + return Mono.error(err); + } + + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + + final AtomicBoolean once = new AtomicBoolean(); + + return Mono.defer( + () -> { + if (once.getAndSet(true)) { + return Mono.error( + new IllegalStateException("MetadataPushMono allows only a single subscriber")); + } + + ByteBuf metadataPushFrame = + MetadataPushFrameCodec.encodeReleasingPayload(allocator, payload); + + sendProcessor.onNextPrioritized(metadataPushFrame); + + return Mono.empty(); + }); + } + + @Nullable + private Throwable checkAvailable() { + Throwable err = this.terminationError; + if (err != null) { + return err; + } + RequesterLeaseHandler lh = leaseHandler; + if (!lh.useLease()) { + return lh.leaseError(); + } + return null; + } + + private void handleIncomingFrames(ByteBuf frame) { + try { + int streamId = FrameHeaderCodec.streamId(frame); + FrameType type = FrameHeaderCodec.frameType(frame); + if (streamId == 0) { + handleStreamZero(type, frame); + } else { + handleFrame(streamId, type, frame); + } + frame.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frame); + throw reactor.core.Exceptions.propagate(t); + } + } + + private void handleStreamZero(FrameType type, ByteBuf frame) { + switch (type) { + case ERROR: + tryTerminateOnZeroError(frame); + break; + case LEASE: + leaseHandler.receive(frame); + break; + case KEEPALIVE: + if (keepAliveFramesAcceptor != null) { + keepAliveFramesAcceptor.receive(frame); + } + break; + default: + // Ignore unknown frames. Throwing an error will close the socket. + if (LOGGER.isInfoEnabled()) { + LOGGER.info("Requester received unsupported frame on stream 0: " + frame.toString()); + } + } + } + + private void handleFrame(int streamId, FrameType type, ByteBuf frame) { + Subscriber receiver = receivers.get(streamId); + switch (type) { + case NEXT: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onNext(payloadDecoder.apply(frame)); + break; + case NEXT_COMPLETE: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onNext(payloadDecoder.apply(frame)); + receiver.onComplete(); + break; + case COMPLETE: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onComplete(); + receivers.remove(streamId); + break; + case ERROR: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onError(Exceptions.from(streamId, frame)); + receivers.remove(streamId); + break; + case CANCEL: + { + Subscription sender = senders.remove(streamId); + if (sender != null) { + sender.cancel(); + } + break; + } + case REQUEST_N: + { + Subscription sender = senders.get(streamId); + if (sender != null) { + long n = RequestNFrameCodec.requestN(frame); + sender.request(n); + } + break; + } + default: + throw new IllegalStateException( + "Requester received unsupported frame on stream " + streamId + ": " + frame.toString()); + } + } + + private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBuf frame) { + if (!streamIdSupplier.isBeforeOrCurrent(streamId)) { + if (type == FrameType.ERROR) { + // message for stream that has never existed, we have a problem with + // the overall connection and must tear down + String errorMessage = ErrorFrameCodec.dataUtf8(frame); + + throw new IllegalStateException( + "Client received error for non-existent stream: " + + streamId + + " Message: " + + errorMessage); + } else { + throw new IllegalStateException( + "Client received message for non-existent stream: " + + streamId + + ", frame type: " + + type); + } + } + // receiving a frame after a given stream has been cancelled/completed, + // so ignore (cancellation is async so there is a race condition) + } + + private void tryTerminateOnKeepAlive(KeepAlive keepAlive) { + tryTerminate( + () -> + new ConnectionErrorException( + String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); + } + + private void tryTerminateOnConnectionError(Throwable e) { + tryTerminate(() -> e); + } + + private void tryTerminateOnZeroError(ByteBuf errorFrame) { + tryTerminate(() -> Exceptions.from(0, errorFrame)); + } + + private void tryTerminate(Supplier errorSupplier) { + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(e); + } + } + } + + private void tryShutdown() { + if (terminationError == null) { + if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) { + terminate(CLOSED_CHANNEL_EXCEPTION); + } + } + } + + private void terminate(Throwable e) { + connection.dispose(); + leaseHandler.dispose(); + + synchronized (receivers) { + receivers + .values() + .forEach( + receiver -> { + try { + receiver.onError(e); + } catch (Throwable t) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); + } + } + }); + } + synchronized (senders) { + senders + .values() + .forEach( + sender -> { + try { + sender.cancel(); + } catch (Throwable t) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); + } + } + }); + } + senders.clear(); + receivers.clear(); + sendProcessor.dispose(); + if (e == CLOSED_CHANNEL_EXCEPTION) { + onClose.onComplete(); + } else { + onClose.onError(e); + } + } + + private void handleSendProcessorError(Throwable t) { + connection.dispose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java new file mode 100644 index 000000000..d3860e5f2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -0,0 +1,610 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.*; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.SynchronizedIntObjectHashMap; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.lease.ResponderLeaseHandler; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import java.util.function.LongConsumer; +import java.util.function.Supplier; +import org.reactivestreams.Processor; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.*; +import reactor.util.annotation.Nullable; + +/** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ +class RSocketResponder implements RSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketResponder.class); + + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + referenceCounted -> { + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + }; + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + + private final DuplexConnection connection; + private final RSocket requestHandler; + + @SuppressWarnings("deprecation") + private final io.rsocket.ResponderRSocket responderRSocket; + + private final PayloadDecoder payloadDecoder; + private final ResponderLeaseHandler leaseHandler; + private final Disposable leaseHandlerDisposable; + + private volatile Throwable terminationError; + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketResponder.class, Throwable.class, "terminationError"); + + private final int mtu; + + private final IntObjectMap sendingSubscriptions; + private final IntObjectMap> channelProcessors; + + private final UnboundedProcessor sendProcessor; + private final ByteBufAllocator allocator; + + RSocketResponder( + DuplexConnection connection, + RSocket requestHandler, + PayloadDecoder payloadDecoder, + ResponderLeaseHandler leaseHandler, + int mtu) { + this.connection = connection; + this.allocator = connection.alloc(); + this.mtu = mtu; + + this.requestHandler = requestHandler; + this.responderRSocket = + (requestHandler instanceof io.rsocket.ResponderRSocket) + ? (io.rsocket.ResponderRSocket) requestHandler + : null; + + this.payloadDecoder = payloadDecoder; + this.leaseHandler = leaseHandler; + this.sendingSubscriptions = new SynchronizedIntObjectHashMap<>(); + this.channelProcessors = new SynchronizedIntObjectHashMap<>(); + + // DO NOT Change the order here. The Send processor must be subscribed to before receiving + // connections + this.sendProcessor = new UnboundedProcessor<>(); + + connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); + + connection.receive().subscribe(this::handleFrame, e -> {}); + leaseHandlerDisposable = leaseHandler.send(sendProcessor::onNextPrioritized); + + this.connection + .onClose() + .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose); + } + + private void handleSendProcessorError(Throwable t) { + sendingSubscriptions + .values() + .forEach( + subscription -> { + try { + subscription.cancel(); + } catch (Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); + } + } + }); + + channelProcessors + .values() + .forEach( + subscription -> { + try { + subscription.onError(t); + } catch (Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); + } + } + }); + } + + private void tryTerminateOnConnectionError(Throwable e) { + tryTerminate(() -> e); + } + + private void tryTerminateOnConnectionClose() { + tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); + } + + private void tryTerminate(Supplier errorSupplier) { + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + cleanup(e); + } + } + } + + @Override + public Mono fireAndForget(Payload payload) { + try { + if (leaseHandler.useLease()) { + return requestHandler.fireAndForget(payload); + } else { + payload.release(); + return Mono.error(leaseHandler.leaseError()); + } + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public Mono requestResponse(Payload payload) { + try { + if (leaseHandler.useLease()) { + return requestHandler.requestResponse(payload); + } else { + payload.release(); + return Mono.error(leaseHandler.leaseError()); + } + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public Flux requestStream(Payload payload) { + try { + if (leaseHandler.useLease()) { + return requestHandler.requestStream(payload); + } else { + payload.release(); + return Flux.error(leaseHandler.leaseError()); + } + } catch (Throwable t) { + return Flux.error(t); + } + } + + @Override + public Flux requestChannel(Publisher payloads) { + try { + if (leaseHandler.useLease()) { + return requestHandler.requestChannel(payloads); + } else { + return Flux.error(leaseHandler.leaseError()); + } + } catch (Throwable t) { + return Flux.error(t); + } + } + + private Flux requestChannel(Payload payload, Publisher payloads) { + try { + if (leaseHandler.useLease()) { + return responderRSocket.requestChannel(payload, payloads); + } else { + payload.release(); + return Flux.error(leaseHandler.leaseError()); + } + } catch (Throwable t) { + return Flux.error(t); + } + } + + @Override + public Mono metadataPush(Payload payload) { + try { + return requestHandler.metadataPush(payload); + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public void dispose() { + tryTerminate(() -> new CancellationException("Disposed")); + } + + @Override + public boolean isDisposed() { + return connection.isDisposed(); + } + + @Override + public Mono onClose() { + return connection.onClose(); + } + + private void cleanup(Throwable e) { + cleanUpSendingSubscriptions(); + cleanUpChannelProcessors(e); + + connection.dispose(); + leaseHandlerDisposable.dispose(); + requestHandler.dispose(); + sendProcessor.dispose(); + } + + private synchronized void cleanUpSendingSubscriptions() { + sendingSubscriptions.values().forEach(Subscription::cancel); + sendingSubscriptions.clear(); + } + + private synchronized void cleanUpChannelProcessors(Throwable e) { + channelProcessors + .values() + .forEach( + payloadPayloadProcessor -> { + try { + payloadPayloadProcessor.onError(e); + } catch (Throwable t) { + // noops + } + }); + channelProcessors.clear(); + } + + private void handleFrame(ByteBuf frame) { + try { + int streamId = FrameHeaderCodec.streamId(frame); + Subscriber receiver; + FrameType frameType = FrameHeaderCodec.frameType(frame); + switch (frameType) { + case REQUEST_FNF: + handleFireAndForget(streamId, fireAndForget(payloadDecoder.apply(frame))); + break; + case REQUEST_RESPONSE: + handleRequestResponse(streamId, requestResponse(payloadDecoder.apply(frame))); + break; + case CANCEL: + handleCancelFrame(streamId); + break; + case REQUEST_N: + handleRequestN(streamId, frame); + break; + case REQUEST_STREAM: + long streamInitialRequestN = RequestStreamFrameCodec.initialRequestN(frame); + Payload streamPayload = payloadDecoder.apply(frame); + handleStream(streamId, requestStream(streamPayload), streamInitialRequestN, null); + break; + case REQUEST_CHANNEL: + long channelInitialRequestN = RequestChannelFrameCodec.initialRequestN(frame); + Payload channelPayload = payloadDecoder.apply(frame); + handleChannel(streamId, channelPayload, channelInitialRequestN); + break; + case METADATA_PUSH: + handleMetadataPush(metadataPush(payloadDecoder.apply(frame))); + break; + case PAYLOAD: + // TODO: Hook in receiving socket. + break; + case NEXT: + receiver = channelProcessors.get(streamId); + if (receiver != null) { + receiver.onNext(payloadDecoder.apply(frame)); + } + break; + case COMPLETE: + receiver = channelProcessors.get(streamId); + if (receiver != null) { + receiver.onComplete(); + } + break; + case ERROR: + receiver = channelProcessors.get(streamId); + if (receiver != null) { + receiver.onError(new ApplicationErrorException(ErrorFrameCodec.dataUtf8(frame))); + } + break; + case NEXT_COMPLETE: + receiver = channelProcessors.get(streamId); + if (receiver != null) { + receiver.onNext(payloadDecoder.apply(frame)); + receiver.onComplete(); + } + break; + case SETUP: + handleError(streamId, new IllegalStateException("Setup frame received post setup.")); + break; + case LEASE: + default: + handleError( + streamId, + new IllegalStateException("ServerRSocket: Unexpected frame type: " + frameType)); + break; + } + ReferenceCountUtil.safeRelease(frame); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frame); + throw Exceptions.propagate(t); + } + } + + private void handleFireAndForget(int streamId, Mono result) { + result.subscribe( + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + sendingSubscriptions.put(streamId, subscription); + subscription.request(Long.MAX_VALUE); + } + + @Override + protected void hookOnError(Throwable throwable) {} + + @Override + protected void hookFinally(SignalType type) { + sendingSubscriptions.remove(streamId); + } + }); + } + + private void handleRequestResponse(int streamId, Mono response) { + final BaseSubscriber subscriber = + new BaseSubscriber() { + private boolean isEmpty = true; + + @Override + protected void hookOnNext(Payload payload) { + if (isEmpty) { + isEmpty = false; + } + + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; + } + + ByteBuf byteBuf = + PayloadFrameCodec.encodeNextCompleteReleasingPayload(allocator, streamId, payload); + sendProcessor.onNext(byteBuf); + } + + @Override + protected void hookOnError(Throwable throwable) { + handleError(streamId, throwable); + } + + @Override + protected void hookOnComplete() { + if (isEmpty) { + sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId)); + } + } + + @Override + protected void hookFinally(SignalType type) { + sendingSubscriptions.remove(streamId, this); + } + }; + + sendingSubscriptions.put(streamId, subscriber); + response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); + } + + private void handleStream( + int streamId, + Flux response, + long initialRequestN, + @Nullable UnicastProcessor requestChannel) { + final BaseSubscriber subscriber = + new BaseSubscriber() { + + @Override + protected void hookOnSubscribe(Subscription s) { + s.request(initialRequestN); + } + + @Override + protected void hookOnNext(Payload payload) { + try { + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + // specifically for requestChannel case so when Payload is invalid we will not be + // sending CancelFrame and ErrorFrame + // Note: CancelFrame is redundant and due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream + // is + // terminated on both Requester and Responder. + // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both the Requester and Responder. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); + } + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; + } + + ByteBuf byteBuf = + PayloadFrameCodec.encodeNextReleasingPayload(allocator, streamId, payload); + sendProcessor.onNext(byteBuf); + } catch (Throwable e) { + // specifically for requestChannel case so when Payload is invalid we will not be + // sending CancelFrame and ErrorFrame + // Note: CancelFrame is redundant and due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both Requester and Responder. + // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both the Requester and Responder. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); + } + cancel(); + handleError(streamId, e); + } + } + + @Override + protected void hookOnComplete() { + sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId)); + } + + @Override + protected void hookOnError(Throwable throwable) { + handleError(streamId, throwable); + } + + @Override + protected void hookOnCancel() { + // specifically for requestChannel case so when requester sends Cancel frame so the + // whole chain MUST be terminated + // Note: CancelFrame is redundant from the responder side due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving a CANCEL, the stream is terminated on the Responder. + // Upon sending a CANCEL, the stream is terminated on the Requester. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); + try { + requestChannel.dispose(); + } catch (Exception e) { + // might be thrown back if stream is cancelled + } + } + } + + @Override + protected void hookFinally(SignalType type) { + sendingSubscriptions.remove(streamId); + } + }; + + sendingSubscriptions.put(streamId, subscriber); + response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); + } + + private void handleChannel(int streamId, Payload payload, long initialRequestN) { + UnicastProcessor frames = UnicastProcessor.create(); + channelProcessors.put(streamId, frames); + + Flux payloads = + frames + .doOnRequest( + new LongConsumer() { + boolean first = true; + + @Override + public void accept(long l) { + long n; + if (first) { + first = false; + n = l - 1L; + } else { + n = l; + } + if (n > 0) { + sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); + } + } + }) + .doFinally( + signalType -> { + if (channelProcessors.remove(streamId, frames)) { + if (signalType == SignalType.CANCEL) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } + } + }) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + + // not chained, as the payload should be enqueued in the Unicast processor before this method + // returns + // and any later payload can be processed + frames.onNext(payload); + + if (responderRSocket != null) { + handleStream(streamId, requestChannel(payload, payloads), initialRequestN, frames); + } else { + handleStream(streamId, requestChannel(payloads), initialRequestN, frames); + } + } + + private void handleMetadataPush(Mono result) { + result.subscribe( + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + subscription.request(Long.MAX_VALUE); + } + + @Override + protected void hookOnError(Throwable throwable) {} + }); + } + + private void handleCancelFrame(int streamId) { + Subscription subscription = sendingSubscriptions.remove(streamId); + channelProcessors.remove(streamId); + + if (subscription != null) { + subscription.cancel(); + } + } + + private void handleError(int streamId, Throwable t) { + sendProcessor.onNext(ErrorFrameCodec.encode(allocator, streamId, t)); + } + + private void handleRequestN(int streamId, ByteBuf frame) { + Subscription subscription = sendingSubscriptions.get(streamId); + + if (subscription != null) { + long n = RequestNFrameCodec.requestN(frame); + subscription.request(n); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java new file mode 100644 index 000000000..c5734cecc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -0,0 +1,446 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.exceptions.InvalidSetupException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.lease.Leases; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.resume.SessionManager; +import io.rsocket.transport.ServerTransport; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * The main class for starting an RSocket server. + * + *

For example: + * + *

{@code
+ * CloseableChannel closeable =
+ *         RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+ *                 .bind(TcpServerTransport.create("localhost", 7000))
+ *                 .block();
+ * }
+ */ +public final class RSocketServer { + private static final String SERVER_TAG = "server"; + + private SocketAcceptor acceptor = SocketAcceptor.with(new RSocket() {}); + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Resume resume; + private Supplier> leasesSupplier = null; + + private int mtu = 0; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + + private RSocketServer() {} + + /** Static factory method to create an {@code RSocketServer}. */ + public static RSocketServer create() { + return new RSocketServer(); + } + + /** + * Static factory method to create an {@code RSocketServer} instance with the given {@code + * SocketAcceptor}. Effectively a shortcut for: + * + *
+   * RSocketServer.create().acceptor(...);
+   * 
+ * + * @param acceptor the acceptor to handle connections with + * @return the same instance for method chaining + * @see #acceptor(SocketAcceptor) + */ + public static RSocketServer create(SocketAcceptor acceptor) { + return RSocketServer.create().acceptor(acceptor); + } + + /** + * Set the acceptor to handle incoming connections and handle requests. + * + *

An example with access to the {@code SETUP} frame and sending RSocket for performing + * requests back to the client if needed: + * + *

{@code
+   * RSocketServer.create((setup, sendingRSocket) -> Mono.just(new RSocket() {...}))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

A shortcut to provide the handling RSocket only: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

A shortcut to handle request-response interactions only: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.forRequestResponse(payload -> ...))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

By default, {@code new RSocket(){}} is used for handling which rejects requests from the + * client with {@link UnsupportedOperationException}. + * + * @param acceptor the acceptor to handle incoming connections and requests with + * @return the same instance for method chaining + */ + public RSocketServer acceptor(SocketAcceptor acceptor) { + Objects.requireNonNull(acceptor); + this.acceptor = acceptor; + return this; + } + + /** + * Configure interception at one of the following levels: + * + *

    + *
  • Transport level + *
  • At the level of accepting new connections + *
  • Performing requests + *
  • Responding to requests + *
+ * + * @param configurer a configurer to customize interception with. + * @return the same instance for method chaining + * @see io.rsocket.plugins.LimitRateInterceptor + */ + public RSocketServer interceptors(Consumer configurer) { + configurer.accept(this.interceptors); + return this; + } + + /** + * Enables the Resume capability of the RSocket protocol where if the client gets disconnected, + * the connection is re-acquired and any interrupted streams are transparently resumed. For this + * to work clients must also support and request to enable this when connecting. + * + *

Use the {@link Resume} argument to customize the Resume session duration, storage, retry + * logic, and others. + * + *

By default this is not enabled. + * + * @param resume configuration for the Resume capability + * @return the same instance for method chaining + * @see Resuming + * Operation + */ + public RSocketServer resume(Resume resume) { + this.resume = resume; + return this; + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. For + * this to work clients must also support and request to enable this when connecting. + * + *

Example usage: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+   *         .lease(Leases::new)
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

By default this is not enabled. + * + * @param supplier supplier for a {@link Leases} + * @return the same instance for method chaining + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketServer lease(Supplier> supplier) { + this.leasesSupplier = supplier; + return this; + } + + /** + * When this is set, frames larger than the given maximum transmission unit (mtu) size value are + * fragmented. + * + *

By default this is not set in which case payloads are sent whole up to the maximum frame + * size of 16,777,215 bytes. + * + * @param mtu the threshold size for fragmentation, must be no less than 64 + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketServer fragment(int mtu) { + if (mtu > 0 && mtu < FragmentationDuplexConnection.MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format( + "The smallest allowed mtu size is %d bytes, provided: %d", + FragmentationDuplexConnection.MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } + this.mtu = mtu; + return this; + } + + /** + * Configure the {@code PayloadDecoder} used to create {@link Payload}'s from incoming raw frame + * buffers. The following decoders are available: + * + *

    + *
  • {@link PayloadDecoder#DEFAULT} -- the data and metadata are independent copies of the + * underlying frame {@link ByteBuf} + *
  • {@link PayloadDecoder#ZERO_COPY} -- the data and metadata are retained slices of the + * underlying {@link ByteBuf}. That's more efficient but requires careful tracking and + * {@link Payload#release() release} of the payload when no longer needed. + *
+ * + *

By default this is set to {@link PayloadDecoder#DEFAULT} in which case data and metadata are + * copied and do not need to be tracked and released. + * + * @param decoder the decoder to use + * @return the same instance for method chaining + */ + public RSocketServer payloadDecoder(PayloadDecoder decoder) { + Objects.requireNonNull(decoder); + this.payloadDecoder = decoder; + return this; + } + + /** + * Start the server on the given transport. + * + *

The following transports are available from additional RSocket Java modules: + * + *

    + *
  • {@link io.rsocket.transport.netty.client.TcpServerTransport TcpServerTransport} via + * {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.netty.client.WebsocketServerTransport + * WebsocketServerTransport} via {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.local.LocalServerTransport LocalServerTransport} via {@code + * rsocket-transport-local} + *
+ * + * @param transport the transport of choice to connect with + * @param the type of {@code Closeable} for the given transport + * @return a {@code Mono} with a {@code Closeable} that can be used to obtain information about + * the server, stop it, or be notified of when it is stopped. + */ + public Mono bind(ServerTransport transport) { + return Mono.defer( + new Supplier>() { + ServerSetup serverSetup = serverSetup(); + + @Override + public Mono get() { + return transport + .start(duplexConnection -> acceptor(serverSetup, duplexConnection), mtu) + .doOnNext(c -> c.onClose().doFinally(v -> serverSetup.dispose()).subscribe()); + } + }); + } + + /** + * Start the server on the given transport. Effectively is a shortcut for {@code + * .bind(ServerTransport).block()} + */ + public T bindNow(ServerTransport transport) { + return bind(transport).block(); + } + + /** + * An alternative to {@link #bind(ServerTransport)} that is useful for installing RSocket on a + * server that is started independently. + * + * @see io.rsocket.examples.transport.ws.WebSocketHeadersSample + */ + public ServerTransport.ConnectionAcceptor asConnectionAcceptor() { + return new ServerTransport.ConnectionAcceptor() { + private final ServerSetup serverSetup = serverSetup(); + + @Override + public Mono apply(DuplexConnection connection) { + return acceptor(serverSetup, connection); + } + }; + } + + private Mono acceptor(ServerSetup serverSetup, DuplexConnection connection) { + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(connection, interceptors, false); + + return multiplexer + .asSetupConnection() + .receive() + .next() + .flatMap(startFrame -> accept(serverSetup, startFrame, multiplexer)); + } + + private Mono acceptResume( + ServerSetup serverSetup, ByteBuf resumeFrame, ClientServerInputMultiplexer multiplexer) { + return serverSetup.acceptRSocketResume(resumeFrame, multiplexer); + } + + private Mono accept( + ServerSetup serverSetup, ByteBuf startFrame, ClientServerInputMultiplexer multiplexer) { + switch (FrameHeaderCodec.frameType(startFrame)) { + case SETUP: + return acceptSetup(serverSetup, startFrame, multiplexer); + case RESUME: + return acceptResume(serverSetup, startFrame, multiplexer); + default: + return serverSetup + .sendError( + multiplexer, + new InvalidSetupException( + "invalid setup frame: " + FrameHeaderCodec.frameType(startFrame))) + .doFinally( + signalType -> { + startFrame.release(); + multiplexer.dispose(); + }); + } + } + + private Mono acceptSetup( + ServerSetup serverSetup, ByteBuf setupFrame, ClientServerInputMultiplexer multiplexer) { + + if (!SetupFrameCodec.isSupportedVersion(setupFrame)) { + return serverSetup + .sendError( + multiplexer, + new InvalidSetupException( + "Unsupported version: " + SetupFrameCodec.humanReadableVersion(setupFrame))) + .doFinally( + signalType -> { + setupFrame.release(); + multiplexer.dispose(); + }); + } + + boolean leaseEnabled = leasesSupplier != null; + if (SetupFrameCodec.honorLease(setupFrame) && !leaseEnabled) { + return serverSetup + .sendError(multiplexer, new InvalidSetupException("lease is not supported")) + .doFinally( + signalType -> { + setupFrame.release(); + multiplexer.dispose(); + }); + } + + return serverSetup.acceptRSocketSetup( + setupFrame, + multiplexer, + (keepAliveHandler, wrappedMultiplexer) -> { + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(setupFrame); + + Leases leases = leaseEnabled ? leasesSupplier.get() : null; + RequesterLeaseHandler requesterLeaseHandler = + leaseEnabled + ? new RequesterLeaseHandler.Impl(SERVER_TAG, leases.receiver()) + : RequesterLeaseHandler.None; + + RSocket rSocketRequester = + new RSocketRequester( + wrappedMultiplexer.asServerConnection(), + payloadDecoder, + StreamIdSupplier.serverSupplier(), + mtu, + setupPayload.keepAliveInterval(), + setupPayload.keepAliveMaxLifetime(), + keepAliveHandler, + requesterLeaseHandler, + Schedulers.single(Schedulers.parallel())); + + RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setupPayload, wrappedRSocketRequester) + .onErrorResume( + err -> + serverSetup + .sendError(multiplexer, rejectedSetupError(err)) + .then(Mono.error(err))) + .doOnNext( + rSocketHandler -> { + RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); + DuplexConnection connection = wrappedMultiplexer.asClientConnection(); + + ResponderLeaseHandler responderLeaseHandler = + leaseEnabled + ? new ResponderLeaseHandler.Impl<>( + SERVER_TAG, connection.alloc(), leases.sender(), leases.stats()) + : ResponderLeaseHandler.None; + + RSocket rSocketResponder = + new RSocketResponder( + connection, + wrappedRSocketHandler, + payloadDecoder, + responderLeaseHandler, + mtu); + }) + .doFinally(signalType -> setupPayload.release()) + .then(); + }); + } + + private ServerSetup serverSetup() { + return resume != null ? createSetup() : new ServerSetup.DefaultServerSetup(); + } + + ServerSetup createSetup() { + return new ServerSetup.ResumableServerSetup( + new SessionManager(), + resume.getSessionDuration(), + resume.getStreamTimeout(), + resume.getStoreFactory(SERVER_TAG), + resume.isCleanupStoreOnKeepAlive()); + } + + private Exception rejectedSetupError(Throwable err) { + String msg = err.getMessage(); + return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java new file mode 100644 index 000000000..81f6625f0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java @@ -0,0 +1,477 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Operators.MonoSubscriber; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class ReconnectMono extends Mono implements Invalidatable, Disposable, Scannable { + + final Mono source; + final BiConsumer onValueReceived; + final Consumer onValueExpired; + final ReconnectMainSubscriber mainSubscriber; + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ReconnectMono.class, "wip"); + + volatile ReconnectInner[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectMono.class, ReconnectInner[].class, "subscribers"); + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] EMPTY_UNSUBSCRIBED = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] EMPTY_SUBSCRIBED = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] READY = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] TERMINATED = new ReconnectInner[0]; + + static final int ADDED_STATE = 0; + static final int READY_STATE = 1; + static final int TERMINATED_STATE = 2; + + T value; + Throwable t; + + ReconnectMono( + Mono source, + Consumer onValueExpired, + BiConsumer onValueReceived) { + this.source = source; + this.onValueExpired = onValueExpired; + this.onValueReceived = onValueReceived; + this.mainSubscriber = new ReconnectMainSubscriber<>(this); + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return source; + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + + final boolean isDisposed = isDisposed(); + if (key == Attr.TERMINATED) return isDisposed; + if (key == Attr.ERROR) return t; + + return null; + } + + @Override + public void dispose() { + this.terminate(new CancellationException("ReconnectMono has already been disposed")); + } + + @Override + public boolean isDisposed() { + return this.subscribers == TERMINATED; + } + + @Override + @SuppressWarnings("uncheked") + public void subscribe(CoreSubscriber actual) { + final ReconnectInner inner = new ReconnectInner<>(actual, this); + actual.onSubscribe(inner); + + final int state = this.add(inner); + + if (state == READY_STATE) { + inner.complete(this.value); + } else if (state == TERMINATED_STATE) { + inner.onError(this.t); + } + } + + /** + * Block the calling thread indefinitely, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @return the value of this {@code ReconnectMono} + */ + @Override + @Nullable + public T block() { + return block(null); + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@code ReconnectMono} or {@code null} if the timeout is reached and + * the {@code ReconnectMono} has not completed + */ + @Override + @Nullable + @SuppressWarnings("uncheked") + public T block(@Nullable Duration timeout) { + try { + ReconnectInner[] subscribers = this.subscribers; + if (subscribers == READY) { + return this.value; + } + + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("ReconnectMono terminated with an error")); + throw re; + } + + // connect once + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.source.subscribe(this.mainSubscriber); + } + + long delay; + if (null == timeout) { + delay = 0L; + } else { + delay = System.nanoTime() + timeout.toNanos(); + } + for (; ; ) { + ReconnectInner[] inners = this.subscribers; + + if (inners == READY) { + return this.value; + } + if (inners == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = + Exceptions.addSuppressed(re, new Exception("ReconnectMono terminated with an error")); + throw re; + } + if (timeout != null && delay < System.nanoTime()) { + throw new IllegalStateException("Timeout on Mono blocking read"); + } + + Thread.sleep(1); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + + throw new IllegalStateException("Thread Interruption on Mono blocking read"); + } + } + + @SuppressWarnings("unchecked") + void terminate(Throwable t) { + if (isDisposed()) { + return; + } + + // writes happens before volatile write + this.t = t; + + final ReconnectInner[] subscribers = SUBSCRIBERS.getAndSet(this, TERMINATED); + if (subscribers == TERMINATED) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.mainSubscriber.dispose(); + + this.doFinally(); + + for (CoreSubscriber consumer : subscribers) { + consumer.onError(t); + } + } + + void complete() { + ReconnectInner[] subscribers = this.subscribers; + if (subscribers == TERMINATED) { + return; + } + + final T value = this.value; + + for (; ; ) { + // ensures TERMINATE is going to be replaced with READY + if (SUBSCRIBERS.compareAndSet(this, subscribers, READY)) { + break; + } + + subscribers = this.subscribers; + + if (subscribers == TERMINATED) { + this.doFinally(); + return; + } + } + + this.onValueReceived.accept(value, this); + + for (ReconnectInner consumer : subscribers) { + consumer.complete(value); + } + } + + void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + + if (value != null && isDisposed()) { + this.value = null; + this.onValueExpired.accept(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + + // Check RSocket is not good + @Override + public void invalidate() { + if (this.subscribers == TERMINATED) { + return; + } + + final ReconnectInner[] subscribers = this.subscribers; + + if (subscribers == READY && SUBSCRIBERS.compareAndSet(this, READY, EMPTY_UNSUBSCRIBED)) { + final T value = this.value; + this.value = null; + + if (value != null) { + this.onValueExpired.accept(value); + } + } + } + + int add(ReconnectInner ps) { + for (; ; ) { + ReconnectInner[] a = this.subscribers; + + if (a == TERMINATED) { + return TERMINATED_STATE; + } + + if (a == READY) { + return READY_STATE; + } + + int n = a.length; + @SuppressWarnings("unchecked") + ReconnectInner[] b = new ReconnectInner[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + this.source.subscribe(this.mainSubscriber); + } + return ADDED_STATE; + } + } + } + + @SuppressWarnings("unchecked") + void remove(ReconnectInner ps) { + for (; ; ) { + ReconnectInner[] a = this.subscribers; + int n = a.length; + if (n == 0) { + return; + } + + int j = -1; + for (int i = 0; i < n; i++) { + if (a[i] == ps) { + j = i; + break; + } + } + + if (j < 0) { + return; + } + + ReconnectInner[] b; + + if (n == 1) { + b = EMPTY_SUBSCRIBED; + } else { + b = new ReconnectInner[n - 1]; + System.arraycopy(a, 0, b, 0, j); + System.arraycopy(a, j + 1, b, j, n - j - 1); + } + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + return; + } + } + } + + static final class ReconnectMainSubscriber implements CoreSubscriber { + + final ReconnectMono parent; + + volatile Subscription s; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectMainSubscriber.class, Subscription.class, "s"); + + ReconnectMainSubscriber(ReconnectMono parent) { + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final ReconnectMono p = this.parent; + final T value = p.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + p.doFinally(); + return; + } + + if (value == null) { + p.terminate(new IllegalStateException("Unexpected Completion of the Upstream")); + } else { + p.complete(); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + final ReconnectMono p = this.parent; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + p.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + // terminate upstream which means retryBackoff has exhausted + p.terminate(t); + } + + @Override + public void onNext(T value) { + if (this.s == Operators.cancelledSubscription()) { + this.parent.onValueExpired.accept(value); + return; + } + + final ReconnectMono p = this.parent; + + p.value = value; + // volatile write and check on racing + p.doFinally(); + } + + void dispose() { + Operators.terminate(S, this); + } + } + + static final class ReconnectInner extends MonoSubscriber { + final ReconnectMono parent; + + ReconnectInner(CoreSubscriber actual, ReconnectMono parent) { + super(actual); + this.parent = parent; + } + + @Override + public void cancel() { + if (!isCancelled()) { + super.cancel(); + this.parent.remove(this); + } + } + + @Override + public void onComplete() { + if (!isCancelled()) { + this.actual.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + if (isCancelled()) { + Operators.onErrorDropped(t, currentContext()); + } else { + this.actual.onError(t); + } + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.parent; + return super.scanUnsafe(key); + } + } +} + +interface Invalidatable { + + void invalidate(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java b/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java new file mode 100644 index 000000000..05f8d6b3c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java @@ -0,0 +1,188 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.util.context.Context; + +/** + * This is a support class for handling of request input, intended for use with {@link + * Operators#lift}. It ensures serial execution of cancellation vs first request signals and also + * provides hooks for separate handling of first vs subsequent {@link Subscription#request} + * invocations. + */ +abstract class RequestOperator + implements CoreSubscriber, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + + Subscription s; + Fuseable.QueueSubscription qs; + + int streamId; + boolean firstRequest = true; + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(RequestOperator.class, "wip"); + + RequestOperator(CoreSubscriber actual) { + this.actual = actual; + } + + /** + * Optional hook executed exactly once on the first {@link Subscription#request) invocation + * and right after the {@link Subscription#request} was propagated to the upstream subscription. + * + *

Note: this hook may not be invoked if cancellation happened before this invocation + */ + void hookOnFirstRequest(long n) {} + + /** + * Optional hook executed after the {@link Subscription#request} was propagated to the upstream + * subscription and excludes the first {@link Subscription#request} invocation. + */ + void hookOnRemainingRequests(long n) {} + + /** Optional hook executed after this {@link Subscription} cancelling. */ + void hookOnCancel() {} + + /** + * Optional hook executed after {@link org.reactivestreams.Subscriber} termination events + * (onError, onComplete). + * + * @param signalType the type of termination event that triggered the hook ({@link + * SignalType#ON_ERROR} or {@link SignalType#ON_COMPLETE}) + */ + void hookOnTerminal(SignalType signalType) {} + + @Override + public Context currentContext() { + return actual.currentContext(); + } + + @Override + public void request(long n) { + this.s.request(n); + if (!firstRequest) { + try { + this.hookOnRemainingRequests(n); + } catch (Throwable throwable) { + onError(throwable); + } + return; + } + this.firstRequest = false; + + if (WIP.getAndIncrement(this) != 0) { + return; + } + int missed = 1; + + boolean firstLoop = true; + for (; ; ) { + if (firstLoop) { + firstLoop = false; + try { + this.hookOnFirstRequest(n); + } catch (Throwable throwable) { + onError(throwable); + return; + } + } else { + try { + this.hookOnCancel(); + } catch (Throwable throwable) { + onError(throwable); + } + return; + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + return; + } + } + } + + @Override + public void cancel() { + this.s.cancel(); + + if (WIP.getAndIncrement(this) != 0) { + return; + } + + hookOnCancel(); + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; + } + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(Payload t) { + this.actual.onNext(t); + } + + @Override + public void onError(Throwable t) { + this.actual.onError(t); + try { + this.hookOnTerminal(SignalType.ON_ERROR); + } catch (Throwable throwable) { + Operators.onErrorDropped(throwable, currentContext()); + } + } + + @Override + public void onComplete() { + this.actual.onComplete(); + try { + this.hookOnTerminal(SignalType.ON_COMPLETE); + } catch (Throwable throwable) { + Operators.onErrorDropped(throwable, currentContext()); + } + } + + @Override + public int requestFusion(int requestedMode) { + if (this.qs != null) { + return this.qs.requestFusion(requestedMode); + } else { + return Fuseable.NONE; + } + } + + @Override + public Payload poll() { + return this.qs.poll(); + } + + @Override + public int size() { + return this.qs.size(); + } + + @Override + public boolean isEmpty() { + return this.qs.isEmpty(); + } + + @Override + public void clear() { + this.qs.clear(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/Resume.java b/rsocket-core/src/main/java/io/rsocket/core/Resume.java new file mode 100644 index 000000000..48133af98 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/Resume.java @@ -0,0 +1,177 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.ResumableFramesStore; +import java.time.Duration; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.util.retry.Retry; + +/** + * Simple holder of configuration settings for the RSocket Resume capability. This can be used to + * configure an {@link RSocketConnector} or an {@link RSocketServer} except for {@link + * #retry(Retry)} and {@link #token(Supplier)} which apply only to the client side. + */ +public class Resume { + private static final Logger logger = LoggerFactory.getLogger(Resume.class); + + private Duration sessionDuration = Duration.ofMinutes(2); + + /* Storage */ + private boolean cleanupStoreOnKeepAlive; + private Function storeFactory; + private Duration streamTimeout = Duration.ofSeconds(10); + + /* Client only */ + private Supplier tokenSupplier = ResumeFrameCodec::generateResumeToken; + private Retry retry = + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1)) + .maxBackoff(Duration.ofSeconds(16)) + .jitter(1.0) + .doBeforeRetry(signal -> logger.debug("Connection error", signal.failure())); + + public Resume() {} + + /** + * The maximum time for a client to keep trying to reconnect. During this time client and server + * continue to store unsent frames to keep the session warm and ready to resume. + * + *

By default this is set to 2 minutes. + * + * @param sessionDuration the max duration for a session + * @return the same instance for method chaining + */ + public Resume sessionDuration(Duration sessionDuration) { + this.sessionDuration = Objects.requireNonNull(sessionDuration); + return this; + } + + /** + * When this property is enabled, hints from {@code KEEPALIVE} frames about how much data has been + * received by the other side, is used to proactively clean frames from the {@link + * #storeFactory(Function) store}. + * + *

By default this is set to {@code false} in which case information from {@code KEEPALIVE} is + * ignored and old frames from the store are removed only when the store runs out of space. + * + * @return the same instance for method chaining + */ + public Resume cleanupStoreOnKeepAlive() { + this.cleanupStoreOnKeepAlive = true; + return this; + } + + /** + * Configure a factory to create the storage for buffering (or persisting) a window of frames that + * may need to be sent again to resume after a dropped connection. + * + *

By default {@link InMemoryResumableFramesStore} is used with its cache size set to 100,000 + * bytes. When the cache fills up, the oldest frames are gradually removed to create space for new + * ones. + * + * @param storeFactory the factory to use to create the store + * @return the same instance for method chaining + */ + public Resume storeFactory( + Function storeFactory) { + this.storeFactory = storeFactory; + return this; + } + + /** + * A {@link reactor.core.publisher.Flux#timeout(Duration) timeout} value to apply to the resumed + * session stream obtained from the {@link #storeFactory(Function) store} after a reconnect. The + * resume stream must not take longer than the specified time to emit each frame. + * + *

By default this is set to 10 seconds. + * + * @param streamTimeout the timeout value for resuming a session stream + * @return the same instance for method chaining + */ + public Resume streamTimeout(Duration streamTimeout) { + this.streamTimeout = Objects.requireNonNull(streamTimeout); + return this; + } + + /** + * Configure the logic for reconnecting. This setting is for use with {@link + * RSocketConnector#resume(Resume)} on the client side only. + * + *

By default this is set to: + * + *

{@code
+   * Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1))
+   *     .maxBackoff(Duration.ofSeconds(16))
+   *     .jitter(1.0)
+   * }
+ * + * @param retry the {@code Retry} spec to use when attempting to reconnect + * @return the same instance for method chaining + */ + public Resume retry(Retry retry) { + this.retry = retry; + return this; + } + + /** + * Customize the generation of the resume identification token used to resume. This setting is for + * use with {@link RSocketConnector#resume(Resume)} on the client side only. + * + *

By default this is {@code ResumeFrameFlyweight::generateResumeToken}. + * + * @param supplier a custom generator for a resume identification token + * @return the same instance for method chaining + */ + public Resume token(Supplier supplier) { + this.tokenSupplier = supplier; + return this; + } + + // Package private accessors + + Duration getSessionDuration() { + return sessionDuration; + } + + boolean isCleanupStoreOnKeepAlive() { + return cleanupStoreOnKeepAlive; + } + + Function getStoreFactory(String tag) { + return storeFactory != null + ? storeFactory + : token -> new InMemoryResumableFramesStore(tag, 100_000); + } + + Duration getStreamTimeout() { + return streamTimeout; + } + + Retry getRetry() { + return retry; + } + + Supplier getTokenSupplier() { + return tokenSupplier; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java new file mode 100644 index 000000000..337d17c64 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java @@ -0,0 +1,158 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.keepalive.KeepAliveHandler.*; + +import io.netty.buffer.ByteBuf; +import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.exceptions.UnsupportedSetupException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.resume.*; +import java.time.Duration; +import java.util.function.BiFunction; +import java.util.function.Function; +import reactor.core.publisher.Mono; + +abstract class ServerSetup { + + abstract Mono acceptRSocketSetup( + ByteBuf frame, + ClientServerInputMultiplexer multiplexer, + BiFunction> then); + + abstract Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer); + + void dispose() {} + + Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { + DuplexConnection duplexConnection = multiplexer.asSetupConnection(); + return duplexConnection + .sendOne(ErrorFrameCodec.encode(duplexConnection.alloc(), 0, exception)) + .onErrorResume(err -> Mono.empty()); + } + + static class DefaultServerSetup extends ServerSetup { + + @Override + public Mono acceptRSocketSetup( + ByteBuf frame, + ClientServerInputMultiplexer multiplexer, + BiFunction> then) { + + if (SetupFrameCodec.resumeEnabled(frame)) { + return sendError(multiplexer, new UnsupportedSetupException("resume not supported")) + .doFinally( + signalType -> { + frame.release(); + multiplexer.dispose(); + }); + } else { + return then.apply(new DefaultKeepAliveHandler(multiplexer), multiplexer); + } + } + + @Override + public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { + + return sendError(multiplexer, new RejectedResumeException("resume not supported")) + .doFinally( + signalType -> { + frame.release(); + multiplexer.dispose(); + }); + } + } + + static class ResumableServerSetup extends ServerSetup { + private final SessionManager sessionManager; + private final Duration resumeSessionDuration; + private final Duration resumeStreamTimeout; + private final Function resumeStoreFactory; + private final boolean cleanupStoreOnKeepAlive; + + ResumableServerSetup( + SessionManager sessionManager, + Duration resumeSessionDuration, + Duration resumeStreamTimeout, + Function resumeStoreFactory, + boolean cleanupStoreOnKeepAlive) { + this.sessionManager = sessionManager; + this.resumeSessionDuration = resumeSessionDuration; + this.resumeStreamTimeout = resumeStreamTimeout; + this.resumeStoreFactory = resumeStoreFactory; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + } + + @Override + public Mono acceptRSocketSetup( + ByteBuf frame, + ClientServerInputMultiplexer multiplexer, + BiFunction> then) { + + if (SetupFrameCodec.resumeEnabled(frame)) { + ByteBuf resumeToken = SetupFrameCodec.resumeToken(frame); + + ResumableDuplexConnection connection = + sessionManager + .save( + new ServerRSocketSession( + multiplexer.asClientServerConnection(), + resumeSessionDuration, + resumeStreamTimeout, + resumeStoreFactory, + resumeToken, + cleanupStoreOnKeepAlive)) + .resumableConnection(); + return then.apply( + new ResumableKeepAliveHandler(connection), + new ClientServerInputMultiplexer(connection)); + } else { + return then.apply(new DefaultKeepAliveHandler(multiplexer), multiplexer); + } + } + + @Override + public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { + ServerRSocketSession session = sessionManager.get(ResumeFrameCodec.token(frame)); + if (session != null) { + return session + .continueWith(multiplexer.asClientServerConnection()) + .resumeWith(frame) + .onClose() + .then(); + } else { + return sendError(multiplexer, new RejectedResumeException("unknown resume token")) + .doFinally( + s -> { + frame.release(); + multiplexer.dispose(); + }); + } + } + + @Override + public void dispose() { + sessionManager.dispose(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java similarity index 52% rename from rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java rename to rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java index f9985b4ac..15d39c993 100644 --- a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java +++ b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,18 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package io.rsocket.core; -package io.rsocket; - -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import io.netty.util.collection.IntObjectMap; +/** This API is not thread-safe and must be strictly used in serialized fashion */ final class StreamIdSupplier { + private static final int MASK = 0x7FFFFFFF; - private static final AtomicIntegerFieldUpdater STREAM_ID = - AtomicIntegerFieldUpdater.newUpdater(StreamIdSupplier.class, "streamId"); - private volatile int streamId; + private long streamId; - private StreamIdSupplier(int streamId) { + // Visible for testing + StreamIdSupplier(int streamId) { this.streamId = streamId; } @@ -36,8 +36,20 @@ static StreamIdSupplier serverSupplier() { return new StreamIdSupplier(0); } - int nextStreamId() { - return STREAM_ID.addAndGet(this, 2); + /** + * This methods provides new stream id and ensures there is no intersections with already running + * streams. This methods is not thread-safe. + * + * @param streamIds currently running streams store + * @return next stream id + */ + int nextStreamId(IntObjectMap streamIds) { + int streamId; + do { + this.streamId += 2; + streamId = (int) (this.streamId & MASK); + } while (streamId == 0 || streamIds.containsKey(streamId)); + return streamId; } boolean isBeforeOrCurrent(int streamId) { diff --git a/rsocket-core/src/main/java/io/rsocket/core/package-info.java b/rsocket-core/src/main/java/io/rsocket/core/package-info.java new file mode 100644 index 000000000..29db3f205 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/package-info.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains {@link io.rsocket.core.RSocketConnector RSocketConnector} and {@link + * io.rsocket.core.RSocketServer RSocketServer}, the main classes for connecting to or starting an + * RSocket server. + * + *

This package also contains a package private classes that implement support for the main + * RSocket interactions. + */ +@NonNullApi +package io.rsocket.core; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java index e92534b2a..cd0d46754 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * Application layer logic generating a Reactive Streams {@code onError} event. @@ -32,10 +33,9 @@ public final class ApplicationErrorException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ApplicationErrorException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public ApplicationErrorException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ApplicationErrorException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.APPLICATION_ERROR; + public ApplicationErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java index 984e8249b..d51ba0fb7 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The Responder canceled the request but may have started processing it (similar to REJECTED but @@ -33,10 +34,9 @@ public final class CanceledException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public CanceledException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public CanceledException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public CanceledException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CANCELED; + public CanceledException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CANCELED, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java index 3f4f4309d..80324aa90 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The connection is being terminated. Sender or Receiver of this frame MUST wait for outstanding @@ -33,10 +34,9 @@ public final class ConnectionCloseException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ConnectionCloseException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public ConnectionCloseException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ConnectionCloseException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CONNECTION_CLOSE; + public ConnectionCloseException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CONNECTION_CLOSE, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java index beaa3d0d0..b44714f7e 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The connection is being terminated. Sender or Receiver of this frame MAY close the connection @@ -33,10 +34,9 @@ public final class ConnectionErrorException extends RSocketException implements * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ConnectionErrorException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public ConnectionErrorException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ConnectionErrorException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CONNECTION_ERROR; + public ConnectionErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CONNECTION_ERROR, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java new file mode 100644 index 000000000..079b561f9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +public class CustomRSocketException extends RSocketException { + private static final long serialVersionUID = 7873267740343446585L; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @param cause the cause of this exception + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + if (errorCode > ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE + && errorCode < ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]", this); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java index 97de65a96..5c6eee614 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,22 @@ package io.rsocket.exceptions; -import static io.rsocket.frame.ErrorFrameFlyweight.*; +import static io.rsocket.frame.ErrorFrameCodec.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.CANCELED; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.INVALID; +import static io.rsocket.frame.ErrorFrameCodec.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.UNSUPPORTED_SETUP; import io.netty.buffer.ByteBuf; -import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; import java.util.Objects; /** Utility class that generates an exception from a frame. */ @@ -28,42 +40,56 @@ public final class Exceptions { private Exceptions() {} /** - * Create a {@link RSocketException} from a Frame that matches the error code it contains. + * Create a {@link RSocketErrorException} from a Frame that matches the error code it contains. * * @param frame the frame to retrieve the error code and message from - * @return a {@link RSocketException} that matches the error code in the Frame + * @return a {@link RSocketErrorException} that matches the error code in the Frame * @throws NullPointerException if {@code frame} is {@code null} */ - public static RuntimeException from(ByteBuf frame) { + public static RuntimeException from(int streamId, ByteBuf frame) { Objects.requireNonNull(frame, "frame must not be null"); - int errorCode = ErrorFrameFlyweight.errorCode(frame); - String message = ErrorFrameFlyweight.dataUtf8(frame); + int errorCode = ErrorFrameCodec.errorCode(frame); + String message = ErrorFrameCodec.dataUtf8(frame); - switch (errorCode) { - case APPLICATION_ERROR: - return new ApplicationErrorException(message); - case CANCELED: - return new CanceledException(message); - case CONNECTION_CLOSE: - return new ConnectionCloseException(message); - case CONNECTION_ERROR: - return new ConnectionErrorException(message); - case INVALID: - return new InvalidException(message); - case INVALID_SETUP: - return new InvalidSetupException(message); - case REJECTED: - return new RejectedException(message); - case REJECTED_RESUME: - return new RejectedResumeException(message); - case REJECTED_SETUP: - return new RejectedSetupException(message); - case UNSUPPORTED_SETUP: - return new UnsupportedSetupException(message); - default: - return new IllegalArgumentException( - String.format("Invalid Error frame: %d '%s'", errorCode, message)); + if (streamId == 0) { + switch (errorCode) { + case INVALID_SETUP: + return new InvalidSetupException(message); + case UNSUPPORTED_SETUP: + return new UnsupportedSetupException(message); + case REJECTED_SETUP: + return new RejectedSetupException(message); + case REJECTED_RESUME: + return new RejectedResumeException(message); + case CONNECTION_ERROR: + return new ConnectionErrorException(message); + case CONNECTION_CLOSE: + return new ConnectionCloseException(message); + default: + return new IllegalArgumentException( + String.format("Invalid Error frame in Stream ID 0: 0x%08X '%s'", errorCode, message)); + } + } else { + switch (errorCode) { + case APPLICATION_ERROR: + return new ApplicationErrorException(message); + case REJECTED: + return new RejectedException(message); + case CANCELED: + return new CanceledException(message); + case INVALID: + return new InvalidException(message); + default: + if (errorCode >= MIN_USER_ALLOWED_ERROR_CODE + || errorCode <= MAX_USER_ALLOWED_ERROR_CODE) { + return new CustomRSocketException(errorCode, message); + } + return new IllegalArgumentException( + String.format( + "Invalid Error frame in Stream ID %d: 0x%08X '%s'", + streamId, errorCode, message)); + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java index 4783b1590..a1b77b8dd 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The request is invalid. @@ -32,10 +33,9 @@ public final class InvalidException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public InvalidException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public InvalidException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public InvalidException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.INVALID; + public InvalidException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.INVALID, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java index b3705d5b7..b0889c5a6 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The Setup frame is invalid for the server (it could be that the client is too recent for the old @@ -33,10 +34,9 @@ public final class InvalidSetupException extends SetupException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public InvalidSetupException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public InvalidSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public InvalidSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.INVALID_SETUP; + public InvalidSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.INVALID_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java index 7508a1ee3..2b137282f 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,41 +16,48 @@ package io.rsocket.exceptions; -import java.util.Objects; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; -/** The root of the RSocket exception hierarchy. */ -public abstract class RSocketException extends RuntimeException { +/** + * The root of the RSocket exception hierarchy. + * + * @deprecated please use {@link RSocketErrorException} instead + */ +@Deprecated +public abstract class RSocketException extends RSocketErrorException { private static final long serialVersionUID = 2912815394105575423L; /** - * Constructs a new exception with the specified message. + * Constructs a new exception with the specified message and error code 0x201 (Application error). * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RSocketException(String message) { - super(Objects.requireNonNull(message, "message must not be null")); + this(message, null); } /** - * Constructs a new exception with the specified message and cause. + * Constructs a new exception with the specified message and cause and error code 0x201 + * (Application error). * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} is {@code null} */ public RSocketException(String message, @Nullable Throwable cause) { - super(Objects.requireNonNull(message, "message must not be null"), cause); + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); } /** - * Returns the RSocket error code - * represented by this exception + * Constructs a new exception with the specified error code, message and cause. * - * @return the RSocket error code + * @param errorCode the RSocket protocol error code + * @param message the message + * @param cause the cause of this exception */ - public abstract int errorCode(); + public RSocketException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java index e7556e19a..baed84e1b 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * Despite being a valid request, the Responder decided to reject it. The Responder guarantees that @@ -26,7 +27,7 @@ * @see Error * Codes */ -public final class RejectedException extends RSocketException implements Retryable { +public class RejectedException extends RSocketException implements Retryable { private static final long serialVersionUID = 3926231092835143715L; @@ -34,10 +35,9 @@ public final class RejectedException extends RSocketException implements Retryab * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedException(String message) { - super(message); + this(message, null); } /** @@ -45,14 +45,8 @@ public RejectedException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED; + public RejectedException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java index 0d4116538..8a99fcffb 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The server rejected the resume, it can specify the reason in the payload. @@ -32,10 +33,9 @@ public final class RejectedResumeException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedResumeException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public RejectedResumeException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedResumeException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED_RESUME; + public RejectedResumeException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED_RESUME, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java index 1fa5f604e..c09a27e32 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The server rejected the setup, it can specify the reason in the payload. @@ -32,10 +33,9 @@ public final class RejectedSetupException extends SetupException implements Retr * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedSetupException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public RejectedSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED_SETUP; + public RejectedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java index 2111a51b1..ed979c9e6 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ package io.rsocket.exceptions; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + /** The root of the setup exception hierarchy. */ public abstract class SetupException extends RSocketException { @@ -25,10 +28,11 @@ public abstract class SetupException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} + * @deprecated please use {@link #SetupException(int, String, Throwable)} */ + @Deprecated public SetupException(String message) { - super(message); + this(message, null); } /** @@ -36,9 +40,21 @@ public SetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @deprecated please use {@link #SetupException(int, String, Throwable)} + */ + @Deprecated + public SetupException(String message, @Nullable Throwable cause) { + this(ErrorFrameCodec.INVALID_SETUP, message, cause); + } + + /** + * Constructs a new exception with the specified error code, message and cause. + * + * @param errorCode the RSocket protocol code + * @param message the message + * @param cause the cause of this exception */ - public SetupException(String message, Throwable cause) { - super(message, cause); + public SetupException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java index 7d14bc5d2..7429ccd98 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * Some (or all) of the parameters specified by the client are unsupported by the server. @@ -32,10 +33,9 @@ public final class UnsupportedSetupException extends SetupException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public UnsupportedSetupException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public UnsupportedSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public UnsupportedSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.UNSUPPORTED_SETUP; + public UnsupportedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.UNSUPPORTED_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java b/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java index babf8194e..969aedded 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ /** - * The hierarchy of exceptions that can be returned by the API + * A hierarchy of exceptions that represent RSocket protocol error codes. * * @see Error * Codes diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java index af492508d..5d89bb9ad 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,21 @@ package io.rsocket.fragmentation; +import static io.rsocket.fragmentation.FrameFragmenter.fragmentFrame; + import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import java.util.Objects; import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; /** * A {@link DuplexConnection} implementation that fragments and reassembles {@link ByteBuf}s. @@ -29,140 +39,97 @@ * href="https://github.com/rsocket/rsocket/blob/master/Protocol.md#fragmentation-and-reassembly">Fragmentation * and Reassembly */ -public final class FragmentationDuplexConnection implements DuplexConnection { - public FragmentationDuplexConnection(DuplexConnection connection, int mtu) {} - - @Override - public Mono send(Publisher frames) { - return null; - } - - @Override - public Flux receive() { - return null; - } - - @Override - public Mono onClose() { - return null; - } - - @Override - public void dispose() {} - - /* - private final ByteBufAllocator byteBufAllocator; - +public final class FragmentationDuplexConnection extends ReassemblyDuplexConnection + implements DuplexConnection { + public static final int MIN_MTU_SIZE = 64; + private static final Logger logger = LoggerFactory.getLogger(FragmentationDuplexConnection.class); private final DuplexConnection delegate; + private final int mtu; + private final FrameReassembler frameReassembler; + private final boolean encodeLength; + private final String type; - private final FrameFragmenter frameFragmenter; - - private final IntObjectHashMap frameReassemblers = new IntObjectHashMap<>(); - - */ - /** - * Creates a new instance. - * - * @param delegate the {@link DuplexConnection} to decorate - * @param maxFragmentSize the maximum fragment size - * @throws NullPointerException if {@code delegate} is {@code null} - * @throws IllegalArgumentException if {@code maxFragmentSize} is not {@code positive} - */ - /* - // TODO: Remove once ByteBufAllocators are shared - public FragmentationDuplexConnection(DuplexConnection delegate, int maxFragmentSize) { - this(PooledByteBufAllocator.DEFAULT, delegate, maxFragmentSize); - } - - */ - /** - * Creates a new instance. - * - * @param byteBufAllocator the {@link ByteBufAllocator} to use - * @param delegate the {@link DuplexConnection} to decorate - * @param maxFragmentSize the maximum fragment size. A value of 0 indicates that frames should not - * be fragmented. - * @throws NullPointerException if {@code byteBufAllocator} or {@code delegate} are {@code null} - * @throws IllegalArgumentException if {@code maxFragmentSize} is not {@code positive} - */ - /* public FragmentationDuplexConnection( - ByteBufAllocator byteBufAllocator, DuplexConnection delegate, int maxFragmentSize) { - - this.byteBufAllocator = - Objects.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); - this.delegate = Objects.requireNonNull(delegate, "delegate must not be null"); + DuplexConnection delegate, int mtu, boolean encodeAndEncodeLength, String type) { + super(delegate, encodeAndEncodeLength); - NumberUtils.requireNonNegative(maxFragmentSize, "maxFragmentSize must be positive"); + Objects.requireNonNull(delegate, "delegate must not be null"); + this.encodeLength = encodeAndEncodeLength; + this.delegate = delegate; + this.mtu = assertMtu(mtu); + this.frameReassembler = new FrameReassembler(delegate.alloc()); + this.type = type; - this.frameFragmenter = new FrameFragmenter(byteBufAllocator, maxFragmentSize); - - delegate - .onClose() - .doFinally( - signalType -> { - Collection values; - synchronized (FragmentationDuplexConnection.this) { - values = frameReassemblers.values(); - } - values.forEach(FrameReassembler::dispose); - }) - .subscribe(); + delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); } - @Override - public double availability() { - return delegate.availability(); + private boolean shouldFragment(FrameType frameType, int readableBytes) { + return frameType.isFragmentable() && readableBytes > mtu; } - @Override - public void dispose() { - delegate.dispose(); + /*TODO this is nullable and not returning empty to workaround javac 11.0.3 compiler issue on ubuntu (at least) */ + @Nullable + public static Mono checkMtu(int mtu) { + if (isInsufficientMtu(mtu)) { + String msg = + String.format("smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + return Mono.error(new IllegalArgumentException(msg)); + } else { + return null; + } } - @Override - public boolean isDisposed() { - return delegate.isDisposed(); + private static int assertMtu(int mtu) { + if (isInsufficientMtu(mtu)) { + String msg = + String.format("smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } else { + return mtu; + } } - @Override - public Mono onClose() { - return delegate.onClose(); + private static boolean isInsufficientMtu(int mtu) { + return mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0; } @Override - public Flux receive() { - return delegate - .receive() - .map(AbstractionLeakingFrameUtils::fromAbstractionLeakingFrame) - .concatMap(t2 -> toReassembledFrames(t2.getT1(), t2.getT2())); + public Mono send(Publisher frames) { + return Flux.from(frames).concatMap(this::sendOne).then(); } @Override - public Mono send(Publisher frames) { - Objects.requireNonNull(frames, "frames must not be null"); - - return delegate.send( - Flux.from(frames) - .map(AbstractionLeakingFrameUtils::fromAbstractionLeakingFrame) - .concatMap(t2 -> toFragmentedFrames(t2.getT1(), t2.getT2()))); - } - - private Flux toFragmentedFrames(int streamId, io.rsocket.framing.Frame frame) { - return this.frameFragmenter - .fragment(frame) - .map(fragment -> toAbstractionLeakingFrame(byteBufAllocator, streamId, fragment)); + public Mono sendOne(ByteBuf frame) { + FrameType frameType = FrameHeaderCodec.frameType(frame); + int readableBytes = frame.readableBytes(); + if (shouldFragment(frameType, readableBytes)) { + if (logger.isDebugEnabled()) { + return delegate.send( + Flux.from(fragmentFrame(alloc(), mtu, frame, frameType, encodeLength)) + .doOnNext( + byteBuf -> { + ByteBuf f = encodeLength ? FrameLengthCodec.frame(byteBuf) : byteBuf; + logger.debug( + "{} - stream id {} - frame type {} - \n {}", + type, + FrameHeaderCodec.streamId(f), + FrameHeaderCodec.frameType(f), + ByteBufUtil.prettyHexDump(f)); + })); + } else { + return delegate.send( + Flux.from(fragmentFrame(alloc(), mtu, frame, frameType, encodeLength))); + } + } else { + return delegate.sendOne(encode(frame)); + } } - private Mono toReassembledFrames(int streamId, io.rsocket.framing.Frame fragment) { - FrameReassembler frameReassembler; - synchronized (this) { - frameReassembler = - frameReassemblers.computeIfAbsent( - streamId, i -> createFrameReassembler(byteBufAllocator)); + private ByteBuf encode(ByteBuf frame) { + if (encodeLength) { + return FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame); + } else { + return frame; } - - return Mono.justOrEmpty(frameReassembler.reassemble(fragment)) - .map(frame -> toAbstractionLeakingFrame(byteBufAllocator, streamId, frame)); - }*/ + } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index e9b4de243..4b8fd36e9 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,23 @@ package io.rsocket.fragmentation; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import java.util.function.Consumer; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.SynchronousSink; + /** * The implementation of the RSocket fragmentation behavior. * @@ -24,168 +41,208 @@ * and Reassembly */ final class FrameFragmenter { - /* - private final ByteBufAllocator byteBufAllocator; - - private final Logger logger = LoggerFactory.getLogger(this.getClass()); - - private final int maxFragmentSize; - - */ - /** - * Creates a new instance - * - * @param byteBufAllocator the {@link ByteBufAllocator} to use - * @param maxFragmentSize the maximum size of each fragment - */ - /* - FrameFragmenter(ByteBufAllocator byteBufAllocator, int maxFragmentSize) { - this.byteBufAllocator = - Objects.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); - this.maxFragmentSize = maxFragmentSize; - } - - */ - /** - * Returns a {@link Flux} of fragments frames - * - * @param frame the {@link ByteBuf} to fragment - * @return a {@link Flux} of fragment frames - * @throws NullPointerException if {@code frame} is {@code null} - */ - /* - public Flux fragment(ByteBuf frame) { - Objects.requireNonNull(frame, "frame must not be null"); - - if (!shouldFragment(frame)) { - logger.debug("Not fragmenting {}", frame); - return Flux.just(frame); - } - - logger.debug("Fragmenting {}", frame); + static Publisher fragmentFrame( + ByteBufAllocator allocator, + int mtu, + final ByteBuf frame, + FrameType frameType, + boolean encodeLength) { + ByteBuf metadata = getMetadata(frame, frameType); + ByteBuf data = getData(frame, frameType); + int streamId = FrameHeaderCodec.streamId(frame); return Flux.generate( - () -> new FragmentationState((FragmentableFrame) frame), - this::generate, - FragmentationState::dispose); + new Consumer>() { + boolean first = true; + + @Override + public void accept(SynchronousSink sink) { + ByteBuf byteBuf; + if (first) { + first = false; + byteBuf = + encodeFirstFragment( + allocator, mtu, frame, frameType, streamId, metadata, data); + } else { + byteBuf = encodeFollowsFragment(allocator, mtu, streamId, metadata, data); + } + + sink.next(encode(allocator, byteBuf, encodeLength)); + if (!metadata.isReadable() && !data.isReadable()) { + sink.complete(); + } + } + }) + .doFinally(signalType -> ReferenceCountUtil.safeRelease(frame)); } - private FragmentationState generate(FragmentationState state, SynchronousSink sink) { - int fragmentLength = maxFragmentSize; - - ByteBuf metadata; - if (state.hasReadableMetadata()) { - metadata = state.readMetadataFragment(fragmentLength); - fragmentLength -= metadata.readableBytes(); - } else { - metadata = null; + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + ByteBuf frame, + FrameType frameType, + int streamId, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + int remaining = mtu - FrameHeaderCodec.size(); + + // substract the initial request n + switch (frameType) { + case REQUEST_STREAM: + case REQUEST_CHANNEL: + remaining -= Integer.BYTES; + break; + default: } - if (state.hasReadableMetadata()) { - ByteBuf fragment = state.createFrame(byteBufAllocator, false, metadata, null); - logger.debug("Fragment {}", fragment); - - sink.next(fragment); - return state; + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= 3; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); } - ByteBuf data; - data = state.hasReadableData() ? state.readDataFragment(fragmentLength) : null; - - if (state.hasReadableData()) { - ByteBuf fragment = state.createFrame(byteBufAllocator, false, metadata, data); - logger.debug("Fragment {}", fragment); - - sink.next(fragment); - return state; + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); } - ByteBuf fragment = state.createFrame(byteBufAllocator, true, metadata, data); - logger.debug("Final Fragment {}", fragment); - - sink.next(fragment); - sink.complete(); - return state; - } - - private int getFragmentableLength(FragmentableFrame fragmentableFrame) { - return fragmentableFrame.getMetadataLength().orElse(0) + fragmentableFrame.getDataLength(); - } - - private boolean shouldFragment(ByteBuf frame) { - if (maxFragmentSize == 0 || !(frame instanceof FragmentableFrame)) { - return false; + switch (frameType) { + case REQUEST_FNF: + return RequestFireAndForgetFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + case REQUEST_STREAM: + return RequestStreamFrameCodec.encode( + allocator, + streamId, + true, + RequestStreamFrameCodec.initialRequestN(frame), + metadataFragment, + dataFragment); + case REQUEST_RESPONSE: + return RequestResponseFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + case REQUEST_CHANNEL: + return RequestChannelFrameCodec.encode( + allocator, + streamId, + true, + false, + RequestChannelFrameCodec.initialRequestN(frame), + metadataFragment, + dataFragment); + // Payload and synthetic types + case PAYLOAD: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, false, metadataFragment, dataFragment); + case NEXT: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, true, metadataFragment, dataFragment); + case NEXT_COMPLETE: + return PayloadFrameCodec.encode( + allocator, streamId, true, true, true, metadataFragment, dataFragment); + case COMPLETE: + return PayloadFrameCodec.encode( + allocator, streamId, true, true, false, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); } - - FragmentableFrame fragmentableFrame = (FragmentableFrame) frame; - return !fragmentableFrame.isFollowsFlagSet() - && getFragmentableLength(fragmentableFrame) > maxFragmentSize; } - static final class FragmentationState implements Disposable { - - private final FragmentableFrame frame; - - private int dataIndex = 0; - - private boolean initialFragmentCreated = false; - - private int metadataIndex = 0; - - FragmentationState(FragmentableFrame frame) { - this.frame = frame; + static ByteBuf encodeFollowsFragment( + ByteBufAllocator allocator, int mtu, int streamId, ByteBuf metadata, ByteBuf data) { + // subtract the header bytes + int remaining = mtu - FrameHeaderCodec.size(); + + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= 3; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); } - @Override - public void dispose() { - disposeQuietly(frame); + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); } - ByteBuf createFrame( - ByteBufAllocator byteBufAllocator, - boolean complete, - @Nullable ByteBuf metadata, - @Nullable ByteBuf data) { - - if (initialFragmentCreated) { - return createPayloadFrame(byteBufAllocator, !complete, data == null, metadata, data); - } else { - initialFragmentCreated = true; - return frame.createFragment(byteBufAllocator, metadata, data); - } - } - - boolean hasReadableData() { - return frame.getDataLength() - dataIndex > 0; - } + boolean follows = data.isReadable() || metadata.isReadable(); + return PayloadFrameCodec.encode( + allocator, streamId, follows, false, true, metadataFragment, dataFragment); + } - boolean hasReadableMetadata() { - Integer metadataLength = frame.getUnsafeMetadataLength(); - return metadataLength != null && metadataLength - metadataIndex > 0; + static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(frame); + if (hasMetadata) { + ByteBuf metadata; + switch (frameType) { + case REQUEST_FNF: + metadata = RequestFireAndForgetFrameCodec.metadata(frame); + break; + case REQUEST_STREAM: + metadata = RequestStreamFrameCodec.metadata(frame); + break; + case REQUEST_RESPONSE: + metadata = RequestResponseFrameCodec.metadata(frame); + break; + case REQUEST_CHANNEL: + metadata = RequestChannelFrameCodec.metadata(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + metadata = PayloadFrameCodec.metadata(frame); + break; + default: + throw new IllegalStateException("unsupported fragment type"); + } + return metadata; + } else { + return Unpooled.EMPTY_BUFFER; } + } - ByteBuf readDataFragment(int length) { - int safeLength = min(length, frame.getDataLength() - dataIndex); - - ByteBuf fragment = frame.getUnsafeData().slice(dataIndex, safeLength); - - dataIndex += fragment.readableBytes(); - return fragment; + static ByteBuf getData(ByteBuf frame, FrameType frameType) { + ByteBuf data; + switch (frameType) { + case REQUEST_FNF: + data = RequestFireAndForgetFrameCodec.data(frame); + break; + case REQUEST_STREAM: + data = RequestStreamFrameCodec.data(frame); + break; + case REQUEST_RESPONSE: + data = RequestResponseFrameCodec.data(frame); + break; + case REQUEST_CHANNEL: + data = RequestChannelFrameCodec.data(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + data = PayloadFrameCodec.data(frame); + break; + default: + throw new IllegalStateException("unsupported fragment type"); } + return data; + } - ByteBuf readMetadataFragment(int length) { - Integer metadataLength = frame.getUnsafeMetadataLength(); - ByteBuf metadata = frame.getUnsafeMetadata(); - - if (metadataLength == null || metadata == null) { - throw new IllegalStateException("Cannot read metadata fragment with no metadata"); - } - - int safeLength = min(length, metadataLength - metadataIndex); - - ByteBuf fragment = metadata.slice(metadataIndex, safeLength); - - metadataIndex += fragment.readableBytes(); - return fragment; + static ByteBuf encode(ByteBufAllocator allocator, ByteBuf frame, boolean encodeLength) { + if (encodeLength) { + return FrameLengthCodec.encode(allocator, frame.readableBytes(), frame); + } else { + return frame; } - }*/ + } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java index a44883915..52068e5de 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,20 @@ package io.rsocket.fragmentation; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.frame.*; +import java.util.concurrent.atomic.AtomicBoolean; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; +import reactor.core.publisher.SynchronousSink; +import reactor.util.annotation.Nullable; /** * The implementation of the RSocket reassembly behavior. @@ -25,143 +38,260 @@ * href="https://github.com/rsocket/rsocket/blob/master/Protocol.md#fragmentation-and-reassembly">Fragmentation * and Reassembly */ -final class FrameReassembler implements Disposable { - @Override - public void dispose() {} +final class FrameReassembler extends AtomicBoolean implements Disposable { - @Override - public boolean isDisposed() { - return false; - } - /* - private static final Recycler RECYCLER = createRecycler(FrameReassembler::new); + private static final long serialVersionUID = -4394598098863449055L; - private final Handle handle; + private static final Logger logger = LoggerFactory.getLogger(FrameReassembler.class); - private ByteBufAllocator byteBufAllocator; + final IntObjectMap headers; + final IntObjectMap metadata; + final IntObjectMap data; - private ReassemblyState state; + private final ByteBufAllocator allocator; - private FrameReassembler(Handle handle) { - this.handle = handle; + public FrameReassembler(ByteBufAllocator allocator) { + this.allocator = allocator; + this.headers = new IntObjectHashMap<>(); + this.metadata = new IntObjectHashMap<>(); + this.data = new IntObjectHashMap<>(); } @Override public void dispose() { - if (state != null) { - disposeQuietly(state); + if (compareAndSet(false, true)) { + synchronized (FrameReassembler.this) { + for (ByteBuf byteBuf : headers.values()) { + ReferenceCountUtil.safeRelease(byteBuf); + } + headers.clear(); + + for (ByteBuf byteBuf : metadata.values()) { + ReferenceCountUtil.safeRelease(byteBuf); + } + metadata.clear(); + + for (ByteBuf byteBuf : data.values()) { + ReferenceCountUtil.safeRelease(byteBuf); + } + data.clear(); + } } + } + + @Override + public boolean isDisposed() { + return get(); + } - byteBufAllocator = null; - state = null; - - handle.recycle(this); - } - - */ - /** - * Creates a new instance - * - * @param byteBufAllocator the {@link ByteBufAllocator} to use - * @return the {@code FrameReassembler} - * @throws NullPointerException if {@code byteBufAllocator} is {@code null} - */ - /* - static FrameReassembler createFrameReassembler(ByteBufAllocator byteBufAllocator) { - return RECYCLER.get().setByteBufAllocator(byteBufAllocator); - } - - */ - /** - * Reassembles a frame. If the frame is not a candidate for fragmentation, emits the frame. If - * frame is a candidate for fragmentation, accumulates the content until the final fragment. - * - * @param frame the frame to inspect for reassembly - * @return the reassembled frame if complete, otherwise {@code null} - * @throws NullPointerException if {@code frame} is {@code null} - */ - /* @Nullable - Frame reassemble(Frame frame) { - Objects.requireNonNull(frame, "frame must not be null"); + synchronized ByteBuf getHeader(int streamId) { + return headers.get(streamId); + } - if (!(frame instanceof FragmentableFrame)) { - return frame; - } + synchronized CompositeByteBuf getMetadata(int streamId) { + CompositeByteBuf byteBuf = metadata.get(streamId); - FragmentableFrame fragmentableFrame = (FragmentableFrame) frame; + if (byteBuf == null) { + byteBuf = allocator.compositeBuffer(); + metadata.put(streamId, byteBuf); + } - if (fragmentableFrame.isFollowsFlagSet()) { - if (state == null) { - state = new ReassemblyState(fragmentableFrame); - } else { - state.accumulate(fragmentableFrame); - } - } else if (state != null) { - state.accumulate(fragmentableFrame); + return byteBuf; + } - Frame reassembledFrame = state.createFrame(byteBufAllocator); - state.dispose(); - state = null; + synchronized CompositeByteBuf getData(int streamId) { + CompositeByteBuf byteBuf = data.get(streamId); - return reassembledFrame; - } else { - return fragmentableFrame; + if (byteBuf == null) { + byteBuf = allocator.compositeBuffer(); + data.put(streamId, byteBuf); } - return null; + return byteBuf; + } + + @Nullable + synchronized ByteBuf removeHeader(int streamId) { + return headers.remove(streamId); + } + + @Nullable + synchronized CompositeByteBuf removeMetadata(int streamId) { + return metadata.remove(streamId); } - FrameReassembler setByteBufAllocator(ByteBufAllocator byteBufAllocator) { - this.byteBufAllocator = - Objects.requireNonNull(byteBufAllocator, "byteBufAllocator must not be null"); + @Nullable + synchronized CompositeByteBuf removeData(int streamId) { + return data.remove(streamId); + } - return this; + synchronized void putHeader(int streamId, ByteBuf header) { + headers.put(streamId, header); } - static final class ReassemblyState implements Disposable { + void cancelAssemble(int streamId) { + ByteBuf header = removeHeader(streamId); + CompositeByteBuf metadata = removeMetadata(streamId); + CompositeByteBuf data = removeData(streamId); - private ByteBuf data; + if (header != null) { + ReferenceCountUtil.safeRelease(header); + } - private List fragments = new ArrayList<>(); + if (metadata != null) { + ReferenceCountUtil.safeRelease(metadata); + } - private ByteBuf metadata; + if (data != null) { + ReferenceCountUtil.safeRelease(data); + } + } - ReassemblyState(FragmentableFrame fragment) { - accumulate(fragment); + void handleNoFollowsFlag(ByteBuf frame, SynchronousSink sink, int streamId) { + ByteBuf header = removeHeader(streamId); + if (header != null) { + if (FrameHeaderCodec.hasMetadata(header)) { + ByteBuf assembledFrame = assembleFrameWithMetadata(frame, streamId, header); + sink.next(assembledFrame); + } else { + ByteBuf data = assembleData(frame, streamId); + ByteBuf assembledFrame = FragmentationCodec.encode(allocator, header, data); + sink.next(assembledFrame); + } + frame.release(); + } else { + sink.next(frame); } + } + + void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { + ByteBuf header = getHeader(streamId); + if (header == null) { + header = frame.copy(frame.readerIndex(), FrameHeaderCodec.size()); - @Override - public void dispose() { - fragments.forEach(Disposable::dispose); + if (frameType == FrameType.REQUEST_CHANNEL || frameType == FrameType.REQUEST_STREAM) { + long i = RequestChannelFrameCodec.initialRequestN(frame); + header.writeInt(i > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) i); + } + putHeader(streamId, header); } - void accumulate(FragmentableFrame fragment) { - fragments.add(fragment); - metadata = accumulateMetadata(fragment); - data = accumulateData(fragment); + if (FrameHeaderCodec.hasMetadata(frame)) { + CompositeByteBuf metadata = getMetadata(streamId); + switch (frameType) { + case REQUEST_FNF: + metadata.addComponents(true, RequestFireAndForgetFrameCodec.metadata(frame).retain()); + break; + case REQUEST_STREAM: + metadata.addComponents(true, RequestStreamFrameCodec.metadata(frame).retain()); + break; + case REQUEST_RESPONSE: + metadata.addComponents(true, RequestResponseFrameCodec.metadata(frame).retain()); + break; + case REQUEST_CHANNEL: + metadata.addComponents(true, RequestChannelFrameCodec.metadata(frame).retain()); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + metadata.addComponents(true, PayloadFrameCodec.metadata(frame).retain()); + break; + default: + throw new IllegalStateException("unsupported fragment type"); + } } - Frame createFrame(ByteBufAllocator byteBufAllocator) { - FragmentableFrame root = fragments.get(0); - return root.createNonFragment(byteBufAllocator, metadata, data); + ByteBuf data; + switch (frameType) { + case REQUEST_FNF: + data = RequestFireAndForgetFrameCodec.data(frame).retain(); + break; + case REQUEST_STREAM: + data = RequestStreamFrameCodec.data(frame).retain(); + break; + case REQUEST_RESPONSE: + data = RequestResponseFrameCodec.data(frame).retain(); + break; + case REQUEST_CHANNEL: + data = RequestChannelFrameCodec.data(frame).retain(); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + data = PayloadFrameCodec.data(frame).retain(); + break; + default: + throw new IllegalStateException("unsupported fragment type"); } - private ByteBuf accumulateData(FragmentableFrame fragment) { - ByteBuf data = fragment.getUnsafeData(); - return this.data == null ? data.retain() : Unpooled.wrappedBuffer(this.data, data.retain()); + getData(streamId).addComponents(true, data); + frame.release(); + } + + void reassembleFrame(ByteBuf frame, SynchronousSink sink) { + try { + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); + switch (frameType) { + case CANCEL: + case ERROR: + cancelAssemble(streamId); + } + + if (!frameType.isFragmentable()) { + sink.next(frame); + return; + } + + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + + if (hasFollows) { + handleFollowsFlag(frame, streamId, frameType); + } else { + handleNoFollowsFlag(frame, sink, streamId); + } + + } catch (Throwable t) { + logger.error("error reassemble frame", t); + sink.error(t); } + } - private @Nullable ByteBuf accumulateMetadata(FragmentableFrame fragment) { - ByteBuf metadata = fragment.getUnsafeMetadata(); + private ByteBuf assembleFrameWithMetadata(ByteBuf frame, int streamId, ByteBuf header) { + ByteBuf metadata; + CompositeByteBuf cm = removeMetadata(streamId); - if (metadata == null) { - return this.metadata; + ByteBuf decodedMetadata = PayloadFrameCodec.metadata(frame); + if (decodedMetadata != null) { + if (cm != null) { + metadata = cm.addComponents(true, decodedMetadata.retain()); + } else { + metadata = PayloadFrameCodec.metadata(frame).retain(); } + } else { + metadata = cm; + } + + ByteBuf data = assembleData(frame, streamId); + + return FragmentationCodec.encode(allocator, header, metadata, data); + } - return this.metadata == null - ? metadata.retain() - : Unpooled.wrappedBuffer(this.metadata, metadata.retain()); + private ByteBuf assembleData(ByteBuf frame, int streamId) { + ByteBuf data; + CompositeByteBuf cd = removeData(streamId); + if (cd != null) { + cd.addComponents(true, PayloadFrameCodec.data(frame).retain()); + data = cd; + } else { + data = Unpooled.EMPTY_BUFFER; } - }*/ + + return data; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java new file mode 100644 index 000000000..6060c0c20 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java @@ -0,0 +1,92 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.fragmentation; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameLengthCodec; +import java.util.Objects; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * A {@link DuplexConnection} implementation that reassembles {@link ByteBuf}s. + * + * @see Fragmentation + * and Reassembly + */ +public class ReassemblyDuplexConnection implements DuplexConnection { + private final DuplexConnection delegate; + private final FrameReassembler frameReassembler; + private final boolean decodeLength; + + public ReassemblyDuplexConnection(DuplexConnection delegate, boolean decodeLength) { + Objects.requireNonNull(delegate, "delegate must not be null"); + this.decodeLength = decodeLength; + this.delegate = delegate; + this.frameReassembler = new FrameReassembler(delegate.alloc()); + + delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); + } + + @Override + public Mono send(Publisher frames) { + return delegate.send(frames); + } + + @Override + public Mono sendOne(ByteBuf frame) { + return delegate.sendOne(frame); + } + + private ByteBuf decode(ByteBuf frame) { + if (decodeLength) { + return FrameLengthCodec.frame(frame).retain(); + } else { + return frame; + } + } + + @Override + public Flux receive() { + return delegate + .receive() + .handle( + (byteBuf, sink) -> { + ByteBuf decode = decode(byteBuf); + frameReassembler.reassembleFrame(decode, sink); + }); + } + + @Override + public ByteBufAllocator alloc() { + return delegate.alloc(); + } + + @Override + public Mono onClose() { + return delegate.onClose(); + } + + @Override + public void dispose() { + delegate.dispose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java index 4431f98dd..8cc3fb41a 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java similarity index 55% rename from rsocket-core/src/main/java/io/rsocket/frame/CancelFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java index 349a43c3a..d0d929f0f 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java @@ -3,10 +3,10 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -public class CancelFrameFlyweight { - private CancelFrameFlyweight() {} +public class CancelFrameCodec { + private CancelFrameCodec() {} public static ByteBuf encode(final ByteBufAllocator allocator, final int streamId) { - return FrameHeaderFlyweight.encode(allocator, streamId, FrameType.CANCEL, 0); + return FrameHeaderCodec.encode(allocator, streamId, FrameType.CANCEL, 0); } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java deleted file mode 100644 index f07f5f004..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java +++ /dev/null @@ -1,83 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; - -class DataAndMetadataFlyweight { - public static final int FRAME_LENGTH_MASK = 0xFFFFFF; - - private DataAndMetadataFlyweight() {} - - private static void encodeLength(final ByteBuf byteBuf, final int length) { - if ((length & ~FRAME_LENGTH_MASK) != 0) { - throw new IllegalArgumentException("Length is larger than 24 bits"); - } - // Write each byte separately in reverse order, this mean we can write 1 << 23 without - // overflowing. - byteBuf.writeByte(length >> 16); - byteBuf.writeByte(length >> 8); - byteBuf.writeByte(length); - } - - private static int decodeLength(final ByteBuf byteBuf) { - int length = (byteBuf.readByte() & 0xFF) << 16; - length |= (byteBuf.readByte() & 0xFF) << 8; - length |= byteBuf.readByte() & 0xFF; - return length; - } - - static ByteBuf encodeOnlyMetadata( - ByteBufAllocator allocator, final ByteBuf header, ByteBuf metadata) { - return allocator.compositeBuffer(2).addComponents(true, header, metadata); - } - - static ByteBuf encodeOnlyData(ByteBufAllocator allocator, final ByteBuf header, ByteBuf data) { - return allocator.compositeBuffer(2).addComponents(true, header, data); - } - - static ByteBuf encode( - ByteBufAllocator allocator, final ByteBuf header, ByteBuf metadata, ByteBuf data) { - - int length = metadata.readableBytes(); - encodeLength(header, length); - - return allocator.compositeBuffer(3).addComponents(true, header, metadata, data); - } - - static ByteBuf metadataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { - if (hasMetadata) { - int length = decodeLength(byteBuf); - return byteBuf.readSlice(length); - } else { - return Unpooled.EMPTY_BUFFER; - } - } - - static ByteBuf metadata(ByteBuf byteBuf, boolean hasMetadata) { - byteBuf.markReaderIndex(); - ByteBuf metadata = metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; - } - - static ByteBuf dataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { - if (hasMetadata) { - /*moves reader index*/ - int length = decodeLength(byteBuf); - byteBuf.skipBytes(length); - } - if (byteBuf.readableBytes() > 0) { - return byteBuf.slice(); - } else { - return Unpooled.EMPTY_BUFFER; - } - } - - static ByteBuf data(ByteBuf byteBuf, boolean hasMetadata) { - byteBuf.markReaderIndex(); - ByteBuf data = dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return data; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java similarity index 70% rename from rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java index 55e23541e..dcacb57dc 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java @@ -3,28 +3,35 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; -import io.rsocket.exceptions.RSocketException; +import io.rsocket.RSocketErrorException; import java.nio.charset.StandardCharsets; -public class ErrorFrameFlyweight { +public class ErrorFrameCodec { - // defined error codes + // defined zero stream id error codes public static final int INVALID_SETUP = 0x00000001; public static final int UNSUPPORTED_SETUP = 0x00000002; public static final int REJECTED_SETUP = 0x00000003; public static final int REJECTED_RESUME = 0x00000004; public static final int CONNECTION_ERROR = 0x00000101; public static final int CONNECTION_CLOSE = 0x00000102; + // defined non-zero stream id error codes public static final int APPLICATION_ERROR = 0x00000201; public static final int REJECTED = 0x00000202; public static final int CANCELED = 0x00000203; public static final int INVALID = 0x00000204; + // defined user-allowed error codes range + public static final int MIN_USER_ALLOWED_ERROR_CODE = 0x00000301; + public static final int MAX_USER_ALLOWED_ERROR_CODE = 0xFFFFFFFE; public static ByteBuf encode( ByteBufAllocator allocator, int streamId, Throwable t, ByteBuf data) { - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.ERROR, 0); + ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.ERROR, 0); - int errorCode = errorCodeFromException(t); + int errorCode = + t instanceof RSocketErrorException + ? ((RSocketErrorException) t).errorCode() + : APPLICATION_ERROR; header.writeInt(errorCode); @@ -37,17 +44,9 @@ public static ByteBuf encode(ByteBufAllocator allocator, int streamId, Throwable return encode(allocator, streamId, t, data); } - public static int errorCodeFromException(Throwable t) { - if (t instanceof RSocketException) { - return ((RSocketException) t).errorCode(); - } - - return APPLICATION_ERROR; - } - public static int errorCode(ByteBuf byteBuf) { byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); + byteBuf.skipBytes(FrameHeaderCodec.size()); int i = byteBuf.readInt(); byteBuf.resetReaderIndex(); return i; @@ -55,7 +54,7 @@ public static int errorCode(ByteBuf byteBuf) { public static ByteBuf data(ByteBuf byteBuf) { byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); ByteBuf slice = byteBuf.slice(); byteBuf.resetReaderIndex(); return slice; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java deleted file mode 100644 index ccbff374e..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java +++ /dev/null @@ -1,74 +0,0 @@ -package io.rsocket.frame; - -/** - * The types of {@link Error} that can be set. - * - * @see Error - * Codes - */ -public final class ErrorType { - - /** - * Application layer logic generating a Reactive Streams onError event. Stream ID MUST be > 0. - */ - public static final int APPLICATION_ERROR = 0x00000201;; - - /** - * The Responder canceled the request but may have started processing it (similar to REJECTED but - * doesn't guarantee lack of side-effects). Stream ID MUST be > 0. - */ - public static final int CANCELED = 0x00000203; - - /** - * The connection is being terminated. Stream ID MUST be 0. Sender or Receiver of this frame MUST - * wait for outstanding streams to terminate before closing the connection. New requests MAY not - * be accepted. - */ - public static final int CONNECTION_CLOSE = 0x00000102; - - /** - * The connection is being terminated. Stream ID MUST be 0. Sender or Receiver of this frame MAY - * close the connection immediately without waiting for outstanding streams to terminate. - */ - public static final int CONNECTION_ERROR = 0x00000101; - - /** The request is invalid. Stream ID MUST be > 0. */ - public static final int INVALID = 0x00000204; - - /** - * The Setup frame is invalid for the server (it could be that the client is too recent for the - * old server). Stream ID MUST be 0. - */ - public static final int INVALID_SETUP = 0x00000001; - - /** - * Despite being a valid request, the Responder decided to reject it. The Responder guarantees - * that it didn't process the request. The reason for the rejection is explained in the Error Data - * section. Stream ID MUST be > 0. - */ - public static final int REJECTED = 0x00000202; - - /** - * The server rejected the resume, it can specify the reason in the payload. Stream ID MUST be 0. - */ - public static final int REJECTED_RESUME = 0x00000004; - - /** - * The server rejected the setup, it can specify the reason in the payload. Stream ID MUST be 0. - */ - public static final int REJECTED_SETUP = 0x00000003; - - /** Reserved. */ - public static final int RESERVED = 0x00000000; - - /** Reserved for Extension Use. */ - public static final int RESERVED_FOR_EXTENSION = 0xFFFFFFFF; - - /** - * Some (or all) of the parameters specified by the client are unsupported by the server. Stream - * ID MUST be 0. - */ - public static final int UNSUPPORTED_SETUP = 0x00000002; - - private ErrorType() {} -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java new file mode 100644 index 000000000..418926596 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java @@ -0,0 +1,67 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +public class ExtensionFrameCodec { + private ExtensionFrameCodec() {} + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + int extendedType, + @Nullable ByteBuf metadata, + ByteBuf data) { + + final boolean hasMetadata = metadata != null; + + int flags = FrameHeaderCodec.FLAGS_I; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.EXT, flags); + header.writeInt(extendedType); + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + public static int extendedType(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i; + } + + public static ByteBuf data(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + // Extended type + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + // Extended type + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java deleted file mode 100644 index 18516eb27..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java +++ /dev/null @@ -1,64 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import javax.annotation.Nullable; - -public class ExtensionFrameFlyweight { - private ExtensionFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - int extendedType, - @Nullable ByteBuf metadata, - ByteBuf data) { - - int flags = FrameHeaderFlyweight.FLAGS_I; - - if (metadata != null) { - flags |= FrameHeaderFlyweight.FLAGS_M; - } - - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.EXT, flags); - header.writeInt(extendedType); - if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } - } - - public static int extendedType(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.EXT, byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - int i = byteBuf.readInt(); - byteBuf.resetReaderIndex(); - return i; - } - - public static ByteBuf data(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.EXT, byteBuf); - - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - // Extended type - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf data = DataAndMetadataFlyweight.dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return data; - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.EXT, byteBuf); - - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - // Extended type - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java new file mode 100644 index 000000000..de228b271 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java @@ -0,0 +1,19 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +/** FragmentationFlyweight is used to re-assemble frames */ +public class FragmentationCodec { + public static ByteBuf encode(final ByteBufAllocator allocator, ByteBuf header, ByteBuf data) { + return encode(allocator, header, null, data); + } + + public static ByteBuf encode( + final ByteBufAllocator allocator, ByteBuf header, @Nullable ByteBuf metadata, ByteBuf data) { + + final boolean hasMetadata = metadata != null; + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java new file mode 100644 index 000000000..ea011e503 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java @@ -0,0 +1,103 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import reactor.util.annotation.Nullable; + +class FrameBodyCodec { + public static final int FRAME_LENGTH_MASK = 0xFFFFFF; + + private FrameBodyCodec() {} + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + byte b = byteBuf.readByte(); + int length = (b & 0xFF) << 16; + byte b1 = byteBuf.readByte(); + length |= (b1 & 0xFF) << 8; + byte b2 = byteBuf.readByte(); + length |= b2 & 0xFF; + return length; + } + + static ByteBuf encode( + ByteBufAllocator allocator, + final ByteBuf header, + @Nullable ByteBuf metadata, + boolean hasMetadata, + @Nullable ByteBuf data) { + + final boolean addData; + if (data != null) { + if (data.isReadable()) { + addData = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + data.release(); + addData = false; + } + } else { + addData = false; + } + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (hasMetadata) { + int length = metadata.readableBytes(); + encodeLength(header, length); + } + + if (addMetadata && addData) { + return allocator.compositeBuffer(3).addComponents(true, header, metadata, data); + } else if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else if (addData) { + return allocator.compositeBuffer(2).addComponents(true, header, data); + } else { + return header; + } + } + + static ByteBuf metadataWithoutMarking(ByteBuf byteBuf) { + int length = decodeLength(byteBuf); + return byteBuf.readSlice(length); + } + + static ByteBuf dataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { + if (hasMetadata) { + /*moves reader index*/ + int length = decodeLength(byteBuf); + byteBuf.skipBytes(length); + } + if (byteBuf.readableBytes() > 0) { + return byteBuf.readSlice(byteBuf.readableBytes()); + } else { + return Unpooled.EMPTY_BUFFER; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java similarity index 85% rename from rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java index 7dbe8053a..28f39459d 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java @@ -12,7 +12,7 @@ * *

Not thread-safe. Assumed to be used single-threaded */ -public final class FrameHeaderFlyweight { +public final class FrameHeaderCodec { /** (I)gnore flag: a value of 0 indicates the protocol can't ignore this frame */ public static final int FLAGS_I = 0b10_0000_0000; /** (M)etadata flag: a value of 1 indicates the frame contains metadata */ @@ -38,14 +38,14 @@ public final class FrameHeaderFlyweight { disableFrameTypeCheck = Boolean.getBoolean(DISABLE_FRAME_TYPE_CHECK); } - private FrameHeaderFlyweight() {} + private FrameHeaderCodec() {} static ByteBuf encodeStreamZero( final ByteBufAllocator allocator, final FrameType frameType, int flags) { return encode(allocator, 0, frameType, flags); } - static ByteBuf encode( + public static ByteBuf encode( final ByteBufAllocator allocator, final int streamId, final FrameType frameType, int flags) { if (!frameType.canHaveMetadata() && ((flags & FLAGS_M) == FLAGS_M)) { throw new IllegalStateException("bad value for metadata flag"); @@ -56,6 +56,10 @@ static ByteBuf encode( return allocator.buffer().writeInt(streamId).writeShort(typeAndFlags); } + public static boolean hasFollows(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_F) == FLAGS_F; + } + public static int streamId(ByteBuf byteBuf) { byteBuf.markReaderIndex(); int streamId = byteBuf.readInt(); @@ -75,6 +79,19 @@ public static boolean hasMetadata(ByteBuf byteBuf) { return (flags(byteBuf) & FLAGS_M) == FLAGS_M; } + /** + * faster version of {@link #frameType(ByteBuf)} which does not replace PAYLOAD with synthetic + * type + */ + public static FrameType nativeFrameType(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(Integer.BYTES); + int typeAndFlags = byteBuf.readShort() & 0xFFFF; + FrameType result = FrameType.fromEncodedType(typeAndFlags >> FRAME_TYPE_SHIFT); + byteBuf.resetReaderIndex(); + return result; + } + public static FrameType frameType(ByteBuf byteBuf) { byteBuf.markReaderIndex(); byteBuf.skipBytes(Integer.BYTES); @@ -113,7 +130,7 @@ public static void ensureFrameType(final FrameType frameType, ByteBuf byteBuf) { } } - static int size() { + public static int size() { return HEADER_SIZE; } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java similarity index 95% rename from rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java index 622160061..f6c19c8ee 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java @@ -7,11 +7,11 @@ * Some transports like TCP aren't framed, and require a length. This is used by DuplexConnections * for transports that need to send length */ -public class FrameLengthFlyweight { +public class FrameLengthCodec { public static final int FRAME_LENGTH_MASK = 0xFFFFFF; public static final int FRAME_LENGTH_SIZE = 3; - private FrameLengthFlyweight() {} + private FrameLengthCodec() {} private static void encodeLength(final ByteBuf byteBuf, final int length) { if ((length & ~FRAME_LENGTH_MASK) != 0) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java new file mode 100644 index 000000000..66d18c8a7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -0,0 +1,117 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; + +public class FrameUtil { + + private FrameUtil() {} + + public static String toString(ByteBuf frame) { + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); + StringBuilder payload = new StringBuilder(); + + payload + .append("\nFrame => Stream ID: ") + .append(streamId) + .append(" Type: ") + .append(frameType) + .append(" Flags: 0b") + .append(Integer.toBinaryString(FrameHeaderCodec.flags(frame))) + .append(" Length: " + frame.readableBytes()); + + if (frameType.hasInitialRequestN()) { + payload.append(" InitialRequestN: ").append(RequestStreamFrameCodec.initialRequestN(frame)); + } + + if (frameType == FrameType.REQUEST_N) { + payload.append(" RequestN: ").append(RequestNFrameCodec.requestN(frame)); + } + + if (FrameHeaderCodec.hasMetadata(frame)) { + payload.append("\nMetadata:\n"); + + ByteBufUtil.appendPrettyHexDump(payload, getMetadata(frame, frameType)); + } + + payload.append("\nData:\n"); + ByteBufUtil.appendPrettyHexDump(payload, getData(frame, frameType)); + + return payload.toString(); + } + + private static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(frame); + if (hasMetadata) { + ByteBuf metadata; + switch (frameType) { + case REQUEST_FNF: + metadata = RequestFireAndForgetFrameCodec.metadata(frame); + break; + case REQUEST_STREAM: + metadata = RequestStreamFrameCodec.metadata(frame); + break; + case REQUEST_RESPONSE: + metadata = RequestResponseFrameCodec.metadata(frame); + break; + case REQUEST_CHANNEL: + metadata = RequestChannelFrameCodec.metadata(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + metadata = PayloadFrameCodec.metadata(frame); + break; + case METADATA_PUSH: + metadata = MetadataPushFrameCodec.metadata(frame); + break; + case SETUP: + metadata = SetupFrameCodec.metadata(frame); + break; + case LEASE: + metadata = LeaseFrameCodec.metadata(frame); + break; + default: + return Unpooled.EMPTY_BUFFER; + } + return metadata; + } else { + return Unpooled.EMPTY_BUFFER; + } + } + + private static ByteBuf getData(ByteBuf frame, FrameType frameType) { + ByteBuf data; + switch (frameType) { + case REQUEST_FNF: + data = RequestFireAndForgetFrameCodec.data(frame); + break; + case REQUEST_STREAM: + data = RequestStreamFrameCodec.data(frame); + break; + case REQUEST_RESPONSE: + data = RequestResponseFrameCodec.data(frame); + break; + case REQUEST_CHANNEL: + data = RequestChannelFrameCodec.data(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + data = PayloadFrameCodec.data(frame); + break; + case SETUP: + data = SetupFrameCodec.data(frame); + break; + default: + return Unpooled.EMPTY_BUFFER; + } + return data; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java new file mode 100644 index 000000000..56a93d869 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java @@ -0,0 +1,159 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +class GenericFrameCodec { + + static ByteBuf encodeReleasingPayload( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean complete, + boolean next, + final Payload payload) { + return encodeReleasingPayload(allocator, frameType, streamId, complete, next, 0, payload); + } + + static ByteBuf encodeReleasingPayload( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean complete, + boolean next, + int requestN, + final Payload payload) { + + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } + + return encode(allocator, frameType, streamId, false, complete, next, requestN, metadata, data); + } + + static ByteBuf encode( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + return encode(allocator, frameType, streamId, fragmentFollows, false, false, 0, metadata, data); + } + + static ByteBuf encode( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean fragmentFollows, + boolean complete, + boolean next, + int requestN, + @Nullable ByteBuf metadata, + @Nullable ByteBuf data) { + + final boolean hasMetadata = metadata != null; + + int flags = 0; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + if (fragmentFollows) { + flags |= FrameHeaderCodec.FLAGS_F; + } + + if (complete) { + flags |= FrameHeaderCodec.FLAGS_C; + } + + if (next) { + flags |= FrameHeaderCodec.FLAGS_N; + } + + final ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, frameType, flags); + + if (requestN > 0) { + header.writeInt(requestN); + } + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + static ByteBuf data(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + int idx = byteBuf.readerIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.readerIndex(idx); + return data; + } + + @Nullable + static ByteBuf metadata(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + static ByteBuf dataWithRequestN(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + @Nullable + static ByteBuf metadataWithRequestN(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + static int initialRequestN(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int i = byteBuf.skipBytes(FrameHeaderCodec.size()).readInt(); + byteBuf.resetReaderIndex(); + return i; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java similarity index 61% rename from rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java index e4e6029b3..752d5b3eb 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java @@ -3,7 +3,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -public class KeepAliveFrameFlyweight { +public class KeepAliveFrameCodec { /** * (R)espond: Set by the sender of the KEEPALIVE, to which the responder MUST reply with a * KEEPALIVE without the R flag set @@ -12,7 +12,7 @@ public class KeepAliveFrameFlyweight { public static final long LAST_POSITION_MASK = 0x8000000000000000L; - private KeepAliveFrameFlyweight() {} + private KeepAliveFrameCodec() {} public static ByteBuf encode( final ByteBufAllocator allocator, @@ -20,7 +20,7 @@ public static ByteBuf encode( final long lastPosition, final ByteBuf data) { final int flags = respond ? FLAGS_KEEPALIVE_R : 0; - ByteBuf header = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.KEEPALIVE, flags); + ByteBuf header = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.KEEPALIVE, flags); long lp = 0; if (lastPosition > 0) { @@ -29,27 +29,27 @@ public static ByteBuf encode( header.writeLong(lp); - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); + return FrameBodyCodec.encode(allocator, header, null, false, data); } public static boolean respondFlag(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.KEEPALIVE, byteBuf); - int flags = FrameHeaderFlyweight.flags(byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + int flags = FrameHeaderCodec.flags(byteBuf); return (flags & FLAGS_KEEPALIVE_R) == FLAGS_KEEPALIVE_R; } public static long lastPosition(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); byteBuf.markReaderIndex(); - long l = byteBuf.skipBytes(FrameHeaderFlyweight.size()).readLong(); + long l = byteBuf.skipBytes(FrameHeaderCodec.size()).readLong(); byteBuf.resetReaderIndex(); return l; } public static ByteBuf data(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); byteBuf.markReaderIndex(); - ByteBuf slice = byteBuf.skipBytes(FrameHeaderFlyweight.size() + Long.BYTES).slice(); + ByteBuf slice = byteBuf.skipBytes(FrameHeaderCodec.size() + Long.BYTES).slice(); byteBuf.resetReaderIndex(); return slice; } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFlyweight.java deleted file mode 100644 index 8747da0bb..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFlyweight.java +++ /dev/null @@ -1,62 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import javax.annotation.Nullable; - -public class LeaseFlyweight { - - public static ByteBuf encode( - final ByteBufAllocator allocator, - final int ttl, - final int numRequests, - @Nullable final ByteBuf metadata) { - - int flags = 0; - - if (metadata != null) { - flags |= FrameHeaderFlyweight.FLAGS_M; - } - - ByteBuf header = - FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.LEASE, flags) - .writeInt(ttl) - .writeInt(numRequests); - - return DataAndMetadataFlyweight.encodeOnlyMetadata(allocator, header, metadata); - } - - public static int ttl(final ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.LEASE, byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - int ttl = byteBuf.readInt(); - byteBuf.resetReaderIndex(); - return ttl; - } - - public static int numRequests(final ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.LEASE, byteBuf); - byteBuf.markReaderIndex(); - // Ttl - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - int numRequests = byteBuf.readInt(); - byteBuf.resetReaderIndex(); - return numRequests; - } - - public static ByteBuf metadata(final ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.LEASE, byteBuf); - if (FrameHeaderFlyweight.hasMetadata(byteBuf)) { - byteBuf.markReaderIndex(); - // Ttl + Num of requests - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES * 2); - ByteBuf metadata = byteBuf.slice(); - byteBuf.resetReaderIndex(); - return metadata; - } else { - return Unpooled.EMPTY_BUFFER; - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java new file mode 100644 index 000000000..f20c25d3b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java @@ -0,0 +1,83 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +public class LeaseFrameCodec { + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final int ttl, + final int numRequests, + @Nullable final ByteBuf metadata) { + + final boolean hasMetadata = metadata != null; + + int flags = 0; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = + FrameHeaderCodec.encodeStreamZero(allocator, FrameType.LEASE, flags) + .writeInt(ttl) + .writeInt(numRequests); + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else { + return header; + } + } + + public static int ttl(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int ttl = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return ttl; + } + + public static int numRequests(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + byteBuf.markReaderIndex(); + // Ttl + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + int numRequests = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return numRequests; + } + + @Nullable + public static ByteBuf metadata(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + if (FrameHeaderCodec.hasMetadata(byteBuf)) { + byteBuf.markReaderIndex(); + // Ttl + Num of requests + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES * 2); + ByteBuf metadata = byteBuf.slice(); + byteBuf.resetReaderIndex(); + return metadata; + } else { + return null; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java new file mode 100644 index 000000000..d8ffe3eef --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java @@ -0,0 +1,43 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; + +public class MetadataPushFrameCodec { + + public static ByteBuf encodeReleasingPayload(ByteBufAllocator allocator, Payload payload) { + if (!payload.hasMetadata()) { + throw new IllegalStateException( + "Metadata push requires to have metadata present" + " in the given Payload"); + } + final ByteBuf metadata = payload.metadata().retain(); + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + metadata.release(); + throw e; + } + return encode(allocator, metadata); + } + + public static ByteBuf encode(ByteBufAllocator allocator, ByteBuf metadata) { + ByteBuf header = + FrameHeaderCodec.encodeStreamZero( + allocator, FrameType.METADATA_PUSH, FrameHeaderCodec.FLAGS_M); + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } + + public static ByteBuf metadata(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int headerSize = FrameHeaderCodec.size(); + int metadataLength = byteBuf.readableBytes() - headerSize; + byteBuf.skipBytes(headerSize); + ByteBuf metadata = byteBuf.readSlice(metadataLength); + byteBuf.resetReaderIndex(); + return metadata; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java deleted file mode 100644 index d37b573ba..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java +++ /dev/null @@ -1,23 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; - -public class MetadataPushFrameFlyweight { - public static ByteBuf encode(ByteBufAllocator allocator, ByteBuf metadata) { - ByteBuf header = - FrameHeaderFlyweight.encodeStreamZero( - allocator, FrameType.METADATA_PUSH, FrameHeaderFlyweight.FLAGS_M); - return allocator.compositeBuffer(2).addComponents(true, header, metadata); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - byteBuf.markReaderIndex(); - int headerSize = FrameHeaderFlyweight.size(); - int metadataLength = byteBuf.readableBytes() - headerSize; - byteBuf.skipBytes(headerSize); - ByteBuf metadata = byteBuf.readSlice(metadataLength); - byteBuf.resetReaderIndex(); - return metadata; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java new file mode 100644 index 000000000..1ae9c6750 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java @@ -0,0 +1,56 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class PayloadFrameCodec { + + private PayloadFrameCodec() {} + + public static ByteBuf encodeNextReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, false, payload); + } + + public static ByteBuf encodeNextCompleteReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, true, payload); + } + + static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, boolean complete, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.PAYLOAD, streamId, complete, true, payload); + } + + public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { + return encode(allocator, streamId, false, true, false, null, null); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + boolean next, + @Nullable ByteBuf metadata, + @Nullable ByteBuf data) { + + return GenericFrameCodec.encode( + allocator, FrameType.PAYLOAD, streamId, fragmentFollows, complete, next, 0, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java deleted file mode 100644 index 83f2406dd..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java +++ /dev/null @@ -1,78 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Payload; - -public class PayloadFrameFlyweight { - private static final RequestFlyweight FLYWEIGHT = new RequestFlyweight(FrameType.PAYLOAD); - - private PayloadFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - boolean complete, - boolean next, - ByteBuf metadata, - ByteBuf data) { - return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, complete, next, 0, metadata, data); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - boolean complete, - boolean next, - Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - fragmentFollows, - complete, - next, - 0, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); - } - - public static ByteBuf encodeNextComplete( - ByteBufAllocator allocator, int streamId, Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - false, - true, - true, - 0, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); - } - - public static ByteBuf encodeNext(ByteBufAllocator allocator, int streamId, Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - false, - false, - true, - 0, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); - } - - public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { - return FLYWEIGHT.encode(allocator, streamId, false, true, false, 0, null, null); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.data(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadata(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java new file mode 100644 index 000000000..60906083d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java @@ -0,0 +1,69 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestChannelFrameCodec { + + private RequestChannelFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, + int streamId, + boolean complete, + long initialRequestN, + Payload payload) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_CHANNEL, streamId, complete, false, reqN, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + long initialRequestN, + @Nullable ByteBuf metadata, + ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encode( + allocator, + FrameType.REQUEST_CHANNEL, + streamId, + fragmentFollows, + complete, + false, + reqN, + metadata, + data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.dataWithRequestN(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadataWithRequestN(byteBuf); + } + + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = GenericFrameCodec.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java deleted file mode 100644 index fb6ecebb0..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java +++ /dev/null @@ -1,38 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; - -public class RequestChannelFrameFlyweight { - - private static final RequestFlyweight FLYWEIGHT = new RequestFlyweight(FrameType.REQUEST_CHANNEL); - - private RequestChannelFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - boolean complete, - long requestN, - ByteBuf metadata, - ByteBuf data) { - - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - - return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, complete, false, reqN, metadata, data); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.dataWithRequestN(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadataWithRequestN(byteBuf); - } - - public static int initialRequestN(ByteBuf byteBuf) { - return FLYWEIGHT.initialRequestN(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java new file mode 100644 index 000000000..b91199179 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java @@ -0,0 +1,38 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestFireAndForgetFrameCodec { + + private RequestFireAndForgetFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_FNF, streamId, false, false, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + + return GenericFrameCodec.encode( + allocator, FrameType.REQUEST_FNF, streamId, fragmentFollows, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java deleted file mode 100644 index 680374f71..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java +++ /dev/null @@ -1,29 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; - -public class RequestFireAndForgetFrameFlyweight { - - private static final RequestFlyweight FLYWEIGHT = new RequestFlyweight(FrameType.REQUEST_FNF); - - private RequestFireAndForgetFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - ByteBuf metadata, - ByteBuf data) { - - return FLYWEIGHT.encode(allocator, streamId, fragmentFollows, metadata, data); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.data(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadata(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java deleted file mode 100644 index 8196c56d8..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java +++ /dev/null @@ -1,105 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import javax.annotation.Nullable; - -class RequestFlyweight { - FrameType frameType; - - RequestFlyweight(FrameType frameType) { - this.frameType = frameType; - } - - ByteBuf encode( - final ByteBufAllocator allocator, - final int streamId, - boolean fragmentFollows, - @Nullable ByteBuf metadata, - ByteBuf data) { - return encode(allocator, streamId, fragmentFollows, false, false, 0, metadata, data); - } - - ByteBuf encode( - final ByteBufAllocator allocator, - final int streamId, - boolean fragmentFollows, - boolean complete, - boolean next, - int requestN, - @Nullable ByteBuf metadata, - ByteBuf data) { - int flags = 0; - - if (metadata != null) { - flags |= FrameHeaderFlyweight.FLAGS_M; - } - - if (fragmentFollows) { - flags |= FrameHeaderFlyweight.FLAGS_F; - } - - if (complete) { - flags |= FrameHeaderFlyweight.FLAGS_C; - } - - if (next) { - flags |= FrameHeaderFlyweight.FLAGS_N; - } - - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, frameType, flags); - - if (requestN > 0) { - header.writeInt(requestN); - } - - if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } - } - - ByteBuf data(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - int idx = byteBuf.readerIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - ByteBuf data = DataAndMetadataFlyweight.dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.readerIndex(idx); - return data; - } - - ByteBuf metadata(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; - } - - ByteBuf dataWithRequestN(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf data = DataAndMetadataFlyweight.dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return data; - } - - ByteBuf metadataWithRequestN(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; - } - - int initialRequestN(ByteBuf byteBuf) { - byteBuf.markReaderIndex(); - int i = byteBuf.skipBytes(FrameHeaderFlyweight.size()).readInt(); - byteBuf.resetReaderIndex(); - return i; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java new file mode 100644 index 000000000..66bdd46f4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java @@ -0,0 +1,30 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class RequestNFrameCodec { + private RequestNFrameCodec() {} + + public static ByteBuf encode( + final ByteBufAllocator allocator, final int streamId, long requestN) { + + if (requestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; + + ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.REQUEST_N, 0); + return header.writeInt(reqN); + } + + public static long requestN(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.REQUEST_N, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i == Integer.MAX_VALUE ? Long.MAX_VALUE : i; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java deleted file mode 100644 index 5a4c4c273..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; - -public class RequestNFrameFlyweight { - private RequestNFrameFlyweight() {} - - public static ByteBuf encode( - final ByteBufAllocator allocator, final int streamId, long requestN) { - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - return encode(allocator, streamId, reqN); - } - - public static ByteBuf encode(final ByteBufAllocator allocator, final int streamId, int requestN) { - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.REQUEST_N, 0); - - if (requestN < 1) { - throw new IllegalArgumentException("request n is less than 1"); - } - - return header.writeInt(requestN); - } - - public static int requestN(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.REQUEST_N, byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - int i = byteBuf.readInt(); - byteBuf.resetReaderIndex(); - return i; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java new file mode 100644 index 000000000..4a37acfd5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java @@ -0,0 +1,37 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestResponseFrameCodec { + + private RequestResponseFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_RESPONSE, streamId, false, false, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + return GenericFrameCodec.encode( + allocator, FrameType.REQUEST_RESPONSE, streamId, fragmentFollows, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java deleted file mode 100644 index efbffbd40..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java +++ /dev/null @@ -1,35 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Payload; - -public class RequestResponseFrameFlyweight { - private static final RequestFlyweight FLYWEIGHT = - new RequestFlyweight(FrameType.REQUEST_RESPONSE); - - private RequestResponseFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, int streamId, boolean fragmentFollows, Payload payload) { - return encode( - allocator, streamId, fragmentFollows, payload.sliceMetadata(), payload.sliceData()); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - ByteBuf metadata, - ByteBuf data) { - return FLYWEIGHT.encode(allocator, streamId, fragmentFollows, metadata, data); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.data(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadata(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java new file mode 100644 index 000000000..2f5dbf0d8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java @@ -0,0 +1,64 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestStreamFrameCodec { + + private RequestStreamFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, long initialRequestN, Payload payload) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_STREAM, streamId, false, false, reqN, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + long initialRequestN, + @Nullable ByteBuf metadata, + ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encode( + allocator, + FrameType.REQUEST_STREAM, + streamId, + fragmentFollows, + false, + false, + reqN, + metadata, + data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.dataWithRequestN(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadataWithRequestN(byteBuf); + } + + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = GenericFrameCodec.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java deleted file mode 100644 index 3e858f5d4..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java +++ /dev/null @@ -1,76 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Payload; - -public class RequestStreamFrameFlyweight { - - private static final RequestFlyweight FLYWEIGHT = new RequestFlyweight(FrameType.REQUEST_STREAM); - - private RequestStreamFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - long requestN, - Payload payload) { - return encode( - allocator, - streamId, - fragmentFollows, - requestN, - payload.sliceMetadata(), - payload.sliceData()); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - int requestN, - Payload payload) { - return encode( - allocator, - streamId, - fragmentFollows, - requestN, - payload.sliceMetadata(), - payload.sliceData()); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - long requestN, - ByteBuf metadata, - ByteBuf data) { - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - return encode(allocator, streamId, fragmentFollows, reqN, metadata, data); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - int requestN, - ByteBuf metadata, - ByteBuf data) { - return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, false, false, requestN, metadata, data); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.dataWithRequestN(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadataWithRequestN(byteBuf); - } - - public static int initialRequestN(ByteBuf byteBuf) { - return FLYWEIGHT.initialRequestN(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFlyweight.java deleted file mode 100644 index 899957718..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/ResumeFlyweight.java +++ /dev/null @@ -1,3 +0,0 @@ -package io.rsocket.frame; - -public class ResumeFlyweight {} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java new file mode 100644 index 000000000..aae89f7ab --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.util.UUID; + +public class ResumeFrameCodec { + static final int CURRENT_VERSION = SetupFrameCodec.CURRENT_VERSION; + + public static ByteBuf encode( + ByteBufAllocator allocator, + ByteBuf token, + long lastReceivedServerPos, + long firstAvailableClientPos) { + + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.RESUME, 0); + byteBuf.writeInt(CURRENT_VERSION); + token.markReaderIndex(); + byteBuf.writeShort(token.readableBytes()); + byteBuf.writeBytes(token); + token.resetReaderIndex(); + byteBuf.writeLong(lastReceivedServerPos); + byteBuf.writeLong(firstAvailableClientPos); + + return byteBuf; + } + + public static int version(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int version = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + + return version; + } + + public static ByteBuf token(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + ByteBuf token = byteBuf.readSlice(tokenLength); + byteBuf.resetReaderIndex(); + + return token; + } + + public static long lastReceivedServerPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + byteBuf.skipBytes(tokenLength); + long lastReceivedServerPos = byteBuf.readLong(); + byteBuf.resetReaderIndex(); + + return lastReceivedServerPos; + } + + public static long firstAvailableClientPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + byteBuf.skipBytes(tokenLength); + // last received server position + byteBuf.skipBytes(Long.BYTES); + long firstAvailableClientPos = byteBuf.readLong(); + byteBuf.resetReaderIndex(); + + return firstAvailableClientPos; + } + + public static ByteBuf generateResumeToken() { + UUID uuid = UUID.randomUUID(); + ByteBuf bb = Unpooled.buffer(16); + bb.writeLong(uuid.getMostSignificantBits()); + bb.writeLong(uuid.getLeastSignificantBits()); + return bb; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFlyweight.java deleted file mode 100644 index d947a3f76..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFlyweight.java +++ /dev/null @@ -1,3 +0,0 @@ -package io.rsocket.frame; - -public class ResumeOkFlyweight {} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java new file mode 100644 index 000000000..2b6951e49 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java @@ -0,0 +1,22 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class ResumeOkFrameCodec { + + public static ByteBuf encode(final ByteBufAllocator allocator, long lastReceivedClientPos) { + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.RESUME_OK, 0); + byteBuf.writeLong(lastReceivedClientPos); + return byteBuf; + } + + public static long lastReceivedClientPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME_OK, byteBuf); + byteBuf.markReaderIndex(); + long lastReceivedClientPosition = byteBuf.skipBytes(FrameHeaderCodec.size()).readLong(); + byteBuf.resetReaderIndex(); + + return lastReceivedClientPosition; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java similarity index 62% rename from rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java index 2a5992419..547e2436e 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java @@ -4,9 +4,11 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; +import io.rsocket.Payload; import java.nio.charset.StandardCharsets; +import reactor.util.annotation.Nullable; -public class SetupFrameFlyweight { +public class SetupFrameCodec { /** * A flag used to indicate that the client requires connection resumption, if possible (the frame * contains a Resume Identification Token) @@ -16,9 +18,9 @@ public class SetupFrameFlyweight { /** A flag used to indicate that the client will honor LEASE sent by the server */ public static final int FLAGS_WILL_HONOR_LEASE = 0b00_0100_0000; - public static final int CURRENT_VERSION = VersionFlyweight.encode(1, 0); + public static final int CURRENT_VERSION = VersionCodec.encode(1, 0); - private static final int VERSION_FIELD_OFFSET = FrameHeaderFlyweight.size(); + private static final int VERSION_FIELD_OFFSET = FrameHeaderCodec.size(); private static final int KEEPALIVE_INTERVAL_FIELD_OFFSET = VERSION_FIELD_OFFSET + Integer.BYTES; private static final int KEEPALIVE_MAX_LIFETIME_FIELD_OFFSET = KEEPALIVE_INTERVAL_FIELD_OFFSET + Integer.BYTES; @@ -27,63 +29,59 @@ public class SetupFrameFlyweight { public static ByteBuf encode( final ByteBufAllocator allocator, - boolean lease, - boolean resume, + final boolean lease, final int keepaliveInterval, final int maxLifetime, final String metadataMimeType, final String dataMimeType, - final ByteBuf metadata, - final ByteBuf data) { + final Payload setupPayload) { return encode( allocator, lease, - resume, keepaliveInterval, maxLifetime, Unpooled.EMPTY_BUFFER, metadataMimeType, dataMimeType, - metadata, - data); + setupPayload); } public static ByteBuf encode( final ByteBufAllocator allocator, - boolean lease, - boolean resume, + final boolean lease, final int keepaliveInterval, final int maxLifetime, final ByteBuf resumeToken, final String metadataMimeType, final String dataMimeType, - final ByteBuf metadata, - final ByteBuf data) { + final Payload setupPayload) { + + final ByteBuf data = setupPayload.sliceData(); + final boolean hasMetadata = setupPayload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? setupPayload.sliceMetadata() : null; int flags = 0; - if (resume) { - throw new IllegalArgumentException("RESUME_ENABLE not supported"); - } - /* - if (resume) { + if (resumeToken.readableBytes() > 0) { flags |= FLAGS_RESUME_ENABLE; - }*/ + } if (lease) { flags |= FLAGS_WILL_HONOR_LEASE; } - if (metadata != null) { - flags |= FrameHeaderFlyweight.FLAGS_M; + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; } - ByteBuf header = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.SETUP, flags); + final ByteBuf header = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.SETUP, flags); header.writeInt(CURRENT_VERSION).writeInt(keepaliveInterval).writeInt(maxLifetime); if ((flags & FLAGS_RESUME_ENABLE) != 0) { + resumeToken.markReaderIndex(); header.writeShort(resumeToken.readableBytes()).writeBytes(resumeToken); + resumeToken.resetReaderIndex(); } // Write metadata mime-type @@ -95,21 +93,27 @@ public static ByteBuf encode( length = ByteBufUtil.utf8Bytes(dataMimeType); header.writeByte(length); ByteBufUtil.writeUtf8(header, dataMimeType); - if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); } public static int version(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.SETUP, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.SETUP, byteBuf); byteBuf.markReaderIndex(); int version = byteBuf.skipBytes(VERSION_FIELD_OFFSET).readInt(); byteBuf.resetReaderIndex(); return version; } + public static String humanReadableVersion(ByteBuf byteBuf) { + int encodedVersion = version(byteBuf); + return VersionCodec.major(encodedVersion) + "." + VersionCodec.minor(encodedVersion); + } + + public static boolean isSupportedVersion(ByteBuf byteBuf) { + return CURRENT_VERSION == version(byteBuf); + } + public static int resumeTokenLength(ByteBuf byteBuf) { byteBuf.markReaderIndex(); int tokenLength = byteBuf.skipBytes(VARIABLE_DATA_OFFSET).readShort() & 0xFFFF; @@ -132,18 +136,43 @@ public static int keepAliveMaxLifetime(ByteBuf byteBuf) { } public static boolean honorLease(ByteBuf byteBuf) { - return (FLAGS_WILL_HONOR_LEASE & FrameHeaderFlyweight.flags(byteBuf)) == FLAGS_WILL_HONOR_LEASE; + return (FLAGS_WILL_HONOR_LEASE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_WILL_HONOR_LEASE; } public static boolean resumeEnabled(ByteBuf byteBuf) { - return (FLAGS_RESUME_ENABLE & FrameHeaderFlyweight.flags(byteBuf)) == FLAGS_RESUME_ENABLE; + return (FLAGS_RESUME_ENABLE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_RESUME_ENABLE; + } + + public static ByteBuf resumeToken(ByteBuf byteBuf) { + if (resumeEnabled(byteBuf)) { + byteBuf.markReaderIndex(); + // header + int resumePos = + FrameHeaderCodec.size() + + + // version + Integer.BYTES + + + // keep-alive interval + Integer.BYTES + + + // keep-alive maxLifeTime + Integer.BYTES; + + int tokenLength = byteBuf.skipBytes(resumePos).readShort() & 0xFFFF; + ByteBuf resumeToken = byteBuf.readSlice(tokenLength); + byteBuf.resetReaderIndex(); + return resumeToken; + } else { + return Unpooled.EMPTY_BUFFER; + } } public static String metadataMimeType(ByteBuf byteBuf) { int skip = bytesToSkipToMimeType(byteBuf); byteBuf.markReaderIndex(); - int length = byteBuf.skipBytes(skip).readByte(); - String mimeType = byteBuf.readSlice(length).toString(StandardCharsets.UTF_8); + int length = byteBuf.skipBytes(skip).readUnsignedByte(); + String mimeType = byteBuf.slice(byteBuf.readerIndex(), length).toString(StandardCharsets.UTF_8); byteBuf.resetReaderIndex(); return mimeType; } @@ -158,28 +187,32 @@ public static String dataMimeType(ByteBuf byteBuf) { return mimeType; } + @Nullable public static ByteBuf metadata(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } byteBuf.markReaderIndex(); skipToPayload(byteBuf); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); byteBuf.resetReaderIndex(); return metadata; } public static ByteBuf data(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); byteBuf.markReaderIndex(); skipToPayload(byteBuf); - ByteBuf data = DataAndMetadataFlyweight.dataWithoutMarking(byteBuf, hasMetadata); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); byteBuf.resetReaderIndex(); return data; } private static int bytesToSkipToMimeType(ByteBuf byteBuf) { int bytesToSkip = VARIABLE_DATA_OFFSET; - if ((FLAGS_RESUME_ENABLE & FrameHeaderFlyweight.flags(byteBuf)) == FLAGS_RESUME_ENABLE) { - bytesToSkip = resumeTokenLength(byteBuf) + Short.BYTES; + if ((FLAGS_RESUME_ENABLE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_RESUME_ENABLE) { + bytesToSkip += resumeTokenLength(byteBuf) + Short.BYTES; } return bytesToSkip; } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java similarity index 96% rename from rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java index e238b3fe2..35e4aa86a 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java @@ -16,7 +16,7 @@ package io.rsocket.frame; -public class VersionFlyweight { +public class VersionCodec { public static int encode(int major, int minor) { return (major << 16) | (minor & 0xFFFF); diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java index 692dcb363..e6874c097 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java @@ -3,8 +3,15 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Payload; -import io.rsocket.frame.*; -import io.rsocket.util.ByteBufPayload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.util.DefaultPayload; import java.nio.ByteBuffer; /** Default Frame decoder that copies the frames contents for easy of use. */ @@ -14,45 +21,49 @@ class DefaultPayloadDecoder implements PayloadDecoder { public Payload apply(ByteBuf byteBuf) { ByteBuf m; ByteBuf d; - FrameType type = FrameHeaderFlyweight.frameType(byteBuf); + FrameType type = FrameHeaderCodec.frameType(byteBuf); switch (type) { case REQUEST_FNF: - d = RequestFireAndForgetFrameFlyweight.data(byteBuf); - m = RequestFireAndForgetFrameFlyweight.metadata(byteBuf); + d = RequestFireAndForgetFrameCodec.data(byteBuf); + m = RequestFireAndForgetFrameCodec.metadata(byteBuf); break; case REQUEST_RESPONSE: - d = RequestResponseFrameFlyweight.data(byteBuf); - m = RequestResponseFrameFlyweight.metadata(byteBuf); + d = RequestResponseFrameCodec.data(byteBuf); + m = RequestResponseFrameCodec.metadata(byteBuf); break; case REQUEST_STREAM: - d = RequestStreamFrameFlyweight.data(byteBuf); - m = RequestStreamFrameFlyweight.metadata(byteBuf); + d = RequestStreamFrameCodec.data(byteBuf); + m = RequestStreamFrameCodec.metadata(byteBuf); break; case REQUEST_CHANNEL: - d = RequestChannelFrameFlyweight.data(byteBuf); - m = RequestChannelFrameFlyweight.metadata(byteBuf); + d = RequestChannelFrameCodec.data(byteBuf); + m = RequestChannelFrameCodec.metadata(byteBuf); break; case NEXT: case NEXT_COMPLETE: - d = PayloadFrameFlyweight.data(byteBuf); - m = PayloadFrameFlyweight.metadata(byteBuf); + d = PayloadFrameCodec.data(byteBuf); + m = PayloadFrameCodec.metadata(byteBuf); break; case METADATA_PUSH: d = Unpooled.EMPTY_BUFFER; - m = MetadataPushFrameFlyweight.metadata(byteBuf); + m = MetadataPushFrameCodec.metadata(byteBuf); break; default: throw new IllegalArgumentException("unsupported frame type: " + type); } - ByteBuffer metadata = ByteBuffer.allocateDirect(m.readableBytes()); ByteBuffer data = ByteBuffer.allocateDirect(d.readableBytes()); - data.put(d.nioBuffer()); data.flip(); - metadata.put(m.nioBuffer()); - metadata.flip(); - return ByteBufPayload.create(data, metadata); + if (m != null) { + ByteBuffer metadata = ByteBuffer.allocateDirect(m.readableBytes()); + metadata.put(m.nioBuffer()); + metadata.flip(); + + return DefaultPayload.create(data, metadata); + } + + return DefaultPayload.create(data); } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java index 0b63590e8..3a0dc7bb5 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java @@ -3,7 +3,14 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Payload; -import io.rsocket.frame.*; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.util.ByteBufPayload; /** @@ -15,37 +22,37 @@ public class ZeroCopyPayloadDecoder implements PayloadDecoder { public Payload apply(ByteBuf byteBuf) { ByteBuf m; ByteBuf d; - FrameType type = FrameHeaderFlyweight.frameType(byteBuf); + FrameType type = FrameHeaderCodec.frameType(byteBuf); switch (type) { case REQUEST_FNF: - d = RequestFireAndForgetFrameFlyweight.data(byteBuf); - m = RequestFireAndForgetFrameFlyweight.metadata(byteBuf); + d = RequestFireAndForgetFrameCodec.data(byteBuf); + m = RequestFireAndForgetFrameCodec.metadata(byteBuf); break; case REQUEST_RESPONSE: - d = RequestResponseFrameFlyweight.data(byteBuf); - m = RequestResponseFrameFlyweight.metadata(byteBuf); + d = RequestResponseFrameCodec.data(byteBuf); + m = RequestResponseFrameCodec.metadata(byteBuf); break; case REQUEST_STREAM: - d = RequestStreamFrameFlyweight.data(byteBuf); - m = RequestStreamFrameFlyweight.metadata(byteBuf); + d = RequestStreamFrameCodec.data(byteBuf); + m = RequestStreamFrameCodec.metadata(byteBuf); break; case REQUEST_CHANNEL: - d = RequestChannelFrameFlyweight.data(byteBuf); - m = RequestChannelFrameFlyweight.metadata(byteBuf); + d = RequestChannelFrameCodec.data(byteBuf); + m = RequestChannelFrameCodec.metadata(byteBuf); break; case NEXT: case NEXT_COMPLETE: - d = PayloadFrameFlyweight.data(byteBuf); - m = PayloadFrameFlyweight.metadata(byteBuf); + d = PayloadFrameCodec.data(byteBuf); + m = PayloadFrameCodec.metadata(byteBuf); break; case METADATA_PUSH: d = Unpooled.EMPTY_BUFFER; - m = MetadataPushFrameFlyweight.metadata(byteBuf); + m = MetadataPushFrameCodec.metadata(byteBuf); break; default: throw new IllegalArgumentException("unsupported frame type: " + type); } - return ByteBufPayload.create(d.retain(), m.retain()); + return ByteBufPayload.create(d.retain(), m != null ? m.retain() : null); } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java new file mode 100644 index 000000000..82e8acaf3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Support for encoding and decoding of RSocket frames to and from {@link io.rsocket.Payload + * Payload}. + */ +@NonNullApi +package io.rsocket.frame.decoder; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/package-info.java b/rsocket-core/src/main/java/io/rsocket/frame/package-info.java new file mode 100644 index 000000000..69f6d6860 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Support for encoding and decoding of RSocket frames to and from {@link io.rsocket.Payload + * Payload}. + */ +@NonNullApi +package io.rsocket.frame; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java new file mode 100644 index 000000000..9668e5e18 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java @@ -0,0 +1,30 @@ +package io.rsocket.internal; + +import io.rsocket.DuplexConnection; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +public abstract class BaseDuplexConnection implements DuplexConnection { + private MonoProcessor onClose = MonoProcessor.create(); + + public BaseDuplexConnection() { + onClose.doFinally(s -> doOnClose()).subscribe(); + } + + protected abstract void doOnClose(); + + @Override + public final Mono onClose() { + return onClose; + } + + @Override + public final void dispose() { + onClose.onComplete(); + } + + @Override + public final boolean isDisposed() { + return onClose.isDisposed(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java index e6178bd5b..038120efc 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -17,12 +17,13 @@ package io.rsocket.internal; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.FrameType; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameUtil; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; -import io.rsocket.plugins.PluginRegistry; +import io.rsocket.plugins.InitializingInterceptorRegistry; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,37 +46,55 @@ */ public class ClientServerInputMultiplexer implements Closeable { private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); + private static final InitializingInterceptorRegistry emptyInterceptorRegistry = + new InitializingInterceptorRegistry(); - private final DuplexConnection streamZeroConnection; + private final DuplexConnection setupConnection; private final DuplexConnection serverConnection; private final DuplexConnection clientConnection; private final DuplexConnection source; + private final DuplexConnection clientServerConnection; - public ClientServerInputMultiplexer(DuplexConnection source, PluginRegistry plugins) { + public ClientServerInputMultiplexer(DuplexConnection source) { + this(source, emptyInterceptorRegistry, false); + } + + public ClientServerInputMultiplexer( + DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { this.source = source; - final MonoProcessor> streamZero = MonoProcessor.create(); + final MonoProcessor> setup = MonoProcessor.create(); final MonoProcessor> server = MonoProcessor.create(); final MonoProcessor> client = MonoProcessor.create(); - source = plugins.applyConnection(Type.SOURCE, source); - streamZeroConnection = - plugins.applyConnection(Type.STREAM_ZERO, new InternalDuplexConnection(source, streamZero)); + source = registry.initConnection(Type.SOURCE, source); + setupConnection = + registry.initConnection(Type.SETUP, new InternalDuplexConnection(source, setup)); serverConnection = - plugins.applyConnection(Type.SERVER, new InternalDuplexConnection(source, server)); + registry.initConnection(Type.SERVER, new InternalDuplexConnection(source, server)); clientConnection = - plugins.applyConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); + registry.initConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); + clientServerConnection = new InternalDuplexConnection(source, client, server); source .receive() .groupBy( frame -> { - int streamId = FrameHeaderFlyweight.streamId(frame); + int streamId = FrameHeaderCodec.streamId(frame); final Type type; if (streamId == 0) { - if (FrameHeaderFlyweight.frameType(frame) == FrameType.SETUP) { - type = Type.STREAM_ZERO; - } else { - type = Type.CLIENT; + switch (FrameHeaderCodec.frameType(frame)) { + case SETUP: + case RESUME: + case RESUME_OK: + type = Type.SETUP; + break; + case LEASE: + case KEEPALIVE: + case ERROR: + type = isClient ? Type.CLIENT : Type.SERVER; + break; + default: + type = isClient ? Type.SERVER : Type.CLIENT; } } else if ((streamId & 0b1) == 0) { type = Type.SERVER; @@ -87,8 +106,8 @@ public ClientServerInputMultiplexer(DuplexConnection source, PluginRegistry plug .subscribe( group -> { switch (group.key()) { - case STREAM_ZERO: - streamZero.onNext(group); + case SETUP: + setup.onNext(group); break; case SERVER: @@ -100,10 +119,11 @@ public ClientServerInputMultiplexer(DuplexConnection source, PluginRegistry plug break; } }, - t -> { - LOGGER.error("Error receiving frame:", t); - dispose(); - }); + t -> {}); + } + + public DuplexConnection asClientServerConnection() { + return clientServerConnection; } public DuplexConnection asServerConnection() { @@ -114,8 +134,8 @@ public DuplexConnection asClientConnection() { return clientConnection; } - public DuplexConnection asStreamZeroConnection() { - return streamZeroConnection; + public DuplexConnection asSetupConnection() { + return setupConnection; } @Override @@ -135,20 +155,21 @@ public Mono onClose() { private static class InternalDuplexConnection implements DuplexConnection { private final DuplexConnection source; - private final MonoProcessor> processor; + private final MonoProcessor>[] processors; private final boolean debugEnabled; + @SafeVarargs public InternalDuplexConnection( - DuplexConnection source, MonoProcessor> processor) { + DuplexConnection source, MonoProcessor>... processors) { this.source = source; - this.processor = processor; + this.processors = processors; this.debugEnabled = LOGGER.isDebugEnabled(); } @Override public Mono send(Publisher frame) { if (debugEnabled) { - frame = Flux.from(frame).doOnNext(f -> LOGGER.debug("sending -> " + f.toString())); + frame = Flux.from(frame).doOnNext(f -> LOGGER.debug("sending -> " + FrameUtil.toString(f))); } return source.send(frame); @@ -157,7 +178,7 @@ public Mono send(Publisher frame) { @Override public Mono sendOne(ByteBuf frame) { if (debugEnabled) { - LOGGER.debug("sending -> " + frame.toString()); + LOGGER.debug("sending -> " + FrameUtil.toString(frame)); } return source.sendOne(frame); @@ -165,14 +186,23 @@ public Mono sendOne(ByteBuf frame) { @Override public Flux receive() { - return processor.flatMapMany( - f -> { - if (debugEnabled) { - return f.doOnNext(frame -> LOGGER.debug("receiving -> " + frame.toString())); - } else { - return f; - } - }); + return Flux.fromArray(processors) + .flatMap( + p -> + p.flatMapMany( + f -> { + if (debugEnabled) { + return f.doOnNext( + frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); + } else { + return f; + } + })); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java b/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java deleted file mode 100755 index 17372ea01..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.concurrent.atomic.AtomicBoolean; -import javax.annotation.Nullable; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; - -/** */ -public class LimitableRequestPublisher extends Flux implements Subscription { - private final Publisher source; - - private final AtomicBoolean canceled; - - private long internalRequested; - - private long externalRequested; - - private volatile boolean subscribed; - - private volatile @Nullable Subscription internalSubscription; - - private LimitableRequestPublisher(Publisher source) { - this.source = source; - this.canceled = new AtomicBoolean(); - } - - public static LimitableRequestPublisher wrap(Publisher source) { - return new LimitableRequestPublisher<>(source); - } - - @Override - public void subscribe(CoreSubscriber destination) { - synchronized (this) { - if (subscribed) { - throw new IllegalStateException("only one subscriber at a time"); - } - - subscribed = true; - } - - destination.onSubscribe(new InnerSubscription()); - source.subscribe(new InnerSubscriber(destination)); - } - - public void increaseRequestLimit(long n) { - synchronized (this) { - externalRequested = Operators.addCap(n, externalRequested); - } - - requestN(); - } - - @Override - public void request(long n) { - increaseRequestLimit(n); - } - - private void requestN() { - long r; - synchronized (this) { - if (internalSubscription == null) { - return; - } - - r = Math.min(internalRequested, externalRequested); - externalRequested -= r; - internalRequested -= r; - } - - if (r > 0) { - internalSubscription.request(r); - } - } - - public void cancel() { - if (canceled.compareAndSet(false, true) && internalSubscription != null) { - internalSubscription.cancel(); - internalSubscription = null; - subscribed = false; - } - } - - private class InnerSubscriber implements Subscriber { - Subscriber destination; - - private InnerSubscriber(Subscriber destination) { - this.destination = destination; - } - - @Override - public void onSubscribe(Subscription s) { - synchronized (LimitableRequestPublisher.this) { - LimitableRequestPublisher.this.internalSubscription = s; - - if (canceled.get()) { - s.cancel(); - subscribed = false; - LimitableRequestPublisher.this.internalSubscription = null; - } - } - - requestN(); - } - - @Override - public void onNext(T t) { - try { - destination.onNext(t); - } catch (Throwable e) { - onError(e); - } - } - - @Override - public void onError(Throwable t) { - destination.onError(t); - } - - @Override - public void onComplete() { - destination.onComplete(); - } - } - - private class InnerSubscription implements Subscription { - @Override - public void request(long n) { - synchronized (LimitableRequestPublisher.this) { - internalRequested = Operators.addCap(n, internalRequested); - } - - requestN(); - } - - @Override - public void cancel() { - LimitableRequestPublisher.this.cancel(); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java b/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java deleted file mode 100644 index 0d2e5988e..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java +++ /dev/null @@ -1,581 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.Objects; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.BiFunction; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Fuseable; -import reactor.core.Scannable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; -import reactor.util.annotation.Nullable; -import reactor.util.context.Context; - -/** @deprecated in favour of {@link Flux#switchOnFirst(BiFunction)} */ -@Deprecated -public final class SwitchTransformFlux extends Flux { - - final Publisher source; - final BiFunction, Publisher> transformer; - - public SwitchTransformFlux( - Publisher source, BiFunction, Publisher> transformer) { - this.source = Objects.requireNonNull(source, "source"); - this.transformer = Objects.requireNonNull(transformer, "transformer"); - } - - @Override - public int getPrefetch() { - return 1; - } - - @Override - @SuppressWarnings("unchecked") - public void subscribe(CoreSubscriber actual) { - if (actual instanceof Fuseable.ConditionalSubscriber) { - source.subscribe( - new SwitchTransformConditionalOperator<>( - (Fuseable.ConditionalSubscriber) actual, transformer)); - return; - } - source.subscribe(new SwitchTransformOperator<>(actual, transformer)); - } - - static final class SwitchTransformOperator extends Flux - implements CoreSubscriber, Subscription, Scannable { - - final CoreSubscriber outer; - final BiFunction, Publisher> transformer; - - Subscription s; - Throwable throwable; - - volatile boolean done; - volatile T first; - - volatile CoreSubscriber inner; - - @SuppressWarnings("rawtypes") - static final AtomicReferenceFieldUpdater INNER = - AtomicReferenceFieldUpdater.newUpdater( - SwitchTransformOperator.class, CoreSubscriber.class, "inner"); - - volatile int wip; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformOperator.class, "wip"); - - volatile int once; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformOperator.class, "once"); - - SwitchTransformOperator( - CoreSubscriber outer, - BiFunction, Publisher> transformer) { - this.outer = outer; - this.transformer = transformer; - } - - @Override - @Nullable - public Object scanUnsafe(Attr key) { - if (key == Attr.CANCELLED) return s == Operators.cancelledSubscription(); - if (key == Attr.PREFETCH) return 1; - - return null; - } - - @Override - public Context currentContext() { - CoreSubscriber actual = inner; - - if (actual != null) { - return actual.currentContext(); - } - - return outer.currentContext(); - } - - @Override - public void cancel() { - if (s != Operators.cancelledSubscription()) { - Subscription s = this.s; - this.s = Operators.cancelledSubscription(); - - if (WIP.getAndIncrement(this) == 0) { - INNER.lazySet(this, null); - - T f = first; - if (f != null) { - first = null; - Operators.onDiscard(f, currentContext()); - } - } - - s.cancel(); - } - } - - @Override - public void subscribe(CoreSubscriber actual) { - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - INNER.lazySet(this, actual); - actual.onSubscribe(this); - } else { - Operators.error( - actual, new IllegalStateException("SwitchTransform allows only one Subscriber")); - } - } - - @Override - public void onSubscribe(Subscription s) { - if (Operators.validate(this.s, s)) { - this.s = s; - s.request(1); - } - } - - @Override - public void onNext(T t) { - if (done) { - Operators.onNextDropped(t, currentContext()); - return; - } - - CoreSubscriber i = inner; - - if (i == null) { - try { - first = t; - Publisher result = - Objects.requireNonNull( - transformer.apply(t, this), "The transformer returned a null value"); - result.subscribe(outer); - return; - } catch (Throwable e) { - onError(Operators.onOperatorError(s, e, t, currentContext())); - return; - } - } - - i.onNext(t); - } - - @Override - public void onError(Throwable t) { - if (done) { - Operators.onErrorDropped(t, currentContext()); - return; - } - - throwable = t; - done = true; - CoreSubscriber i = inner; - - if (i != null) { - if (first == null) { - drainRegular(); - } - } else { - Operators.error(outer, t); - } - } - - @Override - public void onComplete() { - if (done) { - return; - } - - done = true; - CoreSubscriber i = inner; - - if (i != null) { - if (first == null) { - drainRegular(); - } - } else { - Operators.complete(outer); - } - } - - @Override - public void request(long n) { - if (Operators.validate(n)) { - if (first != null && drainRegular() && n != Long.MAX_VALUE) { - if (--n > 0) { - s.request(n); - } - } else { - s.request(n); - } - } - } - - boolean drainRegular() { - if (WIP.getAndIncrement(this) != 0) { - return false; - } - - T f = first; - int m = 1; - boolean sent = false; - Subscription s = this.s; - CoreSubscriber a = inner; - - for (; ; ) { - if (f != null) { - first = null; - - if (s == Operators.cancelledSubscription()) { - Operators.onDiscard(f, a.currentContext()); - return true; - } - - a.onNext(f); - f = null; - sent = true; - } - - if (s == Operators.cancelledSubscription()) { - return sent; - } - - if (done) { - Throwable t = throwable; - if (t != null) { - a.onError(t); - } else { - a.onComplete(); - } - return sent; - } - - m = WIP.addAndGet(this, -m); - - if (m == 0) { - return sent; - } - } - } - } - - static final class SwitchTransformConditionalOperator extends Flux - implements Fuseable.ConditionalSubscriber, Subscription, Scannable { - - final Fuseable.ConditionalSubscriber outer; - final BiFunction, Publisher> transformer; - - Subscription s; - Throwable throwable; - - volatile boolean done; - volatile T first; - - volatile Fuseable.ConditionalSubscriber inner; - - @SuppressWarnings("rawtypes") - static final AtomicReferenceFieldUpdater< - SwitchTransformConditionalOperator, Fuseable.ConditionalSubscriber> - INNER = - AtomicReferenceFieldUpdater.newUpdater( - SwitchTransformConditionalOperator.class, - Fuseable.ConditionalSubscriber.class, - "inner"); - - volatile int wip; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformConditionalOperator.class, "wip"); - - volatile int once; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformConditionalOperator.class, "once"); - - SwitchTransformConditionalOperator( - Fuseable.ConditionalSubscriber outer, - BiFunction, Publisher> transformer) { - this.outer = outer; - this.transformer = transformer; - } - - @Override - @Nullable - public Object scanUnsafe(Attr key) { - if (key == Attr.CANCELLED) return s == Operators.cancelledSubscription(); - if (key == Attr.PREFETCH) return 1; - - return null; - } - - @Override - public Context currentContext() { - CoreSubscriber actual = inner; - - if (actual != null) { - return actual.currentContext(); - } - - return outer.currentContext(); - } - - @Override - public void cancel() { - if (s != Operators.cancelledSubscription()) { - Subscription s = this.s; - this.s = Operators.cancelledSubscription(); - - if (WIP.getAndIncrement(this) == 0) { - INNER.lazySet(this, null); - - T f = first; - if (f != null) { - first = null; - Operators.onDiscard(f, currentContext()); - } - } - - s.cancel(); - } - } - - @Override - @SuppressWarnings("unchecked") - public void subscribe(CoreSubscriber actual) { - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - if (actual instanceof Fuseable.ConditionalSubscriber) { - INNER.lazySet(this, (Fuseable.ConditionalSubscriber) actual); - } else { - INNER.lazySet(this, new ConditionalSubscriberAdapter<>(actual)); - } - actual.onSubscribe(this); - } else { - Operators.error( - actual, new IllegalStateException("SwitchTransform allows only one Subscriber")); - } - } - - @Override - public void onSubscribe(Subscription s) { - if (Operators.validate(this.s, s)) { - this.s = s; - s.request(1); - } - } - - @Override - public void onNext(T t) { - if (done) { - Operators.onNextDropped(t, currentContext()); - return; - } - - CoreSubscriber i = inner; - - if (i == null) { - try { - first = t; - Publisher result = - Objects.requireNonNull( - transformer.apply(t, this), "The transformer returned a null value"); - result.subscribe(outer); - return; - } catch (Throwable e) { - onError(Operators.onOperatorError(s, e, t, currentContext())); - return; - } - } - - i.onNext(t); - } - - @Override - public boolean tryOnNext(T t) { - if (done) { - Operators.onNextDropped(t, currentContext()); - return false; - } - - Fuseable.ConditionalSubscriber i = inner; - - if (i == null) { - try { - first = t; - Publisher result = - Objects.requireNonNull( - transformer.apply(t, this), "The transformer returned a null value"); - result.subscribe(outer); - return true; - } catch (Throwable e) { - onError(Operators.onOperatorError(s, e, t, currentContext())); - return false; - } - } - - return i.tryOnNext(t); - } - - @Override - public void onError(Throwable t) { - if (done) { - Operators.onErrorDropped(t, currentContext()); - return; - } - - throwable = t; - done = true; - CoreSubscriber i = inner; - - if (i != null) { - if (first == null) { - drainRegular(); - } - } else { - Operators.error(outer, t); - } - } - - @Override - public void onComplete() { - if (done) { - return; - } - - done = true; - CoreSubscriber i = inner; - - if (i != null) { - if (first == null) { - drainRegular(); - } - } else { - Operators.complete(outer); - } - } - - @Override - public void request(long n) { - if (Operators.validate(n)) { - if (first != null && drainRegular() && n != Long.MAX_VALUE) { - if (--n > 0) { - s.request(n); - } - } else { - s.request(n); - } - } - } - - boolean drainRegular() { - if (WIP.getAndIncrement(this) != 0) { - return false; - } - - T f = first; - int m = 1; - boolean sent = false; - Subscription s = this.s; - CoreSubscriber a = inner; - - for (; ; ) { - if (f != null) { - first = null; - - if (s == Operators.cancelledSubscription()) { - Operators.onDiscard(f, a.currentContext()); - return true; - } - - a.onNext(f); - f = null; - sent = true; - } - - if (s == Operators.cancelledSubscription()) { - return sent; - } - - if (done) { - Throwable t = throwable; - if (t != null) { - a.onError(t); - } else { - a.onComplete(); - } - return sent; - } - - m = WIP.addAndGet(this, -m); - - if (m == 0) { - return sent; - } - } - } - } - - static final class ConditionalSubscriberAdapter implements Fuseable.ConditionalSubscriber { - - final CoreSubscriber delegate; - - ConditionalSubscriberAdapter(CoreSubscriber delegate) { - this.delegate = delegate; - } - - @Override - public Context currentContext() { - return delegate.currentContext(); - } - - @Override - public void onSubscribe(Subscription s) { - delegate.onSubscribe(s); - } - - @Override - public void onNext(T t) { - delegate.onNext(t); - } - - @Override - public void onError(Throwable t) { - delegate.onError(t); - } - - @Override - public void onComplete() { - delegate.onComplete(); - } - - @Override - public boolean tryOnNext(T t) { - delegate.onNext(t); - return true; - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java b/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java new file mode 100644 index 000000000..fd6bf0aed --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/SynchronizedIntObjectHashMap.java @@ -0,0 +1,748 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.rsocket.internal; + +import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo; + +import io.netty.util.collection.IntObjectMap; +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * A hash map implementation of {@link IntObjectMap} that uses open addressing for keys. To minimize + * the memory footprint, this class uses open addressing rather than chaining. Collisions are + * resolved using linear probing. Deletions implement compaction, so cost of remove can approach + * O(N) for full maps, which makes a small loadFactor recommended. + * + * @param The value type stored in the map. + */ +public class SynchronizedIntObjectHashMap implements IntObjectMap { + + /** Default initial capacity. Used if not specified in the constructor */ + public static final int DEFAULT_CAPACITY = 8; + + /** Default load factor. Used if not specified in the constructor */ + public static final float DEFAULT_LOAD_FACTOR = 0.5f; + + /** + * Placeholder for null values, so we can use the actual null to mean available. (Better than + * using a placeholder for available: less references for GC processing.) + */ + private static final Object NULL_VALUE = new Object(); + + /** The maximum number of elements allowed without allocating more space. */ + private int maxSize; + + /** The load factor for the map. Used to calculate {@link #maxSize}. */ + private final float loadFactor; + + private int[] keys; + private V[] values; + private int size; + private int mask; + + private final Set keySet = new KeySet(); + private final Set> entrySet = new EntrySet(); + private final Iterable> entries = PrimitiveIterator::new; + + public SynchronizedIntObjectHashMap() { + this(DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR); + } + + public SynchronizedIntObjectHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR); + } + + public SynchronizedIntObjectHashMap(int initialCapacity, float loadFactor) { + if (loadFactor <= 0.0f || loadFactor > 1.0f) { + // Cannot exceed 1 because we can never store more than capacity elements; + // using a bigger loadFactor would trigger rehashing before the desired load is reached. + throw new IllegalArgumentException("loadFactor must be > 0 and <= 1"); + } + + this.loadFactor = loadFactor; + + // Adjust the initial capacity if necessary. + int capacity = safeFindNextPositivePowerOfTwo(initialCapacity); + mask = capacity - 1; + + // Allocate the arrays. + keys = new int[capacity]; + @SuppressWarnings({"unchecked", "SuspiciousArrayCast"}) + V[] temp = (V[]) new Object[capacity]; + values = temp; + + // Initialize the maximum size value. + maxSize = calcMaxSize(capacity); + } + + private static T toExternal(T value) { + assert value != null : "null is not a legitimate internal value. Concurrent Modification?"; + return value == NULL_VALUE ? null : value; + } + + @SuppressWarnings("unchecked") + private static T toInternal(T value) { + return value == null ? (T) NULL_VALUE : value; + } + + public synchronized V[] getValuesCopy() { + V[] values = this.values; + return Arrays.copyOf(values, values.length); + } + + @Override + public synchronized V get(int key) { + int index = indexOf(key); + return index == -1 ? null : toExternal(values[index]); + } + + @Override + public synchronized V put(int key, V value) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (; ; ) { + if (values[index] == null) { + // Found empty slot, use it. + keys[index] = key; + values[index] = toInternal(value); + growSize(); + return null; + } + if (keys[index] == key) { + // Found existing entry with this key, just replace the value. + V previousValue = values[index]; + values[index] = toInternal(value); + return toExternal(previousValue); + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + // Can only happen if the map was full at MAX_ARRAY_SIZE and couldn't grow. + throw new IllegalStateException("Unable to insert"); + } + } + } + + @Override + public synchronized void putAll(Map sourceMap) { + if (sourceMap instanceof SynchronizedIntObjectHashMap) { + // Optimization - iterate through the arrays. + @SuppressWarnings("unchecked") + SynchronizedIntObjectHashMap source = (SynchronizedIntObjectHashMap) sourceMap; + for (int i = 0; i < source.values.length; ++i) { + V sourceValue = source.values[i]; + if (sourceValue != null) { + put(source.keys[i], sourceValue); + } + } + return; + } + + // Otherwise, just add each entry. + for (Entry entry : sourceMap.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public synchronized V remove(int key) { + int index = indexOf(key); + if (index == -1) { + return null; + } + + V prev = values[index]; + removeAt(index); + return toExternal(prev); + } + + @Override + public synchronized int size() { + return size; + } + + @Override + public synchronized boolean isEmpty() { + return size == 0; + } + + @Override + public synchronized void clear() { + Arrays.fill(keys, 0); + Arrays.fill(values, null); + size = 0; + } + + @Override + public synchronized boolean containsKey(int key) { + return indexOf(key) >= 0; + } + + @Override + public synchronized boolean containsValue(Object value) { + @SuppressWarnings("unchecked") + V v1 = toInternal((V) value); + for (V v2 : values) { + // The map supports null values; this will be matched as NULL_VALUE.equals(NULL_VALUE). + if (v2 != null && v2.equals(v1)) { + return true; + } + } + return false; + } + + @Override + public synchronized Iterable> entries() { + return entries; + } + + @Override + public synchronized Collection values() { + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new Iterator() { + final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public V next() { + return iter.next().value(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + @Override + public int size() { + return size; + } + }; + } + + @Override + public synchronized int hashCode() { + // Hashcode is based on all non-zero, valid keys. We have to scan the whole keys + // array, which may have different lengths for two maps of same size(), so the + // capacity cannot be used as input for hashing but the size can. + int hash = size; + for (int key : keys) { + // 0 can be a valid key or unused slot, but won't impact the hashcode in either case. + // This way we can use a cheap loop without conditionals, or hard-to-unroll operations, + // or the devastatingly bad memory locality of visiting value objects. + // Also, it's important to use a hash function that does not depend on the ordering + // of terms, only their values; since the map is an unordered collection and + // entries can end up in different positions in different maps that have the same + // elements, but with different history of puts/removes, due to conflicts. + hash ^= hashCode(key); + } + return hash; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof IntObjectMap)) { + return false; + } + @SuppressWarnings("rawtypes") + IntObjectMap other = (IntObjectMap) obj; + synchronized (this) { + if (size != other.size()) { + return false; + } + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + int key = keys[i]; + Object otherValue = other.get(key); + if (value == NULL_VALUE) { + if (otherValue != null) { + return false; + } + } else if (!value.equals(otherValue)) { + return false; + } + } + } + } + return true; + } + + @Override + public synchronized boolean containsKey(Object key) { + return containsKey(objectToKey(key)); + } + + @Override + public synchronized V get(Object key) { + return get(objectToKey(key)); + } + + @Override + public synchronized V put(Integer key, V value) { + return put(objectToKey(key), value); + } + + @Override + public synchronized V remove(Object key) { + return remove(objectToKey(key)); + } + + @Override + public synchronized Set keySet() { + return keySet; + } + + @Override + public synchronized Set> entrySet() { + return entrySet; + } + + private int objectToKey(Object key) { + return ((Integer) key).intValue(); + } + + /** + * Locates the index for the given key. This method probes using double hashing. + * + * @param key the key for an entry in the map. + * @return the index where the key was found, or {@code -1} if no entry is found for that key. + */ + private int indexOf(int key) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (; ; ) { + if (values[index] == null) { + // It's available, so no chance that this value exists anywhere in the map. + return -1; + } + if (key == keys[index]) { + return index; + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + return -1; + } + } + } + + /** Returns the hashed index for the given key. */ + private int hashIndex(int key) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array + // bounds. + return hashCode(key) & mask; + } + + /** Returns the hash code for the key. */ + private static int hashCode(int key) { + return key; + } + + /** Get the next sequential index after {@code index} and wraps if necessary. */ + private int probeNext(int index) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array + // bounds. + return (index + 1) & mask; + } + + /** Grows the map size after an insertion. If necessary, performs a rehash of the map. */ + private void growSize() { + size++; + + if (size > maxSize) { + if (keys.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Max capacity reached at size=" + size); + } + + // Double the capacity. + rehash(keys.length << 1); + } + } + + /** + * Removes entry at the given index position. Also performs opportunistic, incremental rehashing + * if necessary to not break conflict chains. + * + * @param index the index position of the element to remove. + * @return {@code true} if the next item was moved back. {@code false} otherwise. + */ + private boolean removeAt(final int index) { + --size; + // Clearing the key is not strictly necessary (for GC like in a regular collection), + // but recommended for security. The memory location is still fresh in the cache anyway. + keys[index] = 0; + values[index] = null; + + // In the interval from index to the next available entry, the arrays may have entries + // that are displaced from their base position due to prior conflicts. Iterate these + // entries and move them back if possible, optimizing future lookups. + // Knuth Section 6.4 Algorithm R, also used by the JDK's IdentityHashMap. + + int nextFree = index; + int i = probeNext(index); + for (V value = values[i]; value != null; value = values[i = probeNext(i)]) { + int key = keys[i]; + int bucket = hashIndex(key); + if (i < bucket && (bucket <= nextFree || nextFree <= i) + || bucket <= nextFree && nextFree <= i) { + // Move the displaced entry "back" to the first available position. + keys[nextFree] = key; + values[nextFree] = value; + // Put the first entry after the displaced entry + keys[i] = 0; + values[i] = null; + nextFree = i; + } + } + return nextFree != index; + } + + /** Calculates the maximum size allowed before rehashing. */ + private int calcMaxSize(int capacity) { + // Clip the upper bound so that there will always be at least one available slot. + int upperBound = capacity - 1; + return Math.min(upperBound, (int) (capacity * loadFactor)); + } + + /** + * Rehashes the map for the given capacity. + * + * @param newCapacity the new capacity for the map. + */ + private void rehash(int newCapacity) { + int[] oldKeys = keys; + V[] oldVals = values; + + keys = new int[newCapacity]; + @SuppressWarnings({"unchecked", "SuspiciousArrayCast"}) + V[] temp = (V[]) new Object[newCapacity]; + values = temp; + + maxSize = calcMaxSize(newCapacity); + mask = newCapacity - 1; + + // Insert to the new arrays. + for (int i = 0; i < oldVals.length; ++i) { + V oldVal = oldVals[i]; + if (oldVal != null) { + // Inlined put(), but much simpler: we don't need to worry about + // duplicated keys, growing/rehashing, or failing to insert. + int oldKey = oldKeys[i]; + int index = hashIndex(oldKey); + + for (; ; ) { + if (values[index] == null) { + keys[index] = oldKey; + values[index] = oldVal; + break; + } + + // Conflict, keep probing. Can wrap around, but never reaches startIndex again. + index = probeNext(index); + } + } + } + } + + @Override + public synchronized String toString() { + if (isEmpty()) { + return "{}"; + } + StringBuilder sb = new StringBuilder(4 * size); + sb.append('{'); + boolean first = true; + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + if (!first) { + sb.append(", "); + } + sb.append(keyToString(keys[i])) + .append('=') + .append(value == this ? "(this Map)" : toExternal(value)); + first = false; + } + } + return sb.append('}').toString(); + } + + /** + * Helper method called by {@link #toString()} in order to convert a single map key into a string. + * This is protected to allow subclasses to override the appearance of a given key. + */ + protected String keyToString(int key) { + return Integer.toString(key); + } + + /** Set implementation for iterating over the entries of the map. */ + private final class EntrySet extends AbstractSet> { + @Override + public Iterator> iterator() { + return new MapIterator(); + } + + @Override + public int size() { + return SynchronizedIntObjectHashMap.this.size(); + } + } + + /** Set implementation for iterating over the keys. */ + private final class KeySet extends AbstractSet { + @Override + public int size() { + return SynchronizedIntObjectHashMap.this.size(); + } + + @Override + public boolean contains(Object o) { + return SynchronizedIntObjectHashMap.this.containsKey(o); + } + + @Override + public boolean remove(Object o) { + return SynchronizedIntObjectHashMap.this.remove(o) != null; + } + + @Override + public boolean retainAll(Collection retainedKeys) { + synchronized (SynchronizedIntObjectHashMap.this) { + boolean changed = false; + for (Iterator> iter = entries().iterator(); iter.hasNext(); ) { + PrimitiveEntry entry = iter.next(); + if (!retainedKeys.contains(entry.key())) { + changed = true; + iter.remove(); + } + } + return changed; + } + } + + @Override + public void clear() { + SynchronizedIntObjectHashMap.this.clear(); + } + + @Override + public Iterator iterator() { + synchronized (SynchronizedIntObjectHashMap.this) { + final Iterator> iter = entrySet.iterator(); + return new Iterator() { + @Override + public boolean hasNext() { + synchronized (SynchronizedIntObjectHashMap.this) { + return iter.hasNext(); + } + } + + @Override + public Integer next() { + synchronized (SynchronizedIntObjectHashMap.this) { + return iter.next().getKey(); + } + } + + @Override + public void remove() { + synchronized (SynchronizedIntObjectHashMap.this) { + iter.remove(); + } + } + }; + } + } + } + + /** + * Iterator over primitive entries. Entry key/values are overwritten by each call to {@link + * #next()}. + */ + private final class PrimitiveIterator implements Iterator>, PrimitiveEntry { + private int prevIndex = -1; + private int nextIndex = -1; + private int entryIndex = -1; + + private void scanNext() { + while (++nextIndex != values.length && values[nextIndex] == null) {} + } + + @Override + public boolean hasNext() { + synchronized (SynchronizedIntObjectHashMap.this) { + if (nextIndex == -1) { + scanNext(); + } + return nextIndex != values.length; + } + } + + @Override + public PrimitiveEntry next() { + synchronized (SynchronizedIntObjectHashMap.this) { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + prevIndex = nextIndex; + scanNext(); + + // Always return the same Entry object, just change its index each time. + entryIndex = prevIndex; + return this; + } + } + + @Override + public void remove() { + synchronized (SynchronizedIntObjectHashMap.this) { + if (prevIndex == -1) { + throw new IllegalStateException("next must be called before each remove."); + } + if (removeAt(prevIndex)) { + // removeAt may move elements "back" in the array if they have been displaced because + // their + // spot in the + // array was occupied when they were inserted. If this occurs then the nextIndex is now + // invalid and + // should instead point to the prevIndex which now holds an element which was "moved + // back". + nextIndex = prevIndex; + } + prevIndex = -1; + } + } + + // Entry implementation. Since this implementation uses a single Entry, we coalesce that + // into the Iterator object (potentially making loop optimization much easier). + + @Override + public int key() { + synchronized (SynchronizedIntObjectHashMap.this) { + return keys[entryIndex]; + } + } + + @Override + public V value() { + synchronized (SynchronizedIntObjectHashMap.this) { + return toExternal(values[entryIndex]); + } + } + + @Override + public void setValue(V value) { + synchronized (SynchronizedIntObjectHashMap.this) { + values[entryIndex] = toInternal(value); + } + } + } + + /** Iterator used by the {@link Map} interface. */ + private final class MapIterator implements Iterator> { + private final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + synchronized (SynchronizedIntObjectHashMap.this) { + return iter.hasNext(); + } + } + + @Override + public Entry next() { + synchronized (SynchronizedIntObjectHashMap.this) { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + iter.next(); + + return new MapEntry(iter.entryIndex); + } + } + + @Override + public void remove() { + synchronized (SynchronizedIntObjectHashMap.this) { + iter.remove(); + } + } + } + + /** A single entry in the map. */ + final class MapEntry implements Entry { + private final int entryIndex; + + MapEntry(int entryIndex) { + this.entryIndex = entryIndex; + } + + @Override + public Integer getKey() { + synchronized (SynchronizedIntObjectHashMap.this) { + verifyExists(); + return keys[entryIndex]; + } + } + + @Override + public V getValue() { + synchronized (SynchronizedIntObjectHashMap.this) { + verifyExists(); + return toExternal(values[entryIndex]); + } + } + + @Override + public V setValue(V value) { + synchronized (SynchronizedIntObjectHashMap.this) { + verifyExists(); + V prevValue = toExternal(values[entryIndex]); + values[entryIndex] = toInternal(value); + return prevValue; + } + } + + private void verifyExists() { + if (values[entryIndex] == null) { + throw new IllegalStateException("The map entry has been removed"); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java index 607a7ec73..cb8b5d63d 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -16,7 +16,8 @@ package io.rsocket.internal; -import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; import java.util.Objects; import java.util.Queue; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; @@ -42,40 +43,58 @@ public final class UnboundedProcessor extends FluxProcessor implements Fuseable.QueueSubscription, Fuseable { + final Queue queue; + final Queue priorityQueue; + + volatile boolean done; + Throwable error; + // important to not loose the downstream too early and miss discard hook, while + // having relevant hasDownstreams() + boolean hasDownstream; + volatile CoreSubscriber actual; + + volatile boolean cancelled; + + volatile int once; + @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater ONCE = AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "once"); + volatile int wip; + @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater WIP = AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "wip"); + volatile int discardGuard; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater DISCARD_GUARD = + AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "discardGuard"); + + volatile long requested; + @SuppressWarnings("rawtypes") static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "requested"); - final Queue queue; - volatile boolean done; - Throwable error; - volatile CoreSubscriber actual; - volatile boolean cancelled; - volatile int once; - volatile int wip; - volatile long requested; - volatile boolean outputFused; + boolean outputFused; public UnboundedProcessor() { - this.queue = Queues.unboundedMultiproducer().get(); + this.queue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.priorityQueue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); } @Override public int getBufferSize() { - return Queues.capacity(this.queue); + return Integer.MAX_VALUE; } @Override public Object scanUnsafe(Attr key) { if (Attr.BUFFERED == key) return queue.size(); + if (Attr.PREFETCH == key) return Integer.MAX_VALUE; return super.scanUnsafe(key); } @@ -83,6 +102,7 @@ void drainRegular(Subscriber a) { int missed = 1; final Queue q = queue; + final Queue pq = priorityQueue; for (; ; ) { @@ -92,10 +112,18 @@ void drainRegular(Subscriber a) { while (r != e) { boolean d = done; - T t = q.poll(); - boolean empty = t == null; + T t; + boolean empty; + + if (!pq.isEmpty()) { + t = pq.poll(); + empty = false; + } else { + t = q.poll(); + empty = t == null; + } - if (checkTerminated(d, empty, a, q)) { + if (checkTerminated(d, empty, a)) { return; } @@ -109,7 +137,7 @@ void drainRegular(Subscriber a) { } if (r == e) { - if (checkTerminated(done, q.isEmpty(), a, q)) { + if (checkTerminated(done, q.isEmpty() && pq.isEmpty(), a)) { return; } } @@ -128,13 +156,11 @@ void drainRegular(Subscriber a) { void drainFused(Subscriber a) { int missed = 1; - final Queue q = queue; - for (; ; ) { if (cancelled) { - q.clear(); - actual = null; + this.clear(); + hasDownstream = false; return; } @@ -143,7 +169,7 @@ void drainFused(Subscriber a) { a.onNext(null); if (d) { - actual = null; + hasDownstream = false; Throwable ex = error; if (ex != null) { @@ -163,6 +189,9 @@ void drainFused(Subscriber a) { public void drain() { if (WIP.getAndIncrement(this) != 0) { + if (cancelled) { + this.clear(); + } return; } @@ -187,20 +216,15 @@ public void drain() { } } - boolean checkTerminated(boolean d, boolean empty, Subscriber a, Queue q) { + boolean checkTerminated(boolean d, boolean empty, Subscriber a) { if (cancelled) { - while (!q.isEmpty()) { - T t = q.poll(); - if (t != null) { - ReferenceCountUtil.safeRelease(t); - } - } - actual = null; + this.clear(); + hasDownstream = false; return true; } if (d && empty) { Throwable e = error; - actual = null; + hasDownstream = false; if (e != null) { a.onError(e); } else { @@ -232,11 +256,28 @@ public Context currentContext() { return actual != null ? actual.currentContext() : Context.empty(); } + public void onNextPrioritized(T t) { + if (done || cancelled) { + Operators.onNextDropped(t, currentContext()); + release(t); + return; + } + + if (!priorityQueue.offer(t)) { + Throwable ex = + Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); + onError(Operators.onOperatorError(null, ex, t, currentContext())); + release(t); + return; + } + drain(); + } + @Override public void onNext(T t) { if (done || cancelled) { Operators.onNextDropped(t, currentContext()); - ReferenceCountUtil.safeRelease(t); + release(t); return; } @@ -244,7 +285,7 @@ public void onNext(T t) { Throwable ex = Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); onError(Operators.onOperatorError(null, ex, t, currentContext())); - ReferenceCountUtil.safeRelease(t); + release(t); return; } drain(); @@ -282,7 +323,7 @@ public void subscribe(CoreSubscriber actual) { actual.onSubscribe(this); this.actual = actual; if (cancelled) { - this.actual = null; + this.hasDownstream = false; } else { drain(); } @@ -309,38 +350,56 @@ public void cancel() { cancelled = true; if (WIP.getAndIncrement(this) == 0) { - clear(); - actual = null; + this.clear(); + hasDownstream = false; } } - @Override - public T peek() { - return queue.peek(); - } - @Override @Nullable public T poll() { + Queue pq = this.priorityQueue; + if (!pq.isEmpty()) { + return pq.poll(); + } return queue.poll(); } @Override public int size() { - return queue.size(); + return priorityQueue.size() + queue.size(); } @Override public boolean isEmpty() { - return queue.isEmpty(); + return priorityQueue.isEmpty() && queue.isEmpty(); } @Override public void clear() { - while (!queue.isEmpty()) { - T t = queue.poll(); - if (t != null) { - ReferenceCountUtil.safeRelease(t); + if (DISCARD_GUARD.getAndIncrement(this) != 0) { + return; + } + + int missed = 1; + + for (; ; ) { + while (!queue.isEmpty()) { + T t = queue.poll(); + if (t != null) { + release(t); + } + } + while (!priorityQueue.isEmpty()) { + T t = priorityQueue.poll(); + if (t != null) { + release(t); + } + } + + missed = DISCARD_GUARD.addAndGet(this, -missed); + if (missed == 0) { + break; } } } @@ -382,6 +441,19 @@ public long downstreamCount() { @Override public boolean hasDownstreams() { - return actual != null; + return hasDownstream; + } + + void release(T t) { + if (t instanceof ReferenceCounted) { + ReferenceCounted refCounted = (ReferenceCounted) t; + if (refCounted.refCnt() > 0) { + try { + refCounted.release(); + } catch (Throwable ex) { + // no ops + } + } + } } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java deleted file mode 100644 index 1e616b427..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java +++ /dev/null @@ -1,190 +0,0 @@ -package io.rsocket.internal; - -import java.util.Objects; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import java.util.function.LongSupplier; -import java.util.stream.Stream; -import org.reactivestreams.Processor; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Disposable; -import reactor.core.Scannable; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.Operators; -import reactor.util.annotation.Nullable; -import reactor.util.context.Context; -import reactor.util.function.Tuple2; - -public class UnicastMonoProcessor extends Mono - implements Processor, - CoreSubscriber, - Disposable, - Subscription, - Scannable, - LongSupplier { - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(UnicastMonoProcessor.class, "once"); - - private final MonoProcessor processor; - - @SuppressWarnings("unused") - private volatile int once; - - private UnicastMonoProcessor() { - this.processor = MonoProcessor.create(); - } - - public static UnicastMonoProcessor create() { - return new UnicastMonoProcessor<>(); - } - - @Override - public Stream actuals() { - return processor.actuals(); - } - - @Override - public boolean isScanAvailable() { - return processor.isScanAvailable(); - } - - @Override - public String name() { - return processor.name(); - } - - @Override - public String stepName() { - return processor.stepName(); - } - - @Override - public Stream steps() { - return processor.steps(); - } - - @Override - public Stream parents() { - return processor.parents(); - } - - @Override - @Nullable - public T scan(Attr key) { - return processor.scan(key); - } - - @Override - public T scanOrDefault(Attr key, T defaultValue) { - return processor.scanOrDefault(key, defaultValue); - } - - @Override - public Stream> tags() { - return processor.tags(); - } - - @Override - public long getAsLong() { - return processor.getAsLong(); - } - - @Override - public void onSubscribe(Subscription s) { - processor.onSubscribe(s); - } - - @Override - public void onNext(O o) { - processor.onNext(o); - } - - @Override - public void onError(Throwable t) { - processor.onError(t); - } - - @Nullable - public Throwable getError() { - return processor.getError(); - } - - public boolean isCancelled() { - return processor.isCancelled(); - } - - public boolean isError() { - return processor.isError(); - } - - public boolean isSuccess() { - return processor.isSuccess(); - } - - public boolean isTerminated() { - return processor.isTerminated(); - } - - @Nullable - public O peek() { - return processor.peek(); - } - - public long downstreamCount() { - return processor.downstreamCount(); - } - - public boolean hasDownstreams() { - return processor.hasDownstreams(); - } - - @Override - public void onComplete() { - processor.onComplete(); - } - - @Override - public void request(long n) { - processor.request(n); - } - - @Override - public void cancel() { - processor.cancel(); - } - - @Override - public void dispose() { - processor.dispose(); - } - - @Override - public Context currentContext() { - return processor.currentContext(); - } - - @Override - public boolean isDisposed() { - return processor.isDisposed(); - } - - @Override - public Object scanUnsafe(Attr key) { - return processor.scanUnsafe(key); - } - - @Override - public void subscribe(CoreSubscriber actual) { - Objects.requireNonNull(actual, "subscribe"); - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - processor.subscribe(actual); - } else { - Operators.error( - actual, - new IllegalStateException("UnicastMonoProcessor allows only a single Subscriber")); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java new file mode 100644 index 000000000..6939b0f7a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java @@ -0,0 +1,258 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; + +import java.util.AbstractQueue; +import java.util.Iterator; + +abstract class BaseLinkedQueuePad0 extends AbstractQueue implements MessagePassingQueue { + long p00, p01, p02, p03, p04, p05, p06, p07; + long p10, p11, p12, p13, p14, p15, p16; +} + +// $gen:ordered-fields +abstract class BaseLinkedQueueProducerNodeRef extends BaseLinkedQueuePad0 { + static final long P_NODE_OFFSET = + fieldOffset(BaseLinkedQueueProducerNodeRef.class, "producerNode"); + + private LinkedQueueNode producerNode; + + final void spProducerNode(LinkedQueueNode newValue) { + producerNode = newValue; + } + + @SuppressWarnings("unchecked") + final LinkedQueueNode lvProducerNode() { + return (LinkedQueueNode) UNSAFE.getObjectVolatile(this, P_NODE_OFFSET); + } + + @SuppressWarnings("unchecked") + final boolean casProducerNode(LinkedQueueNode expect, LinkedQueueNode newValue) { + return UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, expect, newValue); + } + + final LinkedQueueNode lpProducerNode() { + return producerNode; + } +} + +abstract class BaseLinkedQueuePad1 extends BaseLinkedQueueProducerNodeRef { + long p01, p02, p03, p04, p05, p06, p07; + long p10, p11, p12, p13, p14, p15, p16, p17; +} + +// $gen:ordered-fields +abstract class BaseLinkedQueueConsumerNodeRef extends BaseLinkedQueuePad1 { + private static final long C_NODE_OFFSET = + fieldOffset(BaseLinkedQueueConsumerNodeRef.class, "consumerNode"); + + private LinkedQueueNode consumerNode; + + final void spConsumerNode(LinkedQueueNode newValue) { + consumerNode = newValue; + } + + @SuppressWarnings("unchecked") + final LinkedQueueNode lvConsumerNode() { + return (LinkedQueueNode) UNSAFE.getObjectVolatile(this, C_NODE_OFFSET); + } + + final LinkedQueueNode lpConsumerNode() { + return consumerNode; + } +} + +abstract class BaseLinkedQueuePad2 extends BaseLinkedQueueConsumerNodeRef { + long p01, p02, p03, p04, p05, p06, p07; + long p10, p11, p12, p13, p14, p15, p16, p17; +} + +/** + * A base data structure for concurrent linked queues. For convenience also pulled in common single + * consumer methods since at this time there's no plan to implement MC. + * + * @param + * @author nitsanw + */ +abstract class BaseLinkedQueue extends BaseLinkedQueuePad2 { + + @Override + public final Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + protected final LinkedQueueNode newNode() { + return new LinkedQueueNode(); + } + + protected final LinkedQueueNode newNode(E e) { + return new LinkedQueueNode(e); + } + + /** + * {@inheritDoc}
+ * + *

IMPLEMENTATION NOTES:
+ * This is an O(n) operation as we run through all the nodes and count them.
+ * The accuracy of the value returned by this method is subject to races with producer/consumer + * threads. In particular when racing with the consumer thread this method may under estimate the + * size.
+ * + * @see java.util.Queue#size() + */ + @Override + public final int size() { + // Read consumer first, this is important because if the producer is node is 'older' than the + // consumer + // the consumer may overtake it (consume past it) invalidating the 'snapshot' notion of size. + LinkedQueueNode chaserNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + int size = 0; + // must chase the nodes all the way to the producer node, but there's no need to count beyond + // expected head. + while (chaserNode != producerNode + && // don't go passed producer node + chaserNode != null + && // stop at last node + size < Integer.MAX_VALUE) // stop at max int + { + LinkedQueueNode next; + next = chaserNode.lvNext(); + // check if this node has been consumed, if so return what we have + if (next == chaserNode) { + return size; + } + chaserNode = next; + size++; + } + return size; + } + + /** + * {@inheritDoc}
+ * + *

IMPLEMENTATION NOTES:
+ * Queue is empty when producerNode is the same as consumerNode. An alternative implementation + * would be to observe the producerNode.value is null, which also means an empty queue because + * only the consumerNode.value is allowed to be null. + * + * @see MessagePassingQueue#isEmpty() + */ + @Override + public final boolean isEmpty() { + return lvConsumerNode() == lvProducerNode(); + } + + protected E getSingleConsumerNodeValue( + LinkedQueueNode currConsumerNode, LinkedQueueNode nextNode) { + // we have to null out the value because we are going to hang on to the node + final E nextValue = nextNode.getAndNullValue(); + + // Fix up the next ref of currConsumerNode to prevent promoted nodes from keeping new ones + // alive. + // We use a reference to self instead of null because null is already a meaningful value (the + // next of + // producer node is null). + currConsumerNode.soNext(currConsumerNode); + spConsumerNode(nextNode); + // currConsumerNode is now no longer referenced and can be collected + return nextValue; + } + + @Override + public E relaxedPoll() { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + final LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + @Override + public E relaxedPeek() { + final LinkedQueueNode nextNode = lpConsumerNode().lvNext(); + if (nextNode != null) { + return nextNode.lpValue(); + } + return null; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public int drain(Consumer c) { + long result = 0; // use long to force safepoint into loop below + int drained; + do { + drained = drain(c, 4096); + result += drained; + } while (drained == 4096 && result <= Integer.MAX_VALUE - 4096); + return (int) result; + } + + @Override + public int drain(Consumer c, int limit) { + LinkedQueueNode chaserNode = this.lpConsumerNode(); + for (int i = 0; i < limit; i++) { + final LinkedQueueNode nextNode = chaserNode.lvNext(); + + if (nextNode == null) { + return i; + } + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + return limit; + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + LinkedQueueNode chaserNode = this.lpConsumerNode(); + int idleCounter = 0; + while (exit.keepRunning()) { + for (int i = 0; i < 4096; i++) { + final LinkedQueueNode nextNode = chaserNode.lvNext(); + if (nextNode == null) { + idleCounter = wait.idle(idleCounter); + continue; + } + + idleCounter = 0; + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + } + } + + @Override + public int capacity() { + return UNBOUNDED_CAPACITY; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java new file mode 100644 index 000000000..635779df3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java @@ -0,0 +1,663 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.CircularArrayOffsetCalculator.allocate; +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.modifiedCalcElementOffset; +import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.calcElementOffset; +import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.lvElement; +import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.soElement; + +import io.rsocket.internal.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import io.rsocket.internal.jctools.util.PortableJvmInfo; +import io.rsocket.internal.jctools.util.Pow2; +import io.rsocket.internal.jctools.util.RangeUtil; +import java.util.AbstractQueue; +import java.util.Iterator; + +abstract class BaseMpscLinkedArrayQueuePad1 extends AbstractQueue implements IndexedQueue { + long p01, p02, p03, p04, p05, p06, p07; + long p10, p11, p12, p13, p14, p15, p16, p17; +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueProducerFields extends BaseMpscLinkedArrayQueuePad1 { + private static final long P_INDEX_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(long newValue) { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + + final boolean casProducerIndex(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +abstract class BaseMpscLinkedArrayQueuePad2 extends BaseMpscLinkedArrayQueueProducerFields { + long p01, p02, p03, p04, p05, p06, p07; + long p10, p11, p12, p13, p14, p15, p16, p17; +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueConsumerFields extends BaseMpscLinkedArrayQueuePad2 { + private static final long C_INDEX_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueConsumerFields.class, "consumerIndex"); + + private volatile long consumerIndex; + protected long consumerMask; + protected E[] consumerBuffer; + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +abstract class BaseMpscLinkedArrayQueuePad3 extends BaseMpscLinkedArrayQueueConsumerFields { + long p0, p1, p2, p3, p4, p5, p6, p7; + long p10, p11, p12, p13, p14, p15, p16, p17; +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueColdProducerFields + extends BaseMpscLinkedArrayQueuePad3 { + private static final long P_LIMIT_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueColdProducerFields.class, "producerLimit"); + + private volatile long producerLimit; + protected long producerMask; + protected E[] producerBuffer; + + final long lvProducerLimit() { + return producerLimit; + } + + final boolean casProducerLimit(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_LIMIT_OFFSET, expect, newValue); + } + + final void soProducerLimit(long newValue) { + UNSAFE.putOrderedLong(this, P_LIMIT_OFFSET, newValue); + } +} + +/** + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in + * linked chunks of the initial size. The queue grows only when the current buffer is full and + * elements are not copied on resize, instead a link to the new buffer is stored in the old buffer + * for the consumer to follow.
+ * + * @param + */ +public abstract class BaseMpscLinkedArrayQueue + extends BaseMpscLinkedArrayQueueColdProducerFields + implements MessagePassingQueue, QueueProgressIndicators { + // No post padding here, subclasses must add + private static final Object JUMP = new Object(); + private static final Object BUFFER_CONSUMED = new Object(); + private static final int CONTINUE_TO_P_INDEX_CAS = 0; + private static final int RETRY = 1; + private static final int QUEUE_FULL = 2; + private static final int QUEUE_RESIZE = 3; + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the + * chunk size. Must be 2 or more. + */ + public BaseMpscLinkedArrayQueue(final int initialCapacity) { + RangeUtil.checkGreaterThanOrEqual(initialCapacity, 2, "initialCapacity"); + + int p2capacity = Pow2.roundToPowerOfTwo(initialCapacity); + // leave lower bit of mask clear + long mask = (p2capacity - 1) << 1; + // need extra element to point at next array + E[] buffer = allocate(p2capacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + soProducerLimit(mask); // we know it's all empty to start with + } + + @Override + public final int size() { + // NOTE: because indices are on even numbers we cannot use the size util. + + /* + * It is possible for a thread to be interrupted or reschedule between the read of the producer and + * consumer indices, therefore protection is required to ensure size is within valid range. In the + * event of concurrent polls/offers to this method the size is OVER estimated as we read consumer + * index BEFORE the producer index. + */ + long after = lvConsumerIndex(); + long size; + while (true) { + final long before = after; + final long currentProducerIndex = lvProducerIndex(); + after = lvConsumerIndex(); + if (before == after) { + size = ((currentProducerIndex - after) >> 1); + break; + } + } + // Long overflow is impossible, so size is always positive. Integer overflow is possible for the + // unbounded + // indexed queues. + if (size > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) size; + } + } + + @Override + public final boolean isEmpty() { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there + // is + // nothing we can do to make this an exact method. + return (this.lvConsumerIndex() == this.lvProducerIndex()); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + + long mask; + E[] buffer; + long pIndex; + + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + + // mask/buffer may get changed by resizing -> only use for array access after successful CAS. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) - [mask/buffer] -> cas(pIndex) + + // assumption behind this optimization is that queue is almost always empty or near empty + if (producerLimit <= pIndex) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch (result) { + case CONTINUE_TO_P_INDEX_CAS: + break; + case RETRY: + continue; + case QUEUE_FULL: + return false; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, e, null); + return true; + } + } + + if (casProducerIndex(pIndex, pIndex + 2)) { + break; + } + } + // INDEX visible before ELEMENT + final long offset = modifiedCalcElementOffset(pIndex, mask); + soElement(buffer, offset, e); // release element e + return true; + } + + /** + * {@inheritDoc} + * + *

This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcElementOffset(index, mask); + Object e = lvElement(buffer, offset); // LoadLoad + if (e == null) { + if (index != lvProducerIndex()) { + // poll() == null iff queue is empty, null element is not strong enough indicator, so we + // must + // check the producer index. If the queue is indeed not empty we spin until element is + // visible. + do { + e = lvElement(buffer, offset); + } while (e == null); + } else { + return null; + } + } + + if (e == JUMP) { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, index); + } + + soElement(buffer, offset, null); // release element null + soConsumerIndex(index + 2); // release cIndex + return (E) e; + } + + /** + * {@inheritDoc} + * + *

This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcElementOffset(index, mask); + Object e = lvElement(buffer, offset); // LoadLoad + if (e == null && index != lvProducerIndex()) { + // peek() == null iff queue is empty, null element is not strong enough indicator, so we must + // check the producer index. If the queue is indeed not empty we spin until element is + // visible. + do { + e = lvElement(buffer, offset); + } while (e == null); + } + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), index); + } + return (E) e; + } + + /** We do not inline resize into this method because we do not resize on fill. */ + private int offerSlowPath(long mask, long pIndex, long producerLimit) { + final long cIndex = lvConsumerIndex(); + long bufferCapacity = getCurrentBufferCapacity(mask); + + if (cIndex + bufferCapacity > pIndex) { + if (!casProducerLimit(producerLimit, cIndex + bufferCapacity)) { + // retry from top + return RETRY; + } else { + // continue to pIndex CAS + return CONTINUE_TO_P_INDEX_CAS; + } + } + // full and cannot grow + else if (availableInQueue(pIndex, cIndex) <= 0) { + // offer should return false; + return QUEUE_FULL; + } + // grab index for resize -> set lower bit + else if (casProducerIndex(pIndex, pIndex + 1)) { + // trigger a resize + return QUEUE_RESIZE; + } else { + // failed resize attempt, retry from top + return RETRY; + } + } + + /** @return available elements in queue * 2 */ + protected abstract long availableInQueue(long pIndex, long cIndex); + + @SuppressWarnings("unchecked") + private E[] nextBuffer(final E[] buffer, final long mask) { + final long offset = nextArrayOffset(mask); + final E[] nextBuffer = (E[]) lvElement(buffer, offset); + consumerBuffer = nextBuffer; + consumerMask = (length(nextBuffer) - 2) << 1; + soElement(buffer, offset, BUFFER_CONSUMED); + return nextBuffer; + } + + private long nextArrayOffset(long mask) { + return modifiedCalcElementOffset(mask + 2, Long.MAX_VALUE); + } + + private E newBufferPoll(E[] nextBuffer, long index) { + final long offset = modifiedCalcElementOffset(index, consumerMask); + final E n = lvElement(nextBuffer, offset); // LoadLoad + if (n == null) { + throw new IllegalStateException("new buffer must have at least one element"); + } + soElement(nextBuffer, offset, null); // StoreStore + soConsumerIndex(index + 2); + return n; + } + + private E newBufferPeek(E[] nextBuffer, long index) { + final long offset = modifiedCalcElementOffset(index, consumerMask); + final E n = lvElement(nextBuffer, offset); // LoadLoad + if (null == n) { + throw new IllegalStateException("new buffer must have at least one element"); + } + return n; + } + + @Override + public long currentProducerIndex() { + return lvProducerIndex() / 2; + } + + @Override + public long currentConsumerIndex() { + return lvConsumerIndex() / 2; + } + + @Override + public abstract int capacity(); + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPoll() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcElementOffset(index, mask); + Object e = lvElement(buffer, offset); // LoadLoad + if (e == null) { + return null; + } + if (e == JUMP) { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, index); + } + soElement(buffer, offset, null); + soConsumerIndex(index + 2); + return (E) e; + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPeek() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcElementOffset(index, mask); + Object e = lvElement(buffer, offset); // LoadLoad + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), index); + } + return (E) e; + } + + @Override + public int fill(Supplier s) { + long result = + 0; // result is a long because we want to have a safepoint check at regular intervals + final int capacity = capacity(); + do { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= capacity); + return (int) result; + } + + @Override + public int fill(Supplier s, int batchSize) { + long mask; + E[] buffer; + long pIndex; + int claimedSlots; + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + + // NOTE: mask/buffer may get changed by resizing -> only use for array access after successful + // CAS. + // Only by virtue offloading them between the lvProducerIndex and a successful + // casProducerIndex are they + // safe to use. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) -> [mask/buffer] -> cas(pIndex) + + // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' + long batchIndex = Math.min(producerLimit, pIndex + 2 * batchSize); + + if (pIndex >= producerLimit || producerLimit < batchIndex) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch (result) { + case CONTINUE_TO_P_INDEX_CAS: + // offer slow path verifies only one slot ahead, we cannot rely on indication here + case RETRY: + continue; + case QUEUE_FULL: + return 0; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, null, s); + return 1; + } + } + + // claim limit slots at once + if (casProducerIndex(pIndex, batchIndex)) { + claimedSlots = (int) ((batchIndex - pIndex) / 2); + break; + } + } + + for (int i = 0; i < claimedSlots; i++) { + final long offset = modifiedCalcElementOffset(pIndex + 2 * i, mask); + soElement(buffer, offset, s.get()); + } + return claimedSlots; + } + + @Override + public void fill(Supplier s, WaitStrategy w, ExitCondition exit) { + + while (exit.keepRunning()) { + if (fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { + int idleCounter = 0; + while (exit.keepRunning() && fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { + idleCounter = w.idle(idleCounter); + } + } + } + } + + @Override + public int drain(Consumer c) { + return drain(c, capacity()); + } + + @Override + public int drain(final Consumer c, final int limit) { + // Impl note: there are potentially some small gains to be had by manually inlining + // relaxedPoll() and hoisting + // reused fields out to reduce redundant reads. + int i = 0; + E m; + for (; i < limit && (m = relaxedPoll()) != null; i++) { + c.accept(m); + } + return i; + } + + @Override + public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { + int idleCounter = 0; + while (exit.keepRunning()) { + E e = relaxedPoll(); + if (e == null) { + idleCounter = w.idle(idleCounter); + continue; + } + idleCounter = 0; + c.accept(e); + } + } + + /** + * Get an iterator for this queue. This method is thread safe. + * + *

The iterator provides a best-effort snapshot of the elements in the queue. The returned + * iterator is not guaranteed to return elements in queue order, and races with the consumer + * thread may cause gaps in the sequence of returned elements. Like {link #relaxedPoll}, the + * iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public Iterator iterator() { + return new WeakIterator(); + } + + private final class WeakIterator implements Iterator { + + private long nextIndex; + private E nextElement; + private E[] currentBuffer; + private int currentBufferLength; + + WeakIterator() { + setBuffer(consumerBuffer); + nextElement = getNext(); + } + + @Override + public boolean hasNext() { + return nextElement != null; + } + + @Override + public E next() { + E e = nextElement; + nextElement = getNext(); + return e; + } + + private void setBuffer(E[] buffer) { + this.currentBuffer = buffer; + this.currentBufferLength = length(buffer); + this.nextIndex = 0; + } + + private E getNext() { + while (true) { + while (nextIndex < currentBufferLength - 1) { + long offset = calcElementOffset(nextIndex++); + E e = lvElement(currentBuffer, offset); + if (e != null && e != JUMP) { + return e; + } + } + long offset = calcElementOffset(currentBufferLength - 1); + Object nextArray = lvElement(currentBuffer, offset); + if (nextArray == BUFFER_CONSUMED) { + // Consumer may have passed us, just jump to the current consumer buffer + setBuffer(consumerBuffer); + } else if (nextArray != null) { + setBuffer((E[]) nextArray); + } else { + return null; + } + } + } + } + + private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s) { + assert (e != null && s == null) || (e == null || s != null); + int newBufferLength = getNextBufferSize(oldBuffer); + final E[] newBuffer; + try { + newBuffer = allocate(newBufferLength); + } catch (OutOfMemoryError oom) { + assert lvProducerIndex() == pIndex + 1; + soProducerIndex(pIndex); + throw oom; + } + + producerBuffer = newBuffer; + final int newMask = (newBufferLength - 2) << 1; + producerMask = newMask; + + final long offsetInOld = modifiedCalcElementOffset(pIndex, oldMask); + final long offsetInNew = modifiedCalcElementOffset(pIndex, newMask); + + soElement(newBuffer, offsetInNew, e == null ? s.get() : e); // element in new array + soElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); // buffer linked + + // ASSERT code + final long cIndex = lvConsumerIndex(); + final long availableInQueue = availableInQueue(pIndex, cIndex); + RangeUtil.checkPositive(availableInQueue, "availableInQueue"); + + // Invalidate racing CASs + // We never set the limit beyond the bounds of a buffer + soProducerLimit(pIndex + Math.min(newMask, availableInQueue)); + + // make resize visible to the other producers + soProducerIndex(pIndex + 2); + + // INDEX visible before ELEMENT, consistent with consumer expectation + + // make resize visible to consumer + soElement(oldBuffer, offsetInOld, JUMP); + } + + /** @return next buffer size(inclusive of next array pointer) */ + protected abstract int getNextBufferSize(E[] buffer); + + /** @return current buffer capacity for elements (excluding next pointer and jump entry) * 2 */ + protected abstract long getCurrentBufferCapacity(long mask); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java new file mode 100644 index 000000000..d746fccbb --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/CircularArrayOffsetCalculator.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ARRAY_BASE; +import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; + +import io.rsocket.internal.jctools.util.InternalAPI; + +@InternalAPI +public final class CircularArrayOffsetCalculator { + @SuppressWarnings("unchecked") + public static E[] allocate(int capacity) { + return (E[]) new Object[capacity]; + } + + /** + * @param index desirable element index + * @param mask (length - 1) + * @return the offset in bytes within the array for a given index. + */ + public static long calcElementOffset(long index, long mask) { + return REF_ARRAY_BASE + ((index & mask) << REF_ELEMENT_SHIFT); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java new file mode 100644 index 000000000..1b7d43166 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import io.rsocket.internal.jctools.util.InternalAPI; + +@InternalAPI +public final class IndexedQueueSizeUtil { + public static int size(IndexedQueue iq) { + /* + * It is possible for a thread to be interrupted or reschedule between the read of the producer and + * consumer indices, therefore protection is required to ensure size is within valid range. In the + * event of concurrent polls/offers to this method the size is OVER estimated as we read consumer + * index BEFORE the producer index. + */ + long after = iq.lvConsumerIndex(); + long size; + while (true) { + final long before = after; + final long currentProducerIndex = iq.lvProducerIndex(); + after = iq.lvConsumerIndex(); + if (before == after) { + size = (currentProducerIndex - after); + break; + } + } + // Long overflow is impossible (), so size is always positive. Integer overflow is possible for + // the unbounded + // indexed queues. + if (size > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) size; + } + } + + public static boolean isEmpty(IndexedQueue iq) { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there + // is + // nothing we can do to make this an exact method. + return (iq.lvConsumerIndex() == iq.lvProducerIndex()); + } + + @InternalAPI + public interface IndexedQueue { + long lvConsumerIndex(); + + long lvProducerIndex(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java new file mode 100644 index 000000000..5e7831128 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ARRAY_BASE; +import static io.rsocket.internal.jctools.util.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; + +/** This is used for method substitution in the LinkedArray classes code generation. */ +final class LinkedArrayQueueUtil { + private LinkedArrayQueueUtil() {} + + static int length(Object[] buf) { + return buf.length; + } + + /** + * This method assumes index is actually (index << 1) because lower bit is used for resize. This + * is compensated for by reducing the element shift. The computation is constant folded, so + * there's no cost. + */ + static long modifiedCalcElementOffset(long index, long mask) { + return REF_ARRAY_BASE + ((index & mask) << (REF_ELEMENT_SHIFT - 1)); + } + + static long nextArrayOffset(Object[] curr) { + return REF_ARRAY_BASE + ((long) (length(curr) - 1) << REF_ELEMENT_SHIFT); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java new file mode 100644 index 000000000..6ea69e330 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.util.UnsafeAccess.fieldOffset; + +final class LinkedQueueNode { + private static final long NEXT_OFFSET = fieldOffset(LinkedQueueNode.class, "next"); + + private E value; + private volatile LinkedQueueNode next; + + LinkedQueueNode() { + this(null); + } + + LinkedQueueNode(E val) { + spValue(val); + } + + /** + * Gets the current value and nulls out the reference to it from this node. + * + * @return value + */ + public E getAndNullValue() { + E temp = lpValue(); + spValue(null); + return temp; + } + + public E lpValue() { + return value; + } + + public void spValue(E newValue) { + value = newValue; + } + + public void soNext(LinkedQueueNode n) { + UNSAFE.putOrderedObject(this, NEXT_OFFSET, n); + } + + public LinkedQueueNode lvNext() { + return next; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java new file mode 100644 index 000000000..e0c3d0ee1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +public interface MessagePassingQueue { + int UNBOUNDED_CAPACITY = -1; + + interface Supplier { + T get(); + } + + interface Consumer { + void accept(T e); + } + + interface WaitStrategy { + int idle(int idleCounter); + } + + interface ExitCondition { + + boolean keepRunning(); + } + + boolean offer(T e); + + T poll(); + + T peek(); + + int size(); + + void clear(); + + boolean isEmpty(); + + int capacity(); + + boolean relaxedOffer(T e); + + T relaxedPoll(); + + T relaxedPeek(); + + int drain(Consumer c); + + int fill(Supplier s); + + int drain(Consumer c, int limit); + + int fill(Supplier s, int limit); + + void drain(Consumer c, WaitStrategy wait, ExitCondition exit); + + void fill(Supplier s, WaitStrategy wait, ExitCondition exit); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java new file mode 100644 index 000000000..59eab33a1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; + +import io.rsocket.internal.jctools.util.PortableJvmInfo; + +/** + * An MPSC array queue which starts at initialCapacity and grows indefinitely in linked + * chunks of the initial size. The queue grows only when the current chunk is full and elements are + * not copied on resize, instead a link to the new chunk is stored in the old chunk for the consumer + * to follow.
+ * + * @param + */ +public class MpscUnboundedArrayQueue extends BaseMpscLinkedArrayQueue { + long p0, p1, p2, p3, p4, p5, p6, p7; + long p10, p11, p12, p13, p14, p15, p16, p17; + + public MpscUnboundedArrayQueue(int chunkSize) { + super(chunkSize); + } + + @Override + protected long availableInQueue(long pIndex, long cIndex) { + return Integer.MAX_VALUE; + } + + @Override + public int capacity() { + return MessagePassingQueue.UNBOUNDED_CAPACITY; + } + + @Override + public int drain(Consumer c) { + return drain(c, 4096); + } + + @Override + public int fill(Supplier s) { + long result = + 0; // result is a long because we want to have a safepoint check at regular intervals + final int capacity = 4096; + do { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= capacity); + return (int) result; + } + + @Override + protected int getNextBufferSize(E[] buffer) { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return mask; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java new file mode 100644 index 000000000..6418cc947 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +/** + * This interface is provided for monitoring purposes only and is only available on queues where it + * is easy to provide it. The producer/consumer progress indicators usually correspond with the + * number of elements offered/polled, but they are not guaranteed to maintain that semantic. + * + * @author nitsanw + */ +public interface QueueProgressIndicators { + + /** + * This method has no concurrent visibility semantics. The value returned may be negative. Under + * normal circumstances 2 consecutive calls to this method can offer an idea of progress made by + * producer threads by subtracting the 2 results though in extreme cases (if producers have + * progressed by more than 2^64) this may also fail.
+ * This value will normally indicate number of elements passed into the queue, but may under some + * circumstances be a derivative of that figure. This method should not be used to derive size or + * emptiness. + * + * @return the current value of the producer progress index + */ + long currentProducerIndex(); + + /** + * This method has no concurrent visibility semantics. The value returned may be negative. Under + * normal circumstances 2 consecutive calls to this method can offer an idea of progress made by + * consumer threads by subtracting the 2 results though in extreme cases (if consumers have + * progressed by more than 2^64) this may also fail.
+ * This value will normally indicate number of elements taken out of the queue, but may under some + * circumstances be a derivative of that figure. This method should not be used to derive size or + * emptiness. + * + * @return the current value of the consumer progress index + */ + long currentConsumerIndex(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java new file mode 100644 index 000000000..f233e9597 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/InternalAPI.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.util; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * This annotation marks classes and methods which may be public for any reason (to support better + * testing or reduce code duplication) but are not intended as public API and may change between + * releases without the change being considered a breaking API change (a major release). + */ +@Target({ElementType.TYPE, ElementType.METHOD, ElementType.FIELD, ElementType.CONSTRUCTOR}) +@Retention(RetentionPolicy.SOURCE) +public @interface InternalAPI {} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java new file mode 100644 index 000000000..2d567d60d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/PortableJvmInfo.java @@ -0,0 +1,23 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.util; + +/** JVM Information that is standard and available on all JVMs (i.e. does not use unsafe) */ +@InternalAPI +public interface PortableJvmInfo { + int CACHE_LINE_SIZE = Integer.getInteger("jctools.cacheLineSize", 64); + int CPUs = Runtime.getRuntime().availableProcessors(); + int RECOMENDED_OFFER_BATCH = CPUs * 4; + int RECOMENDED_POLL_BATCH = CPUs * 4; +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java new file mode 100644 index 000000000..d8c66d89e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/Pow2.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.util; + +/** Power of 2 utility functions. */ +@InternalAPI +public final class Pow2 { + public static final int MAX_POW2 = 1 << 30; + + /** + * @param value from which next positive power of two will be found. + * @return the next positive power of 2, this value if it is a power of 2. Negative values are + * mapped to 1. + * @throws IllegalArgumentException is value is more than MAX_POW2 or less than 0 + */ + public static int roundToPowerOfTwo(final int value) { + if (value > MAX_POW2) { + throw new IllegalArgumentException( + "There is no larger power of 2 int for value:" + value + " since it exceeds 2^31."); + } + if (value < 0) { + throw new IllegalArgumentException("Given value:" + value + ". Expecting value >= 0."); + } + final int nextPow2 = 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + return nextPow2; + } + + /** + * @param value to be tested to see if it is a power of two. + * @return true if the value is a power of 2 otherwise false. + */ + public static boolean isPowerOfTwo(final int value) { + return (value & (value - 1)) == 0; + } + + /** + * Align a value to the next multiple up of alignment. If the value equals an alignment multiple + * then it is returned unchanged. + * + * @param value to be aligned up. + * @param alignment to be used, must be a power of 2. + * @return the value aligned to the next boundary. + */ + public static long align(final long value, final int alignment) { + if (!isPowerOfTwo(alignment)) { + throw new IllegalArgumentException("alignment must be a power of 2:" + alignment); + } + return (value + (alignment - 1)) & ~(alignment - 1); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java new file mode 100644 index 000000000..77a0582ca --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/RangeUtil.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.util; + +@InternalAPI +public final class RangeUtil { + public static long checkPositive(long n, String name) { + if (n <= 0) { + throw new IllegalArgumentException(name + ": " + n + " (expected: > 0)"); + } + + return n; + } + + public static int checkPositiveOrZero(int n, String name) { + if (n < 0) { + throw new IllegalArgumentException(name + ": " + n + " (expected: >= 0)"); + } + + return n; + } + + public static int checkLessThan(int n, int expected, String name) { + if (n >= expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: < " + expected + ')'); + } + + return n; + } + + public static int checkLessThanOrEqual(int n, long expected, String name) { + if (n > expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: <= " + expected + ')'); + } + + return n; + } + + public static int checkGreaterThanOrEqual(int n, int expected, String name) { + if (n < expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: >= " + expected + ')'); + } + + return n; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java new file mode 100755 index 000000000..793e64505 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeAccess.java @@ -0,0 +1,80 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.util; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import sun.misc.Unsafe; + +/** + * Why should we resort to using Unsafe?
+ * + *

    + *
  1. To construct class fields which allow volatile/ordered/plain access: This requirement is + * covered by {@link AtomicReferenceFieldUpdater} and similar but their performance is + * arguably worse than the DIY approach (depending on JVM version) while Unsafe + * intrinsification is a far lesser challenge for JIT compilers. + *
  2. To construct flavors of {@link AtomicReferenceArray}. + *
  3. Other use cases exist but are not present in this library yet. + *
+ * + * @author nitsanw + */ +@InternalAPI +public class UnsafeAccess { + public static final boolean SUPPORTS_GET_AND_SET; + public static final Unsafe UNSAFE; + + static { + Unsafe instance; + try { + final Field field = Unsafe.class.getDeclaredField("theUnsafe"); + field.setAccessible(true); + instance = (Unsafe) field.get(null); + } catch (Exception ignored) { + // Some platforms, notably Android, might not have a sun.misc.Unsafe + // implementation with a private `theUnsafe` static instance. In this + // case we can try and call the default constructor, which proves + // sufficient for Android usage. + try { + Constructor c = Unsafe.class.getDeclaredConstructor(); + c.setAccessible(true); + instance = c.newInstance(); + } catch (Exception e) { + SUPPORTS_GET_AND_SET = false; + throw new RuntimeException(e); + } + } + + boolean getAndSetSupport = false; + try { + Unsafe.class.getMethod("getAndSetObject", Object.class, Long.TYPE, Object.class); + getAndSetSupport = true; + } catch (Exception ignored) { + } + + UNSAFE = instance; + SUPPORTS_GET_AND_SET = getAndSetSupport; + } + + public static long fieldOffset(Class clz, String fieldName) throws RuntimeException { + try { + return UNSAFE.objectFieldOffset(clz.getDeclaredField(fieldName)); + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java new file mode 100644 index 000000000..d8309c5c5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/util/UnsafeRefArrayAccess.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.util; + +import static io.rsocket.internal.jctools.util.UnsafeAccess.UNSAFE; + +/** + * A concurrent access enabling class used by circular array based queues this class exposes an + * offset computation method along with differently memory fenced load/store methods into the + * underlying array. The class is pre-padded and the array is padded on either side to help with + * False sharing prvention. It is expected theat subclasses handle post padding. + * + *

Offset calculation is separate from access to enable the reuse of a give compute offset. + * + *

Load/Store methods using a buffer parameter are provided to allow the prevention of + * final field reload after a LoadLoad barrier. + * + *

+ * + * @author nitsanw + */ +@InternalAPI +public final class UnsafeRefArrayAccess { + public static final long REF_ARRAY_BASE; + public static final int REF_ELEMENT_SHIFT; + + static { + final int scale = UnsafeAccess.UNSAFE.arrayIndexScale(Object[].class); + if (4 == scale) { + REF_ELEMENT_SHIFT = 2; + } else if (8 == scale) { + REF_ELEMENT_SHIFT = 3; + } else { + throw new IllegalStateException("Unknown pointer size: " + scale); + } + REF_ARRAY_BASE = UnsafeAccess.UNSAFE.arrayBaseOffset(Object[].class); + } + + /** + * A plain store (no ordering/fences) of an element to a given offset + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @param e an orderly kitty + */ + public static void spElement(E[] buffer, long offset, E e) { + UNSAFE.putObject(buffer, offset, e); + } + + /** + * An ordered store(store + StoreStore barrier) of an element to a given offset + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset} + * @param e an orderly kitty + */ + public static void soElement(E[] buffer, long offset, E e) { + UNSAFE.putOrderedObject(buffer, offset, e); + } + + /** + * A plain load (no ordering/fences) of an element from a given offset. + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @return the element at the offset + */ + @SuppressWarnings("unchecked") + public static E lpElement(E[] buffer, long offset) { + return (E) UNSAFE.getObject(buffer, offset); + } + + /** + * A volatile load (load + LoadLoad barrier) of an element from a given offset. + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcElementOffset(long)} + * @return the element at the offset + */ + @SuppressWarnings("unchecked") + public static E lvElement(E[] buffer, long offset) { + return (E) UNSAFE.getObjectVolatile(buffer, offset); + } + + /** + * @param index desirable element index + * @return the offset in bytes within the array for a given index. + */ + public static long calcElementOffset(long index) { + return REF_ARRAY_BASE + (index << REF_ELEMENT_SHIFT); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/package-info.java b/rsocket-core/src/main/java/io/rsocket/internal/package-info.java index 09918f3d1..07ddfab41 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/package-info.java @@ -18,5 +18,7 @@ * Internal package and must not be used outside this project. There are no guarantees for * API compatibility. */ -@javax.annotation.ParametersAreNonnullByDefault +@NonNullApi package io.rsocket.internal; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java new file mode 100644 index 000000000..6fc96d6d2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java @@ -0,0 +1,8 @@ +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; + +public interface KeepAliveFramesAcceptor { + + void receive(ByteBuf keepAliveFrame); +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java new file mode 100644 index 000000000..2535c342b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java @@ -0,0 +1,57 @@ +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import io.rsocket.keepalive.KeepAliveSupport.KeepAlive; +import io.rsocket.resume.ResumableDuplexConnection; +import java.util.function.Consumer; + +public interface KeepAliveHandler { + + KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onFrameSent, + Consumer onTimeout); + + class DefaultKeepAliveHandler implements KeepAliveHandler { + private final Closeable duplexConnection; + + public DefaultKeepAliveHandler(Closeable duplexConnection) { + this.duplexConnection = duplexConnection; + } + + @Override + public KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onSendKeepAliveFrame, + Consumer onTimeout) { + duplexConnection.onClose().doFinally(s -> keepAliveSupport.stop()).subscribe(); + return keepAliveSupport + .onSendKeepAliveFrame(onSendKeepAliveFrame) + .onTimeout(onTimeout) + .start(); + } + } + + class ResumableKeepAliveHandler implements KeepAliveHandler { + private final ResumableDuplexConnection resumableDuplexConnection; + + public ResumableKeepAliveHandler(ResumableDuplexConnection resumableDuplexConnection) { + this.resumableDuplexConnection = resumableDuplexConnection; + } + + @Override + public KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onSendKeepAliveFrame, + Consumer onTimeout) { + resumableDuplexConnection.onResume(keepAliveSupport::start); + resumableDuplexConnection.onDisconnect(keepAliveSupport::stop); + return keepAliveSupport + .resumeState(resumableDuplexConnection) + .onSendKeepAliveFrame(onSendKeepAliveFrame) + .onTimeout(keepAlive -> resumableDuplexConnection.disconnect()) + .start(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java new file mode 100644 index 000000000..db29d8030 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java @@ -0,0 +1,170 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.resume.ResumeStateHolder; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; + +public abstract class KeepAliveSupport implements KeepAliveFramesAcceptor { + final ByteBufAllocator allocator; + private final Duration keepAliveInterval; + private final Duration keepAliveTimeout; + private final long keepAliveTimeoutMillis; + private volatile Consumer onTimeout; + private volatile Consumer onFrameSent; + private volatile Disposable ticksDisposable; + private final AtomicBoolean started = new AtomicBoolean(); + + private volatile ResumeStateHolder resumeStateHolder; + private volatile long lastReceivedMillis; + + private KeepAliveSupport( + ByteBufAllocator allocator, int keepAliveInterval, int keepAliveTimeout) { + this.allocator = allocator; + this.keepAliveInterval = Duration.ofMillis(keepAliveInterval); + this.keepAliveTimeout = Duration.ofMillis(keepAliveTimeout); + this.keepAliveTimeoutMillis = keepAliveTimeout; + } + + public KeepAliveSupport start() { + this.lastReceivedMillis = System.currentTimeMillis(); + if (started.compareAndSet(false, true)) { + ticksDisposable = Flux.interval(keepAliveInterval).subscribe(v -> onIntervalTick()); + } + return this; + } + + public void stop() { + if (started.compareAndSet(true, false)) { + ticksDisposable.dispose(); + } + } + + @Override + public void receive(ByteBuf keepAliveFrame) { + this.lastReceivedMillis = System.currentTimeMillis(); + if (resumeStateHolder != null) { + long remoteLastReceivedPos = remoteLastReceivedPosition(keepAliveFrame); + resumeStateHolder.onImpliedPosition(remoteLastReceivedPos); + } + if (KeepAliveFrameCodec.respondFlag(keepAliveFrame)) { + long localLastReceivedPos = localLastReceivedPosition(); + send( + KeepAliveFrameCodec.encode( + allocator, + false, + localLastReceivedPos, + KeepAliveFrameCodec.data(keepAliveFrame).retain())); + } + } + + public KeepAliveSupport resumeState(ResumeStateHolder resumeStateHolder) { + this.resumeStateHolder = resumeStateHolder; + return this; + } + + public KeepAliveSupport onSendKeepAliveFrame(Consumer onFrameSent) { + this.onFrameSent = onFrameSent; + return this; + } + + public KeepAliveSupport onTimeout(Consumer onTimeout) { + this.onTimeout = onTimeout; + return this; + } + + abstract void onIntervalTick(); + + void send(ByteBuf frame) { + if (onFrameSent != null) { + onFrameSent.accept(frame); + } + } + + void tryTimeout() { + long now = System.currentTimeMillis(); + if (now - lastReceivedMillis >= keepAliveTimeoutMillis) { + if (onTimeout != null) { + onTimeout.accept(new KeepAlive(keepAliveInterval, keepAliveTimeout)); + } + stop(); + } + } + + long localLastReceivedPosition() { + return resumeStateHolder != null ? resumeStateHolder.impliedPosition() : 0; + } + + long remoteLastReceivedPosition(ByteBuf keepAliveFrame) { + return KeepAliveFrameCodec.lastPosition(keepAliveFrame); + } + + public static final class ServerKeepAliveSupport extends KeepAliveSupport { + + public ServerKeepAliveSupport( + ByteBufAllocator allocator, int keepAlivePeriod, int keepAliveTimeout) { + super(allocator, keepAlivePeriod, keepAliveTimeout); + } + + @Override + void onIntervalTick() { + tryTimeout(); + } + } + + public static final class ClientKeepAliveSupport extends KeepAliveSupport { + + public ClientKeepAliveSupport( + ByteBufAllocator allocator, int keepAliveInterval, int keepAliveTimeout) { + super(allocator, keepAliveInterval, keepAliveTimeout); + } + + @Override + void onIntervalTick() { + tryTimeout(); + send( + KeepAliveFrameCodec.encode( + allocator, true, localLastReceivedPosition(), Unpooled.EMPTY_BUFFER)); + } + } + + public static final class KeepAlive { + private final Duration tickPeriod; + private final Duration timeoutMillis; + + public KeepAlive(Duration tickPeriod, Duration timeoutMillis) { + this.tickPeriod = tickPeriod; + this.timeoutMillis = timeoutMillis; + } + + public Duration getTickPeriod() { + return tickPeriod; + } + + public Duration getTimeout() { + return timeoutMillis; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java b/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java new file mode 100644 index 000000000..d94a93cad --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Support classes for sending and keeping track of KEEPALIVE frames from the remote. */ +@NonNullApi +package io.rsocket.keepalive; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java index 62ce16907..673b4a480 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,20 @@ package io.rsocket.lease; import io.netty.buffer.ByteBuf; -import javax.annotation.Nullable; +import io.netty.buffer.Unpooled; +import io.rsocket.Availability; +import reactor.util.annotation.Nullable; /** A contract for RSocket lease, which is sent by a request acceptor and is time bound. */ -public interface Lease { +public interface Lease extends Availability { + + static Lease create(int timeToLiveMillis, int numberOfRequests, @Nullable ByteBuf metadata) { + return LeaseImpl.create(timeToLiveMillis, numberOfRequests, metadata); + } + + static Lease create(int timeToLiveMillis, int numberOfRequests) { + return create(timeToLiveMillis, numberOfRequests, Unpooled.EMPTY_BUFFER); + } /** * Number of requests allowed by this lease. @@ -30,11 +40,30 @@ public interface Lease { int getAllowedRequests(); /** - * Number of seconds that this lease is valid from the time it is received. + * Initial number of requests allowed by this lease. + * + * @return initial number of requests allowed by this lease. + */ + default int getStartingAllowedRequests() { + throw new UnsupportedOperationException("Not implemented"); + } + + /** + * Number of milliseconds that this lease is valid from the time it is received. * - * @return Number of seconds that this lease is valid from the time it is received. + * @return Number of milliseconds that this lease is valid from the time it is received. */ - int getTtl(); + int getTimeToLiveMillis(); + + /** + * Number of milliseconds that this lease is still valid from now. + * + * @param now millis since epoch + * @return Number of milliseconds that this lease is still valid from now, or 0 if expired. + */ + default int getRemainingTimeToLiveMillis(long now) { + return isEmpty() ? 0 : (int) Math.max(0, expiry() - now); + } /** * Absolute time since epoch at which this lease will expire. @@ -48,7 +77,6 @@ public interface Lease { * * @return Metadata for the lease. */ - @Nullable ByteBuf getMetadata(); /** @@ -69,4 +97,14 @@ default boolean isExpired() { default boolean isExpired(long now) { return now > expiry(); } + + /** Checks if the lease has not expired and there are allowed requests available */ + default boolean isValid() { + return !isExpired() && getAllowedRequests() > 0; + } + + /** Checks if the lease is empty(default value if no lease was received yet) */ + default boolean isEmpty() { + return getAllowedRequests() == 0 && getTimeToLiveMillis() == 0; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java index 010afcda7..7abb8aab9 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,42 +17,51 @@ package io.rsocket.lease; import io.netty.buffer.ByteBuf; -import io.rsocket.frame.LeaseFlyweight; +import io.netty.buffer.Unpooled; +import java.util.concurrent.atomic.AtomicInteger; import reactor.util.annotation.Nullable; -public final class LeaseImpl implements Lease { - - private final int allowedRequests; - private final int ttl; +public class LeaseImpl implements Lease { + private final int timeToLiveMillis; + private final AtomicInteger allowedRequests; + private final int startingAllowedRequests; + private final ByteBuf metadata; private final long expiry; - private final @Nullable ByteBuf metadata; - public LeaseImpl(int allowedRequests, int ttl) { - this(allowedRequests, ttl, null); + static LeaseImpl create(int timeToLiveMillis, int numberOfRequests, @Nullable ByteBuf metadata) { + assertLease(timeToLiveMillis, numberOfRequests); + return new LeaseImpl(timeToLiveMillis, numberOfRequests, metadata); + } + + static LeaseImpl empty() { + return new LeaseImpl(0, 0, null); } - public LeaseImpl(int allowedRequests, int ttl, ByteBuf metadata) { - this.allowedRequests = allowedRequests; - this.ttl = ttl; - expiry = System.currentTimeMillis() + ttl; - this.metadata = metadata; + private LeaseImpl(int timeToLiveMillis, int allowedRequests, @Nullable ByteBuf metadata) { + this.allowedRequests = new AtomicInteger(allowedRequests); + this.startingAllowedRequests = allowedRequests; + this.timeToLiveMillis = timeToLiveMillis; + this.metadata = metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + this.expiry = timeToLiveMillis == 0 ? 0 : now() + timeToLiveMillis; } - public LeaseImpl(ByteBuf leaseFrame) { - this( - LeaseFlyweight.numRequests(leaseFrame), - LeaseFlyweight.ttl(leaseFrame), - LeaseFlyweight.metadata(leaseFrame)); + public int getTimeToLiveMillis() { + return timeToLiveMillis; } @Override public int getAllowedRequests() { - return allowedRequests; + return Math.max(0, allowedRequests.get()); + } + + @Override + public int getStartingAllowedRequests() { + return startingAllowedRequests; } @Override - public int getTtl() { - return ttl; + public ByteBuf getMetadata() { + return metadata; } @Override @@ -61,19 +70,56 @@ public long expiry() { } @Override - public ByteBuf getMetadata() { - return metadata; + public boolean isValid() { + return !isEmpty() && getAllowedRequests() > 0 && !isExpired(); + } + + /** + * try use 1 allowed request of Lease + * + * @return true if used successfully, false if Lease is expired or no allowed requests available + */ + public boolean use() { + if (isExpired()) { + return false; + } + int remaining = + allowedRequests.accumulateAndGet(1, (cur, update) -> Math.max(-1, cur - update)); + return remaining >= 0; + } + + @Override + public double availability() { + return isValid() ? getAllowedRequests() / (double) getStartingAllowedRequests() : 0.0; } @Override public String toString() { + long now = now(); return "LeaseImpl{" - + "allowedRequests=" - + allowedRequests - + ", ttl=" - + ttl - + ", expiry=" - + expiry + + "timeToLiveMillis=" + + timeToLiveMillis + + ", allowedRequests=" + + getAllowedRequests() + + ", startingAllowedRequests=" + + startingAllowedRequests + + ", expired=" + + isExpired(now) + + ", remainingTimeToLiveMillis=" + + getRemainingTimeToLiveMillis(now) + '}'; } + + private static long now() { + return System.currentTimeMillis(); + } + + private static void assertLease(int timeToLiveMillis, int numberOfRequests) { + if (numberOfRequests <= 0) { + throw new IllegalArgumentException("Number of requests must be positive"); + } + if (timeToLiveMillis <= 0) { + throw new IllegalArgumentException("Time-to-live must be positive"); + } + } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java similarity index 55% rename from rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java rename to rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java index 25b443dd6..791f5a023 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/LeaseStats.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,25 +14,15 @@ * limitations under the License. */ -package io.rsocket.transport.netty; +package io.rsocket.lease; -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; +public interface LeaseStats { -final class TcpUriHandlerTest implements UriHandlerTest { + void onEvent(EventType eventType); - @Override - public String getInvalidUri() { - return "http://test"; - } - - @Override - public UriHandler getUriHandler() { - return new TcpUriHandler(); - } - - @Override - public String getValidUri() { - return "tcp://test:9898"; + enum EventType { + ACCEPT, + REJECT, + TERMINATE } } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Leases.java b/rsocket-core/src/main/java/io/rsocket/lease/Leases.java new file mode 100644 index 000000000..4c90e38ce --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/Leases.java @@ -0,0 +1,65 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.lease; + +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; +import reactor.core.publisher.Flux; + +public class Leases { + private static final Function> noopLeaseSender = leaseStats -> Flux.never(); + private static final Consumer> noopLeaseReceiver = leases -> {}; + + private Function> leaseSender = noopLeaseSender; + private Consumer> leaseReceiver = noopLeaseReceiver; + private Optional stats = Optional.empty(); + + public static Leases create() { + return new Leases<>(); + } + + public Leases sender(Function, Flux> leaseSender) { + this.leaseSender = leaseSender; + return this; + } + + public Leases receiver(Consumer> leaseReceiver) { + this.leaseReceiver = leaseReceiver; + return this; + } + + public Leases stats(T stats) { + this.stats = Optional.of(Objects.requireNonNull(stats)); + return this; + } + + @SuppressWarnings("unchecked") + public Function, Flux> sender() { + return (Function, Flux>) leaseSender; + } + + public Consumer> receiver() { + return leaseReceiver; + } + + @SuppressWarnings("unchecked") + public Optional stats() { + return (Optional) stats; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java new file mode 100644 index 000000000..3b6cec62c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.lease; + +import io.rsocket.exceptions.RejectedException; +import java.util.Objects; +import reactor.util.annotation.Nullable; + +public class MissingLeaseException extends RejectedException { + private static final long serialVersionUID = -6169748673403858959L; + + public MissingLeaseException(Lease lease, String tag) { + super(leaseMessage(Objects.requireNonNull(lease), Objects.requireNonNull(tag))); + } + + public MissingLeaseException(String tag) { + super(leaseMessage(null, Objects.requireNonNull(tag))); + } + + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } + + static String leaseMessage(@Nullable Lease lease, String tag) { + if (lease == null) { + return String.format("[%s] Missing leases", tag); + } + if (lease.isEmpty()) { + return String.format("[%s] Lease was not received yet", tag); + } + boolean expired = lease.isExpired(); + int allowedRequests = lease.getAllowedRequests(); + return String.format( + "[%s] Missing leases. Expired: %b, allowedRequests: %d", tag, expired, allowedRequests); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java new file mode 100644 index 000000000..fd569a2c8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java @@ -0,0 +1,113 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.lease; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Availability; +import io.rsocket.frame.LeaseFrameCodec; +import java.util.function.Consumer; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.ReplayProcessor; + +public interface RequesterLeaseHandler extends Availability, Disposable { + + boolean useLease(); + + Exception leaseError(); + + void receive(ByteBuf leaseFrame); + + void dispose(); + + final class Impl implements RequesterLeaseHandler { + private final String tag; + private final ReplayProcessor receivedLease; + private volatile LeaseImpl currentLease = LeaseImpl.empty(); + + public Impl(String tag, Consumer> leaseReceiver) { + this.tag = tag; + receivedLease = ReplayProcessor.create(1); + leaseReceiver.accept(receivedLease); + } + + @Override + public boolean useLease() { + return currentLease.use(); + } + + @Override + public Exception leaseError() { + LeaseImpl l = this.currentLease; + String t = this.tag; + if (!l.isValid()) { + return new MissingLeaseException(l, t); + } else { + return new MissingLeaseException(t); + } + } + + @Override + public void receive(ByteBuf leaseFrame) { + int numberOfRequests = LeaseFrameCodec.numRequests(leaseFrame); + int timeToLiveMillis = LeaseFrameCodec.ttl(leaseFrame); + ByteBuf metadata = LeaseFrameCodec.metadata(leaseFrame); + LeaseImpl lease = LeaseImpl.create(timeToLiveMillis, numberOfRequests, metadata); + currentLease = lease; + receivedLease.onNext(lease); + } + + @Override + public void dispose() { + receivedLease.onComplete(); + } + + @Override + public boolean isDisposed() { + return receivedLease.isTerminated(); + } + + @Override + public double availability() { + return currentLease.availability(); + } + } + + RequesterLeaseHandler None = + new RequesterLeaseHandler() { + @Override + public boolean useLease() { + return true; + } + + @Override + public Exception leaseError() { + throw new AssertionError("Error not possible with NOOP leases handler"); + } + + @Override + public void receive(ByteBuf leaseFrame) {} + + @Override + public void dispose() {} + + @Override + public double availability() { + return 1.0; + } + }; +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java new file mode 100644 index 000000000..df8787cb7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java @@ -0,0 +1,146 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.lease; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Availability; +import io.rsocket.frame.LeaseFrameCodec; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.Flux; +import reactor.util.annotation.Nullable; + +public interface ResponderLeaseHandler extends Availability { + + boolean useLease(); + + Exception leaseError(); + + Disposable send(Consumer leaseFrameSender); + + final class Impl implements ResponderLeaseHandler { + private volatile LeaseImpl currentLease = LeaseImpl.empty(); + private final String tag; + private final ByteBufAllocator allocator; + private final Function, Flux> leaseSender; + private final Optional leaseStatsOption; + private final T leaseStats; + + public Impl( + String tag, + ByteBufAllocator allocator, + Function, Flux> leaseSender, + Optional leaseStatsOption) { + this.tag = tag; + this.allocator = allocator; + this.leaseSender = leaseSender; + this.leaseStatsOption = leaseStatsOption; + this.leaseStats = leaseStatsOption.orElse(null); + } + + @Override + public boolean useLease() { + boolean success = currentLease.use(); + onUseEvent(success, leaseStats); + return success; + } + + @Override + public Exception leaseError() { + LeaseImpl l = currentLease; + String t = tag; + if (!l.isValid()) { + return new MissingLeaseException(l, t); + } else { + return new MissingLeaseException(t); + } + } + + @Override + public Disposable send(Consumer leaseFrameSender) { + return leaseSender + .apply(leaseStatsOption) + .doOnTerminate(this::onTerminateEvent) + .subscribe( + lease -> { + currentLease = create(lease); + leaseFrameSender.accept(createLeaseFrame(lease)); + }); + } + + @Override + public double availability() { + return currentLease.availability(); + } + + private ByteBuf createLeaseFrame(Lease lease) { + return LeaseFrameCodec.encode( + allocator, lease.getTimeToLiveMillis(), lease.getAllowedRequests(), lease.getMetadata()); + } + + private void onTerminateEvent() { + T ls = leaseStats; + if (ls != null) { + ls.onEvent(LeaseStats.EventType.TERMINATE); + } + } + + private void onUseEvent(boolean success, @Nullable T ls) { + if (ls != null) { + LeaseStats.EventType eventType = + success ? LeaseStats.EventType.ACCEPT : LeaseStats.EventType.REJECT; + ls.onEvent(eventType); + } + } + + private static LeaseImpl create(Lease lease) { + if (lease instanceof LeaseImpl) { + return (LeaseImpl) lease; + } else { + return LeaseImpl.create( + lease.getTimeToLiveMillis(), lease.getAllowedRequests(), lease.getMetadata()); + } + } + } + + ResponderLeaseHandler None = + new ResponderLeaseHandler() { + @Override + public boolean useLease() { + return true; + } + + @Override + public Exception leaseError() { + throw new AssertionError("Error not possible with NOOP leases handler"); + } + + @Override + public Disposable send(Consumer leaseFrameSender) { + return Disposables.disposed(); + } + + @Override + public double availability() { + return 1.0; + } + }; +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/package-info.java b/rsocket-core/src/main/java/io/rsocket/lease/package-info.java index 6700c10d9..342ab27f7 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,5 +14,14 @@ * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** + * Contains support classes for the Lease feature of the RSocket protocol. + * + * @see Resuming + * Operation + */ +@NonNullApi package io.rsocket.lease; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java new file mode 100644 index 000000000..d908abb3c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java @@ -0,0 +1,335 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.util.CharByteBufUtil; + +public class AuthMetadataCodec { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + static final int USERNAME_BYTES_LENGTH = 1; + static final int AUTH_TYPE_ID_LENGTH = 1; + + static final char[] EMPTY_CHARS_ARRAY = new char[0]; + + private AuthMetadataCodec() {} + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customAuthType the custom mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code customAuthType} is non US_ASCII string or + * empty string or its length is greater than 128 bytes + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, String customAuthType, ByteBuf metadata) { + + int actualASCIILength = ByteBufUtil.utf8Bytes(customAuthType); + if (actualASCIILength != customAuthType.length()) { + throw new IllegalArgumentException("custom auth type must be US_ASCII characters only"); + } + if (actualASCIILength < 1 || actualASCIILength > 128) { + throw new IllegalArgumentException( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + int capacity = 1 + actualASCIILength; + ByteBuf headerBuffer = allocator.buffer(capacity, capacity); + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + headerBuffer.writeByte(actualASCIILength - 1); + + ByteBufUtil.reserveAndWriteUtf8(headerBuffer, customAuthType, actualASCIILength); + + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to create intermediate buffers as needed. + * @param authType the well-known mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code authType} is {@link + * WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} or {@link + * WellKnownAuthType#UNKNOWN_RESERVED_AUTH_TYPE} + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, WellKnownAuthType authType, ByteBuf metadata) { + + if (authType == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE + || authType == WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE) { + throw new IllegalArgumentException("only allowed AuthType should be used"); + } + + int capacity = AUTH_TYPE_ID_LENGTH; + ByteBuf headerBuffer = + allocator + .buffer(capacity, capacity) + .writeByte(authType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using Simple Authentication format + * + * @throws IllegalArgumentException if the username length is greater than 255 + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param username the char sequence which represents user name. + * @param password the char sequence which represents user password. + */ + public static ByteBuf encodeSimpleMetadata( + ByteBufAllocator allocator, char[] username, char[] password) { + + int usernameLength = CharByteBufUtil.utf8Bytes(username); + if (usernameLength > 255) { + throw new IllegalArgumentException( + "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + } + + int passwordLength = CharByteBufUtil.utf8Bytes(password); + int capacity = AUTH_TYPE_ID_LENGTH + USERNAME_BYTES_LENGTH + usernameLength + passwordLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.SIMPLE.getIdentifier() | STREAM_METADATA_KNOWN_MASK) + .writeByte(usernameLength); + + CharByteBufUtil.writeUtf8(buffer, username); + CharByteBufUtil.writeUtf8(buffer, password); + + return buffer; + } + + /** + * Encode a Authentication CompositeMetadata payload using Bearer Authentication format + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param token the char sequence which represents BEARER token. + */ + public static ByteBuf encodeBearerMetadata(ByteBufAllocator allocator, char[] token) { + + int tokenLength = CharByteBufUtil.utf8Bytes(token); + int capacity = AUTH_TYPE_ID_LENGTH + tokenLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.BEARER.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + CharByteBufUtil.writeUtf8(buffer, token); + + return buffer; + } + + /** + * Encode a new Authentication Metadata payload information, first verifying if the passed {@link + * String} matches a {@link WellKnownAuthType} (in which case it will be encoded in a compressed + * fashion using the mime id of that type). + * + *

Prefer using {@link #encodeMetadata(ByteBufAllocator, String, ByteBuf)} if you already know + * that the mime type is not a {@link WellKnownAuthType}. + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param authType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeMetadata(ByteBufAllocator, WellKnownAuthType, ByteBuf) + * @see #encodeMetadata(ByteBufAllocator, String, ByteBuf) + */ + public static ByteBuf encodeMetadataWithCompression( + ByteBufAllocator allocator, String authType, ByteBuf metadata) { + WellKnownAuthType wkn = WellKnownAuthType.fromString(authType); + if (wkn == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE) { + return AuthMetadataCodec.encodeMetadata(allocator, authType, metadata); + } else { + return AuthMetadataCodec.encodeMetadata(allocator, wkn, metadata); + } + } + + /** + * Get the first {@code byte} from a {@link ByteBuf} and check whether it is length or {@link + * WellKnownAuthType}. Assuming said buffer properly contains such a {@code byte} + * + * @param metadata byteBuf used to get information from + */ + public static boolean isWellKnownAuthType(ByteBuf metadata) { + byte lengthOrId = metadata.getByte(0); + return (lengthOrId & STREAM_METADATA_LENGTH_MASK) != lengthOrId; + } + + /** + * Read first byte from the given {@code metadata} and tries to convert it's value to {@link + * WellKnownAuthType}. + * + * @param metadata given metadata buffer to read from + * @return Return on of the know Auth types or {@link WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} if + * field's value is length or unknown auth type + * @throws IllegalStateException if not enough readable bytes in the given {@link ByteBuf} + */ + public static WellKnownAuthType readWellKnownAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 1) { + throw new IllegalStateException( + "Unable to decode Well Know Auth type. Not enough readable bytes"); + } + byte lengthOrId = metadata.readByte(); + int normalizedId = (byte) (lengthOrId & STREAM_METADATA_LENGTH_MASK); + + if (normalizedId != lengthOrId) { + return WellKnownAuthType.fromIdentifier(normalizedId); + } + + return WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } + + /** + * Read up to 129 bytes from the given metadata in order to get the custom Auth Type + * + * @param metadata + * @return + */ + public static CharSequence readCustomAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 2) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Not enough readable bytes"); + } + + byte encodedLength = metadata.readByte(); + if (encodedLength < 0) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Incorrect auth type length"); + } + + // encoded length is realLength - 1 in order to avoid intersection with 0x00 authtype + int realLength = encodedLength + 1; + if (metadata.readableBytes() < realLength) { + throw new IllegalArgumentException( + "Unable to decode custom Auth type. Malformed length or auth type string"); + } + + return metadata.readCharSequence(realLength, CharsetUtil.US_ASCII); + } + + /** + * Read all remaining {@code bytes} from the given {@link ByteBuf} and return sliced + * representation of a payload + * + * @param metadata metadata to get payload from. Please note, the {@code metadata#readIndex} + * should be set to the beginning of the payload bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if no bytes readable in the + * given one + */ + public static ByteBuf readPayload(ByteBuf metadata) { + if (metadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return metadata.readSlice(metadata.readableBytes()); + } + + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero + */ + public static ByteBuf readUsername(ByteBuf simpleAuthMetadata) { + short usernameLength = readUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read password from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if password length is zero + */ + public static ByteBuf readPassword(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(simpleAuthMetadata.readableBytes()); + } + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return {@code char[]} which represents UTF-8 username + */ + public static char[] readUsernameAsCharArray(ByteBuf simpleAuthMetadata) { + short usernameLength = readUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] readPasswordAsCharArray(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, simpleAuthMetadata.readableBytes()); + } + + /** + * Read all the remaining {@code bytes} from the given {@link ByteBuf} where the first byte is + * username length and the subsequent number of bytes equal to decoded length + * + * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] readBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { + if (bearerAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(bearerAuthMetadata, bearerAuthMetadata.readableBytes()); + } + + private static short readUsernameLength(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() < 1) { + throw new IllegalStateException( + "Unable to decode custom username. Not enough readable bytes"); + } + + short usernameLength = simpleAuthMetadata.readUnsignedByte(); + + if (simpleAuthMetadata.readableBytes() < usernameLength) { + throw new IllegalArgumentException( + "Unable to decode username. Malformed username length or content"); + } + + return usernameLength; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java new file mode 100644 index 000000000..4a48921b1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java @@ -0,0 +1,241 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static io.rsocket.metadata.CompositeMetadataFlyweight.computeNextEntryIndex; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataFlyweight.hasEntry; +import static io.rsocket.metadata.CompositeMetadataFlyweight.isWellKnownMimeType; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.metadata.CompositeMetadata.Entry; +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import reactor.util.annotation.Nullable; + +/** + * An {@link Iterable} wrapper around a {@link ByteBuf} that exposes metadata entry information at + * each decoding step. This is only possible on frame types used to initiate interactions, if the + * SETUP metadata mime type was {@link WellKnownMimeType#MESSAGE_RSOCKET_COMPOSITE_METADATA}. + * + *

This allows efficient incremental decoding of the entries (without moving the source's {@link + * io.netty.buffer.ByteBuf#readerIndex()}). The buffer is assumed to contain just enough bytes to + * represent one or more entries (mime type compressed or not). The decoding stops when the buffer + * reaches 0 readable bytes, and fails if it contains bytes but not enough to correctly decode an + * entry. + * + *

A note on future-proofness: it is possible to come across a compressed mime type that this + * implementation doesn't recognize. This is likely to be due to the use of a byte id that is merely + * reserved in this implementation, but maps to a {@link WellKnownMimeType} in the implementation + * that encoded the metadata. This can be detected by detecting that an entry is a {@link + * ReservedMimeTypeEntry}. In this case {@link Entry#getMimeType()} will return {@code null}. The + * encoded id can be retrieved using {@link ReservedMimeTypeEntry#getType()}. The byte and content + * buffer should be kept around and re-encoded using {@link + * CompositeMetadataFlyweight#encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, byte, + * ByteBuf)} in case passing that entry through is required. + */ +public final class CompositeMetadata implements Iterable { + + private final boolean retainSlices; + + private final ByteBuf source; + + public CompositeMetadata(ByteBuf source, boolean retainSlices) { + this.source = source; + this.retainSlices = retainSlices; + } + + /** + * Turn this {@link CompositeMetadata} into a sequential {@link Stream}. + * + * @return the composite metadata sequential {@link Stream} + */ + public Stream stream() { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize( + iterator(), Spliterator.DISTINCT | Spliterator.NONNULL | Spliterator.ORDERED), + false); + } + + /** + * An {@link Iterator} that lazily decodes {@link Entry} in this composite metadata. + * + * @return the composite metadata {@link Iterator} + */ + @Override + public Iterator iterator() { + return new Iterator() { + + private int entryIndex = 0; + + @Override + public boolean hasNext() { + return hasEntry(CompositeMetadata.this.source, this.entryIndex); + } + + @Override + public Entry next() { + ByteBuf[] headerAndData = + decodeMimeAndContentBuffersSlices( + CompositeMetadata.this.source, + this.entryIndex, + CompositeMetadata.this.retainSlices); + + ByteBuf header = headerAndData[0]; + ByteBuf data = headerAndData[1]; + + this.entryIndex = computeNextEntryIndex(this.entryIndex, header, data); + + if (!isWellKnownMimeType(header)) { + CharSequence typeString = decodeMimeTypeFromMimeBuffer(header); + if (typeString == null) { + throw new IllegalStateException("MIME type cannot be null"); + } + + return new ExplicitMimeTimeEntry(data, typeString.toString()); + } + + byte id = decodeMimeIdFromMimeBuffer(header); + WellKnownMimeType type = WellKnownMimeType.fromIdentifier(id); + + if (WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE == type) { + return new ReservedMimeTypeEntry(data, id); + } + + return new WellKnownMimeTypeEntry(data, type); + } + }; + } + + /** An entry in the {@link CompositeMetadata}. */ + public interface Entry { + + /** + * Returns the un-decoded content of the {@link Entry}. + * + * @return the un-decoded content of the {@link Entry} + */ + ByteBuf getContent(); + + /** + * Returns the MIME type of the entry, if it can be decoded. + * + * @return the MIME type of the entry, if it can be decoded, otherwise {@code null}. + */ + @Nullable + String getMimeType(); + } + + /** An {@link Entry} backed by an explicitly declared MIME type. */ + public static final class ExplicitMimeTimeEntry implements Entry { + + private final ByteBuf content; + + private final String type; + + public ExplicitMimeTimeEntry(ByteBuf content, String type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.type; + } + } + + /** + * An {@link Entry} backed by a {@link WellKnownMimeType} entry, but one that is not understood by + * this implementation. + */ + public static final class ReservedMimeTypeEntry implements Entry { + private final ByteBuf content; + private final int type; + + public ReservedMimeTypeEntry(ByteBuf content, int type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + /** + * {@inheritDoc} Since this entry represents a compressed id that couldn't be decoded, this is + * always {@code null}. + */ + @Override + public String getMimeType() { + return null; + } + + /** + * Returns the reserved, but unknown {@link WellKnownMimeType} for this entry. Range is 0-127 + * (inclusive). + * + * @return the reserved, but unknown {@link WellKnownMimeType} for this entry + */ + public int getType() { + return this.type; + } + } + + /** An {@link Entry} backed by a {@link WellKnownMimeType}. */ + public static final class WellKnownMimeTypeEntry implements Entry { + + private final ByteBuf content; + private final WellKnownMimeType type; + + public WellKnownMimeTypeEntry(ByteBuf content, WellKnownMimeType type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.type.getString(); + } + + /** + * Returns the {@link WellKnownMimeType} for this entry. + * + * @return the {@link WellKnownMimeType} for this entry + */ + public WellKnownMimeType getType() { + return this.type; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java new file mode 100644 index 000000000..5e00abba8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java @@ -0,0 +1,385 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.util.NumberUtils; +import reactor.util.annotation.Nullable; + +/** + * A flyweight class that can be used to encode/decode composite metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * CompositeMetadata} for an Iterator-like approach to decoding entries. + */ +public class CompositeMetadataCodec { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private CompositeMetadataCodec() {} + + public static int computeNextEntryIndex( + int currentEntryIndex, ByteBuf headerSlice, ByteBuf contentSlice) { + return currentEntryIndex + + headerSlice.readableBytes() // this includes the mime length byte + + 3 // 3 bytes of the content length, which are excluded from the slice + + contentSlice.readableBytes(); + } + + /** + * Decode the next metadata entry (a mime header + content pair of {@link ByteBuf}) from a {@link + * ByteBuf} that contains at least enough bytes for one more such entry. These buffers are + * actually slices of the full metadata buffer, and this method doesn't move the full metadata + * buffer's {@link ByteBuf#readerIndex()}. As such, it requires the user to provide an {@code + * index} to read from. The next index is computed by calling {@link #computeNextEntryIndex(int, + * ByteBuf, ByteBuf)}. Size of the first buffer (the "header buffer") drives which decoding method + * should be further applied to it. + * + *

The header buffer is either: + * + *

    + *
  • made up of a single byte: this represents an encoded mime id, which can be further + * decoded using {@link #decodeMimeIdFromMimeBuffer(ByteBuf)} + *
  • made up of 2 or more bytes: this represents an encoded mime String + its length, which + * can be further decoded using {@link #decodeMimeTypeFromMimeBuffer(ByteBuf)}. Note the + * encoded length, in the first byte, is skipped by this decoding method because the + * remaining length of the buffer is that of the mime string. + *
+ * + * @param compositeMetadata the source {@link ByteBuf} that originally contains one or more + * metadata entries + * @param entryIndex the {@link ByteBuf#readerIndex()} to start decoding from. original reader + * index is kept on the source buffer + * @param retainSlices should produced metadata entry buffers {@link ByteBuf#slice() slices} be + * {@link ByteBuf#retainedSlice() retained}? + * @return a {@link ByteBuf} array of length 2 containing the mime header buffer + * slice and the content buffer slice, or one of the + * zero-length error constant arrays + */ + public static ByteBuf[] decodeMimeAndContentBuffersSlices( + ByteBuf compositeMetadata, int entryIndex, boolean retainSlices) { + compositeMetadata.markReaderIndex(); + compositeMetadata.readerIndex(entryIndex); + + if (compositeMetadata.isReadable()) { + ByteBuf mime; + int ridx = compositeMetadata.readerIndex(); + byte mimeIdOrLength = compositeMetadata.readByte(); + if ((mimeIdOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + mime = + retainSlices + ? compositeMetadata.retainedSlice(ridx, 1) + : compositeMetadata.slice(ridx, 1); + } else { + // M flag unset, remaining 7 bits are the length of the mime + int mimeLength = Byte.toUnsignedInt(mimeIdOrLength) + 1; + + if (compositeMetadata.isReadable( + mimeLength)) { // need to be able to read an extra mimeLength bytes + // here we need a way for the returned ByteBuf to differentiate between a + // 1-byte length mime type and a 1 byte encoded mime id, preferably without + // re-applying the byte mask. The easiest way is to include the initial byte + // and have further decoding ignore the first byte. 1 byte buffer == id, 2+ byte + // buffer == full mime string. + mime = + retainSlices + ? + // we accommodate that we don't read from current readerIndex, but + // readerIndex - 1 ("0"), for a total slice size of mimeLength + 1 + compositeMetadata.retainedSlice(ridx, mimeLength + 1) + : compositeMetadata.slice(ridx, mimeLength + 1); + // we thus need to skip the bytes we just sliced, but not the flag/length byte + // which was already skipped in initial read + compositeMetadata.skipBytes(mimeLength); + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + + if (compositeMetadata.isReadable(3)) { + // ensures the length medium can be read + final int metadataLength = compositeMetadata.readUnsignedMedium(); + if (compositeMetadata.isReadable(metadataLength)) { + ByteBuf metadata = + retainSlices + ? compositeMetadata.readRetainedSlice(metadataLength) + : compositeMetadata.readSlice(metadataLength); + compositeMetadata.resetReaderIndex(); + return new ByteBuf[] {mime, metadata}; + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + compositeMetadata.resetReaderIndex(); + throw new IllegalArgumentException( + String.format("entry index %d is larger than buffer size", entryIndex)); + } + + /** + * Decode a {@code byte} compressed mime id from a {@link ByteBuf}, assuming said buffer properly + * contains such an id. + * + *

The buffer must have exactly one readable byte, which is assumed to have been tested for + * mime id encoding via the {@link #STREAM_METADATA_KNOWN_MASK} mask ({@code firstByte & + * STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). + * + *

If there is no readable byte, the negative identifier of {@link + * WellKnownMimeType#UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeBuffer the buffer that should next contain the compressed mime id byte + * @return the compressed mime id, between 0 and 127, or a negative id if the input is invalid + * @see #decodeMimeTypeFromMimeBuffer(ByteBuf) + */ + public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { + if (mimeBuffer.readableBytes() != 1) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier(); + } + return (byte) (mimeBuffer.readByte() & STREAM_METADATA_LENGTH_MASK); + } + + /** + * Decode a {@link CharSequence} custome mime type from a {@link ByteBuf}, assuming said buffer + * properly contains such a mime type. + * + *

The buffer must at least have two readable bytes, which distinguishes it from the {@link + * #decodeMimeIdFromMimeBuffer(ByteBuf) compressed id} case. The first byte is a size and the + * remaining bytes must correspond to the {@link CharSequence}, encoded fully in US_ASCII. As a + * result, the first byte can simply be skipped, and the remaining of the buffer be decoded to the + * mime type. + * + *

If the mime header buffer is less than 2 bytes long, returns {@code null}. + * + * @param flyweightMimeBuffer the mime header {@link ByteBuf} that contains length + custom mime + * type + * @return the decoded custom mime type, as a {@link CharSequence}, or null if the input is + * invalid + * @see #decodeMimeIdFromMimeBuffer(ByteBuf) + */ + @Nullable + public static CharSequence decodeMimeTypeFromMimeBuffer(ByteBuf flyweightMimeBuffer) { + if (flyweightMimeBuffer.readableBytes() < 2) { + throw new IllegalStateException("unable to decode explicit MIME type"); + } + // the encoded length is assumed to be kept at the start of the buffer + // but also assumed to be irrelevant because the rest of the slice length + // actually already matches _decoded_length + flyweightMimeBuffer.skipBytes(1); + int mimeStringLength = flyweightMimeBuffer.readableBytes(); + return flyweightMimeBuffer.readCharSequence(mimeStringLength, CharsetUtil.US_ASCII); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, without checking if the {@link String} can be matched with a well known compressable + * mime type. Prefer using this method and {@link #encodeAndAddMetadata(CompositeByteBuf, + * ByteBufAllocator, WellKnownMimeType, ByteBuf)} if you know in advance whether or not the mime + * is well known. Otherwise use {@link #encodeAndAddMetadataWithCompression(CompositeByteBuf, + * ByteBufAllocator, String, ByteBuf)} + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customMimeType the custom mime type to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String customMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, customMimeType, metadata.readableBytes()), metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + WellKnownMimeType knownMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, knownMimeType.getIdentifier(), metadata.readableBytes()), + metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, first verifying if the passed {@link String} matches a {@link WellKnownMimeType} (in + * which case it will be encoded in a compressed fashion using the mime id of that type). + * + *

Prefer using {@link #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, String, + * ByteBuf)} if you already know that the mime type is not a {@link WellKnownMimeType}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param mimeType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, WellKnownMimeType, ByteBuf) + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadataWithCompression( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String mimeType, + ByteBuf metadata) { + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, mimeType, metadata.readableBytes()), metadata); + } else { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, wkn.getIdentifier(), metadata.readableBytes()), + metadata); + } + } + + /** + * Returns whether there is another entry available at a given index + * + * @param compositeMetadata the buffer to inspect + * @param entryIndex the index to check at + * @return whether there is another entry available at a given index + */ + public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { + return compositeMetadata.writerIndex() - entryIndex > 0; + } + + /** + * Returns whether the header represents a well-known MIME type. + * + * @param header the header to inspect + * @return whether the header represents a well-known MIME type + */ + public static boolean isWellKnownMimeType(ByteBuf header) { + return header.readableBytes() == 1; + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param unknownCompressedMimeType the id of the {@link + * WellKnownMimeType#UNKNOWN_RESERVED_MIME_TYPE} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + byte unknownCompressedMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, unknownCompressedMimeType, metadata.readableBytes()), + metadata); + } + + /** + * Encode a custom mime type and a metadata value length into a newly allocated {@link ByteBuf}. + * + *

This larger representation encodes the mime type representation's length on a single byte, + * then the representation itself, then the unsigned metadata value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param customMime a custom mime type to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, String customMime, int metadataLength) { + ByteBuf metadataHeader = allocator.buffer(4 + customMime.length()); + // reserve 1 byte for the customMime length + // /!\ careful not to read that first byte, which is random at this point + int writerIndexInitial = metadataHeader.writerIndex(); + metadataHeader.writerIndex(writerIndexInitial + 1); + + // write the custom mime in UTF8 but validate it is all ASCII-compatible + // (which produces the right result since ASCII chars are still encoded on 1 byte in UTF8) + int customMimeLength = ByteBufUtil.writeUtf8(metadataHeader, customMime); + if (!ByteBufUtil.isText( + metadataHeader, metadataHeader.readerIndex() + 1, customMimeLength, CharsetUtil.US_ASCII)) { + metadataHeader.release(); + throw new IllegalArgumentException("custom mime type must be US_ASCII characters only"); + } + if (customMimeLength < 1 || customMimeLength > 128) { + metadataHeader.release(); + throw new IllegalArgumentException( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + metadataHeader.markWriterIndex(); + + // go back to beginning and write the length + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + metadataHeader.writerIndex(writerIndexInitial); + metadataHeader.writeByte(customMimeLength - 1); + + // go back to post-mime type and write the metadata content length + metadataHeader.resetWriterIndex(); + NumberUtils.encodeUnsignedMedium(metadataHeader, metadataLength); + + return metadataHeader; + } + + /** + * Encode a {@link WellKnownMimeType well known mime type} and a metadata value length into a + * newly allocated {@link ByteBuf}. + * + *

This compact representation encodes the mime type via its ID on a single byte, and the + * unsigned value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param mimeType a byte identifier of a {@link WellKnownMimeType} to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, byte mimeType, int metadataLength) { + ByteBuf buffer = allocator.buffer(4, 4).writeByte(mimeType | STREAM_METADATA_KNOWN_MASK); + + NumberUtils.encodeUnsignedMedium(buffer, metadataLength); + + return buffer; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java new file mode 100644 index 000000000..9916dfd3b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java @@ -0,0 +1,262 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import reactor.util.annotation.Nullable; + +/** + * A flyweight class that can be used to encode/decode composite metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * CompositeMetadata} for an Iterator-like approach to decoding entries. + * + * @deprecated in favor of {@link CompositeMetadataCodec} + */ +@Deprecated +public class CompositeMetadataFlyweight { + + private CompositeMetadataFlyweight() {} + + public static int computeNextEntryIndex( + int currentEntryIndex, ByteBuf headerSlice, ByteBuf contentSlice) { + return CompositeMetadataCodec.computeNextEntryIndex( + currentEntryIndex, headerSlice, contentSlice); + } + + /** + * Decode the next metadata entry (a mime header + content pair of {@link ByteBuf}) from a {@link + * ByteBuf} that contains at least enough bytes for one more such entry. These buffers are + * actually slices of the full metadata buffer, and this method doesn't move the full metadata + * buffer's {@link ByteBuf#readerIndex()}. As such, it requires the user to provide an {@code + * index} to read from. The next index is computed by calling {@link #computeNextEntryIndex(int, + * ByteBuf, ByteBuf)}. Size of the first buffer (the "header buffer") drives which decoding method + * should be further applied to it. + * + *

The header buffer is either: + * + *

    + *
  • made up of a single byte: this represents an encoded mime id, which can be further + * decoded using {@link #decodeMimeIdFromMimeBuffer(ByteBuf)} + *
  • made up of 2 or more bytes: this represents an encoded mime String + its length, which + * can be further decoded using {@link #decodeMimeTypeFromMimeBuffer(ByteBuf)}. Note the + * encoded length, in the first byte, is skipped by this decoding method because the + * remaining length of the buffer is that of the mime string. + *
+ * + * @param compositeMetadata the source {@link ByteBuf} that originally contains one or more + * metadata entries + * @param entryIndex the {@link ByteBuf#readerIndex()} to start decoding from. original reader + * index is kept on the source buffer + * @param retainSlices should produced metadata entry buffers {@link ByteBuf#slice() slices} be + * {@link ByteBuf#retainedSlice() retained}? + * @return a {@link ByteBuf} array of length 2 containing the mime header buffer + * slice and the content buffer slice, or one of the + * zero-length error constant arrays + */ + public static ByteBuf[] decodeMimeAndContentBuffersSlices( + ByteBuf compositeMetadata, int entryIndex, boolean retainSlices) { + return CompositeMetadataCodec.decodeMimeAndContentBuffersSlices( + compositeMetadata, entryIndex, retainSlices); + } + + /** + * Decode a {@code byte} compressed mime id from a {@link ByteBuf}, assuming said buffer properly + * contains such an id. + * + *

The buffer must have exactly one readable byte, which is assumed to have been tested for + * mime id encoding via the {@link CompositeMetadataCodec#STREAM_METADATA_KNOWN_MASK} mask ({@code + * firstByte & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). + * + *

If there is no readable byte, the negative identifier of {@link + * WellKnownMimeType#UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeBuffer the buffer that should next contain the compressed mime id byte + * @return the compressed mime id, between 0 and 127, or a negative id if the input is invalid + * @see #decodeMimeTypeFromMimeBuffer(ByteBuf) + */ + public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { + return CompositeMetadataCodec.decodeMimeIdFromMimeBuffer(mimeBuffer); + } + + /** + * Decode a {@link CharSequence} custome mime type from a {@link ByteBuf}, assuming said buffer + * properly contains such a mime type. + * + *

The buffer must at least have two readable bytes, which distinguishes it from the {@link + * #decodeMimeIdFromMimeBuffer(ByteBuf) compressed id} case. The first byte is a size and the + * remaining bytes must correspond to the {@link CharSequence}, encoded fully in US_ASCII. As a + * result, the first byte can simply be skipped, and the remaining of the buffer be decoded to the + * mime type. + * + *

If the mime header buffer is less than 2 bytes long, returns {@code null}. + * + * @param flyweightMimeBuffer the mime header {@link ByteBuf} that contains length + custom mime + * type + * @return the decoded custom mime type, as a {@link CharSequence}, or null if the input is + * invalid + * @see #decodeMimeIdFromMimeBuffer(ByteBuf) + */ + @Nullable + public static CharSequence decodeMimeTypeFromMimeBuffer(ByteBuf flyweightMimeBuffer) { + return CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(flyweightMimeBuffer); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, without checking if the {@link String} can be matched with a well known compressable + * mime type. Prefer using this method and {@link #encodeAndAddMetadata(CompositeByteBuf, + * ByteBufAllocator, WellKnownMimeType, ByteBuf)} if you know in advance whether or not the mime + * is well known. Otherwise use {@link #encodeAndAddMetadataWithCompression(CompositeByteBuf, + * ByteBufAllocator, String, ByteBuf)} + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customMimeType the custom mime type to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String customMimeType, + ByteBuf metadata) { + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetaData, allocator, customMimeType, metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + WellKnownMimeType knownMimeType, + ByteBuf metadata) { + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetaData, allocator, knownMimeType, metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, first verifying if the passed {@link String} matches a {@link WellKnownMimeType} (in + * which case it will be encoded in a compressed fashion using the mime id of that type). + * + *

Prefer using {@link #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, String, + * ByteBuf)} if you already know that the mime type is not a {@link WellKnownMimeType}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param mimeType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, WellKnownMimeType, ByteBuf) + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadataWithCompression( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String mimeType, + ByteBuf metadata) { + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + compositeMetaData, allocator, mimeType, metadata); + } + + /** + * Returns whether there is another entry available at a given index + * + * @param compositeMetadata the buffer to inspect + * @param entryIndex the index to check at + * @return whether there is another entry available at a given index + */ + public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { + return CompositeMetadataCodec.hasEntry(compositeMetadata, entryIndex); + } + + /** + * Returns whether the header represents a well-known MIME type. + * + * @param header the header to inspect + * @return whether the header represents a well-known MIME type + */ + public static boolean isWellKnownMimeType(ByteBuf header) { + return CompositeMetadataCodec.isWellKnownMimeType(header); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param unknownCompressedMimeType the id of the {@link + * WellKnownMimeType#UNKNOWN_RESERVED_MIME_TYPE} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + byte unknownCompressedMimeType, + ByteBuf metadata) { + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetaData, allocator, unknownCompressedMimeType, metadata); + } + + /** + * Encode a custom mime type and a metadata value length into a newly allocated {@link ByteBuf}. + * + *

This larger representation encodes the mime type representation's length on a single byte, + * then the representation itself, then the unsigned metadata value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param customMime a custom mime type to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, String customMime, int metadataLength) { + return CompositeMetadataCodec.encodeMetadataHeader(allocator, customMime, metadataLength); + } + + /** + * Encode a {@link WellKnownMimeType well known mime type} and a metadata value length into a + * newly allocated {@link ByteBuf}. + * + *

This compact representation encodes the mime type via its ID on a single byte, and the + * unsigned value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param mimeType a byte identifier of a {@link WellKnownMimeType} to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, byte mimeType, int metadataLength) { + return CompositeMetadataCodec.encodeMetadataHeader(allocator, mimeType, metadataLength); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java new file mode 100644 index 000000000..d1f2643dc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java @@ -0,0 +1,18 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; + +/** + * Routing Metadata extension from + * https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + * + * @author linux_china + */ +public class RoutingMetadata extends TaggingMetadata { + private static final WellKnownMimeType ROUTING_MIME_TYPE = + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING; + + public RoutingMetadata(ByteBuf content) { + super(ROUTING_MIME_TYPE.getString(), content); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java new file mode 100644 index 000000000..e22d97106 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java @@ -0,0 +1,64 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +/** + * Tagging metadata from https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + * + * @author linux_china + */ +public class TaggingMetadata implements Iterable, CompositeMetadata.Entry { + /** Tag max length in bytes */ + private static int TAG_LENGTH_MAX = 0xFF; + + private String mimeType; + private ByteBuf content; + + public TaggingMetadata(String mimeType, ByteBuf content) { + this.mimeType = mimeType; + this.content = content; + } + + public Stream stream() { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize( + iterator(), Spliterator.DISTINCT | Spliterator.NONNULL | Spliterator.ORDERED), + false); + } + + @Override + public Iterator iterator() { + return new Iterator() { + @Override + public boolean hasNext() { + return content.readerIndex() < content.capacity(); + } + + @Override + public String next() { + int tagLength = TAG_LENGTH_MAX & content.readByte(); + if (tagLength > 0) { + return content.readSlice(tagLength).toString(StandardCharsets.UTF_8); + } else { + return ""; + } + } + }; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.mimeType; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java new file mode 100644 index 000000000..d766cf59f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java @@ -0,0 +1,76 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import java.nio.charset.StandardCharsets; +import java.util.Collection; + +/** + * A flyweight class that can be used to encode/decode tagging metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * TaggingMetadata} for an Iterator-like approach to decoding entries. + * + * @author linux_china + */ +public class TaggingMetadataCodec { + /** Tag max length in bytes */ + private static int TAG_LENGTH_MAX = 0xFF; + + /** + * create routing metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return routing metadata + */ + public static RoutingMetadata createRoutingMetadata( + ByteBufAllocator allocator, Collection tags) { + return new RoutingMetadata(createTaggingContent(allocator, tags)); + } + + /** + * create tagging metadata from composite metadata entry + * + * @param entry composite metadata entry + * @return tagging metadata + */ + public static TaggingMetadata createTaggingMetadata(CompositeMetadata.Entry entry) { + return new TaggingMetadata(entry.getMimeType(), entry.getContent()); + } + + /** + * create tagging metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param tags tag values + * @return Tagging Metadata + */ + public static TaggingMetadata createTaggingMetadata( + ByteBufAllocator allocator, String knownMimeType, Collection tags) { + return new TaggingMetadata(knownMimeType, createTaggingContent(allocator, tags)); + } + + /** + * create tagging content + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return tagging content + */ + public static ByteBuf createTaggingContent(ByteBufAllocator allocator, Collection tags) { + CompositeByteBuf taggingContent = allocator.compositeBuffer(); + for (String key : tags) { + int length = ByteBufUtil.utf8Bytes(key); + if (length == 0 || length > TAG_LENGTH_MAX) { + continue; + } + ByteBuf byteBuf = allocator.buffer().writeByte(length); + byteBuf.writeCharSequence(key, StandardCharsets.UTF_8); + taggingContent.addComponent(true, byteBuf); + } + return taggingContent; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java new file mode 100644 index 000000000..718528358 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java @@ -0,0 +1,62 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import java.util.Collection; + +/** + * A flyweight class that can be used to encode/decode tagging metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * TaggingMetadata} for an Iterator-like approach to decoding entries. + * + * @deprecated in favor of {@link TaggingMetadataCodec} + * @author linux_china + */ +@Deprecated +public class TaggingMetadataFlyweight { + /** + * create routing metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return routing metadata + */ + public static RoutingMetadata createRoutingMetadata( + ByteBufAllocator allocator, Collection tags) { + return TaggingMetadataCodec.createRoutingMetadata(allocator, tags); + } + + /** + * create tagging metadata from composite metadata entry + * + * @param entry composite metadata entry + * @return tagging metadata + */ + public static TaggingMetadata createTaggingMetadata(CompositeMetadata.Entry entry) { + return TaggingMetadataCodec.createTaggingMetadata(entry); + } + + /** + * create tagging metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param tags tag values + * @return Tagging Metadata + */ + public static TaggingMetadata createTaggingMetadata( + ByteBufAllocator allocator, String knownMimeType, Collection tags) { + return TaggingMetadataCodec.createTaggingMetadata(allocator, knownMimeType, tags); + } + + /** + * create tagging content + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return tagging content + */ + public static ByteBuf createTaggingContent(ByteBufAllocator allocator, Collection tags) { + return TaggingMetadataCodec.createTaggingContent(allocator, tags); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java new file mode 100644 index 000000000..d276a9436 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +/** + * Represents decoded tracing metadata which is fully compatible with Zipkin B3 propagation + * + * @since 1.0 + */ +public final class TracingMetadata { + + final long traceIdHigh; + final long traceId; + private final boolean hasParentId; + final long parentId; + final long spanId; + final boolean isEmpty; + final boolean isNotSampled; + final boolean isSampled; + final boolean isDebug; + + TracingMetadata( + long traceIdHigh, + long traceId, + long spanId, + boolean hasParentId, + long parentId, + boolean isEmpty, + boolean isNotSampled, + boolean isSampled, + boolean isDebug) { + this.traceIdHigh = traceIdHigh; + this.traceId = traceId; + this.spanId = spanId; + this.hasParentId = hasParentId; + this.parentId = parentId; + this.isEmpty = isEmpty; + this.isNotSampled = isNotSampled; + this.isSampled = isSampled; + this.isDebug = isDebug; + } + + /** When non-zero, the trace containing this span uses 128-bit trace identifiers. */ + public long traceIdHigh() { + return traceIdHigh; + } + + /** Unique 8-byte identifier for a trace, set on all spans within it. */ + public long traceId() { + return traceId; + } + + /** Indicates if the parent's {@link #spanId} or if this the root span in a trace. */ + public final boolean hasParent() { + return hasParentId; + } + + /** Returns the parent's {@link #spanId} where zero implies absent. */ + public long parentId() { + return parentId; + } + + /** + * Unique 8-byte identifier of this span within a trace. + * + *

A span is uniquely identified in storage by ({@linkplain #traceId}, {@linkplain #spanId}). + */ + public long spanId() { + return spanId; + } + + /** Indicates that trace IDs should be accepted for tracing. */ + public boolean isSampled() { + return isSampled; + } + + /** Indicates that trace IDs should be force traced. */ + public boolean isDebug() { + return isDebug; + } + + /** Includes that there is sampling information and no trace IDs. */ + public boolean isEmpty() { + return isEmpty; + } + + /** + * Indicated that sampling decision is present. If {@code false} means that decision is unknown + * and says explicitly that {@link #isDebug()} and {@link #isSampled()} also returns {@code + * false}. + */ + public boolean isDecided() { + return isNotSampled || isDebug || isSampled; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java new file mode 100644 index 000000000..eb44956f6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java @@ -0,0 +1,172 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +/** + * Represents codes for tracing metadata which is fully compatible with Zipkin B3 propagation + * + * @since 1.0 + */ +public class TracingMetadataCodec { + + static final int FLAG_EXTENDED_TRACE_ID_SIZE = 0b0000_1000; + static final int FLAG_INCLUDE_PARENT_ID = 0b0000_0100; + static final int FLAG_NOT_SAMPLED = 0b0001_0000; + static final int FLAG_SAMPLED = 0b0010_0000; + static final int FLAG_DEBUG = 0b0100_0000; + static final int FLAG_IDS_SET = 0b1000_0000; + + public static ByteBuf encodeEmpty(ByteBufAllocator allocator, Flags flag) { + + return encode(allocator, true, 0, 0, false, 0, 0, false, flag); + } + + public static ByteBuf encode128( + ByteBufAllocator allocator, + long traceIdHigh, + long traceId, + long spanId, + long parentId, + Flags flag) { + + return encode(allocator, false, traceIdHigh, traceId, true, spanId, parentId, true, flag); + } + + public static ByteBuf encode128( + ByteBufAllocator allocator, long traceIdHigh, long traceId, long spanId, Flags flag) { + + return encode(allocator, false, traceIdHigh, traceId, true, spanId, 0, false, flag); + } + + public static ByteBuf encode64( + ByteBufAllocator allocator, long traceId, long spanId, long parentId, Flags flag) { + + return encode(allocator, false, 0, traceId, false, spanId, parentId, true, flag); + } + + public static ByteBuf encode64( + ByteBufAllocator allocator, long traceId, long spanId, Flags flag) { + return encode(allocator, false, 0, traceId, false, spanId, 0, false, flag); + } + + static ByteBuf encode( + ByteBufAllocator allocator, + boolean isEmpty, + long traceIdHigh, + long traceId, + boolean extendedTraceId, + long spanId, + long parentId, + boolean includesParent, + Flags flag) { + int size = + 1 + + (isEmpty + ? 0 + : (Long.BYTES + + Long.BYTES + + (extendedTraceId ? Long.BYTES : 0) + + (includesParent ? Long.BYTES : 0))); + final ByteBuf buffer = allocator.buffer(size); + + int byteFlags = 0; + switch (flag) { + case NOT_SAMPLE: + byteFlags |= FLAG_NOT_SAMPLED; + break; + case SAMPLE: + byteFlags |= FLAG_SAMPLED; + break; + case DEBUG: + byteFlags |= FLAG_DEBUG; + break; + } + + if (isEmpty) { + return buffer.writeByte(byteFlags); + } + + byteFlags |= FLAG_IDS_SET; + + if (extendedTraceId) { + byteFlags |= FLAG_EXTENDED_TRACE_ID_SIZE; + } + + if (includesParent) { + byteFlags |= FLAG_INCLUDE_PARENT_ID; + } + + buffer.writeByte(byteFlags); + + if (extendedTraceId) { + buffer.writeLong(traceIdHigh); + } + + buffer.writeLong(traceId).writeLong(spanId); + + if (includesParent) { + buffer.writeLong(parentId); + } + + return buffer; + } + + public static TracingMetadata decode(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + try { + byte flags = byteBuf.readByte(); + boolean isNotSampled = (flags & FLAG_NOT_SAMPLED) == FLAG_NOT_SAMPLED; + boolean isSampled = (flags & FLAG_SAMPLED) == FLAG_SAMPLED; + boolean isDebug = (flags & FLAG_DEBUG) == FLAG_DEBUG; + boolean isIDSet = (flags & FLAG_IDS_SET) == FLAG_IDS_SET; + + if (!isIDSet) { + return new TracingMetadata(0, 0, 0, false, 0, true, isNotSampled, isSampled, isDebug); + } + + boolean extendedTraceId = + (flags & FLAG_EXTENDED_TRACE_ID_SIZE) == FLAG_EXTENDED_TRACE_ID_SIZE; + + long traceIdHigh; + if (extendedTraceId) { + traceIdHigh = byteBuf.readLong(); + } else { + traceIdHigh = 0; + } + + long traceId = byteBuf.readLong(); + long spanId = byteBuf.readLong(); + + boolean includesParent = (flags & FLAG_INCLUDE_PARENT_ID) == FLAG_INCLUDE_PARENT_ID; + + long parentId; + if (includesParent) { + parentId = byteBuf.readLong(); + } else { + parentId = 0; + } + + return new TracingMetadata( + traceIdHigh, + traceId, + spanId, + includesParent, + parentId, + false, + isNotSampled, + isSampled, + isDebug); + } finally { + byteBuf.resetReaderIndex(); + } + } + + public enum Flags { + UNDECIDED, + NOT_SAMPLE, + SAMPLE, + DEBUG + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java new file mode 100644 index 000000000..66c98701c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Auth Types, as defined in the eponymous extension. Such auth types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownAuthType { + UNPARSEABLE_AUTH_TYPE("UNPARSEABLE_AUTH_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_AUTH_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + SIMPLE("simple", (byte) 0x00), + BEARER("bearer", (byte) 0x01); + // ... reserved for future use ... + + static final WellKnownAuthType[] TYPES_BY_AUTH_ID; + static final Map TYPES_BY_AUTH_STRING; + + static { + // precompute an array of all valid auth ids, filling the blanks with the RESERVED enum + TYPES_BY_AUTH_ID = new WellKnownAuthType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_AUTH_ID, UNKNOWN_RESERVED_AUTH_TYPE); + // also prepare a Map of the types by auth string + TYPES_BY_AUTH_STRING = new LinkedHashMap<>(128); + + for (WellKnownAuthType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_AUTH_ID[value.getIdentifier()] = value; + TYPES_BY_AUTH_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownAuthType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownAuthType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_AUTH_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_AUTH_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownAuthType}, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownAuthType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_AUTH_TYPE; + } + return TYPES_BY_AUTH_ID[id]; + } + + /** + * Find the {@link WellKnownAuthType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownAuthType}, the {@link + * #UNPARSEABLE_AUTH_TYPE} is returned. + * + * @param authType the looked up auth type + * @return the matching {@link WellKnownAuthType}, or {@link #UNPARSEABLE_AUTH_TYPE} if none + * matches + */ + public static WellKnownAuthType fromString(String authType) { + if (authType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_AUTH_TYPE's text has been used + if (authType.equals(UNKNOWN_RESERVED_AUTH_TYPE.str)) { + return UNPARSEABLE_AUTH_TYPE; + } + + return TYPES_BY_AUTH_STRING.getOrDefault(authType, UNPARSEABLE_AUTH_TYPE); + } + + /** @return the byte identifier of the auth type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the auth type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java new file mode 100644 index 000000000..e78e87629 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java @@ -0,0 +1,167 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Mime Types, as defined in the eponymous extension. Such mime types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownMimeType { + UNPARSEABLE_MIME_TYPE("UNPARSEABLE_MIME_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_MIME_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + APPLICATION_AVRO("application/avro", (byte) 0x00), + APPLICATION_CBOR("application/cbor", (byte) 0x01), + APPLICATION_GRAPHQL("application/graphql", (byte) 0x02), + APPLICATION_GZIP("application/gzip", (byte) 0x03), + APPLICATION_JAVASCRIPT("application/javascript", (byte) 0x04), + APPLICATION_JSON("application/json", (byte) 0x05), + APPLICATION_OCTET_STREAM("application/octet-stream", (byte) 0x06), + APPLICATION_PDF("application/pdf", (byte) 0x07), + APPLICATION_THRIFT("application/vnd.apache.thrift.binary", (byte) 0x08), + APPLICATION_PROTOBUF("application/vnd.google.protobuf", (byte) 0x09), + APPLICATION_XML("application/xml", (byte) 0x0A), + APPLICATION_ZIP("application/zip", (byte) 0x0B), + AUDIO_AAC("audio/aac", (byte) 0x0C), + AUDIO_MP3("audio/mp3", (byte) 0x0D), + AUDIO_MP4("audio/mp4", (byte) 0x0E), + AUDIO_MPEG3("audio/mpeg3", (byte) 0x0F), + AUDIO_MPEG("audio/mpeg", (byte) 0x10), + AUDIO_OGG("audio/ogg", (byte) 0x11), + AUDIO_OPUS("audio/opus", (byte) 0x12), + AUDIO_VORBIS("audio/vorbis", (byte) 0x13), + IMAGE_BMP("image/bmp", (byte) 0x14), + IMAGE_GIF("image/gif", (byte) 0x15), + IMAGE_HEIC_SEQUENCE("image/heic-sequence", (byte) 0x16), + IMAGE_HEIC("image/heic", (byte) 0x17), + IMAGE_HEIF_SEQUENCE("image/heif-sequence", (byte) 0x18), + IMAGE_HEIF("image/heif", (byte) 0x19), + IMAGE_JPEG("image/jpeg", (byte) 0x1A), + IMAGE_PNG("image/png", (byte) 0x1B), + IMAGE_TIFF("image/tiff", (byte) 0x1C), + MULTIPART_MIXED("multipart/mixed", (byte) 0x1D), + TEXT_CSS("text/css", (byte) 0x1E), + TEXT_CSV("text/csv", (byte) 0x1F), + TEXT_HTML("text/html", (byte) 0x20), + TEXT_PLAIN("text/plain", (byte) 0x21), + TEXT_XML("text/xml", (byte) 0x22), + VIDEO_H264("video/H264", (byte) 0x23), + VIDEO_H265("video/H265", (byte) 0x24), + VIDEO_VP8("video/VP8", (byte) 0x25), + APPLICATION_HESSIAN("application/x-hessian", (byte) 0x26), + APPLICATION_JAVA_OBJECT("application/x-java-object", (byte) 0x27), + APPLICATION_CLOUDEVENTS_JSON("application/cloudevents+json", (byte) 0x28), + + // ... reserved for future use ... + MESSAGE_RSOCKET_MIMETYPE("message/x.rsocket.mime-type.v0", (byte) 0x7A), + MESSAGE_RSOCKET_ACCEPT_MIMETYPES("message/x.rsocket.accept-mime-types.v0", (byte) 0x7B), + MESSAGE_RSOCKET_AUTHENTICATION("message/x.rsocket.authentication.v0", (byte) 0x7C), + MESSAGE_RSOCKET_TRACING_ZIPKIN("message/x.rsocket.tracing-zipkin.v0", (byte) 0x7D), + MESSAGE_RSOCKET_ROUTING("message/x.rsocket.routing.v0", (byte) 0x7E), + MESSAGE_RSOCKET_COMPOSITE_METADATA("message/x.rsocket.composite-metadata.v0", (byte) 0x7F); + + static final WellKnownMimeType[] TYPES_BY_MIME_ID; + static final Map TYPES_BY_MIME_STRING; + + static { + // precompute an array of all valid mime ids, filling the blanks with the RESERVED enum + TYPES_BY_MIME_ID = new WellKnownMimeType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_MIME_ID, UNKNOWN_RESERVED_MIME_TYPE); + // also prepare a Map of the types by mime string + TYPES_BY_MIME_STRING = new HashMap<>(128); + + for (WellKnownMimeType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_MIME_ID[value.getIdentifier()] = value; + TYPES_BY_MIME_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownMimeType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownMimeType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_MIME_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_MIME_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownMimeType}, or {@link #UNKNOWN_RESERVED_MIME_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_MIME_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownMimeType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_MIME_TYPE; + } + return TYPES_BY_MIME_ID[id]; + } + + /** + * Find the {@link WellKnownMimeType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownMimeType}, the {@link + * #UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeType the looked up mime type + * @return the matching {@link WellKnownMimeType}, or {@link #UNPARSEABLE_MIME_TYPE} if none + * matches + */ + public static WellKnownMimeType fromString(String mimeType) { + if (mimeType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_MIME_TYPE's text has been used + if (mimeType.equals(UNKNOWN_RESERVED_MIME_TYPE.str)) { + return UNPARSEABLE_MIME_TYPE; + } + + return TYPES_BY_MIME_STRING.getOrDefault(mimeType, UNPARSEABLE_MIME_TYPE); + } + + /** @return the byte identifier of the mime type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the mime type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java b/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java new file mode 100644 index 000000000..3fb9ae1d6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains implementations of RSocket protocol extensions related + * to the use of metadata. + */ +@NonNullApi +package io.rsocket.metadata; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java new file mode 100644 index 000000000..e1a8ba449 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java @@ -0,0 +1,194 @@ +package io.rsocket.metadata.security; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.metadata.AuthMetadataCodec; + +/** @deprecated in favor of {@link io.rsocket.metadata.AuthMetadataCodec} */ +@Deprecated +public class AuthMetadataFlyweight { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + private AuthMetadataFlyweight() {} + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customAuthType the custom mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code customAuthType} is non US_ASCII string or + * empty string or its length is greater than 128 bytes + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, String customAuthType, ByteBuf metadata) { + + return AuthMetadataCodec.encodeMetadata(allocator, customAuthType, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to create intermediate buffers as needed. + * @param authType the well-known mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code authType} is {@link + * WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} or {@link + * WellKnownAuthType#UNKNOWN_RESERVED_AUTH_TYPE} + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, WellKnownAuthType authType, ByteBuf metadata) { + + return AuthMetadataCodec.encodeMetadata(allocator, WellKnownAuthType.cast(authType), metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using Simple Authentication format + * + * @throws IllegalArgumentException if the username length is greater than 255 + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param username the char sequence which represents user name. + * @param password the char sequence which represents user password. + */ + public static ByteBuf encodeSimpleMetadata( + ByteBufAllocator allocator, char[] username, char[] password) { + return AuthMetadataCodec.encodeSimpleMetadata(allocator, username, password); + } + + /** + * Encode a Authentication CompositeMetadata payload using Bearer Authentication format + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param token the char sequence which represents BEARER token. + */ + public static ByteBuf encodeBearerMetadata(ByteBufAllocator allocator, char[] token) { + return AuthMetadataCodec.encodeBearerMetadata(allocator, token); + } + + /** + * Encode a new Authentication Metadata payload information, first verifying if the passed {@link + * String} matches a {@link WellKnownAuthType} (in which case it will be encoded in a compressed + * fashion using the mime id of that type). + * + *

Prefer using {@link #encodeMetadata(ByteBufAllocator, String, ByteBuf)} if you already know + * that the mime type is not a {@link WellKnownAuthType}. + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param authType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeMetadata(ByteBufAllocator, WellKnownAuthType, ByteBuf) + * @see #encodeMetadata(ByteBufAllocator, String, ByteBuf) + */ + public static ByteBuf encodeMetadataWithCompression( + ByteBufAllocator allocator, String authType, ByteBuf metadata) { + return AuthMetadataCodec.encodeMetadataWithCompression(allocator, authType, metadata); + } + + /** + * Get the first {@code byte} from a {@link ByteBuf} and check whether it is length or {@link + * WellKnownAuthType}. Assuming said buffer properly contains such a {@code byte} + * + * @param metadata byteBuf used to get information from + */ + public static boolean isWellKnownAuthType(ByteBuf metadata) { + return AuthMetadataCodec.isWellKnownAuthType(metadata); + } + + /** + * Read first byte from the given {@code metadata} and tries to convert it's value to {@link + * WellKnownAuthType}. + * + * @param metadata given metadata buffer to read from + * @return Return on of the know Auth types or {@link WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} if + * field's value is length or unknown auth type + * @throws IllegalStateException if not enough readable bytes in the given {@link ByteBuf} + */ + public static WellKnownAuthType decodeWellKnownAuthType(ByteBuf metadata) { + return WellKnownAuthType.cast(AuthMetadataCodec.readWellKnownAuthType(metadata)); + } + + /** + * Read up to 129 bytes from the given metadata in order to get the custom Auth Type + * + * @param metadata + * @return + */ + public static CharSequence decodeCustomAuthType(ByteBuf metadata) { + return AuthMetadataCodec.readCustomAuthType(metadata); + } + + /** + * Read all remaining {@code bytes} from the given {@link ByteBuf} and return sliced + * representation of a payload + * + * @param metadata metadata to get payload from. Please note, the {@code metadata#readIndex} + * should be set to the beginning of the payload bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if no bytes readable in the + * given one + */ + public static ByteBuf decodePayload(ByteBuf metadata) { + return AuthMetadataCodec.readPayload(metadata); + } + + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero + */ + public static ByteBuf decodeUsername(ByteBuf simpleAuthMetadata) { + return AuthMetadataCodec.readUsername(simpleAuthMetadata); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read password from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if password length is zero + */ + public static ByteBuf decodePassword(ByteBuf simpleAuthMetadata) { + return AuthMetadataCodec.readPassword(simpleAuthMetadata); + } + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return {@code char[]} which represents UTF-8 username + */ + public static char[] decodeUsernameAsCharArray(ByteBuf simpleAuthMetadata) { + return AuthMetadataCodec.readUsernameAsCharArray(simpleAuthMetadata); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] decodePasswordAsCharArray(ByteBuf simpleAuthMetadata) { + return AuthMetadataCodec.readPasswordAsCharArray(simpleAuthMetadata); + } + + /** + * Read all the remaining {@code bytes} from the given {@link ByteBuf} where the first byte is + * username length and the subsequent number of bytes equal to decoded length + * + * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] decodeBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { + return AuthMetadataCodec.readBearerTokenAsCharArray(bearerAuthMetadata); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java new file mode 100644 index 000000000..24e5ff0db --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java @@ -0,0 +1,147 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata.security; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Auth Types, as defined in the eponymous extension. Such auth types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + * + * @deprecated in favor of {@link io.rsocket.metadata.WellKnownAuthType} + */ +@Deprecated +public enum WellKnownAuthType { + UNPARSEABLE_AUTH_TYPE("UNPARSEABLE_AUTH_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_AUTH_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + SIMPLE("simple", (byte) 0x00), + BEARER("bearer", (byte) 0x01); + // ... reserved for future use ... + + static final WellKnownAuthType[] TYPES_BY_AUTH_ID; + static final Map TYPES_BY_AUTH_STRING; + + static { + // precompute an array of all valid auth ids, filling the blanks with the RESERVED enum + TYPES_BY_AUTH_ID = new WellKnownAuthType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_AUTH_ID, UNKNOWN_RESERVED_AUTH_TYPE); + // also prepare a Map of the types by auth string + TYPES_BY_AUTH_STRING = new LinkedHashMap<>(128); + + for (WellKnownAuthType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_AUTH_ID[value.getIdentifier()] = value; + TYPES_BY_AUTH_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownAuthType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + static io.rsocket.metadata.WellKnownAuthType cast(WellKnownAuthType wellKnownAuthType) { + byte identifier = wellKnownAuthType.identifier; + if (identifier == io.rsocket.metadata.WellKnownAuthType.UNPARSEABLE_AUTH_TYPE.getIdentifier()) { + return io.rsocket.metadata.WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } else if (identifier + == io.rsocket.metadata.WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE.getIdentifier()) { + return io.rsocket.metadata.WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + } else { + return io.rsocket.metadata.WellKnownAuthType.fromIdentifier(identifier); + } + } + + static WellKnownAuthType cast(io.rsocket.metadata.WellKnownAuthType wellKnownAuthType) { + byte identifier = wellKnownAuthType.getIdentifier(); + if (identifier == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE.identifier) { + return WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } else if (identifier == WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE.identifier) { + return WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + } else { + return TYPES_BY_AUTH_ID[identifier]; + } + } + + /** + * Find the {@link WellKnownAuthType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_AUTH_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_AUTH_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownAuthType}, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownAuthType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_AUTH_TYPE; + } + return TYPES_BY_AUTH_ID[id]; + } + + /** + * Find the {@link WellKnownAuthType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownAuthType}, the {@link + * #UNPARSEABLE_AUTH_TYPE} is returned. + * + * @param authType the looked up auth type + * @return the matching {@link WellKnownAuthType}, or {@link #UNPARSEABLE_AUTH_TYPE} if none + * matches + */ + public static WellKnownAuthType fromString(String authType) { + if (authType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_AUTH_TYPE's text has been used + if (authType.equals(UNKNOWN_RESERVED_AUTH_TYPE.str)) { + return UNPARSEABLE_AUTH_TYPE; + } + + return TYPES_BY_AUTH_STRING.getOrDefault(authType, UNPARSEABLE_AUTH_TYPE); + } + + /** @return the byte identifier of the auth type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the auth type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/package-info.java b/rsocket-core/src/main/java/io/rsocket/package-info.java index 243c1ab52..6fe74fb38 100644 --- a/rsocket-core/src/main/java/io/rsocket/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,4 +14,16 @@ * limitations under the License. */ +/** + * Contains key contracts of the RSocket programming model including {@link io.rsocket.RSocket + * RSocket} for performing or handling RSocket interactions, {@link io.rsocket.SocketAcceptor + * SocketAcceptor} for declaring responders, {@link io.rsocket.Payload Payload} for access to the + * content of a payload, and others. + * + *

To connect to or start a server see {@link io.rsocket.core.RSocketConnector RSocketConnector} + * and {@link io.rsocket.core.RSocketServer RSocketServer} in {@link io.rsocket.core}. + */ +@NonNullApi package io.rsocket; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java index 2b10f8746..6b2a7a71b 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,11 +19,15 @@ import io.rsocket.DuplexConnection; import java.util.function.BiFunction; -/** */ +/** + * Contract to decorate a {@link DuplexConnection} and intercept the sending and receiving of + * RSocket frames at the transport level. + */ public @FunctionalInterface interface DuplexConnectionInterceptor extends BiFunction { + enum Type { - STREAM_ZERO, + SETUP, CLIENT, SERVER, SOURCE diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java new file mode 100644 index 000000000..fc032847c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -0,0 +1,56 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; + +/** + * Extends {@link InterceptorRegistry} with methods for building a chain of registered interceptors. + * This is not intended for direct use by applications. + */ +public class InitializingInterceptorRegistry extends InterceptorRegistry { + + public DuplexConnection initConnection( + DuplexConnectionInterceptor.Type type, DuplexConnection connection) { + for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { + connection = interceptor.apply(type, connection); + } + return connection; + } + + public RSocket initRequester(RSocket rsocket) { + for (RSocketInterceptor interceptor : getRequesterInteceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public RSocket initResponder(RSocket rsocket) { + for (RSocketInterceptor interceptor : getResponderInterceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public SocketAcceptor initSocketAcceptor(SocketAcceptor acceptor) { + for (SocketAcceptorInterceptor interceptor : getSocketAcceptorInterceptors()) { + acceptor = interceptor.apply(acceptor); + } + return acceptor; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java new file mode 100644 index 000000000..427fa15ae --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -0,0 +1,120 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +/** + * Provides support for registering interceptors at the following levels: + * + *

    + *
  • {@link #forConnection(DuplexConnectionInterceptor)} -- transport level + *
  • {@link #forSocketAcceptor(SocketAcceptorInterceptor)} -- for accepting new connections + *
  • {@link #forRequester(RSocketInterceptor)} -- for performing of requests + *
  • {@link #forResponder(RSocketInterceptor)} -- for responding to requests + *
+ */ +public class InterceptorRegistry { + private List requesterInteceptors = new ArrayList<>(); + private List responderInterceptors = new ArrayList<>(); + private List socketAcceptorInterceptors = new ArrayList<>(); + private List connectionInterceptors = new ArrayList<>(); + + /** + * Add an {@link RSocketInterceptor} that will decorate the RSocket used for performing requests. + */ + public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { + requesterInteceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forRequester(RSocketInterceptor)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forRequester(Consumer> consumer) { + consumer.accept(requesterInteceptors); + return this; + } + + /** + * Add an {@link RSocketInterceptor} that will decorate the RSocket used for resonding to + * requests. + */ + public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { + responderInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forResponder(RSocketInterceptor)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forResponder(Consumer> consumer) { + consumer.accept(responderInterceptors); + return this; + } + + /** + * Add a {@link SocketAcceptorInterceptor} that will intercept the accepting of new connections. + */ + public InterceptorRegistry forSocketAcceptor(SocketAcceptorInterceptor interceptor) { + socketAcceptorInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forSocketAcceptor(SocketAcceptorInterceptor)} with access to the list of + * existing registrations. + */ + public InterceptorRegistry forSocketAcceptor(Consumer> consumer) { + consumer.accept(socketAcceptorInterceptors); + return this; + } + + /** Add a {@link DuplexConnectionInterceptor}. */ + public InterceptorRegistry forConnection(DuplexConnectionInterceptor interceptor) { + connectionInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forConnection(DuplexConnectionInterceptor)} with access to the list of + * existing registrations. + */ + public InterceptorRegistry forConnection(Consumer> consumer) { + consumer.accept(connectionInterceptors); + return this; + } + + List getRequesterInteceptors() { + return requesterInteceptors; + } + + List getResponderInterceptors() { + return responderInterceptors; + } + + List getConnectionInterceptors() { + return connectionInterceptors; + } + + List getSocketAcceptorInterceptors() { + return socketAcceptorInterceptors; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java new file mode 100644 index 000000000..d7d9742d0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +/** + * Interceptor that adds {@link Flux#limitRate(int, int)} to publishers of outbound streams that + * breaks down or aggregates demand values from the remote end (i.e. {@code REQUEST_N} frames) into + * batches of a uniform size. For example the remote may request {@code Long.MAXVALUE} or it may + * start requesting one at a time, in both cases with the limit set to 64, the publisher will see a + * demand of 64 to start and subsequent batches of 48, i.e. continuing to prefetch and refill an + * internal queue when it falls to 75% full. The high and low tide marks are configurable. + * + *

See static factory methods to create an instance for a requester or for a responder. + * + *

Note: keep in mind that the {@code limitRate} operator always uses requests + * the same request values, even if the remote requests less than the limit. For example given a + * limit of 64, if the remote requests 4, 64 will be prefetched of which 4 will be sent and 60 will + * be cached. + * + * @since 1.0 + */ +public class LimitRateInterceptor implements RSocketInterceptor { + + private final int highTide; + private final int lowTide; + private final boolean requesterProxy; + + private LimitRateInterceptor(int highTide, int lowTide, boolean requesterProxy) { + this.highTide = highTide; + this.lowTide = lowTide; + this.requesterProxy = requesterProxy; + } + + @Override + public RSocket apply(RSocket socket) { + return requesterProxy ? new RequesterProxy(socket) : new ResponderProxy(socket); + } + + /** + * Create an interceptor for an {@code RSocket} that handles request-stream and/or request-channel + * interactions. + * + * @param prefetchRate the prefetch rate to pass to {@link Flux#limitRate(int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forResponder(int prefetchRate) { + return forResponder(prefetchRate, prefetchRate); + } + + /** + * Create an interceptor for an {@code RSocket} that handles request-stream and/or request-channel + * interactions with more control over the overall prefetch rate and replenish threshold. + * + * @param highTide the high tide value to pass to {@link Flux#limitRate(int, int)} + * @param lowTide the low tide value to pass to {@link Flux#limitRate(int, int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forResponder(int highTide, int lowTide) { + return new LimitRateInterceptor(highTide, lowTide, false); + } + + /** + * Create an interceptor for an {@code RSocket} that performs request-channel interactions. + * + * @param prefetchRate the prefetch rate to pass to {@link Flux#limitRate(int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forRequester(int prefetchRate) { + return forRequester(prefetchRate, prefetchRate); + } + + /** + * Create an interceptor for an {@code RSocket} that performs request-channel interactions with + * more control over the overall prefetch rate and replenish threshold. + * + * @param highTide the high tide value to pass to {@link Flux#limitRate(int, int)} + * @param lowTide the low tide value to pass to {@link Flux#limitRate(int, int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forRequester(int highTide, int lowTide) { + return new LimitRateInterceptor(highTide, lowTide, true); + } + + /** Responder side proxy, limits response streams. */ + private class ResponderProxy extends RSocketProxy { + + ResponderProxy(RSocket source) { + super(source); + } + + @Override + public Flux requestStream(Payload payload) { + return super.requestStream(payload).limitRate(highTide, lowTide); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return super.requestChannel(payloads).limitRate(highTide, lowTide); + } + } + + /** Requester side proxy, limits channel request stream. */ + private class RequesterProxy extends RSocketProxy { + + RequesterProxy(RSocket source) { + super(source); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return super.requestChannel(Flux.from(payloads).limitRate(highTide, lowTide)); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java deleted file mode 100644 index 873f6babb..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.plugins; - -import io.rsocket.DuplexConnection; -import io.rsocket.RSocket; -import java.util.ArrayList; -import java.util.List; - -public class PluginRegistry { - private List connections = new ArrayList<>(); - private List clients = new ArrayList<>(); - private List servers = new ArrayList<>(); - - public PluginRegistry() {} - - public PluginRegistry(PluginRegistry defaults) { - this.connections.addAll(defaults.connections); - this.clients.addAll(defaults.clients); - this.servers.addAll(defaults.servers); - } - - public void addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - connections.add(interceptor); - } - - public void addClientPlugin(RSocketInterceptor interceptor) { - clients.add(interceptor); - } - - public void addServerPlugin(RSocketInterceptor interceptor) { - servers.add(interceptor); - } - - public RSocket applyClient(RSocket rSocket) { - for (RSocketInterceptor i : clients) { - rSocket = i.apply(rSocket); - } - - return rSocket; - } - - public RSocket applyServer(RSocket rSocket) { - for (RSocketInterceptor i : servers) { - rSocket = i.apply(rSocket); - } - - return rSocket; - } - - public DuplexConnection applyConnection( - DuplexConnectionInterceptor.Type type, DuplexConnection connection) { - for (DuplexConnectionInterceptor i : connections) { - connection = i.apply(type, connection); - } - - return connection; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java b/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java deleted file mode 100644 index 1ac147687..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.plugins; - -/** JVM wide plugins for RSocket */ -public class Plugins { - private static PluginRegistry DEFAULT = new PluginRegistry(); - - private Plugins() {} - - public static void interceptConnection(DuplexConnectionInterceptor interceptor) { - DEFAULT.addConnectionPlugin(interceptor); - } - - public static void interceptClient(RSocketInterceptor interceptor) { - DEFAULT.addClientPlugin(interceptor); - } - - public static void interceptServer(RSocketInterceptor interceptor) { - DEFAULT.addServerPlugin(interceptor); - } - - public static PluginRegistry defaultPlugins() { - return DEFAULT; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java index 0bad0faed..0cd4bb8f6 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,5 +19,10 @@ import io.rsocket.RSocket; import java.util.function.Function; -/** */ +/** + * Contract to decorate an {@link RSocket}, providing a way to intercept interactions. This can be + * applied to a {@link InterceptorRegistry#forRequester(RSocketInterceptor) requester} or {@link + * InterceptorRegistry#forResponder(RSocketInterceptor) responder} {@code RSocket} of a client or + * server. + */ public @FunctionalInterface interface RSocketInterceptor extends Function {} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java new file mode 100644 index 000000000..6dd850ba9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java @@ -0,0 +1,29 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.SocketAcceptor; +import java.util.function.Function; + +/** + * Contract to decorate a {@link SocketAcceptor}, providing access to connection {@code setup} + * information and the ability to also decorate the sockets for requesting and responding. + * + *

This could be used as an alternative to registering an individual "requester" {@code + * RSocketInterceptor} and "responder" {@code RSocketInterceptor}. + */ +public @FunctionalInterface interface SocketAcceptorInterceptor + extends Function {} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java b/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java new file mode 100644 index 000000000..fd9e1f01a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contracts for interception of transports, connections, and requests in in RSocket Java. */ +@NonNullApi +package io.rsocket.plugins; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java new file mode 100644 index 000000000..ed9450357 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java @@ -0,0 +1,194 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.internal.ClientServerInputMultiplexer; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class ClientRSocketSession implements RSocketSession> { + private static final Logger logger = LoggerFactory.getLogger(ClientRSocketSession.class); + + private final ResumableDuplexConnection resumableConnection; + private volatile Mono newConnection; + private volatile ByteBuf resumeToken; + private final ByteBufAllocator allocator; + + public ClientRSocketSession( + DuplexConnection duplexConnection, + Duration resumeSessionDuration, + Retry retry, + ResumableFramesStore resumableFramesStore, + Duration resumeStreamTimeout, + boolean cleanupStoreOnKeepAlive) { + this.allocator = duplexConnection.alloc(); + this.resumableConnection = + new ResumableDuplexConnection( + "client", + duplexConnection, + resumableFramesStore, + resumeStreamTimeout, + cleanupStoreOnKeepAlive); + + /*session completed: release token initially retained in resumeToken(ByteBuf)*/ + onClose().doFinally(s -> resumeToken.release()).subscribe(); + + resumableConnection + .connectionErrors() + .flatMap( + err -> { + logger.debug("Client session connection error. Starting new connection"); + AtomicBoolean once = new AtomicBoolean(); + return newConnection + .delaySubscription( + once.compareAndSet(false, true) + ? retry.generateCompanion(Flux.just(new RetrySignal(err))) + : Mono.empty()) + .retryWhen(retry) + .timeout(resumeSessionDuration); + }) + .map(ClientServerInputMultiplexer::new) + .subscribe( + multiplexer -> { + /*reconnect resumable connection*/ + reconnect(multiplexer.asClientServerConnection()); + long impliedPosition = resumableConnection.impliedPosition(); + long position = resumableConnection.position(); + logger.debug( + "Client ResumableConnection reconnected. Sending RESUME frame with state: [impliedPos: {}, pos: {}]", + impliedPosition, + position); + /*Connection is established again: send RESUME frame to server, listen for RESUME_OK*/ + sendFrame( + ResumeFrameCodec.encode( + allocator, + /*retain so token is not released once sent as part of resume frame*/ + resumeToken.retain(), + impliedPosition, + position)) + .then(multiplexer.asSetupConnection().receive().next()) + .subscribe(this::resumeWith); + }, + err -> { + logger.debug("Client ResumableConnection reconnect timeout"); + resumableConnection.dispose(); + }); + } + + @Override + public ClientRSocketSession continueWith(Mono connectionFactory) { + this.newConnection = connectionFactory; + return this; + } + + @Override + public ClientRSocketSession resumeWith(ByteBuf resumeOkFrame) { + logger.debug("ResumeOK FRAME received"); + long remotePos = remotePos(resumeOkFrame); + long remoteImpliedPos = remoteImpliedPos(resumeOkFrame); + resumeOkFrame.release(); + + resumableConnection.resume( + remotePos, + remoteImpliedPos, + pos -> + pos.then() + /*Resumption is impossible: send CONNECTION_ERROR*/ + .onErrorResume( + err -> + sendFrame( + ErrorFrameCodec.encode( + allocator, 0, errorFrameThrowable(remoteImpliedPos))) + .then(Mono.fromRunnable(resumableConnection::dispose)) + /*Resumption is impossible: no need to return control to ResumableConnection*/ + .then(Mono.never()))); + return this; + } + + public ClientRSocketSession resumeToken(ByteBuf resumeToken) { + /*retain so token is not released once sent as part of setup frame*/ + this.resumeToken = resumeToken.retain(); + return this; + } + + @Override + public void reconnect(DuplexConnection connection) { + resumableConnection.reconnect(connection); + } + + @Override + public ResumableDuplexConnection resumableConnection() { + return resumableConnection; + } + + @Override + public ByteBuf token() { + return resumeToken; + } + + private Mono sendFrame(ByteBuf frame) { + return resumableConnection.sendOne(frame).onErrorResume(err -> Mono.empty()); + } + + private static long remoteImpliedPos(ByteBuf resumeOkFrame) { + return ResumeOkFrameCodec.lastReceivedClientPos(resumeOkFrame); + } + + private static long remotePos(ByteBuf resumeOkFrame) { + return -1; + } + + private static ConnectionErrorException errorFrameThrowable(long impliedPos) { + return new ConnectionErrorException("resumption_server_pos=[" + impliedPos + "]"); + } + + private static class RetrySignal implements Retry.RetrySignal { + + private final Throwable ex; + + RetrySignal(Throwable ex) { + this.ex = ex; + } + + @Override + public long totalRetries() { + return 0; + } + + @Override + public long totalRetriesInARow() { + return 0; + } + + @Override + public Throwable failure() { + return ex; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java new file mode 100644 index 000000000..415a77f92 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import java.time.Duration; + +public class ClientResume { + private final Duration sessionDuration; + private final ByteBuf resumeToken; + + public ClientResume(Duration sessionDuration, ByteBuf resumeToken) { + this.sessionDuration = sessionDuration; + this.resumeToken = resumeToken; + } + + public Duration sessionDuration() { + return sessionDuration; + } + + public ByteBuf resumeToken() { + return resumeToken; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java new file mode 100644 index 000000000..461be02d2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java @@ -0,0 +1,77 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import java.time.Duration; +import java.util.Objects; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +/** + * @deprecated as of 1.0 RC7 in favor of passing {@link Retry#backoff(long, Duration)} to {@link + * io.rsocket.core.Resume#retry(Retry)}. + */ +@Deprecated +public class ExponentialBackoffResumeStrategy implements ResumeStrategy { + private volatile Duration next; + private final Duration firstBackoff; + private final Duration maxBackoff; + private final int factor; + + public ExponentialBackoffResumeStrategy(Duration firstBackoff, Duration maxBackoff, int factor) { + this.firstBackoff = Objects.requireNonNull(firstBackoff, "firstBackoff"); + this.maxBackoff = Objects.requireNonNull(maxBackoff, "maxBackoff"); + this.factor = requirePositive(factor); + } + + @Override + public Publisher apply(ClientResume clientResume, Throwable throwable) { + return Flux.defer(() -> Mono.delay(next()).thenReturn(toString())); + } + + Duration next() { + next = + next == null + ? firstBackoff + : Duration.ofMillis(Math.min(maxBackoff.toMillis(), next.toMillis() * factor)); + return next; + } + + private static int requirePositive(int value) { + if (value <= 0) { + throw new IllegalArgumentException("Value must be positive: " + value); + } else { + return value; + } + } + + @Override + public String toString() { + return "ExponentialBackoffResumeStrategy{" + + "next=" + + next + + ", firstBackoff=" + + firstBackoff + + ", maxBackoff=" + + maxBackoff + + ", factor=" + + factor + + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java new file mode 100644 index 000000000..1875b7eac --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java @@ -0,0 +1,240 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import java.util.Queue; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.util.concurrent.Queues; + +public class InMemoryResumableFramesStore implements ResumableFramesStore { + private static final Logger logger = LoggerFactory.getLogger(InMemoryResumableFramesStore.class); + private static final long SAVE_REQUEST_SIZE = Long.MAX_VALUE; + + private final MonoProcessor disposed = MonoProcessor.create(); + volatile long position; + volatile long impliedPosition; + volatile int cacheSize; + final Queue cachedFrames; + private final String tag; + private final int cacheLimit; + private volatile int upstreamFrameRefCnt; + + public InMemoryResumableFramesStore(String tag, int cacheSizeBytes) { + this.tag = tag; + this.cacheLimit = cacheSizeBytes; + this.cachedFrames = cachedFramesQueue(cacheSizeBytes); + } + + public Mono saveFrames(Flux frames) { + MonoProcessor completed = MonoProcessor.create(); + frames + .doFinally(s -> completed.onComplete()) + .subscribe(new FramesSubscriber(SAVE_REQUEST_SIZE)); + return completed; + } + + @Override + public void releaseFrames(long remoteImpliedPos) { + long pos = position; + logger.debug( + "{} Removing frames for local: {}, remote implied: {}", tag, pos, remoteImpliedPos); + long removeSize = Math.max(0, remoteImpliedPos - pos); + while (removeSize > 0) { + ByteBuf cachedFrame = cachedFrames.poll(); + if (cachedFrame != null) { + removeSize -= releaseTailFrame(cachedFrame); + } else { + break; + } + } + if (removeSize > 0) { + throw new IllegalStateException( + String.format( + "Local and remote state disagreement: " + + "need to remove additional %d bytes, but cache is empty", + removeSize)); + } else if (removeSize < 0) { + throw new IllegalStateException( + "Local and remote state disagreement: " + "local and remote frame sizes are not equal"); + } else { + logger.debug("{} Removed frames. Current cache size: {}", tag, cacheSize); + } + } + + @Override + public Flux resumeStream() { + return Flux.generate( + () -> new ResumeStreamState(cachedFrames.size(), upstreamFrameRefCnt), + (state, sink) -> { + if (state.next()) { + /*spsc queue has no iterator - iterating by consuming*/ + ByteBuf frame = cachedFrames.poll(); + if (state.shouldRetain(frame)) { + frame.retain(); + } + cachedFrames.offer(frame); + sink.next(frame); + } else { + sink.complete(); + logger.debug("{} Resuming stream completed", tag); + } + return state; + }); + } + + @Override + public long framePosition() { + return position; + } + + @Override + public long frameImpliedPosition() { + return impliedPosition; + } + + @Override + public void resumableFrameReceived(ByteBuf frame) { + /*called on transport thread so non-atomic on volatile is safe*/ + impliedPosition += frame.readableBytes(); + } + + @Override + public Mono onClose() { + return disposed; + } + + @Override + public void dispose() { + cacheSize = 0; + ByteBuf frame = cachedFrames.poll(); + while (frame != null) { + frame.release(); + frame = cachedFrames.poll(); + } + disposed.onComplete(); + } + + @Override + public boolean isDisposed() { + return disposed.isTerminated(); + } + + /* this method and saveFrame() won't be called concurrently, + * so non-atomic on volatile is safe*/ + private int releaseTailFrame(ByteBuf content) { + int frameSize = content.readableBytes(); + cacheSize -= frameSize; + position += frameSize; + content.release(); + return frameSize; + } + + /*this method and releaseTailFrame() won't be called concurrently, + * so non-atomic on volatile is safe*/ + void saveFrame(ByteBuf frame) { + if (upstreamFrameRefCnt == 0) { + upstreamFrameRefCnt = frame.refCnt(); + } + + int frameSize = frame.readableBytes(); + long availableSize = cacheLimit - cacheSize; + while (availableSize < frameSize) { + ByteBuf cachedFrame = cachedFrames.poll(); + if (cachedFrame != null) { + availableSize += releaseTailFrame(cachedFrame); + } else { + break; + } + } + if (availableSize >= frameSize) { + cachedFrames.offer(frame.retain()); + cacheSize += frameSize; + } else { + position += frameSize; + } + } + + static class ResumeStreamState { + private final int cacheSize; + private final int expectedRefCnt; + private int cacheCounter; + + public ResumeStreamState(int cacheSize, int expectedRefCnt) { + this.cacheSize = cacheSize; + this.expectedRefCnt = expectedRefCnt; + } + + public boolean next() { + if (cacheCounter < cacheSize) { + cacheCounter++; + return true; + } else { + return false; + } + } + + public boolean shouldRetain(ByteBuf frame) { + return frame.refCnt() == expectedRefCnt; + } + } + + static Queue cachedFramesQueue(int size) { + return Queues.get(size).get(); + } + + class FramesSubscriber implements Subscriber { + private final long firstRequestSize; + private final long refillSize; + private int received; + private Subscription s; + + public FramesSubscriber(long requestSize) { + this.firstRequestSize = requestSize; + this.refillSize = firstRequestSize / 2; + } + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + s.request(firstRequestSize); + } + + @Override + public void onNext(ByteBuf byteBuf) { + saveFrame(byteBuf); + if (firstRequestSize != Long.MAX_VALUE && ++received == refillSize) { + received = 0; + s.request(refillSize); + } + } + + @Override + public void onError(Throwable t) { + logger.info("unexpected onError signal: {}, {}", t.getClass(), t.getMessage()); + } + + @Override + public void onComplete() {} + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java new file mode 100644 index 000000000..bd447c8a9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import java.time.Duration; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +/** + * @deprecated as of 1.0 RC7 in favor of passing {@link Retry#fixedDelay(long, Duration)} to {@link + * io.rsocket.core.Resume#retry(Retry)}. + */ +@Deprecated +public class PeriodicResumeStrategy implements ResumeStrategy { + private final Duration interval; + + public PeriodicResumeStrategy(Duration interval) { + this.interval = interval; + } + + @Override + public Publisher apply(ClientResume clientResumeConfiguration, Throwable throwable) { + return Mono.delay(interval).thenReturn(toString()); + } + + @Override + public String toString() { + return "PeriodicResumeStrategy{" + "interval=" + interval + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java new file mode 100644 index 000000000..7ec0abaee --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import reactor.core.publisher.Mono; + +public interface RSocketSession extends Closeable { + + ByteBuf token(); + + ResumableDuplexConnection resumableConnection(); + + RSocketSession continueWith(T ConnectionFactory); + + RSocketSession resumeWith(ByteBuf resumeFrame); + + void reconnect(DuplexConnection connection); + + @Override + default Mono onClose() { + return resumableConnection().onClose(); + } + + @Override + default void dispose() { + resumableConnection().dispose(); + } + + @Override + default boolean isDisposed() { + return resumableConnection().isDisposed(); + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java b/rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java similarity index 56% rename from rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java rename to rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java index ed8e6cd1d..6553e5ec5 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/RequestListener.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,25 +14,19 @@ * limitations under the License. */ -package io.rsocket.transport.local; +package io.rsocket.resume; -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; +import reactor.core.publisher.Flux; +import reactor.core.publisher.ReplayProcessor; -final class LocalUriHandlerTest implements UriHandlerTest { +class RequestListener { + private final ReplayProcessor requests = ReplayProcessor.create(1); - @Override - public String getInvalidUri() { - return "http://test"; + public Flux apply(Flux flux) { + return flux.doOnRequest(requests::onNext); } - @Override - public UriHandler getUriHandler() { - return new LocalUriHandler(); - } - - @Override - public String getValidUri() { - return "local:test"; + public Flux requests() { + return requests; } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java new file mode 100644 index 000000000..461d71228 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java @@ -0,0 +1,450 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameHeaderCodec; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.*; +import reactor.util.concurrent.Queues; + +public class ResumableDuplexConnection implements DuplexConnection, ResumeStateHolder { + private static final Logger logger = LoggerFactory.getLogger(ResumableDuplexConnection.class); + private static final Throwable closedChannelException = new ClosedChannelException(); + + private final String tag; + private final ResumableFramesStore resumableFramesStore; + private final Duration resumeStreamTimeout; + private final boolean cleanupOnKeepAlive; + + private final ReplayProcessor connections = ReplayProcessor.create(1); + private final EmitterProcessor connectionErrors = EmitterProcessor.create(); + private volatile DuplexConnection curConnection; + /*used instead of EmitterProcessor because its autocancel=false capability had no expected effect*/ + private final FluxProcessor downStreamFrames = ReplayProcessor.create(0); + private final FluxProcessor resumeSaveFrames = EmitterProcessor.create(); + private final MonoProcessor resumeSaveCompleted = MonoProcessor.create(); + private final Queue actions = Queues.unboundedMultiproducer().get(); + private final AtomicInteger actionsWip = new AtomicInteger(); + private final AtomicBoolean disposed = new AtomicBoolean(); + + private final Mono framesSent; + private final RequestListener downStreamRequestListener = new RequestListener(); + private final RequestListener resumeSaveStreamRequestListener = new RequestListener(); + private final UnicastProcessor> upstreams = UnicastProcessor.create(); + private final UpstreamFramesSubscriber upstreamSubscriber = + new UpstreamFramesSubscriber( + Queues.SMALL_BUFFER_SIZE, + downStreamRequestListener.requests(), + resumeSaveStreamRequestListener.requests(), + this::dispatch); + + private volatile Runnable onResume; + private volatile Runnable onDisconnect; + private volatile int state; + private volatile Disposable resumedStreamDisposable = Disposables.disposed(); + + public ResumableDuplexConnection( + String tag, + DuplexConnection duplexConnection, + ResumableFramesStore resumableFramesStore, + Duration resumeStreamTimeout, + boolean cleanupOnKeepAlive) { + this.tag = tag; + this.resumableFramesStore = resumableFramesStore; + this.resumeStreamTimeout = resumeStreamTimeout; + this.cleanupOnKeepAlive = cleanupOnKeepAlive; + + resumableFramesStore + .saveFrames(resumeSaveStreamRequestListener.apply(resumeSaveFrames)) + .subscribe(resumeSaveCompleted); + + upstreams.flatMap(Function.identity()).subscribe(upstreamSubscriber); + + framesSent = + connections + .switchMap( + c -> { + logger.debug("Switching transport: {}", tag); + return c.send(downStreamRequestListener.apply(downStreamFrames)) + .doFinally( + s -> + logger.debug( + "{} Transport send completed: {}, {}", tag, s, c.toString())) + .onErrorResume(err -> Mono.never()); + }) + .then() + .cache(); + + reconnect(duplexConnection); + } + + @Override + public ByteBufAllocator alloc() { + return curConnection.alloc(); + } + + public void disconnect() { + DuplexConnection c = this.curConnection; + if (c != null) { + disconnect(c); + } + } + + public void onDisconnect(Runnable onDisconnectAction) { + this.onDisconnect = onDisconnectAction; + } + + public void onResume(Runnable onResumeAction) { + this.onResume = onResumeAction; + } + + /*reconnected by session after error. After this downstream can receive frames, + * but sending in suppressed until resume() is called*/ + public void reconnect(DuplexConnection connection) { + if (curConnection == null) { + logger.debug("{} Resumable duplex connection started with connection: {}", tag, connection); + state = State.CONNECTED; + onNewConnection(connection); + } else { + logger.debug( + "{} Resumable duplex connection reconnected with connection: {}", tag, connection); + /*race between sendFrame and doResumeStart may lead to ongoing upstream frames + written before resume complete*/ + dispatch(new ResumeStart(connection)); + } + } + + /*after receiving RESUME (Server) or RESUME_OK (Client) + calculate and send resume frames */ + public void resume( + long remotePos, long remoteImpliedPos, Function, Mono> resumeFrameSent) { + /*race between sendFrame and doResume may lead to duplicate frames on resume store*/ + dispatch(new Resume(remotePos, remoteImpliedPos, resumeFrameSent)); + } + + @Override + public Mono sendOne(ByteBuf frame) { + return curConnection.sendOne(frame); + } + + @Override + public Mono send(Publisher frames) { + upstreams.onNext(Flux.from(frames)); + return framesSent; + } + + @Override + public Flux receive() { + return connections.switchMap( + c -> + c.receive() + .doOnNext( + f -> { + if (isResumableFrame(f)) { + resumableFramesStore.resumableFrameReceived(f); + } + }) + .onErrorResume(err -> Mono.never())); + } + + public long position() { + return resumableFramesStore.framePosition(); + } + + @Override + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); + } + + @Override + public void onImpliedPosition(long remoteImpliedPos) { + logger.debug("Got remote position from keep-alive: {}", remoteImpliedPos); + if (cleanupOnKeepAlive) { + dispatch(new ReleaseFrames(remoteImpliedPos)); + } + } + + @Override + public Mono onClose() { + return Flux.merge(connections.last().flatMap(Closeable::onClose), resumeSaveCompleted).then(); + } + + @Override + public void dispose() { + if (disposed.compareAndSet(false, true)) { + logger.debug("Resumable connection disposed: {}, {}", tag, this); + upstreams.onComplete(); + connections.onComplete(); + connectionErrors.onComplete(); + resumeSaveFrames.onComplete(); + curConnection.dispose(); + upstreamSubscriber.dispose(); + resumedStreamDisposable.dispose(); + resumableFramesStore.dispose(); + } + } + + @Override + public double availability() { + return curConnection.availability(); + } + + @Override + public boolean isDisposed() { + return disposed.get(); + } + + private void sendFrame(ByteBuf f) { + if (disposed.get()) { + f.release(); + return; + } + /*resuming from store so no need to save again*/ + if (state != State.RESUME && isResumableFrame(f)) { + resumeSaveFrames.onNext(f); + } + /*filter frames coming from upstream before actual resumption began, + * to preserve frames ordering*/ + if (state != State.RESUME_STARTED) { + downStreamFrames.onNext(f); + } + } + + Flux connectionErrors() { + return connectionErrors; + } + + private void dispatch(Object action) { + actions.offer(action); + if (actionsWip.getAndIncrement() == 0) { + do { + Object a = actions.poll(); + if (a instanceof ByteBuf) { + sendFrame((ByteBuf) a); + } else { + ((Runnable) a).run(); + } + } while (actionsWip.decrementAndGet() != 0); + } + } + + private void doResumeStart(DuplexConnection connection) { + state = State.RESUME_STARTED; + resumedStreamDisposable.dispose(); + upstreamSubscriber.resumeStart(); + onNewConnection(connection); + } + + private void doResume( + long remotePosition, + long remoteImpliedPosition, + Function, Mono> sendResumeFrame) { + long localPosition = position(); + long localImpliedPosition = impliedPosition(); + + logger.debug("Resumption start"); + logger.debug( + "Resumption states. local: [pos: {}, impliedPos: {}], remote: [pos: {}, impliedPos: {}]", + localPosition, + localImpliedPosition, + remotePosition, + remoteImpliedPosition); + + long remoteImpliedPos = + calculateRemoteImpliedPos( + localPosition, localImpliedPosition, + remotePosition, remoteImpliedPosition); + + Mono impliedPositionOrError; + if (remoteImpliedPos >= 0) { + state = State.RESUME; + releaseFramesToPosition(remoteImpliedPos); + impliedPositionOrError = Mono.just(localImpliedPosition); + } else { + impliedPositionOrError = + Mono.error( + new ResumeStateException( + localPosition, localImpliedPosition, + remotePosition, remoteImpliedPosition)); + } + + sendResumeFrame + .apply(impliedPositionOrError) + .doOnSuccess( + v -> { + Runnable r = this.onResume; + if (r != null) { + r.run(); + } + }) + .then( + streamResumedFrames( + resumableFramesStore + .resumeStream() + .timeout(resumeStreamTimeout) + .doFinally(s -> dispatch(new ResumeComplete()))) + .doOnError(err -> dispose())) + .onErrorResume(err -> Mono.empty()) + .subscribe(); + } + + static long calculateRemoteImpliedPos( + long pos, long impliedPos, long remotePos, long remoteImpliedPos) { + if (remotePos <= impliedPos && pos <= remoteImpliedPos) { + return remoteImpliedPos; + } else { + return -1L; + } + } + + private void doResumeComplete() { + logger.debug("Completing resumption"); + state = State.RESUME_COMPLETED; + upstreamSubscriber.resumeComplete(); + } + + private Mono streamResumedFrames(Flux frames) { + return Mono.create( + s -> { + ResumeFramesSubscriber subscriber = + new ResumeFramesSubscriber( + downStreamRequestListener.requests(), this::dispatch, s::error, s::success); + s.onDispose(subscriber); + resumedStreamDisposable = subscriber; + frames.subscribe(subscriber); + }); + } + + private void onNewConnection(DuplexConnection connection) { + curConnection = connection; + connection.onClose().doFinally(v -> disconnect(connection)).subscribe(); + connections.onNext(connection); + } + + private void disconnect(DuplexConnection connection) { + /*do not report late disconnects on old connection if new one is available*/ + if (curConnection == connection && state != State.DISCONNECTED) { + connection.dispose(); + state = State.DISCONNECTED; + logger.debug( + "{} Inner connection disconnected: {}", + tag, + closedChannelException.getClass().getSimpleName()); + connectionErrors.onNext(closedChannelException); + Runnable r = this.onDisconnect; + if (r != null) { + r.run(); + } + } + } + + /*remove frames confirmed by implied pos, + set current pos accordingly*/ + private void releaseFramesToPosition(long remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + + static boolean isResumableFrame(ByteBuf frame) { + switch (FrameHeaderCodec.nativeFrameType(frame)) { + case REQUEST_CHANNEL: + case REQUEST_STREAM: + case REQUEST_RESPONSE: + case REQUEST_FNF: + case REQUEST_N: + case CANCEL: + case ERROR: + case PAYLOAD: + return true; + default: + return false; + } + } + + static class State { + static int CONNECTED = 0; + static int RESUME_STARTED = 1; + static int RESUME = 2; + static int RESUME_COMPLETED = 3; + static int DISCONNECTED = 4; + } + + class ResumeStart implements Runnable { + private final DuplexConnection connection; + + public ResumeStart(DuplexConnection connection) { + this.connection = connection; + } + + @Override + public void run() { + doResumeStart(connection); + } + } + + class Resume implements Runnable { + private final long remotePos; + private final long remoteImpliedPos; + private final Function, Mono> resumeFrameSent; + + public Resume( + long remotePos, long remoteImpliedPos, Function, Mono> resumeFrameSent) { + this.remotePos = remotePos; + this.remoteImpliedPos = remoteImpliedPos; + this.resumeFrameSent = resumeFrameSent; + } + + @Override + public void run() { + doResume(remotePos, remoteImpliedPos, resumeFrameSent); + } + } + + private class ResumeComplete implements Runnable { + + @Override + public void run() { + doResumeComplete(); + } + } + + private class ReleaseFrames implements Runnable { + private final long remoteImpliedPos; + + public ReleaseFrames(long remoteImpliedPos) { + this.remoteImpliedPos = remoteImpliedPos; + } + + @Override + public void run() { + releaseFramesToPosition(remoteImpliedPos); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java new file mode 100644 index 000000000..3a30544b6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java @@ -0,0 +1,55 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** Store for resumable frames */ +public interface ResumableFramesStore extends Closeable { + + /** + * Save resumable frames for potential resumption + * + * @param frames {@link Flux} of resumable frames + * @return {@link Mono} which completes once all resume frames are written + */ + Mono saveFrames(Flux frames); + + /** Release frames from tail of the store up to remote implied position */ + void releaseFrames(long remoteImpliedPos); + + /** + * @return {@link Flux} of frames from store tail to head. It should terminate with error if + * frames are not continuous + */ + Flux resumeStream(); + + /** @return Local frame position as defined by RSocket protocol */ + long framePosition(); + + /** @return Implied frame position as defined by RSocket protocol */ + long frameImpliedPosition(); + + /** + * Received resumable frame as defined by RSocket protocol. Implementation must increment frame + * implied position + */ + void resumableFrameReceived(ByteBuf frame); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeCache.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeCache.java deleted file mode 100644 index f85949c60..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeCache.java +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import io.netty.buffer.ByteBuf; -import java.util.*; -import reactor.core.publisher.Flux; - -public class ResumeCache { - private final ResumePositionCounter strategy; - private final int maxBufferSize; - - private final LinkedHashMap frames = new LinkedHashMap<>(); - private int lastRemotePosition = 0; - private int currentPosition = 0; - private int bufferSize; - - public ResumeCache(ResumePositionCounter strategy, int maxBufferSize) { - this.strategy = strategy; - this.maxBufferSize = maxBufferSize; - } - - public void updateRemotePosition(int remotePosition) { - if (remotePosition > currentPosition) { - throw new IllegalStateException( - "Remote ahead of " + lastRemotePosition + " , expected " + remotePosition); - } - - if (remotePosition == lastRemotePosition) { - return; - } - - if (remotePosition < lastRemotePosition) { - throw new IllegalStateException( - "Remote position moved back from " + lastRemotePosition + " to " + remotePosition); - } - - lastRemotePosition = remotePosition; - - Iterator> positions = frames.entrySet().iterator(); - - while (positions.hasNext()) { - Map.Entry cachePosition = positions.next(); - - if (cachePosition.getKey() <= remotePosition) { - positions.remove(); - bufferSize -= strategy.cost(cachePosition.getValue()); - cachePosition.getValue().release(); - } - - // TODO check for a bad position - } - } - - public void sent(ByteBuf frame) { - if (ResumeUtil.isTracked(frame)) { - frames.put(currentPosition, frame.copy()); - bufferSize += strategy.cost(frame); - - currentPosition += ResumeUtil.offset(frame); - - if (frames.size() > maxBufferSize) { - ByteBuf f = frames.remove(first(frames)); - bufferSize -= strategy.cost(f); - } - } - } - - private int first(LinkedHashMap frames) { - return frames.keySet().iterator().next(); - } - - public Flux resend(int remotePosition) { - updateRemotePosition(remotePosition); - - if (remotePosition == currentPosition) { - return Flux.empty(); - } - - List resend = new ArrayList<>(); - - for (Map.Entry cachePosition : frames.entrySet()) { - if (remotePosition < cachePosition.getKey()) { - resend.add(cachePosition.getValue()); - } - - // TODO error handling - } - - return Flux.fromIterable(resend); - } - - public int getCurrentPosition() { - return currentPosition; - } - - public int getRemotePosition() { - return lastRemotePosition; - } - - public int getEarliestResendPosition() { - if (frames.isEmpty()) { - return currentPosition; - } else { - return first(frames); - } - } - - public int size() { - return bufferSize; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java new file mode 100644 index 000000000..4facdd3c1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeFramesSubscriber.java @@ -0,0 +1,88 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; + +class ResumeFramesSubscriber implements Subscriber, Disposable { + private final Flux requests; + private final Consumer onNext; + private final Consumer onError; + private final Runnable onComplete; + private final AtomicBoolean disposed = new AtomicBoolean(); + private volatile Disposable requestsDisposable; + private volatile Subscription subscription; + + public ResumeFramesSubscriber( + Flux requests, + Consumer onNext, + Consumer onError, + Runnable onComplete) { + this.requests = requests; + this.onNext = onNext; + this.onError = onError; + this.onComplete = onComplete; + } + + @Override + public void onSubscribe(Subscription s) { + if (isDisposed()) { + s.cancel(); + } else { + this.subscription = s; + this.requestsDisposable = requests.subscribe(s::request); + } + } + + @Override + public void onNext(ByteBuf frame) { + this.onNext.accept(frame); + } + + @Override + public void onError(Throwable t) { + this.onError.accept(t); + requestsDisposable.dispose(); + } + + @Override + public void onComplete() { + this.onComplete.run(); + requestsDisposable.dispose(); + } + + @Override + public void dispose() { + if (disposed.compareAndSet(false, true)) { + if (subscription != null) { + subscription.cancel(); + requestsDisposable.dispose(); + } + } + } + + @Override + public boolean isDisposed() { + return disposed.get(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java new file mode 100644 index 000000000..1fae24b07 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +class ResumeStateException extends RuntimeException { + private static final long serialVersionUID = -5393753463377588732L; + private final long localPos; + private final long localImpliedPos; + private final long remotePos; + private final long remoteImpliedPos; + + public ResumeStateException( + long localPos, long localImpliedPos, long remotePos, long remoteImpliedPos) { + this.localPos = localPos; + this.localImpliedPos = localImpliedPos; + this.remotePos = remotePos; + this.remoteImpliedPos = remoteImpliedPos; + } + + public long getLocalPos() { + return localPos; + } + + public long getLocalImpliedPos() { + return localImpliedPos; + } + + public long getRemotePos() { + return remotePos; + } + + public long getRemoteImpliedPos() { + return remoteImpliedPos; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java new file mode 100644 index 000000000..31687a24b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +public interface ResumeStateHolder { + + long impliedPosition(); + + void onImpliedPosition(long remoteImpliedPos); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java new file mode 100644 index 000000000..d9dec9f54 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java @@ -0,0 +1,29 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import java.util.function.BiFunction; +import org.reactivestreams.Publisher; +import reactor.util.retry.Retry; + +/** + * @deprecated as of 1.0 RC7 in favor of using {@link io.rsocket.core.Resume#retry(Retry)} via + * {@link io.rsocket.core.RSocketConnector} or {@link io.rsocket.core.RSocketServer}. + */ +@Deprecated +@FunctionalInterface +public interface ResumeStrategy extends BiFunction> {} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeToken.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeToken.java deleted file mode 100644 index 8f33f7951..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeToken.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import io.netty.buffer.ByteBufUtil; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.UUID; - -public final class ResumeToken { - // TODO consider best format to store this - private final byte[] resumeToken; - - protected ResumeToken(byte[] resumeToken) { - this.resumeToken = resumeToken; - } - - public static ResumeToken bytes(byte[] token) { - return new ResumeToken(token); - } - - public static ResumeToken generate() { - return new ResumeToken(getBytesFromUUID(UUID.randomUUID())); - } - - static byte[] getBytesFromUUID(UUID uuid) { - ByteBuffer bb = ByteBuffer.wrap(new byte[16]); - bb.putLong(uuid.getMostSignificantBits()); - bb.putLong(uuid.getLeastSignificantBits()); - - return bb.array(); - } - - @Override - public int hashCode() { - return Arrays.hashCode(resumeToken); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof ResumeToken) { - return Arrays.equals(resumeToken, ((ResumeToken) obj).resumeToken); - } - - return false; - } - - @Override - public String toString() { - return ByteBufUtil.hexDump(resumeToken); - } - - public byte[] toByteArray() { - return resumeToken; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeUtil.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeUtil.java deleted file mode 100644 index 36558d7ce..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeUtil.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import io.netty.buffer.ByteBuf; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.FrameType; - -public class ResumeUtil { - public static boolean isTracked(FrameType frameType) { - switch (frameType) { - case REQUEST_CHANNEL: - case REQUEST_STREAM: - case REQUEST_RESPONSE: - case REQUEST_FNF: - // case METADATA_PUSH: - case REQUEST_N: - case CANCEL: - case ERROR: - case PAYLOAD: - return true; - default: - return false; - } - } - - public static boolean isTracked(ByteBuf frame) { - return isTracked(FrameHeaderFlyweight.frameType(frame)); - } - - public static int offset(ByteBuf frame) { - return 0; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java new file mode 100644 index 000000000..b54ce644f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java @@ -0,0 +1,160 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import java.time.Duration; +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.FluxProcessor; +import reactor.core.publisher.Mono; +import reactor.core.publisher.ReplayProcessor; + +public class ServerRSocketSession implements RSocketSession { + private static final Logger logger = LoggerFactory.getLogger(ServerRSocketSession.class); + + private final ResumableDuplexConnection resumableConnection; + /*used instead of EmitterProcessor because its autocancel=false capability had no expected effect*/ + private final FluxProcessor newConnections = + ReplayProcessor.create(0); + private final ByteBufAllocator allocator; + private final ByteBuf resumeToken; + + public ServerRSocketSession( + DuplexConnection duplexConnection, + Duration resumeSessionDuration, + Duration resumeStreamTimeout, + Function resumeStoreFactory, + ByteBuf resumeToken, + boolean cleanupStoreOnKeepAlive) { + this.allocator = duplexConnection.alloc(); + this.resumeToken = resumeToken; + this.resumableConnection = + new ResumableDuplexConnection( + "server", + duplexConnection, + resumeStoreFactory.apply(resumeToken), + resumeStreamTimeout, + cleanupStoreOnKeepAlive); + + Mono timeout = + resumableConnection + .connectionErrors() + .flatMap( + err -> { + logger.debug("Starting session timeout due to error", err); + return newConnections + .next() + .doOnNext(c -> logger.debug("Connection after error: {}", c)) + .timeout(resumeSessionDuration); + }) + .then() + .cast(DuplexConnection.class); + + newConnections + .mergeWith(timeout) + .subscribe( + connection -> { + reconnect(connection); + logger.debug("Server ResumableConnection reconnected: {}", connection); + }, + err -> { + logger.debug("Server ResumableConnection reconnect timeout"); + resumableConnection.dispose(); + }); + } + + @Override + public ServerRSocketSession continueWith(DuplexConnection connectionFactory) { + logger.debug("Server continued with connection: {}", connectionFactory); + newConnections.onNext(connectionFactory); + return this; + } + + @Override + public ServerRSocketSession resumeWith(ByteBuf resumeFrame) { + logger.debug("Resume FRAME received"); + long remotePos = remotePos(resumeFrame); + long remoteImpliedPos = remoteImpliedPos(resumeFrame); + resumeFrame.release(); + + resumableConnection.resume( + remotePos, + remoteImpliedPos, + pos -> + pos.flatMap(impliedPos -> sendFrame(ResumeOkFrameCodec.encode(allocator, impliedPos))) + .onErrorResume( + err -> + sendFrame(ErrorFrameCodec.encode(allocator, 0, errorFrameThrowable(err))) + .then(Mono.fromRunnable(resumableConnection::dispose)) + /*Resumption is impossible: no need to return control to ResumableConnection*/ + .then(Mono.never()))); + return this; + } + + @Override + public void reconnect(DuplexConnection connection) { + resumableConnection.reconnect(connection); + } + + @Override + public ResumableDuplexConnection resumableConnection() { + return resumableConnection; + } + + @Override + public ByteBuf token() { + return resumeToken; + } + + private Mono sendFrame(ByteBuf frame) { + logger.debug("Sending Resume frame: {}", frame); + return resumableConnection.sendOne(frame).onErrorResume(e -> Mono.empty()); + } + + private static long remotePos(ByteBuf resumeFrame) { + return ResumeFrameCodec.firstAvailableClientPos(resumeFrame); + } + + private static long remoteImpliedPos(ByteBuf resumeFrame) { + return ResumeFrameCodec.lastReceivedServerPos(resumeFrame); + } + + private static RejectedResumeException errorFrameThrowable(Throwable err) { + String msg; + if (err instanceof ResumeStateException) { + ResumeStateException resumeException = ((ResumeStateException) err); + msg = + String.format( + "resumption_pos=[ remote: { pos: %d, impliedPos: %d }, local: { pos: %d, impliedPos: %d }]", + resumeException.getRemotePos(), + resumeException.getRemoteImpliedPos(), + resumeException.getLocalPos(), + resumeException.getLocalImpliedPos()); + } else { + msg = String.format("resume_internal_error: %s", err.getMessage()); + } + return new RejectedResumeException(msg); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java new file mode 100644 index 000000000..1d5c23bd6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import reactor.util.annotation.Nullable; + +public class SessionManager { + private volatile boolean isDisposed; + private final Map sessions = new ConcurrentHashMap<>(); + + public ServerRSocketSession save(ServerRSocketSession session) { + if (isDisposed) { + session.dispose(); + } else { + ByteBuf token = session.token().retain(); + session + .onClose() + .doOnSuccess( + v -> { + if (isDisposed || sessions.get(token) == session) { + sessions.remove(token); + } + token.release(); + }) + .subscribe(); + ServerRSocketSession prevSession = sessions.remove(token); + if (prevSession != null) { + prevSession.dispose(); + } + sessions.put(token, session); + } + return session; + } + + @Nullable + public ServerRSocketSession get(ByteBuf resumeToken) { + return sessions.get(resumeToken); + } + + public void dispose() { + isDisposed = true; + sessions.values().forEach(ServerRSocketSession::dispose); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java b/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java new file mode 100644 index 000000000..f010a05bd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/UpstreamFramesSubscriber.java @@ -0,0 +1,159 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.concurrent.Queues; + +class UpstreamFramesSubscriber implements Subscriber, Disposable { + private static final Logger logger = LoggerFactory.getLogger(UpstreamFramesSubscriber.class); + + private final AtomicBoolean disposed = new AtomicBoolean(); + private final Consumer itemConsumer; + private final Disposable downstreamRequestDisposable; + private final Disposable resumeSaveStreamDisposable; + + private volatile Subscription subs; + private volatile boolean resumeStarted; + private final Queue framesCache; + private long request; + private long downStreamRequestN; + private long resumeSaveStreamRequestN; + + UpstreamFramesSubscriber( + int estimatedDownstreamRequest, + Flux downstreamRequests, + Flux resumeSaveStreamRequests, + Consumer itemConsumer) { + this.itemConsumer = itemConsumer; + this.framesCache = Queues.unbounded(estimatedDownstreamRequest).get(); + + downstreamRequestDisposable = downstreamRequests.subscribe(requestN -> requestN(0, requestN)); + + resumeSaveStreamDisposable = + resumeSaveStreamRequests.subscribe(requestN -> requestN(requestN, 0)); + } + + @Override + public void onSubscribe(Subscription s) { + this.subs = s; + if (!isDisposed()) { + doRequest(); + } else { + s.cancel(); + } + } + + @Override + public void onNext(ByteBuf item) { + processFrame(item); + } + + @Override + public void onError(Throwable t) { + dispose(); + } + + @Override + public void onComplete() { + dispose(); + } + + public void resumeStart() { + resumeStarted = true; + } + + public void resumeComplete() { + ByteBuf frame = framesCache.poll(); + while (frame != null) { + itemConsumer.accept(frame); + frame = framesCache.poll(); + } + resumeStarted = false; + doRequest(); + } + + @Override + public void dispose() { + if (disposed.compareAndSet(false, true)) { + releaseCache(); + if (subs != null) { + subs.cancel(); + } + resumeSaveStreamDisposable.dispose(); + downstreamRequestDisposable.dispose(); + } + } + + @Override + public boolean isDisposed() { + return disposed.get(); + } + + private void requestN(long resumeStreamRequest, long downStreamRequest) { + synchronized (this) { + downStreamRequestN = Operators.addCap(downStreamRequestN, downStreamRequest); + resumeSaveStreamRequestN = Operators.addCap(resumeSaveStreamRequestN, resumeStreamRequest); + + long requests = Math.min(downStreamRequestN, resumeSaveStreamRequestN); + if (requests > 0) { + downStreamRequestN -= requests; + resumeSaveStreamRequestN -= requests; + logger.debug("Upstream subscriber requestN: {}", requests); + request = Operators.addCap(request, requests); + } + } + doRequest(); + } + + private void doRequest() { + if (subs != null && !resumeStarted) { + synchronized (this) { + long r = request; + if (r > 0) { + subs.request(r); + request = 0; + } + } + } + } + + private void releaseCache() { + ByteBuf frame = framesCache.poll(); + while (frame != null && frame.refCnt() > 0) { + frame.release(); + } + } + + private void processFrame(ByteBuf item) { + if (resumeStarted) { + framesCache.offer(item); + } else { + itemConsumer.accept(item); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumePositionCounter.java b/rsocket-core/src/main/java/io/rsocket/resume/package-info.java similarity index 62% rename from rsocket-core/src/main/java/io/rsocket/resume/ResumePositionCounter.java rename to rsocket-core/src/main/java/io/rsocket/resume/package-info.java index 273058731..98744386a 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumePositionCounter.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/package-info.java @@ -14,22 +14,14 @@ * limitations under the License. */ -package io.rsocket.resume; - -import io.netty.buffer.ByteBuf; - /** - * Calculates the cost of a Frame when stored in the ResumeCache. Two obvious and provided - * strategies are simple frame counts and size in bytes. + * Contains support classes for the RSocket resume capability. + * + * @see Resuming + * Operation */ -public interface ResumePositionCounter { - int cost(ByteBuf f); - - static ResumePositionCounter size() { - return ResumeUtil::offset; - } +@NonNullApi +package io.rsocket.resume; - static ResumePositionCounter frames() { - return f -> 1; - } -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java b/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java index d5a8fe775..25fd67097 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java @@ -26,7 +26,8 @@ public interface ClientTransport extends Transport { * Returns a {@code Publisher}, every subscription to which returns a single {@code * DuplexConnection}. * + * @param mtu The mtu used for fragmentation - if set to zero fragmentation will be disabled * @return {@code Publisher}, every subscription returns a single {@code DuplexConnection}. */ - Mono connect(); + Mono connect(int mtu); } diff --git a/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java b/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java index 28af3fd4c..3adc90cc8 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java @@ -29,10 +29,11 @@ public interface ServerTransport extends Transport { * Starts this server. * * @param acceptor An acceptor to process a newly accepted {@code DuplexConnection} + * @param mtu The mtu used for fragmentation - if set to zero fragmentation will be disabled * @return A handle to retrieve information about a started server. * @throws NullPointerException if {@code acceptor} is {@code null} */ - Mono start(ConnectionAcceptor acceptor); + Mono start(ConnectionAcceptor acceptor, int mtu); /** A contract to accept a new {@code DuplexConnection}. */ interface ConnectionAcceptor extends Function> { diff --git a/rsocket-core/src/main/java/io/rsocket/transport/package-info.java b/rsocket-core/src/main/java/io/rsocket/transport/package-info.java index 86e7c311a..00536122a 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,5 +14,8 @@ * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** Client and server transport contracts for pluggable transports. */ +@NonNullApi package io.rsocket.transport; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java b/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java deleted file mode 100644 index ec3d4ab3c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.Optional; -import java.util.ServiceLoader; - -/** Maps a {@link URI} to a {@link ClientTransport} or {@link ServerTransport}. */ -public interface UriHandler { - - /** - * Load all registered instances of {@code UriHandler}. - * - * @return all registered instances of {@code UriHandler} - */ - static ServiceLoader loadServices() { - return ServiceLoader.load(UriHandler.class); - } - - /** - * Returns an implementation of {@link ClientTransport} unambiguously mapped to a {@link URI}, - * otherwise {@link Optional#EMPTY}. - * - * @param uri the uri to map - * @return an implementation of {@link ClientTransport} unambiguously mapped to a {@link URI}, * - * otherwise {@link Optional#EMPTY} - * @throws NullPointerException if {@code uri} is {@code null} - */ - Optional buildClient(URI uri); - - /** - * Returns an implementation of {@link ServerTransport} unambiguously mapped to a {@link URI}, - * otherwise {@link Optional#EMPTY}. - * - * @param uri the uri to map - * @return an implementation of {@link ServerTransport} unambiguously mapped to a {@link URI}, * - * otherwise {@link Optional#EMPTY} - * @throws NullPointerException if {@code uri} is {@code null} - */ - Optional buildServer(URI uri); -} diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java b/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java deleted file mode 100644 index 5275d2304..000000000 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import static io.rsocket.uri.UriHandler.loadServices; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.ServiceLoader; -import reactor.core.publisher.Mono; - -/** - * Registry for looking up transports by URI. - * - *

Uses the Jar Services mechanism with services defined by {@link UriHandler}. - */ -public class UriTransportRegistry { - private static final ClientTransport FAILED_CLIENT_LOOKUP = - () -> Mono.error(new UnsupportedOperationException()); - private static final ServerTransport FAILED_SERVER_LOOKUP = - acceptor -> Mono.error(new UnsupportedOperationException()); - - private List handlers; - - public UriTransportRegistry(ServiceLoader services) { - handlers = new ArrayList<>(); - services.forEach(handlers::add); - } - - public static UriTransportRegistry fromServices() { - ServiceLoader services = loadServices(); - - return new UriTransportRegistry(services); - } - - public static ClientTransport clientForUri(String uri) { - return UriTransportRegistry.fromServices().findClient(uri); - } - - private ClientTransport findClient(String uriString) { - URI uri = URI.create(uriString); - - for (UriHandler h : handlers) { - Optional r = h.buildClient(uri); - if (r.isPresent()) { - return r.get(); - } - } - - return FAILED_CLIENT_LOOKUP; - } - - public static ServerTransport serverForUri(String uri) { - return UriTransportRegistry.fromServices().findServer(uri); - } - - private ServerTransport findServer(String uriString) { - URI uri = URI.create(uriString); - - for (UriHandler h : handlers) { - Optional r = h.buildServer(uri); - if (r.isPresent()) { - return r.get(); - } - } - - return FAILED_SERVER_LOOKUP; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index 0b33e5e22..4cf33fa86 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -21,13 +21,14 @@ import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.Recycler; import io.netty.util.Recycler.Handle; import io.rsocket.Payload; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.Charset; -import javax.annotation.Nullable; +import reactor.util.annotation.Nullable; public final class ByteBufPayload extends AbstractReferenceCounted implements Payload { private static final Recycler RECYCLER = @@ -45,62 +46,6 @@ private ByteBufPayload(final Handle handle) { this.handle = handle; } - @Override - public boolean hasMetadata() { - return metadata != null; - } - - @Override - public ByteBuf sliceMetadata() { - return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); - } - - @Override - public ByteBuf sliceData() { - return data.slice(); - } - - @Override - public ByteBufPayload retain() { - super.retain(); - return this; - } - - @Override - public ByteBufPayload retain(int increment) { - super.retain(increment); - return this; - } - - @Override - public ByteBufPayload touch() { - data.touch(); - if (metadata != null) { - metadata.touch(); - } - return this; - } - - @Override - public ByteBufPayload touch(Object hint) { - data.touch(hint); - if (metadata != null) { - metadata.touch(hint); - } - return this; - } - - @Override - protected void deallocate() { - data.release(); - data = null; - if (metadata != null) { - metadata.release(); - metadata = null; - } - handle.recycle(this); - } - /** * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" * @@ -168,9 +113,10 @@ public static Payload create(ByteBuf data) { public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { ByteBufPayload payload = RECYCLER.get(); - payload.setRefCnt(1); payload.data = data; payload.metadata = metadata; + // unsure data and metadata is set before refCnt change + payload.setRefCnt(1); return payload; } @@ -179,4 +125,95 @@ public static Payload create(Payload payload) { payload.sliceData().retain(), payload.hasMetadata() ? payload.sliceMetadata().retain() : null); } + + @Override + public boolean hasMetadata() { + ensureAccessible(); + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); + } + + @Override + public ByteBuf data() { + ensureAccessible(); + return data; + } + + @Override + public ByteBuf metadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + ensureAccessible(); + return data.slice(); + } + + @Override + public ByteBufPayload retain() { + super.retain(); + return this; + } + + @Override + public ByteBufPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public ByteBufPayload touch() { + ensureAccessible(); + data.touch(); + if (metadata != null) { + metadata.touch(); + } + return this; + } + + @Override + public ByteBufPayload touch(Object hint) { + ensureAccessible(); + data.touch(hint); + if (metadata != null) { + metadata.touch(hint); + } + return this; + } + + @Override + protected void deallocate() { + data.release(); + data = null; + if (metadata != null) { + metadata.release(); + metadata = null; + } + handle.recycle(this); + } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link ByteBufPayload#ensureAccessible()} to try to guard against using the + * buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java new file mode 100644 index 000000000..328fb8435 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java @@ -0,0 +1,210 @@ +package io.rsocket.util; + +import static io.netty.util.internal.StringUtil.isSurrogate; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.MathUtil; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.util.Arrays; + +public class CharByteBufUtil { + + private static final byte WRITE_UTF_UNKNOWN = (byte) '?'; + + private CharByteBufUtil() {} + + /** + * Returns the exact bytes length of UTF8 character sequence. + * + *

This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[])}. + */ + public static int utf8Bytes(final char[] seq) { + return utf8ByteCount(seq, 0, seq.length); + } + + /** + * This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[], int, + * int)}. + */ + public static int utf8Bytes(final char[] seq, int start, int end) { + return utf8ByteCount(checkCharSequenceBounds(seq, start, end), start, end); + } + + private static int utf8ByteCount(final char[] seq, int start, int end) { + int i = start; + // ASCII fast path + while (i < end && seq[i] < 0x80) { + ++i; + } + // !ASCII is packed in a separate method to let the ASCII case be smaller + return i < end ? (i - start) + utf8BytesNonAscii(seq, i, end) : i - start; + } + + private static int utf8BytesNonAscii(final char[] seq, final int start, final int end) { + int encodedLength = 0; + for (int i = start; i < end; i++) { + final char c = seq[i]; + // making it 100% branchless isn't rewarding due to the many bit operations necessary! + if (c < 0x800) { + // branchless version of: (c <= 127 ? 0:1) + 1 + encodedLength += ((0x7f - c) >>> 31) + 1; + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + encodedLength++; + // WRITE_UTF_UNKNOWN + continue; + } + final char c2; + try { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. + c2 = seq[++i]; + } catch (IndexOutOfBoundsException ignored) { + encodedLength++; + // WRITE_UTF_UNKNOWN + break; + } + if (!Character.isLowSurrogate(c2)) { + // WRITE_UTF_UNKNOWN + (Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2) + encodedLength += 2; + continue; + } + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + encodedLength += 4; + } else { + encodedLength += 3; + } + } + return encodedLength; + } + + private static char[] checkCharSequenceBounds(char[] seq, int start, int end) { + if (MathUtil.isOutOfBounds(start, end - start, seq.length)) { + throw new IndexOutOfBoundsException( + "expected: 0 <= start(" + + start + + ") <= end (" + + end + + ") <= seq.length(" + + seq.length + + ')'); + } + return seq; + } + + /** + * Encode a {@code char[]} in UTF-8 and write it + * into {@link ByteBuf}. + * + *

This method returns the actual number of bytes written. + */ + public static int writeUtf8(ByteBuf buf, char[] seq) { + return writeUtf8(buf, seq, 0, seq.length); + } + + /** + * Equivalent to {@link #writeUtf8(ByteBuf, char[]) writeUtf8(buf, seq.subSequence(start, end), + * reserveBytes)} but avoids subsequence object allocation if possible. + * + * @return actual number of bytes written + */ + public static int writeUtf8(ByteBuf buf, char[] seq, int start, int end) { + return writeUtf8(buf, buf.writerIndex(), checkCharSequenceBounds(seq, start, end), start, end); + } + + // Fast-Path implementation + static int writeUtf8(ByteBuf buffer, int writerIndex, char[] seq, int start, int end) { + int oldWriterIndex = writerIndex; + + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = start; i < end; i++) { + char c = seq[i]; + if (c < 0x80) { + buffer.setByte(writerIndex++, (byte) c); + } else if (c < 0x800) { + buffer.setByte(writerIndex++, (byte) (0xc0 | (c >> 6))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + continue; + } + final char c2; + if (seq.length > ++i) { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. If an IndexOutOfBoundsException is thrown we + // will + // re-throw a more informative exception describing the problem. + c2 = seq[i]; + } else { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + break; + } + // Extra method to allow inlining the rest of writeUtf8 which is the most likely code path. + writerIndex = writeUtf8Surrogate(buffer, writerIndex, c, c2); + } else { + buffer.setByte(writerIndex++, (byte) (0xe0 | (c >> 12))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((c >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } + } + buffer.writerIndex(writerIndex); + return writerIndex - oldWriterIndex; + } + + private static int writeUtf8Surrogate(ByteBuf buffer, int writerIndex, char c, char c2) { + if (!Character.isLowSurrogate(c2)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + buffer.setByte(writerIndex++, Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2); + return writerIndex; + } + int codePoint = Character.toCodePoint(c, c2); + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer.setByte(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + return writerIndex; + } + + public static char[] readUtf8(ByteBuf byteBuf, int length) { + CharsetDecoder charsetDecoder = CharsetUtil.UTF_8.newDecoder(); + int en = (int) (length * (double) charsetDecoder.maxCharsPerByte()); + char[] ca = new char[en]; + + CharBuffer charBuffer = CharBuffer.wrap(ca); + ByteBuffer byteBuffer = + byteBuf.nioBufferCount() == 1 + ? byteBuf.internalNioBuffer(byteBuf.readerIndex(), length) + : byteBuf.nioBuffer(byteBuf.readerIndex(), length); + byteBuffer.mark(); + try { + CoderResult cr = charsetDecoder.decode(byteBuffer, charBuffer, true); + if (!cr.isUnderflow()) cr.throwException(); + cr = charsetDecoder.flush(charBuffer); + if (!cr.isUnderflow()) cr.throwException(); + + byteBuffer.reset(); + byteBuf.skipBytes(length); + + return safeTrim(charBuffer.array(), charBuffer.position()); + } catch (CharacterCodingException x) { + // Substitution is always enabled, + // so this shouldn't happen + throw new IllegalStateException("unable to decode char array from the given buffer", x); + } + } + + private static char[] safeTrim(char[] ca, int len) { + if (len == ca.length) return ca; + else return Arrays.copyOf(ca, len); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java index 71bbf3874..58f282110 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -23,7 +23,7 @@ import java.nio.CharBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import javax.annotation.Nullable; +import reactor.util.annotation.Nullable; /** * An implementation of {@link Payload}. This implementation is not thread-safe, and hence @@ -40,66 +40,6 @@ private DefaultPayload(ByteBuffer data, @Nullable ByteBuffer metadata) { this.metadata = metadata; } - @Override - public boolean hasMetadata() { - return metadata != null; - } - - @Override - public ByteBuf sliceMetadata() { - return metadata == null ? Unpooled.EMPTY_BUFFER : Unpooled.wrappedBuffer(metadata); - } - - @Override - public ByteBuf sliceData() { - return Unpooled.wrappedBuffer(data); - } - - @Override - public ByteBuffer getMetadata() { - return metadata == null ? DefaultPayload.EMPTY_BUFFER : metadata.duplicate(); - } - - @Override - public ByteBuffer getData() { - return data.duplicate(); - } - - @Override - public int refCnt() { - return 1; - } - - @Override - public DefaultPayload retain() { - return this; - } - - @Override - public DefaultPayload retain(int increment) { - return this; - } - - @Override - public DefaultPayload touch() { - return this; - } - - @Override - public DefaultPayload touch(Object hint) { - return this; - } - - @Override - public boolean release() { - return false; - } - - @Override - public boolean release(int decrement) { - return false; - } - /** * Static factory method for a text payload. Mainly looks better than "new DefaultPayload(data)" * @@ -167,4 +107,74 @@ public static Payload create(Payload payload) { Unpooled.copiedBuffer(payload.sliceData()), payload.hasMetadata() ? Unpooled.copiedBuffer(payload.sliceMetadata()) : null); } + + @Override + public boolean hasMetadata() { + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + return metadata == null ? Unpooled.EMPTY_BUFFER : Unpooled.wrappedBuffer(metadata); + } + + @Override + public ByteBuf sliceData() { + return Unpooled.wrappedBuffer(data); + } + + @Override + public ByteBuffer getMetadata() { + return metadata == null ? DefaultPayload.EMPTY_BUFFER : metadata.duplicate(); + } + + @Override + public ByteBuffer getData() { + return data.duplicate(); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public DefaultPayload retain() { + return this; + } + + @Override + public DefaultPayload retain(int increment) { + return this; + } + + @Override + public DefaultPayload touch() { + return this; + } + + @Override + public DefaultPayload touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/DisposableUtils.java b/rsocket-core/src/main/java/io/rsocket/util/DisposableUtils.java deleted file mode 100644 index c87a08220..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/DisposableUtils.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.util; - -import java.util.Arrays; -import reactor.core.Disposable; - -/** Utilities for working with the {@link Disposable} type. */ -public final class DisposableUtils { - - private DisposableUtils() {} - - /** - * Calls the {@link Disposable#dispose()} method if the instance is not null. If any exceptions - * are thrown during disposal, suppress them. - * - * @param disposables the {@link Disposable}s to dispose - */ - public static void disposeQuietly(Disposable... disposables) { - Arrays.stream(disposables) - .forEach( - disposable -> { - try { - if (disposable != null) { - disposable.dispose(); - } - } catch (RuntimeException e) { - // Suppress any exceptions during disposal - } - }); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java b/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java index d5eda1d6b..99df97d70 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java @@ -40,6 +40,16 @@ public ByteBuf sliceData() { return Unpooled.EMPTY_BUFFER; } + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + @Override public int refCnt() { return 1; diff --git a/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java index 12e3cee45..3ff720447 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java +++ b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java @@ -16,6 +16,7 @@ package io.rsocket.util; +import io.netty.buffer.ByteBuf; import java.util.Objects; public final class NumberUtils { @@ -143,4 +144,21 @@ public static int requireUnsignedShort(int i) { return i; } + + /** + * Encode an unsigned medium integer on 3 bytes / 24 bits. This can be decoded directly by the + * {@link ByteBuf#readUnsignedMedium()} method. + * + * @param byteBuf the {@link ByteBuf} into which to write the bits + * @param i the medium integer to encode + * @see #requireUnsignedMedium(int) + */ + public static void encodeUnsignedMedium(ByteBuf byteBuf, int i) { + requireUnsignedMedium(i); + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(i >> 16); + byteBuf.writeByte(i >> 8); + byteBuf.writeByte(i); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/RecyclerFactory.java b/rsocket-core/src/main/java/io/rsocket/util/RecyclerFactory.java deleted file mode 100644 index 30385195c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/RecyclerFactory.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.util; - -import io.netty.util.Recycler; -import io.netty.util.Recycler.Handle; -import java.util.Objects; -import java.util.function.Function; - -/** A factory for creating {@link Recycler}s. */ -public final class RecyclerFactory { - - /** - * Creates a new {@link Recycler}. - * - * @param newObjectCreator the {@link Function} to create a new object - * @param the type being recycled. - * @return the {@link Recycler} - * @throws NullPointerException if {@code newObjectCreator} is {@code null} - */ - public static Recycler createRecycler(Function, T> newObjectCreator) { - Objects.requireNonNull(newObjectCreator, "newObjectCreator must not be null"); - - return new Recycler() { - - @Override - protected T newObject(Handle handle) { - return newObjectCreator.apply(handle); - } - }; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/package-info.java b/rsocket-core/src/main/java/io/rsocket/util/package-info.java index 79123d3b2..2fac3327f 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/util/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,5 +14,8 @@ * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** Shared utility classes and {@link io.rsocket.Payload} implementations. */ +@NonNullApi package io.rsocket.util; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java deleted file mode 100644 index 84e6a6a43..000000000 --- a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java +++ /dev/null @@ -1,172 +0,0 @@ -package io.rsocket; - -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.exceptions.ConnectionErrorException; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.FrameType; -import io.rsocket.frame.KeepAliveFrameFlyweight; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.util.DefaultPayload; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Consumer; -import java.util.function.Supplier; -import java.util.stream.Stream; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.MethodSource; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -public class KeepAliveTest { - private static final int CLIENT_REQUESTER_TICK_PERIOD = 100; - private static final int CLIENT_REQUESTER_TIMEOUT = 700; - private static final int CLIENT_REQUESTER_MISSED_ACKS = 3; - private static final int SERVER_RESPONDER_TICK_PERIOD = 100; - private static final int SERVER_RESPONDER_TIMEOUT = 1000; - - static Stream> testData() { - return Stream.of( - requester( - CLIENT_REQUESTER_TICK_PERIOD, CLIENT_REQUESTER_TIMEOUT, CLIENT_REQUESTER_MISSED_ACKS), - responder(SERVER_RESPONDER_TICK_PERIOD, SERVER_RESPONDER_TIMEOUT)); - } - - static Supplier requester(int tickPeriod, int timeout, int missedAcks) { - return () -> { - TestDuplexConnection connection = new TestDuplexConnection(); - Errors errors = new Errors(); - RSocketClient rSocket = - new RSocketClient( - ByteBufAllocator.DEFAULT, - connection, - DefaultPayload::create, - errors, - StreamIdSupplier.clientSupplier(), - Duration.ofMillis(tickPeriod), - Duration.ofMillis(timeout), - missedAcks); - return new TestData(rSocket, errors, connection); - }; - } - - static Supplier responder(int tickPeriod, int timeout) { - return () -> { - TestDuplexConnection connection = new TestDuplexConnection(); - AbstractRSocket handler = new AbstractRSocket() {}; - Errors errors = new Errors(); - RSocketServer rSocket = - new RSocketServer( - ByteBufAllocator.DEFAULT, - connection, - handler, - DefaultPayload::create, - errors, - tickPeriod, - timeout); - return new TestData(rSocket, errors, connection); - }; - } - - @ParameterizedTest - @MethodSource("testData") - void keepAlives(Supplier testDataSupplier) { - TestData testData = testDataSupplier.get(); - TestDuplexConnection connection = testData.connection(); - - Flux.interval(Duration.ofMillis(100)) - .subscribe( - n -> - connection.addToReceivedBuffer( - KeepAliveFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER))); - - Mono.delay(Duration.ofMillis(1500)).block(); - - RSocket rSocket = testData.rSocket(); - List errors = testData.errors().errors(); - - Assertions.assertThat(rSocket.isDisposed()).isFalse(); - Assertions.assertThat(errors).isEmpty(); - } - - @ParameterizedTest - @MethodSource("testData") - void keepAlivesMissing(Supplier testDataSupplier) { - TestData testData = testDataSupplier.get(); - RSocket rSocket = testData.rSocket(); - - Mono.delay(Duration.ofMillis(1500)).block(); - - List errors = testData.errors().errors(); - Assertions.assertThat(rSocket.isDisposed()).isTrue(); - Assertions.assertThat(errors).hasSize(1); - Throwable throwable = errors.get(0); - Assertions.assertThat(throwable).isInstanceOf(ConnectionErrorException.class); - } - - @Test - void clientRequesterRespondsToKeepAlives() { - TestData testData = requester(100, 700, 3).get(); - TestDuplexConnection connection = testData.connection(); - - Mono.delay(Duration.ofMillis(100)) - .subscribe( - l -> - connection.addToReceivedBuffer( - KeepAliveFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER))); - - Mono keepAliveResponse = - Flux.from(connection.getSentAsPublisher()) - .filter( - f -> - FrameHeaderFlyweight.frameType(f) == FrameType.KEEPALIVE - && !KeepAliveFrameFlyweight.respondFlag(f)) - .next() - .then(); - - StepVerifier.create(keepAliveResponse).expectComplete().verify(Duration.ofSeconds(5)); - } - - static class TestData { - private final RSocket rSocket; - private final Errors errors; - private final TestDuplexConnection connection; - - public TestData(RSocket rSocket, Errors errors, TestDuplexConnection connection) { - this.rSocket = rSocket; - this.errors = errors; - this.connection = connection; - } - - public TestDuplexConnection connection() { - return connection; - } - - public RSocket rSocket() { - return rSocket; - } - - public Errors errors() { - return errors; - } - } - - static class Errors implements Consumer { - private final List errors = new ArrayList<>(); - - @Override - public void accept(Throwable throwable) { - errors.add(throwable); - } - - public List errors() { - return new ArrayList<>(errors); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java deleted file mode 100644 index 2224ba393..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static io.rsocket.frame.FrameType.*; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.*; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -import org.junit.Rule; -import org.junit.Test; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.publisher.BaseSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -public class RSocketClientTest { - - @Rule public final ClientSocketRule rule = new ClientSocketRule(); - - @Test(timeout = 2_000) - public void testKeepAlive() throws Exception { - assertThat("Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(KEEPALIVE)); - } - - @Test(timeout = 2_000) - public void testInvalidFrameOnStream0() { - rule.connection.addToReceivedBuffer( - RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 0, 10)); - assertThat("Unexpected errors.", rule.errors, hasSize(1)); - assertThat( - "Unexpected error received.", - rule.errors, - contains(instanceOf(IllegalStateException.class))); - } - - @Test(timeout = 2_000) - public void testStreamInitialN() { - Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); - - BaseSubscriber subscriber = - new BaseSubscriber() { - @Override - protected void hookOnSubscribe(Subscription subscription) { - // don't request here - // subscription.request(3); - } - }; - stream.subscribe(subscriber); - - subscriber.request(5); - - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> frameType(f) != KEEPALIVE) - .collect(Collectors.toList()); - - assertThat("sent frame count", sent.size(), is(1)); - - ByteBuf f = sent.get(0); - - assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); - assertThat("initial request n", RequestStreamFrameFlyweight.initialRequestN(f), is(5)); - } - - @Test(timeout = 2_000) - public void testHandleSetupException() { - rule.connection.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("boom"))); - assertThat("Unexpected errors.", rule.errors, hasSize(1)); - assertThat( - "Unexpected error received.", - rule.errors, - contains(instanceOf(RejectedSetupException.class))); - } - - @Test(timeout = 2_000) - public void testHandleApplicationException() { - rule.connection.clearSendReceiveBuffers(); - Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber responseSub = TestSubscriber.create(); - response.subscribe(responseSub); - - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, new ApplicationErrorException("error"))); - - verify(responseSub).onError(any(ApplicationErrorException.class)); - } - - @Test(timeout = 2_000) - public void testHandleValidFrame() { - Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeNext( - ByteBufAllocator.DEFAULT, streamId, EmptyPayload.INSTANCE)); - - verify(sub).onComplete(); - } - - @Test(timeout = 2_000) - public void testRequestReplyWithCancel() { - Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - - try { - response.block(Duration.ofMillis(100)); - } catch (IllegalStateException ise) { - } - - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> frameType(f) != KEEPALIVE) - .collect(Collectors.toList()); - - assertThat( - "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); - assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); - } - - @Test(timeout = 2_000) - public void testRequestReplyErrorOnSend() { - rule.connection.setAvailability(0); // Fails send - Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber responseSub = TestSubscriber.create(10); - response.subscribe(responseSub); - - this.rule.assertNoConnectionErrors(); - - verify(responseSub).onSubscribe(any(Subscription.class)); - - // TODO this should get the error reported through the response subscription - // verify(responseSub).onError(any(RuntimeException.class)); - } - - @Test(timeout = 2_000) - public void testLazyRequestResponse() { - Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - int streamId = sendRequestResponse(response); - rule.connection.clearSendReceiveBuffers(); - int streamId2 = sendRequestResponse(response); - assertThat("Stream ID reused.", streamId2, not(equalTo(streamId))); - } - - @Test - public void testChannelRequestCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - Flux request = Flux.never().doOnCancel(cancelled::onComplete); - rule.socket.requestChannel(request).subscribe().dispose(); - Flux.first( - cancelled, - Flux.error(new IllegalStateException("Channel request not cancelled")) - .delaySubscription(Duration.ofSeconds(1))) - .blockFirst(); - } - - public int sendRequestResponse(Publisher response) { - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeNextComplete( - ByteBufAllocator.DEFAULT, streamId, EmptyPayload.INSTANCE)); - verify(sub).onNext(any(Payload.class)); - verify(sub).onComplete(); - return streamId; - } - - public static class ClientSocketRule extends AbstractSocketRule { - @Override - protected RSocketClient newRSocket() { - return new RSocketClient( - ByteBufAllocator.DEFAULT, - connection, - DefaultPayload::create, - throwable -> errors.add(throwable), - StreamIdSupplier.clientSupplier(), - Duration.ofMillis(100), - Duration.ofMillis(10_000), - 4); - } - - public int getStreamIdForRequestType(FrameType expectedFrameType) { - assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); - List framesFound = new ArrayList<>(); - for (ByteBuf frame : connection.getSent()) { - FrameType frameType = frameType(frame); - if (frameType == expectedFrameType) { - return FrameHeaderFlyweight.streamId(frame); - } - framesFound.add(frameType); - } - throw new AssertionError( - "No frames sent with frame type: " - + expectedFrameType - + ", frames found: " - + framesFound); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java deleted file mode 100644 index 1d3417e02..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java +++ /dev/null @@ -1,166 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.frame.*; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import java.util.Collection; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.Mono; - -public class RSocketServerTest { - - @Rule public final ServerSocketRule rule = new ServerSocketRule(); - - @Test(timeout = 2000) - @Ignore - public void testHandleKeepAlive() throws Exception { - rule.connection.addToReceivedBuffer( - KeepAliveFrameFlyweight.encode(ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER)); - ByteBuf sent = rule.connection.awaitSend(); - assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); - /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ - assertThat( - "Unexpected keep-alive frame respond flag.", - KeepAliveFrameFlyweight.respondFlag(sent), - is(false)); - } - - @Test(timeout = 2000) - @Ignore - public void testHandleResponseFrameNoError() throws Exception { - final int streamId = 4; - rule.connection.clearSendReceiveBuffers(); - - rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - Collection> sendSubscribers = rule.connection.getSendSubscribers(); - assertThat("Request not sent.", sendSubscribers, hasSize(1)); - assertThat("Unexpected error.", rule.errors, is(empty())); - Subscriber sendSub = sendSubscribers.iterator().next(); - assertThat( - "Unexpected frame sent.", - frameType(rule.connection.awaitSend()), - anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); - } - - @Test(timeout = 2000) - @Ignore - public void testHandlerEmitsError() throws Exception { - final int streamId = 4; - rule.sendRequest(streamId, FrameType.REQUEST_STREAM); - assertThat("Unexpected error.", rule.errors, is(empty())); - assertThat( - "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); - } - - @Test(timeout = 2_0000) - public void testCancel() { - final int streamId = 4; - final AtomicBoolean cancelled = new AtomicBoolean(); - rule.setAcceptingSocket( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.never().doOnCancel(() -> cancelled.set(true)); - } - }); - rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - assertThat("Unexpected error.", rule.errors, is(empty())); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - - rule.connection.addToReceivedBuffer( - CancelFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId)); - - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - assertThat("Subscription not cancelled.", cancelled.get(), is(true)); - } - - public static class ServerSocketRule extends AbstractSocketRule { - - private RSocket acceptingSocket; - - @Override - protected void init() { - acceptingSocket = - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - }; - super.init(); - } - - public void setAcceptingSocket(RSocket acceptingSocket) { - this.acceptingSocket = acceptingSocket; - connection = new TestDuplexConnection(); - connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); - super.init(); - } - - @Override - protected RSocketServer newRSocket() { - return new RSocketServer( - ByteBufAllocator.DEFAULT, - connection, - acceptingSocket, - DefaultPayload::create, - throwable -> errors.add(throwable)); - } - - private void sendRequest(int streamId, FrameType frameType) { - ByteBuf request; - - switch (frameType) { - case REQUEST_STREAM: - request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, 1, EmptyPayload.INSTANCE); - break; - case REQUEST_RESPONSE: - request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, EmptyPayload.INSTANCE); - break; - default: - throw new IllegalArgumentException("unsupported type: " + frameType); - } - - connection.addToReceivedBuffer(request); - connection.addToReceivedBuffer( - RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, 2)); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketTest.java deleted file mode 100644 index 7fcf46674..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.test.util.LocalDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import java.util.ArrayList; -import org.hamcrest.MatcherAssert; -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.DirectProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -public class RSocketTest { - - @Rule public final SocketRule rule = new SocketRule(); - - public static void assertError(String s, String mode, ArrayList errors) { - for (Throwable t : errors) { - if (t.toString().equals(s)) { - return; - } - } - - Assert.fail("Expected " + mode + " connection error: " + s + " other errors " + errors.size()); - } - - @Test(timeout = 2_000) - public void testRequestReplyNoError() { - StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) - .expectNextCount(1) - .expectComplete() - .verify(); - } - - @Test(timeout = 2000) - public void testHandlerEmitsError() { - rule.setRequestAcceptor( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.error(new NullPointerException("Deliberate exception.")); - } - }); - Subscriber subscriber = TestSubscriber.create(); - rule.crs.requestResponse(EmptyPayload.INSTANCE).subscribe(subscriber); - verify(subscriber).onError(any(ApplicationErrorException.class)); - - // Client sees error through normal API - rule.assertNoClientErrors(); - - rule.assertServerError("java.lang.NullPointerException: Deliberate exception."); - } - - @Test(timeout = 2000) - public void testChannel() throws Exception { - Flux requests = - Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); - Flux responses = rule.crs.requestChannel(requests); - StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); - } - - public static class SocketRule extends ExternalResource { - - DirectProcessor serverProcessor; - DirectProcessor clientProcessor; - private RSocketClient crs; - - @SuppressWarnings("unused") - private RSocketServer srs; - - private RSocket requestAcceptor; - private ArrayList clientErrors = new ArrayList<>(); - private ArrayList serverErrors = new ArrayList<>(); - - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - init(); - base.evaluate(); - } - }; - } - - protected void init() { - serverProcessor = DirectProcessor.create(); - clientProcessor = DirectProcessor.create(); - - LocalDuplexConnection serverConnection = - new LocalDuplexConnection("server", clientProcessor, serverProcessor); - LocalDuplexConnection clientConnection = - new LocalDuplexConnection("client", serverProcessor, clientProcessor); - - requestAcceptor = - null != requestAcceptor - ? requestAcceptor - : new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.never(); - } - - @Override - public Flux requestChannel(Publisher payloads) { - Flux.from(payloads) - .map( - payload -> - DefaultPayload.create("server got -> [" + payload.toString() + "]")) - .subscribe(); - - return Flux.range(1, 10) - .map( - payload -> - DefaultPayload.create("server got -> [" + payload.toString() + "]")); - } - }; - - srs = - new RSocketServer( - ByteBufAllocator.DEFAULT, - serverConnection, - requestAcceptor, - DefaultPayload::create, - throwable -> serverErrors.add(throwable)); - - crs = - new RSocketClient( - ByteBufAllocator.DEFAULT, - clientConnection, - DefaultPayload::create, - throwable -> clientErrors.add(throwable), - StreamIdSupplier.clientSupplier()); - } - - public void setRequestAcceptor(RSocket requestAcceptor) { - this.requestAcceptor = requestAcceptor; - init(); - } - - public void assertNoErrors() { - assertNoClientErrors(); - assertNoServerErrors(); - } - - public void assertNoClientErrors() { - MatcherAssert.assertThat( - "Unexpected error on the client connection.", clientErrors, is(empty())); - } - - public void assertNoServerErrors() { - MatcherAssert.assertThat( - "Unexpected error on the server connection.", serverErrors, is(empty())); - } - - public void assertClientError(String s) { - assertError(s, "client", this.clientErrors); - } - - public void assertServerError(String s) { - assertError(s, "server", this.serverErrors); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/TestScheduler.java b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java new file mode 100644 index 000000000..7bc98d45d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java @@ -0,0 +1,80 @@ +package io.rsocket; + +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; +import reactor.core.scheduler.Scheduler; +import reactor.util.concurrent.Queues; + +/** + * This is an implementation of scheduler which allows task execution on the caller thread or + * scheduling it for thread which are currently working (with "work stealing" behaviour) + */ +public final class TestScheduler implements Scheduler { + + public static final Scheduler INSTANCE = new TestScheduler(); + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(TestScheduler.class, "wip"); + + final Worker sharedWorker = new TestWorker(this); + final Queue tasks = Queues.unboundedMultiproducer().get(); + + private TestScheduler() {} + + @Override + public Disposable schedule(Runnable task) { + tasks.offer(task); + if (WIP.getAndIncrement(this) != 0) { + return Disposables.never(); + } + + int missed = 1; + + for (; ; ) { + for (; ; ) { + Runnable runnable = tasks.poll(); + + if (runnable == null) { + break; + } + + try { + runnable.run(); + } catch (Throwable t) { + Exceptions.throwIfFatal(t); + } + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + return Disposables.never(); + } + } + } + + @Override + public Worker createWorker() { + return sharedWorker; + } + + static class TestWorker implements Worker { + + final TestScheduler parent; + + TestWorker(TestScheduler parent) { + this.parent = parent; + } + + @Override + public Disposable schedule(Runnable task) { + return parent.schedule(task); + } + + @Override + public void dispose() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..800e5d678 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,153 @@ +package io.rsocket.buffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + private LeaksTrackingByteBufAllocator(ByteBufAllocator delegate) { + this.delegate = delegate; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + Assertions.assertThat(tracker) + .allSatisfy( + buf -> + Assertions.assertThat(buf) + .matches(bb -> bb.refCnt() == 0, "buffer should be released")); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java similarity index 75% rename from rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java rename to rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java index 22568bfcc..ac5832aaa 100644 --- a/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -14,12 +14,13 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; -import java.util.concurrent.ConcurrentLinkedQueue; -import org.junit.Assert; import org.junit.rules.ExternalResource; import org.junit.runner.Description; import org.junit.runners.model.Statement; @@ -30,16 +31,16 @@ public abstract class AbstractSocketRule extends ExternalReso protected TestDuplexConnection connection; protected Subscriber connectSub; protected T socket; - protected ConcurrentLinkedQueue errors; + protected LeaksTrackingByteBufAllocator allocator; @Override public Statement apply(final Statement base, Description description) { return new Statement() { @Override public void evaluate() throws Throwable { - connection = new TestDuplexConnection(); + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(allocator); connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); init(); base.evaluate(); } @@ -52,9 +53,11 @@ protected void init() { protected abstract T newRSocket(); - public void assertNoConnectionErrors() { - if (errors.size() > 1) { - Assert.fail("No connection errors expected: " + errors.peek().toString()); - } + public ByteBufAllocator alloc() { + return allocator; + } + + public void assertHasNoLeaks() { + allocator.assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java new file mode 100644 index 000000000..8eb5dee09 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java @@ -0,0 +1,90 @@ +package io.rsocket.core; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.util.DefaultPayload; +import org.junit.jupiter.api.Test; + +class ConnectionSetupPayloadTest { + private static final int KEEP_ALIVE_INTERVAL = 5; + private static final int KEEP_ALIVE_MAX_LIFETIME = 500; + private static final String METADATA_TYPE = "metadata_type"; + private static final String DATA_TYPE = "data_type"; + + @Test + void testSetupPayloadWithDataMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {2, 1, 0}); + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = true; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertTrue(setupPayload.willClientHonorLease()); + assertEquals(KEEP_ALIVE_INTERVAL, setupPayload.keepAliveInterval()); + assertEquals(KEEP_ALIVE_MAX_LIFETIME, setupPayload.keepAliveMaxLifetime()); + assertEquals(METADATA_TYPE, SetupFrameCodec.metadataMimeType(frame)); + assertEquals(DATA_TYPE, SetupFrameCodec.dataMimeType(frame)); + assertTrue(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(payload.metadata(), setupPayload.metadata()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + @Test + void testSetupPayloadWithNoMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = null; + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = false; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertFalse(setupPayload.willClientHonorLease()); + assertFalse(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(0, setupPayload.metadata().readableBytes()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + @Test + void testSetupPayloadWithEmptyMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = Unpooled.EMPTY_BUFFER; + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = false; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertFalse(setupPayload.willClientHonorLease()); + assertTrue(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(0, setupPayload.metadata().readableBytes()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + private static ByteBuf encodeSetupFrame(boolean leaseEnabled, Payload setupPayload) { + return SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + leaseEnabled, + KEEP_ALIVE_INTERVAL, + KEEP_ALIVE_MAX_LIFETIME, + Unpooled.EMPTY_BUFFER, + METADATA_TYPE, + DATA_TYPE, + setupPayload); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java new file mode 100644 index 000000000..d98f86113 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -0,0 +1,297 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.keepalive.KeepAliveHandler.DefaultKeepAliveHandler; +import static io.rsocket.keepalive.KeepAliveHandler.ResumableKeepAliveHandler; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public class KeepAliveTest { + private static final int KEEP_ALIVE_INTERVAL = 100; + private static final int KEEP_ALIVE_TIMEOUT = 1000; + private static final int RESUMABLE_KEEP_ALIVE_TIMEOUT = 200; + + private RSocketState requesterState; + private ResumableRSocketState resumableRequesterState; + + static RSocketState requester(int tickPeriod, int timeout) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); + RSocketRequester rSocket = + new RSocketRequester( + connection, + DefaultPayload::create, + StreamIdSupplier.clientSupplier(), + 0, + tickPeriod, + timeout, + new DefaultKeepAliveHandler(connection), + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + return new RSocketState(rSocket, allocator, connection); + } + + static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); + ResumableDuplexConnection resumableConnection = + new ResumableDuplexConnection( + "test", + connection, + new InMemoryResumableFramesStore("test", 10_000), + Duration.ofSeconds(10), + false); + + RSocketRequester rSocket = + new RSocketRequester( + resumableConnection, + DefaultPayload::create, + StreamIdSupplier.clientSupplier(), + 0, + tickPeriod, + timeout, + new ResumableKeepAliveHandler(resumableConnection), + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + return new ResumableRSocketState(rSocket, connection, resumableConnection, allocator); + } + + @BeforeEach + void setUp() { + requesterState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + resumableRequesterState = resumableRequester(KEEP_ALIVE_INTERVAL, RESUMABLE_KEEP_ALIVE_TIMEOUT); + } + + @Test + void rSocketNotDisposedOnPresentKeepAlives() { + TestDuplexConnection connection = requesterState.connection(); + + Flux.interval(Duration.ofMillis(100)) + .subscribe( + n -> + connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode( + ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER))); + + Mono.delay(Duration.ofMillis(2000)).block(); + + RSocket rSocket = requesterState.rSocket(); + + Assertions.assertThat(rSocket.isDisposed()).isFalse(); + } + + @Test + void noKeepAlivesSentAfterRSocketDispose() { + requesterState.rSocket().dispose(); + StepVerifier.create( + Flux.from(requesterState.connection().getSentAsPublisher()) + .take(Duration.ofMillis(500))) + .expectComplete() + .verify(Duration.ofSeconds(1)); + } + + @Test + void rSocketDisposedOnMissingKeepAlives() { + RSocket rSocket = requesterState.rSocket(); + + Mono.delay(Duration.ofMillis(2000)).block(); + + Assertions.assertThat(rSocket.isDisposed()).isTrue(); + rSocket + .onClose() + .as(StepVerifier::create) + .expectError(ConnectionErrorException.class) + .verify(Duration.ofMillis(100)); + } + + @Test + void clientRequesterSendsKeepAlives() { + RSocketState RSocketState = requester(100, 1000); + TestDuplexConnection connection = RSocketState.connection(); + + StepVerifier.create(Flux.from(connection.getSentAsPublisher()).take(3)) + .expectNextMatches(this::keepAliveFrameWithRespondFlag) + .expectNextMatches(this::keepAliveFrameWithRespondFlag) + .expectNextMatches(this::keepAliveFrameWithRespondFlag) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + @Test + void requesterRespondsToKeepAlives() { + RSocketState RSocketState = requester(100_000, 100_000); + TestDuplexConnection connection = RSocketState.connection(); + Mono.delay(Duration.ofMillis(100)) + .subscribe( + l -> + connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode( + ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER))); + + StepVerifier.create(Flux.from(connection.getSentAsPublisher()).take(1)) + .expectNextMatches(this::keepAliveFrameWithoutRespondFlag) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + @Test + void resumableRequesterNoKeepAlivesAfterDisconnect() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + TestDuplexConnection testConnection = rSocketState.connection(); + ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); + + resumableDuplexConnection.disconnect(); + + StepVerifier.create(Flux.from(testConnection.getSentAsPublisher()).take(Duration.ofMillis(500))) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + @Test + void resumableRequesterKeepAlivesAfterReconnect() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); + resumableDuplexConnection.disconnect(); + TestDuplexConnection newTestConnection = new TestDuplexConnection(rSocketState.alloc()); + resumableDuplexConnection.reconnect(newTestConnection); + resumableDuplexConnection.resume(0, 0, ignored -> Mono.empty()); + + StepVerifier.create(Flux.from(newTestConnection.getSentAsPublisher()).take(1)) + .expectNextMatches(this::keepAliveFrame) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + @Test + void resumableRequesterNoKeepAlivesAfterDispose() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + rSocketState.rSocket().dispose(); + StepVerifier.create( + Flux.from(rSocketState.connection().getSentAsPublisher()).take(Duration.ofMillis(500))) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + @Test + void resumableRSocketsNotDisposedOnMissingKeepAlives() { + RSocket rSocket = resumableRequesterState.rSocket(); + TestDuplexConnection connection = resumableRequesterState.connection(); + + Mono.delay(Duration.ofMillis(500)).block(); + + Assertions.assertThat(rSocket.isDisposed()).isFalse(); + Assertions.assertThat(connection.isDisposed()).isTrue(); + } + + private boolean keepAliveFrame(ByteBuf frame) { + return FrameHeaderCodec.frameType(frame) == FrameType.KEEPALIVE; + } + + private boolean keepAliveFrameWithRespondFlag(ByteBuf frame) { + return keepAliveFrame(frame) && KeepAliveFrameCodec.respondFlag(frame); + } + + private boolean keepAliveFrameWithoutRespondFlag(ByteBuf frame) { + return keepAliveFrame(frame) && !KeepAliveFrameCodec.respondFlag(frame); + } + + static class RSocketState { + private final RSocket rSocket; + private final TestDuplexConnection connection; + private final LeaksTrackingByteBufAllocator allocator; + + public RSocketState( + RSocket rSocket, LeaksTrackingByteBufAllocator allocator, TestDuplexConnection connection) { + this.rSocket = rSocket; + this.connection = connection; + this.allocator = allocator; + } + + public TestDuplexConnection connection() { + return connection; + } + + public RSocket rSocket() { + return rSocket; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + } + + static class ResumableRSocketState { + private final RSocket rSocket; + private final TestDuplexConnection connection; + private final ResumableDuplexConnection resumableDuplexConnection; + private final LeaksTrackingByteBufAllocator allocator; + + public ResumableRSocketState( + RSocket rSocket, + TestDuplexConnection connection, + ResumableDuplexConnection resumableDuplexConnection, + LeaksTrackingByteBufAllocator allocator) { + this.rSocket = rSocket; + this.connection = connection; + this.resumableDuplexConnection = resumableDuplexConnection; + this.allocator = allocator; + } + + public TestDuplexConnection connection() { + return connection; + } + + public ResumableDuplexConnection resumableDuplexConnection() { + return resumableDuplexConnection; + } + + public RSocket rSocket() { + return rSocket; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java new file mode 100644 index 000000000..ed9f1ec4a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -0,0 +1,99 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class PayloadValidationUtilsTest { + + @Test + void shouldBeValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthCodec.FRAME_LENGTH_MASK + - FrameLengthCodec.FRAME_LENGTH_SIZE + - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthCodec.FRAME_LENGTH_MASK + - FrameLengthCodec.FRAME_LENGTH_SIZE + - FrameHeaderCodec.size() + + 1]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation0() { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK / 2]; + byte[] data = + new byte + [FrameLengthCodec.FRAME_LENGTH_MASK / 2 + - FrameLengthCodec.FRAME_LENGTH_SIZE + - FrameHeaderCodec.size() + - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation1() { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation2() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation3() { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation4() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java new file mode 100644 index 000000000..ab336b8cd --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -0,0 +1,323 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.SETUP; +import static org.assertj.core.data.Offset.offset; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.lease.*; +import io.rsocket.lease.MissingLeaseException; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestServerTransport; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.EmitterProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +class RSocketLeaseTest { + private static final String TAG = "test"; + + private RSocket rSocketRequester; + private ResponderLeaseHandler responderLeaseHandler; + private ByteBufAllocator byteBufAllocator; + private TestDuplexConnection connection; + private RSocketResponder rSocketResponder; + + private EmitterProcessor leaseSender = EmitterProcessor.create(); + private Flux leaseReceiver; + private RequesterLeaseHandler requesterLeaseHandler; + + @BeforeEach + void setUp() { + PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + byteBufAllocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + connection = new TestDuplexConnection(byteBufAllocator); + requesterLeaseHandler = new RequesterLeaseHandler.Impl(TAG, leases -> leaseReceiver = leases); + responderLeaseHandler = + new ResponderLeaseHandler.Impl<>( + TAG, byteBufAllocator, stats -> leaseSender, Optional.empty()); + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(connection, new InitializingInterceptorRegistry(), true); + rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), + payloadDecoder, + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + requesterLeaseHandler, + TestScheduler.INSTANCE); + + RSocket mockRSocketHandler = mock(RSocket.class); + when(mockRSocketHandler.metadataPush(any())).thenReturn(Mono.empty()); + when(mockRSocketHandler.fireAndForget(any())).thenReturn(Mono.empty()); + when(mockRSocketHandler.requestResponse(any())).thenReturn(Mono.empty()); + when(mockRSocketHandler.requestStream(any())).thenReturn(Flux.empty()); + when(mockRSocketHandler.requestChannel(any())).thenReturn(Flux.empty()); + + rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + mockRSocketHandler, + payloadDecoder, + responderLeaseHandler, + 0); + } + + @Test + public void serverRSocketFactoryRejectsUnsupportedLease() { + Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); + ByteBuf setupFrame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + true, + 1000, + 30_000, + "application/octet-stream", + "application/octet-stream", + payload); + + TestServerTransport transport = new TestServerTransport(); + RSocketServer.create().bind(transport).block(); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer(setupFrame); + + Collection sent = connection.getSent(); + Assertions.assertThat(sent).hasSize(1); + ByteBuf error = sent.iterator().next(); + Assertions.assertThat(FrameHeaderCodec.frameType(error)).isEqualTo(ERROR); + Assertions.assertThat(Exceptions.from(0, error).getMessage()) + .isEqualTo("lease is not supported"); + } + + @Test + public void clientRSocketFactorySetsLeaseFlag() { + TestClientTransport clientTransport = new TestClientTransport(); + RSocketConnector.create().lease(Leases::new).connect(clientTransport).block(); + + Collection sent = clientTransport.testConnection().getSent(); + Assertions.assertThat(sent).hasSize(1); + ByteBuf setup = sent.iterator().next(); + Assertions.assertThat(FrameHeaderCodec.frameType(setup)).isEqualTo(SETUP); + Assertions.assertThat(SetupFrameCodec.honorLease(setup)).isTrue(); + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterMissingLeaseRequestsAreRejected(Function> interaction) { + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.0, offset(1e-2)); + + StepVerifier.create(interaction.apply(rSocketRequester)) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterPresentLeaseRequestsAreAccepted(Function> interaction) { + requesterLeaseHandler.receive(leaseFrame(5_000, 2, Unpooled.EMPTY_BUFFER)); + + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(1.0, offset(1e-2)); + + Flux.from(interaction.apply(rSocketRequester)) + .take(Duration.ofMillis(500)) + .as(StepVerifier::create) + .expectComplete() + .verify(Duration.ofSeconds(5)); + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.5, offset(1e-2)); + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterDepletedAllowedLeaseRequestsAreRejected( + Function> interaction) { + requesterLeaseHandler.receive(leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER)); + interaction.apply(rSocketRequester); + + Flux.from(interaction.apply(rSocketRequester)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.0, offset(1e-2)); + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterExpiredLeaseRequestsAreRejected(Function> interaction) { + requesterLeaseHandler.receive(leaseFrame(50, 1, Unpooled.EMPTY_BUFFER)); + + Flux.defer(() -> interaction.apply(rSocketRequester)) + .delaySubscription(Duration.ofMillis(200)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + } + + @Test + void requesterAvailabilityRespectsTransport() { + requesterLeaseHandler.receive(leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER)); + double unavailable = 0.0; + connection.setAvailability(unavailable); + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(unavailable, offset(1e-2)); + } + + @ParameterizedTest + @MethodSource("interactions") + void responderMissingLeaseRequestsAreRejected(Function> interaction) { + StepVerifier.create(interaction.apply(rSocketResponder)) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("interactions") + void responderPresentLeaseRequestsAreAccepted(Function> interaction) { + leaseSender.onNext(Lease.create(5_000, 2)); + + Flux.from(interaction.apply(rSocketResponder)) + .take(Duration.ofMillis(500)) + .as(StepVerifier::create) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("interactions") + void responderDepletedAllowedLeaseRequestsAreRejected( + Function> interaction) { + leaseSender.onNext(Lease.create(5_000, 1)); + + Flux responder = Flux.from(interaction.apply(rSocketResponder)); + responder.subscribe(); + Flux.from(interaction.apply(rSocketResponder)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("interactions") + void expiredLeaseRequestsAreRejected(Function> interaction) { + leaseSender.onNext(Lease.create(50, 1)); + + Flux.from(interaction.apply(rSocketRequester)) + .delaySubscription(Duration.ofMillis(100)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + } + + @Test + void sendLease() { + ByteBuf metadata = byteBufAllocator.buffer(); + Charset utf8 = StandardCharsets.UTF_8; + String metadataContent = "test"; + metadata.writeCharSequence(metadataContent, utf8); + int ttl = 5_000; + int numberOfRequests = 2; + leaseSender.onNext(Lease.create(5_000, 2, metadata)); + + ByteBuf leaseFrame = + connection + .getSent() + .stream() + .filter(f -> FrameHeaderCodec.frameType(f) == FrameType.LEASE) + .findFirst() + .orElseThrow(() -> new IllegalStateException("Lease frame not sent")); + + Assertions.assertThat(LeaseFrameCodec.ttl(leaseFrame)).isEqualTo(ttl); + Assertions.assertThat(LeaseFrameCodec.numRequests(leaseFrame)).isEqualTo(numberOfRequests); + Assertions.assertThat(LeaseFrameCodec.metadata(leaseFrame).toString(utf8)) + .isEqualTo(metadataContent); + } + + @Test + void receiveLease() { + Collection receivedLeases = new ArrayList<>(); + leaseReceiver.subscribe(lease -> receivedLeases.add(lease)); + + ByteBuf metadata = byteBufAllocator.buffer(); + Charset utf8 = StandardCharsets.UTF_8; + String metadataContent = "test"; + metadata.writeCharSequence(metadataContent, utf8); + int ttl = 5_000; + int numberOfRequests = 2; + + ByteBuf leaseFrame = leaseFrame(ttl, numberOfRequests, metadata).retain(1); + + connection.addToReceivedBuffer(leaseFrame); + + Assertions.assertThat(receivedLeases.isEmpty()).isFalse(); + Lease receivedLease = receivedLeases.iterator().next(); + Assertions.assertThat(receivedLease.getTimeToLiveMillis()).isEqualTo(ttl); + Assertions.assertThat(receivedLease.getStartingAllowedRequests()).isEqualTo(numberOfRequests); + Assertions.assertThat(receivedLease.getMetadata().toString(utf8)).isEqualTo(metadataContent); + } + + ByteBuf leaseFrame(int ttl, int requests, ByteBuf metadata) { + return LeaseFrameCodec.encode(byteBufAllocator, ttl, requests, metadata); + } + + static Stream>> interactions() { + return Stream.of( + rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), + rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), + rSocket -> rSocket.requestStream(DefaultPayload.create("test")), + rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test")))); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java new file mode 100644 index 000000000..dc76b5450 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static org.junit.Assert.assertEquals; + +import io.rsocket.RSocket; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.transport.ClientTransport; +import java.io.UncheckedIOException; +import java.time.Duration; +import java.util.Iterator; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.Exceptions; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RSocketReconnectTest { + + private Queue retries = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldBeASharedReconnectableInstanceOfRSocketMono() { + TestClientTransport[] testClientTransport = + new TestClientTransport[] {new TestClientTransport()}; + Mono rSocketMono = + RSocketConnector.create() + .reconnect(Retry.indefinitely()) + .connect(() -> testClientTransport[0]); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + testClientTransport[0] = new TestClientTransport(); + + RSocket rSocket3 = rSocketMono.block(); + RSocket rSocket4 = rSocketMono.block(); + + Assertions.assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + } + + @Test + @SuppressWarnings({"rawtype", "unchecked"}) + public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + Mockito.when(transport.connect(0)) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(new TestClientTransport().connect(0)); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + } + + @Test + @SuppressWarnings({"rawtype", "unchecked"}) + public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + Mockito.when(transport.connect(0)) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(new TestClientTransport().connect(0)); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + } + + @Test + public void shouldBeNotBeASharedReconnectableInstanceOfRSocketMono() { + + Mono rSocketMono = RSocketConnector.connectWith(new TestClientTransport()); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isNotEqualTo(rSocket2); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertEquals(exceptions.length, retries.size()); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertEquals(index, retryContext.totalRetries()); + assertEquals(exceptions[index], retryContext.failure().getClass()); + index++; + } + } + + Consumer onRetry() { + return context -> retries.add(context); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java new file mode 100644 index 000000000..4cd3a3a26 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.test.util.RaceTestUtils; + +class RSocketRequesterSubscribersTest { + + private static final Set REQUEST_TYPES = + new HashSet<>( + Arrays.asList( + FrameType.METADATA_PUSH, + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL)); + + private LeaksTrackingByteBufAllocator allocator; + private RSocket rSocketRequester; + private TestDuplexConnection connection; + + @BeforeEach + void setUp() { + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(allocator); + rSocketRequester = + new RSocketRequester( + connection, + PayloadDecoder.DEFAULT, + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + } + + @ParameterizedTest + @MethodSource("allInteractions") + void singleSubscriber(Function> interaction) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); + + response.subscribe(assertSubscriberA); + response.subscribe(assertSubscriberB); + + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), 1)); + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); + + Assertions.assertThat(requestFramesCount(connection.getSent())).isEqualTo(1); + } + + @ParameterizedTest + @MethodSource("allInteractions") + void singleSubscriberInCaseOfRacing(Function> interaction) { + for (int i = 1; i < 20000; i += 2) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); + + RaceTestUtils.race( + () -> response.subscribe(assertSubscriberA), () -> response.subscribe(assertSubscriberB)); + + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), i)); + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); + + Assertions.assertThat(new AssertSubscriber[] {assertSubscriberA, assertSubscriberB}) + .anySatisfy(as -> as.assertError(IllegalStateException.class)); + + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))); + connection.clearSendReceiveBuffers(); + } + } + + @ParameterizedTest + @MethodSource("allInteractions") + void singleSubscriberInteractionsAreLazy(Function> interaction) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + + Assertions.assertThat(connection.getSent().size()).isEqualTo(0); + } + + static long requestFramesCount(Collection frames) { + return frames + .stream() + .filter(frame -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(frame))) + .count(); + } + + static Stream>> allInteractions() { + return Stream.of( + rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), + rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), + rSocket -> rSocket.requestStream(DefaultPayload.create("test")), + // rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), + rSocket -> rSocket.metadataPush(DefaultPayload.create("", "test"))); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java similarity index 82% rename from rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java rename to rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java index ae3bfc489..de6f86c57 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketClientTerminationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java @@ -1,6 +1,8 @@ -package io.rsocket; +package io.rsocket.core; -import io.rsocket.RSocketClientTest.ClientSocketRule; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketRequesterTest.ClientSocketRule; import io.rsocket.util.EmptyPayload; import java.nio.channels.ClosedChannelException; import java.time.Duration; @@ -16,18 +18,18 @@ import reactor.test.StepVerifier; @RunWith(Parameterized.class) -public class RSocketClientTerminationTest { +public class RSocketRequesterTerminationTest { @Rule public final ClientSocketRule rule = new ClientSocketRule(); private Function> interaction; - public RSocketClientTerminationTest(Function> interaction) { + public RSocketRequesterTerminationTest(Function> interaction) { this.interaction = interaction; } @Test public void testCurrentStreamIsTerminatedOnConnectionClose() { - RSocketClient rSocket = rule.socket; + RSocketRequester rSocket = rule.socket; Mono.delay(Duration.ofSeconds(1)).doOnNext(v -> rule.connection.dispose()).subscribe(); @@ -38,7 +40,7 @@ public void testCurrentStreamIsTerminatedOnConnectionClose() { @Test public void testSubsequentStreamIsTerminatedAfterConnectionClose() { - RSocketClient rSocket = rule.socket; + RSocketRequester rSocket = rule.socket; rule.connection.dispose(); StepVerifier.create(interaction.apply(rSocket)) diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java new file mode 100644 index 000000000..1ba75f75a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -0,0 +1,1031 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameHeaderCodec.frameType; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.runners.model.Statement; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.UnicastProcessor; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketRequesterTest { + + ClientSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped((t) -> {}); + rule = new ClientSocketRule(); + rule.apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + } + + @Test + @Timeout(2_000) + public void testInvalidFrameOnStream0ShouldNotTerminateRSocket() { + rule.connection.addToReceivedBuffer(RequestNFrameCodec.encode(rule.alloc(), 0, 10)); + Assertions.assertThat(rule.socket.isDisposed()).isFalse(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testStreamInitialN() { + Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); + + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + // don't request here + } + }; + stream.subscribe(subscriber); + + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + subscriber.request(5); + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat("sent frame count", sent.size(), is(1)); + + ByteBuf f = sent.get(0); + + assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); + assertThat("initial request n", RequestStreamFrameCodec.initialRequestN(f), is(5L)); + assertThat("should be released", f.release(), is(true)); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleSetupException() { + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), 0, new RejectedSetupException("boom"))); + Assertions.assertThatThrownBy(() -> rule.socket.onClose().block()) + .isInstanceOf(RejectedSetupException.class); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleApplicationException() { + rule.connection.clearSendReceiveBuffers(); + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(); + response.subscribe(responseSub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), streamId, new ApplicationErrorException("error"))); + + verify(responseSub).onError(any(ApplicationErrorException.class)); + + Assertions.assertThat(rule.connection.getSent()) + // requestResponseFrame + .hasSize(1) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleValidFrame() { + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber sub = TestSubscriber.create(); + response.subscribe(sub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeNextReleasingPayload( + rule.alloc(), streamId, EmptyPayload.INSTANCE)); + + verify(sub).onComplete(); + Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testRequestReplyWithCancel() { + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + + try { + response.block(Duration.ofMillis(100)); + } catch (IllegalStateException ise) { + } + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat( + "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); + assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); + Assertions.assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Disabled("invalid") + @Timeout(2_000) + public void testRequestReplyErrorOnSend() { + rule.connection.setAvailability(0); // Fails send + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(10); + response.subscribe(responseSub); + + this.rule + .socket + .onClose() + .as(StepVerifier::create) + .expectComplete() + .verify(Duration.ofMillis(100)); + + verify(responseSub).onSubscribe(any(Subscription.class)); + + rule.assertHasNoLeaks(); + // TODO this should get the error reported through the response subscription + // verify(responseSub).onError(any(RuntimeException.class)); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation() { + MonoProcessor cancelled = MonoProcessor.create(); + Flux request = Flux.never().doOnCancel(cancelled::onComplete); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation2() { + MonoProcessor cancelled = MonoProcessor.create(); + Flux request = + Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::onComplete); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testChannelRequestServerSideCancellation() { + MonoProcessor cancelled = MonoProcessor.create(); + UnicastProcessor request = UnicastProcessor.create(); + request.onNext(EmptyPayload.INSTANCE); + rule.socket.requestChannel(request).subscribe(cancelled); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(rule.alloc(), streamId)); + rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.alloc(), streamId)); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + + Assertions.assertThat(request.isDisposed()).isTrue(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_CHANNEL) + .matches(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testCorrectFrameOrder() { + MonoProcessor delayer = MonoProcessor.create(); + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) {} + }; + rule.socket + .requestChannel( + Flux.concat(Flux.just(0).delayUntil(i -> delayer), Flux.range(1, 999)) + .map(i -> DefaultPayload.create(i + ""))) + .subscribe(subscriber); + + subscriber.request(1); + subscriber.request(Long.MAX_VALUE); + delayer.onComplete(); + + Iterator iterator = rule.connection.getSent().iterator(); + + ByteBuf initialFrame = iterator.next(); + + Assertions.assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); + Assertions.assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame)) + .isEqualTo(Long.MAX_VALUE); + Assertions.assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8)) + .isEqualTo("0"); + Assertions.assertThat(initialFrame.release()).isTrue(); + + Assertions.assertThat(iterator.hasNext()).isFalse(); + rule.assertHasNoLeaks(); + } + + @Test + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + generator.apply(rule.socket, DefaultPayload.create(data, metadata))) + .expectSubscription() + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .verify(); + rule.assertHasNoLeaks(); + }); + } + + static Stream>> prepareCalls() { + return Stream.of( + RSocket::fireAndForget, + RSocket::requestResponse, + RSocket::requestStream, + (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), + RSocket::metadataPush); + } + + @Test + public void + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + rule.socket.requestChannel( + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata)))) + .expectSubscription() + .then( + () -> + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .verify(); + Assertions.assertThat(rule.connection.getSent()) + // expect to be sent RequestChannelFrame + // expect to be sent CancelFrame + .hasSize(2) + .allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("racingCases") + public void checkNoLeaksOnRacing( + Function> initiator, + BiConsumer, ClientSocketRule> runner) { + for (int i = 0; i < 10000; i++) { + ClientSocketRule clientSocketRule = new ClientSocketRule(); + try { + clientSocketRule + .apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } catch (Throwable throwable) { + throwable.printStackTrace(); + } + + Publisher payloadP = initiator.apply(clientSocketRule); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + if (payloadP instanceof Flux) { + ((Flux) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } else { + ((Mono) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } + + runner.accept(assertSubscriber, clientSocketRule); + + Assertions.assertThat(clientSocketRule.connection.getSent()) + .allMatch(ReferenceCounted::release); + + clientSocketRule.assertHasNoLeaks(); + } + } + + private static Stream racingCases() { + return Stream.of( + Arguments.of( + (Function>) + (rule) -> rule.socket.requestStream(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + + return rule.socket.requestStream(payload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + Assertions.assertThat(rule.connection.getSent()).hasSize(2); + Assertions.assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_STREAM, + "Expected first frame matches {" + + REQUEST_STREAM + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + return rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + sink.complete(); + return ++index; + })); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + // + // Assertions.assertThat(rule.connection.getSent()).hasSize(2); + Assertions.assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_CHANNEL, + "Expected first frame matches {" + + REQUEST_CHANNEL + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = CancelFrameCodec.encode(allocator, streamId); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + ErrorFrameCodec.encode(allocator, streamId, new RuntimeException("test")); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + })); + } + + @Test + public void simpleOnDiscardRequestChannelTest() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + TestPublisher testPublisher = TestPublisher.create(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.next( + ByteBufPayload.create("d", "m"), + ByteBufPayload.create("d1", "m1"), + ByteBufPayload.create("d2", "m2")); + + assertSubscriber.cancel(); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleOnDiscardRequestChannelTest2() { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + TestPublisher testPublisher = TestPublisher.create(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.next(ByteBufPayload.create("d", "m")); + + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + testPublisher.next(ByteBufPayload.create("d1", "m1"), ByteBufPayload.create("d2", "m2")); + + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode( + allocator, streamId, new CustomRSocketException(0x00000404, "test"))); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(responsesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + Publisher response; + + switch (frameType) { + case REQUEST_FNF: + response = + testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p).then(Mono.empty())); + break; + case REQUEST_RESPONSE: + response = testPublisher.mono().flatMap(p -> rule.socket.requestResponse(p)); + break; + case REQUEST_STREAM: + response = testPublisher.mono().flatMapMany(p -> rule.socket.requestStream(p)); + break; + case REQUEST_CHANNEL: + response = rule.socket.requestChannel(testPublisher.flux()); + break; + default: + throw new UnsupportedOperationException("illegal case"); + } + + response.subscribe(assertSubscriber); + testPublisher.next(ByteBufPayload.create("d")); + + int streamId = rule.getStreamIdForRequestType(frameType); + + if (responsesCnt > 0) { + for (int i = 0; i < responsesCnt - 1; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + streamId, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("rd" + (i + 1)).getBytes()))); + } + + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + streamId, + false, + true, + true, + null, + Unpooled.wrappedBuffer(("rd" + responsesCnt).getBytes()))); + } + + if (framesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode(allocator, streamId, framesCnt)); + } + + for (int i = 1; i < framesCnt; i++) { + testPublisher.next(ByteBufPayload.create("d" + i)); + } + + Assertions.assertThat(rule.connection.getSent()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, framesCnt) + .hasSize(framesCnt) + .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)) + .allMatch(ByteBuf::release); + + Assertions.assertThat(assertSubscriber.isTerminated()) + .describedAs("Interaction Type :[%s]. Expected to be terminated", frameType) + .isTrue(); + + Assertions.assertThat(assertSubscriber.values()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames received", + frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(p -> p.release()); + + rule.assertHasNoLeaks(); + rule.connection.clearSendReceiveBuffers(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload( + BiFunction> sourceProducer) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + + Publisher source = sourceProducer.apply(invalidPayload, rule.socket); + + StepVerifier.create(source, 0) + .expectError(IllegalReferenceCountException.class) + .verify(Duration.ofMillis(100)); + } + + private static Stream>> refCntCases() { + return Stream.of( + (p, r) -> r.fireAndForget(p), + (p, r) -> r.requestResponse(p), + (p, r) -> r.requestStream(p), + (p, r) -> r.requestChannel(Mono.just(p)), + (p, r) -> + r.requestChannel(Flux.just(EmptyPayload.INSTANCE, p).doOnSubscribe(s -> s.request(1)))); + } + + @Test + public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { + Payload payload1 = ByteBufPayload.create("abc1"); + Mono fnf1 = rule.socket.fireAndForget(payload1); + + Payload payload2 = ByteBufPayload.create("abc2"); + Mono fnf2 = rule.socket.fireAndForget(payload2); + + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + // checks that fnf2 should have id 1 even though it was generated later than fnf1 + AssertSubscriber voidAssertSubscriber2 = fnf2.subscribeWith(AssertSubscriber.create(0)); + voidAssertSubscriber2.assertTerminated().assertNoError(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_FNF) + .matches(bb -> FrameHeaderCodec.streamId(bb) == 1) + // ensures that this is fnf1 with abc2 data + .matches( + bb -> + ByteBufUtil.equals( + RequestFireAndForgetFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc2".getBytes()))) + .matches(ReferenceCounted::release); + + rule.connection.clearSendReceiveBuffers(); + + // checks that fnf1 should have id 3 even though it was generated earlier + AssertSubscriber voidAssertSubscriber1 = fnf1.subscribeWith(AssertSubscriber.create(0)); + voidAssertSubscriber1.assertTerminated().assertNoError(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_FNF) + .matches(bb -> FrameHeaderCodec.streamId(bb) == 3) + // ensures that this is fnf1 with abc1 data + .matches( + bb -> + ByteBufUtil.equals( + RequestFireAndForgetFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc1".getBytes()))) + .matches(ReferenceCounted::release); + } + + @ParameterizedTest + @MethodSource("requestNInteractions") + public void ensuresThatNoOpsMustHappenUntilFirstRequestN( + FrameType frameType, BiFunction> interaction) { + Payload payload1 = ByteBufPayload.create("abc1"); + Publisher interaction1 = interaction.apply(rule, payload1); + + Payload payload2 = ByteBufPayload.create("abc2"); + Publisher interaction2 = interaction.apply(rule, payload2); + + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(0); + interaction1.subscribe(assertSubscriber1); + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(0); + interaction2.subscribe(assertSubscriber2); + assertSubscriber1.assertNotTerminated().assertNoError(); + assertSubscriber2.assertNotTerminated().assertNoError(); + // even though we subscribed, nothing should happen until the first requestN + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + // first request on the second interaction to ensure that stream id issuing on the first request + assertSubscriber2.request(1); + + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == frameType) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 1, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next()) + + "}") + .matches( + bb -> { + switch (frameType) { + case REQUEST_RESPONSE: + return ByteBufUtil.equals( + RequestResponseFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc2".getBytes())); + case REQUEST_STREAM: + return ByteBufUtil.equals( + RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes())); + case REQUEST_CHANNEL: + return ByteBufUtil.equals( + RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes())); + } + + return false; + }) + .matches(ReferenceCounted::release); + + rule.connection.clearSendReceiveBuffers(); + + assertSubscriber1.request(1); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == frameType) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 3, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next()) + + "}") + .matches( + bb -> { + switch (frameType) { + case REQUEST_RESPONSE: + return ByteBufUtil.equals( + RequestResponseFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc1".getBytes())); + case REQUEST_STREAM: + return ByteBufUtil.equals( + RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes())); + case REQUEST_CHANNEL: + return ByteBufUtil.equals( + RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes())); + } + + return false; + }) + .matches(ReferenceCounted::release); + } + + private static Stream requestNInteractions() { + return Stream.of( + Arguments.of( + REQUEST_RESPONSE, + (BiFunction>) + (rule, payload) -> rule.socket.requestResponse(payload)), + Arguments.of( + REQUEST_STREAM, + (BiFunction>) + (rule, payload) -> rule.socket.requestStream(payload)), + Arguments.of( + REQUEST_CHANNEL, + (BiFunction>) + (rule, payload) -> rule.socket.requestChannel(Flux.just(payload)))); + } + + @ParameterizedTest + @MethodSource("streamIdRacingCases") + public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( + BiFunction> interaction1, + BiFunction> interaction2) { + for (int i = 1; i < 10000; i += 4) { + Payload payload = DefaultPayload.create("test"); + Publisher publisher1 = interaction1.apply(rule, payload); + Publisher publisher2 = interaction2.apply(rule, payload); + RaceTestUtils.race( + () -> publisher1.subscribe(AssertSubscriber.create()), + () -> publisher2.subscribe(AssertSubscriber.create())); + + Assertions.assertThat(rule.connection.getSent()) + .extracting(FrameHeaderCodec::streamId) + .containsExactly(i, i + 2); + rule.connection.getSent().clear(); + } + } + + public static Stream streamIdRacingCases() { + return Stream.of( + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + (BiFunction>) + (r, p) -> r.socket.requestResponse(p)), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestResponse(p), + (BiFunction>) + (r, p) -> r.socket.requestStream(p)), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestStream(p), + (BiFunction>) + (r, p) -> r.socket.requestChannel(Flux.just(p))), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestChannel(Flux.just(p)), + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p))); + } + + public int sendRequestResponse(Publisher response) { + Subscriber sub = TestSubscriber.create(); + response.subscribe(sub); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeNextCompleteReleasingPayload( + rule.alloc(), streamId, EmptyPayload.INSTANCE)); + verify(sub).onNext(any(Payload.class)); + verify(sub).onComplete(); + return streamId; + } + + public static class ClientSocketRule extends AbstractSocketRule { + @Override + protected RSocketRequester newRSocket() { + return new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + } + + public int getStreamIdForRequestType(FrameType expectedFrameType) { + assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); + List framesFound = new ArrayList<>(); + for (ByteBuf frame : connection.getSent()) { + FrameType frameType = frameType(frame); + if (frameType == expectedFrameType) { + return FrameHeaderCodec.streamId(frame); + } + framesFound.add(frameType); + } + throw new AssertionError( + "No frames sent with frame type: " + + expectedFrameType + + ", frames found: " + + framesFound); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java new file mode 100644 index 000000000..036dc2eef --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -0,0 +1,823 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameHeaderCodec.frameType; +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.util.Collection; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.runners.model.Statement; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketResponderTest { + + ServerSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped(t -> {}); + rule = new ServerSocketRule(); + rule.apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandleKeepAlive() throws Exception { + rule.connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(rule.alloc(), true, 0, Unpooled.EMPTY_BUFFER)); + ByteBuf sent = rule.connection.awaitSend(); + assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); + /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ + assertThat( + "Unexpected keep-alive frame respond flag.", + KeepAliveFrameCodec.respondFlag(sent), + is(false)); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandleResponseFrameNoError() throws Exception { + final int streamId = 4; + rule.connection.clearSendReceiveBuffers(); + + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + + Collection> sendSubscribers = rule.connection.getSendSubscribers(); + assertThat("Request not sent.", sendSubscribers, hasSize(1)); + Subscriber sendSub = sendSubscribers.iterator().next(); + assertThat( + "Unexpected frame sent.", + frameType(rule.connection.awaitSend()), + anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandlerEmitsError() throws Exception { + final int streamId = 4; + rule.sendRequest(streamId, FrameType.REQUEST_STREAM); + assertThat( + "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); + } + + @Test + @Timeout(20_000) + public void testCancel() { + ByteBufAllocator allocator = rule.alloc(); + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return Mono.never().doOnCancel(() -> cancelled.set(true)); + } + }); + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + + assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); + + rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(allocator, streamId)); + + assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); + assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + final RSocket acceptingSocket = + new RSocket() { + @Override + public Mono requestResponse(Payload p) { + p.release(); + return Mono.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestStream(Payload p) { + p.release(); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .doOnNext(Payload::release) + .subscribe( + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + subscription.request(1); + } + }); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + }; + rule.setAcceptingSocket(acceptingSocket); + + final Runnable[] runnables = { + () -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE), + () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM), + () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL) + }; + + for (Runnable runnable : runnables) { + rule.connection.clearSendReceiveBuffers(); + runnable.run(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.ERROR) + .matches(bb -> ErrorFrameCodec.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)) + .matches(ReferenceCounted::release); + + assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + } + + rule.assertHasNoLeaks(); + } + + @Test + public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return Flux.never(); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + + RaceTestUtils.race( + () -> { + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + }, + assertSubscriber::cancel); + + Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(cancelFrame), + parallel), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void + checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + FluxSink[] sinks = new FluxSink[1]; + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + + return Flux.create( + sink -> { + sinks[0] = sink; + }, + FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + + ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); + + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), + parallel), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + sink.error(new RuntimeException()); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestResponseTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + sources[0] = new Operators.MonoSubscriber<>(actual); + actual.onSubscribe(sources[0]); + } + }; + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_RESPONSE); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sources[0].complete(ByteBufPayload.create("d1", "m1")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void simpleDiscardRequestStreamTest() { + ByteBufAllocator allocator = rule.alloc(); + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + rule.connection.addToReceivedBuffer(cancelFrame); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleDiscardRequestChannelTest() { + ByteBufAllocator allocator = rule.alloc(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return (Flux) payloads; + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("de3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + + rule.connection.addToReceivedBuffer(cancelFrame); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(framesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return testPublisher.flux(); + } + }, + 1); + + rule.sendRequest(1, frameType, ByteBufPayload.create("d")); + + // if responses number is bigger than 1 we have to send one extra requestN + if (responsesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode(allocator, 1, responsesCnt - 1)); + } + + // respond with specific number of elements + for (int i = 0; i < responsesCnt; i++) { + testPublisher.next(ByteBufPayload.create("rd" + i)); + } + + // Listen to incoming frames. Valid for RequestChannel case only + if (framesCnt > 1) { + for (int i = 1; i < responsesCnt; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("d" + (i + 1)).getBytes()))); + } + } + + if (responsesCnt > 0) { + Assertions.assertThat( + rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)); + } + + if (framesCnt > 1) { + Assertions.assertThat( + rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe single RequestN(%s) frame", + frameType, framesCnt - 1) + .hasSize(1) + .first() + .matches(bb -> RequestNFrameCodec.requestN(bb) == (framesCnt - 1)); + } + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + Assertions.assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) + .hasSize(framesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload(FrameType frameType) { + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Mono.just(invalidPayload); + } + + @Override + public Flux requestStream(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + }); + + rule.sendRequest(1, frameType); + + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches( + bb -> frameType(bb) == ERROR, + "Expect frame type to be {" + + ERROR + + "} but was {" + + frameType(rule.connection.getSent().iterator().next()) + + "}"); + } + + private static Stream refCntCases() { + return Stream.of(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + } + + public static class ServerSocketRule extends AbstractSocketRule { + + private RSocket acceptingSocket; + private volatile int prefetch; + + @Override + protected void init() { + acceptingSocket = + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + }; + super.init(); + } + + public void setAcceptingSocket(RSocket acceptingSocket) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + this.prefetch = Integer.MAX_VALUE; + super.init(); + } + + public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + this.prefetch = prefetch; + super.init(); + } + + @Override + protected RSocketResponder newRSocket() { + return new RSocketResponder( + connection, acceptingSocket, PayloadDecoder.ZERO_COPY, ResponderLeaseHandler.None, 0); + } + + private void sendRequest(int streamId, FrameType frameType) { + sendRequest(streamId, frameType, EmptyPayload.INSTANCE); + } + + private void sendRequest(int streamId, FrameType frameType, Payload payload) { + ByteBuf request; + + switch (frameType) { + case REQUEST_CHANNEL: + request = + RequestChannelFrameCodec.encodeReleasingPayload( + allocator, streamId, false, prefetch, payload); + break; + case REQUEST_STREAM: + request = + RequestStreamFrameCodec.encodeReleasingPayload( + allocator, streamId, prefetch, payload); + break; + case REQUEST_RESPONSE: + request = RequestResponseFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + break; + case REQUEST_FNF: + request = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + break; + default: + throw new IllegalArgumentException("unsupported type: " + frameType); + } + + connection.addToReceivedBuffer(request); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java new file mode 100644 index 000000000..073ebfd06 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java @@ -0,0 +1,43 @@ +package io.rsocket.core; + +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestServerTransport; +import org.assertj.core.api.Assertions; +import org.junit.Test; + +public class RSocketServerFragmentationTest { + + @Test + public void serverErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketServer.create().fragment(2)) + .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void serverSucceedsWithEnabledFragmentationOnSufficientMtu() { + RSocketServer.create().fragment(100).bind(new TestServerTransport()).block(); + } + + @Test + public void serverSucceedsWithDisabledFragmentation() { + RSocketServer.create().bind(new TestServerTransport()).block(); + } + + @Test + public void clientErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketConnector.create().fragment(2)) + .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void clientSucceedsWithEnabledFragmentationOnSufficientMtu() { + RSocketConnector.create().fragment(100).connect(new TestClientTransport()).block(); + } + + @Test + public void clientSucceedsWithDisabledFragmentation() { + RSocketConnector.connectWith(new TestClientTransport()).block(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java new file mode 100644 index 000000000..692894fd6 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -0,0 +1,506 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.test.util.LocalDuplexConnection; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import org.assertj.core.api.Assertions; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExternalResource; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; +import org.reactivestreams.Publisher; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.DirectProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; + +public class RSocketTest { + + @Rule public final SocketRule rule = new SocketRule(); + + @Test + public void rsocketDisposalShouldEndupWithNoErrorsOnClose() { + RSocket requestHandlingRSocket = + new RSocket() { + final Disposable disposable = Disposables.single(); + + @Override + public void dispose() { + disposable.dispose(); + } + + @Override + public boolean isDisposed() { + return disposable.isDisposed(); + } + }; + rule.setRequestAcceptor(requestHandlingRSocket); + rule.crs + .onClose() + .as(StepVerifier::create) + .expectSubscription() + .then(rule.crs::dispose) + .expectComplete() + .verify(Duration.ofMillis(100)); + + Assertions.assertThat(requestHandlingRSocket.isDisposed()).isTrue(); + } + + @Test(timeout = 2_000) + public void testRequestReplyNoError() { + StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) + .expectNextCount(1) + .expectComplete() + .verify(); + } + + @Test(timeout = 2000) + public void testHandlerEmitsError() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error(new NullPointerException("Deliberate exception.")); + } + }); + rule.crs + .requestResponse(EmptyPayload.INSTANCE) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(ApplicationErrorException.class) + .hasMessage("Deliberate exception.")) + .verify(Duration.ofMillis(100)); + } + + @Test(timeout = 2000) + public void testHandlerEmitsCustomError() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error( + new CustomRSocketException(0x00000501, "Deliberate Custom exception.")); + } + }); + rule.crs + .requestResponse(EmptyPayload.INSTANCE) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("Deliberate Custom exception.") + .hasFieldOrPropertyWithValue("errorCode", 0x00000501)) + .verify(); + } + + @Test(timeout = 2000) + public void testRequestPropagatesCorrectlyForRequestChannel() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + // specifically limits request to 3 in order to prevent 256 request from limitRate + // hidden on the responder side + .limitRequest(3); + } + }); + + Flux.range(0, 3) + .map(i -> DefaultPayload.create("" + i)) + .as(rule.crs::requestChannel) + .as(publisher -> StepVerifier.create(publisher, 3)) + .expectSubscription() + .expectNextCount(3) + .expectComplete() + .verify(Duration.ofMillis(5000)); + } + + @Test(timeout = 2000) + public void testStream() throws Exception { + Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test(timeout = 2000) + public void testChannel() throws Exception { + Flux requests = + Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test(timeout = 2000) + public void testErrorPropagatesCorrectly() { + AtomicReference error = new AtomicReference<>(); + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads).doOnError(error::set); + } + }); + Flux requests = Flux.error(new RuntimeException("test")); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectErrorMessage("test").verify(); + Assertions.assertThat(error.get()).isNull(); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion1() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion2() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_CancellationFromResponderShouldLeaveStreamInHalfClosedStateWithNextCompletionPossibleFromRequester() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void + requestChannelCase_CompletionFromRequesterShouldLeaveStreamInHalfClosedStateWithNextCancellationPossibleFromResponder() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_ensureThatRequesterSubscriberCancellationTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + cancelFromRequesterSubscriber( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + void initRequestChannelCase( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(responderSubscriber); + return responderPublisher.flux(); + } + }); + + rule.crs.requestChannel(requesterPublisher).subscribe(requesterSubscriber); + + requesterPublisher.assertWasSubscribed(); + requesterSubscriber.assertSubscribed(); + + responderSubscriber.assertNotSubscribed(); + responderPublisher.assertWasNotSubscribed(); + + // firstRequest + requesterSubscriber.request(1); + requesterPublisher.assertMaxRequested(1); + requesterPublisher.next(DefaultPayload.create("initialData", "initialMetadata")); + + responderSubscriber.assertSubscribed(); + responderPublisher.assertWasSubscribed(); + } + + void nextFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that outerUpstream and innerSubscriber is not terminated so the requestChannel + requesterPublisher.assertSubscribers(1); + responderSubscriber.assertNotTerminated(); + + responderSubscriber.request(6); + requesterPublisher.next( + DefaultPayload.create("d1", "m1"), + DefaultPayload.create("d2"), + DefaultPayload.create("d3", "m3"), + DefaultPayload.create("d4"), + DefaultPayload.create("d5", "m5")); + + List innerPayloads = responderSubscriber.awaitAndAssertNextValueCount(6).values(); + Assertions.assertThat(innerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("initialData", "d1", "d2", "d3", "d4", "d5"); + Assertions.assertThat(innerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, true, false, true, false, true); + Assertions.assertThat(innerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("initialMetadata", "m1", "", "m3", "", "m5"); + } + + void completeFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + requesterPublisher.complete(); + responderSubscriber.assertTerminated(); + requesterPublisher.assertNoSubscribers(); + } + + void cancelFromResponderSubscriber( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + responderSubscriber.cancel(); + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + void nextFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that downstream is not terminated so the requestChannel state is half-closed + responderPublisher.assertSubscribers(1); + requesterSubscriber.assertNotTerminated(); + + // ensures responderPublisher can send messages and outerSubscriber can receive them + requesterSubscriber.request(5); + responderPublisher.next( + DefaultPayload.create("rd1", "rm1"), + DefaultPayload.create("rd2"), + DefaultPayload.create("rd3", "rm3"), + DefaultPayload.create("rd4"), + DefaultPayload.create("rd5", "rm5")); + + List outerPayloads = requesterSubscriber.awaitAndAssertNextValueCount(5).values(); + Assertions.assertThat(outerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("rd1", "rd2", "rd3", "rd4", "rd5"); + Assertions.assertThat(outerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, false, true, false, true); + Assertions.assertThat(outerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("rm1", "", "rm3", "", "rm5"); + } + + void completeFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that after sending complete inner upstream is closed + responderPublisher.complete(); + requesterSubscriber.assertTerminated(); + responderPublisher.assertNoSubscribers(); + } + + void cancelFromRequesterSubscriber( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + requesterSubscriber.cancel(); + // error should be propagated + responderSubscriber.assertTerminated(); + responderPublisher.assertWasCancelled(); + responderPublisher.assertNoSubscribers(); + // ensures that cancellation is propagated to the actual upstream + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + public static class SocketRule extends ExternalResource { + + DirectProcessor serverProcessor; + DirectProcessor clientProcessor; + private RSocketRequester crs; + + @SuppressWarnings("unused") + private RSocketResponder srs; + + private RSocket requestAcceptor; + + private LeaksTrackingByteBufAllocator allocator; + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + init(); + base.evaluate(); + } + }; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + protected void init() { + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + serverProcessor = DirectProcessor.create(); + clientProcessor = DirectProcessor.create(); + + LocalDuplexConnection serverConnection = + new LocalDuplexConnection("server", allocator, clientProcessor, serverProcessor); + LocalDuplexConnection clientConnection = + new LocalDuplexConnection("client", allocator, serverProcessor, clientProcessor); + + clientConnection.onClose().doFinally(__ -> serverConnection.dispose()).subscribe(); + serverConnection.onClose().doFinally(__ -> clientConnection.dispose()).subscribe(); + + requestAcceptor = + null != requestAcceptor + ? requestAcceptor + : new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 10) + .map( + i -> DefaultPayload.create("server got -> [" + payload.toString() + "]")); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")) + .subscribe(); + + return Flux.range(1, 10) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")); + } + }; + + srs = + new RSocketResponder( + serverConnection, + requestAcceptor, + PayloadDecoder.DEFAULT, + ResponderLeaseHandler.None, + 0); + + crs = + new RSocketRequester( + clientConnection, + PayloadDecoder.DEFAULT, + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + } + + public void setRequestAcceptor(RSocket requestAcceptor) { + this.requestAcceptor = requestAcceptor; + init(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java new file mode 100644 index 000000000..968a1a793 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java @@ -0,0 +1,868 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.assertj.core.api.Assertions; +import org.junit.Test; +import org.mockito.Mockito; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ReconnectMonoTests { + + private Queue retries = new ConcurrentLinkedQueue<>(); + private Queue> received = new ConcurrentLinkedQueue<>(); + private Queue expired = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldExpireValueOnRacingDisposeAndNext() { + Hooks.onErrorDropped(t -> {}); + Hooks.onNextDropped(System.out::println); + for (int i = 0; i < 100000; i++) { + final int index = i; + final CoreSubscriber[] monoSubscribers = new CoreSubscriber[1]; + Subscription mockSubscription = Mockito.mock(Subscription.class); + final Mono stringMono = + new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + actual.onSubscribe(mockSubscription); + monoSubscribers[0] = actual; + } + }; + + final ReconnectMono reconnectMono = + stringMono + .doOnDiscard(Object.class, System.out::println) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + RaceTestUtils.race(() -> monoSubscribers[0].onNext("value" + index), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + Mockito.verify(mockSubscription).cancel(); + + if (processor.isError()) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + + Assertions.assertThat(expired).containsOnly("value" + i); + } else { + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final MonoProcessor racerProcessor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(racerProcessor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = MonoProcessor.create(); + final MonoProcessor racerProcessor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + RaceTestUtils.race( + () -> reconnectMono.subscribe(processor), () -> reconnectMono.subscribe(racerProcessor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + Assertions.assertThat(racerProcessor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + String[] values = new String[1]; + + RaceTestUtils.race( + () -> values[0] = reconnectMono.block(timeout), () -> reconnectMono.subscribe(processor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(values).containsExactly("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + String[] values1 = new String[1]; + String[] values2 = new String[1]; + + RaceTestUtils.race( + () -> values1[0] = reconnectMono.block(timeout), + () -> values2[0] = reconnectMono.block(timeout)); + + Assertions.assertThat(values2).containsExactly("value" + i); + Assertions.assertThat(values1).containsExactly("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add( + new ReconnectMono.ReconnectInner<>(MonoProcessor.create(), reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Throwable error = processor.getError(); + + if (error instanceof CancellationException) { + Assertions.assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(error) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Unexpected Completion of the Upstream"); + } + + Assertions.assertThat(expired).isEmpty(); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndError() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + if (processor.getError() instanceof CancellationException) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(processor.getError()) + .isInstanceOf(RuntimeException.class) + .hasMessage("test"); + } + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono() + .retryWhen(Retry.max(1).filter(t -> t instanceof Exception)) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + + if (processor.getError() instanceof CancellationException) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(processor.getError()) + .matches(t -> Exceptions.isRetryExhausted(t)) + .hasCause(runtimeException); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldThrowOnBlocking() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on Mono blocking read"); + } + + @Test + public void shouldThrowOnBlockingIfHasAlreadyTerminated() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + publisher.error(new RuntimeException("test")); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(RuntimeException.class) + .hasMessage("test") + .hasSuppressedException(new Exception("ReconnectMono terminated with an error")); + } + + @Test + public void shouldBeScannable() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final Mono parent = publisher.mono(); + final ReconnectMono reconnectMono = + parent.as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final Scannable scannableOfReconnect = Scannable.from(reconnectMono); + + Assertions.assertThat( + (List) + scannableOfReconnect.parents().map(s -> s.getClass()).collect(Collectors.toList())) + .hasSize(1) + .containsExactly(publisher.mono().getClass()); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) + .isEqualTo(false); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + final Scannable scannableOfMonoProcessor = Scannable.from(processor); + + Assertions.assertThat( + (List) + scannableOfMonoProcessor + .parents() + .map(s -> s.getClass()) + .collect(Collectors.toList())) + .hasSize(3) + .containsExactly( + ReconnectMono.ReconnectInner.class, ReconnectMono.class, publisher.mono().getClass()); + + reconnectMono.dispose(); + + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) + .isEqualTo(true); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) + .isInstanceOf(CancellationException.class); + } + + @Test + public void shouldNotExpiredIfNotCompleted() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + reconnectMono.invalidate(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + publisher.assertSubscribers(1); + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isTrue(); + + publisher.assertSubscribers(0); + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldNotEmitUntilCompletion() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isTrue(); + Assertions.assertThat(processor.peek()).isEqualTo("test"); + } + + @Test + public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + processor.cancel(); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.EMPTY_SUBSCRIBED); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isFalse(); + Assertions.assertThat(processor.peek()).isNull(); + } + + @Test + public void shouldExpireValueOnDispose() { + final TestPublisher publisher = TestPublisher.create(); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono) + .expectSubscription() + .then(() -> publisher.next("value")) + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + + reconnectMono.dispose(); + + Assertions.assertThat(expired).hasSize(1); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(reconnectMono.isDisposed()).isTrue(); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectError(CancellationException.class) + .verify(Duration.ofSeconds(timeout)); + } + + @Test + public void shouldNotifyAllTheSubscribers() { + final TestPublisher publisher = TestPublisher.create(); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor sub1 = MonoProcessor.create(); + final MonoProcessor sub2 = MonoProcessor.create(); + final MonoProcessor sub3 = MonoProcessor.create(); + final MonoProcessor sub4 = MonoProcessor.create(); + + reconnectMono.subscribe(sub1); + reconnectMono.subscribe(sub2); + reconnectMono.subscribe(sub3); + reconnectMono.subscribe(sub4); + + Assertions.assertThat(reconnectMono.subscribers).hasSize(4); + + final ArrayList> processors = new ArrayList<>(200); + + for (int i = 0; i < 100; i++) { + final MonoProcessor subA = MonoProcessor.create(); + final MonoProcessor subB = MonoProcessor.create(); + processors.add(subA); + processors.add(subB); + RaceTestUtils.race(() -> reconnectMono.subscribe(subA), () -> reconnectMono.subscribe(subB)); + } + + Assertions.assertThat(reconnectMono.subscribers).hasSize(204); + + sub1.dispose(); + + Assertions.assertThat(reconnectMono.subscribers).hasSize(203); + + publisher.next("value"); + + Assertions.assertThatThrownBy(sub1::peek).isInstanceOf(CancellationException.class); + Assertions.assertThat(sub2.peek()).isEqualTo("value"); + Assertions.assertThat(sub3.peek()).isEqualTo("value"); + Assertions.assertThat(sub4.peek()).isEqualTo("value"); + + for (MonoProcessor sub : processors) { + Assertions.assertThat(sub.peek()).isEqualTo("value"); + Assertions.assertThat(sub.isTerminated()).isTrue(); + } + + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldExpireValueExactlyOnce() { + for (int i = 0; i < 1000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value"); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::invalidate); + + Assertions.assertThat(expired).hasSize(1).containsOnly("value"); + Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).hasSize(1).containsOnly("value"); + Assertions.assertThat(received) + .hasSize(2) + .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value", reconnectMono)); + + Assertions.assertThat(cold.subscribeCount()).isEqualTo(2); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldTimeoutRetryWithVirtualTime() { + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + // then + StepVerifier.withVirtualTime( + () -> + Mono.error(new RuntimeException("Something went wrong")) + .retryWhen( + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(minBackoff)) + .doAfterRetry(onRetry()) + .maxBackoff(Duration.ofSeconds(maxBackoff))) + .timeout(Duration.ofSeconds(timeout)) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())) + .subscribeOn(Schedulers.elastic())) + .expectSubscription() + .thenAwait(Duration.ofSeconds(timeout)) + .expectError(TimeoutException.class) + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryNoBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.max(2).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.create(mono).verifyErrorMatches(Exceptions::isRetryExhausted); + assertRetries(IOException.class, IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryFixedBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.fixedDelay(1, Duration.ofMillis(500)).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .expectNoEvent(Duration.ofMillis(300)) + .thenAwait(Duration.ofMillis(300)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryExponentialBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .jitter(0.0d) + .doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .thenAwait(Duration.ofMillis(100)) + .thenAwait(Duration.ofMillis(200)) + .thenAwait(Duration.ofMillis(400)) + .thenAwait(Duration.ofMillis(500)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class, IOException.class, IOException.class, IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + Consumer onRetry() { + return context -> retries.add(context); + } + + BiConsumer onValue() { + return (v, __) -> received.add(Tuples.of(v, __)); + } + + Consumer onExpire() { + return (v) -> expired.add(v); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertEquals(exceptions.length, retries.size()); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertEquals(index, retryContext.totalRetries()); + assertEquals(exceptions[index], retryContext.failure().getClass()); + index++; + } + } + + static boolean isRetryExhausted(Throwable e, Class cause) { + return Exceptions.isRetryExhausted(e) && cause.isInstance(e.getCause()); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java similarity index 62% rename from rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java rename to rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index 2326f338d..2957a051e 100644 --- a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -1,29 +1,29 @@ -package io.rsocket; +package io.rsocket.core; import static io.rsocket.transport.ServerTransport.ConnectionAcceptor; import static org.assertj.core.api.Assertions.assertThat; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.rsocket.*; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; -import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.transport.ServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import org.junit.Ignore; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.core.publisher.UnicastProcessor; import reactor.test.StepVerifier; -@Ignore public class SetupRejectionTest { @Test @@ -32,13 +32,13 @@ void responderRejectSetup() { String errorMsg = "error"; RejectingAcceptor acceptor = new RejectingAcceptor(errorMsg); - RSocketFactory.receive().acceptor(acceptor).transport(transport).start().block(); + RSocketServer.create().acceptor(acceptor).bind(transport).block(); transport.connect(); ByteBuf sentFrame = transport.awaitSent(); - assertThat(FrameHeaderFlyweight.frameType(sentFrame)).isEqualTo(FrameType.ERROR); - RuntimeException error = Exceptions.from(sentFrame); + assertThat(FrameHeaderCodec.frameType(sentFrame)).isEqualTo(FrameType.ERROR); + RuntimeException error = Exceptions.from(0, sentFrame); assertThat(errorMsg).isEqualTo(error.getMessage()); assertThat(error).isInstanceOf(RejectedSetupException.class); RSocket acceptorSender = acceptor.senderRSocket().block(); @@ -46,50 +46,61 @@ void responderRejectSetup() { } @Test + @Disabled("FIXME: needs to be revised") void requesterStreamsTerminatedOnZeroErrorFrame() { - TestDuplexConnection conn = new TestDuplexConnection(); - List errors = new ArrayList<>(); - RSocketClient rSocket = - new RSocketClient( - ByteBufAllocator.DEFAULT, + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); + RSocketRequester rSocket = + new RSocketRequester( conn, DefaultPayload::create, - errors::add, - StreamIdSupplier.clientSupplier()); + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); String errorMsg = "error"; - Mono.delay(Duration.ofMillis(100)) - .doOnTerminate( - () -> - conn.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 0, new RejectedSetupException(errorMsg)))) - .subscribe(); - - StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test"))) + StepVerifier.create( + rSocket + .requestResponse(DefaultPayload.create("test")) + .doOnRequest( + ignored -> + conn.addToReceivedBuffer( + ErrorFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 0, + new RejectedSetupException(errorMsg))))) .expectErrorMatches( err -> err instanceof RejectedSetupException && errorMsg.equals(err.getMessage())) .verify(Duration.ofSeconds(5)); - assertThat(errors).hasSize(1); assertThat(rSocket.isDisposed()).isTrue(); } @Test void requesterNewStreamsTerminatedAfterZeroErrorFrame() { - TestDuplexConnection conn = new TestDuplexConnection(); - RSocketClient rSocket = - new RSocketClient( - ByteBufAllocator.DEFAULT, + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); + RSocketRequester rSocket = + new RSocketRequester( conn, DefaultPayload::create, - err -> {}, - StreamIdSupplier.clientSupplier()); + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); conn.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); StepVerifier.create( rSocket @@ -121,10 +132,12 @@ public Mono senderRSocket() { private static class SingleConnectionTransport implements ServerTransport { - private final TestDuplexConnection conn = new TestDuplexConnection(); + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private final TestDuplexConnection conn = new TestDuplexConnection(allocator); @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { return Mono.just(new TestCloseable(acceptor, conn)); } @@ -138,17 +151,7 @@ public ByteBuf awaitSent() { public void connect() { Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); - ByteBuf setup = - SetupFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - false, - false, - 0, - 42, - "mdMime", - "dMime", - payload.sliceMetadata(), - payload.sliceData()); + ByteBuf setup = SetupFrameCodec.encode(allocator, false, 0, 42, "mdMime", "dMime", payload); conn.addToReceivedBuffer(setup); } diff --git a/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java similarity index 55% rename from rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java rename to rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java index 008e0f45a..00248b6d8 100644 --- a/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java @@ -14,43 +14,48 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.internal.SynchronizedIntObjectHashMap; import org.junit.Test; public class StreamIdSupplierTest { @Test public void testClientSequence() { + IntObjectMap map = new SynchronizedIntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); - assertEquals(1, s.nextStreamId()); - assertEquals(3, s.nextStreamId()); - assertEquals(5, s.nextStreamId()); + assertEquals(1, s.nextStreamId(map)); + assertEquals(3, s.nextStreamId(map)); + assertEquals(5, s.nextStreamId(map)); } @Test public void testServerSequence() { + IntObjectMap map = new SynchronizedIntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.serverSupplier(); - assertEquals(2, s.nextStreamId()); - assertEquals(4, s.nextStreamId()); - assertEquals(6, s.nextStreamId()); + assertEquals(2, s.nextStreamId(map)); + assertEquals(4, s.nextStreamId(map)); + assertEquals(6, s.nextStreamId(map)); } @Test public void testClientIsValid() { + IntObjectMap map = new SynchronizedIntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.clientSupplier(); assertFalse(s.isBeforeOrCurrent(1)); assertFalse(s.isBeforeOrCurrent(3)); - s.nextStreamId(); + s.nextStreamId(map); assertTrue(s.isBeforeOrCurrent(1)); assertFalse(s.isBeforeOrCurrent(3)); - s.nextStreamId(); + s.nextStreamId(map); assertTrue(s.isBeforeOrCurrent(3)); // negative @@ -63,16 +68,17 @@ public void testClientIsValid() { @Test public void testServerIsValid() { + IntObjectMap map = new SynchronizedIntObjectHashMap<>(); StreamIdSupplier s = StreamIdSupplier.serverSupplier(); assertFalse(s.isBeforeOrCurrent(2)); assertFalse(s.isBeforeOrCurrent(4)); - s.nextStreamId(); + s.nextStreamId(map); assertTrue(s.isBeforeOrCurrent(2)); assertFalse(s.isBeforeOrCurrent(4)); - s.nextStreamId(); + s.nextStreamId(map); assertTrue(s.isBeforeOrCurrent(4)); // negative @@ -82,4 +88,32 @@ public void testServerIsValid() { // client also accepted (checked externally) assertTrue(s.isBeforeOrCurrent(1)); } + + @Test + public void testWrap() { + IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + StreamIdSupplier s = new StreamIdSupplier(Integer.MAX_VALUE - 3); + + assertEquals(2147483646, s.nextStreamId(map)); + assertEquals(2, s.nextStreamId(map)); + assertEquals(4, s.nextStreamId(map)); + + s = new StreamIdSupplier(Integer.MAX_VALUE - 2); + + assertEquals(2147483647, s.nextStreamId(map)); + assertEquals(1, s.nextStreamId(map)); + assertEquals(3, s.nextStreamId(map)); + } + + @Test + public void testSkipFound() { + IntObjectMap map = new SynchronizedIntObjectHashMap<>(); + map.put(5, new Object()); + map.put(9, new Object()); + StreamIdSupplier s = StreamIdSupplier.clientSupplier(); + assertEquals(1, s.nextStreamId(map)); + assertEquals(3, s.nextStreamId(map)); + assertEquals(7, s.nextStreamId(map)); + assertEquals(11, s.nextStreamId(map)); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/TestingStuff.java b/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java similarity index 97% rename from rsocket-core/src/test/java/io/rsocket/TestingStuff.java rename to rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java index 64c790053..e0ebf5064 100644 --- a/rsocket-core/src/test/java/io/rsocket/TestingStuff.java +++ b/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java @@ -1,4 +1,4 @@ -package io.rsocket; +package io.rsocket.core; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; @@ -16,6 +16,6 @@ public void testStuff() { ByteBuf byteBuf = Unpooled.wrappedBuffer(ByteBufUtil.decodeHexDump(f1)); System.out.println(ByteBufUtil.prettyHexDump(byteBuf)); - ConnectionSetupPayload.create(byteBuf); + new DefaultConnectionSetupPayload(byteBuf); } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java index c7bbfadf6..b3f596a37 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java @@ -16,131 +16,220 @@ package io.rsocket.exceptions; +import static io.rsocket.frame.ErrorFrameCodec.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.CANCELED; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.INVALID; +import static io.rsocket.frame.ErrorFrameCodec.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.UNSUPPORTED_SETUP; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + final class ExceptionsTest { - /* @DisplayName("from returns ApplicationErrorException") @Test void fromApplicationException() { - ByteBuf byteBuf = createErrorFrame(APPLICATION_ERROR, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, APPLICATION_ERROR, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(ApplicationErrorException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); } @DisplayName("from returns CanceledException") @Test void fromCanceledException() { - ByteBuf byteBuf = createErrorFrame(CANCELED, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, CANCELED, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(CanceledException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); } @DisplayName("from returns ConnectionCloseException") @Test void fromConnectionCloseException() { - ByteBuf byteBuf = createErrorFrame(CONNECTION_CLOSE, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_CLOSE, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(ConnectionCloseException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); } @DisplayName("from returns ConnectionErrorException") @Test void fromConnectionErrorException() { - ByteBuf byteBuf = createErrorFrame(CONNECTION_ERROR, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_ERROR, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(ConnectionErrorException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); } @DisplayName("from returns IllegalArgumentException if error frame has illegal error code") @Test void fromIllegalErrorFrame() { - ByteBuf byteBuf = createErrorFrame(0x00000000, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, 0x00000000, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) - .isInstanceOf(IllegalArgumentException.class) - .withFailMessage("Invalid Error frame: %d, '%s'", 0, "test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") + .isInstanceOf(IllegalArgumentException.class); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns InvalidException") @Test void fromInvalidException() { - ByteBuf byteBuf = createErrorFrame(INVALID, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, INVALID, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(InvalidException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns InvalidSetupException") @Test void fromInvalidSetupException() { - ByteBuf byteBuf = createErrorFrame(INVALID_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, INVALID_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(InvalidSetupException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedException") @Test void fromRejectedException() { - ByteBuf byteBuf = createErrorFrame(REJECTED, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, REJECTED, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(RejectedException.class) .withFailMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedResumeException") @Test void fromRejectedResumeException() { - ByteBuf byteBuf = createErrorFrame(REJECTED_RESUME, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, REJECTED_RESUME, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(RejectedResumeException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedSetupException") @Test void fromRejectedSetupException() { - ByteBuf byteBuf = createErrorFrame(REJECTED_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, REJECTED_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(RejectedSetupException.class) .withFailMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns UnsupportedSetupException") @Test void fromUnsupportedSetupException() { - ByteBuf byteBuf = createErrorFrame(UNSUPPORTED_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, UNSUPPORTED_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(UnsupportedSetupException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } + + @DisplayName("from returns CustomRSocketException") + @Test + void fromCustomRSocketException() { + for (int i = 0; i < 1000; i++) { + int randomCode = + ThreadLocalRandom.current().nextBoolean() + ? ThreadLocalRandom.current() + .nextInt(Integer.MIN_VALUE, ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE) + : ThreadLocalRandom.current() + .nextInt(ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE, Integer.MAX_VALUE); + ByteBuf byteBuf = createErrorFrame(0, randomCode, "test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } } @DisplayName("from throws NullPointerException with null frame") @Test void fromWithNullFrame() { assertThatNullPointerException() - .isThrownBy(() -> Exceptions.from(null)) + .isThrownBy(() -> Exceptions.from(0, null)) .withMessage("frame must not be null"); } - private ByteBuf createErrorFrame(int errorCode, String message) { - ByteBuf byteBuf = Unpooled.buffer(); - - ErrorFrameFlyweight.encode(byteBuf, 0, errorCode, Unpooled.copiedBuffer(message, UTF_8)); - - return byteBuf; - }*/ + private ByteBuf createErrorFrame(int streamId, int errorCode, String message) { + return ErrorFrameCodec.encode( + UnpooledByteBufAllocator.DEFAULT, streamId, new TestRSocketException(errorCode, message)); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java index 8c39e8250..ccf7649d2 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java @@ -17,27 +17,22 @@ package io.rsocket.exceptions; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; interface RSocketExceptionTest { - @DisplayName("constructor throws NullPointerException with null message") + @DisplayName("constructor does not throw NullPointerException with null message") @Test default void constructorWithNullMessage() { - assertThatNullPointerException() - .isThrownBy(() -> getException(null)) - .withMessage("message must not be null"); + assertThat(getException(null)).hasMessage(null); } - @DisplayName("constructor throws NullPointerException with null message and cause") + @DisplayName("constructor does not throw NullPointerException with null message and cause") @Test default void constructorWithNullMessageAndCause() { - assertThatNullPointerException() - .isThrownBy(() -> getException(null, new Exception())) - .withMessage("message must not be null"); + assertThat(getException(null)).hasMessage(null); } @DisplayName("errorCode returns specified value") diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java new file mode 100644 index 000000000..6c2e63730 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java @@ -0,0 +1,39 @@ +package io.rsocket.exceptions; + +public class TestRSocketException extends RSocketException { + private static final long serialVersionUID = 7873267740343446585L; + + private final int errorCode; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code + * @param message the message + * @throws NullPointerException if {@code message} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message) { + super(message); + this.errorCode = errorCode; + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code + * @param message the message + * @param cause the cause of this exception + * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message, Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + } + + @Override + public int errorCode() { + return errorCode; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java index 6b25ac902..932df4283 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,367 +16,94 @@ package io.rsocket.fragmentation; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.*; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + final class FragmentationDuplexConnectionTest { - /* + private static byte[] data = new byte[1024]; + private static byte[] metadata = new byte[1024]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + { + Mockito.when(delegate.onClose()).thenReturn(Mono.never()); + } + @SuppressWarnings("unchecked") - private final ArgumentCaptor> publishers = + private final ArgumentCaptor> publishers = ArgumentCaptor.forClass(Publisher.class); + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") @Test void constructorInvalidMaxFragmentSize() { assertThatIllegalArgumentException() - .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, delegate, Integer.MIN_VALUE)) - .withMessage("maxFragmentSize must be positive"); + .isThrownBy(() -> new FragmentationDuplexConnection(delegate, Integer.MIN_VALUE, false, "")) + .withMessage("smallest allowed mtu size is 64 bytes, provided: -2147483648"); } - @DisplayName("constructor throws NullPointerException with null byteBufAllocator") + @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") @Test - void constructorNullByteBufAllocator() { - assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(null, delegate, 2)) - .withMessage("byteBufAllocator must not be null"); + void constructorMtuLessThanMin() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new FragmentationDuplexConnection(delegate, 2, false, "")) + .withMessage("smallest allowed mtu size is 64 bytes, provided: 2"); } @DisplayName("constructor throws NullPointerException with null delegate") @Test void constructorNullDelegate() { assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, null, 2)) + .isThrownBy(() -> new FragmentationDuplexConnection(null, 64, false, "")) .withMessage("delegate must not be null"); } - @DisplayName("reassembles data") - @Test - void reassembleData() { - ByteBuf data = getRandomByteBuf(6); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, null, data)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, null, data.slice(0, 2))); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, false, null, data.slice(2, 2))); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, data.slice(4, 2))); - - when(delegate.receive()).thenReturn(Flux.just(fragment1, fragment2, fragment3)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2) - .receive() - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); - } - - @DisplayName("reassembles metadata") - @Test - void reassembleMetadata() { - ByteBuf metadata = getRandomByteBuf(6); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, metadata, null)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null)); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null)); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, true, metadata.slice(4, 2), null)); - - when(delegate.receive()).thenReturn(Flux.just(fragment1, fragment2, fragment3)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2) - .receive() - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); - } - - @DisplayName("reassembles metadata and data") - @Test - void reassembleMetadataAndData() { - ByteBuf metadata = getRandomByteBuf(5); - ByteBuf data = getRandomByteBuf(5); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, metadata, data)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null)); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null)); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, - 1, - createPayloadFrame(DEFAULT, true, false, metadata.slice(4, 1), data.slice(0, 1))); - - Frame fragment4 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, false, null, data.slice(1, 2))); - - Frame fragment5 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, data.slice(3, 2))); - - when(delegate.receive()) - .thenReturn(Flux.just(fragment1, fragment2, fragment3, fragment4, fragment5)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2) - .receive() - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); - } - - @DisplayName("does not reassemble a non-fragment frame") - @Test - void reassembleNonFragment() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, true, (ByteBuf) null, null)); - - when(delegate.receive()).thenReturn(Flux.just(frame.retain())); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2) - .receive() - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); - } - - @DisplayName("does not reassemble non fragmentable frame") - @Test - void reassembleNonFragmentableFrame() { - Frame frame = toAbstractionLeakingFrame(DEFAULT, 1, createTestCancelFrame()); - - when(delegate.receive()).thenReturn(Flux.just(frame.retain())); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2) - .receive() - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); - } - @DisplayName("fragments data") @Test void sendData() { - ByteBuf data = getRandomByteBuf(6); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, null, data)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, null, data.slice(0, 2))); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, false, null, data.slice(2, 2))); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, data.slice(4, 2))); + ByteBuf encode = + RequestResponseFrameCodec.encode( + allocator, 1, false, Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(data)); when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .verifyComplete(); - } + new FragmentationDuplexConnection(delegate, 64, false, "").sendOne(encode.retain()); - @DisplayName("does not fragment with size equal to maxFragmentLength") - @Test - void sendEqualToMaxFragmentLength() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2))); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - } - - @DisplayName("does not fragment an already-fragmented frame") - @Test - void sendFragment() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, (ByteBuf) null, null)); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - } - - @DisplayName("does not fragment with size smaller than maxFragmentLength") - @Test - void sendLessThanMaxFragmentLength() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(1))); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - } - - @DisplayName("fragments metadata") - @Test - void sendMetadata() { - ByteBuf metadata = getRandomByteBuf(6); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, metadata, null)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null)); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null)); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, true, metadata.slice(4, 2), null)); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .verifyComplete(); - } - - @DisplayName("fragments metadata and data") - @Test - void sendMetadataAndData() { - ByteBuf metadata = getRandomByteBuf(5); - ByteBuf data = getRandomByteBuf(5); - - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, false, 1, metadata, data)); - - Frame fragment1 = - toAbstractionLeakingFrame( - DEFAULT, 1, createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null)); - - Frame fragment2 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null)); - - Frame fragment3 = - toAbstractionLeakingFrame( - DEFAULT, - 1, - createPayloadFrame(DEFAULT, true, false, metadata.slice(4, 1), data.slice(0, 1))); - - Frame fragment4 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, true, false, null, data.slice(1, 2))); - - Frame fragment5 = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, data.slice(3, 2))); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); verify(delegate).send(publishers.capture()); StepVerifier.create(Flux.from(publishers.getValue())) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .expectNext(fragment4) - .expectNext(fragment5) + .expectNextCount(17) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); + }) .verifyComplete(); } - - @DisplayName("does not fragment non-fragmentable frame") - @Test - void sendNonFragmentable() { - Frame frame = toAbstractionLeakingFrame(DEFAULT, 1, createTestCancelFrame()); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 2).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - } - - @DisplayName("send throws NullPointerException with null frames") - @Test - void sendNullFrames() { - when(delegate.onClose()).thenReturn(Mono.never()); - - assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(DEFAULT, delegate, 2).send(null)) - .withMessage("frames must not be null"); - } - - @DisplayName("does not fragment with zero maxFragmentLength") - @Test - void sendZeroMaxFragmentLength() { - Frame frame = - toAbstractionLeakingFrame( - DEFAULT, 1, createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2))); - - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(DEFAULT, delegate, 0).sendOne(frame.retain()); - verify(delegate).send(publishers.capture()); - - StepVerifier.create(Flux.from(publishers.getValue())).expectNext(frame).verifyComplete(); - }*/ } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java new file mode 100644 index 000000000..ff62b56f2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java @@ -0,0 +1,57 @@ +package io.rsocket.fragmentation; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameUtil; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +public class FragmentationIntegrationTest { + private static byte[] data = new byte[128]; + private static byte[] metadata = new byte[128]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } + + private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + + @DisplayName("fragments and reassembles data") + @Test + void fragmentAndReassembleData() { + ByteBuf frame = + PayloadFrameCodec.encodeNextCompleteReleasingPayload( + allocator, 2, DefaultPayload.create(data)); + System.out.println(FrameUtil.toString(frame)); + + frame.retain(); + + Publisher fragments = + FrameFragmenter.fragmentFrame( + allocator, 64, frame, FrameHeaderCodec.frameType(frame), false); + + FrameReassembler reassembler = new FrameReassembler(allocator); + + ByteBuf assembled = + Flux.from(fragments) + .doOnNext(byteBuf -> System.out.println(FrameUtil.toString(byteBuf))) + .handle(reassembler::reassembleFrame) + .blockLast(); + + System.out.println("assembled"); + String s = FrameUtil.toString(assembled); + System.out.println(s); + + Assert.assertEquals(FrameHeaderCodec.frameType(frame), FrameHeaderCodec.frameType(assembled)); + Assert.assertEquals(frame.readableBytes(), assembled.readableBytes()); + Assert.assertEquals(PayloadFrameCodec.data(frame), PayloadFrameCodec.data(assembled)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java index 8cf3edb96..60dbef74b 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java @@ -16,158 +16,335 @@ package io.rsocket.fragmentation; -final class FrameFragmenterTest { - /* - @DisplayName("constructor throws NullPointerException with null ByteBufAllocator") - @Test - void constructorNullByteBufAllocator() { - assertThatNullPointerException() - .isThrownBy(() -> new FrameFragmenter(null, 2)) - .withMessage("byteBufAllocator must not be null"); - } - - @DisplayName("fragments data") - @Test - void fragmentData() { - ByteBuf data = getRandomByteBuf(6); - - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, null, data); +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.frame.*; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, null, data.slice(0, 2)); +final class FrameFragmenterTest { + private static byte[] data = new byte[4096]; + private static byte[] metadata = new byte[4096]; - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, false, null, data.slice(2, 2)); + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } - PayloadFrame fragment3 = createPayloadFrame(DEFAULT, false, false, null, data.slice(4, 2)); + private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .verifyComplete(); + @Test + void testGettingData() { + ByteBuf rr = + RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); + ByteBuf fnf = + RequestFireAndForgetFrameCodec.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)); + ByteBuf rs = + RequestStreamFrameCodec.encode(allocator, 1, true, 1, null, Unpooled.wrappedBuffer(data)); + ByteBuf rc = + RequestChannelFrameCodec.encode( + allocator, 1, true, false, 1, null, Unpooled.wrappedBuffer(data)); + + ByteBuf data = FrameFragmenter.getData(rr, FrameType.REQUEST_RESPONSE); + Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); + data.release(); + + data = FrameFragmenter.getData(fnf, FrameType.REQUEST_FNF); + Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); + data.release(); + + data = FrameFragmenter.getData(rs, FrameType.REQUEST_STREAM); + Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); + data.release(); + + data = FrameFragmenter.getData(rc, FrameType.REQUEST_CHANNEL); + Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); + data.release(); } - @DisplayName("does not fragment with size equal to maxFragmentLength") @Test - void fragmentEqualToMaxFragmentLength() { - PayloadFrame frame = createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2)); - - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); + void testGettingMetadata() { + ByteBuf rr = + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); + ByteBuf fnf = + RequestFireAndForgetFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); + ByteBuf rs = + RequestStreamFrameCodec.encode( + allocator, 1, true, 1, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); + ByteBuf rc = + RequestChannelFrameCodec.encode( + allocator, + 1, + true, + false, + 1, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)); + + ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); + Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); + data.release(); + + data = FrameFragmenter.getMetadata(fnf, FrameType.REQUEST_FNF); + Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); + data.release(); + + data = FrameFragmenter.getMetadata(rs, FrameType.REQUEST_STREAM); + Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); + data.release(); + + data = FrameFragmenter.getMetadata(rc, FrameType.REQUEST_CHANNEL); + Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); + data.release(); } - @DisplayName("does not fragment an already-fragmented frame") @Test - void fragmentFragment() { - PayloadFrame frame = createPayloadFrame(DEFAULT, true, true, (ByteBuf) null, null); + void returnEmptBufferWhenNoMetadataPresent() { + ByteBuf rr = + RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); + ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); + Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); + data.release(); } - @DisplayName("does not fragment with size smaller than maxFragmentLength") + @DisplayName("encode first frame") @Test - void fragmentLessThanMaxFragmentLength() { - PayloadFrame frame = createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(1)); - - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) - .verifyComplete(); + void encodeFirstFrameWithData() { + ByteBuf rr = + RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rr, + FrameType.REQUEST_RESPONSE, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.wrappedBuffer(data)); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); + + ByteBuf data = RequestResponseFrameCodec.data(fragment); + ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); + Assert.assertEquals(byteBuf, data); + + Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); } - @DisplayName("fragments metadata") + @DisplayName("encode first channel frame") @Test - void fragmentMetadata() { - ByteBuf metadata = getRandomByteBuf(6); - - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, metadata, null); - - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null); - - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null); - - PayloadFrame fragment3 = createPayloadFrame(DEFAULT, false, true, metadata.slice(4, 2), null); - - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .verifyComplete(); + void encodeFirstWithDataChannel() { + ByteBuf rc = + RequestChannelFrameCodec.encode( + allocator, 1, true, false, 10, null, Unpooled.wrappedBuffer(data)); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rc, + FrameType.REQUEST_CHANNEL, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.wrappedBuffer(data)); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_CHANNEL, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertEquals(10, RequestChannelFrameCodec.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); + + ByteBuf data = RequestChannelFrameCodec.data(fragment); + ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); + Assert.assertEquals(byteBuf, data); + + Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); } - @DisplayName("fragments metadata and data") + @DisplayName("encode first stream frame") @Test - void fragmentMetadataAndData() { - ByteBuf metadata = getRandomByteBuf(5); - ByteBuf data = getRandomByteBuf(5); - - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, metadata, data); - - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null); - - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null); - - PayloadFrame fragment3 = - createPayloadFrame(DEFAULT, true, false, metadata.slice(4, 1), data.slice(0, 1)); - - PayloadFrame fragment4 = createPayloadFrame(DEFAULT, true, false, null, data.slice(1, 2)); - - PayloadFrame fragment5 = createPayloadFrame(DEFAULT, false, false, null, data.slice(3, 2)); + void encodeFirstWithDataStream() { + ByteBuf rc = + RequestStreamFrameCodec.encode(allocator, 1, true, 50, null, Unpooled.wrappedBuffer(data)); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rc, + FrameType.REQUEST_STREAM, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.wrappedBuffer(data)); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertEquals(50, RequestStreamFrameCodec.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); + + ByteBuf data = RequestStreamFrameCodec.data(fragment); + ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); + Assert.assertEquals(byteBuf, data); + + Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); + } - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(fragment1) - .expectNext(fragment2) - .expectNext(fragment3) - .expectNext(fragment4) - .expectNext(fragment5) - .verifyComplete(); + @DisplayName("encode first frame with only metadata") + @Test + void encodeFirstFrameWithMetadata() { + ByteBuf rr = + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rr, + FrameType.REQUEST_RESPONSE, + 1, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); + + ByteBuf data = RequestResponseFrameCodec.data(fragment); + Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); + + Assert.assertTrue(FrameHeaderCodec.hasMetadata(fragment)); } - @DisplayName("does not fragment non-fragmentable frame") + @DisplayName("encode first stream frame with data and metadata") @Test - void fragmentNonFragmentable() { - CancelFrame frame = createTestCancelFrame(); + void encodeFirstWithDataAndMetadataStream() { + ByteBuf rc = + RequestStreamFrameCodec.encode( + allocator, 1, true, 50, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); + + ByteBuf fragment = + FrameFragmenter.encodeFirstFragment( + allocator, + 256, + rc, + FrameType.REQUEST_STREAM, + 1, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)); + + Assert.assertEquals(256, fragment.readableBytes()); + Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertEquals(50, RequestStreamFrameCodec.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); + + ByteBuf data = RequestStreamFrameCodec.data(fragment); + Assert.assertEquals(0, data.readableBytes()); + + ByteBuf metadata = RequestStreamFrameCodec.metadata(fragment); + ByteBuf byteBuf = Unpooled.wrappedBuffer(this.metadata).readSlice(metadata.readableBytes()); + Assert.assertEquals(byteBuf, metadata); + + Assert.assertTrue(FrameHeaderCodec.hasMetadata(fragment)); + } - new FrameFragmenter(DEFAULT, 2) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) + @DisplayName("fragments frame with only data") + @Test + void fragmentData() { + ByteBuf rr = + RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); + + Publisher fragments = + FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE, false); + + StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) + .expectNextCount(1) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(byteBuf)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); + }) + .expectNextCount(2) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); + }) .verifyComplete(); } - @DisplayName("fragment throws NullPointerException with null frame") + @DisplayName("fragments frame with only metadata") @Test - void fragmentWithNullFrame() { - assertThatNullPointerException() - .isThrownBy(() -> new FrameFragmenter(DEFAULT, 2).fragment(null)) - .withMessage("frame must not be null"); + void fragmentMetadata() { + ByteBuf rr = + RequestStreamFrameCodec.encode( + allocator, 1, true, 10, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); + + Publisher fragments = + FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_STREAM, false); + + StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) + .expectNextCount(1) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(byteBuf)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); + }) + .expectNextCount(2) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); + }) + .verifyComplete(); } - @DisplayName("does not fragment with zero maxFragmentLength") + @DisplayName("fragments frame with data and metadata") @Test - void fragmentZeroMaxFragmentLength() { - PayloadFrame frame = createPayloadFrame(DEFAULT, false, false, null, getRandomByteBuf(2)); - - new FrameFragmenter(DEFAULT, 0) - .fragment(frame) - .as(StepVerifier::create) - .expectNext(frame) + void fragmentDataAndMetadata() { + ByteBuf rr = + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); + + Publisher fragments = + FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE, false); + + StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); + }) + .expectNextCount(6) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); + }) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); + }) .verifyComplete(); - }*/ + } } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java index 467f6c2e7..56f7fcf90 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java @@ -16,108 +16,468 @@ package io.rsocket.fragmentation; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.frame.*; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + final class FrameReassemblerTest { - /* - @DisplayName("createFrameReassembler throws NullPointerException") - @Test - void createFrameReassemblerNullByteBufAllocator() { - assertThatNullPointerException() - .isThrownBy(() -> createFrameReassembler(null)) - .withMessage("byteBufAllocator must not be null"); + private static byte[] data = new byte[1024]; + private static byte[] metadata = new byte[1024]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); } + private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + @DisplayName("reassembles data") @Test void reassembleData() { - ByteBuf data = getRandomByteBuf(6); + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); + + FrameReassembler reassembler = new FrameReassembler(allocator); - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, null, data); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, null, data.slice(0, 2)); + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data)); + + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) + .verifyComplete(); + ReferenceCountUtil.safeRelease(data); + } + + @DisplayName("pass through frames without follows") + @Test + void passthrough() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, false, null, Unpooled.wrappedBuffer(data))); - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, false, null, data.slice(2, 2)); + FrameReassembler reassembler = new FrameReassembler(allocator); - PayloadFrame fragment3 = createPayloadFrame(DEFAULT, false, false, null, data.slice(4, 2)); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - FrameReassembler frameReassembler = createFrameReassembler(DEFAULT); + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents(true, Unpooled.wrappedBuffer(FrameReassemblerTest.data)); - assertThat(frameReassembler.reassemble(fragment1)).isNull(); - assertThat(frameReassembler.reassemble(fragment2)).isNull(); - assertThat(frameReassembler.reassemble(fragment3)).isEqualTo(frame); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) + .verifyComplete(); + ReferenceCountUtil.safeRelease(data); } @DisplayName("reassembles metadata") @Test void reassembleMetadata() { - ByteBuf metadata = getRandomByteBuf(6); + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, metadata, null); + FrameReassembler reassembler = new FrameReassembler(allocator); - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null); + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - PayloadFrame fragment3 = createPayloadFrame(DEFAULT, false, true, metadata.slice(4, 2), null); - - FrameReassembler frameReassembler = createFrameReassembler(DEFAULT); - - assertThat(frameReassembler.reassemble(fragment1)).isNull(); - assertThat(frameReassembler.reassemble(fragment2)).isNull(); - assertThat(frameReassembler.reassemble(fragment3)).isEqualTo(frame); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestResponseFrameCodec.metadata(byteBuf); + Assert.assertEquals(metadata, m); + }) + .verifyComplete(); } - @DisplayName("reassembles metadata and data") + @DisplayName("reassembles metadata request channel") @Test - void reassembleMetadataAndData() { - ByteBuf metadata = getRandomByteBuf(5); - ByteBuf data = getRandomByteBuf(5); + void reassembleMetadataChannel() { + List byteBufs = + Arrays.asList( + RequestChannelFrameCodec.encode( + allocator, + 1, + true, + false, + 100, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); - RequestStreamFrame frame = createRequestStreamFrame(DEFAULT, false, 1, metadata, data); + FrameReassembler reassembler = new FrameReassembler(allocator); - RequestStreamFrame fragment1 = - createRequestStreamFrame(DEFAULT, true, 1, metadata.slice(0, 2), null); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - PayloadFrame fragment2 = createPayloadFrame(DEFAULT, true, true, metadata.slice(2, 2), null); + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - PayloadFrame fragment3 = - createPayloadFrame(DEFAULT, true, false, metadata.slice(4, 1), data.slice(0, 1)); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestChannelFrameCodec.metadata(byteBuf); + Assert.assertEquals(metadata, m); + Assert.assertEquals(100, RequestChannelFrameCodec.initialRequestN(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) + .verifyComplete(); + + ReferenceCountUtil.safeRelease(metadata); + } + + @DisplayName("reassembles metadata request stream") + @Test + void reassembleMetadataStream() { + List byteBufs = + Arrays.asList( + RequestStreamFrameCodec.encode( + allocator, 1, true, 250, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); - PayloadFrame fragment4 = createPayloadFrame(DEFAULT, true, false, null, data.slice(1, 2)); + FrameReassembler reassembler = new FrameReassembler(allocator); - PayloadFrame fragment5 = createPayloadFrame(DEFAULT, false, false, null, data.slice(3, 2)); + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); - FrameReassembler frameReassembler = createFrameReassembler(DEFAULT); + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - assertThat(frameReassembler.reassemble(fragment1)).isNull(); - assertThat(frameReassembler.reassemble(fragment2)).isNull(); - assertThat(frameReassembler.reassemble(fragment3)).isNull(); - assertThat(frameReassembler.reassemble(fragment4)).isNull(); - assertThat(frameReassembler.reassemble(fragment5)).isEqualTo(frame); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestStreamFrameCodec.metadata(byteBuf); + Assert.assertEquals(metadata, m); + Assert.assertEquals(250, RequestChannelFrameCodec.initialRequestN(byteBuf)); + ReferenceCountUtil.safeRelease(byteBuf); + }) + .verifyComplete(); + + ReferenceCountUtil.safeRelease(metadata); } - @DisplayName("does not reassemble a non-fragment frame") + @DisplayName("reassembles metadata and data") @Test - void reassembleNonFragment() { - PayloadFrame frame = createPayloadFrame(DEFAULT, false, true, (ByteBuf) null, null); + void reassembleMetadataAndData() { + + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); + + FrameReassembler reassembler = new FrameReassembler(allocator); + + Flux assembled = Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.data), + Unpooled.wrappedBuffer(FrameReassemblerTest.data)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata), + Unpooled.wrappedBuffer(FrameReassemblerTest.metadata)); - assertThat(createFrameReassembler(DEFAULT).reassemble(frame)).isEqualTo(frame); + StepVerifier.create(assembled) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); + Assert.assertEquals(metadata, RequestResponseFrameCodec.metadata(byteBuf)); + }) + .verifyComplete(); + ReferenceCountUtil.safeRelease(data); + ReferenceCountUtil.safeRelease(metadata); } - @DisplayName("does not reassemble non fragmentable frame") + @DisplayName("cancel removes inflight frames") @Test - void reassembleNonFragmentableFrame() { - CancelFrame frame = createTestCancelFrame(); + public void cancelBeforeAssembling() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data))); - assertThat(createFrameReassembler(DEFAULT).reassemble(frame)).isEqualTo(frame); + FrameReassembler reassembler = new FrameReassembler(allocator); + Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); + + Assert.assertTrue(reassembler.headers.containsKey(1)); + Assert.assertTrue(reassembler.metadata.containsKey(1)); + Assert.assertTrue(reassembler.data.containsKey(1)); + + Flux.just(CancelFrameCodec.encode(allocator, 1)) + .handle(reassembler::reassembleFrame) + .blockLast(); + + Assert.assertFalse(reassembler.headers.containsKey(1)); + Assert.assertFalse(reassembler.metadata.containsKey(1)); + Assert.assertFalse(reassembler.data.containsKey(1)); } - @DisplayName("reassemble throws NullPointerException with null frame") + @DisplayName("dispose should clean up maps") @Test - void reassembleNullFrame() { - assertThatNullPointerException() - .isThrownBy(() -> createFrameReassembler(DEFAULT).reassemble(null)) - .withMessage("frame must not be null"); - }*/ + public void dispose() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data))); + + FrameReassembler reassembler = new FrameReassembler(allocator); + Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); + + Assert.assertTrue(reassembler.headers.containsKey(1)); + Assert.assertTrue(reassembler.metadata.containsKey(1)); + Assert.assertTrue(reassembler.data.containsKey(1)); + + reassembler.dispose(); + + Assert.assertFalse(reassembler.headers.containsKey(1)); + Assert.assertFalse(reassembler.metadata.containsKey(1)); + Assert.assertFalse(reassembler.data.containsKey(1)); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java new file mode 100644 index 000000000..b083d6841 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java @@ -0,0 +1,272 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.fragmentation; + +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +final class ReassembleDuplexConnectionTest { + private static byte[] data = new byte[1024]; + private static byte[] metadata = new byte[1024]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } + + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + @DisplayName("reassembles data") + @Test + void reassembleData() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("reassembles metadata") + @Test + void reassembleMetadata() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestResponseFrameCodec.metadata(byteBuf); + Assert.assertEquals(metadata, m); + }) + .verifyComplete(); + } + + @DisplayName("reassembles metadata and data") + @Test + void reassembleMetadataAndData() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); + Assert.assertEquals(metadata, RequestResponseFrameCodec.metadata(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("does not reassemble a non-fragment frame") + @Test + void reassembleNonFragment() { + ByteBuf encode = + RequestResponseFrameCodec.encode(allocator, 1, false, null, Unpooled.wrappedBuffer(data)); + + when(delegate.receive()).thenReturn(Flux.just(encode)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals( + Unpooled.wrappedBuffer(data), RequestResponseFrameCodec.data(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("does not reassemble non fragmentable frame") + @Test + void reassembleNonFragmentableFrame() { + ByteBuf encode = CancelFrameCodec.encode(allocator, 2); + + when(delegate.receive()).thenReturn(Flux.just(encode)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.CANCEL, FrameHeaderCodec.frameType(byteBuf)); + }) + .verifyComplete(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index b22a95c0b..63300c718 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -17,6 +17,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; import org.assertj.core.presentation.StandardRepresentation; public final class ByteBufRepresentation extends StandardRepresentation { @@ -24,7 +25,17 @@ public final class ByteBufRepresentation extends StandardRepresentation { @Override protected String fallbackToStringOf(Object object) { if (object instanceof ByteBuf) { - return ByteBufUtil.prettyHexDump((ByteBuf) object); + try { + String normalBufferString = object.toString(); + String prettyHexDump = ByteBufUtil.prettyHexDump((ByteBuf) object); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } catch (IllegalReferenceCountException e) { + // noops + } } return super.fallbackToStringOf(object); diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java similarity index 65% rename from rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java index fa663432c..dc04c1141 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java @@ -8,13 +8,13 @@ import io.rsocket.exceptions.ApplicationErrorException; import org.junit.jupiter.api.Test; -class ErrorFrameFlyweightTest { +class ErrorFrameCodecTest { @Test void testEncode() { ByteBuf frame = - ErrorFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 1, new ApplicationErrorException("d")); + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, new ApplicationErrorException("d")); - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); assertEquals("00000b000000012c000000020164", ByteBufUtil.hexDump(frame)); frame.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java new file mode 100644 index 000000000..28209393e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java @@ -0,0 +1,62 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ExtensionFrameCodecTest { + + @Test + void extensionDataMetadata() { + ByteBuf metadata = bytebuf("md"); + ByteBuf data = bytebuf("d"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, extendedType, metadata, data); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertEquals(metadata, ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(data, ExtensionFrameCodec.data(extension)); + extension.release(); + } + + @Test + void extensionData() { + ByteBuf data = bytebuf("d"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, extendedType, null, data); + + Assertions.assertFalse(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertNull(ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(data, ExtensionFrameCodec.data(extension)); + extension.release(); + } + + @Test + void extensionMetadata() { + ByteBuf metadata = bytebuf("md"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode( + ByteBufAllocator.DEFAULT, 1, extendedType, metadata, Unpooled.EMPTY_BUFFER); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertEquals(metadata, ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(0, ExtensionFrameCodec.data(extension).readableBytes()); + extension.release(); + } + + private static ByteBuf bytebuf(String str) { + return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java deleted file mode 100644 index e337d4332..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java +++ /dev/null @@ -1,62 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import java.nio.charset.StandardCharsets; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class ExtensionFrameFlyweightTest { - - @Test - void extensionDataMetadata() { - ByteBuf metadata = bytebuf("md"); - ByteBuf data = bytebuf("d"); - int extendedType = 1; - - ByteBuf extension = - ExtensionFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 1, extendedType, metadata, data); - - Assertions.assertTrue(FrameHeaderFlyweight.hasMetadata(extension)); - Assertions.assertEquals(extendedType, ExtensionFrameFlyweight.extendedType(extension)); - Assertions.assertEquals(metadata, ExtensionFrameFlyweight.metadata(extension)); - Assertions.assertEquals(data, ExtensionFrameFlyweight.data(extension)); - extension.release(); - } - - @Test - void extensionData() { - ByteBuf data = bytebuf("d"); - int extendedType = 1; - - ByteBuf extension = - ExtensionFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 1, extendedType, null, data); - - Assertions.assertFalse(FrameHeaderFlyweight.hasMetadata(extension)); - Assertions.assertEquals(extendedType, ExtensionFrameFlyweight.extendedType(extension)); - Assertions.assertEquals(0, ExtensionFrameFlyweight.metadata(extension).readableBytes()); - Assertions.assertEquals(data, ExtensionFrameFlyweight.data(extension)); - extension.release(); - } - - @Test - void extensionMetadata() { - ByteBuf metadata = bytebuf("md"); - int extendedType = 1; - - ByteBuf extension = - ExtensionFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 1, extendedType, metadata, Unpooled.EMPTY_BUFFER); - - Assertions.assertTrue(FrameHeaderFlyweight.hasMetadata(extension)); - Assertions.assertEquals(extendedType, ExtensionFrameFlyweight.extendedType(extension)); - Assertions.assertEquals(metadata, ExtensionFrameFlyweight.metadata(extension)); - Assertions.assertEquals(0, ExtensionFrameFlyweight.data(extension).readableBytes()); - extension.release(); - } - - private static ByteBuf bytebuf(String str) { - return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java similarity index 52% rename from rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java index a17fcc205..15788e631 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java @@ -7,7 +7,7 @@ import io.netty.buffer.ByteBufAllocator; import org.junit.jupiter.api.Test; -class FrameHeaderFlyweightTest { +class FrameHeaderCodecTest { // Taken from spec private static final int FRAME_MAX_SIZE = 16_777_215; @@ -15,10 +15,10 @@ class FrameHeaderFlyweightTest { void typeAndFlag() { FrameType frameType = FrameType.REQUEST_FNF; int flags = 0b1110110111; - ByteBuf header = FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); + ByteBuf header = FrameHeaderCodec.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); - assertEquals(flags, FrameHeaderFlyweight.flags(header)); - assertEquals(frameType, FrameHeaderFlyweight.frameType(header)); + assertEquals(flags, FrameHeaderCodec.flags(header)); + assertEquals(frameType, FrameHeaderCodec.frameType(header)); header.release(); } @@ -26,11 +26,11 @@ void typeAndFlag() { void typeAndFlagTruncated() { FrameType frameType = FrameType.SETUP; int flags = 0b11110110111; // 1 bit too many - ByteBuf header = FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); + ByteBuf header = FrameHeaderCodec.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); - assertNotEquals(flags, FrameHeaderFlyweight.flags(header)); - assertEquals(flags & 0b0000_0011_1111_1111, FrameHeaderFlyweight.flags(header)); - assertEquals(frameType, FrameHeaderFlyweight.frameType(header)); + assertNotEquals(flags, FrameHeaderCodec.flags(header)); + assertEquals(flags & 0b0000_0011_1111_1111, FrameHeaderCodec.flags(header)); + assertEquals(frameType, FrameHeaderCodec.frameType(header)); header.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java new file mode 100644 index 000000000..ac19dc754 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java @@ -0,0 +1,264 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; + +class GenericFrameCodecTest { + @Test + void testEncoding() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length + // | | | | ⌌Encoded Metadata + // | | | | | ⌌Encoded Data + // __|________|_________|______|____|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "000010000000011900000000010000026d6464"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void testEncodingWithEmptyMetadata() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length (0) + // | | | | ⌌Encoded Data + // __|________|_________|_______|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "00000e0000000119000000000100000064"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void testEncodingWithNullMetadata() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Data + // __|________|_________|_____| + // ↓<-> ↓↓ <-> ↓↓ <-> ↓↓↓ + String expected = "00000b0000000118000000000164"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void requestResponseDataMetadata() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestResponseFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = RequestResponseFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestResponseData() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestResponseFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestResponseFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertNull(metadata); + request.release(); + } + + @Test + void requestResponseMetadata() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + ByteBuf data = RequestResponseFrameCodec.data(request); + String metadata = RequestResponseFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertTrue(data.readableBytes() == 0); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestStreamDataMetadata() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Integer.MAX_VALUE + 1L, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + String data = RequestStreamFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = RequestStreamFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals(Long.MAX_VALUE, actualRequest); + assertEquals("md", metadata); + assertEquals("d", data); + request.release(); + } + + @Test + void requestStreamData() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 42, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + String data = RequestStreamFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestStreamFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals(42L, actualRequest); + assertNull(metadata); + assertEquals("d", data); + request.release(); + } + + @Test + void requestStreamMetadata() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 42, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + ByteBuf data = RequestStreamFrameCodec.data(request); + String metadata = RequestStreamFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals(42L, actualRequest); + assertTrue(data.readableBytes() == 0); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestFnfDataAndMetadata() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestFireAndForgetFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = + RequestFireAndForgetFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestFnfData() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestFireAndForgetFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestFireAndForgetFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertNull(metadata); + request.release(); + } + + @Test + void requestFnfMetadata() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + ByteBuf data = RequestFireAndForgetFrameCodec.data(request); + String metadata = + RequestFireAndForgetFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("md", metadata); + assertTrue(data.readableBytes() == 0); + request.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java index eb55e89cd..bc013e024 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java @@ -14,18 +14,18 @@ class KeepaliveFrameFlyweightTest { @Test void canReadData() { ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); - ByteBuf frame = KeepAliveFrameFlyweight.encode(ByteBufAllocator.DEFAULT, true, 0, data); - assertTrue(KeepAliveFrameFlyweight.respondFlag(frame)); - assertEquals(data, KeepAliveFrameFlyweight.data(frame)); + ByteBuf frame = KeepAliveFrameCodec.encode(ByteBufAllocator.DEFAULT, true, 0, data); + assertTrue(KeepAliveFrameCodec.respondFlag(frame)); + assertEquals(data, KeepAliveFrameCodec.data(frame)); frame.release(); } @Test void testEncoding() { ByteBuf frame = - KeepAliveFrameFlyweight.encode( + KeepAliveFrameCodec.encode( ByteBufAllocator.DEFAULT, true, 0, Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); assertEquals("00000f000000000c80000000000000000064", ByteBufUtil.hexDump(frame)); frame.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFlyweightTest.java deleted file mode 100644 index 5c226f309..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFlyweightTest.java +++ /dev/null @@ -1,42 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import java.nio.charset.StandardCharsets; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class LeaseFlyweightTest { - - @Test - void leaseMetadata() { - ByteBuf metadata = bytebuf("md"); - int ttl = 1; - int numRequests = 42; - ByteBuf lease = LeaseFlyweight.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, metadata); - - Assertions.assertTrue(FrameHeaderFlyweight.hasMetadata(lease)); - Assertions.assertEquals(ttl, LeaseFlyweight.ttl(lease)); - Assertions.assertEquals(numRequests, LeaseFlyweight.numRequests(lease)); - Assertions.assertEquals(metadata, LeaseFlyweight.metadata(lease)); - lease.release(); - } - - @Test - void leaseAbsentMetadata() { - int ttl = 1; - int numRequests = 42; - ByteBuf lease = LeaseFlyweight.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, null); - - Assertions.assertFalse(FrameHeaderFlyweight.hasMetadata(lease)); - Assertions.assertEquals(ttl, LeaseFlyweight.ttl(lease)); - Assertions.assertEquals(numRequests, LeaseFlyweight.numRequests(lease)); - Assertions.assertEquals(0, LeaseFlyweight.metadata(lease).readableBytes()); - lease.release(); - } - - private static ByteBuf bytebuf(String str) { - return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java new file mode 100644 index 000000000..73c3bde5e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java @@ -0,0 +1,42 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class LeaseFrameCodecTest { + + @Test + void leaseMetadata() { + ByteBuf metadata = bytebuf("md"); + int ttl = 1; + int numRequests = 42; + ByteBuf lease = LeaseFrameCodec.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, metadata); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(lease)); + Assertions.assertEquals(ttl, LeaseFrameCodec.ttl(lease)); + Assertions.assertEquals(numRequests, LeaseFrameCodec.numRequests(lease)); + Assertions.assertEquals(metadata, LeaseFrameCodec.metadata(lease)); + lease.release(); + } + + @Test + void leaseAbsentMetadata() { + int ttl = 1; + int numRequests = 42; + ByteBuf lease = LeaseFrameCodec.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, null); + + Assertions.assertFalse(FrameHeaderCodec.hasMetadata(lease)); + Assertions.assertEquals(ttl, LeaseFrameCodec.ttl(lease)); + Assertions.assertEquals(numRequests, LeaseFrameCodec.numRequests(lease)); + Assertions.assertNull(LeaseFrameCodec.metadata(lease)); + lease.release(); + } + + private static ByteBuf bytebuf(String str) { + return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java index 9ef89326a..aecbb31ce 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java @@ -15,9 +15,9 @@ public class PayloadFlyweightTest { void nextCompleteDataMetadata() { Payload payload = DefaultPayload.create("d", "md"); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); - String data = PayloadFrameFlyweight.data(nextComplete).toString(StandardCharsets.UTF_8); - String metadata = PayloadFrameFlyweight.metadata(nextComplete).toString(StandardCharsets.UTF_8); + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(nextComplete).toString(StandardCharsets.UTF_8); + String metadata = PayloadFrameCodec.metadata(nextComplete).toString(StandardCharsets.UTF_8); Assertions.assertEquals("d", data); Assertions.assertEquals("md", metadata); nextComplete.release(); @@ -27,11 +27,11 @@ void nextCompleteDataMetadata() { void nextCompleteData() { Payload payload = DefaultPayload.create("d"); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); - String data = PayloadFrameFlyweight.data(nextComplete).toString(StandardCharsets.UTF_8); - ByteBuf metadata = PayloadFrameFlyweight.metadata(nextComplete); + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(nextComplete).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(nextComplete); Assertions.assertEquals("d", data); - Assertions.assertTrue(metadata.readableBytes() == 0); + Assertions.assertNull(metadata); nextComplete.release(); } @@ -42,9 +42,9 @@ void nextCompleteMetaData() { Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer("md".getBytes(StandardCharsets.UTF_8))); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); - ByteBuf data = PayloadFrameFlyweight.data(nextComplete); - String metadata = PayloadFrameFlyweight.metadata(nextComplete).toString(StandardCharsets.UTF_8); + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + ByteBuf data = PayloadFrameCodec.data(nextComplete); + String metadata = PayloadFrameCodec.metadata(nextComplete).toString(StandardCharsets.UTF_8); Assertions.assertTrue(data.readableBytes() == 0); Assertions.assertEquals("md", metadata); nextComplete.release(); @@ -53,9 +53,10 @@ void nextCompleteMetaData() { @Test void nextDataMetadata() { Payload payload = DefaultPayload.create("d", "md"); - ByteBuf next = PayloadFrameFlyweight.encodeNext(ByteBufAllocator.DEFAULT, 1, payload); - String data = PayloadFrameFlyweight.data(next).toString(StandardCharsets.UTF_8); - String metadata = PayloadFrameFlyweight.metadata(next).toString(StandardCharsets.UTF_8); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + String metadata = PayloadFrameCodec.metadata(next).toString(StandardCharsets.UTF_8); Assertions.assertEquals("d", data); Assertions.assertEquals("md", metadata); next.release(); @@ -64,11 +65,24 @@ void nextDataMetadata() { @Test void nextData() { Payload payload = DefaultPayload.create("d"); - ByteBuf next = PayloadFrameFlyweight.encodeNext(ByteBufAllocator.DEFAULT, 1, payload); - String data = PayloadFrameFlyweight.data(next).toString(StandardCharsets.UTF_8); - ByteBuf metadata = PayloadFrameFlyweight.metadata(next); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(next); Assertions.assertEquals("d", data); - Assertions.assertTrue(metadata.readableBytes() == 0); + Assertions.assertNull(metadata); + next.release(); + } + + @Test + void nextDataEmptyMetadata() { + Payload payload = DefaultPayload.create("d".getBytes(), new byte[0]); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(next); + Assertions.assertEquals("d", data); + Assertions.assertEquals(metadata.readableBytes(), 0); next.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java deleted file mode 100644 index 9acec2c81..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java +++ /dev/null @@ -1,249 +0,0 @@ -package io.rsocket.frame; - -import static org.junit.jupiter.api.Assertions.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import java.nio.charset.StandardCharsets; -import org.junit.jupiter.api.Test; - -class RequestFlyweightTest { - @Test - void testEncoding() { - ByteBuf frame = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 1, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - - assertEquals("000010000000011900000000010000026d6464", ByteBufUtil.hexDump(frame)); - frame.release(); - } - - @Test - void testEncodingWithEmptyMetadata() { - ByteBuf frame = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 1, - Unpooled.EMPTY_BUFFER, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - - assertEquals("00000e0000000119000000000100000064", ByteBufUtil.hexDump(frame)); - frame.release(); - } - - @Test - void testEncodingWithNullMetadata() { - ByteBuf frame = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 1, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - - assertEquals("00000b0000000118000000000164", ByteBufUtil.hexDump(frame)); - frame.release(); - } - - @Test - void requestResponseDataMetadata() { - ByteBuf request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - String data = RequestResponseFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - String metadata = - RequestResponseFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("d", data); - assertEquals("md", metadata); - request.release(); - } - - @Test - void requestResponseData() { - ByteBuf request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - String data = RequestResponseFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - ByteBuf metadata = RequestResponseFrameFlyweight.metadata(request); - - assertFalse(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("d", data); - assertTrue(metadata.readableBytes() == 0); - request.release(); - } - - @Test - void requestResponseMetadata() { - ByteBuf request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.EMPTY_BUFFER); - - ByteBuf data = RequestResponseFrameFlyweight.data(request); - String metadata = - RequestResponseFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertTrue(data.readableBytes() == 0); - assertEquals("md", metadata); - request.release(); - } - - @Test - void requestStreamDataMetadata() { - ByteBuf request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Integer.MAX_VALUE + 1L, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); - String data = RequestStreamFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - String metadata = - RequestStreamFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(Integer.MAX_VALUE, actualRequest); - assertEquals("md", metadata); - assertEquals("d", data); - request.release(); - } - - @Test - void requestStreamData() { - ByteBuf request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 42, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); - String data = RequestStreamFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - ByteBuf metadata = RequestStreamFrameFlyweight.metadata(request); - - assertFalse(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(42, actualRequest); - assertTrue(metadata.readableBytes() == 0); - assertEquals("d", data); - request.release(); - } - - @Test - void requestStreamMetadata() { - ByteBuf request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 42, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.EMPTY_BUFFER); - - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); - ByteBuf data = RequestStreamFrameFlyweight.data(request); - String metadata = - RequestStreamFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(42, actualRequest); - assertTrue(data.readableBytes() == 0); - assertEquals("md", metadata); - request.release(); - } - - @Test - void requestFnfDataAndMetadata() { - ByteBuf request = - RequestFireAndForgetFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - String data = RequestFireAndForgetFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - String metadata = - RequestFireAndForgetFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("d", data); - assertEquals("md", metadata); - request.release(); - } - - @Test - void requestFnfData() { - ByteBuf request = - RequestFireAndForgetFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - String data = RequestFireAndForgetFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - ByteBuf metadata = RequestFireAndForgetFrameFlyweight.metadata(request); - - assertFalse(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("d", data); - assertTrue(metadata.readableBytes() == 0); - request.release(); - } - - @Test - void requestFnfMetadata() { - ByteBuf request = - RequestFireAndForgetFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.EMPTY_BUFFER); - - ByteBuf data = RequestFireAndForgetFrameFlyweight.data(request); - String metadata = - RequestFireAndForgetFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("md", metadata); - assertTrue(data.readableBytes() == 0); - request.release(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java similarity index 63% rename from rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java index 4411b98c9..e38258040 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java @@ -7,12 +7,12 @@ import io.netty.buffer.ByteBufUtil; import org.junit.jupiter.api.Test; -class RequestNFrameFlyweightTest { +class RequestNFrameCodecTest { @Test void testEncoding() { - ByteBuf frame = RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 1, 5); + ByteBuf frame = RequestNFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, 5); - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); assertEquals("00000a00000001200000000005", ByteBufUtil.hexDump(frame)); frame.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java new file mode 100644 index 000000000..fe05335d2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.util.Arrays; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +public class ResumeFrameCodecTest { + + @Test + void testEncoding() { + byte[] tokenBytes = new byte[65000]; + Arrays.fill(tokenBytes, (byte) 1); + ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); + ByteBuf byteBuf = ResumeFrameCodec.encode(ByteBufAllocator.DEFAULT, token, 21, 12); + Assert.assertEquals(ResumeFrameCodec.CURRENT_VERSION, ResumeFrameCodec.version(byteBuf)); + Assert.assertEquals(token, ResumeFrameCodec.token(byteBuf)); + Assert.assertEquals(21, ResumeFrameCodec.lastReceivedServerPos(byteBuf)); + Assert.assertEquals(12, ResumeFrameCodec.firstAvailableClientPos(byteBuf)); + byteBuf.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java new file mode 100644 index 000000000..33dd8eb70 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java @@ -0,0 +1,16 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.junit.Assert; +import org.junit.Test; + +public class ResumeOkFrameCodecTest { + + @Test + public void testEncoding() { + ByteBuf byteBuf = ResumeOkFrameCodec.encode(ByteBufAllocator.DEFAULT, 42); + Assert.assertEquals(42, ResumeOkFrameCodec.lastReceivedClientPos(byteBuf)); + byteBuf.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java new file mode 100644 index 000000000..9607ad327 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java @@ -0,0 +1,57 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +class SetupFrameCodecTest { + @Test + void testEncodingNoResume() { + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + Payload payload = DefaultPayload.create(data, metadata); + ByteBuf frame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, false, 5, 500, "metadata_type", "data_type", payload); + + assertEquals(FrameType.SETUP, FrameHeaderCodec.frameType(frame)); + assertFalse(SetupFrameCodec.resumeEnabled(frame)); + assertEquals(0, SetupFrameCodec.resumeToken(frame).readableBytes()); + assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); + assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); + assertEquals(metadata, SetupFrameCodec.metadata(frame)); + assertEquals(data, SetupFrameCodec.data(frame)); + assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); + frame.release(); + } + + @Test + void testEncodingResume() { + byte[] tokenBytes = new byte[65000]; + Arrays.fill(tokenBytes, (byte) 1); + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + Payload payload = DefaultPayload.create(data, metadata); + ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); + ByteBuf frame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, true, 5, 500, token, "metadata_type", "data_type", payload); + + assertEquals(FrameType.SETUP, FrameHeaderCodec.frameType(frame)); + assertTrue(SetupFrameCodec.honorLease(frame)); + assertTrue(SetupFrameCodec.resumeEnabled(frame)); + assertEquals(token, SetupFrameCodec.resumeToken(frame)); + assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); + assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); + assertEquals(metadata, SetupFrameCodec.metadata(frame)); + assertEquals(data, SetupFrameCodec.data(frame)); + assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java deleted file mode 100644 index 36c9946aa..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java +++ /dev/null @@ -1,75 +0,0 @@ -package io.rsocket.frame; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import java.nio.charset.StandardCharsets; -import org.junit.jupiter.api.Test; - -class SetupFrameFlyweightTest { - @Test - void validFrame() { - ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); - ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); - ByteBuf frame = - SetupFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - false, - false, - 5, - 500, - "metadata_type", - "data_type", - metadata, - data); - - assertEquals(FrameType.SETUP, FrameHeaderFlyweight.frameType(frame)); - assertEquals("metadata_type", SetupFrameFlyweight.metadataMimeType(frame)); - assertEquals("data_type", SetupFrameFlyweight.dataMimeType(frame)); - assertEquals(metadata, SetupFrameFlyweight.metadata(frame)); - assertEquals(data, SetupFrameFlyweight.data(frame)); - assertEquals(SetupFrameFlyweight.CURRENT_VERSION, SetupFrameFlyweight.version(frame)); - frame.release(); - } - - @Test - void resumeNotSupported() { - assertThrows( - IllegalArgumentException.class, - () -> - SetupFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - false, - true, - 5, - 500, - "", - "", - Unpooled.EMPTY_BUFFER, - Unpooled.EMPTY_BUFFER)); - } - - @Test - public void testEncoding() { - ByteBuf frame = - SetupFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - false, - false, - 5000, - 60000, - "mdmt", - "dmt", - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - assertEquals( - "00002100000000050000010000000013880000ea60046d646d7403646d740000026d6464", - ByteBufUtil.hexDump(frame)); - frame.release(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java similarity index 58% rename from rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java index 3f311c7ef..be7fb837b 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java @@ -20,29 +20,29 @@ import org.junit.jupiter.api.Test; -public class VersionFlyweightTest { +public class VersionCodecTest { @Test public void simple() { - int version = VersionFlyweight.encode(1, 0); - assertEquals(1, VersionFlyweight.major(version)); - assertEquals(0, VersionFlyweight.minor(version)); + int version = VersionCodec.encode(1, 0); + assertEquals(1, VersionCodec.major(version)); + assertEquals(0, VersionCodec.minor(version)); assertEquals(0x00010000, version); - assertEquals("1.0", VersionFlyweight.toString(version)); + assertEquals("1.0", VersionCodec.toString(version)); } @Test public void complex() { - int version = VersionFlyweight.encode(0x1234, 0x5678); - assertEquals(0x1234, VersionFlyweight.major(version)); - assertEquals(0x5678, VersionFlyweight.minor(version)); + int version = VersionCodec.encode(0x1234, 0x5678); + assertEquals(0x1234, VersionCodec.major(version)); + assertEquals(0x5678, VersionCodec.minor(version)); assertEquals(0x12345678, version); - assertEquals("4660.22136", VersionFlyweight.toString(version)); + assertEquals("4660.22136", VersionCodec.toString(version)); } @Test public void noShortOverflow() { - int version = VersionFlyweight.encode(43210, 43211); - assertEquals(43210, VersionFlyweight.major(version)); - assertEquals(43211, VersionFlyweight.minor(version)); + int version = VersionCodec.encode(43210, 43211); + assertEquals(43210, VersionCodec.major(version)); + assertEquals(43211, VersionCodec.minor(version)); } } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java index 6cbf050ac..63acc40aa 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java @@ -18,60 +18,212 @@ import static org.junit.Assert.assertEquals; +import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.plugins.PluginRegistry; +import io.netty.buffer.Unpooled; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.*; +import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.DefaultPayload; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; import org.junit.Test; public class ClientServerInputMultiplexerTest { private TestDuplexConnection source; - private ClientServerInputMultiplexer multiplexer; - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private ClientServerInputMultiplexer clientMultiplexer; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private ClientServerInputMultiplexer serverMultiplexer; @Before public void setup() { - source = new TestDuplexConnection(); - multiplexer = new ClientServerInputMultiplexer(source, new PluginRegistry()); + source = new TestDuplexConnection(allocator); + clientMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); + serverMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); } @Test - public void testSplits() { + public void clientSplits() { AtomicInteger clientFrames = new AtomicInteger(); AtomicInteger serverFrames = new AtomicInteger(); - AtomicInteger connectionFrames = new AtomicInteger(); + AtomicInteger setupFrames = new AtomicInteger(); - multiplexer + clientMultiplexer .asClientConnection() .receive() .doOnNext(f -> clientFrames.incrementAndGet()) .subscribe(); - multiplexer + clientMultiplexer .asServerConnection() .receive() .doOnNext(f -> serverFrames.incrementAndGet()) .subscribe(); - multiplexer - .asStreamZeroConnection() + clientMultiplexer + .asSetupConnection() .receive() - .doOnNext(f -> connectionFrames.incrementAndGet()) + .doOnNext(f -> setupFrames.incrementAndGet()) .subscribe(); - source.addToReceivedBuffer(ErrorFrameFlyweight.encode(allocator, 1, new Exception())); + source.addToReceivedBuffer(errorFrame(1)); assertEquals(1, clientFrames.get()); assertEquals(0, serverFrames.get()); - assertEquals(0, connectionFrames.get()); + assertEquals(0, setupFrames.get()); - source.addToReceivedBuffer(ErrorFrameFlyweight.encode(allocator, 2, new Exception())); - assertEquals(1, clientFrames.get()); + source.addToReceivedBuffer(errorFrame(1)); + assertEquals(2, clientFrames.get()); + assertEquals(0, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(leaseFrame()); + assertEquals(3, clientFrames.get()); + assertEquals(0, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(keepAliveFrame()); + assertEquals(4, clientFrames.get()); + assertEquals(0, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(errorFrame(2)); + assertEquals(4, clientFrames.get()); assertEquals(1, serverFrames.get()); - assertEquals(0, connectionFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(errorFrame(0)); + assertEquals(5, clientFrames.get()); + assertEquals(1, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(metadataPushFrame()); + assertEquals(5, clientFrames.get()); + assertEquals(2, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(setupFrame()); + assertEquals(5, clientFrames.get()); + assertEquals(2, serverFrames.get()); + assertEquals(1, setupFrames.get()); + + source.addToReceivedBuffer(resumeFrame()); + assertEquals(5, clientFrames.get()); + assertEquals(2, serverFrames.get()); + assertEquals(2, setupFrames.get()); + + source.addToReceivedBuffer(resumeOkFrame()); + assertEquals(5, clientFrames.get()); + assertEquals(2, serverFrames.get()); + assertEquals(3, setupFrames.get()); + } + + @Test + public void serverSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + AtomicInteger setupFrames = new AtomicInteger(); + + serverMultiplexer + .asClientConnection() + .receive() + .doOnNext(f -> clientFrames.incrementAndGet()) + .subscribe(); + serverMultiplexer + .asServerConnection() + .receive() + .doOnNext(f -> serverFrames.incrementAndGet()) + .subscribe(); + serverMultiplexer + .asSetupConnection() + .receive() + .doOnNext(f -> setupFrames.incrementAndGet()) + .subscribe(); - source.addToReceivedBuffer(ErrorFrameFlyweight.encode(allocator, 1, new Exception())); + source.addToReceivedBuffer(errorFrame(1)); + assertEquals(1, clientFrames.get()); + assertEquals(0, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(errorFrame(1)); + assertEquals(2, clientFrames.get()); + assertEquals(0, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(leaseFrame()); assertEquals(2, clientFrames.get()); assertEquals(1, serverFrames.get()); - assertEquals(0, connectionFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(keepAliveFrame()); + assertEquals(2, clientFrames.get()); + assertEquals(2, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(errorFrame(2)); + assertEquals(2, clientFrames.get()); + assertEquals(3, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(errorFrame(0)); + assertEquals(2, clientFrames.get()); + assertEquals(4, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(metadataPushFrame()); + assertEquals(3, clientFrames.get()); + assertEquals(4, serverFrames.get()); + assertEquals(0, setupFrames.get()); + + source.addToReceivedBuffer(setupFrame()); + assertEquals(3, clientFrames.get()); + assertEquals(4, serverFrames.get()); + assertEquals(1, setupFrames.get()); + + source.addToReceivedBuffer(resumeFrame()); + assertEquals(3, clientFrames.get()); + assertEquals(4, serverFrames.get()); + assertEquals(2, setupFrames.get()); + + source.addToReceivedBuffer(resumeOkFrame()); + assertEquals(3, clientFrames.get()); + assertEquals(4, serverFrames.get()); + assertEquals(3, setupFrames.get()); + } + + private ByteBuf resumeFrame() { + return ResumeFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER, 0, 0); + } + + private ByteBuf setupFrame() { + return SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + false, + 0, + 42, + "application/octet-stream", + "application/octet-stream", + DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER)); + } + + private ByteBuf leaseFrame() { + return LeaseFrameCodec.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf errorFrame(int i) { + return ErrorFrameCodec.encode(allocator, i, new Exception()); + } + + private ByteBuf resumeOkFrame() { + return ResumeOkFrameCodec.encode(allocator, 0); + } + + private ByteBuf keepAliveFrame() { + return KeepAliveFrameCodec.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf metadataPushFrame() { + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); } } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java new file mode 100644 index 000000000..d73f92b85 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java @@ -0,0 +1,23 @@ +package io.rsocket.internal; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import reactor.core.scheduler.Scheduler; + +public class SchedulerUtils { + + public static void warmup(Scheduler scheduler) throws InterruptedException { + warmup(scheduler, 10000); + } + + public static void warmup(Scheduler scheduler, int times) throws InterruptedException { + scheduler.start(); + + // warm up + CountDownLatch latch = new CountDownLatch(times); + for (int i = 0; i < times; i++) { + scheduler.schedule(latch::countDown); + } + latch.await(5, TimeUnit.SECONDS); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java b/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java deleted file mode 100644 index 2297d6bfa..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java +++ /dev/null @@ -1,444 +0,0 @@ -package io.rsocket.internal; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasItem; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.Test; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; -import reactor.test.StepVerifier; -import reactor.test.publisher.TestPublisher; -import reactor.test.util.RaceTestUtils; -import reactor.util.context.Context; - -public class SwitchTransformFluxTest { - - @Test - public void shouldBeAbleToCancelSubscription() throws InterruptedException { - for (int j = 0; j < 10; j++) { - ArrayList capturedElements = new ArrayList<>(); - ArrayList capturedCompletions = new ArrayList<>(); - for (int i = 0; i < 1000; i++) { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong captureElement = new AtomicLong(0L); - AtomicBoolean captureCompletion = new AtomicBoolean(false); - AtomicLong requested = new AtomicLong(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::addAndGet) - .doOnCancel(latch::countDown) - .transform( - flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)); - - publisher.next(1L); - - switchTransformed.subscribe( - captureElement::set, - __ -> {}, - () -> captureCompletion.set(true), - s -> - new Thread( - () -> - RaceTestUtils.race( - publisher::complete, - () -> - RaceTestUtils.race( - s::cancel, () -> s.request(1), Schedulers.parallel()), - Schedulers.parallel())) - .start()); - - Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); - Assert.assertEquals(requested.get(), 1L); - capturedElements.add(captureElement.get()); - capturedCompletions.add(captureCompletion.get()); - } - - Assume.assumeThat(capturedElements, hasItem(equalTo(0L))); - Assume.assumeThat(capturedCompletions, hasItem(equalTo(false))); - } - } - - @Test - public void shouldRequestExpectedAmountOfElements() throws InterruptedException { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong capture = new AtomicLong(); - AtomicLong requested = new AtomicLong(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::addAndGet) - .transform(flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)); - - publisher.next(1L); - - switchTransformed.subscribe( - capture::set, - __ -> {}, - latch::countDown, - s -> { - for (int i = 0; i < 10000; i++) { - RaceTestUtils.race(() -> s.request(1), () -> s.request(1)); - } - RaceTestUtils.race(publisher::complete, publisher::complete); - }); - - latch.await(5, TimeUnit.SECONDS); - - Assert.assertEquals(capture.get(), 1L); - Assert.assertEquals(requested.get(), 20000L); - } - - @Test - public void shouldReturnCorrectContextOnEmptySource() { - Flux switchTransformed = - Flux.empty() - .transform(flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - StepVerifier.create(switchTransformed, 0) - .expectSubscription() - .thenRequest(1) - .expectAccessibleContext() - .contains("a", "c") - .contains("c", "d") - .then() - .expectComplete() - .verify(); - } - - @Test - public void shouldNotFailOnIncorrectPublisherBehavior() { - TestPublisher publisher = - TestPublisher.createNoncompliant(TestPublisher.Violation.CLEANUP_ON_TERMINATE); - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, - (first, innerFlux) -> innerFlux.subscriberContext(Context.of("a", "b")))); - - StepVerifier.create( - new Flux() { - @Override - public void subscribe(CoreSubscriber actual) { - switchTransformed.subscribe(actual); - publisher.next(1L); - } - }, - 0) - .thenRequest(1) - .expectNext(1L) - .thenRequest(1) - .then(() -> publisher.next(2L)) - .expectNext(2L) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .expectError() - .verifyThenAssertThat() - .hasDroppedErrors(3) - .tookLessThan(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - // @Test - // public void shouldNotFailOnIncorrePu - - @Test - public void shouldBeAbleToAccessUpstreamContext() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, - (first, innerFlux) -> - innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")))) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - publisher.next(1L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("1") - .thenRequest(1) - .then(() -> publisher.next(2L)) - .expectNext("2") - .expectAccessibleContext() - .contains("a", "b") - .contains("c", "d") - .then() - .then(publisher::complete) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void shouldNotHangWhenOneElementUpstream() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, - (first, innerFlux) -> - innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")))) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - publisher.next(1L); - publisher.complete(); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("1") - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void backpressureTest() { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - publisher.next(1L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("1") - .thenRequest(1) - .then(() -> publisher.next(2L)) - .expectNext("2") - .then(publisher::complete) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - - Assert.assertEquals(2L, requested.get()); - } - - @Test - public void backpressureConditionalTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) - .filter(e -> false); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assert.assertEquals(2L, requested.get()); - } - - @Test - public void backpressureHiddenConditionalTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf).hide())) - .filter(e -> false); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assert.assertEquals(10001L, requested.get()); - } - - @Test - public void backpressureDrawbackOnConditionalInTransformTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, - (first, innerFlux) -> innerFlux.map(String::valueOf).filter(e -> false))); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assert.assertEquals(10001L, requested.get()); - } - - @Test - public void shouldErrorOnOverflowTest() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - publisher.next(1L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("1") - .then(() -> publisher.next(2L)) - .expectError() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void shouldPropagateonCompleteCorrectly() { - Flux switchTransformed = - Flux.empty() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - StepVerifier.create(switchTransformed).expectComplete().verify(Duration.ofSeconds(10)); - } - - @Test - public void shouldPropagateErrorCorrectly() { - Flux switchTransformed = - Flux.error(new RuntimeException("hello")) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - StepVerifier.create(switchTransformed) - .expectErrorMessage("hello") - .verify(Duration.ofSeconds(10)); - } - - @Test - public void shouldBeAbleToBeCancelledProperly() { - TestPublisher publisher = TestPublisher.createCold(); - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - publisher.next(1); - - StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - } - - @Test - public void shouldBeAbleToCatchDiscardedElement() { - TestPublisher publisher = TestPublisher.createCold(); - Integer[] discarded = new Integer[1]; - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) - .doOnDiscard(Integer.class, e -> discarded[0] = e); - - publisher.next(1); - - StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - - Assert.assertArrayEquals(new Integer[] {1}, discarded); - } - - @Test - public void shouldBeAbleToCatchDiscardedElementInCaseOfConditional() { - TestPublisher publisher = TestPublisher.createCold(); - Integer[] discarded = new Integer[1]; - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) - .filter(t -> true) - .doOnDiscard(Integer.class, e -> discarded[0] = e); - - publisher.next(1); - - StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - - Assert.assertArrayEquals(new Integer[] {1}, discarded); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java index 0dc7d9090..7bf975543 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -17,6 +17,7 @@ package io.rsocket.internal; import io.rsocket.Payload; +import io.rsocket.util.ByteBufPayload; import io.rsocket.util.EmptyPayload; import java.util.concurrent.CountDownLatch; import org.junit.Assert; @@ -82,6 +83,36 @@ public void testOnNextAfterSubscribe_1000() throws Exception { testOnNextAfterSubscribeN(1000); } + @Test + public void testPrioritizedSending() { + UnboundedProcessor processor = new UnboundedProcessor<>(); + + for (int i = 0; i < 1000; i++) { + processor.onNext(EmptyPayload.INSTANCE); + } + + processor.onNextPrioritized(ByteBufPayload.create("test")); + + Payload closestPayload = processor.next().block(); + + Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + } + + @Test + public void testPrioritizedFused() { + UnboundedProcessor processor = new UnboundedProcessor<>(); + + for (int i = 0; i < 1000; i++) { + processor.onNext(EmptyPayload.INSTANCE); + } + + processor.onNextPrioritized(ByteBufPayload.create("test")); + + Payload closestPayload = processor.poll(); + + Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + } + public void testOnNextAfterSubscribeN(int n) throws Exception { CountDownLatch latch = new CountDownLatch(n); UnboundedProcessor processor = new UnboundedProcessor<>(); diff --git a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java new file mode 100644 index 000000000..84a589a8d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java @@ -0,0 +1,1154 @@ +/* + * Copyright (c) 2011-2017 Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.subscriber; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BooleanSupplier; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.context.Context; + +/** + * A Subscriber implementation that hosts assertion tests for its state and allows asynchronous + * cancellation and requesting. + * + *

To create a new instance of {@link AssertSubscriber}, you have the choice between these static + * methods: + * + *

    + *
  • {@link AssertSubscriber#create()}: create a new {@link AssertSubscriber} and requests an + * unbounded number of elements. + *
  • {@link AssertSubscriber#create(long)}: create a new {@link AssertSubscriber} and requests + * {@code n} elements (can be 0 if you want no initial demand). + *
+ * + *

If you are testing asynchronous publishers, don't forget to use one of the {@code await*()} + * methods to wait for the data to assert. + * + *

You can extend this class but only the onNext, onError and onComplete can be overridden. You + * can call {@link #request(long)} and {@link #cancel()} from any thread or from within the + * overridable methods but you should avoid calling the assertXXX methods asynchronously. + * + *

Usage: + * + *

{@code
+ * AssertSubscriber
+ *   .subscribe(publisher)
+ *   .await()
+ *   .assertValues("ABC", "DEF");
+ * }
+ * + * @param the value type. + * @author Sebastien Deleuze + * @author David Karnok + * @author Anatoly Kadyshev + * @author Stephane Maldini + * @author Brian Clozel + */ +public class AssertSubscriber implements CoreSubscriber, Subscription { + + /** Default timeout for waiting next values to be received */ + public static final Duration DEFAULT_VALUES_TIMEOUT = Duration.ofSeconds(3); + + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(AssertSubscriber.class, "requested"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater NEXT_VALUES = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, List.class, "values"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, Subscription.class, "s"); + + private final Context context; + + private final List errors = new LinkedList<>(); + + private final CountDownLatch cdl = new CountDownLatch(1); + + volatile Subscription s; + + volatile long requested; + + volatile List values = new LinkedList<>(); + + /** The fusion mode to request. */ + private int requestedFusionMode = -1; + + /** The established fusion mode. */ + private volatile int establishedFusionMode = -1; + + /** The fuseable QueueSubscription in case a fusion mode was specified. */ + private Fuseable.QueueSubscription qs; + + private int subscriptionCount = 0; + + private int completionCount = 0; + + private volatile long valueCount = 0L; + + private volatile long nextValueAssertedCount = 0L; + + private Duration valuesTimeout = DEFAULT_VALUES_TIMEOUT; + + private boolean valuesStorage = true; + + // + // ============================================================================================================== + // Static methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throws an {@link AssertionError} with the specified error message + * supplier. + * + * @param timeout the timeout duration + * @param errorMessageSupplier the error message supplier + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, Supplier errorMessageSupplier, BooleanSupplier conditionSupplier) { + + Objects.requireNonNull(errorMessageSupplier); + Objects.requireNonNull(conditionSupplier); + Objects.requireNonNull(timeout); + + long timeoutNs = timeout.toNanos(); + long startTime = System.nanoTime(); + do { + if (conditionSupplier.getAsBoolean()) { + return; + } + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } while (System.nanoTime() - startTime < timeoutNs); + throw new AssertionError(errorMessageSupplier.get()); + } + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throw an {@link AssertionError} with the specified error message. + * + * @param timeout the timeout duration + * @param errorMessage the error message + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, final String errorMessage, BooleanSupplier conditionSupplier) { + await( + timeout, + new Supplier() { + @Override + public String get() { + return errorMessage; + } + }, + conditionSupplier); + } + + /** + * Create a new {@link AssertSubscriber} that requests an unbounded number of elements. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create() { + return new AssertSubscriber<>(); + } + + /** + * Create a new {@link AssertSubscriber} that requests initially {@code n} elements. You can then + * manage the demand with {@link Subscription#request(long)}. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param n Number of elements to request (can be 0 if you want no initial demand). + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create(long n) { + return new AssertSubscriber<>(n); + } + + // + // ============================================================================================================== + // constructors + // + // ============================================================================================================== + + public AssertSubscriber() { + this(Context.empty(), Long.MAX_VALUE); + } + + public AssertSubscriber(long n) { + this(Context.empty(), n); + } + + public AssertSubscriber(Context context) { + this(context, Long.MAX_VALUE); + } + + public AssertSubscriber(Context context, long n) { + if (n < 0) { + throw new IllegalArgumentException("initialRequest >= required but it was " + n); + } + this.context = context; + REQUESTED.lazySet(this, n); + } + + // + // ============================================================================================================== + // Configuration + // + // ============================================================================================================== + + /** + * Enable or disabled the values storage. It is enabled by default, and can be disable in order to + * be able to perform performance benchmarks or tests with a huge amount values. + * + * @param enabled enable value storage? + * @return this + */ + public final AssertSubscriber configureValuesStorage(boolean enabled) { + this.valuesStorage = enabled; + return this; + } + + /** + * Configure the timeout in seconds for waiting next values to be received (3 seconds by default). + * + * @param timeout the new default value timeout duration + * @return this + */ + public final AssertSubscriber configureValuesTimeout(Duration timeout) { + this.valuesTimeout = timeout; + return this; + } + + /** + * Returns the established fusion mode or -1 if it was not enabled + * + * @return the fusion mode, see Fuseable constants + */ + public final int establishedFusionMode() { + return establishedFusionMode; + } + + // + // ============================================================================================================== + // Assertions + // + // ============================================================================================================== + + /** + * Assert a complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertComplete() { + assertNoError(); + int c = completionCount; + if (c == 0) { + throw new AssertionError("Not completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert the specified values have been received. Values storage should be enabled to use this + * method. + * + * @param expectedValues the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertContainValues(Set expectedValues) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + if (expectedValues.size() > values.size()) { + throw new AssertionError("Actual contains fewer elements" + values, null); + } + + Iterator expected = expectedValues.iterator(); + + for (; ; ) { + boolean n2 = expected.hasNext(); + if (n2) { + T t2 = expected.next(); + if (!values.contains(t2)) { + throw new AssertionError( + "The element is not contained in the " + + "received results" + + " = " + + valueAndClass(t2), + null); + } + } else { + break; + } + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertError() { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @param clazz The class of the exception contained in the error signal + * @return this + */ + public final AssertSubscriber assertError(Class clazz) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + Throwable e = errors.get(0); + if (!clazz.isInstance(e)) { + throw new AssertionError( + "Error class incompatible: expected = " + clazz + ", actual = " + e, null); + } + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + public final AssertSubscriber assertErrorMessage(String message) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + assertionError("No error", null); + } + if (s == 1) { + if (!Objects.equals(message, errors.get(0).getMessage())) { + assertionError( + "Error class incompatible: expected = \"" + + message + + "\", actual = \"" + + errors.get(0).getMessage() + + "\"", + null); + } + } + if (s > 1) { + assertionError("Multiple errors: " + s, null); + } + + return this; + } + + /** + * Assert an error signal has been received. + * + * @param expectation A method that can verify the exception contained in the error signal and + * throw an exception (like an {@link AssertionError}) if the exception is not valid. + * @return this + */ + public final AssertSubscriber assertErrorWith(Consumer expectation) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + expectation.accept(errors.get(0)); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert that the upstream was a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertFuseableSource() { + if (qs == null) { + throw new AssertionError("Upstream was not Fuseable"); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionEnabled() { + if (establishedFusionMode != Fuseable.SYNC && establishedFusionMode != Fuseable.ASYNC) { + throw new AssertionError("Fusion was not enabled"); + } + return this; + } + + public final AssertSubscriber assertFusionMode(int expectedMode) { + if (establishedFusionMode != expectedMode) { + throw new AssertionError( + "Wrong fusion mode: expected: " + + fusionModeName(expectedMode) + + ", actual: " + + fusionModeName(establishedFusionMode)); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionRejected() { + if (establishedFusionMode != Fuseable.NONE) { + throw new AssertionError("Fusion was granted"); + } + return this; + } + + /** + * Assert no error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNoError() { + int s = errors.size(); + if (s == 1) { + Throwable e = errors.get(0); + String valueAndClass = e == null ? null : e + " (" + e.getClass().getSimpleName() + ")"; + throw new AssertionError("Error present: " + valueAndClass, null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert no values have been received. + * + * @return this + */ + public final AssertSubscriber assertNoValues() { + if (valueCount != 0) { + throw new AssertionError( + "No values expected but received: [length = " + values.size() + "] " + values, null); + } + return this; + } + + /** + * Assert that the upstream was not a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertNonFuseableSource() { + if (qs != null) { + throw new AssertionError("Upstream was Fuseable"); + } + return this; + } + + /** + * Assert no complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotComplete() { + int c = completionCount; + if (c == 1) { + throw new AssertionError("Completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert no subscription occurred. + * + * @return this + */ + public final AssertSubscriber assertNotSubscribed() { + int s = subscriptionCount; + + if (s == 1) { + throw new AssertionError("OnSubscribe called once", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert no complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotTerminated() { + if (cdl.getCount() == 0) { + throw new AssertionError("Terminated", null); + } + return this; + } + + /** + * Assert subscription occurred (once). + * + * @return this + */ + public final AssertSubscriber assertSubscribed() { + int s = subscriptionCount; + + if (s == 0) { + throw new AssertionError("OnSubscribe not called", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert either complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertTerminated() { + if (cdl.getCount() != 0) { + throw new AssertionError("Not terminated", null); + } + return this; + } + + /** + * Assert {@code n} values has been received. + * + * @param n the expected value count + * @return this + */ + public final AssertSubscriber assertValueCount(long n) { + if (valueCount != n) { + throw new AssertionError( + "Different value count: expected = " + n + ", actual = " + valueCount, null); + } + return this; + } + + /** + * Assert the specified values have been received in the same order read by the passed {@link + * Iterable}. Values storage should be enabled to use this method. + * + * @param expectedSequence the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertValueSequence(Iterable expectedSequence) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + Iterator actual = values.iterator(); + Iterator expected = expectedSequence.iterator(); + int i = 0; + for (; ; ) { + boolean n1 = actual.hasNext(); + boolean n2 = expected.hasNext(); + if (n1 && n2) { + T t1 = actual.next(); + T t2 = expected.next(); + if (!Objects.equals(t1, t2)) { + throw new AssertionError( + "The element with index " + + i + + " does not match: expected = " + + valueAndClass(t2) + + ", actual = " + + valueAndClass(t1), + null); + } + i++; + } else if (n1 && !n2) { + throw new AssertionError("Actual contains more elements" + values, null); + } else if (!n1 && n2) { + throw new AssertionError("Actual contains fewer elements: " + values, null); + } else { + break; + } + } + return this; + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectedValues the values to assert + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValues(T... expectedValues) { + return assertValueSequence(Arrays.asList(expectedValues)); + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValuesWith(Consumer... expectations) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + final int expectedValueCount = expectations.length; + if (expectedValueCount != values.size()) { + throw new AssertionError( + "Different value count: expected = " + expectedValueCount + ", actual = " + valueCount, + null); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = values.get(i); + consumer.accept(actualValue); + } + return this; + } + + // + // ============================================================================================================== + // Await methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until a complete successfully or error signal is received. + * + * @return this + */ + public final AssertSubscriber await() { + if (cdl.getCount() == 0) { + return this; + } + try { + cdl.await(); + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + return this; + } + + /** + * Blocking method that waits until a complete successfully or error signal is received or until a + * timeout occurs. + * + * @param timeout The timeout value + * @return this + */ + public final AssertSubscriber await(Duration timeout) { + if (cdl.getCount() == 0) { + return this; + } + try { + if (!cdl.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new AssertionError("No complete or error signal before timeout"); + } + return this; + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + } + + /** + * Blocking method that waits until {@code n} next values have been received. + * + * @param n the value count to assert + * @return this + */ + public final AssertSubscriber awaitAndAssertNextValueCount(final long n) { + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + n, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d", + valueCount - nextValueAssertedCount, n, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + n)); + nextValueAssertedCount += n; + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * values provided) to assert them. + * + * @param values the values to assert + * @return this + */ + @SafeVarargs + @SuppressWarnings("unchecked") + public final AssertSubscriber awaitAndAssertNextValues(T... values) { + final int expectedNum = values.length; + final List> expectations = new ArrayList<>(); + for (int i = 0; i < expectedNum; i++) { + final T expectedValue = values[i]; + expectations.add( + actualValue -> { + if (!actualValue.equals(expectedValue)) { + throw new AssertionError( + String.format( + "Expected Next signal: %s, but got: %s", expectedValue, actualValue)); + } + }); + } + awaitAndAssertNextValuesWith(expectations.toArray((Consumer[]) new Consumer[0])); + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * expectations provided) to assert them. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + */ + @SafeVarargs + public final AssertSubscriber awaitAndAssertNextValuesWith(Consumer... expectations) { + valuesStorage = true; + final int expectedValueCount = expectations.length; + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + expectedValueCount, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d ms", + valueCount - nextValueAssertedCount, expectedValueCount, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + expectedValueCount)); + List nextValuesSnapshot; + List empty = new ArrayList<>(); + for (; ; ) { + nextValuesSnapshot = values; + if (NEXT_VALUES.compareAndSet(this, values, empty)) { + break; + } + } + if (nextValuesSnapshot.size() < expectedValueCount) { + throw new AssertionError( + String.format( + "Expected %d number of signals but received %d", + expectedValueCount, nextValuesSnapshot.size())); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = nextValuesSnapshot.get(i); + consumer.accept(actualValue); + } + nextValueAssertedCount += expectedValueCount; + return this; + } + + // + // ============================================================================================================== + // Overrides + // + // ============================================================================================================== + + @Override + public void cancel() { + Subscription a = s; + if (a != Operators.cancelledSubscription()) { + a = S.getAndSet(this, Operators.cancelledSubscription()); + if (a != null && a != Operators.cancelledSubscription()) { + a.cancel(); + } + } + } + + final boolean isCancelled() { + return s == Operators.cancelledSubscription(); + } + + public final boolean isTerminated() { + return cdl.getCount() == 0; + } + + @Override + public void onComplete() { + completionCount++; + cdl.countDown(); + } + + @Override + public void onError(Throwable t) { + errors.add(t); + cdl.countDown(); + } + + @Override + public void onNext(T t) { + if (establishedFusionMode == Fuseable.ASYNC) { + for (; ; ) { + t = qs.poll(); + if (t == null) { + break; + } + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } else { + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + subscriptionCount++; + int requestMode = requestedFusionMode; + if (requestMode >= 0) { + if (!setWithoutRequesting(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + } else { + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; + + int m = qs.requestFusion(requestMode); + establishedFusionMode = m; + + if (m == Fuseable.SYNC) { + for (; ; ) { + T v = qs.poll(); + if (v == null) { + onComplete(); + break; + } + + onNext(v); + } + } else { + requestDeferred(); + } + } else { + requestDeferred(); + } + } + } else { + if (!set(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + } + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + if (establishedFusionMode != Fuseable.SYNC) { + normalRequest(n); + } + } + } + + @Override + @NonNull + public Context currentContext() { + return context; + } + + /** + * Setup what fusion mode should be requested from the incoming Subscription if it happens to be + * QueueSubscription + * + * @param requestMode the mode to request, see Fuseable constants + * @return this + */ + public final AssertSubscriber requestedFusionMode(int requestMode) { + this.requestedFusionMode = requestMode; + return this; + } + + public Subscription upstream() { + return s; + } + + // + // ============================================================================================================== + // Non public methods + // + // ============================================================================================================== + + protected final void normalRequest(long n) { + Subscription a = s; + if (a != null) { + a.request(n); + } else { + Operators.addCap(REQUESTED, this, n); + + a = s; + + if (a != null) { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + a.request(r); + } + } + } + } + + /** Requests the deferred amount if not zero. */ + protected final void requestDeferred() { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + } + + /** + * Atomically sets the single subscription and requests the missed amount from it. + * + * @param s + * @return false if this arbiter is cancelled or there was a subscription already set + */ + protected final boolean set(Subscription s) { + Objects.requireNonNull(s, "s"); + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + + return true; + } + + a = this.s; + + if (a != Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + + Operators.reportSubscriptionSet(); + return false; + } + + /** + * Sets the Subscription once but does not request anything. + * + * @param s the Subscription to set + * @return true if successful, false if the current subscription is not null + */ + protected final boolean setWithoutRequesting(Subscription s) { + Objects.requireNonNull(s, "s"); + for (; ; ) { + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + return true; + } + } + } + + /** + * Prepares and throws an AssertionError exception based on the message, cause, the active state + * and the potential errors so far. + * + * @param message the message + * @param cause the optional Throwable cause + * @throws AssertionError as expected + */ + protected final void assertionError(String message, Throwable cause) { + StringBuilder b = new StringBuilder(); + + if (cdl.getCount() != 0) { + b.append("(active) "); + } + b.append(message); + + List err = errors; + if (!err.isEmpty()) { + b.append(" (+ ").append(err.size()).append(" errors)"); + } + AssertionError e = new AssertionError(b.toString(), cause); + + for (Throwable t : err) { + e.addSuppressed(t); + } + + throw e; + } + + protected final String fusionModeName(int mode) { + switch (mode) { + case -1: + return "Disabled"; + case Fuseable.NONE: + return "None"; + case Fuseable.SYNC: + return "Sync"; + case Fuseable.ASYNC: + return "Async"; + default: + return "Unknown(" + mode + ")"; + } + } + + protected final String valueAndClass(Object o) { + if (o == null) { + return null; + } + return o + " (" + o.getClass().getSimpleName() + ")"; + } + + public List values() { + return values; + } + + public final AssertSubscriber assertNoEvents() { + return assertNoValues().assertNoError().assertNotComplete(); + } + + @SafeVarargs + public final AssertSubscriber assertIncomplete(T... values) { + return assertValues(values).assertNotComplete().assertNoError(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java new file mode 100644 index 000000000..d5b2eeb41 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java @@ -0,0 +1,86 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.lease; + +import static org.junit.Assert.*; + +import io.netty.buffer.Unpooled; +import java.time.Duration; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +public class LeaseImplTest { + + @Test + public void emptyLeaseNoAvailability() { + LeaseImpl empty = LeaseImpl.empty(); + Assertions.assertTrue(empty.isEmpty()); + Assertions.assertFalse(empty.isValid()); + Assertions.assertEquals(0.0, empty.availability(), 1e-5); + } + + @Test + public void emptyLeaseUseNoAvailability() { + LeaseImpl empty = LeaseImpl.empty(); + boolean success = empty.use(); + assertFalse(success); + Assertions.assertEquals(0.0, empty.availability(), 1e-5); + } + + @Test + public void leaseAvailability() { + LeaseImpl lease = LeaseImpl.create(2, 100, Unpooled.EMPTY_BUFFER); + Assertions.assertEquals(1.0, lease.availability(), 1e-5); + } + + @Test + public void leaseUseDecreasesAvailability() { + LeaseImpl lease = LeaseImpl.create(30_000, 2, Unpooled.EMPTY_BUFFER); + boolean success = lease.use(); + Assertions.assertTrue(success); + Assertions.assertEquals(0.5, lease.availability(), 1e-5); + Assertions.assertTrue(lease.isValid()); + success = lease.use(); + Assertions.assertTrue(success); + Assertions.assertEquals(0.0, lease.availability(), 1e-5); + Assertions.assertFalse(lease.isValid()); + Assertions.assertEquals(0, lease.getAllowedRequests()); + success = lease.use(); + Assertions.assertFalse(success); + } + + @Test + public void leaseTimeout() { + int numberOfRequests = 1; + LeaseImpl lease = LeaseImpl.create(1, numberOfRequests, Unpooled.EMPTY_BUFFER); + Mono.delay(Duration.ofMillis(100)).block(); + boolean success = lease.use(); + Assertions.assertFalse(success); + Assertions.assertTrue(lease.isExpired()); + Assertions.assertEquals(numberOfRequests, lease.getAllowedRequests()); + Assertions.assertFalse(lease.isValid()); + } + + @Test + public void useLeaseChangesAllowedRequests() { + int numberOfRequests = 2; + LeaseImpl lease = LeaseImpl.create(30_000, numberOfRequests, Unpooled.EMPTY_BUFFER); + lease.use(); + assertEquals(numberOfRequests - 1, lease.getAllowedRequests()); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java new file mode 100644 index 000000000..bd5e4295a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataFlyweightTest.java @@ -0,0 +1,554 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer; +import static org.assertj.core.api.Assertions.*; + +import io.netty.buffer.*; +import io.netty.util.CharsetUtil; +import io.rsocket.test.util.ByteBufUtils; +import io.rsocket.util.NumberUtils; +import org.junit.jupiter.api.Test; + +class CompositeMetadataFlyweightTest { + + static String byteToBitsString(byte b) { + return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0'); + } + + static String toHeaderBits(ByteBuf encoded) { + encoded.markReaderIndex(); + byte headerByte = encoded.readByte(); + String byteAsString = byteToBitsString(headerByte); + encoded.resetReaderIndex(); + return byteAsString; + } + // ==== + + @Test + void customMimeHeaderLatin1_encodingFails() { + String mimeNotAscii = "mime/typé"; + + assertThatIllegalArgumentException() + .isThrownBy( + () -> + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mimeNotAscii, 0)) + .withMessage("custom mime type must be US_ASCII characters only"); + } + + @Test + void customMimeHeaderLength0_encodingFails() { + assertThatIllegalArgumentException() + .isThrownBy( + () -> CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, "", 0)) + .withMessage( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void customMimeHeaderLength127() { + StringBuilder builder = new StringBuilder(127); + for (int i = 0; i < 127; i++) { + builder.append('a'); + } + String mimeString = builder.toString(); + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111110"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(127 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(127, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void customMimeHeaderLength128() { + StringBuilder builder = new StringBuilder(128); + for (int i = 0; i < 128; i++) { + builder.append('a'); + } + String mimeString = builder.toString(); + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111111"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(128 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(128, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void customMimeHeaderLength129_encodingFails() { + StringBuilder builder = new StringBuilder(129); + for (int i = 0; i < 129; i++) { + builder.append('a'); + } + + assertThatIllegalArgumentException() + .isThrownBy( + () -> + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, builder.toString(), 0)) + .withMessage( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void customMimeHeaderLengthOne() { + String mimeString = "w"; + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000000"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()).as("mime length").isZero(); // encoded as actual length - 1 + + assertThat(header.readCharSequence(1, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void customMimeHeaderLengthTwo() { + String mimeString = "ww"; + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000001"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(2 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(2, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataFlyweight.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void customMimeHeaderUtf8_encodingFails() { + String mimeNotAscii = + "mime/tyࠒe"; // this is the SAMARITAN LETTER QUF u+0812 represented on 3 bytes + assertThatIllegalArgumentException() + .isThrownBy( + () -> + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mimeNotAscii, 0)) + .withMessage("custom mime type must be US_ASCII characters only"); + } + + @Test + void decodeEntryAtEndOfBuffer() { + ByteBuf fakeEntry = Unpooled.buffer(); + + assertThatIllegalArgumentException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryHasNoContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(0); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryTooShortForContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(1); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + NumberUtils.encodeUnsignedMedium(fakeEntry, 456); + fakeEntry.writeChar('w'); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryTooShortForMimeLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(120); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeIdMinusTwoWhenMoreThanOneByte() { + ByteBuf fakeIdBuffer = Unpooled.buffer(2); + fakeIdBuffer.writeInt(200); + + assertThat(decodeMimeIdFromMimeBuffer(fakeIdBuffer)) + .isEqualTo((WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier())); + } + + @Test + void decodeIdMinusTwoWhenZeroByte() { + ByteBuf fakeIdBuffer = Unpooled.buffer(0); + + assertThat(decodeMimeIdFromMimeBuffer(fakeIdBuffer)) + .isEqualTo((WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier())); + } + + @Test + void decodeStringNullIfLengthOne() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + fakeTypeBuffer.writeByte(1); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)); + } + + @Test + void decodeStringNullIfLengthZero() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)); + } + + @Test + void decodeTypeSkipsFirstByte() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + fakeTypeBuffer.writeByte(128); + fakeTypeBuffer.writeCharSequence("example", CharsetUtil.US_ASCII); + + assertThat(decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)).hasToString("example"); + } + + @Test + void encodeMetadataCustomTypeDelegates() { + ByteBuf expected = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, "foo", 2); + + CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadata( + test, ByteBufAllocator.DEFAULT, "foo", ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + } + + @Test + void encodeMetadataKnownTypeDelegates() { + ByteBuf expected = + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, + WellKnownMimeType.APPLICATION_OCTET_STREAM.getIdentifier(), + 2); + + CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadata( + test, + ByteBufAllocator.DEFAULT, + WellKnownMimeType.APPLICATION_OCTET_STREAM, + ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + } + + @Test + void encodeMetadataReservedTypeDelegates() { + ByteBuf expected = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, (byte) 120, 2); + + CompositeByteBuf test = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadata( + test, ByteBufAllocator.DEFAULT, (byte) 120, ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + } + + @Test + void encodeTryCompressWithCompressableType() { + ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); + CompositeByteBuf target = UnpooledByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadataWithCompression( + target, + UnpooledByteBufAllocator.DEFAULT, + WellKnownMimeType.APPLICATION_AVRO.getString(), + metadata); + + assertThat(target.readableBytes()).as("readableBytes 1 + 3 + 2").isEqualTo(6); + } + + @Test + void encodeTryCompressWithCustomType() { + ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); + CompositeByteBuf target = UnpooledByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataFlyweight.encodeAndAddMetadataWithCompression( + target, UnpooledByteBufAllocator.DEFAULT, "custom/example", metadata); + + assertThat(target.readableBytes()).as("readableBytes 1 + 14 + 3 + 2").isEqualTo(20); + } + + @Test + void hasEntry() { + WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; + + CompositeByteBuf buffer = + Unpooled.compositeBuffer() + .addComponent( + true, + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0)) + .addComponent( + true, + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0)); + + assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 0)).isTrue(); + assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 4)).isTrue(); + assertThat(CompositeMetadataFlyweight.hasEntry(buffer, 8)).isFalse(); + } + + @Test + void isWellKnownMimeType() { + ByteBuf wellKnown = Unpooled.buffer().writeByte(0); + assertThat(CompositeMetadataFlyweight.isWellKnownMimeType(wellKnown)).isTrue(); + + ByteBuf explicit = Unpooled.buffer().writeByte(2).writeChar('a'); + assertThat(CompositeMetadataFlyweight.isWellKnownMimeType(explicit)).isFalse(); + } + + @Test + void knownMimeHeader120_reserved() { + byte mime = (byte) 120; + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader(ByteBufAllocator.DEFAULT, mime, 0); + + assertThat(mime) + .as("smoke test RESERVED_120 unsigned 7 bits representation") + .isEqualTo((byte) 0b01111000); + + assertThat(toHeaderBits(encoded)).startsWith("1").isEqualTo("11111000"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("11111000"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)).as("decoded mime id").isEqualTo(mime); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void knownMimeHeader127_compositeMetadata() { + WellKnownMimeType mime = WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA; + assertThat(mime.getIdentifier()) + .as("smoke test COMPOSITE unsigned 7 bits representation") + .isEqualTo((byte) 127) + .isEqualTo((byte) 0b01111111); + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0); + + assertThat(toHeaderBits(encoded)) + .startsWith("1") + .isEqualTo("11111111") + .isEqualTo(byteToBitsString(mime.getIdentifier()).replaceFirst("0", "1")); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("11111111"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)) + .as("decoded mime id") + .isEqualTo(mime.getIdentifier()); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void knownMimeHeaderZero_avro() { + WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; + assertThat(mime.getIdentifier()) + .as("smoke test AVRO unsigned 7 bits representation") + .isEqualTo((byte) 0) + .isEqualTo((byte) 0b00000000); + ByteBuf encoded = + CompositeMetadataFlyweight.encodeMetadataHeader( + ByteBufAllocator.DEFAULT, mime.getIdentifier(), 0); + + assertThat(toHeaderBits(encoded)) + .startsWith("1") + .isEqualTo("10000000") + .isEqualTo(byteToBitsString(mime.getIdentifier()).replaceFirst("0", "1")); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("10000000"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)) + .as("decoded mime id") + .isEqualTo(mime.getIdentifier()); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + } + + @Test + void encodeCustomHeaderAsciiCheckSkipsFirstByte() { + final ByteBuf badBuf = Unpooled.copiedBuffer("é00000000000", CharsetUtil.UTF_8); + badBuf.writerIndex(0); + assertThat(badBuf.readerIndex()).isZero(); + + ByteBufAllocator allocator = + new AbstractByteBufAllocator() { + @Override + public boolean isDirectBufferPooled() { + return false; + } + + @Override + protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity) { + return badBuf; + } + + @Override + protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + return badBuf; + } + }; + + assertThatCode( + () -> CompositeMetadataFlyweight.encodeMetadataHeader(allocator, "custom/type", 0)) + .doesNotThrowAnyException(); + + assertThat(badBuf.readByte()).isEqualTo((byte) 10); + assertThat(badBuf.readCharSequence(11, CharsetUtil.UTF_8)).hasToString("custom/type"); + assertThat(badBuf.readUnsignedMedium()).isEqualTo(0); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java new file mode 100644 index 000000000..f06bdcc0c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java @@ -0,0 +1,178 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.metadata.CompositeMetadata.Entry; +import io.rsocket.metadata.CompositeMetadata.ReservedMimeTypeEntry; +import io.rsocket.metadata.CompositeMetadata.WellKnownMimeTypeEntry; +import io.rsocket.test.util.ByteBufUtils; +import io.rsocket.util.NumberUtils; +import java.util.Iterator; +import java.util.Spliterator; +import org.junit.jupiter.api.Test; + +class CompositeMetadataTest { + + @Test + void decodeEntryHasNoContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(0); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeEntryOnDoneBufferThrowsIllegalArgument() { + ByteBuf fakeBuffer = ByteBufUtils.getRandomByteBuf(0); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeBuffer, false); + + assertThatIllegalArgumentException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("entry index 0 is larger than buffer size"); + } + + @Test + void decodeEntryTooShortForContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(1); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + NumberUtils.encodeUnsignedMedium(fakeEntry, 456); + fakeEntry.writeChar('w'); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeEntryTooShortForMimeLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(120); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeThreeEntries() { + // metadata 1: well known + WellKnownMimeType mimeType1 = WellKnownMimeType.APPLICATION_PDF; + ByteBuf metadata1 = Unpooled.buffer(); + metadata1.writeCharSequence("abcdefghijkl", CharsetUtil.UTF_8); + + // metadata 2: custom + String mimeType2 = "application/custom"; + ByteBuf metadata2 = Unpooled.buffer(); + metadata2.writeChar('E'); + metadata2.writeChar('∑'); + metadata2.writeChar('é'); + metadata2.writeBoolean(true); + metadata2.writeChar('W'); + + // metadata 3: reserved but unknown + byte reserved = 120; + assertThat(WellKnownMimeType.fromIdentifier(reserved)) + .as("ensure UNKNOWN RESERVED used in test") + .isSameAs(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE); + ByteBuf metadata3 = Unpooled.buffer(); + metadata3.writeByte(88); + + CompositeByteBuf compositeMetadataBuffer = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeMetadataFlyweight.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType1, metadata1); + CompositeMetadataFlyweight.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType2, metadata2); + CompositeMetadataFlyweight.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, reserved, metadata3); + + Iterator iterator = new CompositeMetadata(compositeMetadataBuffer, true).iterator(); + + assertThat(iterator.next()) + .as("entry1") + .isNotNull() + .satisfies( + e -> + assertThat(e.getMimeType()).as("entry1 mime type").isEqualTo(mimeType1.getString())) + .satisfies( + e -> + assertThat(((WellKnownMimeTypeEntry) e).getType()) + .as("entry1 mime id") + .isEqualTo(WellKnownMimeType.APPLICATION_PDF)) + .satisfies( + e -> + assertThat(e.getContent().toString(CharsetUtil.UTF_8)) + .as("entry1 decoded") + .isEqualTo("abcdefghijkl")); + + assertThat(iterator.next()) + .as("entry2") + .isNotNull() + .satisfies(e -> assertThat(e.getMimeType()).as("entry2 mime type").isEqualTo(mimeType2)) + .satisfies( + e -> assertThat(e.getContent()).as("entry2 decoded").isEqualByComparingTo(metadata2)); + + assertThat(iterator.next()) + .as("entry3") + .isNotNull() + .satisfies(e -> assertThat(e.getMimeType()).as("entry3 mime type").isNull()) + .satisfies( + e -> + assertThat(((ReservedMimeTypeEntry) e).getType()) + .as("entry3 mime id") + .isEqualTo(reserved)) + .satisfies( + e -> assertThat(e.getContent()).as("entry3 decoded").isEqualByComparingTo(metadata3)); + + assertThat(iterator.hasNext()).as("has no more than 3 entries").isFalse(); + } + + @Test + void streamIsNotParallel() { + final CompositeMetadata metadata = + new CompositeMetadata(ByteBufUtils.getRandomByteBuf(5), false); + + assertThat(metadata.stream().isParallel()).as("isParallel").isFalse(); + } + + @Test + void streamSpliteratorCharacteristics() { + final CompositeMetadata metadata = + new CompositeMetadata(ByteBufUtils.getRandomByteBuf(5), false); + + assertThat(metadata.stream().spliterator()) + .matches(s -> s.hasCharacteristics(Spliterator.ORDERED), "ORDERED") + .matches(s -> s.hasCharacteristics(Spliterator.DISTINCT), "DISTINCT") + .matches(s -> s.hasCharacteristics(Spliterator.NONNULL), "NONNULL") + .matches(s -> !s.hasCharacteristics(Spliterator.SIZED), "not SIZED"); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java new file mode 100644 index 000000000..d1fbb50b0 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java @@ -0,0 +1,47 @@ +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBufAllocator; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +/** + * Tagging metadata test + * + * @author linux_china + */ +public class TaggingMetadataTest { + private ByteBufAllocator byteBufAllocator = ByteBufAllocator.DEFAULT; + + @Test + public void testParseTags() { + List tags = + Arrays.asList( + "ws://localhost:8080/rsocket", String.join("", Collections.nCopies(129, "x"))); + TaggingMetadata taggingMetadata = + TaggingMetadataFlyweight.createTaggingMetadata( + byteBufAllocator, "message/x.rsocket.routing.v0", tags); + TaggingMetadata taggingMetadataCopy = + new TaggingMetadata("message/x.rsocket.routing.v0", taggingMetadata.getContent()); + assertThat(tags) + .containsExactlyElementsOf(taggingMetadataCopy.stream().collect(Collectors.toList())); + } + + @Test + public void testEmptyTagAndOverLengthTag() { + List tags = + Arrays.asList( + "ws://localhost:8080/rsocket", "", String.join("", Collections.nCopies(256, "x"))); + TaggingMetadata taggingMetadata = + TaggingMetadataFlyweight.createTaggingMetadata( + byteBufAllocator, "message/x.rsocket.routing.v0", tags); + TaggingMetadata taggingMetadataCopy = + new TaggingMetadata("message/x.rsocket.routing.v0", taggingMetadata.getContent()); + assertThat(tags.subList(0, 1)) + .containsExactlyElementsOf(taggingMetadataCopy.stream().collect(Collectors.toList())); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java new file mode 100644 index 000000000..cb8478c13 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java @@ -0,0 +1,209 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCounted; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +public class TracingMetadataCodecTest { + + private static Stream flags() { + return Stream.of(TracingMetadataCodec.Flags.values()); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeEmptyTrace(TracingMetadataCodec.Flags expectedFlag) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = TracingMetadataCodec.encodeEmpty(allocator, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(TracingMetadata::isEmpty) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace64WithParent(TracingMetadataCodec.Flags expectedFlag) { + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + long parentId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode64(allocator, traceId, spanId, parentId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == 0) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> tm.hasParent()) + .matches(tm -> tm.parentId() == parentId) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace64(TracingMetadataCodec.Flags expectedFlag) { + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = TracingMetadataCodec.encode64(allocator, traceId, spanId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == 0) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> !tm.hasParent()) + .matches(tm -> tm.parentId() == 0) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace128WithParent(TracingMetadataCodec.Flags expectedFlag) { + long traceIdHighLocal; + do { + traceIdHighLocal = ThreadLocalRandom.current().nextLong(); + + } while (traceIdHighLocal == 0); + long traceIdHigh = traceIdHighLocal; + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + long parentId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode128( + allocator, traceIdHigh, traceId, spanId, parentId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == traceIdHigh) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> tm.hasParent()) + .matches(tm -> tm.parentId() == parentId) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace128(TracingMetadataCodec.Flags expectedFlag) { + long traceIdHighLocal; + do { + traceIdHighLocal = ThreadLocalRandom.current().nextLong(); + + } while (traceIdHighLocal == 0); + long traceIdHigh = traceIdHighLocal; + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode128(allocator, traceIdHigh, traceId, spanId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == traceIdHigh) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> !tm.hasParent()) + .matches(tm -> tm.parentId() == 0) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java new file mode 100644 index 000000000..316aaf091 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class WellKnownMimeTypeTest { + + @Test + void fromIdentifierGreaterThan127() { + assertThat(WellKnownMimeType.fromIdentifier(128)) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromIdentifierMatchFromMimeType() { + for (WellKnownMimeType mimeType : WellKnownMimeType.values()) { + if (mimeType == WellKnownMimeType.UNPARSEABLE_MIME_TYPE + || mimeType == WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE) { + continue; + } + assertThat(WellKnownMimeType.fromString(mimeType.toString())) + .as("mimeType string for " + mimeType.name()) + .isSameAs(mimeType); + + assertThat(WellKnownMimeType.fromIdentifier(mimeType.getIdentifier())) + .as("mimeType ID for " + mimeType.name()) + .isSameAs(mimeType); + } + } + + @Test + void fromIdentifierNegative() { + assertThat(WellKnownMimeType.fromIdentifier(-1)) + .isSameAs(WellKnownMimeType.fromIdentifier(-2)) + .isSameAs(WellKnownMimeType.fromIdentifier(-12)) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromIdentifierReserved() { + assertThat(WellKnownMimeType.fromIdentifier(120)) + .isSameAs(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE); + } + + @Test + void fromStringUnknown() { + assertThat(WellKnownMimeType.fromString("foo/bar")) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromStringUnknownReservedStillReturnsUnparseable() { + assertThat( + WellKnownMimeType.fromString(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE.getString())) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java new file mode 100644 index 000000000..13d910e15 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java @@ -0,0 +1,470 @@ +package io.rsocket.metadata.security; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class AuthMetadataFlyweightTest { + + public static final int AUTH_TYPE_ID_LENGTH = 1; + public static final int USER_NAME_BYTES_LENGTH = 1; + public static final String TEST_BEARER_TOKEN = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJpYXQxIjoxNTE2MjM5MDIyLCJpYXQyIjoxNTE2MjM5MDIyLCJpYXQzIjoxNTE2MjM5MDIyLCJpYXQ0IjoxNTE2MjM5MDIyfQ.ljYuH-GNyyhhLcx-rHMchRkGbNsR2_4aSxo8XjrYrSM"; + + @Test + void shouldCorrectlyEncodeData() { + String username = "test"; + String password = "tset1234"; + + int usernameLength = username.length(); + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData1() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData2() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎1234567#4? "; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + private static void checkSimpleAuthMetadataEncoding( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(byteBuf.readUnsignedByte() & ~0x80) + .isEqualTo(WellKnownAuthType.SIMPLE.getIdentifier()); + Assertions.assertThat(byteBuf.readUnsignedByte()).isEqualTo((short) usernameLength); + + Assertions.assertThat(byteBuf.readCharSequence(usernameLength, CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(byteBuf.readCharSequence(passwordLength, CharsetUtil.UTF_8)) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + private static void checkSimpleAuthMetadataEncodingUsingDecoders( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(AuthMetadataFlyweight.decodeWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.SIMPLE); + byteBuf.markReaderIndex(); + Assertions.assertThat(AuthMetadataFlyweight.decodeUsername(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(AuthMetadataFlyweight.decodePassword(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(password); + byteBuf.resetReaderIndex(); + + Assertions.assertThat(new String(AuthMetadataFlyweight.decodeUsernameAsCharArray(byteBuf))) + .isEqualTo(username); + Assertions.assertThat(new String(AuthMetadataFlyweight.decodePasswordAsCharArray(byteBuf))) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + @Test + void shouldThrowExceptionIfUsernameLengthExitsAllowedBounds() { + String username = + "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎𠸏𠹷𠺝𠺢𠻗𠻹𠻺𠼭𠼮𠽌𠾴𠾼𠿪𡁜𡁯𡁵𡁶𡁻𡃁𡃉𡇙𢃇𢞵𢫕𢭃𢯊𢱑𢱕𢳂𢴈𢵌𢵧𢺳𣲷𤓓𤶸𤷪𥄫𦉘𦟌𦧲𦧺𧨾𨅝𨈇𨋢𨳊𨳍𨳒𩶘𠜎𠜱𠝹"; + String password = "tset1234"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray())) + .hasMessage( + "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + } + + @Test + void shouldEncodeBearerMetadata() { + String testToken = TEST_BEARER_TOKEN; + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeBearerMetadata( + ByteBufAllocator.DEFAULT, testToken.toCharArray()); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(testToken, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(testToken, byteBuf); + } + + private static void checkBearerAuthMetadataEncoding(String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat( + byteBuf.readUnsignedByte() & ~AuthMetadataFlyweight.STREAM_METADATA_KNOWN_MASK) + .isEqualTo(WellKnownAuthType.BEARER.getIdentifier()); + Assertions.assertThat(byteBuf.readSlice(byteBuf.capacity() - 1).toString(CharsetUtil.UTF_8)) + .isEqualTo(testToken); + } + + private static void checkBearerAuthMetadataEncodingUsingDecoders( + String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(byteBuf)).isTrue(); + Assertions.assertThat(AuthMetadataFlyweight.decodeWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.BEARER); + byteBuf.markReaderIndex(); + Assertions.assertThat(new String(AuthMetadataFlyweight.decodeBearerTokenAsCharArray(byteBuf))) + .isEqualTo(testToken); + byteBuf.resetReaderIndex(); + Assertions.assertThat( + AuthMetadataFlyweight.decodePayload(byteBuf).toString(CharsetUtil.UTF_8).toString()) + .isEqualTo(testToken); + } + + @Test + void shouldEncodeCustomAuth() { + String payloadAsAText = "testsecuritybuffer"; + ByteBuf testSecurityPayload = + Unpooled.wrappedBuffer(payloadAsAText.getBytes(CharsetUtil.UTF_8)); + + String customAuthType = "myownauthtype"; + ByteBuf buffer = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload); + + checkCustomAuthMetadataEncoding(testSecurityPayload, customAuthType, buffer); + } + + private static void checkCustomAuthMetadataEncoding( + ByteBuf testSecurityPayload, String customAuthType, ByteBuf buffer) { + Assertions.assertThat(buffer.capacity()) + .isEqualTo(1 + customAuthType.length() + testSecurityPayload.capacity()); + Assertions.assertThat(buffer.readUnsignedByte()) + .isEqualTo((short) (customAuthType.length() - 1)); + Assertions.assertThat( + buffer.readCharSequence(customAuthType.length(), CharsetUtil.US_ASCII).toString()) + .isEqualTo(customAuthType); + Assertions.assertThat(buffer.readSlice(testSecurityPayload.capacity())) + .isEqualTo(testSecurityPayload); + + ReferenceCountUtil.release(buffer); + } + + @Test + void shouldThrowOnNonASCIIChars() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = "1234567#4? 𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage("custom auth type must be US_ASCII characters only"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + // 130 chars + String customAuthType = + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType1() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = ""; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldEncodeUsingWellKnownAuthType() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer(3, 3).writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType1() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType2() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldThrowIfWellKnownAuthTypeIsUnsupportedOrUnknown() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + buffer.release(); + } + + @Test + void shouldCompressMetadata() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "simple", + ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldCompressMetadata1() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "bearer", + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldNotCompressMetadata() { + ByteBuf testMetadataPayload = + Unpooled.wrappedBuffer(TEST_BEARER_TOKEN.getBytes(CharsetUtil.UTF_8)); + String customAuthType = "testauthtype"; + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, customAuthType, testMetadataPayload); + + checkCustomAuthMetadataEncoding(testMetadataPayload, customAuthType, byteBuf); + } + + @Test + void shouldConfirmWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(metadata)).isTrue(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldConfirmGivenMetadataIsNotAWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple/afafgafadf", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(metadata)).isFalse(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadSimpleWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.SIMPLE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType1() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "bearer", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.BEARER; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType2() { + ByteBuf metadata = + ByteBufAllocator.DEFAULT + .buffer() + .writeByte(3 | AuthMetadataFlyweight.STREAM_METADATA_KNOWN_MASK); + WellKnownAuthType expectedType = WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength() { + ByteBuf metadata = ByteBufAllocator.DEFAULT.buffer().writeByte(3); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength1() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, "testmetadataauthtype", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldThrowExceptionIsNotEnoughReadableBytes() { + Assertions.assertThatThrownBy( + () -> AuthMetadataFlyweight.decodeWellKnownAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode Well Know Auth type. Not enough readable bytes"); + } + + private static void checkDecodeWellKnowAuthTypeCorrectly( + ByteBuf metadata, WellKnownAuthType expectedType) { + int initialReaderIndex = metadata.readerIndex(); + + WellKnownAuthType wellKnownAuthType = AuthMetadataFlyweight.decodeWellKnownAuthType(metadata); + + Assertions.assertThat(wellKnownAuthType).isEqualTo(expectedType); + Assertions.assertThat(metadata.readerIndex()) + .isNotEqualTo(initialReaderIndex) + .isEqualTo(initialReaderIndex + 1); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadCustomEncodedAuthType() { + String testAuthType = "TestAuthType"; + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, testAuthType, Unpooled.EMPTY_BUFFER); + checkDecodeCustomAuthTypeCorrectly(testAuthType, byteBuf); + } + + @Test + void shouldThrowExceptionOnEmptyMetadata() { + Assertions.assertThatThrownBy( + () -> AuthMetadataFlyweight.decodeCustomAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode custom Auth type. Not enough readable bytes"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_wellknowninstead() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.decodeCustomAuthType( + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(new byte[] {'a', 'b'})))) + .hasMessage("Unable to decode custom Auth type. Incorrect auth type length"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_length() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.decodeCustomAuthType( + ByteBufAllocator.DEFAULT.buffer().writeByte(127).writeChar('a').writeChar('b'))) + .hasMessage("Unable to decode custom Auth type. Malformed length or auth type string"); + } + + private static void checkDecodeCustomAuthTypeCorrectly(String testAuthType, ByteBuf byteBuf) { + int initialReaderIndex = byteBuf.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.decodeCustomAuthType(byteBuf).toString()) + .isEqualTo(testAuthType); + Assertions.assertThat(byteBuf.readerIndex()) + .isEqualTo(initialReaderIndex + testAuthType.length() + 1); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java new file mode 100644 index 000000000..9da66d424 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java @@ -0,0 +1,93 @@ +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import java.util.Arrays; +import org.junit.Assert; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +public class InMemoryResumeStoreTest { + + @Test + void saveWithoutTailRemoval() { + InMemoryResumableFramesStore store = inMemoryStore(25); + ByteBuf frame = frameMock(10); + store.saveFrames(Flux.just(frame)).block(); + Assert.assertEquals(1, store.cachedFrames.size()); + Assert.assertEquals(frame.readableBytes(), store.cacheSize); + Assert.assertEquals(0, store.position); + } + + @Test + void saveRemoveOneFromTail() { + InMemoryResumableFramesStore store = inMemoryStore(25); + ByteBuf frame1 = frameMock(20); + ByteBuf frame2 = frameMock(10); + store.saveFrames(Flux.just(frame1, frame2)).block(); + Assert.assertEquals(1, store.cachedFrames.size()); + Assert.assertEquals(frame2.readableBytes(), store.cacheSize); + Assert.assertEquals(frame1.readableBytes(), store.position); + } + + @Test + void saveRemoveTwoFromTail() { + InMemoryResumableFramesStore store = inMemoryStore(25); + ByteBuf frame1 = frameMock(10); + ByteBuf frame2 = frameMock(10); + ByteBuf frame3 = frameMock(20); + store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); + Assert.assertEquals(1, store.cachedFrames.size()); + Assert.assertEquals(frame3.readableBytes(), store.cacheSize); + Assert.assertEquals(size(frame1, frame2), store.position); + } + + @Test + void saveBiggerThanStore() { + InMemoryResumableFramesStore store = inMemoryStore(25); + ByteBuf frame1 = frameMock(10); + ByteBuf frame2 = frameMock(10); + ByteBuf frame3 = frameMock(30); + store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); + Assert.assertEquals(0, store.cachedFrames.size()); + Assert.assertEquals(0, store.cacheSize); + Assert.assertEquals(size(frame1, frame2, frame3), store.position); + } + + @Test + void releaseFrames() { + InMemoryResumableFramesStore store = inMemoryStore(100); + ByteBuf frame1 = frameMock(10); + ByteBuf frame2 = frameMock(10); + ByteBuf frame3 = frameMock(30); + store.saveFrames(Flux.just(frame1, frame2, frame3)).block(); + store.releaseFrames(20); + Assert.assertEquals(1, store.cachedFrames.size()); + Assert.assertEquals(frame3.readableBytes(), store.cacheSize); + Assert.assertEquals(size(frame1, frame2), store.position); + } + + @Test + void receiveImpliedPosition() { + InMemoryResumableFramesStore store = inMemoryStore(100); + ByteBuf frame1 = frameMock(10); + ByteBuf frame2 = frameMock(30); + store.resumableFrameReceived(frame1); + store.resumableFrameReceived(frame2); + Assert.assertEquals(size(frame1, frame2), store.frameImpliedPosition()); + } + + private int size(ByteBuf... byteBufs) { + return Arrays.stream(byteBufs).mapToInt(ByteBuf::readableBytes).sum(); + } + + private static InMemoryResumableFramesStore inMemoryStore(int size) { + return new InMemoryResumableFramesStore("test", size); + } + + private static ByteBuf frameMock(int size) { + byte[] bytes = new byte[size]; + Arrays.fill(bytes, (byte) 7); + return Unpooled.wrappedBuffer(bytes); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCacheTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeCacheTest.java deleted file mode 100644 index 18ab589dc..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCacheTest.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -public class ResumeCacheTest { - /*private Frame CANCEL = Frame.Cancel.from(1); - private Frame STREAM = - Frame.Request.from(1, FrameType.REQUEST_STREAM, DefaultPayload.create("Test"), 100); - - private ResumeCache cache = new ResumeCache(ResumePositionCounter.frames(), 2); - - @Test - public void startsEmpty() { - Flux x = cache.resend(0); - assertEquals(0L, (long) x.count().block()); - cache.updateRemotePosition(0); - } - - @Test(expected = IllegalStateException.class) - public void failsForFutureUpdatePosition() { - cache.updateRemotePosition(1); - } - - @Test(expected = IllegalStateException.class) - public void failsForFutureResend() { - cache.resend(1); - } - - @Test - public void updatesPositions() { - assertEquals(0, cache.getRemotePosition()); - assertEquals(0, cache.getCurrentPosition()); - assertEquals(0, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(STREAM); - - assertEquals(0, cache.getRemotePosition()); - assertEquals(14, cache.getCurrentPosition()); - assertEquals(0, cache.getEarliestResendPosition()); - assertEquals(1, cache.size()); - - cache.updateRemotePosition(14); - - assertEquals(14, cache.getRemotePosition()); - assertEquals(14, cache.getCurrentPosition()); - assertEquals(14, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(CANCEL); - - assertEquals(14, cache.getRemotePosition()); - assertEquals(20, cache.getCurrentPosition()); - assertEquals(14, cache.getEarliestResendPosition()); - assertEquals(1, cache.size()); - - cache.updateRemotePosition(20); - - assertEquals(20, cache.getRemotePosition()); - assertEquals(20, cache.getCurrentPosition()); - assertEquals(20, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(STREAM); - - assertEquals(20, cache.getRemotePosition()); - assertEquals(34, cache.getCurrentPosition()); - assertEquals(20, cache.getEarliestResendPosition()); - assertEquals(1, cache.size()); - } - - @Test - public void supportsZeroBuffer() { - cache = new ResumeCache(ResumePositionCounter.frames(), 0); - - cache.sent(STREAM); - cache.sent(STREAM); - cache.sent(STREAM); - - assertEquals(0, cache.getRemotePosition()); - assertEquals(42, cache.getCurrentPosition()); - assertEquals(42, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - } - - @Test - public void supportsFrameCountBuffers() { - cache = new ResumeCache(ResumePositionCounter.size(), 100); - - assertEquals(0, cache.getRemotePosition()); - assertEquals(0, cache.getCurrentPosition()); - assertEquals(0, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(STREAM); - - assertEquals(0, cache.getRemotePosition()); - assertEquals(14, cache.getCurrentPosition()); - assertEquals(0, cache.getEarliestResendPosition()); - assertEquals(14, cache.size()); - - cache.updateRemotePosition(14); - - assertEquals(14, cache.getRemotePosition()); - assertEquals(14, cache.getCurrentPosition()); - assertEquals(14, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(CANCEL); - - assertEquals(14, cache.getRemotePosition()); - assertEquals(20, cache.getCurrentPosition()); - assertEquals(14, cache.getEarliestResendPosition()); - assertEquals(6, cache.size()); - }*/ -} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java new file mode 100644 index 000000000..7d2a7bcc8 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ResumeCalculatorTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ResumeCalculatorTest { + + @BeforeEach + void setUp() {} + + @Test + void clientResumeSuccess() { + long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 42, -1, 3); + Assertions.assertEquals(3, position); + } + + @Test + void clientResumeError() { + long position = ResumableDuplexConnection.calculateRemoteImpliedPos(4, 42, -1, 3); + Assertions.assertEquals(-1, position); + } + + @Test + void serverResumeSuccess() { + long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 42, 4, 23); + Assertions.assertEquals(23, position); + } + + @Test + void serverResumeErrorClientState() { + long position = ResumableDuplexConnection.calculateRemoteImpliedPos(1, 3, 4, 23); + Assertions.assertEquals(-1, position); + } + + @Test + void serverResumeErrorServerState() { + long position = ResumableDuplexConnection.calculateRemoteImpliedPos(4, 42, 4, 1); + Assertions.assertEquals(-1, position); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeExpBackoffTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeExpBackoffTest.java new file mode 100644 index 000000000..d86276466 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ResumeExpBackoffTest.java @@ -0,0 +1,75 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.time.Duration; +import java.util.List; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +public class ResumeExpBackoffTest { + + @Test + void backOffSeries() { + Duration firstBackoff = Duration.ofSeconds(1); + Duration maxBackoff = Duration.ofSeconds(32); + int factor = 2; + ExponentialBackoffResumeStrategy strategy = + new ExponentialBackoffResumeStrategy(firstBackoff, maxBackoff, factor); + + List expected = + Flux.just(1, 2, 4, 8, 16, 32, 32).map(Duration::ofSeconds).collectList().block(); + + List actual = Flux.range(1, 7).map(v -> strategy.next()).collectList().block(); + + Assertions.assertThat(actual).isEqualTo(expected); + } + + @Test + void nullFirstBackoff() { + assertThrows( + NullPointerException.class, + () -> { + ExponentialBackoffResumeStrategy strategy = + new ExponentialBackoffResumeStrategy(Duration.ofSeconds(1), null, 42); + }); + } + + @Test + void nullMaxBackoff() { + assertThrows( + NullPointerException.class, + () -> { + ExponentialBackoffResumeStrategy strategy = + new ExponentialBackoffResumeStrategy(null, Duration.ofSeconds(1), 42); + }); + } + + @Test + void negativeFactor() { + assertThrows( + IllegalArgumentException.class, + () -> { + ExponentialBackoffResumeStrategy strategy = + new ExponentialBackoffResumeStrategy( + Duration.ofSeconds(1), Duration.ofSeconds(32), -1); + }); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeTokenTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeTokenTest.java deleted file mode 100644 index 7a2fafc8a..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeTokenTest.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -import static org.junit.Assert.assertEquals; - -import java.util.UUID; -import org.junit.Test; - -public class ResumeTokenTest { - @Test - public void testFromUuid() { - UUID x = UUID.fromString("3bac9870-3873-403a-99f4-9728aa8c7860"); - - ResumeToken t = ResumeToken.bytes(ResumeToken.getBytesFromUUID(x)); - ResumeToken t2 = ResumeToken.bytes(ResumeToken.getBytesFromUUID(x)); - - assertEquals("3bac98703873403a99f49728aa8c7860", t.toString()); - - assertEquals(t.hashCode(), t2.hashCode()); - assertEquals(t, t2); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeUtilTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeUtilTest.java deleted file mode 100644 index fdd308d10..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeUtilTest.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.resume; - -public class ResumeUtilTest { - /*private Frame CANCEL = Frame.Cancel.from(1); - private Frame STREAM = - Frame.Request.from(1, FrameType.REQUEST_STREAM, DefaultPayload.create("Test"), 100); - - @Test - public void testSupportedTypes() { - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_STREAM)); - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_CHANNEL)); - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_RESPONSE)); - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_N)); - assertTrue(ResumeUtil.isTracked(FrameType.CANCEL)); - assertTrue(ResumeUtil.isTracked(FrameType.ERROR)); - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_FNF)); - assertTrue(ResumeUtil.isTracked(FrameType.PAYLOAD)); - } - - @Test - public void testUnsupportedTypes() { - assertFalse(ResumeUtil.isTracked(FrameType.METADATA_PUSH)); - assertFalse(ResumeUtil.isTracked(FrameType.RESUME)); - assertFalse(ResumeUtil.isTracked(FrameType.RESUME_OK)); - assertFalse(ResumeUtil.isTracked(FrameType.SETUP)); - assertFalse(ResumeUtil.isTracked(FrameType.EXT)); - assertFalse(ResumeUtil.isTracked(FrameType.KEEPALIVE)); - } - - @Test - public void testOffset() { - assertEquals(6, ResumeUtil.offset(CANCEL)); - assertEquals(14, ResumeUtil.offset(STREAM)); - }*/ -} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java index d945dd45d..58323c066 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.test.util; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import org.reactivestreams.Publisher; import reactor.core.publisher.DirectProcessor; @@ -25,17 +26,22 @@ import reactor.core.publisher.MonoProcessor; public class LocalDuplexConnection implements DuplexConnection { + private final ByteBufAllocator allocator; private final DirectProcessor send; private final DirectProcessor receive; private final MonoProcessor onClose; private final String name; public LocalDuplexConnection( - String name, DirectProcessor send, DirectProcessor receive) { + String name, + ByteBufAllocator allocator, + DirectProcessor send, + DirectProcessor receive) { this.name = name; + this.allocator = allocator; this.send = send; this.receive = receive; - onClose = MonoProcessor.create(); + this.onClose = MonoProcessor.create(); } @Override @@ -52,6 +58,11 @@ public Flux receive() { return receive.doOnNext(f -> System.out.println(name + " - " + f.toString())); } + @Override + public ByteBufAllocator alloc() { + return allocator; + } + @Override public void dispose() { onClose.onComplete(); diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java new file mode 100644 index 000000000..a30e75875 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java @@ -0,0 +1,26 @@ +package io.rsocket.test.util; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.transport.ClientTransport; +import reactor.core.publisher.Mono; + +public class TestClientTransport implements ClientTransport { + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private final TestDuplexConnection testDuplexConnection = new TestDuplexConnection(allocator); + + @Override + public Mono connect(int mtu) { + return Mono.just(testDuplexConnection); + } + + public TestDuplexConnection testConnection() { + return testDuplexConnection; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java index fd48cd9d3..17a19b8c9 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.test.util; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import java.util.Collection; import java.util.concurrent.ConcurrentLinkedQueue; @@ -27,6 +28,7 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.DirectProcessor; import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; @@ -40,18 +42,24 @@ public class TestDuplexConnection implements DuplexConnection { private final LinkedBlockingQueue sent; private final DirectProcessor sentPublisher; + private final FluxSink sendSink; private final DirectProcessor received; + private final FluxSink receivedSink; private final MonoProcessor onClose; private final ConcurrentLinkedQueue> sendSubscribers; + private final ByteBufAllocator allocator; private volatile double availability = 1; private volatile int initialSendRequestN = Integer.MAX_VALUE; - public TestDuplexConnection() { - sent = new LinkedBlockingQueue<>(); - received = DirectProcessor.create(); - sentPublisher = DirectProcessor.create(); - sendSubscribers = new ConcurrentLinkedQueue<>(); - onClose = MonoProcessor.create(); + public TestDuplexConnection(ByteBufAllocator allocator) { + this.allocator = allocator; + this.sent = new LinkedBlockingQueue<>(); + this.received = DirectProcessor.create(); + this.receivedSink = received.sink(); + this.sentPublisher = DirectProcessor.create(); + this.sendSink = sentPublisher.sink(); + this.sendSubscribers = new ConcurrentLinkedQueue<>(); + this.onClose = MonoProcessor.create(); } @Override @@ -65,7 +73,7 @@ public Mono send(Publisher frames) { .doOnNext( frame -> { sent.offer(frame); - sentPublisher.onNext(frame); + sendSink.next(frame); }) .doOnError(throwable -> logger.error("Error in send stream on test connection.", throwable)) .subscribe(subscriber); @@ -78,6 +86,11 @@ public Flux receive() { return received; } + @Override + public ByteBufAllocator alloc() { + return allocator; + } + @Override public double availability() { return availability; @@ -116,7 +129,7 @@ public Publisher getSentAsPublisher() { public void addToReceivedBuffer(ByteBuf... received) { for (ByteBuf frame : received) { - this.received.onNext(frame); + this.receivedSink.next(frame); } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java new file mode 100644 index 000000000..325496148 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java @@ -0,0 +1,54 @@ +package io.rsocket.test.util; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.transport.ServerTransport; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +public class TestServerTransport implements ServerTransport { + private final MonoProcessor conn = MonoProcessor.create(); + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + @Override + public Mono start(ConnectionAcceptor acceptor, int mtu) { + conn.flatMap(acceptor::apply) + .subscribe(ignored -> {}, err -> disposeConnection(), this::disposeConnection); + return Mono.just( + new Closeable() { + @Override + public Mono onClose() { + return conn.then(); + } + + @Override + public void dispose() { + conn.onComplete(); + } + + @Override + public boolean isDisposed() { + return conn.isTerminated(); + } + }); + } + + private void disposeConnection() { + TestDuplexConnection c = conn.peek(); + if (c != null) { + c.dispose(); + } + } + + public TestDuplexConnection connect() { + TestDuplexConnection c = new TestDuplexConnection(allocator); + conn.onNext(c); + return c; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java b/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java deleted file mode 100644 index 46634e94b..000000000 --- a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; -import reactor.core.publisher.Mono; - -public final class TestUriHandler implements UriHandler { - - private static final String SCHEME = "test"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(() -> Mono.just(new TestDuplexConnection())); - } - - @Override - public Optional buildServer(URI uri) { - return Optional.empty(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java b/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java deleted file mode 100644 index 9e7b92f65..000000000 --- a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import static org.junit.Assert.assertTrue; - -import io.rsocket.DuplexConnection; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.transport.ClientTransport; -import org.junit.Test; - -public class UriTransportRegistryTest { - @Test - public void testTestRegistered() { - ClientTransport test = UriTransportRegistry.clientForUri("test://test"); - - DuplexConnection duplexConnection = test.connect().block(); - - assertTrue(duplexConnection instanceof TestDuplexConnection); - } - - @Test(expected = UnsupportedOperationException.class) - public void testTestUnregistered() { - ClientTransport test = UriTransportRegistry.clientForUri("mailto://bonson@baulsupp.net"); - - test.connect().block(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java new file mode 100644 index 000000000..2ad944d09 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java @@ -0,0 +1,64 @@ +package io.rsocket.util; + +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ByteBufPayloadTest { + + @Test + public void shouldIndicateThatItHasMetadata() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = ByteBufPayload.create("data"); + + Assertions.assertThat(payload.hasMetadata()).isFalse(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + ByteBufPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldThrowExceptionIfAccessAfterRelease() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.release()).isTrue(); + + Assertions.assertThatThrownBy(payload::hasMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::data).isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::metadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::touch) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(() -> payload.touch("test")) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getDataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java index 45ee4eacb..6bae0886b 100644 --- a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java +++ b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java @@ -16,10 +16,13 @@ package io.rsocket.util; -import static org.hamcrest.MatcherAssert.*; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import io.netty.buffer.Unpooled; import io.rsocket.Payload; +import java.nio.ByteBuffer; +import org.assertj.core.api.Assertions; import org.junit.Test; public class DefaultPayloadTest { @@ -48,4 +51,27 @@ public void staticMethods() { assertDataAndMetadata(DefaultPayload.create(DATA_VAL, METADATA_VAL), DATA_VAL, METADATA_VAL); assertDataAndMetadata(DefaultPayload.create(DATA_VAL), DATA_VAL, null); } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = DefaultPayload.create("data"); + + Assertions.assertThat(payload.hasMetadata()).isFalse(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + DefaultPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata2() { + Payload payload = + DefaultPayload.create(ByteBuffer.wrap("data".getBytes()), ByteBuffer.allocate(0)); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java index 988bd523d..46e0f77f4 100644 --- a/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java @@ -18,6 +18,8 @@ import static org.assertj.core.api.Assertions.*; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -158,4 +160,28 @@ void requireUnsignedShortOverFlow() { .isThrownBy(() -> NumberUtils.requireUnsignedShort(1 << 16)) .withMessage("%d is larger than 16 bits", 1 << 16); } + + @Test + void encodeUnsignedMedium() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + NumberUtils.encodeUnsignedMedium(buffer, 129); + buffer.markReaderIndex(); + + assertThat(buffer.readUnsignedMedium()).as("reading as unsigned medium").isEqualTo(129); + + buffer.resetReaderIndex(); + assertThat(buffer.readMedium()).as("reading as signed medium").isEqualTo(129); + } + + @Test + void encodeUnsignedMediumLarge() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + NumberUtils.encodeUnsignedMedium(buffer, 0xFFFFFC); + buffer.markReaderIndex(); + + assertThat(buffer.readUnsignedMedium()).as("reading as unsigned medium").isEqualTo(16777212); + + buffer.resetReaderIndex(); + assertThat(buffer.readMedium()).as("reading as signed medium").isEqualTo(-4); + } } diff --git a/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index 068667aa7..000000000 --- a/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.uri.TestUriHandler diff --git a/rsocket-examples/build.gradle b/rsocket-examples/build.gradle index 04cd59c50..01e80cfa1 100644 --- a/rsocket-examples/build.gradle +++ b/rsocket-examples/build.gradle @@ -22,10 +22,13 @@ dependencies { implementation project(':rsocket-core') implementation project(':rsocket-transport-local') implementation project(':rsocket-transport-netty') + runtimeOnly 'ch.qos.logback:logback-classic' testImplementation project(':rsocket-test') testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.mockito:mockito-core' + testImplementation 'org.assertj:assertj-core' + testImplementation 'io.projectreactor:reactor-test' // TODO: Remove after JUnit5 migration testCompileOnly 'junit:junit' diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java index 386154f20..b532c0140 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,59 +16,48 @@ package io.rsocket.examples.transport.tcp.channel; -import io.rsocket.AbstractRSocket; -import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; -import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public final class ChannelEchoClient { + private static final Logger logger = LoggerFactory.getLogger(ChannelEchoClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor(new SocketAcceptorImpl()) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + + SocketAcceptor echoAcceptor = + SocketAcceptor.forRequestChannel( + payloads -> + Flux.from(payloads) + .map(Payload::getDataUtf8) + .map(s -> "Echo: " + s) + .map(DefaultPayload::create)); + + RSocketServer.create(echoAcceptor) + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); socket .requestChannel( Flux.interval(Duration.ofMillis(1000)).map(i -> DefaultPayload.create("Hello"))) .map(Payload::getDataUtf8) - .doOnNext(System.out::println) + .doOnNext(logger::debug) .take(10) .doFinally(signalType -> socket.dispose()) .then() .block(); } - - private static class SocketAcceptorImpl implements SocketAcceptor { - @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { - return Mono.just( - new AbstractRSocket() { - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.from(payloads) - .map(Payload::getDataUtf8) - .map(s -> "Echo: " + s) - .map(DefaultPayload::create); - } - }); - } - } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java index c0a271d66..3eba5a800 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,12 @@ package io.rsocket.examples.transport.tcp.duplex; -import io.rsocket.AbstractRSocket; +import static io.rsocket.SocketAcceptor.forRequestStream; + import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; @@ -30,36 +32,30 @@ public final class DuplexClient { public static void main(String[] args) { - RSocketFactory.receive() - .acceptor( - (setup, reactiveSocket) -> { - reactiveSocket + + RSocketServer.create( + (setup, rsocket) -> { + rsocket .requestStream(DefaultPayload.create("Hello-Bidi")) .map(Payload::getDataUtf8) .log() .subscribe(); - return Mono.just(new AbstractRSocket() {}); + return Mono.just(new RSocket() {}); }) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); - RSocket socket = - RSocketFactory.connect() + RSocket rsocket = + RSocketConnector.create() .acceptor( - rSocket -> - new AbstractRSocket() { - @Override - public Flux requestStream(Payload payload) { - return Flux.interval(Duration.ofSeconds(1)) - .map(aLong -> DefaultPayload.create("Bi-di Response => " + aLong)); - } - }) - .transport(TcpClientTransport.create("localhost", 7000)) - .start() + forRequestStream( + payload -> + Flux.interval(Duration.ofSeconds(1)) + .map(aLong -> DefaultPayload.create("Bi-di Response => " + aLong)))) + .connect(TcpClientTransport.create("localhost", 7000)) .block(); - socket.onClose().block(); + rsocket.onClose().block(); } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java new file mode 100644 index 000000000..3eaebd89a --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java @@ -0,0 +1,158 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease; + +import static java.time.Duration.ofSeconds; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.lease.Lease; +import io.rsocket.lease.LeaseStats; +import io.rsocket.lease.Leases; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.util.Date; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class LeaseExample { + private static final String SERVER_TAG = "server"; + private static final String CLIENT_TAG = "client"; + + public static void main(String[] args) { + + CloseableChannel server = + RSocketServer.create( + (setup, sendingRSocket) -> Mono.just(new ServerRSocket(sendingRSocket))) + .lease( + () -> + Leases.create() + .sender(new LeaseSender(SERVER_TAG, 7_000, 5)) + .receiver(new LeaseReceiver(SERVER_TAG)) + .stats(new NoopStats())) + .bind(TcpServerTransport.create("localhost", 7000)) + .block(); + + RSocket clientRSocket = + RSocketConnector.create() + .lease( + () -> + Leases.create() + .sender(new LeaseSender(CLIENT_TAG, 3_000, 5)) + .receiver(new LeaseReceiver(CLIENT_TAG))) + .acceptor( + SocketAcceptor.forRequestResponse( + payload -> Mono.just(DefaultPayload.create("Client Response " + new Date())))) + .connect(TcpClientTransport.create(server.address())) + .block(); + + Flux.interval(ofSeconds(1)) + .flatMap( + signal -> { + System.out.println("Client requester availability: " + clientRSocket.availability()); + return clientRSocket + .requestResponse(DefaultPayload.create("Client request " + new Date())) + .doOnError(err -> System.out.println("Client request error: " + err)) + .onErrorResume(err -> Mono.empty()); + }) + .subscribe(resp -> System.out.println("Client requester response: " + resp.getDataUtf8())); + + clientRSocket.onClose().block(); + server.dispose(); + } + + private static class LeaseSender implements Function, Flux> { + private final String tag; + private final int ttlMillis; + private final int allowedRequests; + + public LeaseSender(String tag, int ttlMillis, int allowedRequests) { + this.tag = tag; + this.ttlMillis = ttlMillis; + this.allowedRequests = allowedRequests; + } + + @Override + public Flux apply(Optional leaseStats) { + System.out.println( + String.format("%s stats are %s", tag, leaseStats.isPresent() ? "present" : "absent")); + return Flux.interval(ofSeconds(1), ofSeconds(10)) + .onBackpressureLatest() + .map( + tick -> { + System.out.println( + String.format( + "%s responder sends new leases: ttl: %d, requests: %d", + tag, ttlMillis, allowedRequests)); + return Lease.create(ttlMillis, allowedRequests); + }); + } + } + + private static class LeaseReceiver implements Consumer> { + private final String tag; + + public LeaseReceiver(String tag) { + this.tag = tag; + } + + @Override + public void accept(Flux receivedLeases) { + receivedLeases.subscribe( + l -> + System.out.println( + String.format( + "%s received leases - ttl: %d, requests: %d", + tag, l.getTimeToLiveMillis(), l.getAllowedRequests()))); + } + } + + private static class NoopStats implements LeaseStats { + + @Override + public void onEvent(EventType eventType) {} + } + + private static class ServerRSocket implements RSocket { + private final RSocket senderRSocket; + + public ServerRSocket(RSocket senderRSocket) { + this.senderRSocket = senderRSocket; + } + + @Override + public Mono requestResponse(Payload payload) { + System.out.println("Server requester availability: " + senderRSocket.availability()); + senderRSocket + .requestResponse(DefaultPayload.create("Server request " + new Date())) + .doOnError(err -> System.out.println("Server request error: " + err)) + .onErrorResume(err -> Mono.empty()) + .subscribe( + resp -> System.out.println("Server requester response: " + resp.getDataUtf8())); + + return Mono.just(DefaultPayload.create("Server Response " + new Date())); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java new file mode 100644 index 000000000..67a85b67f --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java @@ -0,0 +1,84 @@ +package io.rsocket.examples.transport.tcp.plugins; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.plugins.LimitRateInterceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public class LimitRateInterceptorExample { + + private static final Logger logger = LoggerFactory.getLogger(LimitRateInterceptorExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.interval(Duration.ofMillis(100)) + .doOnRequest( + e -> logger.debug("Server publisher receives request for " + e)) + .map(aLong -> DefaultPayload.create("Interval: " + aLong)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .doOnRequest( + e -> logger.debug("Server publisher receives request for " + e)); + } + })) + .interceptors(registry -> registry.forResponder(LimitRateInterceptor.forResponder(64))) + .bind(TcpServerTransport.create("localhost", 7000)) + .subscribe(); + + RSocket socket = + RSocketConnector.create() + .interceptors(registry -> registry.forRequester(LimitRateInterceptor.forRequester(64))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + logger.debug( + "\n\nStart of requestStream interaction\n" + "----------------------------------\n"); + + socket + .requestStream(DefaultPayload.create("Hello")) + .doOnRequest(e -> logger.debug("Client sends requestN(" + e + ")")) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .block(); + + logger.debug( + "\n\nStart of requestChannel interaction\n" + "-----------------------------------\n"); + + socket + .requestChannel( + Flux.generate( + () -> 1L, + (s, sink) -> { + sink.next(DefaultPayload.create("Next " + s)); + return ++s; + }) + .doOnRequest(e -> logger.debug("Client publisher receives request for " + e))) + .doOnRequest(e -> logger.debug("Client sends requestN(" + e + ")")) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .doFinally(signalType -> socket.dispose()) + .then() + .block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java index 537485fa4..85faeee82 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,65 +16,54 @@ package io.rsocket.examples.transport.tcp.requestresponse; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; public final class HelloWorldClient { + private static final Logger logger = LoggerFactory.getLogger(HelloWorldClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor( - (setupPayload, reactiveSocket) -> - Mono.just( - new AbstractRSocket() { - boolean fail = true; - @Override - public Mono requestResponse(Payload p) { - if (fail) { - fail = false; - return Mono.error(new Throwable()); - } else { - return Mono.just(p); - } - } - })) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() - .subscribe(); + RSocket rsocket = + new RSocket() { + boolean fail = true; - RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); + @Override + public Mono requestResponse(Payload p) { + if (fail) { + fail = false; + return Mono.error(new Throwable("Simulated error")); + } else { + return Mono.just(p); + } + } + }; - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + RSocketServer.create(SocketAcceptor.with(rsocket)) + .bind(TcpServerTransport.create("localhost", 7000)) + .subscribe(); - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + RSocket socket = + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + for (int i = 0; i < 3; i++) { + socket + .requestResponse(DefaultPayload.create("Hello")) + .map(Payload::getDataUtf8) + .onErrorReturn("error") + .doOnNext(logger::debug) + .block(); + } socket.dispose(); } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java new file mode 100644 index 000000000..6724ca93f --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java @@ -0,0 +1,141 @@ +package io.rsocket.examples.transport.tcp.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import java.io.BufferedInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.SynchronousSink; + +class Files { + private static final Logger logger = LoggerFactory.getLogger(Files.class); + + public static Flux fileSource(String fileName, int chunkSizeBytes) { + return Flux.generate( + () -> new FileState(fileName, chunkSizeBytes), FileState::consumeNext, FileState::dispose); + } + + public static Subscriber fileSink(String fileName, int windowSize) { + return new Subscriber() { + Subscription s; + int requests = windowSize; + OutputStream outputStream; + int receivedBytes; + int receivedCount; + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + this.s.request(requests); + } + + @Override + public void onNext(Payload payload) { + ByteBuf data = payload.data(); + receivedBytes += data.readableBytes(); + receivedCount += 1; + logger.debug("Received file chunk: " + receivedCount + ". Total size: " + receivedBytes); + if (outputStream == null) { + outputStream = open(fileName); + } + write(outputStream, data); + payload.release(); + + requests--; + if (requests == windowSize / 2) { + requests += windowSize; + s.request(windowSize); + } + } + + private void write(OutputStream outputStream, ByteBuf byteBuf) { + try { + byteBuf.readBytes(outputStream, byteBuf.readableBytes()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable t) { + close(outputStream); + } + + @Override + public void onComplete() { + close(outputStream); + } + + private OutputStream open(String filename) { + try { + /*do not buffer for demo purposes*/ + return new FileOutputStream(filename); + } catch (FileNotFoundException e) { + throw new RuntimeException(e); + } + } + + private void close(OutputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + } + } + } + }; + } + + private static class FileState { + private final String fileName; + private final int chunkSizeBytes; + private BufferedInputStream inputStream; + private byte[] chunkBytes; + + public FileState(String fileName, int chunkSizeBytes) { + this.fileName = fileName; + this.chunkSizeBytes = chunkSizeBytes; + } + + public FileState consumeNext(SynchronousSink sink) { + if (inputStream == null) { + InputStream in = getClass().getClassLoader().getResourceAsStream(fileName); + if (in == null) { + sink.error(new FileNotFoundException(fileName)); + return this; + } + this.inputStream = new BufferedInputStream(in); + this.chunkBytes = new byte[chunkSizeBytes]; + } + try { + int consumedBytes = inputStream.read(chunkBytes); + if (consumedBytes == -1) { + sink.complete(); + } else { + sink.next(Unpooled.copiedBuffer(chunkBytes, 0, consumedBytes)); + } + } catch (IOException e) { + sink.error(e); + } + return this; + } + + public void dispose() { + if (inputStream != null) { + try { + inputStream.close(); + } catch (IOException e) { + } + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java new file mode 100644 index 000000000..93b54e146 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -0,0 +1,118 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.resume; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.util.retry.Retry; + +public class ResumeFileTransfer { + + /*amount of file chunks requested by subscriber: n, refilled on n/2 of received items*/ + private static final int PREFETCH_WINDOW_SIZE = 4; + private static final Logger logger = LoggerFactory.getLogger(ResumeFileTransfer.class); + + public static void main(String[] args) { + + Resume resume = + new Resume() + .sessionDuration(Duration.ofMinutes(5)) + .retry( + Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)) + .doBeforeRetry(s -> logger.debug("Disconnected. Trying to resume..."))); + + RequestCodec codec = new RequestCodec(); + + CloseableChannel server = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> { + Request request = codec.decode(payload); + payload.release(); + String fileName = request.getFileName(); + int chunkSize = request.getChunkSize(); + + Flux ticks = Flux.interval(Duration.ofMillis(500)).onBackpressureDrop(); + + return Files.fileSource(fileName, chunkSize) + .map(DefaultPayload::create) + .zipWith(ticks, (p, tick) -> p); + })) + .resume(resume) + .bind(TcpServerTransport.create("localhost", 8000)) + .block(); + + RSocket client = + RSocketConnector.create() + .resume(resume) + .connect(TcpClientTransport.create("localhost", 8001)) + .block(); + + client + .requestStream(codec.encode(new Request(16, "lorem.txt"))) + .doFinally(s -> server.dispose()) + .subscribe(Files.fileSink("rsocket-examples/out/lorem_output.txt", PREFETCH_WINDOW_SIZE)); + + server.onClose().block(); + } + + private static class RequestCodec { + + public Payload encode(Request request) { + String encoded = request.getChunkSize() + ":" + request.getFileName(); + return DefaultPayload.create(encoded); + } + + public Request decode(Payload payload) { + String encoded = payload.getDataUtf8(); + String[] chunkSizeAndFileName = encoded.split(":"); + int chunkSize = Integer.parseInt(chunkSizeAndFileName[0]); + String fileName = chunkSizeAndFileName[1]; + return new Request(chunkSize, fileName); + } + } + + private static class Request { + private final int chunkSize; + private final String fileName; + + public Request(int chunkSize, String fileName) { + this.chunkSize = chunkSize; + this.fileName = fileName; + } + + public int getChunkSize() { + return chunkSize; + } + + public String getFileName() { + return fileName; + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md new file mode 100644 index 000000000..55e761fe8 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md @@ -0,0 +1,29 @@ +1. Start socat. It is used for emulation of transport disconnects + +`socat -d TCP-LISTEN:8001,fork,reuseaddr TCP:localhost:8000` + +2. start `ResumeFileTransfer.main` + +3. terminate/start socat periodically for session resumption + +`ResumeFileTransfer` output is as follows + +``` +Received file chunk: 7. Total size: 112 +Received file chunk: 8. Total size: 128 +Received file chunk: 9. Total size: 144 +Received file chunk: 10. Total size: 160 +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Received file chunk: 11. Total size: 176 +Received file chunk: 12. Total size: 192 +Received file chunk: 13. Total size: 208 +Received file chunk: 14. Total size: 224 +Received file chunk: 15. Total size: 240 +Received file chunk: 16. Total size: 256 +``` + +It transfers file from `resources/lorem.txt` to `build/out/lorem_output.txt` in chunks of 16 bytes every 500 millis diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java index 57a659c1d..6ac329d56 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,51 +16,43 @@ package io.rsocket.examples.transport.tcp.stream; -import io.rsocket.*; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public final class StreamingClient { + private static final Logger logger = LoggerFactory.getLogger(StreamingClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor(new SocketAcceptorImpl()) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.interval(Duration.ofMillis(100)) + .map(aLong -> DefaultPayload.create("Interval: " + aLong)))) + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); socket .requestStream(DefaultPayload.create("Hello")) .map(Payload::getDataUtf8) - .doOnNext(System.out::println) + .doOnNext(logger::debug) .take(10) .then() .doFinally(signalType -> socket.dispose()) .then() .block(); } - - private static class SocketAcceptorImpl implements SocketAcceptor { - @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { - return Mono.just( - new AbstractRSocket() { - @Override - public Flux requestStream(Payload payload) { - return Flux.interval(Duration.ofMillis(100)) - .map(aLong -> DefaultPayload.create("Interval: " + aLong)); - } - }); - } - } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java new file mode 100644 index 000000000..2ab73116d --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.ws; + +import io.netty.handler.codec.http.HttpResponseStatus; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.HashMap; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.netty.Connection; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +public class WebSocketHeadersSample { + static final Payload payload1 = ByteBufPayload.create("Hello "); + + public static void main(String[] args) { + + ServerTransport.ConnectionAcceptor acceptor = + RSocketServer.create(SocketAcceptor.with(new ServerRSocket())) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .asConnectionAcceptor(); + + DisposableServer disposableServer = + HttpServer.create() + .host("localhost") + .port(0) + .route( + routes -> + routes.ws( + "/", + (in, out) -> { + if (in.headers().containsValue("Authorization", "test", true)) { + DuplexConnection connection = + new ReassemblyDuplexConnection( + new WebsocketDuplexConnection((Connection) in), false); + return acceptor.apply(connection).then(out.neverComplete()); + } + + return out.sendClose( + HttpResponseStatus.UNAUTHORIZED.code(), + HttpResponseStatus.UNAUTHORIZED.reasonPhrase()); + })) + .bindNow(); + + WebsocketClientTransport clientTransport = + WebsocketClientTransport.create(disposableServer.host(), disposableServer.port()); + + clientTransport.setTransportHeaders( + () -> { + HashMap map = new HashMap<>(); + map.put("Authorization", "test"); + return map; + }); + + RSocket socket = + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(clientTransport) + .block(); + + Flux.range(0, 100) + .concatMap(i -> socket.fireAndForget(payload1.retain())) + // .doOnNext(p -> { + //// System.out.println(p.getDataUtf8()); + // p.release(); + // }) + .blockLast(); + socket.dispose(); + + WebsocketClientTransport clientTransport2 = + WebsocketClientTransport.create(disposableServer.host(), disposableServer.port()); + + RSocket rSocket = + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(clientTransport2) + .block(); + + // expect error here because of closed channel + rSocket.requestResponse(payload1).block(); + } + + private static class ServerRSocket implements RSocket { + + @Override + public Mono fireAndForget(Payload payload) { + // System.out.println(payload.getDataUtf8()); + payload.release(); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads).subscribeOn(Schedulers.single()); + } + } +} diff --git a/rsocket-examples/src/main/resources/log4j.properties b/rsocket-examples/src/main/resources/log4j.properties deleted file mode 100644 index 035f18ebd..000000000 --- a/rsocket-examples/src/main/resources/log4j.properties +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=DEBUG, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n \ No newline at end of file diff --git a/rsocket-examples/src/main/resources/logback.xml b/rsocket-examples/src/main/resources/logback.xml new file mode 100644 index 000000000..17dd8b5e3 --- /dev/null +++ b/rsocket-examples/src/main/resources/logback.xml @@ -0,0 +1,35 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + + + diff --git a/rsocket-examples/src/main/resources/lorem.txt b/rsocket-examples/src/main/resources/lorem.txt new file mode 100644 index 000000000..e035ea86d --- /dev/null +++ b/rsocket-examples/src/main/resources/lorem.txt @@ -0,0 +1,32 @@ +Alteration literature to or an sympathize mr imprudence. Of is ferrars subject as enjoyed or tedious cottage. +Procuring as in resembled by in agreeable. Next long no gave mr eyes. Admiration advantages no he celebrated so pianoforte unreserved. +Not its herself forming charmed amiable. Him why feebly expect future now. + +Situation admitting promotion at or to perceived be. Mr acuteness we as estimable enjoyment up. +An held late as felt know. Learn do allow solid to grave. Middleton suspicion age her attention. +Chiefly several bed its wishing. Is so moments on chamber pressed to. Doubtful yet way properly answered humanity its desirous. + Minuter believe service arrived civilly add all. Acuteness allowance an at eagerness favourite in extensive exquisite ye. + + Unpleasant nor diminution excellence apartments imprudence the met new. Draw part them he an to he roof only. + Music leave say doors him. Tore bred form if sigh case as do. Staying he no looking if do opinion. + Sentiments way understood end partiality and his. + + Ladyship it daughter securing procured or am moreover mr. Put sir she exercise vicinity cheerful wondered. + Continual say suspicion provision you neglected sir curiosity unwilling. Simplicity end themselves increasing led day sympathize yet. + General windows effects not are drawing man garrets. Common indeed garden you his ladies out yet. Preference imprudence contrasted to remarkably in on. + Taken now you him trees tears any. Her object giving end sister except oppose. + + No comfort do written conduct at prevent manners on. Celebrated contrasted discretion him sympathize her collecting occasional. + Do answered bachelor occasion in of offended no concerns. Supply worthy warmth branch of no ye. Voice tried known to as my to. + Though wished merits or be. Alone visit use these smart rooms ham. No waiting in on enjoyed placing it inquiry. + + So insisted received is occasion advanced honoured. Among ready to which up. Attacks smiling and may out assured moments man nothing outward. + Thrown any behind afford either the set depend one temper. Instrument melancholy in acceptance collecting frequently be if. + Zealously now pronounce existence add you instantly say offending. Merry their far had widen was. Concerns no in expenses raillery formerly. + + As am hastily invited settled at limited civilly fortune me. Really spring in extent an by. Judge but built gay party world. + Of so am he remember although required. Bachelor unpacked be advanced at. Confined in declared marianne is vicinity. + + In alteration insipidity impression by travelling reasonable up motionless. Of regard warmth by unable sudden garden ladies. + No kept hung am size spot no. Likewise led and dissuade rejoiced welcomed husbands boy. Do listening on he suspected resembled. + Water would still if to. Position boy required law moderate was may. \ No newline at end of file diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java index 627b1d7da..e2471f2fc 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,12 +23,13 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.plugins.DuplexConnectionInterceptor; import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.plugins.SocketAcceptorInterceptor; import io.rsocket.test.TestSubscriber; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; @@ -38,7 +39,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.reactivestreams.Publisher; @@ -48,35 +48,54 @@ public class IntegrationTest { - private static final RSocketInterceptor clientPlugin; - private static final RSocketInterceptor serverPlugin; - private static final DuplexConnectionInterceptor connectionPlugin; - public static volatile boolean calledClient = false; - public static volatile boolean calledServer = false; - public static volatile boolean calledFrame = false; + private static final RSocketInterceptor requesterInterceptor; + private static final RSocketInterceptor responderInterceptor; + private static final SocketAcceptorInterceptor clientAcceptorInterceptor; + private static final SocketAcceptorInterceptor serverAcceptorInterceptor; + private static final DuplexConnectionInterceptor connectionInterceptor; + + private static volatile boolean calledRequester = false; + private static volatile boolean calledResponder = false; + private static volatile boolean calledClientAcceptor = false; + private static volatile boolean calledServerAcceptor = false; + private static volatile boolean calledFrame = false; static { - clientPlugin = + requesterInterceptor = reactiveSocket -> new RSocketProxy(reactiveSocket) { @Override public Mono requestResponse(Payload payload) { - calledClient = true; + calledRequester = true; return reactiveSocket.requestResponse(payload); } }; - serverPlugin = + responderInterceptor = reactiveSocket -> new RSocketProxy(reactiveSocket) { @Override public Mono requestResponse(Payload payload) { - calledServer = true; + calledResponder = true; return reactiveSocket.requestResponse(payload); } }; - connectionPlugin = + clientAcceptorInterceptor = + acceptor -> + (setup, sendingSocket) -> { + calledClientAcceptor = true; + return acceptor.accept(setup, sendingSocket); + }; + + serverAcceptorInterceptor = + acceptor -> + (setup, sendingSocket) -> { + calledServerAcceptor = true; + return acceptor.accept(setup, sendingSocket); + }; + + connectionInterceptor = (type, connection) -> { calledFrame = true; return connection; @@ -95,17 +114,8 @@ public void startup() { requestCount = new AtomicInteger(); disconnectionCounter = new CountDownLatch(1); - TcpServerTransport serverTransport = TcpServerTransport.create(0); - server = - RSocketFactory.receive() - .addServerPlugin(serverPlugin) - .addConnectionPlugin(connectionPlugin) - .errorConsumer( - t -> { - errorCount.incrementAndGet(); - }) - .acceptor( + RSocketServer.create( (setup, sendingSocket) -> { sendingSocket .onClose() @@ -113,7 +123,7 @@ public void startup() { .subscribe(); return Mono.just( - new AbstractRSocket() { + new RSocket() { @Override public Mono requestResponse(Payload payload) { return Mono.just(DefaultPayload.create("RESPONSE", "METADATA")) @@ -132,16 +142,24 @@ public Flux requestChannel(Publisher payloads) { } }); }) - .transport(serverTransport) - .start() + .interceptors( + registry -> + registry + .forResponder(responderInterceptor) + .forSocketAcceptor(serverAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) .block(); client = - RSocketFactory.connect() - .addClientPlugin(clientPlugin) - .addConnectionPlugin(connectionPlugin) - .transport(TcpClientTransport.create(server.address())) - .start() + RSocketConnector.create() + .interceptors( + registry -> + registry + .forRequester(requesterInterceptor) + .forSocketAcceptor(clientAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .connect(TcpClientTransport.create(server.address())) .block(); } @@ -154,8 +172,10 @@ public void teardown() { public void testRequest() { client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); assertThat("Server did not see the request.", requestCount.get(), is(1)); - assertTrue(calledClient); - assertTrue(calledServer); + assertTrue(calledRequester); + assertTrue(calledResponder); + assertTrue(calledClientAcceptor); + assertTrue(calledServerAcceptor); assertTrue(calledFrame); } @@ -181,8 +201,6 @@ public void testCallRequestWithErrorAndThenRequest() { } catch (Throwable t) { } - Assert.assertEquals(1, errorCount.incrementAndGet()); - testRequest(); } } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java index 6c8f0e8fa..48e5baaa7 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -1,9 +1,10 @@ package io.rsocket.integration; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.test.SlowTest; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; @@ -14,32 +15,26 @@ import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public class InteractionsLoadTest { @Test @SlowTest public void channel() { - TcpServerTransport serverTransport = TcpServerTransport.create(0); - CloseableChannel server = - RSocketFactory.receive() - .acceptor((setup, rsocket) -> Mono.just(new EchoRSocket())) - .transport(serverTransport) - .start() + RSocketServer.create(SocketAcceptor.with(new EchoRSocket())) + .bind(TcpServerTransport.create("localhost", 0)) .block(Duration.ofSeconds(10)); - TcpClientTransport transport = TcpClientTransport.create(server.address()); - - RSocket client = - RSocketFactory.connect().transport(transport).start().block(Duration.ofSeconds(10)); + RSocket clientRSocket = + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) + .block(Duration.ofSeconds(10)); int concurrency = 16; Flux.range(1, concurrency) .flatMap( v -> - client + clientRSocket .requestChannel( input().onBackpressureDrop().map(iv -> DefaultPayload.create("foo"))) .limitRate(10000), @@ -70,7 +65,8 @@ private static Flux input() { return interval; } - private static class EchoRSocket extends AbstractRSocket { + private static class EchoRSocket implements RSocket { + @Override public Flux requestChannel(Publisher payloads) { return Flux.from(payloads) diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java index f5d048508..de27bcb9b 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -19,10 +19,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; @@ -40,26 +40,20 @@ import reactor.core.scheduler.Schedulers; public class TcpIntegrationTest { - private AbstractRSocket handler; + private RSocket handler; private CloseableChannel server; @Before public void startup() { - TcpServerTransport serverTransport = TcpServerTransport.create(0); server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) - .transport(serverTransport) - .start() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) + .bind(TcpServerTransport.create("localhost", 0)) .block(); } private RSocket buildClient() { - return RSocketFactory.connect() - .transport(TcpClientTransport.create(server.address())) - .start() - .block(); + return RSocketConnector.connectWith(TcpClientTransport.create(server.address())).block(); } @After @@ -67,10 +61,10 @@ public void cleanup() { server.dispose(); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testCompleteWithoutNext() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return Flux.empty(); @@ -83,10 +77,10 @@ public Flux requestStream(Payload payload) { assertFalse(hasElements); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testSingleStream() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return Flux.just(DefaultPayload.create("RESPONSE", "METADATA")); @@ -100,10 +94,10 @@ public Flux requestStream(Payload payload) { assertEquals("RESPONSE", result.getDataUtf8()); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testZeroPayload() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return Flux.just(EmptyPayload.INSTANCE); @@ -117,10 +111,10 @@ public Flux requestStream(Payload payload) { assertEquals("", result.getDataUtf8()); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testRequestResponseErrors() { handler = - new AbstractRSocket() { + new RSocket() { boolean first = true; @Override @@ -151,7 +145,7 @@ public Mono requestResponse(Payload payload) { assertEquals("SUCCESS", response2.getDataUtf8()); } - @Test(timeout = 5_000L) + @Test(timeout = 15_000L) public void testTwoConcurrentStreams() throws InterruptedException { ConcurrentHashMap> map = new ConcurrentHashMap<>(); UnicastProcessor processor1 = UnicastProcessor.create(); @@ -160,7 +154,7 @@ public void testTwoConcurrentStreams() throws InterruptedException { map.put("REQUEST2", processor2); handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return map.get(payload.getDataUtf8()); diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java index ec1d41bf9..7d34ba478 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,60 +16,41 @@ package io.rsocket.integration; -import io.rsocket.*; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; import io.rsocket.transport.local.LocalClientTransport; import io.rsocket.transport.local.LocalServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Supplier; import org.junit.Test; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public class TestingStreaming { - private Supplier> serverSupplier = - () -> LocalServerTransport.create("test"); - - private Supplier clientSupplier = () -> LocalClientTransport.create("test"); + LocalServerTransport serverTransport = LocalServerTransport.create("test"); @Test(expected = ApplicationErrorException.class) public void testRangeButThrowException() { Closeable server = null; try { server = - RSocketFactory.receive() - .errorConsumer(Throwable::printStackTrace) - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1000) - .doOnNext( - i -> { - if (i > 3) { - throw new RuntimeException("BOOM!"); - } - }) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 1000) + .doOnNext( + i -> { + if (i > 3) { + throw new RuntimeException("BOOM!"); + } + }) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); @@ -85,29 +66,13 @@ public void testRangeOfConsumers() { Closeable server = null; try { server = - RSocketFactory.receive() - .errorConsumer(Throwable::printStackTrace) - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1000) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 1000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); @@ -119,10 +84,7 @@ public Flux requestStream(Payload payload) { } private Flux consumer(String s) { - return RSocketFactory.connect() - .errorConsumer(Throwable::printStackTrace) - .transport(clientSupplier) - .start() + return RSocketConnector.connectWith(LocalClientTransport.create("test")) .flatMapMany( rSocket -> { AtomicInteger count = new AtomicInteger(); @@ -135,31 +97,15 @@ private Flux consumer(String s) { @Test public void testSingleConsumer() { Closeable server = null; - try { server = - RSocketFactory.receive() - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 10_000) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 10_000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); consumer("1").blockLast(); diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java b/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java new file mode 100644 index 000000000..e29066f02 --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java @@ -0,0 +1,75 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.rsocket.DuplexConnection; +import io.rsocket.transport.ClientTransport; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import reactor.core.publisher.Mono; + +class DisconnectableClientTransport implements ClientTransport { + private final ClientTransport clientTransport; + private final AtomicReference curConnection = new AtomicReference<>(); + private long nextConnectPermitMillis; + + public DisconnectableClientTransport(ClientTransport clientTransport) { + this.clientTransport = clientTransport; + } + + @Override + public Mono connect(int mtu) { + return Mono.defer( + () -> + now() < nextConnectPermitMillis + ? Mono.error(new ClosedChannelException()) + : clientTransport + .connect(mtu) + .map( + c -> { + if (curConnection.compareAndSet(null, c)) { + return c; + } else { + throw new IllegalStateException( + "Transport supports at most 1 connection"); + } + })); + } + + public void disconnect() { + disconnectFor(Duration.ZERO); + } + + public void disconnectPermanently() { + disconnectFor(Duration.ofDays(42)); + } + + public void disconnectFor(Duration cooldown) { + DuplexConnection cur = curConnection.getAndSet(null); + if (cur != null) { + nextConnectPermitMillis = now() + cooldown.toMillis(); + cur.dispose(); + } else { + throw new IllegalStateException("Trying to disconnect while not connected"); + } + } + + private static long now() { + return System.currentTimeMillis(); + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java new file mode 100644 index 000000000..b2dad0022 --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java @@ -0,0 +1,229 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.exceptions.UnsupportedSetupException; +import io.rsocket.test.SlowTest; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; + +@SlowTest +public class ResumeIntegrationTest { + private static final String SERVER_HOST = "localhost"; + private static final int SERVER_PORT = 0; + + @Test + void timeoutOnPermanentDisconnect() { + CloseableChannel closeable = newServerRSocket().block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + + int sessionDurationSeconds = 5; + RSocket rSocket = newClientRSocket(clientTransport, sessionDurationSeconds).block(); + + Mono.delay(Duration.ofSeconds(1)).subscribe(v -> clientTransport.disconnectPermanently()); + + StepVerifier.create( + rSocket.requestChannel(testRequest()).then().doFinally(s -> closeable.dispose())) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(7)); + } + + @Test + public void reconnectOnDisconnect() { + CloseableChannel closeable = newServerRSocket().block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + + int sessionDurationSeconds = 15; + RSocket rSocket = newClientRSocket(clientTransport, sessionDurationSeconds).block(); + + Flux.just(3, 20, 40, 75) + .flatMap(v -> Mono.delay(Duration.ofSeconds(v))) + .subscribe(v -> clientTransport.disconnectFor(Duration.ofSeconds(7))); + + AtomicInteger counter = new AtomicInteger(-1); + StepVerifier.create( + rSocket + .requestChannel(testRequest()) + .take(Duration.ofSeconds(600)) + .map(Payload::getDataUtf8) + .timeout(Duration.ofSeconds(12)) + .doOnNext(x -> throwOnNonContinuous(counter, x)) + .then() + .doFinally(s -> closeable.dispose())) + .expectComplete() + .verify(); + } + + @Test + public void reconnectOnMissingSession() { + + int serverSessionDuration = 2; + + CloseableChannel closeable = newServerRSocket(serverSessionDuration).block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + int clientSessionDurationSeconds = 10; + + RSocket rSocket = newClientRSocket(clientTransport, clientSessionDurationSeconds).block(); + + Mono.delay(Duration.ofSeconds(1)) + .subscribe(v -> clientTransport.disconnectFor(Duration.ofSeconds(3))); + + StepVerifier.create( + rSocket.requestChannel(testRequest()).then().doFinally(s -> closeable.dispose())) + .expectError() + .verify(Duration.ofSeconds(5)); + + StepVerifier.create(rSocket.onClose()) + .expectErrorMatches( + err -> + err instanceof RejectedResumeException + && "unknown resume token".equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + } + + @Test + void serverMissingResume() { + CloseableChannel closeableChannel = + RSocketServer.create(SocketAcceptor.with(new TestResponderRSocket())) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)) + .block(); + + RSocket rSocket = + RSocketConnector.create() + .resume(new Resume()) + .connect(clientTransport(closeableChannel.address())) + .block(); + + StepVerifier.create(rSocket.onClose().doFinally(s -> closeableChannel.dispose())) + .expectErrorMatches( + err -> + err instanceof UnsupportedSetupException + && "resume not supported".equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(rSocket.isDisposed()).isTrue(); + } + + static ClientTransport clientTransport(InetSocketAddress address) { + return TcpClientTransport.create(address); + } + + static ServerTransport serverTransport(String host, int port) { + return TcpServerTransport.create(host, port); + } + + private static Flux testRequest() { + return Flux.interval(Duration.ofMillis(500)) + .map(v -> DefaultPayload.create("client_request")) + .onBackpressureDrop(); + } + + private void throwOnNonContinuous(AtomicInteger counter, String x) { + int curValue = Integer.parseInt(x); + int prevValue = counter.get(); + if (prevValue >= 0) { + int dif = curValue - prevValue; + if (dif != 1) { + throw new IllegalStateException( + String.format( + "Payload values are expected to be continuous numbers: %d %d", + prevValue, curValue)); + } + } + counter.set(curValue); + } + + private static Mono newClientRSocket( + DisconnectableClientTransport clientTransport, int sessionDurationSeconds) { + return RSocketConnector.create() + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .storeFactory(t -> new InMemoryResumableFramesStore("client", 500_000)) + .cleanupStoreOnKeepAlive() + .retry(Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)))) + .keepAlive(Duration.ofSeconds(5), Duration.ofMinutes(5)) + .connect(clientTransport); + } + + private static Mono newServerRSocket() { + return newServerRSocket(15); + } + + private static Mono newServerRSocket(int sessionDurationSeconds) { + return RSocketServer.create(SocketAcceptor.with(new TestResponderRSocket())) + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .cleanupStoreOnKeepAlive() + .storeFactory(t -> new InMemoryResumableFramesStore("server", 500_000))) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)); + } + + private static class TestResponderRSocket implements RSocket { + + AtomicInteger counter = new AtomicInteger(); + + @Override + public Flux requestChannel(Publisher payloads) { + return duplicate( + Flux.interval(Duration.ofMillis(1)) + .onBackpressureLatest() + .publishOn(Schedulers.elastic()), + 20) + .map(v -> DefaultPayload.create(String.valueOf(counter.getAndIncrement()))) + .takeUntilOther(Flux.from(payloads).then()); + } + + private Flux duplicate(Flux f, int n) { + Flux r = Flux.empty(); + for (int i = 0; i < n; i++) { + r = r.mergeWith(f); + } + return r; + } + } +} diff --git a/rsocket-examples/src/test/resources/log4j.properties b/rsocket-examples/src/test/resources/log4j.properties deleted file mode 100644 index 51731fc15..000000000 --- a/rsocket-examples/src/test/resources/log4j.properties +++ /dev/null @@ -1,21 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss,SSS} %5p [%t] (%F) - %m%n -#log4j.logger.io.rsocket.FrameLogger=Debug \ No newline at end of file diff --git a/rsocket-examples/src/test/resources/logback-test.xml b/rsocket-examples/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-examples/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-load-balancer/build.gradle b/rsocket-load-balancer/build.gradle index a2c8b73c7..748f95de6 100644 --- a/rsocket-load-balancer/build.gradle +++ b/rsocket-load-balancer/build.gradle @@ -34,6 +34,7 @@ dependencies { testCompileOnly 'junit:junit' testImplementation 'org.hamcrest:hamcrest-library' testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + testRuntimeOnly 'ch.qos.logback:logback-classic' } description = 'Transparent Load Balancer for RSocket' diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java index 7bea75318..65ce80934 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java @@ -279,7 +279,6 @@ private synchronized void addSockets(int numberOfNewSocket) { if (optional.isPresent()) { RSocketSupplier supplier = optional.get(); WeightedSocket socket = new WeightedSocket(supplier, lowerQuantile, higherQuantile); - activeSockets.add(socket); } else { break; } @@ -356,7 +355,8 @@ private synchronized void quickSlowestRS() { } if (slowest != null) { - activeSockets.remove(slowest); + logger.debug("Disposing slowest WeightedSocket {}", slowest); + slowest.dispose(); } } @@ -374,10 +374,11 @@ public synchronized double availability() { } private synchronized RSocket select() { + refreshSockets(); + if (activeSockets.isEmpty()) { return FAILING_REACTIVE_SOCKET; } - refreshSockets(); int size = activeSockets.size(); if (size == 1) { @@ -460,7 +461,7 @@ public synchronized String toString() { @Override public void dispose() { - synchronized (this) {; + synchronized (this) { activeSockets.forEach(WeightedSocket::dispose); activeSockets.clear(); onClose.onComplete(); @@ -535,7 +536,7 @@ public Mono onClose() { * Wrapper of a RSocket, it computes statistics about the req/resp calls and update availability * accordingly. */ - private class WeightedSocket extends AbstractRSocket implements LoadBalancerSocketMetrics { + private class WeightedSocket implements LoadBalancerSocketMetrics, RSocket { private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; private final Quantile lowerQuantile; @@ -553,6 +554,7 @@ private class WeightedSocket extends AbstractRSocket implements LoadBalancerSock private AtomicLong pendingStreams; // number of active streams private volatile double availability = 0.0; + private final MonoProcessor onClose = MonoProcessor.create(); WeightedSocket( RSocketSupplier factory, @@ -572,13 +574,16 @@ private class WeightedSocket extends AbstractRSocket implements LoadBalancerSock this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); this.pendingStreams = new AtomicLong(); + logger.debug("Creating WeightedSocket {} from factory {}", WeightedSocket.this, factory); + WeightedSocket.this .onClose() .doFinally( s -> { pool.accept(factory); activeSockets.remove(WeightedSocket.this); - refreshSockets(); + logger.debug( + "Removed {} from factory {} from activeSockets", WeightedSocket.this, factory); }) .subscribe(); @@ -587,7 +592,11 @@ private class WeightedSocket extends AbstractRSocket implements LoadBalancerSock .retryBackoff(weightedSocketRetries, weightedSocketBackOff, weightedSocketMaxBackOff) .doOnError( throwable -> { - logger.error("error while connecting {}", throwable); + logger.error( + "error while connecting {} from factory {}", + WeightedSocket.this, + factory, + throwable); WeightedSocket.this.dispose(); }) .subscribe( @@ -597,7 +606,8 @@ private class WeightedSocket extends AbstractRSocket implements LoadBalancerSock .onClose() .doFinally( signalType -> { - System.out.println("RSocket closed"); + logger.info( + "RSocket {} from factory {} closed", WeightedSocket.this, factory); WeightedSocket.this.dispose(); }) .subscribe(); @@ -607,7 +617,7 @@ private class WeightedSocket extends AbstractRSocket implements LoadBalancerSock .onClose() .doFinally( signalType -> { - System.out.println("Factory closed"); + logger.info("Factory {} closed", factory); rSocket.dispose(); }) .subscribe(); @@ -617,20 +627,30 @@ private class WeightedSocket extends AbstractRSocket implements LoadBalancerSock .onClose() .doFinally( signalType -> { - System.out.println("WeightedSocket closed"); + logger.info( + "WeightedSocket {} from factory {} closed", + WeightedSocket.this, + factory); rSocket.dispose(); }) .subscribe(); - synchronized (LoadBalancedRSocketMono.this) { + /*synchronized (LoadBalancedRSocketMono.this) { if (activeSockets.size() >= targetAperture) { quickSlowestRS(); pendingSockets -= 1; } - } - + }*/ rSocketMono.onNext(rSocket); availability = 1.0; + if (!WeightedSocket.this + .isDisposed()) { // May be already disposed because of retryBackoff delay + activeSockets.add(WeightedSocket.this); + logger.debug( + "Added WeightedSocket {} from factory {} to activeSockets", + WeightedSocket.this, + factory); + } }); } @@ -772,6 +792,21 @@ public double availability() { return availability; } + @Override + public void dispose() { + onClose.onComplete(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + @Override public String toString() { return "WeightedSocket(" @@ -828,11 +863,11 @@ public long lastTimeUsedMillis() { */ private class LatencySubscriber implements Subscriber { private final Subscriber child; - private final LoadBalancedRSocketMono.WeightedSocket socket; + private final WeightedSocket socket; private final AtomicBoolean done; private long start; - LatencySubscriber(Subscriber child, LoadBalancedRSocketMono.WeightedSocket socket) { + LatencySubscriber(Subscriber child, WeightedSocket socket) { this.child = child; this.socket = socket; this.done = new AtomicBoolean(false); @@ -892,9 +927,9 @@ public void onComplete() { */ private class CountingSubscriber implements Subscriber { private final Subscriber child; - private final LoadBalancedRSocketMono.WeightedSocket socket; + private final WeightedSocket socket; - CountingSubscriber(Subscriber child, LoadBalancedRSocketMono.WeightedSocket socket) { + CountingSubscriber(Subscriber child, WeightedSocket socket) { this.child = child; this.socket = socket; } @@ -915,8 +950,8 @@ public void onError(Throwable t) { socket.pendingStreams.decrementAndGet(); child.onError(t); if (t instanceof TransportException || t instanceof ClosedChannelException) { - activeSockets.remove(socket); - refreshSockets(); + logger.debug("Disposing {} from activeSockets because of error {}", socket, t); + socket.dispose(); } } diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java index 35615d2a2..1683ee125 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java @@ -84,6 +84,9 @@ private synchronized void handleNewFactories(Collection newFact } factoryPool.addAll(added); + if (!added.isEmpty()) { + changed = true; + } if (changed && logger.isDebugEnabled()) { StringBuilder msgBuilder = new StringBuilder(); @@ -104,8 +107,10 @@ private synchronized void handleNewFactories(Collection newFact @Override public synchronized void accept(RSocketSupplier rSocketSupplier) { - leasedSuppliers.remove(rSocketSupplier); - if (!rSocketSupplier.isDisposed()) { + boolean contained = leasedSuppliers.remove(rSocketSupplier); + if (contained + && !rSocketSupplier + .isDisposed()) { // only added leasedSupplier back to factoryPool if it's still there factoryPool.add(rSocketSupplier); } } @@ -119,6 +124,7 @@ public synchronized Optional get() { if (rSocketSupplier.availability() > 0.0) { factoryPool.remove(0); leasedSuppliers.add(rSocketSupplier); + logger.debug("Added {} to leasedSuppliers", rSocketSupplier); optional = Optional.of(rSocketSupplier); } } else if (poolSize > 1) { @@ -143,10 +149,12 @@ public synchronized Optional get() { if (factory0.availability() > factory1.availability()) { factoryPool.remove(i0); leasedSuppliers.add(factory0); + logger.debug("Added {} to leasedSuppliers", factory0); optional = Optional.of(factory0); } else { factoryPool.remove(i1); leasedSuppliers.add(factory1); + logger.debug("Added {} to leasedSuppliers", factory1); optional = Optional.of(factory1); } } diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java new file mode 100644 index 000000000..55ce5646c --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +package io.rsocket.client.filter; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java new file mode 100644 index 000000000..ec21dee96 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +package io.rsocket.client; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java new file mode 100644 index 000000000..cfb071175 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +package io.rsocket.stat; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java index b529e426c..4baa106c5 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java @@ -19,19 +19,15 @@ import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.client.filter.RSocketSupplier; -import io.rsocket.util.EmptyPayload; -import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.util.Arrays; +import java.util.Collections; import java.util.List; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CompletableFuture; import java.util.function.Function; import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -39,11 +35,8 @@ public class LoadBalancedRSocketMonoTest { @Test(timeout = 10_000L) public void testNeverSelectFailingFactories() throws InterruptedException { - InetSocketAddress local0 = InetSocketAddress.createUnresolved("localhost", 7000); - InetSocketAddress local1 = InetSocketAddress.createUnresolved("localhost", 7001); - TestingRSocket socket = new TestingRSocket(Function.identity()); - RSocketSupplier failing = failingClient(local0); + RSocketSupplier failing = failingClient(); RSocketSupplier succeeding = succeedingFactory(socket); List factories = Arrays.asList(failing, succeeding); @@ -52,9 +45,6 @@ public void testNeverSelectFailingFactories() throws InterruptedException { @Test(timeout = 10_000L) public void testNeverSelectFailingSocket() throws InterruptedException { - InetSocketAddress local0 = InetSocketAddress.createUnresolved("localhost", 7000); - InetSocketAddress local1 = InetSocketAddress.createUnresolved("localhost", 7001); - TestingRSocket socket = new TestingRSocket(Function.identity()); TestingRSocket failingSocket = new TestingRSocket(Function.identity()) { @@ -76,6 +66,33 @@ public double availability() { testBalancer(clients); } + @Test(timeout = 10_000L) + public void testRefreshesSocketsOnSelectBeforeReturningFailedAfterNewFactoriesDelivered() { + TestingRSocket socket = new TestingRSocket(Function.identity()); + + CompletableFuture laterSupplier = new CompletableFuture<>(); + Flux> factories = + Flux.create( + s -> { + s.next(Collections.emptyList()); + + laterSupplier.handle( + (RSocketSupplier result, Throwable t) -> { + s.next(Collections.singletonList(result)); + return null; + }); + }); + + LoadBalancedRSocketMono balancer = LoadBalancedRSocketMono.create(factories); + + Assert.assertEquals(0.0, balancer.availability(), 0); + + laterSupplier.complete(succeedingFactory(socket)); + balancer.rSocketMono.block(); + + Assert.assertEquals(1.0, balancer.availability(), 0); + } + private void testBalancer(List factories) throws InterruptedException { Publisher> src = s -> { @@ -92,39 +109,6 @@ private void testBalancer(List factories) throws InterruptedExc Flux.range(0, 100).flatMap(i -> balancer).blockLast(); } - private void makeAcall(RSocket balancer) throws InterruptedException { - CountDownLatch latch = new CountDownLatch(1); - - balancer - .requestResponse(EmptyPayload.INSTANCE) - .subscribe( - new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - s.request(1L); - } - - @Override - public void onNext(Payload payload) { - System.out.println("Successfully receiving a response"); - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - Assert.assertTrue(false); - latch.countDown(); - } - - @Override - public void onComplete() { - latch.countDown(); - } - }); - - latch.await(); - } - private static RSocketSupplier succeedingFactory(RSocket socket) { RSocketSupplier mock = Mockito.mock(RSocketSupplier.class); @@ -135,7 +119,7 @@ private static RSocketSupplier succeedingFactory(RSocket socket) { return mock; } - private static RSocketSupplier failingClient(SocketAddress sa) { + private static RSocketSupplier failingClient() { RSocketSupplier mock = Mockito.mock(RSocketSupplier.class); Mockito.when(mock.availability()).thenReturn(0.0); diff --git a/rsocket-load-balancer/src/test/resources/log4j.properties b/rsocket-load-balancer/src/test/resources/log4j.properties deleted file mode 100644 index 8fc3a9cdd..000000000 --- a/rsocket-load-balancer/src/test/resources/log4j.properties +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] (%F:%L) - %m%n \ No newline at end of file diff --git a/rsocket-load-balancer/src/test/resources/logback-test.xml b/rsocket-load-balancer/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-load-balancer/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-micrometer/build.gradle b/rsocket-micrometer/build.gradle index 5f2aeb16f..4be616623 100644 --- a/rsocket-micrometer/build.gradle +++ b/rsocket-micrometer/build.gradle @@ -27,8 +27,6 @@ dependencies { implementation 'org.slf4j:slf4j-api' - compileOnly 'com.google.code.findbugs:jsr305' - testImplementation project(':rsocket-test') testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java index 20d58dcb7..c8b22382a 100644 --- a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java @@ -20,8 +20,9 @@ import io.micrometer.core.instrument.*; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; import java.util.Objects; @@ -82,6 +83,11 @@ final class MicrometerDuplexConnection implements DuplexConnection { this.frameCounters = new FrameCounters(connectionType, meterRegistry, tags); } + @Override + public ByteBufAllocator alloc() { + return delegate.alloc(); + } + @Override public void dispose() { delegate.dispose(); @@ -185,7 +191,7 @@ private static Counter counter( @Override public void accept(ByteBuf frame) { - FrameType frameType = FrameHeaderFlyweight.frameType(frame); + FrameType frameType = FrameHeaderCodec.frameType(frame); switch (frameType) { case SETUP: diff --git a/rsocket-test/build.gradle b/rsocket-test/build.gradle index 3009b5135..5ec1a8061 100644 --- a/rsocket-test/build.gradle +++ b/rsocket-test/build.gradle @@ -26,8 +26,6 @@ dependencies { api 'org.hdrhistogram:HdrHistogram' api 'org.junit.jupiter:junit-jupiter-api' - compileOnly 'com.google.code.findbugs:jsr305' - implementation 'io.projectreactor:reactor-test' implementation 'org.assertj:assertj-core' implementation 'org.mockito:mockito-core' diff --git a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java index ec143b7ab..6f562875f 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java +++ b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,8 @@ import io.rsocket.Closeable; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.function.BiFunction; @@ -47,17 +48,13 @@ public ClientSetupRule( this.serverInit = address -> - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) - .transport(serverTransportSupplier.apply(address)) - .start() + RSocketServer.create((setup, rsocket) -> Mono.just(new TestRSocket(data, metadata))) + .bind(serverTransportSupplier.apply(address)) .block(); this.clientConnector = (address, server) -> - RSocketFactory.connect() - .transport(clientTransportSupplier.apply(address, server)) - .start() + RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) .doOnError(Throwable::printStackTrace) .block(); } diff --git a/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java b/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java new file mode 100644 index 000000000..3830ec1bc --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java @@ -0,0 +1,17 @@ +package io.rsocket.test; + +import java.lang.annotation.*; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +/** + * {@code @PerfTest} is used to signal that the annotated test class or method is performance test, + * and is disabled unless enabled via setting the {@code TEST_PERF_ENABLED} environment variable to + * {@code true}. + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@EnabledIfEnvironmentVariable(named = "TEST_PERF_ENABLED", matches = "(?i)true") +@Test +public @interface PerfTest {} diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java index d1ff04c79..9017e854b 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java @@ -20,7 +20,9 @@ import io.rsocket.RSocket; import io.rsocket.util.ByteBufPayload; import java.time.Duration; +import java.util.function.BiFunction; import org.HdrHistogram.Recorder; +import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -49,7 +51,18 @@ public Recorder startTracker(Duration interval) { return histogram; } - public Flux startPingPong(int count, final Recorder histogram) { + public Flux requestResponsePingPong(int count, final Recorder histogram) { + return pingPong(RSocket::requestResponse, count, histogram); + } + + public Flux requestStreamPingPong(int count, final Recorder histogram) { + return pingPong(RSocket::requestStream, count, histogram); + } + + Flux pingPong( + BiFunction> interaction, + int count, + final Recorder histogram) { return client .flatMapMany( rsocket -> @@ -57,8 +70,7 @@ public Flux startPingPong(int count, final Recorder histogram) { .flatMap( i -> { long start = System.nanoTime(); - return rsocket - .requestResponse(payload.retain()) + return Flux.from(interaction.apply(rsocket, payload.retain())) .doOnNext(Payload::release) .doFinally( signalType -> { diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java index 2f54ddb50..47f40a59d 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java @@ -16,13 +16,13 @@ package io.rsocket.test; -import io.rsocket.AbstractRSocket; import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; import io.rsocket.util.ByteBufPayload; import java.util.concurrent.ThreadLocalRandom; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class PingHandler implements SocketAcceptor { @@ -42,12 +42,18 @@ public PingHandler(byte[] data) { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { return Mono.just( - new AbstractRSocket() { + new RSocket() { @Override public Mono requestResponse(Payload payload) { payload.release(); return Mono.just(pong.retain()); } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.range(0, 100).map(v -> pong.retain()); + } }); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java index 18c23057a..1e66abc5e 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java @@ -22,6 +22,7 @@ import io.rsocket.Payload; import io.rsocket.frame.*; import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; /** Test instances of all frame types. */ public final class TestFrames { @@ -32,80 +33,76 @@ private TestFrames() {} /** @return {@link ByteBuf} representing test instance of Cancel frame */ public static ByteBuf createTestCancelFrame() { - return CancelFrameFlyweight.encode(allocator, 1); + return CancelFrameCodec.encode(allocator, 1); } /** @return {@link ByteBuf} representing test instance of Error frame */ public static ByteBuf createTestErrorFrame() { - return ErrorFrameFlyweight.encode(allocator, 1, new RuntimeException()); + return ErrorFrameCodec.encode(allocator, 1, new RuntimeException()); } /** @return {@link ByteBuf} representing test instance of Extension frame */ public static ByteBuf createTestExtensionFrame() { - return ExtensionFrameFlyweight.encode( + return ExtensionFrameCodec.encode( allocator, 1, 1, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Keep-Alive frame */ public static ByteBuf createTestKeepaliveFrame() { - return KeepAliveFrameFlyweight.encode(allocator, false, 1, Unpooled.EMPTY_BUFFER); + return KeepAliveFrameCodec.encode(allocator, false, 1, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Lease frame */ public static ByteBuf createTestLeaseFrame() { - return LeaseFlyweight.encode(allocator, 1, 1, null); + return LeaseFrameCodec.encode(allocator, 1, 1, null); } /** @return {@link ByteBuf} representing test instance of Metadata-Push frame */ public static ByteBuf createTestMetadataPushFrame() { - return MetadataPushFrameFlyweight.encode(allocator, Unpooled.EMPTY_BUFFER); + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Payload frame */ public static ByteBuf createTestPayloadFrame() { - return PayloadFrameFlyweight.encode( - allocator, 1, false, true, false, null, Unpooled.EMPTY_BUFFER); + return PayloadFrameCodec.encode(allocator, 1, false, true, false, null, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Request-Channel frame */ public static ByteBuf createTestRequestChannelFrame() { - return RequestChannelFrameFlyweight.encode( + return RequestChannelFrameCodec.encode( allocator, 1, false, false, 1, null, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Fire-and-Forget frame */ public static ByteBuf createTestRequestFireAndForgetFrame() { - return RequestFireAndForgetFrameFlyweight.encode( - allocator, 1, false, null, Unpooled.EMPTY_BUFFER); + return RequestFireAndForgetFrameCodec.encode(allocator, 1, false, null, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Request-N frame */ public static ByteBuf createTestRequestNFrame() { - return RequestNFrameFlyweight.encode(allocator, 1, 1); + return RequestNFrameCodec.encode(allocator, 1, 1); } /** @return {@link ByteBuf} representing test instance of Request-Response frame */ public static ByteBuf createTestRequestResponseFrame() { - return RequestResponseFrameFlyweight.encode(allocator, 1, false, emptyPayload); + return RequestResponseFrameCodec.encodeReleasingPayload(allocator, 1, emptyPayload); } /** @return {@link ByteBuf} representing test instance of Request-Stream frame */ public static ByteBuf createTestRequestStreamFrame() { - return RequestStreamFrameFlyweight.encode(allocator, 1, false, 1L, emptyPayload); + return RequestStreamFrameCodec.encodeReleasingPayload(allocator, 1, 1L, emptyPayload); } /** @return {@link ByteBuf} representing test instance of Setup frame */ public static ByteBuf createTestSetupFrame() { - return SetupFrameFlyweight.encode( + return SetupFrameCodec.encode( allocator, false, - false, 1, 1, Unpooled.EMPTY_BUFFER, "metadataType", "dataType", - null, - Unpooled.EMPTY_BUFFER); + EmptyPayload.INSTANCE); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index 57a2e5c3c..d48700445 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -16,14 +16,14 @@ package io.rsocket.test; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.util.DefaultPayload; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -public class TestRSocket extends AbstractRSocket { +public class TestRSocket implements RSocket { private final String data; private final String metadata; @@ -55,6 +55,6 @@ public Mono fireAndForget(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { // TODO is defensive copy neccesary? - return Flux.from(payloads).map(DefaultPayload::create); + return Flux.from(payloads).map(Payload::retain); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index fc6301d7d..fc059c7d1 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,28 +19,62 @@ import io.rsocket.Closeable; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.util.DefaultPayload; +import java.io.BufferedReader; +import java.io.InputStreamReader; import java.time.Duration; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import reactor.core.Disposable; import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; public interface TransportTest { + String MOCK_DATA = "test-data"; + String MOCK_METADATA = "metadata"; + String LARGE_DATA = read("words.shakespeare.txt.gz"); + Payload LARGE_PAYLOAD = DefaultPayload.create(LARGE_DATA, LARGE_DATA); + + static String read(String resourceName) { + + try (BufferedReader br = + new BufferedReader( + new InputStreamReader( + new GZIPInputStream( + TransportTest.class.getClassLoader().getResourceAsStream(resourceName))))) { + + return br.lines().map(String::toLowerCase).collect(Collectors.joining("\n\r")); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @BeforeEach + default void setUp() { + Hooks.onOperatorDebug(); + } + @AfterEach default void close() { getTransportPair().dispose(); + Hooks.resetOnOperatorDebug(); } default Payload createTestPayload(int metadataPresent) { @@ -54,12 +88,12 @@ default Payload createTestPayload(int metadataPresent) { metadata1 = ""; break; default: - metadata1 = "metadata"; + metadata1 = MOCK_METADATA; break; } String metadata = metadata1; - return DefaultPayload.create("test-data", metadata); + return DefaultPayload.create(MOCK_DATA, metadata); } @DisplayName("makes 10 fireAndForget requests") @@ -73,6 +107,17 @@ default void fireAndForget10() { .verify(getTimeout()); } + @DisplayName("makes 10 fireAndForget with Large Payload in Requests") + @Test + default void largePayloadFireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD)) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + default RSocket getClient() { return getTransportPair().getClient(); } @@ -92,6 +137,17 @@ default void metadataPush10() { .verify(getTimeout()); } + @DisplayName("makes 10 metadataPush with Large Metadata in requests") + @Test + default void largePayloadMetadataPush10() { + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(DefaultPayload.create("", LARGE_DATA))) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 1 requestChannel request with 0 payloads") @Test default void requestChannel0() { @@ -127,6 +183,19 @@ default void requestChannel200_000() { .verify(getTimeout()); } + @DisplayName("makes 1 requestChannel request with 200 large payloads") + @Test + default void largePayloadRequestChannel200() { + Flux payloads = Flux.range(0, 200).map(__ -> LARGE_PAYLOAD); + + getClient() + .requestChannel(payloads) + .as(StepVerifier::create) + .expectNextCount(200) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 1 requestChannel request with 20,000 payloads") @Test default void requestChannel20_000() { @@ -157,14 +226,18 @@ default void requestChannel2_000_000() { @DisplayName("makes 1 requestChannel request with 3 payloads") @Test default void requestChannel3() { - Flux payloads = Flux.range(0, 3).map(this::createTestPayload); + AtomicLong requested = new AtomicLong(); + Flux payloads = + Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); getClient() .requestChannel(payloads) - .as(StepVerifier::create) + .as(publisher -> StepVerifier.create(publisher, 3)) .expectNextCount(3) .expectComplete() .verify(getTimeout()); + + Assertions.assertThat(requested.get()).isEqualTo(3L); } @DisplayName("makes 1 requestChannel request with 512 payloads") @@ -223,6 +296,17 @@ default void requestResponse100() { .verify(getTimeout()); } + @DisplayName("makes 100 requestResponse requests") + @Test + default void largePayloadRequestResponse100() { + Flux.range(1, 100) + .flatMap(i -> getClient().requestResponse(LARGE_PAYLOAD).map(Payload::getDataUtf8)) + .as(StepVerifier::create) + .expectNextCount(100) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 10,000 requestResponse requests") @Test default void requestResponse10_000() { @@ -283,7 +367,7 @@ default void assertPayload(Payload p) { } default void assertChannelPayload(Payload p) { - if (!"test-data".equals(p.getDataUtf8()) || !"metadata".equals(p.getMetadataUtf8())) { + if (!MOCK_DATA.equals(p.getDataUtf8()) || !MOCK_METADATA.equals(p.getMetadataUtf8())) { throw new IllegalStateException("Unexpected payload"); } } @@ -304,16 +388,12 @@ public TransportPair( T address = addressSupplier.get(); server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) - .transport(serverTransportSupplier.apply(address)) - .start() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) + .bind(serverTransportSupplier.apply(address)) .block(); client = - RSocketFactory.connect() - .transport(clientTransportSupplier.apply(address, server)) - .start() + RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) .doOnError(Throwable::printStackTrace) .block(); } diff --git a/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java b/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java deleted file mode 100644 index ad45e106a..000000000 --- a/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; - -import io.rsocket.uri.UriHandler; -import java.net.URI; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -public interface UriHandlerTest { - - @DisplayName("returns empty Optional client with invalid URI") - @Test - default void buildClientInvalidUri() { - assertThat(getUriHandler().buildClient(URI.create(getInvalidUri()))).isEmpty(); - } - - @DisplayName("buildClient throws NullPointerException with null uri") - @Test - default void buildClientNullUri() { - assertThatNullPointerException() - .isThrownBy(() -> getUriHandler().buildClient(null)) - .withMessage("uri must not be null"); - } - - @DisplayName("returns client with value URI") - @Test - default void buildClientValidUri() { - assertThat(getUriHandler().buildClient(URI.create(getValidUri()))).isNotEmpty(); - } - - @DisplayName("returns empty Optional server with invalid URI") - @Test - default void buildServerInvalidUri() { - assertThat(getUriHandler().buildServer(URI.create(getInvalidUri()))).isEmpty(); - } - - @DisplayName("buildServer throws NullPointerException with null uri") - @Test - default void buildServerNullUri() { - assertThatNullPointerException() - .isThrownBy(() -> getUriHandler().buildServer(null)) - .withMessage("uri must not be null"); - } - - @DisplayName("returns server with value URI") - @Test - default void buildServerValidUri() { - assertThat(getUriHandler().buildServer(URI.create(getValidUri()))).isNotEmpty(); - } - - String getInvalidUri(); - - UriHandler getUriHandler(); - - String getValidUri(); -} diff --git a/rsocket-test/src/main/resources/words.shakespeare.txt.gz b/rsocket-test/src/main/resources/words.shakespeare.txt.gz new file mode 100644 index 000000000..422a4b331 Binary files /dev/null and b/rsocket-test/src/main/resources/words.shakespeare.txt.gz differ diff --git a/rsocket-transport-local/build.gradle b/rsocket-transport-local/build.gradle index 8c3226065..a5ba84d5c 100644 --- a/rsocket-transport-local/build.gradle +++ b/rsocket-transport-local/build.gradle @@ -24,8 +24,6 @@ plugins { dependencies { api project(':rsocket-core') - compileOnly 'com.google.code.findbugs:jsr305' - testImplementation project(':rsocket-test') testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java index 91e8c3e57..d69bd65e8 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -17,14 +17,17 @@ package io.rsocket.transport.local; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; +import io.rsocket.internal.UnboundedProcessor; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.local.LocalServerTransport.ServerDuplexConnectionAcceptor; import java.util.Objects; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.UnicastProcessor; /** * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} in the @@ -34,25 +37,42 @@ public final class LocalClientTransport implements ClientTransport { private final String name; - private LocalClientTransport(String name) { + private final ByteBufAllocator allocator; + + private LocalClientTransport(String name, ByteBufAllocator allocator) { this.name = name; + this.allocator = allocator; } /** * Creates a new instance. * - * @param name the name of the {@link ServerTransport} instance to connect to + * @param name the name of the {@link ClientTransport} instance to connect to * @return a new instance * @throws NullPointerException if {@code name} is {@code null} */ public static LocalClientTransport create(String name) { Objects.requireNonNull(name, "name must not be null"); - return new LocalClientTransport(name); + return create(name, ByteBufAllocator.DEFAULT); } - @Override - public Mono connect() { + /** + * Creates a new instance. + * + * @param name the name of the {@link ClientTransport} instance to connect to + * @param allocator the allocator used by {@link ClientTransport} instance + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ + public static LocalClientTransport create(String name, ByteBufAllocator allocator) { + Objects.requireNonNull(name, "name must not be null"); + Objects.requireNonNull(allocator, "allocator must not be null"); + + return new LocalClientTransport(name, allocator); + } + + private Mono connect() { return Mono.defer( () -> { ServerDuplexConnectionAcceptor server = LocalServerTransport.findServer(name); @@ -60,13 +80,29 @@ public Mono connect() { return Mono.error(new IllegalArgumentException("Could not find server: " + name)); } - UnicastProcessor in = UnicastProcessor.create(); - UnicastProcessor out = UnicastProcessor.create(); + UnboundedProcessor in = new UnboundedProcessor<>(); + UnboundedProcessor out = new UnboundedProcessor<>(); MonoProcessor closeNotifier = MonoProcessor.create(); - server.accept(new LocalDuplexConnection(out, in, closeNotifier)); + server.accept(new LocalDuplexConnection(allocator, out, in, closeNotifier)); - return Mono.just((DuplexConnection) new LocalDuplexConnection(in, out, closeNotifier)); + return Mono.just( + (DuplexConnection) new LocalDuplexConnection(allocator, in, out, closeNotifier)); + }); + } + + @Override + public Mono connect(int mtu) { + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + Mono connect = isError != null ? isError : connect(); + + return connect.map( + duplexConnection -> { + if (mtu > 0) { + return new FragmentationDuplexConnection(duplexConnection, mtu, false, "client"); + } else { + return new ReassemblyDuplexConnection(duplexConnection, false); + } }); } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java index a295d2b00..afaa14f95 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.transport.local; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import java.util.Objects; import org.reactivestreams.Publisher; @@ -28,6 +29,7 @@ /** An implementation of {@link DuplexConnection} that connects inside the same JVM. */ final class LocalDuplexConnection implements DuplexConnection { + private final ByteBufAllocator allocator; private final Flux in; private final MonoProcessor onClose; @@ -42,7 +44,12 @@ final class LocalDuplexConnection implements DuplexConnection { * @param onClose the closing notifier * @throws NullPointerException if {@code in}, {@code out}, or {@code onClose} are {@code null} */ - LocalDuplexConnection(Flux in, Subscriber out, MonoProcessor onClose) { + LocalDuplexConnection( + ByteBufAllocator allocator, + Flux in, + Subscriber out, + MonoProcessor onClose) { + this.allocator = Objects.requireNonNull(allocator, "allocator must not be null"); this.in = Objects.requireNonNull(in, "in must not be null"); this.out = Objects.requireNonNull(out, "out must not be null"); this.onClose = Objects.requireNonNull(onClose, "onClose must not be null"); @@ -73,12 +80,18 @@ public Flux receive() { public Mono send(Publisher frames) { Objects.requireNonNull(frames, "frames must not be null"); - return Flux.from(frames) - .doOnNext( - byteBuf -> { - byteBuf.retain(); - out.onNext(byteBuf); - }) - .then(); + return Flux.from(frames).doOnNext(out::onNext).then(); + } + + @Override + public Mono sendOne(ByteBuf frame) { + Objects.requireNonNull(frame, "frame must not be null"); + out.onNext(frame); + return Mono.empty(); + } + + @Override + public ByteBufAllocator alloc() { + return allocator; } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java index 68e7d462f..382b4533a 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -18,6 +18,8 @@ import io.rsocket.Closeable; import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.Objects; @@ -76,6 +78,20 @@ public static void dispose(String name) { registry.remove(name); } + /** + * Retrieves an instance of {@link ServerDuplexConnectionAcceptor} based on the name of its {@code + * LocalServerTransport}. Returns {@code null} if that server is not registered. + * + * @param name the name of the server to retrieve + * @return the server if it has been registered, {@code null} otherwise + * @throws NullPointerException if {@code name} is {@code null} + */ + static @Nullable ServerDuplexConnectionAcceptor findServer(String name) { + Objects.requireNonNull(name, "name must not be null"); + + return registry.get(name); + } + /** * Returns a new {@link LocalClientTransport} that is connected to this {@code * LocalServerTransport}. @@ -88,34 +104,23 @@ public LocalClientTransport clientTransport() { } @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); - return Mono.create( - sink -> { - ServerDuplexConnectionAcceptor serverDuplexConnectionAcceptor = - new ServerDuplexConnectionAcceptor(name, acceptor); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : Mono.create( + sink -> { + ServerDuplexConnectionAcceptor serverDuplexConnectionAcceptor = + new ServerDuplexConnectionAcceptor(name, acceptor, mtu); - if (registry.putIfAbsent(name, serverDuplexConnectionAcceptor) != null) { - throw new IllegalStateException("name already registered: " + name); - } + if (registry.putIfAbsent(name, serverDuplexConnectionAcceptor) != null) { + throw new IllegalStateException("name already registered: " + name); + } - sink.success(serverDuplexConnectionAcceptor); - }); - } - - /** - * Retrieves an instance of {@link ServerDuplexConnectionAcceptor} based on the name of its {@code - * LocalServerTransport}. Returns {@code null} if that server is not registered. - * - * @param name the name of the server to retrieve - * @return the server if it has been registered, {@code null} otherwise - * @throws NullPointerException if {@code name} is {@code null} - */ - static @Nullable ServerDuplexConnectionAcceptor findServer(String name) { - Objects.requireNonNull(name, "name must not be null"); - - return registry.get(name); + sink.success(serverDuplexConnectionAcceptor); + }); } /** @@ -138,6 +143,8 @@ static class ServerDuplexConnectionAcceptor implements Consumer onClose = MonoProcessor.create(); + private final int mtu; + /** * Creates a new instance * @@ -145,17 +152,25 @@ static class ServerDuplexConnectionAcceptor implements Consumer 0) { + duplexConnection = + new FragmentationDuplexConnection(duplexConnection, mtu, false, "server"); + } else { + duplexConnection = new ReassemblyDuplexConnection(duplexConnection, false); + } + acceptor.apply(duplexConnection).subscribe(); } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java deleted file mode 100644 index 89c816d7a..000000000 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - -/** - * An implementation of {@link UriHandler} that creates {@link LocalClientTransport}s and {@link - * LocalServerTransport}s. - */ -public final class LocalUriHandler implements UriHandler { - - private static final String SCHEME = "local"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(LocalClientTransport.create(uri.getSchemeSpecificPart())); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(LocalServerTransport.create(uri.getSchemeSpecificPart())); - } -} diff --git a/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index 6ff8ffb50..000000000 --- a/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.transport.local.LocalUriHandler diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java index 92478b0bd..4cfee9a01 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java @@ -32,8 +32,8 @@ void connect() { LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); serverTransport - .start(duplexConnection -> Mono.empty()) - .flatMap(closeable -> LocalClientTransport.create(serverTransport.getName()).connect()) + .start(duplexConnection -> Mono.empty(), 0) + .flatMap(closeable -> LocalClientTransport.create(serverTransport.getName()).connect(0)) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -43,7 +43,7 @@ void connect() { @Test void connectNoServer() { LocalClientTransport.create("test-name") - .connect() + .connect(0) .as(StepVerifier::create) .verifyErrorMessage("Could not find server: test-name"); } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java index 58a287948..9228e2d05 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ package io.rsocket.transport.local; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingClient; import io.rsocket.test.PingHandler; @@ -28,18 +29,15 @@ public final class LocalPingPong { public static void main(String... args) { - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(LocalServerTransport.create("test-local-server")) - .start() + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(LocalServerTransport.create("test-local-server")) .block(); Mono client = - RSocketFactory.connect() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(LocalClientTransport.create("test-local-server")) - .start(); + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(LocalClientTransport.create("test-local-server")); PingClient pingClient = new PingClient(client); @@ -48,7 +46,7 @@ public static void main(String... args) { int count = 1_000_000_000; pingClient - .startPingPong(count, recorder) + .requestResponsePingPong(count, recorder) .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) .blockLast(); } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java index 7fb350432..1656ed08d 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java @@ -63,7 +63,7 @@ void findServer() { LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); serverTransport - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -97,7 +97,7 @@ void named() { @Test void start() { LocalServerTransport.createEphemeral() - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -107,7 +107,7 @@ void start() { @Test void startNullAcceptor() { assertThatNullPointerException() - .isThrownBy(() -> LocalServerTransport.createEphemeral().start(null)) + .isThrownBy(() -> LocalServerTransport.createEphemeral().start(null, 0)) .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java deleted file mode 100644 index f6b5cda7e..000000000 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class LocalUriTransportRegistryTest { - - @DisplayName("local URI returns LocalClientTransport") - @Test - void clientForUri() { - assertThat(UriTransportRegistry.clientForUri("local:test1")) - .isInstanceOf(LocalClientTransport.class); - } - - @DisplayName("non-local URI does not return LocalClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("http://localhost")) - .isNotInstanceOf(LocalClientTransport.class); - } - - @DisplayName("local URI returns LocalServerTransport") - @Test - void serverForUri() { - assertThat(UriTransportRegistry.serverForUri("local:test1")) - .isInstanceOf(LocalServerTransport.class); - } - - @DisplayName("non-local URI does not return LocalServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("http://localhost")) - .isNotInstanceOf(LocalServerTransport.class); - } -} diff --git a/rsocket-transport-netty/build.gradle b/rsocket-transport-netty/build.gradle index 5865c3756..64e483c90 100644 --- a/rsocket-transport-netty/build.gradle +++ b/rsocket-transport-netty/build.gradle @@ -30,12 +30,13 @@ if (osdetector.classifier in ["linux-x86_64"] || ["osx-x86_64"] || ["windows-x86 dependencies { api project(':rsocket-core') api 'io.projectreactor.netty:reactor-netty' - - compileOnly 'com.google.code.findbugs:jsr305' + api 'org.slf4j:slf4j-api' testImplementation project(':rsocket-test') testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' + testImplementation 'org.mockito:mockito-core' + testImplementation 'org.mockito:mockito-junit-jupiter' testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.junit.jupiter:junit-jupiter-params' diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java index 68d7ab175..e2c134653 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java @@ -16,8 +16,8 @@ package io.rsocket.transport.netty; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_SIZE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_SIZE; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java deleted file mode 100644 index b84201ac9..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/SendPublisher.java +++ /dev/null @@ -1,295 +0,0 @@ -package io.rsocket.transport.netty; - -import io.netty.buffer.ByteBuf; -import io.netty.channel.Channel; -import io.netty.channel.ChannelPromise; -import io.netty.channel.EventLoop; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.ReferenceCounted; -import java.util.Queue; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Function; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Fuseable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; -import reactor.util.concurrent.Queues; - -class SendPublisher extends Flux { - - private static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(SendPublisher.class, "wip"); - - private static final int MAX_SIZE = Queues.SMALL_BUFFER_SIZE; - private static final int REFILL_SIZE = MAX_SIZE / 2; - private static final AtomicReferenceFieldUpdater INNER_SUBSCRIBER = - AtomicReferenceFieldUpdater.newUpdater(SendPublisher.class, Object.class, "innerSubscriber"); - private static final AtomicIntegerFieldUpdater TERMINATED = - AtomicIntegerFieldUpdater.newUpdater(SendPublisher.class, "terminated"); - private final Publisher source; - private final Channel channel; - private final EventLoop eventLoop; - - private final Queue queue; - private final AtomicBoolean completed = new AtomicBoolean(); - private final Function transformer; - private final SizeOf sizeOf; - - @SuppressWarnings("unused") - private volatile int terminated; - - private int pending; - - @SuppressWarnings("unused") - private volatile int wip; - - @SuppressWarnings("unused") - private volatile Object innerSubscriber; - - private long requested; - - private long requestedUpstream = MAX_SIZE; - - private boolean fuse; - - @SuppressWarnings("unchecked") - SendPublisher( - Publisher source, - Channel channel, - Function transformer, - SizeOf sizeOf) { - this(Queues.small().get(), source, channel, transformer, sizeOf); - } - - @SuppressWarnings("unchecked") - SendPublisher( - Queue queue, - Publisher source, - Channel channel, - Function transformer, - SizeOf sizeOf) { - this.source = source; - this.channel = channel; - this.queue = queue; - this.eventLoop = channel.eventLoop(); - this.transformer = transformer; - this.sizeOf = sizeOf; - - fuse = queue instanceof Fuseable.QueueSubscription; - } - - private ChannelPromise writeCleanupPromise(V poll) { - return channel - .newPromise() - .addListener( - future -> { - try { - if (requested != Long.MAX_VALUE) { - requested--; - } - requestedUpstream--; - pending--; - - InnerSubscriber is = (InnerSubscriber) INNER_SUBSCRIBER.get(SendPublisher.this); - if (is != null) { - is.tryRequestMoreUpstream(); - tryComplete(is); - } - } finally { - if (poll.refCnt() > 0) { - ReferenceCountUtil.safeRelease(poll); - } - } - }); - } - - private void tryComplete(InnerSubscriber is) { - if (pending == 0 - && completed.get() - && queue.isEmpty() - && terminated == 0 - && !is.pendingFlush.get()) { - TERMINATED.set(SendPublisher.this, 1); - is.destination.onComplete(); - } - } - - @Override - public void subscribe(CoreSubscriber destination) { - InnerSubscriber innerSubscriber = new InnerSubscriber(destination); - if (!INNER_SUBSCRIBER.compareAndSet(this, null, innerSubscriber)) { - Operators.error( - destination, new IllegalStateException("SendPublisher only allows one subscription")); - } else { - InnerSubscription innerSubscription = new InnerSubscription(innerSubscriber); - destination.onSubscribe(innerSubscription); - source.subscribe(innerSubscriber); - } - } - - @FunctionalInterface - interface SizeOf { - int size(V v); - } - - private class InnerSubscriber implements Subscriber { - final CoreSubscriber destination; - volatile Subscription s; - private AtomicBoolean pendingFlush = new AtomicBoolean(); - - private InnerSubscriber(CoreSubscriber destination) { - this.destination = destination; - } - - @Override - public void onSubscribe(Subscription s) { - this.s = s; - s.request(MAX_SIZE); - tryDrain(); - } - - @Override - public void onNext(ByteBuf t) { - if (terminated == 0) { - if (!fuse && !queue.offer(t)) { - throw new IllegalStateException("missing back pressure"); - } - tryDrain(); - } - } - - @Override - public void onError(Throwable t) { - if (TERMINATED.compareAndSet(SendPublisher.this, 0, 1)) { - try { - s.cancel(); - destination.onError(t); - } finally { - if (!queue.isEmpty()) { - queue.forEach(ReferenceCountUtil::safeRelease); - } - } - } - } - - @Override - public void onComplete() { - if (completed.compareAndSet(false, true)) { - tryDrain(); - } - } - - private void tryRequestMoreUpstream() { - if (requestedUpstream <= REFILL_SIZE && s != null) { - long u = MAX_SIZE - requestedUpstream; - requestedUpstream = Operators.addCap(requestedUpstream, u); - s.request(u); - } - } - - private void flush() { - try { - channel.flush(); - pendingFlush.set(false); - tryComplete(this); - } catch (Throwable t) { - onError(t); - } - } - - private void tryDrain() { - if (wip == 0 && terminated == 0 && WIP.getAndIncrement(SendPublisher.this) == 0) { - try { - if (eventLoop.inEventLoop()) { - drain(); - } else { - eventLoop.execute(this::drain); - } - } catch (Throwable t) { - onError(t); - } - } - } - - private void drain() { - try { - boolean scheduleFlush; - int missed = 1; - for (; ; ) { - scheduleFlush = false; - - long r = Math.min(requested, requestedUpstream); - while (r-- > 0) { - ByteBuf ByteBuf = queue.poll(); - if (ByteBuf != null && terminated == 0) { - V poll = transformer.apply(ByteBuf); - int readableBytes = sizeOf.size(poll); - pending++; - if (channel.isWritable() && readableBytes <= channel.bytesBeforeUnwritable()) { - channel.write(poll, writeCleanupPromise(poll)); - scheduleFlush = true; - } else { - scheduleFlush = false; - channel.writeAndFlush(poll, writeCleanupPromise(poll)); - } - - tryRequestMoreUpstream(); - } else { - break; - } - } - - if (scheduleFlush) { - pendingFlush.set(true); - eventLoop.execute(this::flush); - } - - if (terminated == 1) { - break; - } - - missed = WIP.addAndGet(SendPublisher.this, -missed); - if (missed == 0) { - break; - } - } - } catch (Throwable t) { - onError(t); - } - } - } - - private class InnerSubscription implements Subscription { - private final InnerSubscriber innerSubscriber; - - private InnerSubscription(InnerSubscriber innerSubscriber) { - this.innerSubscriber = innerSubscriber; - } - - @Override - public void request(long n) { - if (eventLoop.inEventLoop()) { - requested = Operators.addCap(n, requested); - innerSubscriber.tryDrain(); - } else { - eventLoop.execute(() -> request(n)); - } - } - - @Override - public void cancel() { - TERMINATED.set(SendPublisher.this, 1); - while (!queue.isEmpty()) { - ByteBuf poll = queue.poll(); - if (poll != null) { - ReferenceCountUtil.safeRelease(poll); - } - } - } - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java index 57e3ff0a9..b7081593c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -19,102 +19,86 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.internal.BaseDuplexConnection; import java.util.Objects; import org.reactivestreams.Publisher; -import reactor.core.Disposable; -import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.Connection; -import reactor.netty.FutureMono; /** An implementation of {@link DuplexConnection} that connects via TCP. */ -public final class TcpDuplexConnection implements DuplexConnection { +public final class TcpDuplexConnection extends BaseDuplexConnection { private final Connection connection; - private final Disposable channelClosed; - private final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private final boolean encodeLength; + /** * Creates a new instance * - * @param connection the {@link Connection} to for managing the server + * @param connection the {@link Connection} for managing the server */ public TcpDuplexConnection(Connection connection) { - this.connection = Objects.requireNonNull(connection, "connection must not be null"); - this.channelClosed = - FutureMono.from(connection.channel().closeFuture()) - .doFinally( - s -> { - if (!isDisposed()) { - dispose(); - } - }) - .subscribe(); + this(connection, true); } - @Override - public void dispose() { - connection.dispose(); + /** + * Creates a new instance + * + * @param encodeLength indicates if this connection should encode the length or not. + * @param connection the {@link Connection} to for managing the server + */ + public TcpDuplexConnection(Connection connection, boolean encodeLength) { + this.encodeLength = encodeLength; + this.connection = Objects.requireNonNull(connection, "connection must not be null"); + + connection + .channel() + .closeFuture() + .addListener( + future -> { + if (!isDisposed()) dispose(); + }); } @Override - public boolean isDisposed() { - return connection.isDisposed(); + public ByteBufAllocator alloc() { + return connection.channel().alloc(); } @Override - public Mono onClose() { - return connection - .onDispose() - .doFinally( - s -> { - if (!channelClosed.isDisposed()) { - channelClosed.dispose(); - } - }); + protected void doOnClose() { + if (!connection.isDisposed()) { + connection.dispose(); + } } @Override public Flux receive() { - return connection - .inbound() - .receive() - .map( - byteBuf -> { - ByteBuf frame = FrameLengthFlyweight.frame(byteBuf); - frame.retain(); - return frame; - }); + return connection.inbound().receive().map(this::decode); } @Override public Mono send(Publisher frames) { - return Flux.from(frames) - .transform( - frameFlux -> { - if (frameFlux instanceof Fuseable.QueueSubscription) { - Fuseable.QueueSubscription queueSubscription = - (Fuseable.QueueSubscription) frameFlux; - queueSubscription.requestFusion(Fuseable.ASYNC); - return new SendPublisher<>( - queueSubscription, - frameFlux, - connection.channel(), - frame -> - FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame) - .retain(), - ByteBuf::readableBytes); - } else { - return new SendPublisher<>( - frameFlux, - connection.channel(), - frame -> - FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame) - .retain(), - ByteBuf::readableBytes); - } - }) - .then(); + if (frames instanceof Mono) { + return connection.outbound().sendObject(((Mono) frames).map(this::encode)).then(); + } + return connection.outbound().send(Flux.from(frames).map(this::encode)).then(); + } + + private ByteBuf encode(ByteBuf frame) { + if (encodeLength) { + return FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame); + } else { + return frame; + } + } + + private ByteBuf decode(ByteBuf frame) { + if (encodeLength) { + return FrameLengthCodec.frame(frame).retain(); + } else { + return frame; + } } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java deleted file mode 100644 index d4ebd57b7..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; -import reactor.netty.tcp.TcpServer; - -/** - * An implementation of {@link UriHandler} that creates {@link TcpClientTransport}s and {@link - * TcpServerTransport}s. - */ -public final class TcpUriHandler implements UriHandler { - - private static final String SCHEME = "tcp"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(TcpClientTransport.create(uri.getHost(), uri.getPort())); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of( - TcpServerTransport.create(TcpServer.create().host(uri.getHost()).port(uri.getPort()))); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java index aa94aa0bb..0183ef19d 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java @@ -16,17 +16,15 @@ package io.rsocket.transport.netty; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.rsocket.DuplexConnection; +import io.rsocket.internal.BaseDuplexConnection; import java.util.Objects; import org.reactivestreams.Publisher; -import reactor.core.Disposable; -import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.netty.Connection; -import reactor.netty.FutureMono; -import reactor.util.concurrent.Queues; /** * An implementation of {@link DuplexConnection} that connects via a Websocket. @@ -35,10 +33,9 @@ * for message oriented transports so this must be specifically dropped from Frames sent and * stitched back on for frames received. */ -public final class WebsocketDuplexConnection implements DuplexConnection { +public final class WebsocketDuplexConnection extends BaseDuplexConnection { private final Connection connection; - private final Disposable channelClosed; /** * Creates a new instance @@ -47,37 +44,26 @@ public final class WebsocketDuplexConnection implements DuplexConnection { */ public WebsocketDuplexConnection(Connection connection) { this.connection = Objects.requireNonNull(connection, "connection must not be null"); - this.channelClosed = - FutureMono.from(connection.channel().closeFuture()) - .doFinally( - s -> { - if (!isDisposed()) { - dispose(); - } - }) - .subscribe(); - } - @Override - public void dispose() { - connection.dispose(); + connection + .channel() + .closeFuture() + .addListener( + future -> { + if (!isDisposed()) dispose(); + }); } @Override - public boolean isDisposed() { - return connection.isDisposed(); + public ByteBufAllocator alloc() { + return connection.channel().alloc(); } @Override - public Mono onClose() { - return connection - .onDispose() - .doFinally( - s -> { - if (!channelClosed.isDisposed()) { - channelClosed.dispose(); - } - }); + protected void doOnClose() { + if (!connection.isDisposed()) { + connection.dispose(); + } } @Override @@ -87,32 +73,15 @@ public Flux receive() { @Override public Mono send(Publisher frames) { - return Flux.from(frames) - .transform( - frameFlux -> { - if (frameFlux instanceof Fuseable.QueueSubscription) { - Fuseable.QueueSubscription queueSubscription = - (Fuseable.QueueSubscription) frameFlux; - queueSubscription.requestFusion(Fuseable.ASYNC); - return new SendPublisher<>( - queueSubscription, - frameFlux, - connection.channel(), - this::toBinaryWebSocketFrame, - binaryWebSocketFrame -> binaryWebSocketFrame.content().readableBytes()); - } else { - return new SendPublisher<>( - Queues.small().get(), - frameFlux, - connection.channel(), - this::toBinaryWebSocketFrame, - binaryWebSocketFrame -> binaryWebSocketFrame.content().readableBytes()); - } - }) + if (frames instanceof Mono) { + return connection + .outbound() + .sendObject(((Mono) frames).map(BinaryWebSocketFrame::new)) + .then(); + } + return connection + .outbound() + .sendObject(Flux.from(frames).map(BinaryWebSocketFrame::new)) .then(); } - - private BinaryWebSocketFrame toBinaryWebSocketFrame(ByteBuf frame) { - return new BinaryWebSocketFrame(frame.retain()); - } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java deleted file mode 100644 index 6438c4e28..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static io.rsocket.transport.netty.UriUtils.getPort; -import static io.rsocket.transport.netty.UriUtils.isSecure; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -/** - * An implementation of {@link UriHandler} that creates {@link WebsocketClientTransport}s and {@link - * WebsocketServerTransport}s. - */ -public final class WebsocketUriHandler implements UriHandler { - - private static final List SCHEME = Arrays.asList("ws", "wss", "http", "https"); - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (SCHEME.stream().noneMatch(scheme -> scheme.equals(uri.getScheme()))) { - return Optional.empty(); - } - - return Optional.of(WebsocketClientTransport.create(uri)); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (SCHEME.stream().noneMatch(scheme -> scheme.equals(uri.getScheme()))) { - return Optional.empty(); - } - - int port = isSecure(uri) ? getPort(uri, 443) : getPort(uri, 80); - - return Optional.of(WebsocketServerTransport.create(uri.getHost(), port)); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java index 291494f3b..22f139310 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -17,6 +17,8 @@ package io.rsocket.transport.netty.client; import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; @@ -73,7 +75,7 @@ public static TcpClientTransport create(String bindAddress, int port) { public static TcpClientTransport create(InetSocketAddress address) { Objects.requireNonNull(address, "address must not be null"); - TcpClient tcpClient = TcpClient.create().addressSupplier(() -> address); + TcpClient tcpClient = TcpClient.create().remoteAddress(() -> address); return create(tcpClient); } @@ -91,10 +93,21 @@ public static TcpClientTransport create(TcpClient client) { } @Override - public Mono connect() { - return client - .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec())) - .connect() - .map(TcpDuplexConnection::new); + public Mono connect(int mtu) { + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : client + .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec())) + .connect() + .map( + c -> { + if (mtu > 0) { + return new FragmentationDuplexConnection( + new TcpDuplexConnection(c, false), mtu, true, "client"); + } else { + return new ReassemblyDuplexConnection(new TcpDuplexConnection(c), false); + } + }); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index 111a37e98..747401210 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -16,10 +16,13 @@ package io.rsocket.transport.netty.client; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import static io.rsocket.transport.netty.UriUtils.getPort; import static io.rsocket.transport.netty.UriUtils.isSecure; import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.TransportHeaderAware; @@ -32,6 +35,7 @@ import java.util.function.Supplier; import reactor.core.publisher.Mono; import reactor.netty.http.client.HttpClient; +import reactor.netty.http.client.WebsocketClientSpec; import reactor.netty.tcp.TcpClient; /** @@ -40,9 +44,11 @@ */ public final class WebsocketClientTransport implements ClientTransport, TransportHeaderAware { + private static final String DEFAULT_PATH = "/"; + private final HttpClient client; - private String path; + private final String path; private Supplier> transportHeaders = Collections::emptyMap; @@ -87,7 +93,7 @@ public static WebsocketClientTransport create(String bindAddress, int port) { public static WebsocketClientTransport create(InetSocketAddress address) { Objects.requireNonNull(address, "address must not be null"); - TcpClient client = TcpClient.create().addressSupplier(() -> address); + TcpClient client = TcpClient.create().remoteAddress(() -> address); return create(client); } @@ -115,7 +121,7 @@ public static WebsocketClientTransport create(URI uri) { public static WebsocketClientTransport create(TcpClient client) { Objects.requireNonNull(client, "client must not be null"); - return create(HttpClient.from(client), "/"); + return create(HttpClient.from(client), DEFAULT_PATH); } /** @@ -130,17 +136,41 @@ public static WebsocketClientTransport create(HttpClient client, String path) { Objects.requireNonNull(client, "client must not be null"); Objects.requireNonNull(path, "path must not be null"); + path = path.startsWith(DEFAULT_PATH) ? path : (DEFAULT_PATH + path); + return new WebsocketClientTransport(client, path); } + private static TcpClient createClient(URI uri) { + if (isSecure(uri)) { + return TcpClient.create().secure().host(uri.getHost()).port(getPort(uri, 443)); + } else { + return TcpClient.create().host(uri.getHost()).port(getPort(uri, 80)); + } + } + @Override - public Mono connect() { - return client - .headers(headers -> transportHeaders.get().forEach(headers::set)) - .websocket() - .uri(path) - .connect() - .map(WebsocketDuplexConnection::new); + public Mono connect(int mtu) { + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : client + .headers(headers -> transportHeaders.get().forEach(headers::set)) + .websocket( + WebsocketClientSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build()) + .uri(path) + .connect() + .map( + c -> { + DuplexConnection connection = new WebsocketDuplexConnection(c); + if (mtu > 0) { + connection = + new FragmentationDuplexConnection(connection, mtu, false, "client"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); + } + return connection; + }); } @Override @@ -148,12 +178,4 @@ public void setTransportHeaders(Supplier> transportHeaders) this.transportHeaders = Objects.requireNonNull(transportHeaders, "transportHeaders must not be null"); } - - private static TcpClient createClient(URI uri) { - if (isSecure(uri)) { - return TcpClient.create().secure().host(uri.getHost()).port(getPort(uri, 443)); - } else { - return TcpClient.create().host(uri.getHost()).port(getPort(uri, 80)); - } - } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java new file mode 100644 index 000000000..26fb44535 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java @@ -0,0 +1,40 @@ +package io.rsocket.transport.netty.server; + +import static io.netty.channel.ChannelHandler.*; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.transport.ServerTransport; +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.netty.http.server.HttpServer; + +abstract class BaseWebsocketServerTransport implements ServerTransport { + private static final Logger logger = LoggerFactory.getLogger(BaseWebsocketServerTransport.class); + private static final ChannelHandler pongHandler = new PongHandler(); + + static Function serverConfigurer = + server -> + server.tcpConfiguration( + tcpServer -> + tcpServer.doOnConnection(connection -> connection.addHandlerLast(pongHandler))); + + @Sharable + private static class PongHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof PongWebSocketFrame) { + logger.debug("received WebSocket Pong Frame"); + ReferenceCountUtil.safeRelease(msg); + ctx.read(); + } else { + ctx.fireChannelRead(msg); + } + } + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java index f6e83bc36..c0340c7a2 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java @@ -28,7 +28,7 @@ */ public final class CloseableChannel implements Closeable { - private DisposableChannel channel; + private final DisposableChannel channel; /** * Creates a new instance diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java index 6965499a8..56dd59d45 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -16,6 +16,9 @@ package io.rsocket.transport.netty.server; +import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; @@ -89,17 +92,29 @@ public static TcpServerTransport create(TcpServer server) { } @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); - - return server - .doOnConnection( - c -> { - c.addHandlerLast(new RSocketLengthCodec()); - TcpDuplexConnection connection = new TcpDuplexConnection(c); - acceptor.apply(connection).then(Mono.never()).subscribe(c.disposeSubscriber()); - }) - .bind() - .map(CloseableChannel::new); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : server + .doOnConnection( + c -> { + c.addHandlerLast(new RSocketLengthCodec()); + DuplexConnection connection; + if (mtu > 0) { + connection = + new FragmentationDuplexConnection( + new TcpDuplexConnection(c, false), mtu, true, "server"); + } else { + connection = new ReassemblyDuplexConnection(new TcpDuplexConnection(c), false); + } + acceptor + .apply(connection) + .then(Mono.never()) + .subscribe(c.disposeSubscriber()); + }) + .bind() + .map(CloseableChannel::new); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java index b9bb43e6e..bd19f18b0 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -16,7 +16,12 @@ package io.rsocket.transport.netty.server; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.WebsocketDuplexConnection; import java.util.Objects; @@ -27,6 +32,7 @@ import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServerRoutes; +import reactor.netty.http.server.WebsocketServerSpec; import reactor.netty.http.websocket.WebsocketInbound; import reactor.netty.http.websocket.WebsocketOutbound; @@ -34,7 +40,7 @@ * An implementation of {@link ServerTransport} that connects via Websocket and listens on specified * routes. */ -public final class WebsocketRouteTransport implements ServerTransport { +public final class WebsocketRouteTransport extends BaseWebsocketServerTransport { private final String path; @@ -51,21 +57,23 @@ public final class WebsocketRouteTransport implements ServerTransport */ public WebsocketRouteTransport( HttpServer server, Consumer routesBuilder, String path) { - - this.server = Objects.requireNonNull(server, "server must not be null"); + this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null"); this.path = Objects.requireNonNull(path, "path must not be null"); } @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); return server .route( routes -> { routesBuilder.accept(routes); - routes.ws(path, newHandler(acceptor)); + routes.ws( + path, + newHandler(acceptor, mtu), + WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build()); }) .bind() .map(CloseableChannel::new); @@ -78,13 +86,28 @@ public Mono start(ConnectionAcceptor acceptor) { * @return a new Websocket handler * @throws NullPointerException if {@code acceptor} is {@code null} */ - static BiFunction> newHandler( + public static BiFunction> newHandler( ConnectionAcceptor acceptor) { + return newHandler(acceptor, 0); + } - Objects.requireNonNull(acceptor, "acceptor must not be null"); - + /** + * Creates a new Websocket handler + * + * @param acceptor the {@link ConnectionAcceptor} to use with the handler + * @param mtu the fragment size + * @return a new Websocket handler + * @throws NullPointerException if {@code acceptor} is {@code null} + */ + public static BiFunction> newHandler( + ConnectionAcceptor acceptor, int mtu) { return (in, out) -> { - WebsocketDuplexConnection connection = new WebsocketDuplexConnection((Connection) in); + DuplexConnection connection = new WebsocketDuplexConnection((Connection) in); + if (mtu > 0) { + connection = new FragmentationDuplexConnection(connection, mtu, false, "server"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); + } return acceptor.apply(connection).then(out.neverComplete()); }; } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index b6ef5eaea..1a0b32cf0 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -16,30 +16,41 @@ package io.rsocket.transport.netty.server; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.DuplexConnection; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.TransportHeaderAware; +import io.rsocket.transport.netty.WebsocketDuplexConnection; import java.net.InetSocketAddress; import java.util.Collections; import java.util.Map; import java.util.Objects; import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; +import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.WebsocketServerSpec; /** * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a * Websocket. */ -public final class WebsocketServerTransport - implements ServerTransport, TransportHeaderAware { +public final class WebsocketServerTransport extends BaseWebsocketServerTransport + implements TransportHeaderAware { + private static final Logger logger = LoggerFactory.getLogger(WebsocketServerTransport.class); private final HttpServer server; private Supplier> transportHeaders = Collections::emptyMap; private WebsocketServerTransport(HttpServer server) { - this.server = server; + this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); } /** @@ -88,7 +99,7 @@ public static WebsocketServerTransport create(InetSocketAddress address) { * @return a new instance * @throws NullPointerException if {@code server} is {@code null} */ - public static WebsocketServerTransport create(HttpServer server) { + public static WebsocketServerTransport create(final HttpServer server) { Objects.requireNonNull(server, "server must not be null"); return new WebsocketServerTransport(server); @@ -101,16 +112,33 @@ public void setTransportHeaders(Supplier> transportHeaders) } @Override - public Mono start(ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor, int mtu) { Objects.requireNonNull(acceptor, "acceptor must not be null"); - return server - .handle( - (request, response) -> { - transportHeaders.get().forEach(response::addHeader); - return response.sendWebsocket(WebsocketRouteTransport.newHandler(acceptor)); - }) - .bind() - .map(CloseableChannel::new); + Mono isError = FragmentationDuplexConnection.checkMtu(mtu); + return isError != null + ? isError + : server + .handle( + (request, response) -> { + transportHeaders.get().forEach(response::addHeader); + return response.sendWebsocket( + (in, out) -> { + DuplexConnection connection = + new WebsocketDuplexConnection((Connection) in); + if (mtu > 0) { + connection = + new FragmentationDuplexConnection(connection, mtu, false, "server"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); + } + return acceptor.apply(connection).then(out.neverComplete()); + }, + WebsocketServerSpec.builder() + .maxFramePayloadLength(FRAME_LENGTH_MASK) + .build()); + }) + .bind() + .map(CloseableChannel::new); } } diff --git a/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index ec7ddcb80..000000000 --- a/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.transport.netty.TcpUriHandler -io.rsocket.transport.netty.WebsocketUriHandler diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java new file mode 100644 index 000000000..23041ec65 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java @@ -0,0 +1,184 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.RSocketProxy; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class FragmentTest { + private RSocket handler; + private CloseableChannel server; + private String message = null; + private String metaData = null; + private String responseMessage = null; + + private static Stream cases() { + return Stream.of(Arguments.of(0, 64), Arguments.of(64, 0), Arguments.of(64, 64)); + } + + public void startup(int frameSize) { + int randomPort = ThreadLocalRandom.current().nextInt(10_000, 20_000); + StringBuilder message = new StringBuilder(); + StringBuilder responseMessage = new StringBuilder(); + StringBuilder metaData = new StringBuilder(); + for (int i = 0; i < 100; i++) { + message.append("REQUEST "); + responseMessage.append("RESPONSE "); + metaData.append("METADATA "); + } + this.message = message.toString(); + this.responseMessage = responseMessage.toString(); + this.metaData = metaData.toString(); + + TcpServerTransport serverTransport = TcpServerTransport.create("localhost", randomPort); + server = + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) + .fragment(frameSize) + .bind(serverTransport) + .block(); + } + + private RSocket buildClient(int frameSize) { + return RSocketConnector.create() + .fragment(frameSize) + .connect(TcpClientTransport.create(server.address())) + .block(); + } + + @AfterEach + public void cleanup() { + server.dispose(); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentNoMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + System.out.println( + "-------------------------------------------------testFragmentNoMetaData-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentRequestMetaDataOnly(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + System.out.println( + "-------------------------------------------------testFragmentRequestMetaDataOnly-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message, metaData)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentBothMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + Payload responsePayload = DefaultPayload.create(responseMessage); + System.out.println( + "-------------------------------------------------testFragmentBothMetaData-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage, metaData)); + } + + @Override + public Mono requestResponse(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Mono.just(DefaultPayload.create(responseMessage, metaData)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message, metaData)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java new file mode 100644 index 000000000..b9c0d4f60 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java @@ -0,0 +1,80 @@ +package io.rsocket.transport.netty; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.time.Duration; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +class RSocketFactoryNettyTransportFragmentationTest { + + static Stream> arguments() { + return Stream.of(TcpServerTransport.create(0), WebsocketServerTransport.create(0)); + } + + @ParameterizedTest + @MethodSource("arguments") + void serverSucceedsWithEnabledFragmentationOnSufficientMtu( + ServerTransport serverTransport) { + Mono server = + RSocketServer.create(mockAcceptor()) + .fragment(100) + .bind(serverTransport) + .doOnNext(CloseableChannel::dispose); + StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void serverSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { + Mono server = + RSocketServer.create(mockAcceptor()) + .bind(serverTransport) + .doOnNext(CloseableChannel::dispose); + StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void clientSucceedsWithEnabledFragmentationOnSufficientMtu( + ServerTransport serverTransport) { + CloseableChannel server = + RSocketServer.create(mockAcceptor()).fragment(100).bind(serverTransport).block(); + + Mono rSocket = + RSocketConnector.create() + .fragment(100) + .connect(TcpClientTransport.create(server.address())) + .doFinally(s -> server.dispose()); + StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void clientSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { + CloseableChannel server = RSocketServer.create(mockAcceptor()).bind(serverTransport).block(); + + Mono rSocket = + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) + .doFinally(s -> server.dispose()); + StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + private SocketAcceptor mockAcceptor() { + SocketAcceptor mock = Mockito.mock(SocketAcceptor.class); + Mockito.when(mock.accept(Mockito.any(), Mockito.any())) + .thenReturn(Mono.just(Mockito.mock(RSocket.class))); + return mock; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java index f32d28a0b..6fd3de791 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java @@ -2,8 +2,9 @@ import io.rsocket.ConnectionSetupPayload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; @@ -41,19 +42,15 @@ void rejectSetupTcp( Mono serverRequester = acceptor.requesterRSocket(); CloseableChannel channel = - RSocketFactory.receive() - .acceptor(acceptor) - .transport(serverTransport.apply(new InetSocketAddress("localhost", 0))) - .start() + RSocketServer.create(acceptor) + .bind(serverTransport.apply(new InetSocketAddress("localhost", 0))) .block(Duration.ofSeconds(5)); ErrorConsumer errorConsumer = new ErrorConsumer(); RSocket clientRequester = - RSocketFactory.connect() - .errorConsumer(errorConsumer) - .transport(clientTransport.apply(channel.address())) - .start() + RSocketConnector.connectWith(clientTransport.apply(channel.address())) + .doOnError(errorConsumer) .block(Duration.ofSeconds(5)); StepVerifier.create(errorConsumer.errors().next()) diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java index 719c8e2cf..88c64648c 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java @@ -17,32 +17,81 @@ package io.rsocket.transport.netty; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.test.PerfTest; import io.rsocket.test.PingClient; import io.rsocket.transport.netty.client.TcpClientTransport; import java.time.Duration; import org.HdrHistogram.Recorder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; +@PerfTest public final class TcpPing { + private static final int INTERACTIONS_COUNT = 1_000_000_000; + private static final int port = Integer.valueOf(System.getProperty("RSOCKET_TEST_PORT", "7878")); - public static void main(String... args) { - Mono client = - RSocketFactory.connect() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(TcpClientTransport.create(7878)) - .start(); + @BeforeEach + void setUp() { + System.out.println("Starting ping-pong test (TCP transport)"); + System.out.println("port: " + port); + } - PingClient pingClient = new PingClient(client); + @Test + void requestResponseTest() { + PingClient pingClient = newPingClient(); + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); + pingClient + .requestResponsePingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) + .blockLast(); + } + + @Test + void requestStreamTest() { + PingClient pingClient = newPingClient(); Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); - int count = 1_000_000_000; + pingClient + .requestStreamPingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) + .blockLast(); + } + + @Test + void requestStreamResumableTest() { + PingClient pingClient = newResumablePingClient(); + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); pingClient - .startPingPong(count, recorder) - .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) + .requestStreamPingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) .blockLast(); } + + private static PingClient newPingClient() { + return newPingClient(false); + } + + private static PingClient newResumablePingClient() { + return newPingClient(true); + } + + private static PingClient newPingClient(boolean isResumable) { + RSocketConnector connector = RSocketConnector.create(); + if (isResumable) { + connector.resume(new Resume()); + } + Mono rSocket = + connector + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMinutes(1), Duration.ofMinutes(30)) + .connect(TcpClientTransport.create(port)); + + return new PingClient(rSocket); + } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java index ef5f6dbc0..338868470 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java @@ -16,19 +16,29 @@ package io.rsocket.transport.netty; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingHandler; import io.rsocket.transport.netty.server.TcpServerTransport; public final class TcpPongServer { + private static final boolean isResume = + Boolean.valueOf(System.getProperty("RSOCKET_TEST_RESUME", "false")); + private static final int port = Integer.valueOf(System.getProperty("RSOCKET_TEST_PORT", "7878")); public static void main(String... args) { - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(TcpServerTransport.create(7878)) - .start() + System.out.println("Starting TCP ping-pong server"); + System.out.println("port: " + port); + System.out.println("resume enabled: " + isResume); + + RSocketServer server = RSocketServer.create(new PingHandler()); + if (isResume) { + server.resume(new Resume()); + } + server + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create("localhost", port)) .block() .onClose() .block(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java new file mode 100644 index 000000000..95bebd6aa --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java @@ -0,0 +1,55 @@ +package io.rsocket.transport.netty; + +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; +import reactor.core.Exceptions; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +public class TcpSecureTransportTest implements TransportTest { + private final TransportPair transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE)))), + address -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + TcpServer server = + TcpServer.create() + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return TcpServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(10); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java deleted file mode 100644 index a71cc27f9..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class TcpUriTransportRegistryTest { - - @DisplayName("non-tcp URI does not return TcpClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("amqp://localhost")) - .isNotInstanceOf(TcpClientTransport.class) - .isNotInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("tcp URI returns TcpClientTransport") - @Test - void clientForUriTcp() { - assertThat(UriTransportRegistry.clientForUri("tcp://test:9898")) - .isInstanceOf(TcpClientTransport.class); - } - - @DisplayName("non-tcp URI does not return TcpServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("amqp://localhost")) - .isNotInstanceOf(TcpServerTransport.class) - .isNotInstanceOf(WebsocketServerTransport.class); - } - - @DisplayName("tcp URI returns TcpServerTransport") - @Test - void serverForUriTcp() { - assertThat(UriTransportRegistry.serverForUri("tcp://test:9898")) - .isInstanceOf(TcpServerTransport.class); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java new file mode 100644 index 000000000..2deb4a4a8 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java @@ -0,0 +1,128 @@ +package io.rsocket.transport.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.websocketx.*; +import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.URI; + +/** + * This is an example of a WebSocket client. + * + *

In order to run this example you need a compatible WebSocket server. Therefore you can either + * start the WebSocket server from the examples or connect to an existing WebSocket server such as + * ws://echo.websocket.org. + * + *

The client will attempt to connect to the URI passed to it as the first argument. You don't + * have to specify any arguments if you want to connect to the example WebSocket server, as this is + * the default. + */ +public final class WebSocketClient { + + static final String URL = System.getProperty("url", "ws://127.0.0.1:7878/websocket"); + + public static void main(String[] args) throws Exception { + URI uri = new URI(URL); + String scheme = uri.getScheme() == null ? "ws" : uri.getScheme(); + final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost(); + final int port; + if (uri.getPort() == -1) { + if ("ws".equalsIgnoreCase(scheme)) { + port = 80; + } else if ("wss".equalsIgnoreCase(scheme)) { + port = 443; + } else { + port = -1; + } + } else { + port = uri.getPort(); + } + + if (!"ws".equalsIgnoreCase(scheme) && !"wss".equalsIgnoreCase(scheme)) { + System.err.println("Only WS(S) is supported."); + return; + } + + final boolean ssl = "wss".equalsIgnoreCase(scheme); + final SslContext sslCtx; + if (ssl) { + sslCtx = + SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + } else { + sslCtx = null; + } + + EventLoopGroup group = new NioEventLoopGroup(); + try { + // Connect with V13 (RFC 6455 aka HyBi-17). You can change it to V08 or V00. + // If you change it to V00, ping is not supported and remember to change + // HttpResponseDecoder to WebSocketHttpResponseDecoder in the pipeline. + final WebSocketClientHandler handler = + new WebSocketClientHandler( + WebSocketClientHandshakerFactory.newHandshaker( + uri, WebSocketVersion.V13, null, true, new DefaultHttpHeaders())); + + Bootstrap b = new Bootstrap(); + b.group(group) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline p = ch.pipeline(); + if (sslCtx != null) { + p.addLast(sslCtx.newHandler(ch.alloc(), host, port)); + } + p.addLast( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + WebSocketClientCompressionHandler.INSTANCE, + handler); + } + }); + + Channel ch = b.connect(uri.getHost(), port).sync().channel(); + handler.handshakeFuture().sync(); + + BufferedReader console = new BufferedReader(new InputStreamReader(System.in)); + while (true) { + String msg = console.readLine(); + if (msg == null) { + break; + } else if ("bye".equals(msg.toLowerCase())) { + ch.writeAndFlush(new CloseWebSocketFrame()); + ch.closeFuture().sync(); + break; + } else if ("ping".equals(msg.toLowerCase())) { + WebSocketFrame frame = + new PingWebSocketFrame(Unpooled.wrappedBuffer(new byte[] {8, 1, 8, 1})); + ch.writeAndFlush(frame); + } else if ("pong".equals(msg.toLowerCase())) { + WebSocketFrame frame = + new PongWebSocketFrame(Unpooled.wrappedBuffer(new byte[] {8, 1, 8, 1})); + ch.writeAndFlush(frame); + } else { + WebSocketFrame frame = new TextWebSocketFrame(msg); + ch.writeAndFlush(frame); + } + } + } finally { + group.shutdownGracefully(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java new file mode 100644 index 000000000..092cad2c7 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java @@ -0,0 +1,90 @@ +package io.rsocket.transport.netty; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; +import io.netty.util.CharsetUtil; + +public class WebSocketClientHandler extends SimpleChannelInboundHandler { + + private final WebSocketClientHandshaker handshaker; + private ChannelPromise handshakeFuture; + + public WebSocketClientHandler(WebSocketClientHandshaker handshaker) { + this.handshaker = handshaker; + } + + public ChannelFuture handshakeFuture() { + return handshakeFuture; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + handshakeFuture = ctx.newPromise(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + handshaker.handshake(ctx.channel()); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + System.out.println("WebSocket Client disconnected!"); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + Channel ch = ctx.channel(); + if (!handshaker.isHandshakeComplete()) { + try { + handshaker.finishHandshake(ch, (FullHttpResponse) msg); + System.out.println("WebSocket Client connected!"); + handshakeFuture.setSuccess(); + } catch (WebSocketHandshakeException e) { + System.out.println("WebSocket Client failed to connect"); + handshakeFuture.setFailure(e); + } + return; + } + + if (msg instanceof FullHttpResponse) { + FullHttpResponse response = (FullHttpResponse) msg; + throw new IllegalStateException( + "Unexpected FullHttpResponse (getStatus=" + + response.status() + + ", content=" + + response.content().toString(CharsetUtil.UTF_8) + + ')'); + } + + WebSocketFrame frame = (WebSocketFrame) msg; + if (frame instanceof TextWebSocketFrame) { + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + System.out.println("WebSocket Client received message: " + textFrame.text()); + } else if (frame instanceof PongWebSocketFrame) { + System.out.println("WebSocket Client received pong"); + } else if (frame instanceof CloseWebSocketFrame) { + System.out.println("WebSocket Client received closing"); + ch.close(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + if (!handshakeFuture.isDone()) { + handshakeFuture.setFailure(cause); + } + ctx.close(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java new file mode 100644 index 000000000..c418dea0f --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java @@ -0,0 +1,49 @@ +package io.rsocket.transport.netty; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketRouteTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.net.URI; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +public class WebSocketTransportIntegrationTest { + + @Test + public void sendStreamOfDataWithExternalHttpServerTest() { + ServerTransport.ConnectionAcceptor acceptor = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(0, 10).map(i -> DefaultPayload.create(String.valueOf(i))))) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .route(router -> router.ws("/test", WebsocketRouteTransport.newHandler(acceptor))) + .bindNow(); + + RSocket rsocket = + RSocketConnector.connectWith( + WebsocketClientTransport.create( + URI.create("ws://" + server.host() + ":" + server.port() + "/test"))) + .block(); + + StepVerifier.create(rsocket.requestStream(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectNextCount(10) + .expectComplete() + .verify(Duration.ofMillis(1000)); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java index 9b03d1fe2..a784a43c0 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ package io.rsocket.transport.netty; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingClient; import io.rsocket.transport.netty.client.WebsocketClientTransport; import java.time.Duration; @@ -28,7 +29,9 @@ public final class WebsocketPing { public static void main(String... args) { Mono client = - RSocketFactory.connect().transport(WebsocketClientTransport.create(7878)).start(); + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(WebsocketClientTransport.create(7878)); PingClient pingClient = new PingClient(client); @@ -37,7 +40,7 @@ public static void main(String... args) { int count = 1_000_000_000; pingClient - .startPingPong(count, recorder) + .requestResponsePingPong(count, recorder) .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) .blockLast(); } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java new file mode 100644 index 000000000..e2ee9e521 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java @@ -0,0 +1,152 @@ +package io.rsocket.transport.netty; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketRouteTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +public class WebsocketPingPongIntegrationTest { + private static final String host = "localhost"; + private static final int port = 8088; + + private Closeable server; + + @AfterEach + void tearDown() { + server.dispose(); + } + + @ParameterizedTest + @MethodSource("provideServerTransport") + void webSocketPingPong(ServerTransport serverTransport) { + server = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .bind(serverTransport) + .block(); + + String expectedData = "data"; + String expectedPing = "ping"; + + PingSender pingSender = new PingSender(); + + HttpClient httpClient = + HttpClient.create() + .tcpConfiguration( + tcpClient -> + tcpClient + .doOnConnected(b -> b.addHandlerLast(pingSender)) + .host(host) + .port(port)); + + RSocket rSocket = + RSocketConnector.connectWith(WebsocketClientTransport.create(httpClient, "/")).block(); + + rSocket + .requestResponse(DefaultPayload.create(expectedData)) + .delaySubscription(pingSender.sendPing(expectedPing)) + .as(StepVerifier::create) + .expectNextMatches(p -> expectedData.equals(p.getDataUtf8())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + pingSender + .receivePong() + .as(StepVerifier::create) + .expectNextMatches(expectedPing::equals) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + rSocket + .requestResponse(DefaultPayload.create(expectedData)) + .delaySubscription(pingSender.sendPong()) + .as(StepVerifier::create) + .expectNextMatches(p -> expectedData.equals(p.getDataUtf8())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + private static Stream provideServerTransport() { + return Stream.of( + Arguments.of(WebsocketServerTransport.create(host, port)), + Arguments.of( + new WebsocketRouteTransport( + HttpServer.create().host(host).port(port), routes -> {}, "/"))); + } + + private static class PingSender extends ChannelInboundHandlerAdapter { + private final MonoProcessor channel = MonoProcessor.create(); + private final MonoProcessor pong = MonoProcessor.create(); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof PongWebSocketFrame) { + pong.onNext(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); + ReferenceCountUtil.safeRelease(msg); + ctx.read(); + } else { + super.channelRead(ctx, msg); + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + Channel ch = ctx.channel(); + if (!channel.isTerminated() && ch.isWritable()) { + channel.onNext(ctx.channel()); + } + super.channelWritabilityChanged(ctx); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + Channel ch = ctx.channel(); + if (ch.isWritable()) { + channel.onNext(ch); + } + super.handlerAdded(ctx); + } + + public Mono sendPing(String data) { + return send( + new PingWebSocketFrame(Unpooled.wrappedBuffer(data.getBytes(StandardCharsets.UTF_8)))); + } + + public Mono sendPong() { + return send(new PongWebSocketFrame()); + } + + public Mono receivePong() { + return pong; + } + + private Mono send(WebSocketFrame webSocketFrame) { + return channel.doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java index c94a8c539..84dc816be 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,17 +16,17 @@ package io.rsocket.transport.netty; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingHandler; import io.rsocket.transport.netty.server.WebsocketServerTransport; public final class WebsocketPongServer { public static void main(String... args) { - RSocketFactory.receive() - .acceptor(new PingHandler()) - .transport(WebsocketServerTransport.create(7878)) - .start() + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(WebsocketServerTransport.create(7878)) .block() .onClose() .block(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java index c1d608979..ec33060b2 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java @@ -38,7 +38,7 @@ final class WebsocketSecureTransportTest implements TransportTest { (address, server) -> WebsocketClientTransport.create( HttpClient.create() - .addressSupplier(server::address) + .remoteAddress(server::address) .secure( ssl -> ssl.sslContext( @@ -53,7 +53,7 @@ final class WebsocketSecureTransportTest implements TransportTest { HttpServer server = HttpServer.from( TcpServer.create() - .addressSupplier(() -> address) + .bindAddress(() -> address) .secure( ssl -> ssl.sslContext( diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java deleted file mode 100644 index 72a700b0e..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; - -final class WebsocketUriHandlerTest implements UriHandlerTest { - - @Override - public String getInvalidUri() { - return "amqp://test"; - } - - @Override - public UriHandler getUriHandler() { - return new WebsocketUriHandler(); - } - - @Override - public String getValidUri() { - return "ws://test:9898"; - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java deleted file mode 100644 index 5688f14ed..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class WebsocketUriTransportRegistryTest { - - @DisplayName("non-ws URI does not return WebsocketClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("amqp://localhost")) - .isNotInstanceOf(TcpClientTransport.class) - .isNotInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("ws URI returns WebsocketClientTransport") - @Test - void clientForUriWebsocket() { - assertThat(UriTransportRegistry.clientForUri("ws://test:9898")) - .isInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("non-ws URI does not return WebsocketServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("amqp://localhost")) - .isNotInstanceOf(TcpServerTransport.class) - .isNotInstanceOf(WebsocketServerTransport.class); - } - - @DisplayName("ws URI returns WebsocketServerTransport") - @Test - void serverForUriWebsocket() { - assertThat(UriTransportRegistry.serverForUri("ws://test:9898")) - .isInstanceOf(WebsocketServerTransport.class); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java index 388001fb6..e0bdb9cd7 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java @@ -37,8 +37,8 @@ void connect() { TcpServerTransport serverTransport = TcpServerTransport.create(address); serverTransport - .start(duplexConnection -> Mono.empty()) - .flatMap(context -> TcpClientTransport.create(context.address()).connect()) + .start(duplexConnection -> Mono.empty(), 0) + .flatMap(context -> TcpClientTransport.create(context.address()).connect(0)) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -47,7 +47,7 @@ void connect() { @DisplayName("create generates error if server not started") @Test void connectNoServer() { - TcpClientTransport.create(8000).connect().as(StepVerifier::create).verifyError(); + TcpClientTransport.create(8000).connect(0).as(StepVerifier::create).verifyError(); } @DisplayName("creates client with BindAddress") diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java index 202c5b3f3..fc035c536 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java @@ -16,6 +16,7 @@ package io.rsocket.transport.netty.client; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; @@ -23,14 +24,67 @@ import java.net.InetSocketAddress; import java.net.URI; import java.util.Collections; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; import reactor.core.publisher.Mono; import reactor.netty.http.client.HttpClient; import reactor.test.StepVerifier; +@ExtendWith(MockitoExtension.class) final class WebsocketClientTransportTest { + @Test + @Disabled + public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() { + ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class); + HttpClient httpClient = Mockito.spy(HttpClient.create()); + Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any()); + Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture()); + + WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, ""); + + clientTransport.connect(0).subscribe(); + + Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK); + } + + @Test + @Disabled + public void testThatSetupWithSpecifiedFrameSizeButLowerThanWsDefaultShouldSetToWsDefault() { + ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class); + HttpClient httpClient = Mockito.spy(HttpClient.create()); + Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any()); + Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture()); + + WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, ""); + + clientTransport.connect(65536 - 10000).subscribe(); + + Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK); + } + + @Test + @Disabled + public void + testThatSetupWithSpecifiedFrameSizeButHigherThanWsDefaultShouldSetToSpecifiedFrameSize() { + ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class); + HttpClient httpClient = Mockito.spy(HttpClient.create()); + Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any()); + Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture()); + + WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, ""); + + clientTransport.connect(65536 + 10000).subscribe(); + + Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK); + } + @DisplayName("connects to server") @Test void connect() { @@ -39,8 +93,8 @@ void connect() { WebsocketServerTransport serverTransport = WebsocketServerTransport.create(address); serverTransport - .start(duplexConnection -> Mono.empty()) - .flatMap(context -> WebsocketClientTransport.create(context.address()).connect()) + .start(duplexConnection -> Mono.empty(), 0) + .flatMap(context -> WebsocketClientTransport.create(context.address()).connect(0)) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -49,19 +103,31 @@ void connect() { @DisplayName("create generates error if server not started") @Test void connectNoServer() { - WebsocketClientTransport.create(8000).connect().as(StepVerifier::create).verifyError(); + WebsocketClientTransport.create(8000).connect(0).as(StepVerifier::create).verifyError(); } @DisplayName("creates client with BindAddress") @Test void createBindAddress() { - assertThat(WebsocketClientTransport.create("test-bind-address", 8000)).isNotNull(); + assertThat(WebsocketClientTransport.create("test-bind-address", 8000)) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); } @DisplayName("creates client with HttpClient") @Test void createHttpClient() { - assertThat(WebsocketClientTransport.create(HttpClient.create(), "/")).isNotNull(); + assertThat(WebsocketClientTransport.create(HttpClient.create(), "/")) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("creates client with HttpClient and path without root") + @Test + void createHttpClientWithPathWithoutRoot() { + assertThat(WebsocketClientTransport.create(HttpClient.create(), "test")) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/test"); } @DisplayName("creates client with InetSocketAddress") @@ -70,7 +136,8 @@ void createInetSocketAddress() { assertThat( WebsocketClientTransport.create( InetSocketAddress.createUnresolved("test-bind-address", 8000))) - .isNotNull(); + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); } @DisplayName("create throws NullPointerException with null bindAddress") @@ -122,7 +189,17 @@ void createPort() { @DisplayName("creates client with URI") @Test void createUri() { - assertThat(WebsocketClientTransport.create(URI.create("ws://test-host/"))).isNotNull(); + assertThat(WebsocketClientTransport.create(URI.create("ws://test-host"))) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("creates client with URI path") + @Test + void createUriPath() { + assertThat(WebsocketClientTransport.create(URI.create("ws://test-host/test"))) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/test"); } @DisplayName("sets transport headers") diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java index 15a216b96..b6cbfea34 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java @@ -70,7 +70,7 @@ void createNullTcpClient() { @DisplayName("creates server with port") @Test void createPort() { - assertThat(TcpServerTransport.create(8000)).isNotNull(); + assertThat(TcpServerTransport.create("localhost", 8000)).isNotNull(); } @DisplayName("creates client with TcpServer") @@ -87,7 +87,7 @@ void start() { TcpServerTransport serverTransport = TcpServerTransport.create(address); serverTransport - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -97,7 +97,7 @@ void start() { @Test void startNullAcceptor() { assertThatNullPointerException() - .isThrownBy(() -> TcpServerTransport.create(8000).start(null)) + .isThrownBy(() -> TcpServerTransport.create("localhost", 8000).start(null, 0)) .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java index 66822890a..e94bef13c 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java @@ -16,7 +16,6 @@ package io.rsocket.transport.netty.server; -import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; import org.junit.jupiter.api.DisplayName; @@ -57,20 +56,6 @@ void constructorNullServer() { .withMessage("server must not be null"); } - @DisplayName("creates a new handler") - @Test - void newHandler() { - assertThat(WebsocketRouteTransport.newHandler(duplexConnection -> null)).isNotNull(); - } - - @DisplayName("newHandler throws NullPointerException with null acceptor") - @Test - void newHandlerNullAcceptor() { - assertThatNullPointerException() - .isThrownBy(() -> WebsocketRouteTransport.newHandler(null)) - .withMessage("acceptor must not be null"); - } - @DisplayName("starts server") @Test void start() { @@ -78,7 +63,7 @@ void start() { new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path"); serverTransport - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -91,7 +76,7 @@ void startNullAcceptor() { .isThrownBy( () -> new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path") - .start(null)) + .start(null, 0)) .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java index d1a6b374e..249a3e12a 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java @@ -16,19 +16,89 @@ package io.rsocket.transport.netty.server; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; import java.net.InetSocketAddress; import java.util.Collections; +import java.util.function.BiFunction; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; import reactor.core.publisher.Mono; import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.HttpServerRequest; +import reactor.netty.http.server.HttpServerResponse; import reactor.test.StepVerifier; final class WebsocketServerTransportTest { + // @Test + public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() { + ArgumentCaptor captor = ArgumentCaptor.forClass(BiFunction.class); + HttpServer httpServer = Mockito.spy(HttpServer.create()); + Mockito.doAnswer(a -> httpServer).when(httpServer).handle(captor.capture()); + Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind(); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(httpServer); + + serverTransport.start(c -> Mono.empty(), 0).subscribe(); + + HttpServerRequest httpServerRequest = Mockito.mock(HttpServerRequest.class); + HttpServerResponse httpServerResponse = Mockito.mock(HttpServerResponse.class); + + captor.getValue().apply(httpServerRequest, httpServerResponse); + + Mockito.verify(httpServerResponse) + .sendWebsocket( + Mockito.nullable(String.class), Mockito.eq(FRAME_LENGTH_MASK), Mockito.any()); + } + + // @Test + public void testThatSetupWithSpecifiedFrameSizeButLowerThanWsDefaultShouldSetToWsDefault() { + ArgumentCaptor captor = ArgumentCaptor.forClass(BiFunction.class); + HttpServer httpServer = Mockito.spy(HttpServer.create()); + Mockito.doAnswer(a -> httpServer).when(httpServer).handle(captor.capture()); + Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind(); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(httpServer); + + serverTransport.start(c -> Mono.empty(), 1000).subscribe(); + + HttpServerRequest httpServerRequest = Mockito.mock(HttpServerRequest.class); + HttpServerResponse httpServerResponse = Mockito.mock(HttpServerResponse.class); + + captor.getValue().apply(httpServerRequest, httpServerResponse); + + Mockito.verify(httpServerResponse) + .sendWebsocket( + Mockito.nullable(String.class), Mockito.eq(FRAME_LENGTH_MASK), Mockito.any()); + } + + // @Test + public void + testThatSetupWithSpecifiedFrameSizeButHigherThanWsDefaultShouldSetToSpecifiedFrameSize() { + ArgumentCaptor captor = ArgumentCaptor.forClass(BiFunction.class); + HttpServer httpServer = Mockito.spy(HttpServer.create()); + Mockito.doAnswer(a -> httpServer).when(httpServer).handle(captor.capture()); + Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind(); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(httpServer); + + serverTransport.start(c -> Mono.empty(), 65536 + 1000).subscribe(); + + HttpServerRequest httpServerRequest = Mockito.mock(HttpServerRequest.class); + HttpServerResponse httpServerResponse = Mockito.mock(HttpServerResponse.class); + + captor.getValue().apply(httpServerRequest, httpServerResponse); + + Mockito.verify(httpServerResponse) + .sendWebsocket( + Mockito.nullable(String.class), Mockito.eq(FRAME_LENGTH_MASK), Mockito.any()); + } + @DisplayName("creates server with BindAddress") @Test void createBindAddress() { @@ -102,7 +172,7 @@ void start() { WebsocketServerTransport serverTransport = WebsocketServerTransport.create(address); serverTransport - .start(duplexConnection -> Mono.empty()) + .start(duplexConnection -> Mono.empty(), 0) .as(StepVerifier::create) .expectNextCount(1) .verifyComplete(); @@ -112,7 +182,7 @@ void start() { @Test void startNullAcceptor() { assertThatNullPointerException() - .isThrownBy(() -> WebsocketServerTransport.create(8000).start(null)) + .isThrownBy(() -> WebsocketServerTransport.create(8000).start(null, 0)) .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-netty/src/test/resources/logback-test.xml b/rsocket-transport-netty/src/test/resources/logback-test.xml index 49b11d6fb..f9dec2bbe 100644 --- a/rsocket-transport-netty/src/test/resources/logback-test.xml +++ b/rsocket-transport-netty/src/test/resources/logback-test.xml @@ -24,6 +24,8 @@ + + diff --git a/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 000000000..ca6ee9cea --- /dev/null +++ b/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +mock-maker-inline \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 16630076a..25c3feee5 100644 --- a/settings.gradle +++ b/settings.gradle @@ -13,13 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +plugins { + id 'com.gradle.enterprise' version '3.1' +} rootProject.name = 'rsocket-java' include 'rsocket-core' -include 'rsocket-examples' include 'rsocket-load-balancer' include 'rsocket-micrometer' include 'rsocket-test' include 'rsocket-transport-local' include 'rsocket-transport-netty' +include 'rsocket-bom' + +include 'rsocket-examples' +include 'benchmarks' + + + +gradleEnterprise { + buildScan { + termsOfServiceUrl = 'https://gradle.com/terms-of-service' + termsOfServiceAgree = 'yes' + } +} +