diff --git a/.travis.yml b/.travis.yml index 116d2d2ff..4722957c8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -23,7 +23,7 @@ matrix: - jdk: openjdk8 - jdk: openjdk11 env: SKIP_RELEASE=true - - jdk: openjdk12 + - jdk: openjdk14 env: SKIP_RELEASE=true env: diff --git a/README.md b/README.md index 173c3e1ad..f8110a31e 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,26 @@ Releases are available via Maven Central. Example: ```groovy +repositories { + mavenCentral() +} +dependencies { + implementation 'io.rsocket:rsocket-core:1.0.0' + implementation 'io.rsocket:rsocket-transport-netty:1.0.0' +} +``` + +Snapshots are available via [oss.jfrog.org](oss.jfrog.org) (OJO). + +Example: + +```groovy +repositories { + maven { url 'https://oss.jfrog.org/oss-snapshot-local' } +} dependencies { - implementation 'io.rsocket:rsocket-core:1.0.0-RC3' - implementation 'io.rsocket:rsocket-transport-netty:1.0.0-RC3' -// implementation 'io.rsocket:rsocket-core:1.0.0-RC4-SNAPSHOT' -// implementation 'io.rsocket:rsocket-transport-netty:1.0.0-RC4-SNAPSHOT' + implementation 'io.rsocket:rsocket-core:1.0.1-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.0.1-SNAPSHOT' } ``` @@ -57,7 +72,7 @@ package io.rsocket.transport.netty; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.util.DefaultPayload; import reactor.core.publisher.Flux; @@ -67,14 +82,14 @@ import java.net.URI; public class ExampleClient { public static void main(String[] args) { WebsocketClientTransport ws = WebsocketClientTransport.create(URI.create("ws://rsocket-demo.herokuapp.com/ws")); - RSocket client = RSocketFactory.connect().keepAlive().transport(ws).start().block(); + RSocket clientRSocket = RSocketConnector.connectWith(ws).block(); try { - Flux s = client.requestStream(DefaultPayload.create("peace")); + Flux s = clientRSocket.requestStream(DefaultPayload.create("peace")); s.take(10).doOnNext(p -> System.out.println(p.getDataUtf8())).blockLast(); } finally { - client.dispose(); + clientRSocket.dispose(); } } } @@ -89,12 +104,10 @@ or you will get a memory leak. Used correctly this will reduce latency and incre ### Example Server setup ```java -RSocketFactory.receive() +RSocketServer.create(new PingHandler()) // Enable Zero Copy - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(TcpServerTransport.create(7878)) - .start() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(7878)) .block() .onClose() .block(); @@ -102,12 +115,12 @@ RSocketFactory.receive() ### Example Client setup ```java -Mono client = - RSocketFactory.connect() +RSocket clientRSocket = + RSocketConnector.create() // Enable Zero Copy - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(TcpClientTransport.create(7878)) - .start(); + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(TcpClientTransport.create(7878)) + .block(); ``` ## Bugs and Feedback diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..6ba6755a6 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,47 @@ +## Usage of JMH tasks + +Only execute specific benchmark(s) (wildcards are added before and after): +``` +../gradlew jmh --include="(BenchmarkPrimary|OtherBench)" +``` +If you want to specify the wildcards yourself, you can pass the full regexp: +``` +../gradlew jmh --fullInclude=.*MyBenchmark.* +``` + +Specify extra profilers: +``` +../gradlew jmh --profilers="gc,stack" +``` + +Prominent profilers (for full list call `jmhProfilers` task): +- comp - JitCompilations, tune your iterations +- stack - which methods used most time +- gc - print garbage collection stats +- hs_thr - thread usage + +Change report format from JSON to one of [CSV, JSON, NONE, SCSV, TEXT]: +``` +./gradlew jmh --format=csv +``` + +Specify JVM arguments: +``` +../gradlew jmh --jvmArgs="-Dtest.cluster=local" +``` + +Run in verification mode (execute benchmarks with minimum of fork/warmup-/benchmark-iterations): +``` +../gradlew jmh --verify=true +``` + +## Comparing with the baseline +If you wish you run two sets of benchmarks, one for the current change and another one for the "baseline", +there is an additional task `jmhBaseline` that will use the latest release: +``` +../gradlew jmh jmhBaseline --include=MyBenchmark +``` + +## Resources +- http://tutorials.jenkov.com/java-performance/jmh.html (Introduction) +- http://hg.openjdk.java.net/code-tools/jmh/file/tip/jmh-samples/src/main/java/org/openjdk/jmh/samples/ (Samples) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle new file mode 100644 index 000000000..f07f7c6f5 --- /dev/null +++ b/benchmarks/build.gradle @@ -0,0 +1,167 @@ +apply plugin: 'java' +apply plugin: 'idea' + +configurations { + current + baseline { + resolutionStrategy.cacheChangingModulesFor 0, 'seconds' + } +} + +dependencies { + // Use the baseline to avoid using new APIs in the benchmarks + compileOnly "io.rsocket:rsocket-core:${perfBaselineVersion}" + compileOnly "io.rsocket:rsocket-transport-local:${perfBaselineVersion}" + + implementation "org.openjdk.jmh:jmh-core:1.21" + annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.21" + + current project(':rsocket-core') + current project(':rsocket-transport-local') + baseline "io.rsocket:rsocket-core:${perfBaselineVersion}", { + changing = true + } + baseline "io.rsocket:rsocket-transport-local:${perfBaselineVersion}", { + changing = true + } +} + +task jmhProfilers(type: JavaExec, description:'Lists the available profilers for the jmh task', group: 'Development') { + classpath = sourceSets.main.runtimeClasspath + main = 'org.openjdk.jmh.Main' + args '-lprof' +} + +task jmh(type: JmhExecTask, description: 'Executing JMH benchmarks') { + classpath = sourceSets.main.runtimeClasspath + configurations.current +} + +task jmhBaseline(type: JmhExecTask, description: 'Executing JMH baseline benchmarks') { + classpath = sourceSets.main.runtimeClasspath + configurations.baseline +} + +clean { + delete "${projectDir}/src/main/generated" +} + +class JmhExecTask extends JavaExec { + + private String include; + private String fullInclude; + private String exclude; + private String format = "json"; + private String profilers; + private String jmhJvmArgs; + private String verify; + + public JmhExecTask() { + super(); + } + + public String getInclude() { + return include; + } + + @Option(option = "include", description="configure bench inclusion using substring") + public void setInclude(String include) { + this.include = include; + } + + public String getFullInclude() { + return fullInclude; + } + + @Option(option = "fullInclude", description = "explicitly configure bench inclusion using full JMH style regexp") + public void setFullInclude(String fullInclude) { + this.fullInclude = fullInclude; + } + + public String getExclude() { + return exclude; + } + + @Option(option = "exclude", description = "explicitly configure bench exclusion using full JMH style regexp") + public void setExclude(String exclude) { + this.exclude = exclude; + } + + public String getFormat() { + return format; + } + + @Option(option = "format", description = "configure report format") + public void setFormat(String format) { + this.format = format; + } + + public String getProfilers() { + return profilers; + } + + @Option(option = "profilers", description = "configure jmh profiler(s) to use, comma separated") + public void setProfilers(String profilers) { + this.profilers = profilers; + } + + public String getJmhJvmArgs() { + return jmhJvmArgs; + } + + @Option(option = "jvmArgs", description = "configure additional JMH JVM arguments, comma separated") + public void setJmhJvmArgs(String jvmArgs) { + this.jmhJvmArgs = jvmArgs; + } + + public String getVerify() { + return verify; + } + + @Option(option = "verify", description = "run in verify mode") + public void setVerify(String verify) { + this.verify = verify; + } + + @TaskAction + public void exec() { + setMain("org.openjdk.jmh.Main"); + File resultFile = getProject().file("build/reports/" + getName() + "/result." + format); + + if (include != null) { + args(".*" + include + ".*"); + } + else if (fullInclude != null) { + args(fullInclude); + } + + if(exclude != null) { + args("-e", exclude); + } + if(verify != null) { // execute benchmarks with the minimum amount of execution (only to check if they are working) + System.out.println("Running in verify mode"); + args("-f", 1); + args("-wi", 1); + args("-i", 1); + } + args("-foe", "true"); //fail-on-error + args("-v", "NORMAL"); //verbosity [SILENT, NORMAL, EXTRA] + if(profilers != null) { + for (String prof : profilers.split(",")) { + args("-prof", prof); + } + } + args("-jvmArgsPrepend", "-Xmx3072m"); + args("-jvmArgsPrepend", "-Xms3072m"); + if(jmhJvmArgs != null) { + for(String jvmArg : jmhJvmArgs.split(" ")) { + args("-jvmArgsPrepend", jvmArg); + } + } + args("-rf", format); + args("-rff", resultFile); + + System.out.println("\nExecuting JMH with: " + getArgs() + "\n"); + resultFile.getParentFile().mkdirs(); + + super.exec(); + } +} diff --git a/rsocket-core/src/jmh/java/io/rsocket/MaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java similarity index 71% rename from rsocket-core/src/jmh/java/io/rsocket/MaxPerfSubscriber.java rename to benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java index ace985a39..2e6fa6acc 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/MaxPerfSubscriber.java +++ b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java @@ -5,12 +5,12 @@ import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; -public class MaxPerfSubscriber implements CoreSubscriber { +public class MaxPerfSubscriber extends CountDownLatch implements CoreSubscriber { - final CountDownLatch latch = new CountDownLatch(1); final Blackhole blackhole; public MaxPerfSubscriber(Blackhole blackhole) { + super(1); this.blackhole = blackhole; } @@ -20,19 +20,18 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(Payload payload) { - payload.release(); + public void onNext(T payload) { blackhole.consume(payload); } @Override public void onError(Throwable t) { blackhole.consume(t); - latch.countDown(); + countDown(); } @Override public void onComplete() { - latch.countDown(); + countDown(); } } diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java new file mode 100644 index 000000000..7a7a1fdd6 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsMaxPerfSubscriber extends MaxPerfSubscriber { + + public PayloadsMaxPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java new file mode 100644 index 000000000..efc116958 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsPerfSubscriber extends PerfSubscriber { + + public PayloadsPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/rsocket-core/src/jmh/java/io/rsocket/PerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java similarity index 72% rename from rsocket-core/src/jmh/java/io/rsocket/PerfSubscriber.java rename to benchmarks/src/main/java/io/rsocket/PerfSubscriber.java index 98c5edd3b..92577d95c 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/PerfSubscriber.java +++ b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java @@ -5,14 +5,14 @@ import org.reactivestreams.Subscription; import reactor.core.CoreSubscriber; -public class PerfSubscriber implements CoreSubscriber { +public class PerfSubscriber extends CountDownLatch implements CoreSubscriber { - final CountDownLatch latch = new CountDownLatch(1); final Blackhole blackhole; Subscription s; public PerfSubscriber(Blackhole blackhole) { + super(1); this.blackhole = blackhole; } @@ -23,8 +23,7 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(Payload payload) { - payload.release(); + public void onNext(T payload) { blackhole.consume(payload); s.request(1); } @@ -32,11 +31,11 @@ public void onNext(Payload payload) { @Override public void onError(Throwable t) { blackhole.consume(t); - latch.countDown(); + countDown(); } @Override public void onComplete() { - latch.countDown(); + countDown(); } } diff --git a/rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java similarity index 57% rename from rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java rename to benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java index 476d6c814..f78843f5b 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java +++ b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java @@ -1,18 +1,29 @@ -package io.rsocket; - +package io.rsocket.core; + +import io.rsocket.AbstractRSocket; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.PayloadsMaxPerfSubscriber; +import io.rsocket.PayloadsPerfSubscriber; +import io.rsocket.RSocket; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.transport.local.LocalClientTransport; import io.rsocket.transport.local.LocalServerTransport; import io.rsocket.util.EmptyPayload; +import java.lang.reflect.Field; +import java.util.Queue; +import java.util.concurrent.locks.LockSupport; import java.util.stream.IntStream; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; import org.openjdk.jmh.annotations.Measurement; import org.openjdk.jmh.annotations.Mode; import org.openjdk.jmh.annotations.Scope; import org.openjdk.jmh.annotations.Setup; import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.infra.Blackhole; import org.reactivestreams.Publisher; @@ -36,12 +47,25 @@ public class RSocketPerf { RSocket client; Closeable server; + Queue clientsQueue; + + @TearDown + public void tearDown() { + client.dispose(); + server.dispose(); + } + + @TearDown(Level.Iteration) + public void awaitToBeConsumed() { + while (!clientsQueue.isEmpty()) { + LockSupport.parkNanos(1000); + } + } @Setup - public void setUp() { + public void setUp() throws NoSuchFieldException, IllegalAccessException { server = - RSocketFactory.receive() - .acceptor( + RSocketServer.create( (setup, sendingSocket) -> Mono.just( new AbstractRSocket() { @@ -69,73 +93,77 @@ public Flux requestChannel(Publisher payloads) { return Flux.from(payloads); } })) - .transport(LocalServerTransport.create("server")) - .start() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(LocalServerTransport.create("server")) .block(); client = - RSocketFactory.connect() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(LocalClientTransport.create("server")) - .start() + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(LocalClientTransport.create("server")) .block(); + + Field sendProcessorField = RSocketRequester.class.getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); + + clientsQueue = (Queue) sendProcessorField.get(client); } @Benchmark @SuppressWarnings("unchecked") - public PerfSubscriber fireAndForget(Blackhole blackhole) throws InterruptedException { - PerfSubscriber subscriber = new PerfSubscriber(blackhole); + public PayloadsPerfSubscriber fireAndForget(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); client.fireAndForget(PAYLOAD).subscribe((CoreSubscriber) subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public PerfSubscriber requestResponse(Blackhole blackhole) throws InterruptedException { - PerfSubscriber subscriber = new PerfSubscriber(blackhole); + public PayloadsPerfSubscriber requestResponse(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); client.requestResponse(PAYLOAD).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public PerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole blackhole) + public PayloadsPerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole blackhole) throws InterruptedException { - PerfSubscriber subscriber = new PerfSubscriber(blackhole); + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); client.requestStream(PAYLOAD).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public MaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole blackhole) + public PayloadsMaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole blackhole) throws InterruptedException { - MaxPerfSubscriber subscriber = new MaxPerfSubscriber(blackhole); + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); client.requestStream(PAYLOAD).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public PerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole blackhole) + public PayloadsPerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole blackhole) throws InterruptedException { - PerfSubscriber subscriber = new PerfSubscriber(blackhole); + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); client.requestChannel(PAYLOAD_FLUX).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } @Benchmark - public MaxPerfSubscriber requestChannelWithRequestAllStrategy(Blackhole blackhole) + public PayloadsMaxPerfSubscriber requestChannelWithRequestAllStrategy(Blackhole blackhole) throws InterruptedException { - MaxPerfSubscriber subscriber = new MaxPerfSubscriber(blackhole); + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); client.requestChannel(PAYLOAD_FLUX).subscribe(subscriber); - subscriber.latch.await(); + subscriber.await(); return subscriber; } diff --git a/rsocket-core/src/jmh/java/io/rsocket/StreamIdSupplierPerf.java b/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java similarity index 66% rename from rsocket-core/src/jmh/java/io/rsocket/StreamIdSupplierPerf.java rename to benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java index c198b7a19..6b4f3f624 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/StreamIdSupplierPerf.java +++ b/benchmarks/src/main/java/io/rsocket/core/StreamIdSupplierPerf.java @@ -1,8 +1,16 @@ -package io.rsocket; +package io.rsocket.core; import io.netty.util.collection.IntObjectMap; import io.rsocket.internal.SynchronizedIntObjectHashMap; -import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.infra.Blackhole; @BenchmarkMode(Mode.Throughput) diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java similarity index 76% rename from rsocket-core/src/jmh/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java rename to benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java index 139114466..b4ac808d0 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java +++ b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderFlyweightPerf.java @@ -16,7 +16,7 @@ public class FrameHeaderFlyweightPerf { @Benchmark public void encode(Input input) { - ByteBuf byteBuf = FrameHeaderFlyweight.encodeStreamZero(input.allocator, FrameType.SETUP, 0); + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(input.allocator, FrameType.SETUP, 0); boolean release = byteBuf.release(); input.bh.consume(release); } @@ -24,9 +24,9 @@ public void encode(Input input) { @Benchmark public void decode(Input input) { ByteBuf frame = input.frame; - FrameType frameType = FrameHeaderFlyweight.frameType(frame); - int streamId = FrameHeaderFlyweight.streamId(frame); - int flags = FrameHeaderFlyweight.flags(frame); + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); + int flags = FrameHeaderCodec.flags(frame); input.bh.consume(streamId); input.bh.consume(flags); input.bh.consume(frameType); @@ -44,7 +44,7 @@ public void setup(Blackhole bh) { this.bh = bh; this.frameType = FrameType.REQUEST_RESPONSE; allocator = ByteBufAllocator.DEFAULT; - frame = FrameHeaderFlyweight.encode(allocator, 123, FrameType.SETUP, 0); + frame = FrameHeaderCodec.encode(allocator, 123, FrameType.SETUP, 0); } @TearDown diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/FrameTypePerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java similarity index 100% rename from rsocket-core/src/jmh/java/io/rsocket/frame/FrameTypePerf.java rename to benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java diff --git a/rsocket-core/src/jmh/java/io/rsocket/frame/PayloadFlyweightPerf.java b/benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java similarity index 90% rename from rsocket-core/src/jmh/java/io/rsocket/frame/PayloadFlyweightPerf.java rename to benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java index 89e50a6e9..01d82a08f 100644 --- a/rsocket-core/src/jmh/java/io/rsocket/frame/PayloadFlyweightPerf.java +++ b/benchmarks/src/main/java/io/rsocket/frame/PayloadFlyweightPerf.java @@ -18,7 +18,7 @@ public class PayloadFlyweightPerf { @Benchmark public void encode(Input input) { ByteBuf encode = - PayloadFrameFlyweight.encode( + PayloadFrameCodec.encode( input.allocator, 100, false, @@ -33,8 +33,8 @@ public void encode(Input input) { @Benchmark public void decode(Input input) { ByteBuf frame = input.payload; - ByteBuf data = PayloadFrameFlyweight.data(frame); - ByteBuf metadata = PayloadFrameFlyweight.metadata(frame); + ByteBuf data = PayloadFrameCodec.data(frame); + ByteBuf metadata = PayloadFrameCodec.metadata(frame); input.bh.consume(data); input.bh.consume(metadata); } @@ -57,7 +57,7 @@ public void setup(Blackhole bh) { // Encode a payload and then copy it a single bytebuf payload = allocator.buffer(); ByteBuf encode = - PayloadFrameFlyweight.encode( + PayloadFrameCodec.encode( allocator, 100, false, diff --git a/rsocket-core/src/jmh/java/io/rsocket/metadata/WellKnownMimeTypePerf.java b/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java similarity index 100% rename from rsocket-core/src/jmh/java/io/rsocket/metadata/WellKnownMimeTypePerf.java rename to benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java diff --git a/build.gradle b/build.gradle index c354869df..f579b3ae0 100644 --- a/build.gradle +++ b/build.gradle @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,27 +15,25 @@ */ plugins { - id 'com.gradle.build-scan' version '2.4.2' - id 'com.github.sherter.google-java-format' version '0.8' apply false - id 'com.jfrog.artifactory' version '4.9.10' apply false + id 'com.jfrog.artifactory' version '4.11.0' apply false id 'com.jfrog.bintray' version '1.8.4' apply false - id 'me.champeau.gradle.jmh' version '0.4.8' apply false - id 'io.spring.dependency-management' version '1.0.7.RELEASE' apply false + id 'me.champeau.gradle.jmh' version '0.5.0' apply false + id 'io.spring.dependency-management' version '1.0.8.RELEASE' apply false id 'io.morethan.jmhreport' version '0.9.0' apply false } subprojects { apply plugin: 'io.spring.dependency-management' apply plugin: 'com.github.sherter.google-java-format' - - ext['reactor-bom.version'] = 'Dysprosium-RELEASE' + + ext['reactor-bom.version'] = 'Dysprosium-SR7' ext['logback.version'] = '1.2.3' ext['findbugs.version'] = '3.0.2' - ext['netty-bom.version'] = '4.1.37.Final' - ext['netty-boringssl.version'] = '2.0.25.Final' + ext['netty-bom.version'] = '4.1.48.Final' + ext['netty-boringssl.version'] = '2.0.30.Final' ext['hdrhistogram.version'] = '2.1.10' - ext['mockito.version'] = '2.25.1' + ext['mockito.version'] = '3.2.0' ext['slf4j.version'] = '1.7.25' ext['jmh.version'] = '1.21' ext['junit.version'] = '5.5.2' @@ -64,7 +62,6 @@ subprojects { dependencies { dependency "ch.qos.logback:logback-classic:${ext['logback.version']}" - dependency "com.google.code.findbugs:jsr305:${ext['findbugs.version']}" dependency "io.netty:netty-tcnative-boringssl-static:${ext['netty-boringssl.version']}" dependency "io.micrometer:micrometer-core:${ext['micrometer.version']}" dependency "org.assertj:assertj-core:${ext['assertj.version']}" @@ -90,11 +87,18 @@ subprojects { repositories { mavenCentral() - if (version.endsWith('BUILD-SNAPSHOT') || project.hasProperty('platformVersion')) { + if (version.endsWith('SNAPSHOT') || project.hasProperty('platformVersion')) { maven { url 'http://repo.spring.io/libs-snapshot' } + maven { + url 'https://oss.jfrog.org/artifactory/oss-snapshot-local' + } } } + tasks.withType(GenerateModuleMetadata) { + enabled = false + } + plugins.withType(JavaPlugin) { compileJava { sourceCompatibility = 1.8 @@ -104,21 +108,61 @@ subprojects { } javadoc { + def jdk = JavaVersion.current().majorVersion + def jdkJavadoc = "https://docs.oracle.com/javase/$jdk/docs/api/" + if (JavaVersion.current().isJava11Compatible()) { + jdkJavadoc = "https://docs.oracle.com/en/java/javase/$jdk/docs/api/" + } options.with { - links 'https://docs.oracle.com/javase/8/docs/api/' + links jdkJavadoc links 'https://projectreactor.io/docs/core/release/api/' links 'https://netty.io/4.1/api/' } } + tasks.named("javadoc").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + test { useJUnitPlatform() systemProperty "io.netty.leakDetection.level", "ADVANCED" } - tasks.named("javadoc").configure { - onlyIf { System.getenv('SKIP_RELEASE') != "true" } + //all test tasks will show FAILED for each test method, + // common exclusions, no scanning + project.tasks.withType(Test).all { + testLogging { + events "FAILED" + showExceptions true + exceptionFormat "FULL" + stackTraceFilters "ENTRY_POINT" + maxGranularity 3 + } + + if (JavaVersion.current().isJava9Compatible()) { + println "Java 9+: lowering MaxGCPauseMillis to 20ms in ${project.name} ${name}" + jvmArgs = ["-XX:MaxGCPauseMillis=20"] + } + + systemProperty("java.awt.headless", "true") + systemProperty("reactor.trace.cancel", "true") + systemProperty("reactor.trace.nocapacity", "true") + systemProperty("testGroups", project.properties.get("testGroups")) + scanForTestClasses = false + exclude '**/*Abstract*.*' + + //allow re-run of failed tests only without special test tasks failing + // because the filter is too restrictive + filter.setFailOnNoMatchingTests(false) + + //display intermediate results for special test tasks + afterSuite { desc, result -> + if (!desc.parent) { // will match the outermost suite + println('\n' + "${desc} Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} successes, ${result.failedTestCount} failures, ${result.skippedTestCount} skipped)") + } + } } } @@ -145,7 +189,6 @@ subprojects { } } } - } apply from: "${rootDir}/gradle/publications.gradle" diff --git a/ci/travis.sh b/ci/travis.sh index 9154da33b..d190a59ec 100755 --- a/ci/travis.sh +++ b/ci/travis.sh @@ -5,13 +5,24 @@ if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then echo -e "Building PR #$TRAVIS_PULL_REQUEST [$TRAVIS_PULL_REQUEST_SLUG/$TRAVIS_PULL_REQUEST_BRANCH => $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH]" ./gradlew build -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] ; then +elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] && [ "$TRAVIS_BRANCH" == "develop" ] ; then - echo -e "Building Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH" + echo -e "Building Develop Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH/$TRAVIS_BUILD_NUMBER" ./gradlew \ -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ -PversionSuffix="-SNAPSHOT" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ + build artifactoryPublish --stacktrace + +elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ] && [ "$bintrayUser" != "" ] ; then + + echo -e "Building Branch Snapshot $TRAVIS_REPO_SLUG/$TRAVIS_BRANCH/$TRAVIS_BUILD_NUMBER" + ./gradlew \ + -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ + -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ + -PversionSuffix="-${TRAVIS_BRANCH//\//-}-SNAPSHOT" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ build artifactoryPublish --stacktrace elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ] && [ "$bintrayUser" != "" ] ; then @@ -21,6 +32,7 @@ elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ] && [ "$bin -Pversion="$TRAVIS_TAG" \ -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" \ -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" \ + -PbuildNumber="$TRAVIS_BUILD_NUMBER" \ build bintrayUpload --stacktrace else diff --git a/gradle.properties b/gradle.properties index b85cba325..b0b107ec4 100644 --- a/gradle.properties +++ b/gradle.properties @@ -11,4 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # -version=1.0.0-RC6 +version=1.0.1 +perfBaselineVersion=1.0.0 diff --git a/gradle/artifactory.gradle b/gradle/artifactory.gradle index 7f4369242..cdffb2741 100644 --- a/gradle/artifactory.gradle +++ b/gradle/artifactory.gradle @@ -33,6 +33,10 @@ if (project.hasProperty('bintrayUser') && project.hasProperty('bintrayKey')) { defaults { publications(publishing.publications.maven) } + + if (project.hasProperty('buildNumber')) { + clientConfig.info.setBuildNumber(project.property('buildNumber').toString()) + } } } tasks.named("artifactoryPublish").configure { diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 29953ea14..5c2d1cf01 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 7c4388a92..a4b442974 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-6.3-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index cccdd3d51..83f2acfdc 100755 --- a/gradlew +++ b/gradlew @@ -1,5 +1,21 @@ #!/usr/bin/env sh +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + ############################################################################## ## ## Gradle start up script for UN*X @@ -28,7 +44,7 @@ APP_NAME="Gradle" APP_BASE_NAME=`basename "$0"` # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="" +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. MAX_FD="maximum" @@ -109,8 +125,8 @@ if $darwin; then GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" fi -# For Cygwin, switch paths to Windows format before running java -if $cygwin ; then +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then APP_HOME=`cygpath --path --mixed "$APP_HOME"` CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` JAVACMD=`cygpath --unix "$JAVACMD"` diff --git a/gradlew.bat b/gradlew.bat index e95643d6a..24467a141 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,3 +1,19 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + @if "%DEBUG%" == "" @echo off @rem ########################################################################## @rem @@ -14,7 +30,7 @@ set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome diff --git a/rsocket-bom/build.gradle b/rsocket-bom/build.gradle index ca48a87c0..2efc20a91 100755 --- a/rsocket-bom/build.gradle +++ b/rsocket-bom/build.gradle @@ -22,9 +22,11 @@ plugins { description = 'RSocket Java Bill of materials.' +def excluded = ["rsocket-examples", "benchmarks"] + dependencies { constraints { - parent.subprojects.findAll { it.name != project.name }.sort { "$it.name" }.each { + parent.subprojects.findAll { it.name != project.name && !excluded.contains(it.name) } .sort { "$it.name" }.each { api it } } @@ -34,12 +36,6 @@ publishing { publications { maven(MavenPublication) { from components.javaPlatform - // remove scope information from published BOM - pom.withXml { - asNode().dependencyManagement.first().dependencies.first().each { - it.remove(it.scope.first()) - } - } } } } \ No newline at end of file diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index d62452619..41adbd7a8 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -29,8 +29,6 @@ dependencies { implementation 'org.slf4j:slf4j-api' - compileOnly 'com.google.code.findbugs:jsr305' - testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' testImplementation 'org.junit.jupiter:junit-jupiter-api' @@ -46,6 +44,4 @@ dependencies { testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' } -description = "Core functionality for the RSocket library" - -apply from: 'jmh.gradle' +description = "Core functionality for the RSocket library" \ No newline at end of file diff --git a/rsocket-core/jmh.gradle b/rsocket-core/jmh.gradle deleted file mode 100644 index 2a2b4d7cd..000000000 --- a/rsocket-core/jmh.gradle +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -dependencies { - jmh configurations.api - jmh configurations.implementation - jmh 'org.openjdk.jmh:jmh-core' - jmh 'org.openjdk.jmh:jmh-generator-annprocess' - jmh 'io.projectreactor:reactor-test' - jmh project(':rsocket-transport-local') -} - -jmhCompileGeneratedClasses.enabled = false - -jmh { - includeTests = false - profilers = ['gc'] - resultFormat = 'JSON' - - jvmArgs = ['-XX:+UnlockCommercialFeatures', '-XX:+FlightRecorder'] - // jvmArgsAppend = ['-XX:+UseG1GC', '-Xms4g', '-Xmx4g'] -} - -jmhJar { - from project.configurations.jmh -} - -tasks.jmh.finalizedBy tasks.jmhReport - -jmhReport { - jmhResultPath = project.file('build/reports/jmh/results.json') - jmhReportOutput = project.file('build/reports/jmh') -} diff --git a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java b/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java index c099a3120..7f39956dc 100644 --- a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java @@ -16,48 +16,21 @@ package io.rsocket; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; /** * An abstract implementation of {@link RSocket}. All request handling methods emit {@link * UnsupportedOperationException} and hence must be overridden to provide a valid implementation. + * + * @deprecated as of 1.0 in favor of implementing {@link RSocket} directly which has default + * methods. */ +@Deprecated public abstract class AbstractRSocket implements RSocket { private final MonoProcessor onClose = MonoProcessor.create(); - @Override - public Mono fireAndForget(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Fire and forget not implemented.")); - } - - @Override - public Mono requestResponse(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Request-Response not implemented.")); - } - - @Override - public Flux requestStream(Payload payload) { - payload.release(); - return Flux.error(new UnsupportedOperationException("Request-Stream not implemented.")); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.error(new UnsupportedOperationException("Request-Channel not implemented.")); - } - - @Override - public Mono metadataPush(Payload payload) { - payload.release(); - return Mono.error(new UnsupportedOperationException("Metadata-Push not implemented.")); - } - @Override public void dispose() { onClose.onComplete(); diff --git a/rsocket-core/src/main/java/io/rsocket/Closeable.java b/rsocket-core/src/main/java/io/rsocket/Closeable.java index 5eb871e18..2ea9a0371 100644 --- a/rsocket-core/src/main/java/io/rsocket/Closeable.java +++ b/rsocket-core/src/main/java/io/rsocket/Closeable.java @@ -16,17 +16,21 @@ package io.rsocket; +import org.reactivestreams.Subscriber; import reactor.core.Disposable; import reactor.core.publisher.Mono; -/** */ +/** An interface which allows listening to when a specific instance of this interface is closed */ public interface Closeable extends Disposable { /** - * Returns a {@code Publisher} that completes when this {@code RSocket} is closed. A {@code - * RSocket} can be closed by explicitly calling {@link RSocket#dispose()} or when the underlying - * transport connection is closed. + * Returns a {@link Mono} that terminates when the instance is terminated by any reason. Note, in + * case of error termination, the cause of error will be propagated as an error signal through + * {@link org.reactivestreams.Subscriber#onError(Throwable)}. Otherwise, {@link + * Subscriber#onComplete()} will be called. * - * @return A {@code Publisher} that completes when this {@code RSocket} close is complete. + * @return a {@link Mono} to track completion with success or error of the underlying resource. + * When the underlying resource is an `RSocket`, the {@code Mono} exposes stream 0 (i.e. + * connection level) errors. */ Mono onClose(); } diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java index 8762e0489..ece2aa9fa 100644 --- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -1,11 +1,11 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,28 +18,22 @@ import io.netty.buffer.ByteBuf; import io.netty.util.AbstractReferenceCounted; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; -import javax.annotation.Nullable; +import io.rsocket.core.DefaultConnectionSetupPayload; +import reactor.util.annotation.Nullable; /** - * Exposed to server for determination of ResponderRSocket based on mime types and SETUP - * metadata/data + * Exposes information from the {@code SETUP} frame to a server, as well as to client responders. */ public abstract class ConnectionSetupPayload extends AbstractReferenceCounted implements Payload { - public static ConnectionSetupPayload create(final ByteBuf setupFrame) { - return new DefaultConnectionSetupPayload(setupFrame); - } + public abstract String metadataMimeType(); + + public abstract String dataMimeType(); public abstract int keepAliveInterval(); public abstract int keepAliveMaxLifetime(); - public abstract String metadataMimeType(); - - public abstract String dataMimeType(); - public abstract int getFlags(); public abstract boolean willClientHonorLease(); @@ -64,96 +58,15 @@ public ConnectionSetupPayload retain(int increment) { @Override public abstract ConnectionSetupPayload touch(); - @Override - public abstract ConnectionSetupPayload touch(Object hint); - - private static final class DefaultConnectionSetupPayload extends ConnectionSetupPayload { - private final ByteBuf setupFrame; - - public DefaultConnectionSetupPayload(ByteBuf setupFrame) { - this.setupFrame = setupFrame; - } - - @Override - public boolean hasMetadata() { - return FrameHeaderFlyweight.hasMetadata(setupFrame); - } - - @Override - public int keepAliveInterval() { - return SetupFrameFlyweight.keepAliveInterval(setupFrame); - } - - @Override - public int keepAliveMaxLifetime() { - return SetupFrameFlyweight.keepAliveMaxLifetime(setupFrame); - } - - @Override - public String metadataMimeType() { - return SetupFrameFlyweight.metadataMimeType(setupFrame); - } - - @Override - public String dataMimeType() { - return SetupFrameFlyweight.dataMimeType(setupFrame); - } - - @Override - public int getFlags() { - return FrameHeaderFlyweight.flags(setupFrame); - } - - @Override - public boolean willClientHonorLease() { - return SetupFrameFlyweight.honorLease(setupFrame); - } - - @Override - public boolean isResumeEnabled() { - return SetupFrameFlyweight.resumeEnabled(setupFrame); - } - - @Override - public ByteBuf resumeToken() { - return SetupFrameFlyweight.resumeToken(setupFrame); - } - - @Override - public ConnectionSetupPayload touch() { - setupFrame.touch(); - return this; - } - - @Override - public ConnectionSetupPayload touch(Object hint) { - setupFrame.touch(hint); - return this; - } - - @Override - protected void deallocate() { - setupFrame.release(); - } - - @Override - public ByteBuf sliceMetadata() { - return SetupFrameFlyweight.metadata(setupFrame); - } - - @Override - public ByteBuf sliceData() { - return SetupFrameFlyweight.data(setupFrame); - } - - @Override - public ByteBuf data() { - return sliceData(); - } - - @Override - public ByteBuf metadata() { - return sliceMetadata(); - } + /** + * Create a {@code ConnectionSetupPayload}. + * + * @deprecated as of 1.0 RC7. Please, use {@link + * DefaultConnectionSetupPayload#DefaultConnectionSetupPayload(ByteBuf) + * DefaultConnectionSetupPayload} constructor. + */ + @Deprecated + public static ConnectionSetupPayload create(final ByteBuf setupFrame) { + return new DefaultConnectionSetupPayload(setupFrame); } } diff --git a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java index 7739a34c0..6190d24e3 100644 --- a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import java.nio.channels.ClosedChannelException; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -30,9 +31,9 @@ public interface DuplexConnection extends Availability, Closeable { * Sends the source of Frames on this connection and returns the {@code Publisher} representing * the result of this send. * - *

Flow control

+ *

Flow control * - * The passed {@code Publisher} must + *

The passed {@code Publisher} must * * @param frames Stream of {@code Frame}s to send on the connection. * @return {@code Publisher} that completes when all the frames are written on the connection @@ -56,20 +57,20 @@ default Mono sendOne(ByteBuf frame) { /** * Returns a stream of all {@code Frame}s received on this connection. * - *

Completion

+ *

Completion * - * Returned {@code Publisher} MUST never emit a completion event ({@link + *

Returned {@code Publisher} MUST never emit a completion event ({@link * Subscriber#onComplete()}. * - *

Error

+ *

Error * - * Returned {@code Publisher} can error with various transport errors. If the underlying physical - * connection is closed by the peer, then the returned stream from here MUST emit an - * {@link ClosedChannelException}. + *

Returned {@code Publisher} can error with various transport errors. If the underlying + * physical connection is closed by the peer, then the returned stream from here MUST + * emit an {@link ClosedChannelException}. * - *

Multiple Subscriptions

+ *

Multiple Subscriptions * - * Returned {@code Publisher} is not required to support multiple concurrent subscriptions. + *

Returned {@code Publisher} is not required to support multiple concurrent subscriptions. * RSocket will never have multiple subscriptions to this source. Implementations MUST * emit an {@link IllegalStateException} for subsequent concurrent subscriptions, if they do not * support multiple concurrent subscriptions. @@ -78,6 +79,13 @@ default Mono sendOne(ByteBuf frame) { */ Flux receive(); + /** + * Returns the assigned {@link ByteBufAllocator}. + * + * @return the {@link ByteBufAllocator} + */ + ByteBufAllocator alloc(); + @Override default double availability() { return isDisposed() ? 0.0 : 1.0; diff --git a/rsocket-core/src/main/java/io/rsocket/RSocket.java b/rsocket-core/src/main/java/io/rsocket/RSocket.java index 5468b4de8..773c93dc2 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocket.java @@ -33,7 +33,10 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} that completes when the passed {@code payload} is successfully * handled, otherwise errors. */ - Mono fireAndForget(Payload payload); + default Mono fireAndForget(Payload payload) { + payload.release(); + return Mono.error(new UnsupportedOperationException("Fire-and-Forget not implemented.")); + } /** * Request-Response interaction model of {@code RSocket}. @@ -42,7 +45,10 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} containing at most a single {@code Payload} representing the * response. */ - Mono requestResponse(Payload payload); + default Mono requestResponse(Payload payload) { + payload.release(); + return Mono.error(new UnsupportedOperationException("Request-Response not implemented.")); + } /** * Request-Stream interaction model of {@code RSocket}. @@ -50,7 +56,10 @@ public interface RSocket extends Availability, Closeable { * @param payload Request payload. * @return {@code Publisher} containing the stream of {@code Payload}s representing the response. */ - Flux requestStream(Payload payload); + default Flux requestStream(Payload payload) { + payload.release(); + return Flux.error(new UnsupportedOperationException("Request-Stream not implemented.")); + } /** * Request-Channel interaction model of {@code RSocket}. @@ -58,7 +67,9 @@ public interface RSocket extends Availability, Closeable { * @param payloads Stream of request payloads. * @return Stream of response payloads. */ - Flux requestChannel(Publisher payloads); + default Flux requestChannel(Publisher payloads) { + return Flux.error(new UnsupportedOperationException("Request-Channel not implemented.")); + } /** * Metadata-Push interaction model of {@code RSocket}. @@ -67,10 +78,26 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} that completes when the passed {@code payload} is successfully * handled, otherwise errors. */ - Mono metadataPush(Payload payload); + default Mono metadataPush(Payload payload) { + payload.release(); + return Mono.error(new UnsupportedOperationException("Metadata-Push not implemented.")); + } @Override default double availability() { return isDisposed() ? 0.0 : 1.0; } + + @Override + default void dispose() {} + + @Override + default boolean isDisposed() { + return false; + } + + @Override + default Mono onClose() { + return Mono.never(); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java new file mode 100644 index 000000000..b43b14bae --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java @@ -0,0 +1,82 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import reactor.util.annotation.Nullable; + +/** + * Exception that represents an RSocket protocol error. + * + * @see ERROR + * Frame (0x0B) + */ +public class RSocketErrorException extends RuntimeException { + + private static final long serialVersionUID = -1628781753426267554L; + + private static final int MIN_ERROR_CODE = 0x00000001; + + private static final int MAX_ERROR_CODE = 0xFFFFFFFE; + + private final int errorCode; + + /** + * Constructor with a protocol error code and a message. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + */ + public RSocketErrorException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Alternative to {@link #RSocketErrorException(int, String)} with a root cause. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + * @param cause a root cause for the error + */ + public RSocketErrorException(int errorCode, String message, @Nullable Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + if (errorCode > MAX_ERROR_CODE && errorCode < MIN_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000001-0xFFFFFFFE]", this); + } + } + + /** + * Return the RSocket error code + * represented by this exception + * + * @return the RSocket protocol error code + */ + public int errorCode() { + return errorCode; + } + + @Override + public String toString() { + return getClass().getSimpleName() + + " (0x" + + Integer.toHexString(errorCode) + + "): " + + getMessage(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java index 44f64e550..e23bcceb2 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java @@ -1,11 +1,11 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,57 +13,62 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package io.rsocket; -import static io.rsocket.internal.ClientSetup.DefaultClientSetup; -import static io.rsocket.internal.ClientSetup.ResumableClientSetup; - import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.InvalidSetupException; -import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.ResumeFrameFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; +import io.netty.buffer.Unpooled; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.ClientServerInputMultiplexer; -import io.rsocket.internal.ClientSetup; -import io.rsocket.internal.ServerSetup; -import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.lease.LeaseStats; import io.rsocket.lease.Leases; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; -import io.rsocket.plugins.*; -import io.rsocket.resume.*; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.plugins.SocketAcceptorInterceptor; +import io.rsocket.resume.ClientResume; +import io.rsocket.resume.ResumableFramesStore; +import io.rsocket.resume.ResumeStrategy; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; -import io.rsocket.util.ConnectionUtils; -import io.rsocket.util.EmptyPayload; -import io.rsocket.util.MultiSubscriberRSocket; import java.time.Duration; -import java.util.Objects; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +/** + * Main entry point to create RSocket clients or servers as follows: + * + *

    + *
  • {@link ClientRSocketFactory} to connect as a client. Use {@link #connect()} for a default + * instance. + *
  • {@link ServerRSocketFactory} to start a server. Use {@link #receive()} for a default + * instance. + *
+ * + * @deprecated please use {@link RSocketConnector} and {@link RSocketServer}. + */ +@Deprecated +public final class RSocketFactory { -/** Factory for creating RSocket clients and servers. */ -public class RSocketFactory { /** - * Creates a factory that establishes client connections to other RSockets. + * Create a {@code ClientRSocketFactory} to connect to a remote RSocket endpoint. Internally + * delegates to {@link RSocketConnector}. * - * @return a client factory + * @return the {@code ClientRSocketFactory} instance */ public static ClientRSocketFactory connect() { return new ClientRSocketFactory(); } /** - * Creates a factory that receives server connections from client RSockets. + * Create a {@code ServerRSocketFactory} to accept connections from RSocket clients. Internally + * delegates to {@link RSocketServer}. * - * @return a server factory. + * @return the {@code ClientRSocketFactory} instance */ public static ServerRSocketFactory receive() { return new ServerRSocketFactory(); @@ -92,52 +97,58 @@ default Start transport(ServerTransport transport) { } } + /** Factory to create and configure an RSocket client, and connect to a server. */ public static class ClientRSocketFactory implements ClientTransportAcceptor { - private static final String CLIENT_TAG = "client"; - - private SocketAcceptor acceptor = (setup, sendingSocket) -> Mono.just(new AbstractRSocket() {}); - - private Consumer errorConsumer = Throwable::printStackTrace; - private int mtu = 0; - private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins()); + private static final ClientResume CLIENT_RESUME = + new ClientResume(Duration.ofMinutes(2), Unpooled.EMPTY_BUFFER); - private Payload setupPayload = EmptyPayload.INSTANCE; - private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + private final RSocketConnector connector; private Duration tickPeriod = Duration.ofSeconds(20); private Duration ackTimeout = Duration.ofSeconds(30); private int missedAcks = 3; - private String metadataMimeType = "application/binary"; - private String dataMimeType = "application/binary"; + private Resume resume; - private boolean resumeEnabled; - private boolean resumeCleanupStoreOnKeepAlive; - private Supplier resumeTokenSupplier = ResumeFrameFlyweight::generateResumeToken; - private Function resumeStoreFactory = - token -> new InMemoryResumableFramesStore(CLIENT_TAG, 100_000); - private Duration resumeSessionDuration = Duration.ofMinutes(2); - private Duration resumeStreamTimeout = Duration.ofSeconds(10); - private Supplier resumeStrategySupplier = - () -> - new ExponentialBackoffResumeStrategy(Duration.ofSeconds(1), Duration.ofSeconds(16), 2); - - private boolean multiSubscriberRequester = true; - private boolean leaseEnabled; - private Supplier> leasesSupplier = Leases::new; + public ClientRSocketFactory() { + this(RSocketConnector.create()); + } - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + public ClientRSocketFactory(RSocketConnector connector) { + this.connector = connector; + } + /** + * @deprecated this method is deprecated and deliberately has no effect anymore. Right now, in + * order configure the custom {@link ByteBufAllocator} it is recommended to use the + * following setup for Reactor Netty based transport:
+ * 1. For Client:
+ *
{@code
+     * TcpClient.create()
+     *          ...
+     *          .bootstrap(bootstrap -> bootstrap.option(ChannelOption.ALLOCATOR, clientAllocator))
+     * }
+ *
+ * 2. For server:
+ *
{@code
+     * TcpServer.create()
+     *          ...
+     *          .bootstrap(serverBootstrap -> serverBootstrap.childOption(ChannelOption.ALLOCATOR, serverAllocator))
+     * }
+ * Or in case of local transport, to use corresponding factory method {@code + * LocalClientTransport.creat(String, ByteBufAllocator)} + * @param allocator instance of {@link ByteBufAllocator} + * @return this factory instance + */ public ClientRSocketFactory byteBufAllocator(ByteBufAllocator allocator) { - Objects.requireNonNull(allocator); - this.allocator = allocator; return this; } public ClientRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - plugins.addConnectionPlugin(interceptor); + connector.interceptors(registry -> registry.forConnection(interceptor)); return this; } + /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ @Deprecated public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { @@ -145,7 +156,7 @@ public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { } public ClientRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { - plugins.addRequesterPlugin(interceptor); + connector.interceptors(registry -> registry.forRequester(interceptor)); return this; } @@ -156,309 +167,296 @@ public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { } public ClientRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { - plugins.addResponderPlugin(interceptor); + connector.interceptors(registry -> registry.forResponder(interceptor)); return this; } public ClientRSocketFactory addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { - plugins.addSocketAcceptorPlugin(interceptor); + connector.interceptors(registry -> registry.forSocketAcceptor(interceptor)); return this; } /** - * Deprecated as Keep-Alive is not optional according to spec + * Deprecated without replacement as Keep-Alive is not optional according to spec * * @return this ClientRSocketFactory */ @Deprecated public ClientRSocketFactory keepAlive() { + connector.keepAlive(tickPeriod, ackTimeout.plus(tickPeriod.multipliedBy(missedAcks))); return this; } - public ClientRSocketFactory keepAlive( + public ClientTransportAcceptor keepAlive( Duration tickPeriod, Duration ackTimeout, int missedAcks) { this.tickPeriod = tickPeriod; this.ackTimeout = ackTimeout; this.missedAcks = missedAcks; + keepAlive(); return this; } public ClientRSocketFactory keepAliveTickPeriod(Duration tickPeriod) { this.tickPeriod = tickPeriod; + keepAlive(); return this; } public ClientRSocketFactory keepAliveAckTimeout(Duration ackTimeout) { this.ackTimeout = ackTimeout; + keepAlive(); return this; } public ClientRSocketFactory keepAliveMissedAcks(int missedAcks) { this.missedAcks = missedAcks; + keepAlive(); return this; } public ClientRSocketFactory mimeType(String metadataMimeType, String dataMimeType) { - this.dataMimeType = dataMimeType; - this.metadataMimeType = metadataMimeType; + connector.metadataMimeType(metadataMimeType); + connector.dataMimeType(dataMimeType); return this; } public ClientRSocketFactory dataMimeType(String dataMimeType) { - this.dataMimeType = dataMimeType; + connector.dataMimeType(dataMimeType); return this; } public ClientRSocketFactory metadataMimeType(String metadataMimeType) { - this.metadataMimeType = metadataMimeType; + connector.metadataMimeType(metadataMimeType); return this; } - public ClientRSocketFactory lease(Supplier> leasesSupplier) { - this.leaseEnabled = true; - this.leasesSupplier = Objects.requireNonNull(leasesSupplier); + public ClientRSocketFactory lease(Supplier> supplier) { + connector.lease(supplier); return this; } public ClientRSocketFactory lease() { - this.leaseEnabled = true; + connector.lease(Leases::new); return this; } + /** @deprecated without a replacement and no longer used. */ + @Deprecated public ClientRSocketFactory singleSubscriberRequester() { - this.multiSubscriberRequester = false; + return this; + } + + /** + * Enables a reconnectable, shared instance of {@code Mono} so every subscriber will + * observe the same RSocket instance up on connection establishment.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  RSocket r1 = sharedRSocketMono.block();
+     *  RSocket r2 = sharedRSocketMono.block();
+     *
+     *  assert r1 == r2;
+     *
+     * }
+ * + * Apart of the shared behavior, if the connection is lost, the same {@code Mono} + * instance will transparently re-establish the connection for subsequent subscribers.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  RSocket r1 = sharedRSocketMono.block();
+     *  RSocket r2 = sharedRSocketMono.block();
+     *
+     *  assert r1 == r2;
+     *
+     *  r1.dispose()
+     *
+     *  assert r2.isDisposed()
+     *
+     *  RSocket r3 = sharedRSocketMono.block();
+     *  RSocket r4 = sharedRSocketMono.block();
+     *
+     *
+     *  assert r1 != r3;
+     *  assert r4 == r3;
+     *
+     * }
+ * + * Note, having reconnect() enabled does not eliminate the need to accompany each + * individual request with the corresponding retry logic.
+ * For example: + * + *
{@code
+     * Mono sharedRSocketMono =
+     *   RSocketFactory
+     *                .connect()
+     *                .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+     *                .transport(transport)
+     *                .start();
+     *
+     *  sharedRSocket.flatMap(rSocket -> rSocket.requestResponse(...))
+     *               .retryWhen(ownRetry)
+     *               .subscribe()
+     *
+     * }
+ * + * @param retrySpec a retry factory applied for {@link Mono#retryWhen(Retry)} + * @return a shared instance of {@code Mono}. + */ + public ClientRSocketFactory reconnect(Retry retrySpec) { + connector.reconnect(retrySpec); return this; } public ClientRSocketFactory resume() { - this.resumeEnabled = true; + resume = resume != null ? resume : new Resume(); + connector.resume(resume); return this; } - public ClientRSocketFactory resumeToken(Supplier resumeTokenSupplier) { - this.resumeTokenSupplier = Objects.requireNonNull(resumeTokenSupplier); + public ClientRSocketFactory resumeToken(Supplier supplier) { + resume(); + resume.token(supplier); return this; } public ClientRSocketFactory resumeStore( - Function resumeStoreFactory) { - this.resumeStoreFactory = resumeStoreFactory; + Function storeFactory) { + resume(); + resume.storeFactory(storeFactory); return this; } public ClientRSocketFactory resumeSessionDuration(Duration sessionDuration) { - this.resumeSessionDuration = Objects.requireNonNull(sessionDuration); + resume(); + resume.sessionDuration(sessionDuration); return this; } - public ClientRSocketFactory resumeStreamTimeout(Duration resumeStreamTimeout) { - this.resumeStreamTimeout = Objects.requireNonNull(resumeStreamTimeout); + public ClientRSocketFactory resumeStreamTimeout(Duration streamTimeout) { + resume(); + resume.streamTimeout(streamTimeout); return this; } - public ClientRSocketFactory resumeStrategy(Supplier resumeStrategy) { - this.resumeStrategySupplier = Objects.requireNonNull(resumeStrategy); + public ClientRSocketFactory resumeStrategy(Supplier strategy) { + resume(); + resume.retry( + Retry.from( + signals -> signals.flatMap(s -> strategy.get().apply(CLIENT_RESUME, s.failure())))); return this; } public ClientRSocketFactory resumeCleanupOnKeepAlive() { - resumeCleanupStoreOnKeepAlive = true; + resume(); + resume.cleanupStoreOnKeepAlive(); return this; } - @Override - public Start transport(Supplier transportClient) { - return new StartClient(transportClient); + public Start transport(Supplier transport) { + return () -> connector.connect(transport); } public ClientTransportAcceptor acceptor(Function acceptor) { return acceptor(() -> acceptor); } - public ClientTransportAcceptor acceptor(Supplier> acceptor) { - return acceptor((setup, sendingSocket) -> Mono.just(acceptor.get().apply(sendingSocket))); + public ClientTransportAcceptor acceptor(Supplier> acceptorSupplier) { + return acceptor( + (setup, sendingSocket) -> { + acceptorSupplier.get().apply(sendingSocket); + return Mono.empty(); + }); } public ClientTransportAcceptor acceptor(SocketAcceptor acceptor) { - this.acceptor = acceptor; - return StartClient::new; + connector.acceptor(acceptor); + return this; } public ClientRSocketFactory fragment(int mtu) { - this.mtu = mtu; + connector.fragment(mtu); return this; } + /** + * @deprecated this handler is deliberately no-ops and is deprecated with no replacement. In + * order to observe errors, it is recommended to add error handler using {@code doOnError} + * on the specific logical stream. In order to observe connection, or RSocket terminal + * errors, it is recommended to hook on {@link Closeable#onClose()} handler. + */ public ClientRSocketFactory errorConsumer(Consumer errorConsumer) { - this.errorConsumer = errorConsumer; return this; } public ClientRSocketFactory setupPayload(Payload payload) { - this.setupPayload = payload; + connector.setupPayload(payload); return this; } public ClientRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) { - this.payloadDecoder = payloadDecoder; - return this; - } - - private class StartClient implements Start { - private final Supplier transportClient; - - StartClient(Supplier transportClient) { - this.transportClient = transportClient; - } - - @Override - public Mono start() { - return newConnection() - .flatMap( - connection -> { - ClientSetup clientSetup = clientSetup(connection); - ByteBuf resumeToken = clientSetup.resumeToken(); - KeepAliveHandler keepAliveHandler = clientSetup.keepAliveHandler(); - DuplexConnection wrappedConnection = clientSetup.connection(); - - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(wrappedConnection, plugins, true); - - boolean isLeaseEnabled = leaseEnabled; - Leases leases = leasesSupplier.get(); - RequesterLeaseHandler requesterLeaseHandler = - isLeaseEnabled - ? new RequesterLeaseHandler.Impl(CLIENT_TAG, leases.receiver()) - : RequesterLeaseHandler.None; - - RSocket rSocketRequester = - new RSocketRequester( - allocator, - multiplexer.asClientConnection(), - payloadDecoder, - errorConsumer, - StreamIdSupplier.clientSupplier(), - keepAliveTickPeriod(), - keepAliveTimeout(), - keepAliveHandler, - requesterLeaseHandler); - - if (multiSubscriberRequester) { - rSocketRequester = new MultiSubscriberRSocket(rSocketRequester); - } - - RSocket wrappedRSocketRequester = plugins.applyRequester(rSocketRequester); - - ByteBuf setupFrame = - SetupFrameFlyweight.encode( - allocator, - isLeaseEnabled, - keepAliveTickPeriod(), - keepAliveTimeout(), - resumeToken, - metadataMimeType, - dataMimeType, - setupPayload); - - ConnectionSetupPayload setup = ConnectionSetupPayload.create(setupFrame); - - return plugins - .applySocketAcceptorInterceptor(acceptor) - .accept(setup, wrappedRSocketRequester) - .flatMap( - rSocketHandler -> { - RSocket wrappedRSocketHandler = plugins.applyResponder(rSocketHandler); - - ResponderLeaseHandler responderLeaseHandler = - isLeaseEnabled - ? new ResponderLeaseHandler.Impl<>( - CLIENT_TAG, - allocator, - leases.sender(), - errorConsumer, - leases.stats()) - : ResponderLeaseHandler.None; - - RSocket rSocketResponder = - new RSocketResponder( - allocator, - multiplexer.asServerConnection(), - wrappedRSocketHandler, - payloadDecoder, - errorConsumer, - responderLeaseHandler); - - return wrappedConnection - .sendOne(setupFrame) - .thenReturn(wrappedRSocketRequester); - }); - }); - } - - private int keepAliveTickPeriod() { - return (int) tickPeriod.toMillis(); - } - - private int keepAliveTimeout() { - return (int) (ackTimeout.toMillis() + tickPeriod.toMillis() * missedAcks); - } - - private ClientSetup clientSetup(DuplexConnection startConnection) { - if (resumeEnabled) { - ByteBuf resumeToken = resumeTokenSupplier.get(); - return new ResumableClientSetup( - allocator, - startConnection, - newConnection(), - resumeToken, - resumeStoreFactory.apply(resumeToken), - resumeSessionDuration, - resumeStreamTimeout, - resumeStrategySupplier, - resumeCleanupStoreOnKeepAlive); - } else { - return new DefaultClientSetup(startConnection); - } - } - - private Mono newConnection() { - return transportClient.get().connect(mtu); - } + connector.payloadDecoder(payloadDecoder); + return this; } } - public static class ServerRSocketFactory { - private static final String SERVER_TAG = "server"; - - private SocketAcceptor acceptor; - private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; - private Consumer errorConsumer = Throwable::printStackTrace; - private int mtu = 0; - private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins()); + /** Factory to create, configure, and start an RSocket server. */ + public static class ServerRSocketFactory implements ServerTransportAcceptor { + private final RSocketServer server; - private boolean resumeSupported; - private Duration resumeSessionDuration = Duration.ofSeconds(120); - private Duration resumeStreamTimeout = Duration.ofSeconds(10); - private Function resumeStoreFactory = - token -> new InMemoryResumableFramesStore(SERVER_TAG, 100_000); + private Resume resume; - private boolean multiSubscriberRequester = true; - private boolean leaseEnabled; - private Supplier> leasesSupplier = Leases::new; - - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - private boolean resumeCleanupStoreOnKeepAlive; + public ServerRSocketFactory() { + this(RSocketServer.create()); + } - private ServerRSocketFactory() {} + public ServerRSocketFactory(RSocketServer server) { + this.server = server; + } + /** + * @deprecated this method is deprecated and deliberately has no effect anymore. Right now, in + * order configure the custom {@link ByteBufAllocator} it is recommended to use the + * following setup for Reactor Netty based transport:
+ * 1. For Client:
+ *
{@code
+     * TcpClient.create()
+     *          ...
+     *          .bootstrap(bootstrap -> bootstrap.option(ChannelOption.ALLOCATOR, clientAllocator))
+     * }
+ *
+ * 2. For server:
+ *
{@code
+     * TcpServer.create()
+     *          ...
+     *          .bootstrap(serverBootstrap -> serverBootstrap.childOption(ChannelOption.ALLOCATOR, serverAllocator))
+     * }
+ * Or in case of local transport, to use corresponding factory method {@code + * LocalClientTransport.creat(String, ByteBufAllocator)} + * @param allocator instance of {@link ByteBufAllocator} + * @return this factory instance + */ + @Deprecated public ServerRSocketFactory byteBufAllocator(ByteBufAllocator allocator) { - Objects.requireNonNull(allocator); - this.allocator = allocator; return this; } public ServerRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - plugins.addConnectionPlugin(interceptor); + server.interceptors(registry -> registry.forConnection(interceptor)); return this; } /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ @@ -468,7 +466,7 @@ public ServerRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { } public ServerRSocketFactory addRequesterPlugin(RSocketInterceptor interceptor) { - plugins.addRequesterPlugin(interceptor); + server.interceptors(registry -> registry.forRequester(interceptor)); return this; } @@ -479,265 +477,95 @@ public ServerRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { } public ServerRSocketFactory addResponderPlugin(RSocketInterceptor interceptor) { - plugins.addResponderPlugin(interceptor); + server.interceptors(registry -> registry.forResponder(interceptor)); return this; } public ServerRSocketFactory addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { - plugins.addSocketAcceptorPlugin(interceptor); + server.interceptors(registry -> registry.forSocketAcceptor(interceptor)); return this; } public ServerTransportAcceptor acceptor(SocketAcceptor acceptor) { - this.acceptor = acceptor; - return new ServerStart<>(); + server.acceptor(acceptor); + return this; } public ServerRSocketFactory frameDecoder(PayloadDecoder payloadDecoder) { - this.payloadDecoder = payloadDecoder; + server.payloadDecoder(payloadDecoder); return this; } public ServerRSocketFactory fragment(int mtu) { - this.mtu = mtu; + server.fragment(mtu); return this; } + /** + * @deprecated this handler is deliberately no-ops and is deprecated with no replacement. In + * order to observe errors, it is recommended to add error handler using {@code doOnError} + * on the specific logical stream. In order to observe connection, or RSocket terminal + * errors, it is recommended to hook on {@link Closeable#onClose()} handler. + */ public ServerRSocketFactory errorConsumer(Consumer errorConsumer) { - this.errorConsumer = errorConsumer; return this; } - public ServerRSocketFactory lease(Supplier> leasesSupplier) { - this.leaseEnabled = true; - this.leasesSupplier = Objects.requireNonNull(leasesSupplier); + public ServerRSocketFactory lease(Supplier> supplier) { + server.lease(supplier); return this; } public ServerRSocketFactory lease() { - this.leaseEnabled = true; + server.lease(Leases::new); return this; } + /** @deprecated without a replacement and no longer used. */ + @Deprecated public ServerRSocketFactory singleSubscriberRequester() { - this.multiSubscriberRequester = false; return this; } public ServerRSocketFactory resume() { - this.resumeSupported = true; + resume = resume != null ? resume : new Resume(); + server.resume(resume); return this; } public ServerRSocketFactory resumeStore( - Function resumeStoreFactory) { - this.resumeStoreFactory = resumeStoreFactory; + Function storeFactory) { + resume(); + resume.storeFactory(storeFactory); return this; } public ServerRSocketFactory resumeSessionDuration(Duration sessionDuration) { - this.resumeSessionDuration = Objects.requireNonNull(sessionDuration); + resume(); + resume.sessionDuration(sessionDuration); return this; } - public ServerRSocketFactory resumeStreamTimeout(Duration resumeStreamTimeout) { - this.resumeStreamTimeout = Objects.requireNonNull(resumeStreamTimeout); + public ServerRSocketFactory resumeStreamTimeout(Duration streamTimeout) { + resume(); + resume.streamTimeout(streamTimeout); return this; } public ServerRSocketFactory resumeCleanupOnKeepAlive() { - resumeCleanupStoreOnKeepAlive = true; - return this; - } - - private class ServerStart implements Start, ServerTransportAcceptor { - private Supplier> transportServer; - - @Override - public ServerTransport.ConnectionAcceptor toConnectionAcceptor() { - return new ServerTransport.ConnectionAcceptor() { - private final ServerSetup serverSetup = serverSetup(); - - @Override - public Mono apply(DuplexConnection connection) { - return acceptor(serverSetup, connection); - } - }; - } - - @Override - @SuppressWarnings("unchecked") - public Start transport(Supplier> transport) { - this.transportServer = (Supplier) transport; - return (Start) this::start; - } - - private Mono acceptor(ServerSetup serverSetup, DuplexConnection connection) { - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, plugins, false); - - return multiplexer - .asSetupConnection() - .receive() - .next() - .flatMap(startFrame -> accept(serverSetup, startFrame, multiplexer)); - } - - private Mono acceptResume( - ServerSetup serverSetup, ByteBuf resumeFrame, ClientServerInputMultiplexer multiplexer) { - return serverSetup.acceptRSocketResume(resumeFrame, multiplexer); - } - - private Mono accept( - ServerSetup serverSetup, ByteBuf startFrame, ClientServerInputMultiplexer multiplexer) { - switch (FrameHeaderFlyweight.frameType(startFrame)) { - case SETUP: - return acceptSetup(serverSetup, startFrame, multiplexer); - case RESUME: - return acceptResume(serverSetup, startFrame, multiplexer); - default: - return acceptUnknown(startFrame, multiplexer); - } - } - - private Mono acceptSetup( - ServerSetup serverSetup, ByteBuf setupFrame, ClientServerInputMultiplexer multiplexer) { - - if (!SetupFrameFlyweight.isSupportedVersion(setupFrame)) { - return sendError( - multiplexer, - new InvalidSetupException( - "Unsupported version: " - + SetupFrameFlyweight.humanReadableVersion(setupFrame))) - .doFinally( - signalType -> { - setupFrame.release(); - multiplexer.dispose(); - }); - } - - boolean isLeaseEnabled = leaseEnabled; - - if (SetupFrameFlyweight.honorLease(setupFrame) && !isLeaseEnabled) { - return sendError(multiplexer, new InvalidSetupException("lease is not supported")) - .doFinally( - signalType -> { - setupFrame.release(); - multiplexer.dispose(); - }); - } - - return serverSetup.acceptRSocketSetup( - setupFrame, - multiplexer, - (keepAliveHandler, wrappedMultiplexer) -> { - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame); - - Leases leases = leasesSupplier.get(); - RequesterLeaseHandler requesterLeaseHandler = - isLeaseEnabled - ? new RequesterLeaseHandler.Impl(SERVER_TAG, leases.receiver()) - : RequesterLeaseHandler.None; - - RSocket rSocketRequester = - new RSocketRequester( - allocator, - wrappedMultiplexer.asServerConnection(), - payloadDecoder, - errorConsumer, - StreamIdSupplier.serverSupplier(), - setupPayload.keepAliveInterval(), - setupPayload.keepAliveMaxLifetime(), - keepAliveHandler, - requesterLeaseHandler); - - if (multiSubscriberRequester) { - rSocketRequester = new MultiSubscriberRSocket(rSocketRequester); - } - RSocket wrappedRSocketRequester = plugins.applyRequester(rSocketRequester); - - return plugins - .applySocketAcceptorInterceptor(acceptor) - .accept(setupPayload, wrappedRSocketRequester) - .onErrorResume( - err -> sendError(multiplexer, rejectedSetupError(err)).then(Mono.error(err))) - .doOnNext( - rSocketHandler -> { - RSocket wrappedRSocketHandler = plugins.applyResponder(rSocketHandler); - - ResponderLeaseHandler responderLeaseHandler = - isLeaseEnabled - ? new ResponderLeaseHandler.Impl<>( - SERVER_TAG, - allocator, - leases.sender(), - errorConsumer, - leases.stats()) - : ResponderLeaseHandler.None; - - RSocket rSocketResponder = - new RSocketResponder( - allocator, - wrappedMultiplexer.asClientConnection(), - wrappedRSocketHandler, - payloadDecoder, - errorConsumer, - responderLeaseHandler); - }) - .doFinally(signalType -> setupPayload.release()) - .then(); - }); - } - - @Override - public Mono start() { - return Mono.defer( - new Supplier>() { - - ServerSetup serverSetup = serverSetup(); - - @Override - public Mono get() { - return transportServer - .get() - .start(duplexConnection -> acceptor(serverSetup, duplexConnection), mtu) - .doOnNext(c -> c.onClose().doFinally(v -> serverSetup.dispose()).subscribe()); - } - }); - } - - private ServerSetup serverSetup() { - return resumeSupported - ? new ServerSetup.ResumableServerSetup( - allocator, - new SessionManager(), - resumeSessionDuration, - resumeStreamTimeout, - resumeStoreFactory, - resumeCleanupStoreOnKeepAlive) - : new ServerSetup.DefaultServerSetup(allocator); - } - - private Mono acceptUnknown(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { - return sendError( - multiplexer, - new InvalidSetupException( - "invalid setup frame: " + FrameHeaderFlyweight.frameType(frame))) - .doFinally( - signalType -> { - frame.release(); - multiplexer.dispose(); - }); - } - - private Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { - return ConnectionUtils.sendError(allocator, multiplexer, exception); - } - - private Exception rejectedSetupError(Throwable err) { - String msg = err.getMessage(); - return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); - } + resume(); + resume.cleanupStoreOnKeepAlive(); + return this; + } + + @Override + public ServerTransport.ConnectionAcceptor toConnectionAcceptor() { + return server.asConnectionAcceptor(); + } + + @Override + public Start transport(Supplier> transport) { + return () -> server.bind(transport.get()); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java deleted file mode 100644 index cb17ff539..000000000 --- a/rsocket-core/src/main/java/io/rsocket/RSocketRequester.java +++ /dev/null @@ -1,616 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; -import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.util.ReferenceCountUtil; -import io.netty.util.collection.IntObjectMap; -import io.rsocket.exceptions.ConnectionErrorException; -import io.rsocket.exceptions.Exceptions; -import io.rsocket.frame.*; -import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.RateLimitableRequestPublisher; -import io.rsocket.internal.SynchronizedIntObjectHashMap; -import io.rsocket.internal.UnboundedProcessor; -import io.rsocket.internal.UnicastMonoProcessor; -import io.rsocket.keepalive.KeepAliveFramesAcceptor; -import io.rsocket.keepalive.KeepAliveHandler; -import io.rsocket.keepalive.KeepAliveSupport; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.util.OnceConsumer; -import java.nio.channels.ClosedChannelException; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Consumer; -import java.util.function.LongConsumer; -import java.util.function.Supplier; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; -import org.reactivestreams.Processor; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.publisher.*; -import reactor.util.concurrent.Queues; - -/** - * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer - */ -class RSocketRequester implements RSocket { - private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = - AtomicReferenceFieldUpdater.newUpdater( - RSocketRequester.class, Throwable.class, "terminationError"); - private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); - - static { - CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); - } - - private final DuplexConnection connection; - private final PayloadDecoder payloadDecoder; - private final Consumer errorConsumer; - private final StreamIdSupplier streamIdSupplier; - private final IntObjectMap senders; - private final IntObjectMap> receivers; - private final UnboundedProcessor sendProcessor; - private final RequesterLeaseHandler leaseHandler; - private final ByteBufAllocator allocator; - private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; - private volatile Throwable terminationError; - - RSocketRequester( - ByteBufAllocator allocator, - DuplexConnection connection, - PayloadDecoder payloadDecoder, - Consumer errorConsumer, - StreamIdSupplier streamIdSupplier, - int keepAliveTickPeriod, - int keepAliveAckTimeout, - @Nullable KeepAliveHandler keepAliveHandler, - RequesterLeaseHandler leaseHandler) { - this.allocator = allocator; - this.connection = connection; - this.payloadDecoder = payloadDecoder; - this.errorConsumer = errorConsumer; - this.streamIdSupplier = streamIdSupplier; - this.leaseHandler = leaseHandler; - this.senders = new SynchronizedIntObjectHashMap<>(); - this.receivers = new SynchronizedIntObjectHashMap<>(); - - // DO NOT Change the order here. The Send processor must be subscribed to before receiving - this.sendProcessor = new UnboundedProcessor<>(); - - connection - .onClose() - .doFinally(signalType -> tryTerminateOnConnectionClose()) - .subscribe(null, errorConsumer); - connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); - - connection.receive().subscribe(this::handleIncomingFrames, errorConsumer); - - if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { - KeepAliveSupport keepAliveSupport = - new ClientKeepAliveSupport(allocator, keepAliveTickPeriod, keepAliveAckTimeout); - this.keepAliveFramesAcceptor = - keepAliveHandler.start( - keepAliveSupport, sendProcessor::onNext, this::tryTerminateOnKeepAlive); - } else { - keepAliveFramesAcceptor = null; - } - } - - @Override - public Mono fireAndForget(Payload payload) { - return handleFireAndForget(payload); - } - - @Override - public Mono requestResponse(Payload payload) { - return handleRequestResponse(payload); - } - - @Override - public Flux requestStream(Payload payload) { - return handleRequestStream(payload); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return handleChannel(Flux.from(payloads)); - } - - @Override - public Mono metadataPush(Payload payload) { - return handleMetadataPush(payload); - } - - @Override - public double availability() { - return Math.min(connection.availability(), leaseHandler.availability()); - } - - @Override - public void dispose() { - connection.dispose(); - } - - @Override - public boolean isDisposed() { - return connection.isDisposed(); - } - - @Override - public Mono onClose() { - return connection.onClose(); - } - - private Mono handleFireAndForget(Payload payload) { - Throwable err = checkAvailable(); - if (err != null) { - payload.release(); - return Mono.error(err); - } - - final int streamId = streamIdSupplier.nextStreamId(receivers); - - return emptyUnicastMono() - .doOnSubscribe( - new OnceConsumer() { - @Override - public void acceptOnce(@Nonnull Subscription subscription) { - ByteBuf requestFrame = - RequestFireAndForgetFrameFlyweight.encode( - allocator, - streamId, - false, - payload.hasMetadata() ? payload.sliceMetadata().retain() : null, - payload.sliceData().retain()); - payload.release(); - - sendProcessor.onNext(requestFrame); - } - }); - } - - private Mono handleRequestResponse(final Payload payload) { - Throwable err = checkAvailable(); - if (err != null) { - payload.release(); - return Mono.error(err); - } - - int streamId = streamIdSupplier.nextStreamId(receivers); - final UnboundedProcessor sendProcessor = this.sendProcessor; - - UnicastMonoProcessor receiver = UnicastMonoProcessor.create(); - receivers.put(streamId, receiver); - - return receiver - .doOnSubscribe( - new OnceConsumer() { - @Override - public void acceptOnce(@Nonnull Subscription subscription) { - final ByteBuf requestFrame = - RequestResponseFrameFlyweight.encode( - allocator, - streamId, - false, - payload.sliceMetadata().retain(), - payload.sliceData().retain()); - payload.release(); - - sendProcessor.onNext(requestFrame); - } - }) - .doOnError(t -> sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t))) - .doFinally( - s -> { - if (s == SignalType.CANCEL) { - sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); - } - removeStreamReceiver(streamId); - }); - } - - private Flux handleRequestStream(final Payload payload) { - Throwable err = checkAvailable(); - if (err != null) { - payload.release(); - return Flux.error(err); - } - - int streamId = streamIdSupplier.nextStreamId(receivers); - - final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(); - - receivers.put(streamId, receiver); - - return receiver - .doOnRequest( - new LongConsumer() { - - boolean firstRequest = true; - - @Override - public void accept(long n) { - if (firstRequest && !receiver.isDisposed()) { - firstRequest = false; - sendProcessor.onNext( - RequestStreamFrameFlyweight.encode( - allocator, - streamId, - false, - n, - payload.sliceMetadata().retain(), - payload.sliceData().retain())); - payload.release(); - } else if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); - } - } - }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); - } - }) - .doFinally(s -> removeStreamReceiver(streamId)); - } - - private Flux handleChannel(Flux request) { - Throwable err = checkAvailable(); - if (err != null) { - return Flux.error(err); - } - - final UnboundedProcessor sendProcessor = this.sendProcessor; - final UnicastProcessor receiver = UnicastProcessor.create(); - final int streamId = streamIdSupplier.nextStreamId(receivers); - - return receiver - .doOnRequest( - new LongConsumer() { - - boolean firstRequest = true; - - @Override - public void accept(long n) { - if (firstRequest) { - firstRequest = false; - request - .transform( - f -> { - RateLimitableRequestPublisher wrapped = - RateLimitableRequestPublisher.wrap(f, Queues.SMALL_BUFFER_SIZE); - // Need to set this to one for first the frame - wrapped.request(1); - senders.put(streamId, wrapped); - receivers.put(streamId, receiver); - - return wrapped; - }) - .subscribe( - new BaseSubscriber() { - - boolean firstPayload = true; - - @Override - protected void hookOnNext(Payload payload) { - final ByteBuf frame; - - if (firstPayload) { - firstPayload = false; - frame = - RequestChannelFrameFlyweight.encode( - allocator, - streamId, - false, - false, - n, - payload.sliceMetadata().retain(), - payload.sliceData().retain()); - } else { - frame = - PayloadFrameFlyweight.encode( - allocator, streamId, false, false, true, payload); - } - - sendProcessor.onNext(frame); - payload.release(); - } - - @Override - protected void hookOnComplete() { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext( - PayloadFrameFlyweight.encodeComplete(allocator, streamId)); - } - if (firstPayload) { - receiver.onComplete(); - } - } - - @Override - protected void hookOnError(Throwable t) { - errorConsumer.accept(t); - receiver.dispose(); - } - }); - } else { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n)); - } - } - } - }) - .doOnError( - t -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) && !receiver.isDisposed()) { - sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId)); - } - }) - .doFinally(s -> removeStreamReceiverAndSender(streamId)); - } - - private Mono handleMetadataPush(Payload payload) { - Throwable err = this.terminationError; - if (err != null) { - payload.release(); - return Mono.error(err); - } - - return emptyUnicastMono() - .doOnSubscribe( - new OnceConsumer() { - @Override - public void acceptOnce(@Nonnull Subscription subscription) { - ByteBuf metadataPushFrame = - MetadataPushFrameFlyweight.encode(allocator, payload.sliceMetadata().retain()); - payload.release(); - - sendProcessor.onNext(metadataPushFrame); - } - }); - } - - private static UnicastMonoProcessor emptyUnicastMono() { - UnicastMonoProcessor result = UnicastMonoProcessor.create(); - result.onComplete(); - return result; - } - - private Throwable checkAvailable() { - Throwable err = this.terminationError; - if (err != null) { - return err; - } - RequesterLeaseHandler lh = leaseHandler; - if (!lh.useLease()) { - return lh.leaseError(); - } - return null; - } - - private boolean contains(int streamId) { - return receivers.containsKey(streamId); - } - - private void handleIncomingFrames(ByteBuf frame) { - try { - int streamId = FrameHeaderFlyweight.streamId(frame); - FrameType type = FrameHeaderFlyweight.frameType(frame); - if (streamId == 0) { - handleStreamZero(type, frame); - } else { - handleFrame(streamId, type, frame); - } - frame.release(); - } catch (Throwable t) { - ReferenceCountUtil.safeRelease(frame); - throw reactor.core.Exceptions.propagate(t); - } - } - - private void handleStreamZero(FrameType type, ByteBuf frame) { - switch (type) { - case ERROR: - tryTerminateOnZeroError(frame); - break; - case LEASE: - leaseHandler.receive(frame); - break; - case KEEPALIVE: - if (keepAliveFramesAcceptor != null) { - keepAliveFramesAcceptor.receive(frame); - } - break; - default: - // Ignore unknown frames. Throwing an error will close the socket. - errorConsumer.accept( - new IllegalStateException( - "Client received supported frame on stream 0: " + frame.toString())); - } - } - - private void handleFrame(int streamId, FrameType type, ByteBuf frame) { - Subscriber receiver = receivers.get(streamId); - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - } else { - switch (type) { - case ERROR: - receiver.onError(Exceptions.from(frame)); - receivers.remove(streamId); - break; - case NEXT_COMPLETE: - receiver.onNext(payloadDecoder.apply(frame)); - receiver.onComplete(); - break; - case CANCEL: - { - RateLimitableRequestPublisher sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - break; - } - case NEXT: - receiver.onNext(payloadDecoder.apply(frame)); - break; - case REQUEST_N: - { - RateLimitableRequestPublisher sender = senders.get(streamId); - if (sender != null) { - int n = RequestNFrameFlyweight.requestN(frame); - sender.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n); - } - break; - } - case COMPLETE: - receiver.onComplete(); - receivers.remove(streamId); - break; - default: - throw new IllegalStateException( - "Client received supported frame on stream " + streamId + ": " + frame.toString()); - } - } - } - - private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBuf frame) { - if (!streamIdSupplier.isBeforeOrCurrent(streamId)) { - if (type == FrameType.ERROR) { - // message for stream that has never existed, we have a problem with - // the overall connection and must tear down - String errorMessage = ErrorFrameFlyweight.dataUtf8(frame); - - throw new IllegalStateException( - "Client received error for non-existent stream: " - + streamId - + " Message: " - + errorMessage); - } else { - throw new IllegalStateException( - "Client received message for non-existent stream: " - + streamId - + ", frame type: " - + type); - } - } - // receiving a frame after a given stream has been cancelled/completed, - // so ignore (cancellation is async so there is a race condition) - } - - private void tryTerminateOnKeepAlive(KeepAlive keepAlive) { - tryTerminate( - () -> - new ConnectionErrorException( - String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); - } - - private void tryTerminateOnConnectionClose() { - tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); - } - - private void tryTerminateOnZeroError(ByteBuf errorFrame) { - tryTerminate(() -> Exceptions.from(errorFrame)); - } - - private void tryTerminate(Supplier errorSupplier) { - if (terminationError == null) { - Exception e = errorSupplier.get(); - if (TERMINATION_ERROR.compareAndSet(this, null, e)) { - terminate(e); - } - } - } - - private void terminate(Exception e) { - connection.dispose(); - leaseHandler.dispose(); - - synchronized (receivers) { - receivers - .values() - .forEach( - receiver -> { - try { - receiver.onError(e); - } catch (Throwable t) { - errorConsumer.accept(t); - } - }); - } - synchronized (senders) { - senders - .values() - .forEach( - sender -> { - try { - sender.cancel(); - } catch (Throwable t) { - errorConsumer.accept(t); - } - }); - } - senders.clear(); - receivers.clear(); - sendProcessor.dispose(); - errorConsumer.accept(e); - } - - private void removeStreamReceiver(int streamId) { - /*on termination receivers are explicitly cleared to avoid removing from map while iterating over one - of its views*/ - if (terminationError == null) { - receivers.remove(streamId); - } - } - - private void removeStreamReceiverAndSender(int streamId) { - /*on termination senders & receivers are explicitly cleared to avoid removing from map while iterating over one - of its views*/ - if (terminationError == null) { - receivers.remove(streamId); - RateLimitableRequestPublisher sender = senders.remove(streamId); - if (sender != null) { - sender.cancel(); - } - } - } - - private void handleSendProcessorError(Throwable t) { - connection.dispose(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java b/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java index f98901472..22697f130 100644 --- a/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/ResponderRSocket.java @@ -1,12 +1,17 @@ package io.rsocket; +import java.util.function.BiFunction; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; /** * Extends the {@link RSocket} that allows an implementer to peek at the first request payload of a * channel. + * + * @deprecated as of 1.0 RC7 in favor of using {@link RSocket#requestChannel(Publisher)} with {@link + * Flux#switchOnFirst(BiFunction)} */ +@Deprecated public interface ResponderRSocket extends RSocket { /** * Implement this method to peak at the first payload of the incoming request stream without diff --git a/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java b/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java index 85c731eea..a42626e78 100644 --- a/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,9 @@ package io.rsocket; import io.rsocket.exceptions.SetupException; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** @@ -38,4 +41,53 @@ public interface SocketAcceptor { * @throws SetupException If the acceptor needs to reject the setup of this socket. */ Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket); + + /** Create a {@code SocketAcceptor} that handles requests with the given {@code RSocket}. */ + static SocketAcceptor with(RSocket rsocket) { + return (setup, sendingRSocket) -> Mono.just(rsocket); + } + + /** Create a {@code SocketAcceptor} for fire-and-forget interactions with the given handler. */ + static SocketAcceptor forFireAndForget(Function> handler) { + return with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-response interactions with the given handler. */ + static SocketAcceptor forRequestResponse(Function> handler) { + return with( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-stream interactions with the given handler. */ + static SocketAcceptor forRequestStream(Function> handler) { + return with( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-channel interactions with the given handler. */ + static SocketAcceptor forRequestChannel(Function, Flux> handler) { + return with( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return handler.apply(payloads); + } + }); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/buffer/AbstractTupleByteBuf.java b/rsocket-core/src/main/java/io/rsocket/buffer/AbstractTupleByteBuf.java deleted file mode 100644 index fbac4b1a0..000000000 --- a/rsocket-core/src/main/java/io/rsocket/buffer/AbstractTupleByteBuf.java +++ /dev/null @@ -1,607 +0,0 @@ -package io.rsocket.buffer; - -import io.netty.buffer.AbstractReferenceCountedByteBuf; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.util.internal.SystemPropertyUtil; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.channels.FileChannel; -import java.nio.channels.ScatteringByteChannel; -import java.nio.charset.Charset; - -abstract class AbstractTupleByteBuf extends AbstractReferenceCountedByteBuf { - static final int DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT = - SystemPropertyUtil.getInt("io.netty.allocator.directMemoryCacheAlignment", 0); - static final ByteBuffer EMPTY_NIO_BUFFER = Unpooled.EMPTY_BUFFER.nioBuffer(); - static final int NOT_ENOUGH_BYTES_AT_MAX_CAPACITY_CODE = 3; - - final ByteBufAllocator allocator; - final int capacity; - - AbstractTupleByteBuf(ByteBufAllocator allocator, int capacity) { - super(Integer.MAX_VALUE); - - this.capacity = capacity; - this.allocator = allocator; - super.writerIndex(capacity); - } - - abstract long calculateRelativeIndex(int index); - - abstract ByteBuf getPart(int index); - - @Override - public ByteBuffer nioBuffer(int index, int length) { - checkIndex(index, length); - - ByteBuffer[] buffers = nioBuffers(index, length); - - if (buffers.length == 1) { - return buffers[0].duplicate(); - } - - ByteBuffer merged = - BufferUtil.allocateDirectAligned(length, DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT) - .order(order()); - for (ByteBuffer buf : buffers) { - merged.put(buf); - } - - merged.flip(); - return merged; - } - - @Override - public ByteBuffer[] nioBuffers(int index, int length) { - checkIndex(index, length); - if (length == 0) { - return new ByteBuffer[] {EMPTY_NIO_BUFFER}; - } - return _nioBuffers(index, length); - } - - protected abstract ByteBuffer[] _nioBuffers(int index, int length); - - @Override - protected byte _getByte(final int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - return byteBuf.getByte(calculatedIndex); - } - - @Override - protected short _getShort(final int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - final int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - if (calculatedIndex + Short.BYTES <= byteBuf.writerIndex()) { - return byteBuf.getShort(calculatedIndex); - } else if (order() == ByteOrder.BIG_ENDIAN) { - return (short) ((_getByte(index) & 0xff) << 8 | _getByte(index + 1) & 0xff); - } else { - return (short) (_getByte(index) & 0xff | (_getByte(index + 1) & 0xff) << 8); - } - } - - @Override - protected short _getShortLE(int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - final int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - if (calculatedIndex + Short.BYTES <= byteBuf.writerIndex()) { - return byteBuf.getShortLE(calculatedIndex); - } else if (order() == ByteOrder.BIG_ENDIAN) { - return (short) (_getByte(index) & 0xff | (_getByte(index + 1) & 0xff) << 8); - } else { - return (short) ((_getByte(index) & 0xff) << 8 | _getByte(index + 1) & 0xff); - } - } - - @Override - protected int _getUnsignedMedium(final int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - if (calculatedIndex + 3 <= byteBuf.writerIndex()) { - return byteBuf.getUnsignedMedium(calculatedIndex); - } else if (order() == ByteOrder.BIG_ENDIAN) { - return (_getShort(index) & 0xffff) << 8 | _getByte(index + 2) & 0xff; - } else { - return _getShort(index) & 0xFFFF | (_getByte(index + 2) & 0xFF) << 16; - } - } - - @Override - protected int _getUnsignedMediumLE(int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - if (calculatedIndex + 3 <= byteBuf.writerIndex()) { - return byteBuf.getUnsignedMediumLE(calculatedIndex); - } else if (order() == ByteOrder.BIG_ENDIAN) { - return _getShortLE(index) & 0xffff | (_getByte(index + 2) & 0xff) << 16; - } else { - return (_getShortLE(index) & 0xffff) << 8 | _getByte(index + 2) & 0xff; - } - } - - @Override - protected int _getInt(final int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - if (calculatedIndex + Integer.BYTES <= byteBuf.writerIndex()) { - return byteBuf.getInt(calculatedIndex); - } else if (order() == ByteOrder.BIG_ENDIAN) { - return (_getShort(index) & 0xffff) << 16 | _getShort(index + 2) & 0xffff; - } else { - return _getShort(index) & 0xFFFF | (_getShort(index + 2) & 0xFFFF) << 16; - } - } - - @Override - protected int _getIntLE(final int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - if (calculatedIndex + Integer.BYTES <= byteBuf.writerIndex()) { - return byteBuf.getIntLE(calculatedIndex); - } else if (order() == ByteOrder.BIG_ENDIAN) { - return _getShortLE(index) & 0xffff | (_getShortLE(index + 2) & 0xffff) << 16; - } else { - return (_getShortLE(index) & 0xffff) << 16 | _getShortLE(index + 2) & 0xffff; - } - } - - @Override - protected long _getLong(final int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - if (calculatedIndex + Long.BYTES <= byteBuf.writerIndex()) { - return byteBuf.getLong(calculatedIndex); - } else if (order() == ByteOrder.BIG_ENDIAN) { - return (_getInt(index) & 0xffffffffL) << 32 | _getInt(index + 4) & 0xffffffffL; - } else { - return _getInt(index) & 0xFFFFFFFFL | (_getInt(index + 4) & 0xFFFFFFFFL) << 32; - } - } - - @Override - protected long _getLongLE(final int index) { - long ri = calculateRelativeIndex(index); - ByteBuf byteBuf = getPart(index); - - int calculatedIndex = (int) (ri & Integer.MAX_VALUE); - - if (calculatedIndex + Long.BYTES <= byteBuf.writerIndex()) { - return byteBuf.getLongLE(calculatedIndex); - } else if (order() == ByteOrder.BIG_ENDIAN) { - return (_getInt(index) & 0xffffffffL) << 32 | _getInt(index + 4) & 0xffffffffL; - } else { - return _getInt(index) & 0xFFFFFFFFL | (_getInt(index + 4) & 0xFFFFFFFFL) << 32; - } - } - - @Override - public ByteBufAllocator alloc() { - return allocator; - } - - @Override - public int capacity() { - return capacity; - } - - @Override - public ByteBuf capacity(int newCapacity) { - throw new UnsupportedOperationException(); - } - - @Override - public int maxCapacity() { - return capacity; - } - - @Override - public ByteOrder order() { - return ByteOrder.LITTLE_ENDIAN; - } - - @Override - public ByteBuf order(ByteOrder endianness) { - return this; - } - - @Override - public ByteBuf unwrap() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isReadOnly() { - return true; - } - - @Override - public ByteBuf asReadOnly() { - return this; - } - - @Override - public boolean isWritable() { - return false; - } - - @Override - public boolean isWritable(int size) { - return false; - } - - @Override - public ByteBuf writerIndex(int writerIndex) { - return this; - } - - @Override - public final int writerIndex() { - return capacity; - } - - @Override - public ByteBuf setIndex(int readerIndex, int writerIndex) { - return this; - } - - @Override - public ByteBuf clear() { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf discardReadBytes() { - return this; - } - - @Override - public ByteBuf discardSomeReadBytes() { - return this; - } - - @Override - public ByteBuf ensureWritable(int minWritableBytes) { - return this; - } - - @Override - public int ensureWritable(int minWritableBytes, boolean force) { - return NOT_ENOUGH_BYTES_AT_MAX_CAPACITY_CODE; - } - - @Override - public ByteBuf setFloatLE(int index, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setDoubleLE(int index, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBoolean(int index, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setByte(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setShort(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setShortLE(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setMedium(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setMediumLE(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setInt(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setIntLE(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setLong(int index, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setLongLE(int index, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setChar(int index, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setFloat(int index, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setDouble(int index, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, ByteBuf src) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, ByteBuf src, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, byte[] src) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setBytes(int index, ByteBuffer src) { - throw new UnsupportedOperationException(); - } - - @Override - public int setBytes(int index, InputStream in, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public int setBytes(int index, ScatteringByteChannel in, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public int setBytes(int index, FileChannel in, long position, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public int setCharSequence(int index, CharSequence sequence, Charset charset) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf setZero(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBoolean(boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeByte(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeShort(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeShortLE(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeMedium(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeMediumLE(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeInt(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeIntLE(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeLong(long value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeLongLE(long value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeChar(int value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeFloat(float value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeDouble(double value) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(ByteBuf src) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(ByteBuf src, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(byte[] src) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeBytes(ByteBuffer src) { - throw new UnsupportedOperationException(); - } - - @Override - public int writeBytes(InputStream in, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public int writeBytes(ScatteringByteChannel in, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public int writeBytes(FileChannel in, long position, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuf writeZero(int length) { - throw new UnsupportedOperationException(); - } - - @Override - public int writeCharSequence(CharSequence sequence, Charset charset) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteBuffer internalNioBuffer(int index, int length) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean hasArray() { - return false; - } - - @Override - public byte[] array() { - throw new UnsupportedOperationException(); - } - - @Override - public int arrayOffset() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean hasMemoryAddress() { - return false; - } - - @Override - public long memoryAddress() { - throw new UnsupportedOperationException(); - } - - @Override - protected void _setByte(int index, int value) {} - - @Override - protected void _setShort(int index, int value) {} - - @Override - protected void _setShortLE(int index, int value) {} - - @Override - protected void _setMedium(int index, int value) {} - - @Override - protected void _setMediumLE(int index, int value) {} - - @Override - protected void _setInt(int index, int value) {} - - @Override - protected void _setIntLE(int index, int value) {} - - @Override - protected void _setLong(int index, long value) {} - - @Override - protected void _setLongLE(int index, long value) {} -} diff --git a/rsocket-core/src/main/java/io/rsocket/buffer/BufferUtil.java b/rsocket-core/src/main/java/io/rsocket/buffer/BufferUtil.java deleted file mode 100644 index 476583ab3..000000000 --- a/rsocket-core/src/main/java/io/rsocket/buffer/BufferUtil.java +++ /dev/null @@ -1,78 +0,0 @@ -package io.rsocket.buffer; - -import java.lang.reflect.Field; -import java.nio.Buffer; -import java.nio.ByteBuffer; -import java.security.AccessController; -import java.security.PrivilegedExceptionAction; -import sun.misc.Unsafe; - -abstract class BufferUtil { - - private static final Unsafe UNSAFE; - - static { - Unsafe unsafe; - try { - final PrivilegedExceptionAction action = - () -> { - final Field f = Unsafe.class.getDeclaredField("theUnsafe"); - f.setAccessible(true); - - return (Unsafe) f.get(null); - }; - - unsafe = AccessController.doPrivileged(action); - } catch (final Exception ex) { - throw new RuntimeException(ex); - } - - UNSAFE = unsafe; - } - - private static final long BYTE_BUFFER_ADDRESS_FIELD_OFFSET; - - static { - try { - BYTE_BUFFER_ADDRESS_FIELD_OFFSET = - UNSAFE.objectFieldOffset(Buffer.class.getDeclaredField("address")); - } catch (final Exception ex) { - throw new RuntimeException(ex); - } - } - - /** - * Allocate a new direct {@link ByteBuffer} that is aligned on a given alignment boundary. - * - * @param capacity required for the buffer. - * @param alignment boundary at which the buffer should begin. - * @return a new {@link ByteBuffer} with the required alignment. - * @throws IllegalArgumentException if the alignment is not a power of 2. - */ - static ByteBuffer allocateDirectAligned(final int capacity, final int alignment) { - if (alignment == 0) { - return ByteBuffer.allocateDirect(capacity); - } - - if (!isPowerOfTwo(alignment)) { - throw new IllegalArgumentException("Must be a power of 2: alignment=" + alignment); - } - - final ByteBuffer buffer = ByteBuffer.allocateDirect(capacity + alignment); - - final long address = UNSAFE.getLong(buffer, BYTE_BUFFER_ADDRESS_FIELD_OFFSET); - final int remainder = (int) (address & (alignment - 1)); - final int offset = alignment - remainder; - - buffer.limit(capacity + offset); - buffer.position(offset); - - return buffer.slice(); - } - - private static boolean isPowerOfTwo(final int value) { - return value > 0 && ((value & (~value + 1)) == value); - } - - private BufferUtil() {} -} diff --git a/rsocket-core/src/main/java/io/rsocket/buffer/Tuple2ByteBuf.java b/rsocket-core/src/main/java/io/rsocket/buffer/Tuple2ByteBuf.java deleted file mode 100644 index 66c68009a..000000000 --- a/rsocket-core/src/main/java/io/rsocket/buffer/Tuple2ByteBuf.java +++ /dev/null @@ -1,392 +0,0 @@ -package io.rsocket.buffer; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; -import java.nio.channels.GatheringByteChannel; -import java.nio.charset.Charset; - -class Tuple2ByteBuf extends AbstractTupleByteBuf { - - private static final long ONE_MASK = 0x100000000L; - private static final long TWO_MASK = 0x200000000L; - private static final long MASK = 0x700000000L; - - private final ByteBuf one; - private final ByteBuf two; - private final int oneReadIndex; - private final int twoReadIndex; - private final int oneReadableBytes; - private final int twoReadableBytes; - private final int twoRelativeIndex; - - private boolean freed; - - Tuple2ByteBuf(ByteBufAllocator allocator, ByteBuf one, ByteBuf two) { - super(allocator, one.readableBytes() + two.readableBytes()); - - this.one = one; - this.two = two; - - this.oneReadIndex = one.readerIndex(); - this.twoReadIndex = two.readerIndex(); - - this.oneReadableBytes = one.readableBytes(); - this.twoReadableBytes = two.readableBytes(); - - this.twoRelativeIndex = oneReadableBytes; - - this.freed = false; - } - - @Override - long calculateRelativeIndex(int index) { - checkIndex(index, 0); - - long relativeIndex; - long mask; - if (index >= twoRelativeIndex) { - relativeIndex = twoReadIndex + (index - oneReadableBytes); - mask = TWO_MASK; - } else { - relativeIndex = oneReadIndex + index; - mask = ONE_MASK; - } - - return relativeIndex | mask; - } - - @Override - ByteBuf getPart(int index) { - long ri = calculateRelativeIndex(index); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - return one; - case 0x2: - return two; - default: - throw new IllegalStateException(); - } - } - - @Override - public boolean isDirect() { - return one.isDirect() && two.isDirect(); - } - - @Override - public int nioBufferCount() { - return one.nioBufferCount() + two.nioBufferCount(); - } - - @Override - public ByteBuffer nioBuffer() { - ByteBuffer[] oneBuffers = one.nioBuffers(); - ByteBuffer[] twoBuffers = two.nioBuffers(); - - ByteBuffer merged = - BufferUtil.allocateDirectAligned(capacity, DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT) - .order(order()); - - for (ByteBuffer b : oneBuffers) { - merged.put(b); - } - - for (ByteBuffer b : twoBuffers) { - merged.put(b); - } - - merged.flip(); - return merged; - } - - @Override - public ByteBuffer[] _nioBuffers(int index, int length) { - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - ByteBuffer[] oneBuffer; - ByteBuffer[] twoBuffer; - int l = Math.min(oneReadableBytes - index, length); - oneBuffer = one.nioBuffers(index, l); - length -= l; - if (length != 0) { - twoBuffer = two.nioBuffers(twoReadIndex, length); - ByteBuffer[] results = new ByteBuffer[oneBuffer.length + twoBuffer.length]; - System.arraycopy(oneBuffer, 0, results, 0, oneBuffer.length); - System.arraycopy(twoBuffer, 0, results, oneBuffer.length, twoBuffer.length); - return results; - } else { - return oneBuffer; - } - case 0x2: - return two.nioBuffers(index, length); - default: - throw new IllegalStateException(); - } - } - - @Override - public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - one.getBytes(index, dst, dstIndex, l); - length -= l; - dstIndex += l; - - if (length != 0) { - two.getBytes(twoReadIndex, dst, dstIndex, length); - } - - break; - } - case 0x2: - { - two.getBytes(index, dst, dstIndex, length); - break; - } - default: - throw new IllegalStateException(); - } - - return this; - } - - @Override - public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { - ByteBuf dstBuf = Unpooled.wrappedBuffer(dst); - int min = Math.min(dst.length, capacity); - return getBytes(0, dstBuf, index, min); - } - - @Override - public ByteBuf getBytes(int index, ByteBuffer dst) { - ByteBuf dstBuf = Unpooled.wrappedBuffer(dst); - int min = Math.min(dst.limit(), capacity); - return getBytes(0, dstBuf, index, min); - } - - @Override - public ByteBuf getBytes(int index, final OutputStream out, int length) throws IOException { - checkIndex(index, length); - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - one.getBytes(index, out, l); - length -= l; - if (length != 0) { - two.getBytes(twoReadIndex, out, length); - } - break; - } - case 0x2: - { - two.getBytes(index, out, length); - break; - } - default: - throw new IllegalStateException(); - } - - return this; - } - - @Override - public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { - checkIndex(index, length); - int read = 0; - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - read += one.getBytes(index, out, l); - length -= l; - if (length != 0) { - read += two.getBytes(twoReadIndex, out, length); - } - break; - } - case 0x2: - { - read += two.getBytes(index, out, length); - break; - } - default: - throw new IllegalStateException(); - } - - return read; - } - - @Override - public int getBytes(int index, FileChannel out, long position, int length) throws IOException { - checkIndex(index, length); - int read = 0; - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - read += one.getBytes(index, out, position, l); - length -= l; - position += l; - if (length != 0) { - read += two.getBytes(twoReadIndex, out, position, length); - } - break; - } - case 0x2: - { - read += two.getBytes(index, out, position, length); - break; - } - default: - throw new IllegalStateException(); - } - - return read; - } - - @Override - public ByteBuf copy(int index, int length) { - checkIndex(index, length); - - ByteBuf buffer = allocator.buffer(length); - - if (index == 0 && length == capacity) { - buffer.writeBytes(one, oneReadIndex, oneReadableBytes); - buffer.writeBytes(two, twoReadIndex, twoReadableBytes); - - return buffer; - } - - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - buffer.writeBytes(one, index, l); - - length -= l; - - if (length != 0) { - buffer.writeBytes(two, twoReadIndex, length); - } - - return buffer; - } - case 0x2: - { - return buffer.writeBytes(two, index, length); - } - default: - throw new IllegalStateException(); - } - } - - @Override - public ByteBuf slice(final int readIndex, int length) { - checkIndex(readIndex, length); - - if (readIndex == 0 && length == capacity) { - return new Tuple2ByteBuf( - allocator, - one.slice(oneReadIndex, oneReadableBytes), - two.slice(twoReadIndex, twoReadableBytes)); - } - - long ri = calculateRelativeIndex(readIndex); - int index = (int) (ri & Integer.MAX_VALUE); - - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - ByteBuf oneSlice; - ByteBuf twoSlice; - - int l = Math.min(oneReadableBytes - index, length); - oneSlice = one.slice(index, l); - length -= l; - if (length != 0) { - twoSlice = two.slice(twoReadIndex, length); - return new Tuple2ByteBuf(allocator, oneSlice, twoSlice); - } else { - return oneSlice; - } - } - case 0x2: - { - return two.slice(index, length); - } - default: - throw new IllegalStateException(); - } - } - - @Override - protected void deallocate() { - if (freed) { - return; - } - - freed = true; - ReferenceCountUtil.safeRelease(one); - ReferenceCountUtil.safeRelease(two); - } - - @Override - public String toString(Charset charset) { - StringBuilder builder = new StringBuilder(3); - builder.append(one.toString(charset)); - builder.append(two.toString(charset)); - return builder.toString(); - } - - @Override - public String toString(int index, int length, Charset charset) { - // TODO - make this smarter - return toString(charset).substring(index, length); - } - - @Override - public String toString() { - return "Tuple2ByteBuf{" - + "capacity=" - + capacity - + ", one=" - + one - + ", two=" - + two - + ", allocator=" - + allocator - + ", oneReadIndex=" - + oneReadIndex - + ", twoReadIndex=" - + twoReadIndex - + ", oneReadableBytes=" - + oneReadableBytes - + ", twoReadableBytes=" - + twoReadableBytes - + ", twoRelativeIndex=" - + twoRelativeIndex - + '}'; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/buffer/Tuple3ByteBuf.java b/rsocket-core/src/main/java/io/rsocket/buffer/Tuple3ByteBuf.java deleted file mode 100644 index 1a0c1ec31..000000000 --- a/rsocket-core/src/main/java/io/rsocket/buffer/Tuple3ByteBuf.java +++ /dev/null @@ -1,579 +0,0 @@ -package io.rsocket.buffer; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.util.ReferenceCountUtil; -import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.channels.FileChannel; -import java.nio.channels.GatheringByteChannel; -import java.nio.charset.Charset; - -class Tuple3ByteBuf extends AbstractTupleByteBuf { - private static final long ONE_MASK = 0x100000000L; - private static final long TWO_MASK = 0x200000000L; - private static final long THREE_MASK = 0x400000000L; - private static final long MASK = 0x700000000L; - - private final ByteBuf one; - private final ByteBuf two; - private final ByteBuf three; - private final int oneReadIndex; - private final int twoReadIndex; - private final int threeReadIndex; - private final int oneReadableBytes; - private final int twoReadableBytes; - private final int threeReadableBytes; - private final int twoRelativeIndex; - private final int threeRelativeIndex; - - private boolean freed; - - Tuple3ByteBuf(ByteBufAllocator allocator, ByteBuf one, ByteBuf two, ByteBuf three) { - super(allocator, one.readableBytes() + two.readableBytes() + three.readableBytes()); - - this.one = one; - this.two = two; - this.three = three; - - this.oneReadIndex = one.readerIndex(); - this.twoReadIndex = two.readerIndex(); - this.threeReadIndex = three.readerIndex(); - - this.oneReadableBytes = one.readableBytes(); - this.twoReadableBytes = two.readableBytes(); - this.threeReadableBytes = three.readableBytes(); - - this.twoRelativeIndex = oneReadableBytes; - this.threeRelativeIndex = twoRelativeIndex + twoReadableBytes; - - this.freed = false; - } - - @Override - public boolean isDirect() { - return one.isDirect() && two.isDirect() && three.isDirect(); - } - - @Override - public long calculateRelativeIndex(int index) { - checkIndex(index, 0); - long relativeIndex; - long mask; - if (index >= threeRelativeIndex) { - relativeIndex = threeReadIndex + (index - twoReadableBytes - oneReadableBytes); - mask = THREE_MASK; - } else if (index >= twoRelativeIndex) { - relativeIndex = twoReadIndex + (index - oneReadableBytes); - mask = TWO_MASK; - } else { - relativeIndex = oneReadIndex + index; - mask = ONE_MASK; - } - - return relativeIndex | mask; - } - - @Override - public ByteBuf getPart(int index) { - long ri = calculateRelativeIndex(index); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - return one; - case 0x2: - return two; - case 0x4: - return three; - default: - throw new IllegalStateException(); - } - } - - @Override - public int nioBufferCount() { - return one.nioBufferCount() + two.nioBufferCount() + three.nioBufferCount(); - } - - @Override - public ByteBuffer nioBuffer() { - - ByteBuffer[] oneBuffers = one.nioBuffers(); - ByteBuffer[] twoBuffers = two.nioBuffers(); - ByteBuffer[] threeBuffers = three.nioBuffers(); - - ByteBuffer merged = - BufferUtil.allocateDirectAligned(capacity, DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT) - .order(order()); - - for (ByteBuffer b : oneBuffers) { - merged.put(b); - } - - for (ByteBuffer b : twoBuffers) { - merged.put(b); - } - - for (ByteBuffer b : threeBuffers) { - merged.put(b); - } - - merged.flip(); - return merged; - } - - @Override - public ByteBuffer[] _nioBuffers(int index, int length) { - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - ByteBuffer[] oneBuffer; - ByteBuffer[] twoBuffer; - ByteBuffer[] threeBuffer; - int l = Math.min(oneReadableBytes - index, length); - oneBuffer = one.nioBuffers(index, l); - length -= l; - if (length != 0) { - l = Math.min(twoReadableBytes, length); - twoBuffer = two.nioBuffers(twoReadIndex, l); - length -= l; - if (length != 0) { - threeBuffer = three.nioBuffers(threeReadIndex, length); - ByteBuffer[] results = - new ByteBuffer[oneBuffer.length + twoBuffer.length + threeBuffer.length]; - System.arraycopy(oneBuffer, 0, results, 0, oneBuffer.length); - System.arraycopy(twoBuffer, 0, results, oneBuffer.length, twoBuffer.length); - System.arraycopy( - threeBuffer, 0, results, oneBuffer.length + twoBuffer.length, threeBuffer.length); - return results; - } else { - ByteBuffer[] results = new ByteBuffer[oneBuffer.length + twoBuffer.length]; - System.arraycopy(oneBuffer, 0, results, 0, oneBuffer.length); - System.arraycopy(twoBuffer, 0, results, oneBuffer.length, twoBuffer.length); - return results; - } - } else { - return oneBuffer; - } - } - case 0x2: - { - ByteBuffer[] twoBuffer; - ByteBuffer[] threeBuffer; - int l = Math.min(twoReadableBytes - index, length); - twoBuffer = two.nioBuffers(index, l); - length -= l; - if (length != 0) { - threeBuffer = three.nioBuffers(threeReadIndex, length); - ByteBuffer[] results = new ByteBuffer[twoBuffer.length + threeBuffer.length]; - System.arraycopy(twoBuffer, 0, results, 0, twoBuffer.length); - System.arraycopy(threeBuffer, 0, results, twoBuffer.length, threeBuffer.length); - return results; - } else { - return twoBuffer; - } - } - case 0x4: - return three.nioBuffers(index, length); - default: - throw new IllegalStateException(); - } - } - - @Override - public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { - checkDstIndex(index, length, dstIndex, dst.capacity()); - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - one.getBytes(index, dst, dstIndex, l); - length -= l; - dstIndex += l; - - if (length != 0) { - l = Math.min(twoReadableBytes, length); - two.getBytes(twoReadIndex, dst, dstIndex, l); - length -= l; - dstIndex += l; - - if (length != 0) { - three.getBytes(threeReadIndex, dst, dstIndex, length); - } - } - break; - } - case 0x2: - { - int l = Math.min(twoReadableBytes - index, length); - two.getBytes(index, dst, dstIndex, l); - length -= l; - dstIndex += l; - - if (length != 0) { - three.getBytes(threeReadIndex, dst, dstIndex, length); - } - break; - } - case 0x4: - { - three.getBytes(index, dst, dstIndex, length); - break; - } - default: - throw new IllegalStateException(); - } - - return this; - } - - @Override - public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { - ByteBuf dstBuf = Unpooled.wrappedBuffer(dst); - int min = Math.min(dst.length, capacity); - return getBytes(0, dstBuf, index, min); - } - - @Override - public ByteBuf getBytes(int index, ByteBuffer dst) { - ByteBuf dstBuf = Unpooled.wrappedBuffer(dst); - int min = Math.min(dst.limit(), capacity); - return getBytes(0, dstBuf, index, min); - } - - @Override - public ByteBuf getBytes(int index, final OutputStream out, int length) throws IOException { - checkIndex(index, length); - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - one.getBytes(index, out, l); - length -= l; - if (length != 0) { - l = Math.min(twoReadableBytes, length); - two.getBytes(twoReadIndex, out, l); - length -= l; - if (length != 0) { - three.getBytes(threeReadIndex, out, length); - } - } - break; - } - case 0x2: - { - int l = Math.min(twoReadableBytes - index, length); - two.getBytes(index, out, l); - length -= l; - - if (length != 0) { - three.getBytes(threeReadIndex, out, length); - } - break; - } - case 0x4: - { - three.getBytes(index, out, length); - - break; - } - default: - throw new IllegalStateException(); - } - - return this; - } - - @Override - public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { - checkIndex(index, length); - int read = 0; - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - read += one.getBytes(index, out, l); - length -= l; - if (length != 0) { - l = Math.min(twoReadableBytes, length); - read += two.getBytes(twoReadIndex, out, l); - length -= l; - if (length != 0) { - read += three.getBytes(threeReadIndex, out, length); - } - } - break; - } - case 0x2: - { - int l = Math.min(twoReadableBytes - index, length); - read += two.getBytes(index, out, l); - length -= l; - - if (length != 0) { - read += three.getBytes(threeReadIndex, out, length); - } - break; - } - case 0x4: - { - read += three.getBytes(index, out, length); - - break; - } - default: - throw new IllegalStateException(); - } - - return read; - } - - @Override - public int getBytes(int index, FileChannel out, long position, int length) throws IOException { - checkIndex(index, length); - int read = 0; - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - read += one.getBytes(index, out, position, l); - length -= l; - position += l; - - if (length != 0) { - l = Math.min(twoReadableBytes, length); - read += two.getBytes(twoReadIndex, out, position, l); - length -= l; - position += l; - - if (length != 0) { - read += three.getBytes(threeReadIndex, out, position, length); - } - } - break; - } - case 0x2: - { - int l = Math.min(twoReadableBytes - index, length); - read += two.getBytes(index, out, position, l); - length -= l; - position += l; - - if (length != 0) { - read += three.getBytes(threeReadIndex, out, position, length); - } - break; - } - case 0x4: - { - read += three.getBytes(index, out, position, length); - - break; - } - default: - throw new IllegalStateException(); - } - - return read; - } - - @Override - public ByteBuf copy(int index, int length) { - checkIndex(index, length); - - ByteBuf buffer = allocator.buffer(length); - - if (index == 0 && length == capacity) { - buffer.writeBytes(one, oneReadIndex, oneReadableBytes); - buffer.writeBytes(two, twoReadIndex, twoReadableBytes); - buffer.writeBytes(three, threeReadIndex, threeReadableBytes); - - return buffer; - } - - long ri = calculateRelativeIndex(index); - index = (int) (ri & Integer.MAX_VALUE); - - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - int l = Math.min(oneReadableBytes - index, length); - buffer.writeBytes(one, index, l); - length -= l; - - if (length != 0) { - l = Math.min(twoReadableBytes, length); - buffer.writeBytes(two, twoReadIndex, l); - length -= l; - if (length != 0) { - buffer.writeBytes(three, threeReadIndex, length); - } - } - - return buffer; - } - case 0x2: - { - int l = Math.min(twoReadableBytes - index, length); - buffer.writeBytes(two, index, l); - length -= l; - - if (length != 0) { - buffer.writeBytes(three, threeReadIndex, length); - } - - return buffer; - } - case 0x4: - { - buffer.writeBytes(three, index, length); - - return buffer; - } - default: - throw new IllegalStateException(); - } - } - - @Override - public ByteBuf retainedSlice() { - return new Tuple3ByteBuf( - allocator, - one.retainedSlice(oneReadIndex, oneReadableBytes), - two.retainedSlice(twoReadIndex, twoReadableBytes), - three.retainedSlice(threeReadIndex, threeReadableBytes)); - } - - @Override - public ByteBuf slice(final int readIndex, int length) { - checkIndex(readIndex, length); - - if (readIndex == 0 && length == capacity) { - return new Tuple3ByteBuf( - allocator, - one.slice(oneReadIndex, oneReadableBytes), - two.slice(twoReadIndex, twoReadableBytes), - three.slice(threeReadIndex, threeReadableBytes)); - } - - long ri = calculateRelativeIndex(readIndex); - int index = (int) (ri & Integer.MAX_VALUE); - switch ((int) ((ri & MASK) >>> 32L)) { - case 0x1: - { - ByteBuf oneSlice; - ByteBuf twoSlice; - ByteBuf threeSlice; - - int l = Math.min(oneReadableBytes - index, length); - oneSlice = one.slice(index, l); - length -= l; - if (length != 0) { - l = Math.min(twoReadableBytes, length); - twoSlice = two.slice(twoReadIndex, l); - length -= l; - if (length != 0) { - threeSlice = three.slice(threeReadIndex, length); - return new Tuple3ByteBuf(allocator, oneSlice, twoSlice, threeSlice); - } else { - return new Tuple2ByteBuf(allocator, oneSlice, twoSlice); - } - - } else { - return oneSlice; - } - } - case 0x2: - { - ByteBuf twoSlice; - ByteBuf threeSlice; - - int l = Math.min(twoReadableBytes - index, length); - twoSlice = two.slice(index, l); - length -= l; - if (length != 0) { - threeSlice = three.slice(threeReadIndex, length); - return new Tuple2ByteBuf(allocator, twoSlice, threeSlice); - } else { - return twoSlice; - } - } - case 0x4: - { - return three.slice(index, length); - } - default: - throw new IllegalStateException(); - } - } - - @Override - protected void deallocate() { - if (freed) { - return; - } - - freed = true; - ReferenceCountUtil.safeRelease(one); - ReferenceCountUtil.safeRelease(two); - ReferenceCountUtil.safeRelease(three); - } - - @Override - public String toString(Charset charset) { - StringBuilder builder = new StringBuilder(3); - builder.append(one.toString(charset)); - builder.append(two.toString(charset)); - builder.append(three.toString(charset)); - return builder.toString(); - } - - @Override - public String toString(int index, int length, Charset charset) { - // TODO - make this smarter - return toString(charset).substring(index, length); - } - - @Override - public String toString() { - return "Tuple3ByteBuf{" - + "capacity=" - + capacity - + ", one=" - + one - + ", two=" - + two - + ", three=" - + three - + ", allocator=" - + allocator - + ", oneReadIndex=" - + oneReadIndex - + ", twoReadIndex=" - + twoReadIndex - + ", threeReadIndex=" - + threeReadIndex - + ", oneReadableBytes=" - + oneReadableBytes - + ", twoReadableBytes=" - + twoReadableBytes - + ", threeReadableBytes=" - + threeReadableBytes - + ", twoRelativeIndex=" - + twoRelativeIndex - + ", threeRelativeIndex=" - + threeRelativeIndex - + '}'; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/buffer/TupleByteBuf.java b/rsocket-core/src/main/java/io/rsocket/buffer/TupleByteBuf.java deleted file mode 100644 index 8c8e2e7e4..000000000 --- a/rsocket-core/src/main/java/io/rsocket/buffer/TupleByteBuf.java +++ /dev/null @@ -1,35 +0,0 @@ -package io.rsocket.buffer; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import java.util.Objects; - -public abstract class TupleByteBuf { - - private TupleByteBuf() {} - - public static ByteBuf of(ByteBuf one, ByteBuf two) { - return of(ByteBufAllocator.DEFAULT, one, two); - } - - public static ByteBuf of(ByteBufAllocator allocator, ByteBuf one, ByteBuf two) { - Objects.requireNonNull(allocator); - Objects.requireNonNull(one); - Objects.requireNonNull(two); - - return new Tuple2ByteBuf(allocator, one, two); - } - - public static ByteBuf of(ByteBuf one, ByteBuf two, ByteBuf three) { - return of(ByteBufAllocator.DEFAULT, one, two, three); - } - - public static ByteBuf of(ByteBufAllocator allocator, ByteBuf one, ByteBuf two, ByteBuf three) { - Objects.requireNonNull(allocator); - Objects.requireNonNull(one); - Objects.requireNonNull(two); - Objects.requireNonNull(three); - - return new Tuple3ByteBuf(allocator, one, two, three); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java new file mode 100644 index 000000000..9b5647c6f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java @@ -0,0 +1,119 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.SetupFrameCodec; + +/** + * Default implementation of {@link ConnectionSetupPayload}. Primarily for internal use within + * RSocket Java but may be created in an application, e.g. for testing purposes. + */ +public class DefaultConnectionSetupPayload extends ConnectionSetupPayload { + + private final ByteBuf setupFrame; + + public DefaultConnectionSetupPayload(ByteBuf setupFrame) { + this.setupFrame = setupFrame; + } + + @Override + public boolean hasMetadata() { + return FrameHeaderCodec.hasMetadata(setupFrame); + } + + @Override + public ByteBuf sliceMetadata() { + final ByteBuf metadata = SetupFrameCodec.metadata(setupFrame); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + return SetupFrameCodec.data(setupFrame); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public String metadataMimeType() { + return SetupFrameCodec.metadataMimeType(setupFrame); + } + + @Override + public String dataMimeType() { + return SetupFrameCodec.dataMimeType(setupFrame); + } + + @Override + public int keepAliveInterval() { + return SetupFrameCodec.keepAliveInterval(setupFrame); + } + + @Override + public int keepAliveMaxLifetime() { + return SetupFrameCodec.keepAliveMaxLifetime(setupFrame); + } + + @Override + public int getFlags() { + return FrameHeaderCodec.flags(setupFrame); + } + + @Override + public boolean willClientHonorLease() { + return SetupFrameCodec.honorLease(setupFrame); + } + + @Override + public boolean isResumeEnabled() { + return SetupFrameCodec.resumeEnabled(setupFrame); + } + + @Override + public ByteBuf resumeToken() { + return SetupFrameCodec.resumeToken(setupFrame); + } + + @Override + public ConnectionSetupPayload touch() { + setupFrame.touch(); + return this; + } + + @Override + public ConnectionSetupPayload touch(Object hint) { + setupFrame.touch(hint); + return this; + } + + @Override + protected void deallocate() { + setupFrame.release(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java new file mode 100644 index 000000000..2d2b96f7e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -0,0 +1,32 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; + +final class PayloadValidationUtils { + static final String INVALID_PAYLOAD_ERROR_MESSAGE = + "The payload is too big to send as a single frame with a 24-bit encoded length. Consider enabling fragmentation via RSocketFactory."; + + static boolean isValid(int mtu, Payload payload) { + if (mtu > 0) { + return true; + } + + if (payload.hasMetadata()) { + return (((FrameHeaderCodec.size() + + FrameLengthCodec.FRAME_LENGTH_SIZE + + FrameHeaderCodec.size() + + payload.data().readableBytes() + + payload.metadata().readableBytes()) + & ~FrameLengthCodec.FRAME_LENGTH_MASK) + == 0); + } else { + return (((FrameHeaderCodec.size() + + payload.data().readableBytes() + + FrameLengthCodec.FRAME_LENGTH_SIZE) + & ~FrameLengthCodec.FRAME_LENGTH_MASK) + == 0); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java new file mode 100644 index 000000000..02f1a51cc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -0,0 +1,580 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.lease.LeaseStats; +import io.rsocket.lease.Leases; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.resume.ClientRSocketSession; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.util.annotation.Nullable; +import reactor.util.retry.Retry; + +/** + * The main class to use to establish a connection to an RSocket server. + * + *

To connect over TCP using default settings: + * + *

{@code
+ * import io.rsocket.transport.netty.client.TcpClientTransport;
+ *
+ * Mono rocketMono =
+ *         RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000));
+ * }
+ * + *

To customize connection settings before connecting: + * + *

{@code
+ * Mono rocketMono =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ * }
+ */ +public class RSocketConnector { + private static final String CLIENT_TAG = "client"; + + private static final BiConsumer INVALIDATE_FUNCTION = + (r, i) -> r.onClose().subscribe(null, __ -> i.invalidate(), i::invalidate); + + private Payload setupPayload = EmptyPayload.INSTANCE; + private String metadataMimeType = "application/binary"; + private String dataMimeType = "application/binary"; + private Duration keepAliveInterval = Duration.ofSeconds(20); + private Duration keepAliveMaxLifeTime = Duration.ofSeconds(90); + + @Nullable private SocketAcceptor acceptor; + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Retry retrySpec; + private Resume resume; + private Supplier> leasesSupplier; + + private int mtu = 0; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + + private RSocketConnector() {} + + /** + * Static factory method to create an {@code RSocketConnector} instance and customize default + * settings before connecting. To connect only, use {@link #connectWith(ClientTransport)}. + */ + public static RSocketConnector create() { + return new RSocketConnector(); + } + + /** + * Static factory method to connect with default settings, effectively a shortcut for: + * + *
+   * RSocketConnector.create().connectWith(transport);
+   * 
+ * + * @param transport the transport of choice to connect with + * @return a {@code Mono} with the connected RSocket + */ + public static Mono connectWith(ClientTransport transport) { + return RSocketConnector.create().connect(() -> transport); + } + + /** + * Provide a {@code Payload} with data and/or metadata for the initial {@code SETUP} frame. Data + * and metadata should be formatted according to the MIME types specified via {@link + * #dataMimeType(String)} and {@link #metadataMimeType(String)}. + * + * @param payload the payload containing data and/or metadata for the {@code SETUP} frame + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector setupPayload(Payload payload) { + this.setupPayload = Objects.requireNonNull(payload); + return this; + } + + /** + * Set the MIME type to use for formatting payload data on the established connection. This is set + * in the initial {@code SETUP} frame sent to the server. + * + *

By default this is set to {@code "application/binary"}. + * + * @param dataMimeType the MIME type to be used for payload data + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector dataMimeType(String dataMimeType) { + this.dataMimeType = Objects.requireNonNull(dataMimeType); + return this; + } + + /** + * Set the MIME type to use for formatting payload metadata on the established connection. This is + * set in the initial {@code SETUP} frame sent to the server. + * + *

For metadata encoding, consider using one of the following encoders: + * + *

    + *
  • {@link io.rsocket.metadata.CompositeMetadataFlyweight Composite Metadata} + *
  • {@link io.rsocket.metadata.TaggingMetadataFlyweight Routing} + *
  • {@link io.rsocket.metadata.security.AuthMetadataFlyweight Authentication} + *
+ * + *

For more on the above metadata formats, see the corresponding protocol extensions + * + *

By default this is set to {@code "application/binary"}. + * + * @param metadataMimeType the MIME type to be used for payload metadata + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector metadataMimeType(String metadataMimeType) { + this.metadataMimeType = Objects.requireNonNull(metadataMimeType); + return this; + } + + /** + * Set the "Time Between {@code KEEPALIVE} Frames" which is how frequently {@code KEEPALIVE} + * frames should be emitted, and the "Max Lifetime" which is how long to allow between {@code + * KEEPALIVE} frames from the remote end before concluding that connectivity is lost. Both + * settings are specified in the initial {@code SETUP} frame sent to the server. The spec mentions + * the following: + * + *

    + *
  • For server-to-server connections, a reasonable time interval between client {@code + * KEEPALIVE} frames is 500ms. + *
  • For mobile-to-server connections, the time interval between client {@code KEEPALIVE} + * frames is often > 30,000ms. + *
+ * + *

By default these are set to 20 seconds and 90 seconds respectively. + * + * @param interval how frequently to emit KEEPALIVE frames + * @param maxLifeTime how long to allow between {@code KEEPALIVE} frames from the remote end + * before assuming that connectivity is lost; the value should be generous and allow for + * multiple missed {@code KEEPALIVE} frames. + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector keepAlive(Duration interval, Duration maxLifeTime) { + if (!interval.negated().isNegative()) { + throw new IllegalArgumentException("`interval` for keepAlive must be > 0"); + } + if (!maxLifeTime.negated().isNegative()) { + throw new IllegalArgumentException("`maxLifeTime` for keepAlive must be > 0"); + } + this.keepAliveInterval = interval; + this.keepAliveMaxLifeTime = maxLifeTime; + return this; + } + + /** + * Configure interception at one of the following levels: + * + *

    + *
  • Transport level + *
  • At the level of accepting new connections + *
  • Performing requests + *
  • Responding to requests + *
+ * + * @param configurer a configurer to customize interception with. + * @return the same instance for method chaining + * @see io.rsocket.plugins.LimitRateInterceptor + */ + public RSocketConnector interceptors(Consumer configurer) { + configurer.accept(this.interceptors); + return this; + } + + /** + * Configure a client-side {@link SocketAcceptor} for responding to requests from the server. + * + *

A full-form example with access to the {@code SETUP} frame and the "sending" RSocket (the + * same as the one returned from {@link #connect(ClientTransport)}): + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor((setup, sendingRSocket) -> Mono.just(new RSocket() {...}))
+   *             .connect(transport);
+   * }
+ * + *

A shortcut example with just the handling RSocket: + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor(SocketAcceptor.with(new RSocket() {...})))
+   *             .connect(transport);
+   * }
+ * + *

A shortcut example handling only request-response: + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor(SocketAcceptor.forRequestResponse(payload -> ...))
+   *             .connect(transport);
+   * }
+ * + *

By default, {@code new RSocket(){}} is used which rejects all requests from the server with + * {@link UnsupportedOperationException}. + * + * @param acceptor the acceptor to use for responding to server requests + * @return the same instance for method chaining + */ + public RSocketConnector acceptor(SocketAcceptor acceptor) { + this.acceptor = acceptor; + return this; + } + + /** + * When this is enabled, the connect methods of this class return a special {@code Mono} + * that maintains a single, shared {@code RSocket} for all subscribers: + * + *

{@code
+   * Mono rsocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = rsocketMono.block();
+   *  RSocket r2 = rsocketMono.block();
+   *
+   *  assert r1 == r2;
+   * }
+ * + *

The {@code RSocket} remains cached until the connection is lost and after that, new attempts + * to subscribe or re-subscribe trigger a reconnect and result in a new shared {@code RSocket}: + * + *

{@code
+   * Mono rsocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = rsocketMono.block();
+   *  RSocket r2 = rsocketMono.block();
+   *
+   *  r1.dispose();
+   *
+   *  RSocket r3 = rsocketMono.block();
+   *  RSocket r4 = rsocketMono.block();
+   *
+   *  assert r1 == r2;
+   *  assert r3 == r4;
+   *  assert r1 != r3;
+   *
+   * }
+ * + *

Downstream subscribers for individual requests still need their own retry logic to determine + * if or when failed requests should be retried which in turn triggers the shared reconnect: + * + *

{@code
+   * Mono rocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  rsocketMono.flatMap(rsocket -> rsocket.requestResponse(...))
+   *           .retryWhen(Retry.fixedDelay(1, Duration.ofSeconds(5)))
+   *           .subscribe()
+   * }
+ * + *

Note: this feature is mutually exclusive with {@link #resume(Resume)}. If + * both are enabled, "resume" takes precedence. Consider using "reconnect" when the server does + * not have "resume" enabled or supported, or when you don't need to incur the overhead of saving + * in-flight frames to be potentially replayed after a reconnect. + * + *

By default this is not enabled in which case a new connection is obtained per subscriber. + * + * @param retry a retry spec that declares the rules for reconnecting + * @return the same instance for method chaining + */ + public RSocketConnector reconnect(Retry retry) { + this.retrySpec = Objects.requireNonNull(retry); + return this; + } + + /** + * Enables the Resume capability of the RSocket protocol where if the client gets disconnected, + * the connection is re-acquired and any interrupted streams are resumed automatically. For this + * to work the server must also support and have the Resume capability enabled. + * + *

See {@link Resume} for settings to customize the Resume capability. + * + *

Note: this feature is mutually exclusive with {@link #reconnect(Retry)}. If + * both are enabled, "resume" takes precedence. Consider using "reconnect" when the server does + * not have "resume" enabled or supported, or when you don't need to incur the overhead of saving + * in-flight frames to be potentially replayed after a reconnect. + * + *

By default this is not enabled. + * + * @param resume configuration for the Resume capability + * @return the same instance for method chaining + * @see Resuming + * Operation + */ + public RSocketConnector resume(Resume resume) { + this.resume = resume; + return this; + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. + * + *

Example usage: + * + *

{@code
+   * Mono rocketMono =
+   *         RSocketConnector.create().lease(Leases::new).connect(transport);
+   * }
+ * + *

By default this is not enabled. + * + * @param supplier supplier for a {@link Leases} + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketConnector lease(Supplier> supplier) { + this.leasesSupplier = supplier; + return this; + } + + /** + * When this is set, frames larger than the given maximum transmission unit (mtu) size value are + * broken down into fragments to fit that size. + * + *

By default this is not set in which case payloads are sent whole up to the maximum frame + * size of 16,777,215 bytes. + * + * @param mtu the threshold size for fragmentation, must be no less than 64 + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketConnector fragment(int mtu) { + if (mtu > 0 && mtu < FragmentationDuplexConnection.MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format( + "The smallest allowed mtu size is %d bytes, provided: %d", + FragmentationDuplexConnection.MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } + this.mtu = mtu; + return this; + } + + /** + * Configure the {@code PayloadDecoder} used to create {@link Payload}'s from incoming raw frame + * buffers. The following decoders are available: + * + *

    + *
  • {@link PayloadDecoder#DEFAULT} -- the data and metadata are independent copies of the + * underlying frame {@link ByteBuf} + *
  • {@link PayloadDecoder#ZERO_COPY} -- the data and metadata are retained slices of the + * underlying {@link ByteBuf}. That's more efficient but requires careful tracking and + * {@link Payload#release() release} of the payload when no longer needed. + *
+ * + *

By default this is set to {@link PayloadDecoder#DEFAULT} in which case data and metadata are + * copied and do not need to be tracked and released. + * + * @param decoder the decoder to use + * @return the same instance for method chaining + */ + public RSocketConnector payloadDecoder(PayloadDecoder decoder) { + Objects.requireNonNull(decoder); + this.payloadDecoder = decoder; + return this; + } + + /** + * The final step to connect with the transport to use as input and the resulting {@code + * Mono} as output. Each subscriber to the returned {@code Mono} starts a new connection + * if neither {@link #reconnect(Retry) reconnect} nor {@link #resume(Resume)} are enabled. + * + *

The following transports are available (via additional RSocket Java modules): + * + *

    + *
  • {@link io.rsocket.transport.netty.client.TcpClientTransport TcpClientTransport} via + * {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.netty.client.WebsocketClientTransport + * WebsocketClientTransport} via {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.local.LocalClientTransport LocalClientTransport} via {@code + * rsocket-transport-local} + *
+ * + * @param transport the transport of choice to connect with + * @return a {@code Mono} with the connected RSocket + */ + public Mono connect(ClientTransport transport) { + return connect(() -> transport); + } + + /** + * Variant of {@link #connect(ClientTransport)} with a {@link Supplier} for the {@code + * ClientTransport}. + * + *

// TODO: when to use? + * + * @param transportSupplier supplier for the transport to connect with + * @return a {@code Mono} with the connected RSocket + */ + public Mono connect(Supplier transportSupplier) { + Mono connectionMono = + Mono.fromSupplier(transportSupplier).flatMap(t -> t.connect(mtu)); + return connectionMono + .flatMap( + connection -> { + ByteBuf resumeToken; + KeepAliveHandler keepAliveHandler; + DuplexConnection wrappedConnection; + + if (resume != null) { + resumeToken = resume.getTokenSupplier().get(); + ClientRSocketSession session = + new ClientRSocketSession( + connection, + resume.getSessionDuration(), + resume.getRetry(), + resume.getStoreFactory(CLIENT_TAG).apply(resumeToken), + resume.getStreamTimeout(), + resume.isCleanupStoreOnKeepAlive()) + .continueWith(connectionMono) + .resumeToken(resumeToken); + keepAliveHandler = + new KeepAliveHandler.ResumableKeepAliveHandler(session.resumableConnection()); + wrappedConnection = session.resumableConnection(); + } else { + resumeToken = Unpooled.EMPTY_BUFFER; + keepAliveHandler = new KeepAliveHandler.DefaultKeepAliveHandler(connection); + wrappedConnection = connection; + } + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(wrappedConnection, interceptors, true); + + boolean leaseEnabled = leasesSupplier != null; + Leases leases = leaseEnabled ? leasesSupplier.get() : null; + RequesterLeaseHandler requesterLeaseHandler = + leaseEnabled + ? new RequesterLeaseHandler.Impl(CLIENT_TAG, leases.receiver()) + : RequesterLeaseHandler.None; + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), + payloadDecoder, + StreamIdSupplier.clientSupplier(), + mtu, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + keepAliveHandler, + requesterLeaseHandler, + Schedulers.single(Schedulers.parallel())); + + RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); + + ByteBuf setupFrame = + SetupFrameCodec.encode( + wrappedConnection.alloc(), + leaseEnabled, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + resumeToken, + metadataMimeType, + dataMimeType, + setupPayload); + + SocketAcceptor acceptor = + this.acceptor != null ? this.acceptor : SocketAcceptor.with(new RSocket() {}); + + ConnectionSetupPayload setup = new DefaultConnectionSetupPayload(setupFrame); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setup, wrappedRSocketRequester) + .flatMap( + rSocketHandler -> { + RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); + + ResponderLeaseHandler responderLeaseHandler = + leaseEnabled + ? new ResponderLeaseHandler.Impl<>( + CLIENT_TAG, + wrappedConnection.alloc(), + leases.sender(), + leases.stats()) + : ResponderLeaseHandler.None; + + RSocket rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + wrappedRSocketHandler, + payloadDecoder, + responderLeaseHandler, + mtu); + + return wrappedConnection + .sendOne(setupFrame) + .thenReturn(wrappedRSocketRequester); + }); + }) + .as( + source -> { + if (retrySpec != null) { + return new ReconnectMono<>( + source.retryWhen(retrySpec), Disposable::dispose, INVALIDATE_FUNCTION); + } else { + return source; + } + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java new file mode 100644 index 000000000..f7fb161fd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -0,0 +1,773 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; +import static io.rsocket.keepalive.KeepAliveSupport.KeepAlive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.SynchronizedIntObjectHashMap; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.keepalive.KeepAliveFramesAcceptor; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.lease.RequesterLeaseHandler; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Processor; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.core.publisher.UnicastProcessor; +import reactor.core.scheduler.Scheduler; +import reactor.util.annotation.Nullable; +import reactor.util.concurrent.Queues; + +/** + * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer + */ +class RSocketRequester implements RSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketRequester.class); + + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + referenceCounted -> { + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + }; + + static { + CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); + } + + private volatile Throwable terminationError; + + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketRequester.class, Throwable.class, "terminationError"); + + private final DuplexConnection connection; + private final PayloadDecoder payloadDecoder; + private final StreamIdSupplier streamIdSupplier; + private final IntObjectMap senders; + private final IntObjectMap> receivers; + private final UnboundedProcessor sendProcessor; + private final int mtu; + private final RequesterLeaseHandler leaseHandler; + private final ByteBufAllocator allocator; + private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; + private final MonoProcessor onClose; + private final Scheduler serialScheduler; + + RSocketRequester( + DuplexConnection connection, + PayloadDecoder payloadDecoder, + StreamIdSupplier streamIdSupplier, + int mtu, + int keepAliveTickPeriod, + int keepAliveAckTimeout, + @Nullable KeepAliveHandler keepAliveHandler, + RequesterLeaseHandler leaseHandler, + Scheduler serialScheduler) { + this.connection = connection; + this.allocator = connection.alloc(); + this.payloadDecoder = payloadDecoder; + this.streamIdSupplier = streamIdSupplier; + this.mtu = mtu; + this.leaseHandler = leaseHandler; + this.senders = new SynchronizedIntObjectHashMap<>(); + this.receivers = new SynchronizedIntObjectHashMap<>(); + this.onClose = MonoProcessor.create(); + this.serialScheduler = serialScheduler; + + // DO NOT Change the order here. The Send processor must be subscribed to before receiving + this.sendProcessor = new UnboundedProcessor<>(); + + connection.onClose().subscribe(null, this::tryTerminateOnConnectionError, this::tryShutdown); + connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); + + connection.receive().subscribe(this::handleIncomingFrames, e -> {}); + + if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { + KeepAliveSupport keepAliveSupport = + new ClientKeepAliveSupport(this.allocator, keepAliveTickPeriod, keepAliveAckTimeout); + this.keepAliveFramesAcceptor = + keepAliveHandler.start( + keepAliveSupport, sendProcessor::onNextPrioritized, this::tryTerminateOnKeepAlive); + } else { + keepAliveFramesAcceptor = null; + } + } + + @Override + public Mono fireAndForget(Payload payload) { + return handleFireAndForget(payload); + } + + @Override + public Mono requestResponse(Payload payload) { + return handleRequestResponse(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return handleRequestStream(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return handleChannel(Flux.from(payloads)); + } + + @Override + public Mono metadataPush(Payload payload) { + return handleMetadataPush(payload); + } + + @Override + public double availability() { + return Math.min(connection.availability(), leaseHandler.availability()); + } + + @Override + public void dispose() { + tryShutdown(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + + private Mono handleFireAndForget(Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + + Throwable err = checkAvailable(); + if (err != null) { + payload.release(); + return Mono.error(err); + } + + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + + final AtomicBoolean once = new AtomicBoolean(); + + return Mono.defer( + () -> { + if (once.getAndSet(true)) { + return Mono.error( + new IllegalStateException("FireAndForgetMono allows only a single subscriber")); + } + + final int streamId = streamIdSupplier.nextStreamId(receivers); + final ByteBuf requestFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload( + allocator, streamId, payload); + + sendProcessor.onNext(requestFrame); + + return Mono.empty(); + }) + .subscribeOn(serialScheduler); + } + + private Mono handleRequestResponse(final Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + + Throwable err = checkAvailable(); + if (err != null) { + payload.release(); + return Mono.error(err); + } + + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + + final UnboundedProcessor sendProcessor = this.sendProcessor; + final UnicastProcessor receiver = UnicastProcessor.create(Queues.one().get()); + final AtomicBoolean once = new AtomicBoolean(); + + return Mono.defer( + () -> { + if (once.getAndSet(true)) { + return Mono.error( + new IllegalStateException("RequestResponseMono allows only a single subscriber")); + } + + return receiver + .next() + .transform( + Operators.lift( + (s, actual) -> + new RequestOperator(actual) { + + @Override + void hookOnFirstRequest(long n) { + int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload( + allocator, streamId, payload); + + receivers.put(streamId, receiver); + sendProcessor.onNext(requestResponseFrame); + } + + @Override + void hookOnCancel() { + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } else { + payload.release(); + } + } + + @Override + public void hookOnTerminal(SignalType signalType) { + receivers.remove(streamId, receiver); + } + })) + .subscribeOn(serialScheduler) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + }); + } + + private Flux handleRequestStream(final Payload payload) { + if (payload.refCnt() <= 0) { + return Flux.error(new IllegalReferenceCountException()); + } + + Throwable err = checkAvailable(); + if (err != null) { + payload.release(); + return Flux.error(err); + } + + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Flux.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + + final UnboundedProcessor sendProcessor = this.sendProcessor; + final UnicastProcessor receiver = UnicastProcessor.create(); + final AtomicBoolean once = new AtomicBoolean(); + + return Flux.defer( + () -> { + if (once.getAndSet(true)) { + return Flux.error( + new IllegalStateException("RequestStreamFlux allows only a single subscriber")); + } + + return receiver + .transform( + Operators.lift( + (s, actual) -> + new RequestOperator(actual) { + + @Override + void hookOnFirstRequest(long n) { + int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload( + allocator, streamId, n, payload); + + receivers.put(streamId, receiver); + + sendProcessor.onNext(requestStreamFrame); + } + + @Override + void hookOnRemainingRequests(long n) { + if (receiver.isDisposed()) { + return; + } + + sendProcessor.onNext( + RequestNFrameCodec.encode(allocator, streamId, n)); + } + + @Override + void hookOnCancel() { + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } else { + payload.release(); + } + } + + @Override + void hookOnTerminal(SignalType signalType) { + receivers.remove(streamId); + } + })) + .subscribeOn(serialScheduler, false) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + }); + } + + private Flux handleChannel(Flux request) { + Throwable err = checkAvailable(); + if (err != null) { + return Flux.error(err); + } + + return request + .switchOnFirst( + (s, flux) -> { + Payload payload = s.get(); + if (payload != null) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + return Mono.error(t); + } + return handleChannel(payload, flux); + } else { + return flux; + } + }, + false) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); + } + + private Flux handleChannel(Payload initialPayload, Flux inboundFlux) { + final UnboundedProcessor sendProcessor = this.sendProcessor; + + final UnicastProcessor receiver = UnicastProcessor.create(); + + return receiver + .transform( + Operators.lift( + (s, actual) -> + new RequestOperator(actual) { + + final BaseSubscriber upstreamSubscriber = + new BaseSubscriber() { + + boolean first = true; + + @Override + protected void hookOnSubscribe(Subscription subscription) { + // noops + } + + @Override + protected void hookOnNext(Payload payload) { + if (first) { + // need to skip first since we have already sent it + // no need to release it since it was released earlier on the + // request + // establishment + // phase + first = false; + request(1); + return; + } + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + // no need to send any errors. + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + receiver.onError(t); + return; + } + final ByteBuf frame = + PayloadFrameCodec.encodeNextReleasingPayload( + allocator, streamId, payload); + + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnComplete() { + ByteBuf frame = PayloadFrameCodec.encodeComplete(allocator, streamId); + sendProcessor.onNext(frame); + } + + @Override + protected void hookOnError(Throwable t) { + ByteBuf frame = ErrorFrameCodec.encode(allocator, streamId, t); + sendProcessor.onNext(frame); + receiver.onError(t); + } + + @Override + protected void hookFinally(SignalType type) { + senders.remove(streamId, this); + } + }; + + @Override + void hookOnFirstRequest(long n) { + final int streamId = streamIdSupplier.nextStreamId(receivers); + this.streamId = streamId; + + final ByteBuf frame = + RequestChannelFrameCodec.encodeReleasingPayload( + allocator, streamId, false, n, initialPayload); + + senders.put(streamId, upstreamSubscriber); + receivers.put(streamId, receiver); + + inboundFlux + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER) + .subscribe(upstreamSubscriber); + + sendProcessor.onNext(frame); + } + + @Override + void hookOnRemainingRequests(long n) { + if (receiver.isDisposed()) { + return; + } + + sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); + } + + @Override + void hookOnCancel() { + senders.remove(streamId, upstreamSubscriber); + if (receivers.remove(streamId, receiver)) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } + } + + @Override + void hookOnTerminal(SignalType signalType) { + if (signalType == SignalType.ON_ERROR) { + upstreamSubscriber.cancel(); + } + receivers.remove(streamId, receiver); + } + + @Override + public void cancel() { + upstreamSubscriber.cancel(); + super.cancel(); + } + })) + .subscribeOn(serialScheduler, false); + } + + private Mono handleMetadataPush(Payload payload) { + if (payload.refCnt() <= 0) { + return Mono.error(new IllegalReferenceCountException()); + } + + Throwable err = this.terminationError; + if (err != null) { + payload.release(); + return Mono.error(err); + } + + if (!PayloadValidationUtils.isValid(this.mtu, payload)) { + payload.release(); + return Mono.error(new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE)); + } + + final AtomicBoolean once = new AtomicBoolean(); + + return Mono.defer( + () -> { + if (once.getAndSet(true)) { + return Mono.error( + new IllegalStateException("MetadataPushMono allows only a single subscriber")); + } + + ByteBuf metadataPushFrame = + MetadataPushFrameCodec.encodeReleasingPayload(allocator, payload); + + sendProcessor.onNextPrioritized(metadataPushFrame); + + return Mono.empty(); + }); + } + + @Nullable + private Throwable checkAvailable() { + Throwable err = this.terminationError; + if (err != null) { + return err; + } + RequesterLeaseHandler lh = leaseHandler; + if (!lh.useLease()) { + return lh.leaseError(); + } + return null; + } + + private void handleIncomingFrames(ByteBuf frame) { + try { + int streamId = FrameHeaderCodec.streamId(frame); + FrameType type = FrameHeaderCodec.frameType(frame); + if (streamId == 0) { + handleStreamZero(type, frame); + } else { + handleFrame(streamId, type, frame); + } + frame.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frame); + throw reactor.core.Exceptions.propagate(t); + } + } + + private void handleStreamZero(FrameType type, ByteBuf frame) { + switch (type) { + case ERROR: + tryTerminateOnZeroError(frame); + break; + case LEASE: + leaseHandler.receive(frame); + break; + case KEEPALIVE: + if (keepAliveFramesAcceptor != null) { + keepAliveFramesAcceptor.receive(frame); + } + break; + default: + // Ignore unknown frames. Throwing an error will close the socket. + if (LOGGER.isInfoEnabled()) { + LOGGER.info("Requester received unsupported frame on stream 0: " + frame.toString()); + } + } + } + + private void handleFrame(int streamId, FrameType type, ByteBuf frame) { + Subscriber receiver = receivers.get(streamId); + switch (type) { + case NEXT: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onNext(payloadDecoder.apply(frame)); + break; + case NEXT_COMPLETE: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onNext(payloadDecoder.apply(frame)); + receiver.onComplete(); + break; + case COMPLETE: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onComplete(); + receivers.remove(streamId); + break; + case ERROR: + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + receiver.onError(Exceptions.from(streamId, frame)); + receivers.remove(streamId); + break; + case CANCEL: + { + Subscription sender = senders.remove(streamId); + if (sender != null) { + sender.cancel(); + } + break; + } + case REQUEST_N: + { + Subscription sender = senders.get(streamId); + if (sender != null) { + long n = RequestNFrameCodec.requestN(frame); + sender.request(n); + } + break; + } + default: + throw new IllegalStateException( + "Requester received unsupported frame on stream " + streamId + ": " + frame.toString()); + } + } + + private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBuf frame) { + if (!streamIdSupplier.isBeforeOrCurrent(streamId)) { + if (type == FrameType.ERROR) { + // message for stream that has never existed, we have a problem with + // the overall connection and must tear down + String errorMessage = ErrorFrameCodec.dataUtf8(frame); + + throw new IllegalStateException( + "Client received error for non-existent stream: " + + streamId + + " Message: " + + errorMessage); + } else { + throw new IllegalStateException( + "Client received message for non-existent stream: " + + streamId + + ", frame type: " + + type); + } + } + // receiving a frame after a given stream has been cancelled/completed, + // so ignore (cancellation is async so there is a race condition) + } + + private void tryTerminateOnKeepAlive(KeepAlive keepAlive) { + tryTerminate( + () -> + new ConnectionErrorException( + String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); + } + + private void tryTerminateOnConnectionError(Throwable e) { + tryTerminate(() -> e); + } + + private void tryTerminateOnZeroError(ByteBuf errorFrame) { + tryTerminate(() -> Exceptions.from(0, errorFrame)); + } + + private void tryTerminate(Supplier errorSupplier) { + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(e); + } + } + } + + private void tryShutdown() { + if (terminationError == null) { + if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) { + terminate(CLOSED_CHANNEL_EXCEPTION); + } + } + } + + private void terminate(Throwable e) { + connection.dispose(); + leaseHandler.dispose(); + + synchronized (receivers) { + receivers + .values() + .forEach( + receiver -> { + try { + receiver.onError(e); + } catch (Throwable t) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); + } + } + }); + } + synchronized (senders) { + senders + .values() + .forEach( + sender -> { + try { + sender.cancel(); + } catch (Throwable t) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); + } + } + }); + } + senders.clear(); + receivers.clear(); + sendProcessor.dispose(); + if (e == CLOSED_CHANNEL_EXCEPTION) { + onClose.onComplete(); + } else { + onClose.onError(e); + } + } + + private void handleSendProcessorError(Throwable t) { + connection.dispose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java similarity index 52% rename from rsocket-core/src/main/java/io/rsocket/RSocketResponder.java rename to rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java index 490b00967..d3860e5f2 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocketResponder.java +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -14,40 +14,75 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.frame.*; import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.internal.RateLimitableRequestPublisher; import io.rsocket.internal.SynchronizedIntObjectHashMap; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.lease.ResponderLeaseHandler; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.Consumer; +import java.util.function.LongConsumer; +import java.util.function.Supplier; import org.reactivestreams.Processor; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.Exceptions; import reactor.core.publisher.*; -import reactor.util.concurrent.Queues; +import reactor.util.annotation.Nullable; /** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ -class RSocketResponder implements ResponderRSocket { +class RSocketResponder implements RSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketResponder.class); + + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + referenceCounted -> { + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + }; + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); private final DuplexConnection connection; private final RSocket requestHandler; - private final ResponderRSocket responderRSocket; + + @SuppressWarnings("deprecation") + private final io.rsocket.ResponderRSocket responderRSocket; + private final PayloadDecoder payloadDecoder; - private final Consumer errorConsumer; private final ResponderLeaseHandler leaseHandler; + private final Disposable leaseHandlerDisposable; + + private volatile Throwable terminationError; + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketResponder.class, Throwable.class, "terminationError"); + + private final int mtu; - private final IntObjectMap sendingLimitableSubscriptions; private final IntObjectMap sendingSubscriptions; private final IntObjectMap> channelProcessors; @@ -55,23 +90,23 @@ class RSocketResponder implements ResponderRSocket { private final ByteBufAllocator allocator; RSocketResponder( - ByteBufAllocator allocator, DuplexConnection connection, RSocket requestHandler, PayloadDecoder payloadDecoder, - Consumer errorConsumer, - ResponderLeaseHandler leaseHandler) { - this.allocator = allocator; + ResponderLeaseHandler leaseHandler, + int mtu) { this.connection = connection; + this.allocator = connection.alloc(); + this.mtu = mtu; this.requestHandler = requestHandler; this.responderRSocket = - (requestHandler instanceof ResponderRSocket) ? (ResponderRSocket) requestHandler : null; + (requestHandler instanceof io.rsocket.ResponderRSocket) + ? (io.rsocket.ResponderRSocket) requestHandler + : null; this.payloadDecoder = payloadDecoder; - this.errorConsumer = errorConsumer; this.leaseHandler = leaseHandler; - this.sendingLimitableSubscriptions = new SynchronizedIntObjectHashMap<>(); this.sendingSubscriptions = new SynchronizedIntObjectHashMap<>(); this.channelProcessors = new SynchronizedIntObjectHashMap<>(); @@ -79,23 +114,14 @@ class RSocketResponder implements ResponderRSocket { // connections this.sendProcessor = new UnboundedProcessor<>(); - connection - .send(sendProcessor) - .doFinally(this::handleSendProcessorCancel) - .subscribe(null, this::handleSendProcessorError); + connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError); - Disposable receiveDisposable = connection.receive().subscribe(this::handleFrame, errorConsumer); - Disposable sendLeaseDisposable = leaseHandler.send(sendProcessor::onNext); + connection.receive().subscribe(this::handleFrame, e -> {}); + leaseHandlerDisposable = leaseHandler.send(sendProcessor::onNextPrioritized); this.connection .onClose() - .doFinally( - s -> { - cleanup(); - receiveDisposable.dispose(); - sendLeaseDisposable.dispose(); - }) - .subscribe(null, errorConsumer); + .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose); } private void handleSendProcessorError(Throwable t) { @@ -106,18 +132,9 @@ private void handleSendProcessorError(Throwable t) { try { subscription.cancel(); } catch (Throwable e) { - errorConsumer.accept(e); - } - }); - - sendingLimitableSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); + } } }); @@ -128,48 +145,28 @@ private void handleSendProcessorError(Throwable t) { try { subscription.onError(t); } catch (Throwable e) { - errorConsumer.accept(e); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("Dropped exception", t); + } } }); } - private void handleSendProcessorCancel(SignalType t) { - if (SignalType.ON_ERROR == t) { - return; - } - - sendingSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + private void tryTerminateOnConnectionError(Throwable e) { + tryTerminate(() -> e); + } - sendingLimitableSubscriptions - .values() - .forEach( - subscription -> { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + private void tryTerminateOnConnectionClose() { + tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); + } - channelProcessors - .values() - .forEach( - subscription -> { - try { - subscription.onComplete(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - }); + private void tryTerminate(Supplier errorSupplier) { + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + cleanup(e); + } + } } @Override @@ -227,8 +224,7 @@ public Flux requestChannel(Publisher payloads) { } } - @Override - public Flux requestChannel(Payload payload, Publisher payloads) { + private Flux requestChannel(Payload payload, Publisher payloads) { try { if (leaseHandler.useLease()) { return responderRSocket.requestChannel(payload, payloads); @@ -252,7 +248,7 @@ public Mono metadataPush(Payload payload) { @Override public void dispose() { - connection.dispose(); + tryTerminate(() -> new CancellationException("Disposed")); } @Override @@ -265,10 +261,12 @@ public Mono onClose() { return connection.onClose(); } - private void cleanup() { + private void cleanup(Throwable e) { cleanUpSendingSubscriptions(); - cleanUpChannelProcessors(); + cleanUpChannelProcessors(e); + connection.dispose(); + leaseHandlerDisposable.dispose(); requestHandler.dispose(); sendProcessor.dispose(); } @@ -276,21 +274,27 @@ private void cleanup() { private synchronized void cleanUpSendingSubscriptions() { sendingSubscriptions.values().forEach(Subscription::cancel); sendingSubscriptions.clear(); - - sendingLimitableSubscriptions.values().forEach(Subscription::cancel); - sendingLimitableSubscriptions.clear(); } - private synchronized void cleanUpChannelProcessors() { - channelProcessors.values().forEach(Processor::onComplete); + private synchronized void cleanUpChannelProcessors(Throwable e) { + channelProcessors + .values() + .forEach( + payloadPayloadProcessor -> { + try { + payloadPayloadProcessor.onError(e); + } catch (Throwable t) { + // noops + } + }); channelProcessors.clear(); } private void handleFrame(ByteBuf frame) { try { - int streamId = FrameHeaderFlyweight.streamId(frame); + int streamId = FrameHeaderCodec.streamId(frame); Subscriber receiver; - FrameType frameType = FrameHeaderFlyweight.frameType(frame); + FrameType frameType = FrameHeaderCodec.frameType(frame); switch (frameType) { case REQUEST_FNF: handleFireAndForget(streamId, fireAndForget(payloadDecoder.apply(frame))); @@ -305,12 +309,12 @@ private void handleFrame(ByteBuf frame) { handleRequestN(streamId, frame); break; case REQUEST_STREAM: - int streamInitialRequestN = RequestStreamFrameFlyweight.initialRequestN(frame); + long streamInitialRequestN = RequestStreamFrameCodec.initialRequestN(frame); Payload streamPayload = payloadDecoder.apply(frame); - handleStream(streamId, requestStream(streamPayload), streamInitialRequestN); + handleStream(streamId, requestStream(streamPayload), streamInitialRequestN, null); break; case REQUEST_CHANNEL: - int channelInitialRequestN = RequestChannelFrameFlyweight.initialRequestN(frame); + long channelInitialRequestN = RequestChannelFrameCodec.initialRequestN(frame); Payload channelPayload = payloadDecoder.apply(frame); handleChannel(streamId, channelPayload, channelInitialRequestN); break; @@ -335,7 +339,7 @@ private void handleFrame(ByteBuf frame) { case ERROR: receiver = channelProcessors.get(streamId); if (receiver != null) { - receiver.onError(new ApplicationErrorException(ErrorFrameFlyweight.dataUtf8(frame))); + receiver.onError(new ApplicationErrorException(ErrorFrameCodec.dataUtf8(frame))); } break; case NEXT_COMPLETE: @@ -372,9 +376,7 @@ protected void hookOnSubscribe(Subscription subscription) { } @Override - protected void hookOnError(Throwable throwable) { - errorConsumer.accept(throwable); - } + protected void hookOnError(Throwable throwable) {} @Override protected void hookFinally(SignalType type) { @@ -384,32 +386,27 @@ protected void hookFinally(SignalType type) { } private void handleRequestResponse(int streamId, Mono response) { - response.subscribe( + final BaseSubscriber subscriber = new BaseSubscriber() { private boolean isEmpty = true; - @Override - protected void hookOnSubscribe(Subscription subscription) { - sendingSubscriptions.put(streamId, subscription); - subscription.request(Long.MAX_VALUE); - } - @Override protected void hookOnNext(Payload payload) { if (isEmpty) { isEmpty = false; } - ByteBuf byteBuf; - try { - byteBuf = PayloadFrameFlyweight.encodeNextComplete(allocator, streamId, payload); - } catch (Throwable t) { + if (!PayloadValidationUtils.isValid(mtu, payload)) { payload.release(); - throw Exceptions.propagate(t); + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; } - payload.release(); - + ByteBuf byteBuf = + PayloadFrameCodec.encodeNextCompleteReleasingPayload(allocator, streamId, payload); sendProcessor.onNext(byteBuf); } @@ -421,75 +418,148 @@ protected void hookOnError(Throwable throwable) { @Override protected void hookOnComplete() { if (isEmpty) { - sendProcessor.onNext(PayloadFrameFlyweight.encodeComplete(allocator, streamId)); + sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId)); } } @Override protected void hookFinally(SignalType type) { - sendingSubscriptions.remove(streamId); + sendingSubscriptions.remove(streamId, this); } - }); + }; + + sendingSubscriptions.put(streamId, subscriber); + response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); } - private void handleStream(int streamId, Flux response, int initialRequestN) { - response - .transform( - frameFlux -> { - RateLimitableRequestPublisher payloads = - RateLimitableRequestPublisher.wrap(frameFlux, Queues.SMALL_BUFFER_SIZE); - sendingLimitableSubscriptions.put(streamId, payloads); - payloads.request( - initialRequestN >= Integer.MAX_VALUE ? Long.MAX_VALUE : initialRequestN); - return payloads; - }) - .subscribe( - new BaseSubscriber() { - - @Override - protected void hookOnNext(Payload payload) { - ByteBuf byteBuf; - try { - byteBuf = PayloadFrameFlyweight.encodeNext(allocator, streamId, payload); - } catch (Throwable t) { - payload.release(); - throw Exceptions.propagate(t); - } + private void handleStream( + int streamId, + Flux response, + long initialRequestN, + @Nullable UnicastProcessor requestChannel) { + final BaseSubscriber subscriber = + new BaseSubscriber() { - payload.release(); + @Override + protected void hookOnSubscribe(Subscription s) { + s.request(initialRequestN); + } - sendProcessor.onNext(byteBuf); + @Override + protected void hookOnNext(Payload payload) { + try { + if (!PayloadValidationUtils.isValid(mtu, payload)) { + payload.release(); + // specifically for requestChannel case so when Payload is invalid we will not be + // sending CancelFrame and ErrorFrame + // Note: CancelFrame is redundant and due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream + // is + // terminated on both Requester and Responder. + // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both the Requester and Responder. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); + } + cancel(); + final IllegalArgumentException t = + new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE); + handleError(streamId, t); + return; } - @Override - protected void hookOnComplete() { - sendProcessor.onNext(PayloadFrameFlyweight.encodeComplete(allocator, streamId)); + ByteBuf byteBuf = + PayloadFrameCodec.encodeNextReleasingPayload(allocator, streamId, payload); + sendProcessor.onNext(byteBuf); + } catch (Throwable e) { + // specifically for requestChannel case so when Payload is invalid we will not be + // sending CancelFrame and ErrorFrame + // Note: CancelFrame is redundant and due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both Requester and Responder. + // Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is + // terminated on both the Requester and Responder. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); } + cancel(); + handleError(streamId, e); + } + } - @Override - protected void hookOnError(Throwable throwable) { - handleError(streamId, throwable); - } + @Override + protected void hookOnComplete() { + sendProcessor.onNext(PayloadFrameCodec.encodeComplete(allocator, streamId)); + } + + @Override + protected void hookOnError(Throwable throwable) { + handleError(streamId, throwable); + } - @Override - protected void hookFinally(SignalType type) { - sendingLimitableSubscriptions.remove(streamId); + @Override + protected void hookOnCancel() { + // specifically for requestChannel case so when requester sends Cancel frame so the + // whole chain MUST be terminated + // Note: CancelFrame is redundant from the responder side due to spec + // (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel) + // Upon receiving a CANCEL, the stream is terminated on the Responder. + // Upon sending a CANCEL, the stream is terminated on the Requester. + if (requestChannel != null) { + channelProcessors.remove(streamId, requestChannel); + try { + requestChannel.dispose(); + } catch (Exception e) { + // might be thrown back if stream is cancelled } - }); + } + } + + @Override + protected void hookFinally(SignalType type) { + sendingSubscriptions.remove(streamId); + } + }; + + sendingSubscriptions.put(streamId, subscriber); + response.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER).subscribe(subscriber); } - private void handleChannel(int streamId, Payload payload, int initialRequestN) { + private void handleChannel(int streamId, Payload payload, long initialRequestN) { UnicastProcessor frames = UnicastProcessor.create(); channelProcessors.put(streamId, frames); Flux payloads = frames - .doOnCancel( - () -> sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId))) - .doOnError(t -> handleError(streamId, t)) .doOnRequest( - l -> sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, l))) - .doFinally(signalType -> channelProcessors.remove(streamId)); + new LongConsumer() { + boolean first = true; + + @Override + public void accept(long l) { + long n; + if (first) { + first = false; + n = l - 1L; + } else { + n = l; + } + if (n > 0) { + sendProcessor.onNext(RequestNFrameCodec.encode(allocator, streamId, n)); + } + } + }) + .doFinally( + signalType -> { + if (channelProcessors.remove(streamId, frames)) { + if (signalType == SignalType.CANCEL) { + sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId)); + } + } + }) + .doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER); // not chained, as the payload should be enqueued in the Unicast processor before this method // returns @@ -497,9 +567,9 @@ private void handleChannel(int streamId, Payload payload, int initialRequestN) { frames.onNext(payload); if (responderRSocket != null) { - handleStream(streamId, requestChannel(payload, payloads), initialRequestN); + handleStream(streamId, requestChannel(payload, payloads), initialRequestN, frames); } else { - handleStream(streamId, requestChannel(payloads), initialRequestN); + handleStream(streamId, requestChannel(payloads), initialRequestN, frames); } } @@ -512,18 +582,13 @@ protected void hookOnSubscribe(Subscription subscription) { } @Override - protected void hookOnError(Throwable throwable) { - errorConsumer.accept(throwable); - } + protected void hookOnError(Throwable throwable) {} }); } private void handleCancelFrame(int streamId) { Subscription subscription = sendingSubscriptions.remove(streamId); - - if (subscription == null) { - subscription = sendingLimitableSubscriptions.remove(streamId); - } + channelProcessors.remove(streamId); if (subscription != null) { subscription.cancel(); @@ -531,20 +596,15 @@ private void handleCancelFrame(int streamId) { } private void handleError(int streamId, Throwable t) { - errorConsumer.accept(t); - sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t)); + sendProcessor.onNext(ErrorFrameCodec.encode(allocator, streamId, t)); } private void handleRequestN(int streamId, ByteBuf frame) { Subscription subscription = sendingSubscriptions.get(streamId); - if (subscription == null) { - subscription = sendingLimitableSubscriptions.get(streamId); - } - if (subscription != null) { - int n = RequestNFrameFlyweight.requestN(frame); - subscription.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n); + long n = RequestNFrameCodec.requestN(frame); + subscription.request(n); } } } diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java new file mode 100644 index 000000000..c5734cecc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -0,0 +1,446 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.exceptions.InvalidSetupException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.ClientServerInputMultiplexer; +import io.rsocket.lease.Leases; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.resume.SessionManager; +import io.rsocket.transport.ServerTransport; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + +/** + * The main class for starting an RSocket server. + * + *

For example: + * + *

{@code
+ * CloseableChannel closeable =
+ *         RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+ *                 .bind(TcpServerTransport.create("localhost", 7000))
+ *                 .block();
+ * }
+ */ +public final class RSocketServer { + private static final String SERVER_TAG = "server"; + + private SocketAcceptor acceptor = SocketAcceptor.with(new RSocket() {}); + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Resume resume; + private Supplier> leasesSupplier = null; + + private int mtu = 0; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + + private RSocketServer() {} + + /** Static factory method to create an {@code RSocketServer}. */ + public static RSocketServer create() { + return new RSocketServer(); + } + + /** + * Static factory method to create an {@code RSocketServer} instance with the given {@code + * SocketAcceptor}. Effectively a shortcut for: + * + *
+   * RSocketServer.create().acceptor(...);
+   * 
+ * + * @param acceptor the acceptor to handle connections with + * @return the same instance for method chaining + * @see #acceptor(SocketAcceptor) + */ + public static RSocketServer create(SocketAcceptor acceptor) { + return RSocketServer.create().acceptor(acceptor); + } + + /** + * Set the acceptor to handle incoming connections and handle requests. + * + *

An example with access to the {@code SETUP} frame and sending RSocket for performing + * requests back to the client if needed: + * + *

{@code
+   * RSocketServer.create((setup, sendingRSocket) -> Mono.just(new RSocket() {...}))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

A shortcut to provide the handling RSocket only: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

A shortcut to handle request-response interactions only: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.forRequestResponse(payload -> ...))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

By default, {@code new RSocket(){}} is used for handling which rejects requests from the + * client with {@link UnsupportedOperationException}. + * + * @param acceptor the acceptor to handle incoming connections and requests with + * @return the same instance for method chaining + */ + public RSocketServer acceptor(SocketAcceptor acceptor) { + Objects.requireNonNull(acceptor); + this.acceptor = acceptor; + return this; + } + + /** + * Configure interception at one of the following levels: + * + *

    + *
  • Transport level + *
  • At the level of accepting new connections + *
  • Performing requests + *
  • Responding to requests + *
+ * + * @param configurer a configurer to customize interception with. + * @return the same instance for method chaining + * @see io.rsocket.plugins.LimitRateInterceptor + */ + public RSocketServer interceptors(Consumer configurer) { + configurer.accept(this.interceptors); + return this; + } + + /** + * Enables the Resume capability of the RSocket protocol where if the client gets disconnected, + * the connection is re-acquired and any interrupted streams are transparently resumed. For this + * to work clients must also support and request to enable this when connecting. + * + *

Use the {@link Resume} argument to customize the Resume session duration, storage, retry + * logic, and others. + * + *

By default this is not enabled. + * + * @param resume configuration for the Resume capability + * @return the same instance for method chaining + * @see Resuming + * Operation + */ + public RSocketServer resume(Resume resume) { + this.resume = resume; + return this; + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. For + * this to work clients must also support and request to enable this when connecting. + * + *

Example usage: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+   *         .lease(Leases::new)
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

By default this is not enabled. + * + * @param supplier supplier for a {@link Leases} + * @return the same instance for method chaining + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketServer lease(Supplier> supplier) { + this.leasesSupplier = supplier; + return this; + } + + /** + * When this is set, frames larger than the given maximum transmission unit (mtu) size value are + * fragmented. + * + *

By default this is not set in which case payloads are sent whole up to the maximum frame + * size of 16,777,215 bytes. + * + * @param mtu the threshold size for fragmentation, must be no less than 64 + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketServer fragment(int mtu) { + if (mtu > 0 && mtu < FragmentationDuplexConnection.MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format( + "The smallest allowed mtu size is %d bytes, provided: %d", + FragmentationDuplexConnection.MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } + this.mtu = mtu; + return this; + } + + /** + * Configure the {@code PayloadDecoder} used to create {@link Payload}'s from incoming raw frame + * buffers. The following decoders are available: + * + *

    + *
  • {@link PayloadDecoder#DEFAULT} -- the data and metadata are independent copies of the + * underlying frame {@link ByteBuf} + *
  • {@link PayloadDecoder#ZERO_COPY} -- the data and metadata are retained slices of the + * underlying {@link ByteBuf}. That's more efficient but requires careful tracking and + * {@link Payload#release() release} of the payload when no longer needed. + *
+ * + *

By default this is set to {@link PayloadDecoder#DEFAULT} in which case data and metadata are + * copied and do not need to be tracked and released. + * + * @param decoder the decoder to use + * @return the same instance for method chaining + */ + public RSocketServer payloadDecoder(PayloadDecoder decoder) { + Objects.requireNonNull(decoder); + this.payloadDecoder = decoder; + return this; + } + + /** + * Start the server on the given transport. + * + *

The following transports are available from additional RSocket Java modules: + * + *

    + *
  • {@link io.rsocket.transport.netty.client.TcpServerTransport TcpServerTransport} via + * {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.netty.client.WebsocketServerTransport + * WebsocketServerTransport} via {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.local.LocalServerTransport LocalServerTransport} via {@code + * rsocket-transport-local} + *
+ * + * @param transport the transport of choice to connect with + * @param the type of {@code Closeable} for the given transport + * @return a {@code Mono} with a {@code Closeable} that can be used to obtain information about + * the server, stop it, or be notified of when it is stopped. + */ + public Mono bind(ServerTransport transport) { + return Mono.defer( + new Supplier>() { + ServerSetup serverSetup = serverSetup(); + + @Override + public Mono get() { + return transport + .start(duplexConnection -> acceptor(serverSetup, duplexConnection), mtu) + .doOnNext(c -> c.onClose().doFinally(v -> serverSetup.dispose()).subscribe()); + } + }); + } + + /** + * Start the server on the given transport. Effectively is a shortcut for {@code + * .bind(ServerTransport).block()} + */ + public T bindNow(ServerTransport transport) { + return bind(transport).block(); + } + + /** + * An alternative to {@link #bind(ServerTransport)} that is useful for installing RSocket on a + * server that is started independently. + * + * @see io.rsocket.examples.transport.ws.WebSocketHeadersSample + */ + public ServerTransport.ConnectionAcceptor asConnectionAcceptor() { + return new ServerTransport.ConnectionAcceptor() { + private final ServerSetup serverSetup = serverSetup(); + + @Override + public Mono apply(DuplexConnection connection) { + return acceptor(serverSetup, connection); + } + }; + } + + private Mono acceptor(ServerSetup serverSetup, DuplexConnection connection) { + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(connection, interceptors, false); + + return multiplexer + .asSetupConnection() + .receive() + .next() + .flatMap(startFrame -> accept(serverSetup, startFrame, multiplexer)); + } + + private Mono acceptResume( + ServerSetup serverSetup, ByteBuf resumeFrame, ClientServerInputMultiplexer multiplexer) { + return serverSetup.acceptRSocketResume(resumeFrame, multiplexer); + } + + private Mono accept( + ServerSetup serverSetup, ByteBuf startFrame, ClientServerInputMultiplexer multiplexer) { + switch (FrameHeaderCodec.frameType(startFrame)) { + case SETUP: + return acceptSetup(serverSetup, startFrame, multiplexer); + case RESUME: + return acceptResume(serverSetup, startFrame, multiplexer); + default: + return serverSetup + .sendError( + multiplexer, + new InvalidSetupException( + "invalid setup frame: " + FrameHeaderCodec.frameType(startFrame))) + .doFinally( + signalType -> { + startFrame.release(); + multiplexer.dispose(); + }); + } + } + + private Mono acceptSetup( + ServerSetup serverSetup, ByteBuf setupFrame, ClientServerInputMultiplexer multiplexer) { + + if (!SetupFrameCodec.isSupportedVersion(setupFrame)) { + return serverSetup + .sendError( + multiplexer, + new InvalidSetupException( + "Unsupported version: " + SetupFrameCodec.humanReadableVersion(setupFrame))) + .doFinally( + signalType -> { + setupFrame.release(); + multiplexer.dispose(); + }); + } + + boolean leaseEnabled = leasesSupplier != null; + if (SetupFrameCodec.honorLease(setupFrame) && !leaseEnabled) { + return serverSetup + .sendError(multiplexer, new InvalidSetupException("lease is not supported")) + .doFinally( + signalType -> { + setupFrame.release(); + multiplexer.dispose(); + }); + } + + return serverSetup.acceptRSocketSetup( + setupFrame, + multiplexer, + (keepAliveHandler, wrappedMultiplexer) -> { + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(setupFrame); + + Leases leases = leaseEnabled ? leasesSupplier.get() : null; + RequesterLeaseHandler requesterLeaseHandler = + leaseEnabled + ? new RequesterLeaseHandler.Impl(SERVER_TAG, leases.receiver()) + : RequesterLeaseHandler.None; + + RSocket rSocketRequester = + new RSocketRequester( + wrappedMultiplexer.asServerConnection(), + payloadDecoder, + StreamIdSupplier.serverSupplier(), + mtu, + setupPayload.keepAliveInterval(), + setupPayload.keepAliveMaxLifetime(), + keepAliveHandler, + requesterLeaseHandler, + Schedulers.single(Schedulers.parallel())); + + RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setupPayload, wrappedRSocketRequester) + .onErrorResume( + err -> + serverSetup + .sendError(multiplexer, rejectedSetupError(err)) + .then(Mono.error(err))) + .doOnNext( + rSocketHandler -> { + RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); + DuplexConnection connection = wrappedMultiplexer.asClientConnection(); + + ResponderLeaseHandler responderLeaseHandler = + leaseEnabled + ? new ResponderLeaseHandler.Impl<>( + SERVER_TAG, connection.alloc(), leases.sender(), leases.stats()) + : ResponderLeaseHandler.None; + + RSocket rSocketResponder = + new RSocketResponder( + connection, + wrappedRSocketHandler, + payloadDecoder, + responderLeaseHandler, + mtu); + }) + .doFinally(signalType -> setupPayload.release()) + .then(); + }); + } + + private ServerSetup serverSetup() { + return resume != null ? createSetup() : new ServerSetup.DefaultServerSetup(); + } + + ServerSetup createSetup() { + return new ServerSetup.ResumableServerSetup( + new SessionManager(), + resume.getSessionDuration(), + resume.getStreamTimeout(), + resume.getStoreFactory(SERVER_TAG), + resume.isCleanupStoreOnKeepAlive()); + } + + private Exception rejectedSetupError(Throwable err) { + String msg = err.getMessage(); + return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java new file mode 100644 index 000000000..81f6625f0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java @@ -0,0 +1,477 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Operators.MonoSubscriber; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class ReconnectMono extends Mono implements Invalidatable, Disposable, Scannable { + + final Mono source; + final BiConsumer onValueReceived; + final Consumer onValueExpired; + final ReconnectMainSubscriber mainSubscriber; + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ReconnectMono.class, "wip"); + + volatile ReconnectInner[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectMono.class, ReconnectInner[].class, "subscribers"); + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] EMPTY_UNSUBSCRIBED = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] EMPTY_SUBSCRIBED = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] READY = new ReconnectInner[0]; + + @SuppressWarnings("rawtypes") + static final ReconnectInner[] TERMINATED = new ReconnectInner[0]; + + static final int ADDED_STATE = 0; + static final int READY_STATE = 1; + static final int TERMINATED_STATE = 2; + + T value; + Throwable t; + + ReconnectMono( + Mono source, + Consumer onValueExpired, + BiConsumer onValueReceived) { + this.source = source; + this.onValueExpired = onValueExpired; + this.onValueReceived = onValueReceived; + this.mainSubscriber = new ReconnectMainSubscriber<>(this); + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return source; + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + + final boolean isDisposed = isDisposed(); + if (key == Attr.TERMINATED) return isDisposed; + if (key == Attr.ERROR) return t; + + return null; + } + + @Override + public void dispose() { + this.terminate(new CancellationException("ReconnectMono has already been disposed")); + } + + @Override + public boolean isDisposed() { + return this.subscribers == TERMINATED; + } + + @Override + @SuppressWarnings("uncheked") + public void subscribe(CoreSubscriber actual) { + final ReconnectInner inner = new ReconnectInner<>(actual, this); + actual.onSubscribe(inner); + + final int state = this.add(inner); + + if (state == READY_STATE) { + inner.complete(this.value); + } else if (state == TERMINATED_STATE) { + inner.onError(this.t); + } + } + + /** + * Block the calling thread indefinitely, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @return the value of this {@code ReconnectMono} + */ + @Override + @Nullable + public T block() { + return block(null); + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@code ReconnectMono} or {@code null} if the timeout is reached and + * the {@code ReconnectMono} has not completed + */ + @Override + @Nullable + @SuppressWarnings("uncheked") + public T block(@Nullable Duration timeout) { + try { + ReconnectInner[] subscribers = this.subscribers; + if (subscribers == READY) { + return this.value; + } + + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("ReconnectMono terminated with an error")); + throw re; + } + + // connect once + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.source.subscribe(this.mainSubscriber); + } + + long delay; + if (null == timeout) { + delay = 0L; + } else { + delay = System.nanoTime() + timeout.toNanos(); + } + for (; ; ) { + ReconnectInner[] inners = this.subscribers; + + if (inners == READY) { + return this.value; + } + if (inners == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = + Exceptions.addSuppressed(re, new Exception("ReconnectMono terminated with an error")); + throw re; + } + if (timeout != null && delay < System.nanoTime()) { + throw new IllegalStateException("Timeout on Mono blocking read"); + } + + Thread.sleep(1); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + + throw new IllegalStateException("Thread Interruption on Mono blocking read"); + } + } + + @SuppressWarnings("unchecked") + void terminate(Throwable t) { + if (isDisposed()) { + return; + } + + // writes happens before volatile write + this.t = t; + + final ReconnectInner[] subscribers = SUBSCRIBERS.getAndSet(this, TERMINATED); + if (subscribers == TERMINATED) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.mainSubscriber.dispose(); + + this.doFinally(); + + for (CoreSubscriber consumer : subscribers) { + consumer.onError(t); + } + } + + void complete() { + ReconnectInner[] subscribers = this.subscribers; + if (subscribers == TERMINATED) { + return; + } + + final T value = this.value; + + for (; ; ) { + // ensures TERMINATE is going to be replaced with READY + if (SUBSCRIBERS.compareAndSet(this, subscribers, READY)) { + break; + } + + subscribers = this.subscribers; + + if (subscribers == TERMINATED) { + this.doFinally(); + return; + } + } + + this.onValueReceived.accept(value, this); + + for (ReconnectInner consumer : subscribers) { + consumer.complete(value); + } + } + + void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + + if (value != null && isDisposed()) { + this.value = null; + this.onValueExpired.accept(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + + // Check RSocket is not good + @Override + public void invalidate() { + if (this.subscribers == TERMINATED) { + return; + } + + final ReconnectInner[] subscribers = this.subscribers; + + if (subscribers == READY && SUBSCRIBERS.compareAndSet(this, READY, EMPTY_UNSUBSCRIBED)) { + final T value = this.value; + this.value = null; + + if (value != null) { + this.onValueExpired.accept(value); + } + } + } + + int add(ReconnectInner ps) { + for (; ; ) { + ReconnectInner[] a = this.subscribers; + + if (a == TERMINATED) { + return TERMINATED_STATE; + } + + if (a == READY) { + return READY_STATE; + } + + int n = a.length; + @SuppressWarnings("unchecked") + ReconnectInner[] b = new ReconnectInner[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + this.source.subscribe(this.mainSubscriber); + } + return ADDED_STATE; + } + } + } + + @SuppressWarnings("unchecked") + void remove(ReconnectInner ps) { + for (; ; ) { + ReconnectInner[] a = this.subscribers; + int n = a.length; + if (n == 0) { + return; + } + + int j = -1; + for (int i = 0; i < n; i++) { + if (a[i] == ps) { + j = i; + break; + } + } + + if (j < 0) { + return; + } + + ReconnectInner[] b; + + if (n == 1) { + b = EMPTY_SUBSCRIBED; + } else { + b = new ReconnectInner[n - 1]; + System.arraycopy(a, 0, b, 0, j); + System.arraycopy(a, j + 1, b, j, n - j - 1); + } + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + return; + } + } + } + + static final class ReconnectMainSubscriber implements CoreSubscriber { + + final ReconnectMono parent; + + volatile Subscription s; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectMainSubscriber.class, Subscription.class, "s"); + + ReconnectMainSubscriber(ReconnectMono parent) { + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final ReconnectMono p = this.parent; + final T value = p.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + p.doFinally(); + return; + } + + if (value == null) { + p.terminate(new IllegalStateException("Unexpected Completion of the Upstream")); + } else { + p.complete(); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + final ReconnectMono p = this.parent; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + p.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + // terminate upstream which means retryBackoff has exhausted + p.terminate(t); + } + + @Override + public void onNext(T value) { + if (this.s == Operators.cancelledSubscription()) { + this.parent.onValueExpired.accept(value); + return; + } + + final ReconnectMono p = this.parent; + + p.value = value; + // volatile write and check on racing + p.doFinally(); + } + + void dispose() { + Operators.terminate(S, this); + } + } + + static final class ReconnectInner extends MonoSubscriber { + final ReconnectMono parent; + + ReconnectInner(CoreSubscriber actual, ReconnectMono parent) { + super(actual); + this.parent = parent; + } + + @Override + public void cancel() { + if (!isCancelled()) { + super.cancel(); + this.parent.remove(this); + } + } + + @Override + public void onComplete() { + if (!isCancelled()) { + this.actual.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + if (isCancelled()) { + Operators.onErrorDropped(t, currentContext()); + } else { + this.actual.onError(t); + } + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.parent; + return super.scanUnsafe(key); + } + } +} + +interface Invalidatable { + + void invalidate(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java b/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java new file mode 100644 index 000000000..05f8d6b3c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestOperator.java @@ -0,0 +1,188 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.core.publisher.SignalType; +import reactor.util.context.Context; + +/** + * This is a support class for handling of request input, intended for use with {@link + * Operators#lift}. It ensures serial execution of cancellation vs first request signals and also + * provides hooks for separate handling of first vs subsequent {@link Subscription#request} + * invocations. + */ +abstract class RequestOperator + implements CoreSubscriber, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + + Subscription s; + Fuseable.QueueSubscription qs; + + int streamId; + boolean firstRequest = true; + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(RequestOperator.class, "wip"); + + RequestOperator(CoreSubscriber actual) { + this.actual = actual; + } + + /** + * Optional hook executed exactly once on the first {@link Subscription#request) invocation + * and right after the {@link Subscription#request} was propagated to the upstream subscription. + * + *

Note: this hook may not be invoked if cancellation happened before this invocation + */ + void hookOnFirstRequest(long n) {} + + /** + * Optional hook executed after the {@link Subscription#request} was propagated to the upstream + * subscription and excludes the first {@link Subscription#request} invocation. + */ + void hookOnRemainingRequests(long n) {} + + /** Optional hook executed after this {@link Subscription} cancelling. */ + void hookOnCancel() {} + + /** + * Optional hook executed after {@link org.reactivestreams.Subscriber} termination events + * (onError, onComplete). + * + * @param signalType the type of termination event that triggered the hook ({@link + * SignalType#ON_ERROR} or {@link SignalType#ON_COMPLETE}) + */ + void hookOnTerminal(SignalType signalType) {} + + @Override + public Context currentContext() { + return actual.currentContext(); + } + + @Override + public void request(long n) { + this.s.request(n); + if (!firstRequest) { + try { + this.hookOnRemainingRequests(n); + } catch (Throwable throwable) { + onError(throwable); + } + return; + } + this.firstRequest = false; + + if (WIP.getAndIncrement(this) != 0) { + return; + } + int missed = 1; + + boolean firstLoop = true; + for (; ; ) { + if (firstLoop) { + firstLoop = false; + try { + this.hookOnFirstRequest(n); + } catch (Throwable throwable) { + onError(throwable); + return; + } + } else { + try { + this.hookOnCancel(); + } catch (Throwable throwable) { + onError(throwable); + } + return; + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + return; + } + } + } + + @Override + public void cancel() { + this.s.cancel(); + + if (WIP.getAndIncrement(this) != 0) { + return; + } + + hookOnCancel(); + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; + } + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(Payload t) { + this.actual.onNext(t); + } + + @Override + public void onError(Throwable t) { + this.actual.onError(t); + try { + this.hookOnTerminal(SignalType.ON_ERROR); + } catch (Throwable throwable) { + Operators.onErrorDropped(throwable, currentContext()); + } + } + + @Override + public void onComplete() { + this.actual.onComplete(); + try { + this.hookOnTerminal(SignalType.ON_COMPLETE); + } catch (Throwable throwable) { + Operators.onErrorDropped(throwable, currentContext()); + } + } + + @Override + public int requestFusion(int requestedMode) { + if (this.qs != null) { + return this.qs.requestFusion(requestedMode); + } else { + return Fuseable.NONE; + } + } + + @Override + public Payload poll() { + return this.qs.poll(); + } + + @Override + public int size() { + return this.qs.size(); + } + + @Override + public boolean isEmpty() { + return this.qs.isEmpty(); + } + + @Override + public void clear() { + this.qs.clear(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/Resume.java b/rsocket-core/src/main/java/io/rsocket/core/Resume.java new file mode 100644 index 000000000..48133af98 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/Resume.java @@ -0,0 +1,177 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.ResumableFramesStore; +import java.time.Duration; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.util.retry.Retry; + +/** + * Simple holder of configuration settings for the RSocket Resume capability. This can be used to + * configure an {@link RSocketConnector} or an {@link RSocketServer} except for {@link + * #retry(Retry)} and {@link #token(Supplier)} which apply only to the client side. + */ +public class Resume { + private static final Logger logger = LoggerFactory.getLogger(Resume.class); + + private Duration sessionDuration = Duration.ofMinutes(2); + + /* Storage */ + private boolean cleanupStoreOnKeepAlive; + private Function storeFactory; + private Duration streamTimeout = Duration.ofSeconds(10); + + /* Client only */ + private Supplier tokenSupplier = ResumeFrameCodec::generateResumeToken; + private Retry retry = + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1)) + .maxBackoff(Duration.ofSeconds(16)) + .jitter(1.0) + .doBeforeRetry(signal -> logger.debug("Connection error", signal.failure())); + + public Resume() {} + + /** + * The maximum time for a client to keep trying to reconnect. During this time client and server + * continue to store unsent frames to keep the session warm and ready to resume. + * + *

By default this is set to 2 minutes. + * + * @param sessionDuration the max duration for a session + * @return the same instance for method chaining + */ + public Resume sessionDuration(Duration sessionDuration) { + this.sessionDuration = Objects.requireNonNull(sessionDuration); + return this; + } + + /** + * When this property is enabled, hints from {@code KEEPALIVE} frames about how much data has been + * received by the other side, is used to proactively clean frames from the {@link + * #storeFactory(Function) store}. + * + *

By default this is set to {@code false} in which case information from {@code KEEPALIVE} is + * ignored and old frames from the store are removed only when the store runs out of space. + * + * @return the same instance for method chaining + */ + public Resume cleanupStoreOnKeepAlive() { + this.cleanupStoreOnKeepAlive = true; + return this; + } + + /** + * Configure a factory to create the storage for buffering (or persisting) a window of frames that + * may need to be sent again to resume after a dropped connection. + * + *

By default {@link InMemoryResumableFramesStore} is used with its cache size set to 100,000 + * bytes. When the cache fills up, the oldest frames are gradually removed to create space for new + * ones. + * + * @param storeFactory the factory to use to create the store + * @return the same instance for method chaining + */ + public Resume storeFactory( + Function storeFactory) { + this.storeFactory = storeFactory; + return this; + } + + /** + * A {@link reactor.core.publisher.Flux#timeout(Duration) timeout} value to apply to the resumed + * session stream obtained from the {@link #storeFactory(Function) store} after a reconnect. The + * resume stream must not take longer than the specified time to emit each frame. + * + *

By default this is set to 10 seconds. + * + * @param streamTimeout the timeout value for resuming a session stream + * @return the same instance for method chaining + */ + public Resume streamTimeout(Duration streamTimeout) { + this.streamTimeout = Objects.requireNonNull(streamTimeout); + return this; + } + + /** + * Configure the logic for reconnecting. This setting is for use with {@link + * RSocketConnector#resume(Resume)} on the client side only. + * + *

By default this is set to: + * + *

{@code
+   * Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1))
+   *     .maxBackoff(Duration.ofSeconds(16))
+   *     .jitter(1.0)
+   * }
+ * + * @param retry the {@code Retry} spec to use when attempting to reconnect + * @return the same instance for method chaining + */ + public Resume retry(Retry retry) { + this.retry = retry; + return this; + } + + /** + * Customize the generation of the resume identification token used to resume. This setting is for + * use with {@link RSocketConnector#resume(Resume)} on the client side only. + * + *

By default this is {@code ResumeFrameFlyweight::generateResumeToken}. + * + * @param supplier a custom generator for a resume identification token + * @return the same instance for method chaining + */ + public Resume token(Supplier supplier) { + this.tokenSupplier = supplier; + return this; + } + + // Package private accessors + + Duration getSessionDuration() { + return sessionDuration; + } + + boolean isCleanupStoreOnKeepAlive() { + return cleanupStoreOnKeepAlive; + } + + Function getStoreFactory(String tag) { + return storeFactory != null + ? storeFactory + : token -> new InMemoryResumableFramesStore(tag, 100_000); + } + + Duration getStreamTimeout() { + return streamTimeout; + } + + Retry getRetry() { + return retry; + } + + Supplier getTokenSupplier() { + return tokenSupplier; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ServerSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java similarity index 75% rename from rsocket-core/src/main/java/io/rsocket/internal/ServerSetup.java rename to rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java index dbd8bc173..337d17c64 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ServerSetup.java +++ b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,41 +14,44 @@ * limitations under the License. */ -package io.rsocket.internal; +package io.rsocket.core; import static io.rsocket.keepalive.KeepAliveHandler.*; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; import io.rsocket.exceptions.RejectedResumeException; import io.rsocket.exceptions.UnsupportedSetupException; -import io.rsocket.frame.ResumeFrameFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.keepalive.KeepAliveHandler; import io.rsocket.resume.*; -import io.rsocket.util.ConnectionUtils; import java.time.Duration; import java.util.function.BiFunction; import java.util.function.Function; import reactor.core.publisher.Mono; -public interface ServerSetup { +abstract class ServerSetup { - Mono acceptRSocketSetup( + abstract Mono acceptRSocketSetup( ByteBuf frame, ClientServerInputMultiplexer multiplexer, BiFunction> then); - Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer); + abstract Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer); - default void dispose() {} + void dispose() {} - class DefaultServerSetup implements ServerSetup { - private final ByteBufAllocator allocator; + Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { + DuplexConnection duplexConnection = multiplexer.asSetupConnection(); + return duplexConnection + .sendOne(ErrorFrameCodec.encode(duplexConnection.alloc(), 0, exception)) + .onErrorResume(err -> Mono.empty()); + } - public DefaultServerSetup(ByteBufAllocator allocator) { - this.allocator = allocator; - } + static class DefaultServerSetup extends ServerSetup { @Override public Mono acceptRSocketSetup( @@ -56,7 +59,7 @@ public Mono acceptRSocketSetup( ClientServerInputMultiplexer multiplexer, BiFunction> then) { - if (SetupFrameFlyweight.resumeEnabled(frame)) { + if (SetupFrameCodec.resumeEnabled(frame)) { return sendError(multiplexer, new UnsupportedSetupException("resume not supported")) .doFinally( signalType -> { @@ -78,28 +81,21 @@ public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexe multiplexer.dispose(); }); } - - private Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { - return ConnectionUtils.sendError(allocator, multiplexer, exception); - } } - class ResumableServerSetup implements ServerSetup { - private final ByteBufAllocator allocator; + static class ResumableServerSetup extends ServerSetup { private final SessionManager sessionManager; private final Duration resumeSessionDuration; private final Duration resumeStreamTimeout; private final Function resumeStoreFactory; private final boolean cleanupStoreOnKeepAlive; - public ResumableServerSetup( - ByteBufAllocator allocator, + ResumableServerSetup( SessionManager sessionManager, Duration resumeSessionDuration, Duration resumeStreamTimeout, Function resumeStoreFactory, boolean cleanupStoreOnKeepAlive) { - this.allocator = allocator; this.sessionManager = sessionManager; this.resumeSessionDuration = resumeSessionDuration; this.resumeStreamTimeout = resumeStreamTimeout; @@ -113,15 +109,14 @@ public Mono acceptRSocketSetup( ClientServerInputMultiplexer multiplexer, BiFunction> then) { - if (SetupFrameFlyweight.resumeEnabled(frame)) { - ByteBuf resumeToken = SetupFrameFlyweight.resumeToken(frame); + if (SetupFrameCodec.resumeEnabled(frame)) { + ByteBuf resumeToken = SetupFrameCodec.resumeToken(frame); ResumableDuplexConnection connection = sessionManager .save( new ServerRSocketSession( multiplexer.asClientServerConnection(), - allocator, resumeSessionDuration, resumeStreamTimeout, resumeStoreFactory, @@ -138,7 +133,7 @@ public Mono acceptRSocketSetup( @Override public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexer multiplexer) { - ServerRSocketSession session = sessionManager.get(ResumeFrameFlyweight.token(frame)); + ServerRSocketSession session = sessionManager.get(ResumeFrameCodec.token(frame)); if (session != null) { return session .continueWith(multiplexer.asClientServerConnection()) @@ -155,10 +150,6 @@ public Mono acceptRSocketResume(ByteBuf frame, ClientServerInputMultiplexe } } - private Mono sendError(ClientServerInputMultiplexer multiplexer, Exception exception) { - return ConnectionUtils.sendError(allocator, multiplexer, exception); - } - @Override public void dispose() { sessionManager.dispose(); diff --git a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java similarity index 70% rename from rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java rename to rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java index af8c6b3d0..15d39c993 100644 --- a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java +++ b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,17 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import io.netty.util.collection.IntObjectMap; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; +/** This API is not thread-safe and must be strictly used in serialized fashion */ final class StreamIdSupplier { private static final int MASK = 0x7FFFFFFF; - private static final AtomicLongFieldUpdater STREAM_ID = - AtomicLongFieldUpdater.newUpdater(StreamIdSupplier.class, "streamId"); - private volatile long streamId; + private long streamId; // Visible for testing StreamIdSupplier(int streamId) { @@ -38,10 +36,18 @@ static StreamIdSupplier serverSupplier() { return new StreamIdSupplier(0); } + /** + * This methods provides new stream id and ensures there is no intersections with already running + * streams. This methods is not thread-safe. + * + * @param streamIds currently running streams store + * @return next stream id + */ int nextStreamId(IntObjectMap streamIds) { int streamId; do { - streamId = (int) STREAM_ID.addAndGet(this, 2) & MASK; + this.streamId += 2; + streamId = (int) (this.streamId & MASK); } while (streamId == 0 || streamIds.containsKey(streamId)); return streamId; } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java b/rsocket-core/src/main/java/io/rsocket/core/package-info.java similarity index 54% rename from rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java rename to rsocket-core/src/main/java/io/rsocket/core/package-info.java index ed8e6cd1d..29db3f205 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriHandlerTest.java +++ b/rsocket-core/src/main/java/io/rsocket/core/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,25 +14,15 @@ * limitations under the License. */ -package io.rsocket.transport.local; - -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; - -final class LocalUriHandlerTest implements UriHandlerTest { - - @Override - public String getInvalidUri() { - return "http://test"; - } - - @Override - public UriHandler getUriHandler() { - return new LocalUriHandler(); - } +/** + * Contains {@link io.rsocket.core.RSocketConnector RSocketConnector} and {@link + * io.rsocket.core.RSocketServer RSocketServer}, the main classes for connecting to or starting an + * RSocket server. + * + *

This package also contains a package private classes that implement support for the main + * RSocket interactions. + */ +@NonNullApi +package io.rsocket.core; - @Override - public String getValidUri() { - return "local:test"; - } -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java index e92534b2a..cd0d46754 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * Application layer logic generating a Reactive Streams {@code onError} event. @@ -32,10 +33,9 @@ public final class ApplicationErrorException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ApplicationErrorException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public ApplicationErrorException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ApplicationErrorException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.APPLICATION_ERROR; + public ApplicationErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java index 984e8249b..d51ba0fb7 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The Responder canceled the request but may have started processing it (similar to REJECTED but @@ -33,10 +34,9 @@ public final class CanceledException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public CanceledException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public CanceledException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public CanceledException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CANCELED; + public CanceledException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CANCELED, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java index 3f4f4309d..80324aa90 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The connection is being terminated. Sender or Receiver of this frame MUST wait for outstanding @@ -33,10 +34,9 @@ public final class ConnectionCloseException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ConnectionCloseException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public ConnectionCloseException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ConnectionCloseException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CONNECTION_CLOSE; + public ConnectionCloseException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CONNECTION_CLOSE, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java index beaa3d0d0..b44714f7e 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The connection is being terminated. Sender or Receiver of this frame MAY close the connection @@ -33,10 +34,9 @@ public final class ConnectionErrorException extends RSocketException implements * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public ConnectionErrorException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public ConnectionErrorException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public ConnectionErrorException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.CONNECTION_ERROR; + public ConnectionErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CONNECTION_ERROR, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java new file mode 100644 index 000000000..079b561f9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +public class CustomRSocketException extends RSocketException { + private static final long serialVersionUID = 7873267740343446585L; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @param cause the cause of this exception + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + if (errorCode > ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE + && errorCode < ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]", this); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java index 97de65a96..5c6eee614 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,22 @@ package io.rsocket.exceptions; -import static io.rsocket.frame.ErrorFrameFlyweight.*; +import static io.rsocket.frame.ErrorFrameCodec.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.CANCELED; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.INVALID; +import static io.rsocket.frame.ErrorFrameCodec.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.UNSUPPORTED_SETUP; import io.netty.buffer.ByteBuf; -import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; import java.util.Objects; /** Utility class that generates an exception from a frame. */ @@ -28,42 +40,56 @@ public final class Exceptions { private Exceptions() {} /** - * Create a {@link RSocketException} from a Frame that matches the error code it contains. + * Create a {@link RSocketErrorException} from a Frame that matches the error code it contains. * * @param frame the frame to retrieve the error code and message from - * @return a {@link RSocketException} that matches the error code in the Frame + * @return a {@link RSocketErrorException} that matches the error code in the Frame * @throws NullPointerException if {@code frame} is {@code null} */ - public static RuntimeException from(ByteBuf frame) { + public static RuntimeException from(int streamId, ByteBuf frame) { Objects.requireNonNull(frame, "frame must not be null"); - int errorCode = ErrorFrameFlyweight.errorCode(frame); - String message = ErrorFrameFlyweight.dataUtf8(frame); + int errorCode = ErrorFrameCodec.errorCode(frame); + String message = ErrorFrameCodec.dataUtf8(frame); - switch (errorCode) { - case APPLICATION_ERROR: - return new ApplicationErrorException(message); - case CANCELED: - return new CanceledException(message); - case CONNECTION_CLOSE: - return new ConnectionCloseException(message); - case CONNECTION_ERROR: - return new ConnectionErrorException(message); - case INVALID: - return new InvalidException(message); - case INVALID_SETUP: - return new InvalidSetupException(message); - case REJECTED: - return new RejectedException(message); - case REJECTED_RESUME: - return new RejectedResumeException(message); - case REJECTED_SETUP: - return new RejectedSetupException(message); - case UNSUPPORTED_SETUP: - return new UnsupportedSetupException(message); - default: - return new IllegalArgumentException( - String.format("Invalid Error frame: %d '%s'", errorCode, message)); + if (streamId == 0) { + switch (errorCode) { + case INVALID_SETUP: + return new InvalidSetupException(message); + case UNSUPPORTED_SETUP: + return new UnsupportedSetupException(message); + case REJECTED_SETUP: + return new RejectedSetupException(message); + case REJECTED_RESUME: + return new RejectedResumeException(message); + case CONNECTION_ERROR: + return new ConnectionErrorException(message); + case CONNECTION_CLOSE: + return new ConnectionCloseException(message); + default: + return new IllegalArgumentException( + String.format("Invalid Error frame in Stream ID 0: 0x%08X '%s'", errorCode, message)); + } + } else { + switch (errorCode) { + case APPLICATION_ERROR: + return new ApplicationErrorException(message); + case REJECTED: + return new RejectedException(message); + case CANCELED: + return new CanceledException(message); + case INVALID: + return new InvalidException(message); + default: + if (errorCode >= MIN_USER_ALLOWED_ERROR_CODE + || errorCode <= MAX_USER_ALLOWED_ERROR_CODE) { + return new CustomRSocketException(errorCode, message); + } + return new IllegalArgumentException( + String.format( + "Invalid Error frame in Stream ID %d: 0x%08X '%s'", + streamId, errorCode, message)); + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java index 4783b1590..a1b77b8dd 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The request is invalid. @@ -32,10 +33,9 @@ public final class InvalidException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public InvalidException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public InvalidException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public InvalidException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.INVALID; + public InvalidException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.INVALID, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java index b3705d5b7..b0889c5a6 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The Setup frame is invalid for the server (it could be that the client is too recent for the old @@ -33,10 +34,9 @@ public final class InvalidSetupException extends SetupException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public InvalidSetupException(String message) { - super(message); + this(message, null); } /** @@ -44,14 +44,8 @@ public InvalidSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public InvalidSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.INVALID_SETUP; + public InvalidSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.INVALID_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java index 7508a1ee3..2b137282f 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,41 +16,48 @@ package io.rsocket.exceptions; -import java.util.Objects; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; import reactor.util.annotation.Nullable; -/** The root of the RSocket exception hierarchy. */ -public abstract class RSocketException extends RuntimeException { +/** + * The root of the RSocket exception hierarchy. + * + * @deprecated please use {@link RSocketErrorException} instead + */ +@Deprecated +public abstract class RSocketException extends RSocketErrorException { private static final long serialVersionUID = 2912815394105575423L; /** - * Constructs a new exception with the specified message. + * Constructs a new exception with the specified message and error code 0x201 (Application error). * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RSocketException(String message) { - super(Objects.requireNonNull(message, "message must not be null")); + this(message, null); } /** - * Constructs a new exception with the specified message and cause. + * Constructs a new exception with the specified message and cause and error code 0x201 + * (Application error). * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} is {@code null} */ public RSocketException(String message, @Nullable Throwable cause) { - super(Objects.requireNonNull(message, "message must not be null"), cause); + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); } /** - * Returns the RSocket error code - * represented by this exception + * Constructs a new exception with the specified error code, message and cause. * - * @return the RSocket error code + * @param errorCode the RSocket protocol error code + * @param message the message + * @param cause the cause of this exception */ - public abstract int errorCode(); + public RSocketException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java index 4ab83182e..baed84e1b 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * Despite being a valid request, the Responder decided to reject it. The Responder guarantees that @@ -34,10 +35,9 @@ public class RejectedException extends RSocketException implements Retryable { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedException(String message) { - super(message); + this(message, null); } /** @@ -45,14 +45,8 @@ public RejectedException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED; + public RejectedException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java index 0d4116538..8a99fcffb 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The server rejected the resume, it can specify the reason in the payload. @@ -32,10 +33,9 @@ public final class RejectedResumeException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedResumeException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public RejectedResumeException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedResumeException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED_RESUME; + public RejectedResumeException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED_RESUME, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java index 1fa5f604e..c09a27e32 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * The server rejected the setup, it can specify the reason in the payload. @@ -32,10 +33,9 @@ public final class RejectedSetupException extends SetupException implements Retr * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public RejectedSetupException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public RejectedSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public RejectedSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.REJECTED_SETUP; + public RejectedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java index 2111a51b1..ed979c9e6 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ package io.rsocket.exceptions; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + /** The root of the setup exception hierarchy. */ public abstract class SetupException extends RSocketException { @@ -25,10 +28,11 @@ public abstract class SetupException extends RSocketException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} + * @deprecated please use {@link #SetupException(int, String, Throwable)} */ + @Deprecated public SetupException(String message) { - super(message); + this(message, null); } /** @@ -36,9 +40,21 @@ public SetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @deprecated please use {@link #SetupException(int, String, Throwable)} + */ + @Deprecated + public SetupException(String message, @Nullable Throwable cause) { + this(ErrorFrameCodec.INVALID_SETUP, message, cause); + } + + /** + * Constructs a new exception with the specified error code, message and cause. + * + * @param errorCode the RSocket protocol code + * @param message the message + * @param cause the cause of this exception */ - public SetupException(String message, Throwable cause) { - super(message, cause); + public SetupException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java index 7d14bc5d2..7429ccd98 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package io.rsocket.exceptions; -import io.rsocket.frame.ErrorType; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; /** * Some (or all) of the parameters specified by the client are unsupported by the server. @@ -32,10 +33,9 @@ public final class UnsupportedSetupException extends SetupException { * Constructs a new exception with the specified message. * * @param message the message - * @throws NullPointerException if {@code message} is {@code null} */ public UnsupportedSetupException(String message) { - super(message); + this(message, null); } /** @@ -43,14 +43,8 @@ public UnsupportedSetupException(String message) { * * @param message the message * @param cause the cause of this exception - * @throws NullPointerException if {@code message} or {@code cause} is {@code null} */ - public UnsupportedSetupException(String message, Throwable cause) { - super(message, cause); - } - - @Override - public int errorCode() { - return ErrorType.UNSUPPORTED_SETUP; + public UnsupportedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.UNSUPPORTED_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java b/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java index babf8194e..969aedded 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ */ /** - * The hierarchy of exceptions that can be returned by the API + * A hierarchy of exceptions that represent RSocket protocol error codes. * * @see Error * Codes diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java index cbe989d4b..5d89bb9ad 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,19 +19,18 @@ import static io.rsocket.fragmentation.FrameFragmenter.fragmentFrame; import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; import io.rsocket.frame.FrameType; import java.util.Objects; -import javax.annotation.Nullable; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; /** * A {@link DuplexConnection} implementation that fragments and reassembles {@link ByteBuf}s. @@ -40,29 +39,25 @@ * href="https://github.com/rsocket/rsocket/blob/master/Protocol.md#fragmentation-and-reassembly">Fragmentation * and Reassembly */ -public final class FragmentationDuplexConnection implements DuplexConnection { - private static final int MIN_MTU_SIZE = 64; +public final class FragmentationDuplexConnection extends ReassemblyDuplexConnection + implements DuplexConnection { + public static final int MIN_MTU_SIZE = 64; private static final Logger logger = LoggerFactory.getLogger(FragmentationDuplexConnection.class); private final DuplexConnection delegate; private final int mtu; - private final ByteBufAllocator allocator; private final FrameReassembler frameReassembler; private final boolean encodeLength; private final String type; public FragmentationDuplexConnection( - DuplexConnection delegate, - ByteBufAllocator allocator, - int mtu, - boolean encodeLength, - String type) { + DuplexConnection delegate, int mtu, boolean encodeAndEncodeLength, String type) { + super(delegate, encodeAndEncodeLength); + Objects.requireNonNull(delegate, "delegate must not be null"); - Objects.requireNonNull(allocator, "byteBufAllocator must not be null"); - this.encodeLength = encodeLength; - this.allocator = allocator; + this.encodeLength = encodeAndEncodeLength; this.delegate = delegate; this.mtu = assertMtu(mtu); - this.frameReassembler = new FrameReassembler(allocator); + this.frameReassembler = new FrameReassembler(delegate.alloc()); this.type = type; delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); @@ -105,25 +100,25 @@ public Mono send(Publisher frames) { @Override public Mono sendOne(ByteBuf frame) { - FrameType frameType = FrameHeaderFlyweight.frameType(frame); + FrameType frameType = FrameHeaderCodec.frameType(frame); int readableBytes = frame.readableBytes(); if (shouldFragment(frameType, readableBytes)) { if (logger.isDebugEnabled()) { return delegate.send( - Flux.from(fragmentFrame(allocator, mtu, frame, frameType, encodeLength)) + Flux.from(fragmentFrame(alloc(), mtu, frame, frameType, encodeLength)) .doOnNext( byteBuf -> { - ByteBuf f = encodeLength ? FrameLengthFlyweight.frame(byteBuf) : byteBuf; + ByteBuf f = encodeLength ? FrameLengthCodec.frame(byteBuf) : byteBuf; logger.debug( "{} - stream id {} - frame type {} - \n {}", type, - FrameHeaderFlyweight.streamId(f), - FrameHeaderFlyweight.frameType(f), + FrameHeaderCodec.streamId(f), + FrameHeaderCodec.frameType(f), ByteBufUtil.prettyHexDump(f)); })); } else { return delegate.send( - Flux.from(fragmentFrame(allocator, mtu, frame, frameType, encodeLength))); + Flux.from(fragmentFrame(alloc(), mtu, frame, frameType, encodeLength))); } } else { return delegate.sendOne(encode(frame)); @@ -132,38 +127,9 @@ public Mono sendOne(ByteBuf frame) { private ByteBuf encode(ByteBuf frame) { if (encodeLength) { - return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame); - } else { - return frame; - } - } - - private ByteBuf decode(ByteBuf frame) { - if (encodeLength) { - return FrameLengthFlyweight.frame(frame).retain(); + return FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame); } else { return frame; } } - - @Override - public Flux receive() { - return delegate - .receive() - .handle( - (byteBuf, sink) -> { - ByteBuf decode = decode(byteBuf); - frameReassembler.reassembleFrame(decode, sink); - }); - } - - @Override - public Mono onClose() { - return delegate.onClose(); - } - - @Override - public void dispose() { - delegate.dispose(); - } } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java index d634f7374..4b8fd36e9 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,14 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCountUtil; -import io.rsocket.frame.*; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import java.util.function.Consumer; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; @@ -42,7 +49,7 @@ static Publisher fragmentFrame( boolean encodeLength) { ByteBuf metadata = getMetadata(frame, frameType); ByteBuf data = getData(frame, frameType); - int streamId = FrameHeaderFlyweight.streamId(frame); + int streamId = FrameHeaderCodec.streamId(frame); return Flux.generate( new Consumer>() { boolean first = true; @@ -77,7 +84,7 @@ static ByteBuf encodeFirstFragment( ByteBuf metadata, ByteBuf data) { // subtract the header bytes - int remaining = mtu - FrameHeaderFlyweight.size(); + int remaining = mtu - FrameHeaderCodec.size(); // substract the initial request n switch (frameType) { @@ -105,40 +112,40 @@ static ByteBuf encodeFirstFragment( switch (frameType) { case REQUEST_FNF: - return RequestFireAndForgetFrameFlyweight.encode( + return RequestFireAndForgetFrameCodec.encode( allocator, streamId, true, metadataFragment, dataFragment); case REQUEST_STREAM: - return RequestStreamFrameFlyweight.encode( + return RequestStreamFrameCodec.encode( allocator, streamId, true, - RequestStreamFrameFlyweight.initialRequestN(frame), + RequestStreamFrameCodec.initialRequestN(frame), metadataFragment, dataFragment); case REQUEST_RESPONSE: - return RequestResponseFrameFlyweight.encode( + return RequestResponseFrameCodec.encode( allocator, streamId, true, metadataFragment, dataFragment); case REQUEST_CHANNEL: - return RequestChannelFrameFlyweight.encode( + return RequestChannelFrameCodec.encode( allocator, streamId, true, false, - RequestChannelFrameFlyweight.initialRequestN(frame), + RequestChannelFrameCodec.initialRequestN(frame), metadataFragment, dataFragment); // Payload and synthetic types case PAYLOAD: - return PayloadFrameFlyweight.encode( + return PayloadFrameCodec.encode( allocator, streamId, true, false, false, metadataFragment, dataFragment); case NEXT: - return PayloadFrameFlyweight.encode( + return PayloadFrameCodec.encode( allocator, streamId, true, false, true, metadataFragment, dataFragment); case NEXT_COMPLETE: - return PayloadFrameFlyweight.encode( + return PayloadFrameCodec.encode( allocator, streamId, true, true, true, metadataFragment, dataFragment); case COMPLETE: - return PayloadFrameFlyweight.encode( + return PayloadFrameCodec.encode( allocator, streamId, true, true, false, metadataFragment, dataFragment); default: throw new IllegalStateException("unsupported fragment type: " + frameType); @@ -148,7 +155,7 @@ static ByteBuf encodeFirstFragment( static ByteBuf encodeFollowsFragment( ByteBufAllocator allocator, int mtu, int streamId, ByteBuf metadata, ByteBuf data) { // subtract the header bytes - int remaining = mtu - FrameHeaderFlyweight.size(); + int remaining = mtu - FrameHeaderCodec.size(); ByteBuf metadataFragment = null; if (metadata.isReadable()) { @@ -166,33 +173,33 @@ static ByteBuf encodeFollowsFragment( } boolean follows = data.isReadable() || metadata.isReadable(); - return PayloadFrameFlyweight.encode( + return PayloadFrameCodec.encode( allocator, streamId, follows, false, true, metadataFragment, dataFragment); } static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(frame); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(frame); if (hasMetadata) { ByteBuf metadata; switch (frameType) { case REQUEST_FNF: - metadata = RequestFireAndForgetFrameFlyweight.metadata(frame); + metadata = RequestFireAndForgetFrameCodec.metadata(frame); break; case REQUEST_STREAM: - metadata = RequestStreamFrameFlyweight.metadata(frame); + metadata = RequestStreamFrameCodec.metadata(frame); break; case REQUEST_RESPONSE: - metadata = RequestResponseFrameFlyweight.metadata(frame); + metadata = RequestResponseFrameCodec.metadata(frame); break; case REQUEST_CHANNEL: - metadata = RequestChannelFrameFlyweight.metadata(frame); + metadata = RequestChannelFrameCodec.metadata(frame); break; // Payload and synthetic types case PAYLOAD: case NEXT: case NEXT_COMPLETE: case COMPLETE: - metadata = PayloadFrameFlyweight.metadata(frame); + metadata = PayloadFrameCodec.metadata(frame); break; default: throw new IllegalStateException("unsupported fragment type"); @@ -207,23 +214,23 @@ static ByteBuf getData(ByteBuf frame, FrameType frameType) { ByteBuf data; switch (frameType) { case REQUEST_FNF: - data = RequestFireAndForgetFrameFlyweight.data(frame); + data = RequestFireAndForgetFrameCodec.data(frame); break; case REQUEST_STREAM: - data = RequestStreamFrameFlyweight.data(frame); + data = RequestStreamFrameCodec.data(frame); break; case REQUEST_RESPONSE: - data = RequestResponseFrameFlyweight.data(frame); + data = RequestResponseFrameCodec.data(frame); break; case REQUEST_CHANNEL: - data = RequestChannelFrameFlyweight.data(frame); + data = RequestChannelFrameCodec.data(frame); break; // Payload and synthetic types case PAYLOAD: case NEXT: case NEXT_COMPLETE: case COMPLETE: - data = PayloadFrameFlyweight.data(frame); + data = PayloadFrameCodec.data(frame); break; default: throw new IllegalStateException("unsupported fragment type"); @@ -233,7 +240,7 @@ static ByteBuf getData(ByteBuf frame, FrameType frameType) { static ByteBuf encode(ByteBufAllocator allocator, ByteBuf frame, boolean encodeLength) { if (encodeLength) { - return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame); + return FrameLengthCodec.encode(allocator, frame.readableBytes(), frame); } else { return frame; } diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java index 0c446a7c4..52068e5de 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ import org.slf4j.LoggerFactory; import reactor.core.Disposable; import reactor.core.publisher.SynchronousSink; +import reactor.util.annotation.Nullable; /** * The implementation of the RSocket reassembly behavior. @@ -83,6 +84,7 @@ public boolean isDisposed() { return get(); } + @Nullable synchronized ByteBuf getHeader(int streamId) { return headers.get(streamId); } @@ -109,14 +111,17 @@ synchronized CompositeByteBuf getData(int streamId) { return byteBuf; } + @Nullable synchronized ByteBuf removeHeader(int streamId) { return headers.remove(streamId); } + @Nullable synchronized CompositeByteBuf removeMetadata(int streamId) { return metadata.remove(streamId); } + @Nullable synchronized CompositeByteBuf removeData(int streamId) { return data.remove(streamId); } @@ -146,12 +151,12 @@ void cancelAssemble(int streamId) { void handleNoFollowsFlag(ByteBuf frame, SynchronousSink sink, int streamId) { ByteBuf header = removeHeader(streamId); if (header != null) { - if (FrameHeaderFlyweight.hasMetadata(header)) { + if (FrameHeaderCodec.hasMetadata(header)) { ByteBuf assembledFrame = assembleFrameWithMetadata(frame, streamId, header); sink.next(assembledFrame); } else { ByteBuf data = assembleData(frame, streamId); - ByteBuf assembledFrame = FragmentationFlyweight.encode(allocator, header, data); + ByteBuf assembledFrame = FragmentationCodec.encode(allocator, header, data); sink.next(assembledFrame); } frame.release(); @@ -163,36 +168,36 @@ void handleNoFollowsFlag(ByteBuf frame, SynchronousSink sink, int strea void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { ByteBuf header = getHeader(streamId); if (header == null) { - header = frame.copy(frame.readerIndex(), FrameHeaderFlyweight.size()); + header = frame.copy(frame.readerIndex(), FrameHeaderCodec.size()); if (frameType == FrameType.REQUEST_CHANNEL || frameType == FrameType.REQUEST_STREAM) { - int i = RequestChannelFrameFlyweight.initialRequestN(frame); - header.writeInt(i); + long i = RequestChannelFrameCodec.initialRequestN(frame); + header.writeInt(i > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) i); } putHeader(streamId, header); } - if (FrameHeaderFlyweight.hasMetadata(frame)) { + if (FrameHeaderCodec.hasMetadata(frame)) { CompositeByteBuf metadata = getMetadata(streamId); switch (frameType) { case REQUEST_FNF: - metadata.addComponents(true, RequestFireAndForgetFrameFlyweight.metadata(frame).retain()); + metadata.addComponents(true, RequestFireAndForgetFrameCodec.metadata(frame).retain()); break; case REQUEST_STREAM: - metadata.addComponents(true, RequestStreamFrameFlyweight.metadata(frame).retain()); + metadata.addComponents(true, RequestStreamFrameCodec.metadata(frame).retain()); break; case REQUEST_RESPONSE: - metadata.addComponents(true, RequestResponseFrameFlyweight.metadata(frame).retain()); + metadata.addComponents(true, RequestResponseFrameCodec.metadata(frame).retain()); break; case REQUEST_CHANNEL: - metadata.addComponents(true, RequestChannelFrameFlyweight.metadata(frame).retain()); + metadata.addComponents(true, RequestChannelFrameCodec.metadata(frame).retain()); break; // Payload and synthetic types case PAYLOAD: case NEXT: case NEXT_COMPLETE: case COMPLETE: - metadata.addComponents(true, PayloadFrameFlyweight.metadata(frame).retain()); + metadata.addComponents(true, PayloadFrameCodec.metadata(frame).retain()); break; default: throw new IllegalStateException("unsupported fragment type"); @@ -202,23 +207,23 @@ void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { ByteBuf data; switch (frameType) { case REQUEST_FNF: - data = RequestFireAndForgetFrameFlyweight.data(frame).retain(); + data = RequestFireAndForgetFrameCodec.data(frame).retain(); break; case REQUEST_STREAM: - data = RequestStreamFrameFlyweight.data(frame).retain(); + data = RequestStreamFrameCodec.data(frame).retain(); break; case REQUEST_RESPONSE: - data = RequestResponseFrameFlyweight.data(frame).retain(); + data = RequestResponseFrameCodec.data(frame).retain(); break; case REQUEST_CHANNEL: - data = RequestChannelFrameFlyweight.data(frame).retain(); + data = RequestChannelFrameCodec.data(frame).retain(); break; // Payload and synthetic types case PAYLOAD: case NEXT: case NEXT_COMPLETE: case COMPLETE: - data = PayloadFrameFlyweight.data(frame).retain(); + data = PayloadFrameCodec.data(frame).retain(); break; default: throw new IllegalStateException("unsupported fragment type"); @@ -230,13 +235,12 @@ void handleFollowsFlag(ByteBuf frame, int streamId, FrameType frameType) { void reassembleFrame(ByteBuf frame, SynchronousSink sink) { try { - FrameType frameType = FrameHeaderFlyweight.frameType(frame); - int streamId = FrameHeaderFlyweight.streamId(frame); + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); switch (frameType) { case CANCEL: case ERROR: cancelAssemble(streamId); - default: } if (!frameType.isFragmentable()) { @@ -244,7 +248,7 @@ void reassembleFrame(ByteBuf frame, SynchronousSink sink) { return; } - boolean hasFollows = FrameHeaderFlyweight.hasFollows(frame); + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); if (hasFollows) { handleFollowsFlag(frame, streamId, frameType); @@ -261,22 +265,28 @@ void reassembleFrame(ByteBuf frame, SynchronousSink sink) { private ByteBuf assembleFrameWithMetadata(ByteBuf frame, int streamId, ByteBuf header) { ByteBuf metadata; CompositeByteBuf cm = removeMetadata(streamId); - if (cm != null) { - metadata = cm.addComponents(true, PayloadFrameFlyweight.metadata(frame).retain()); + + ByteBuf decodedMetadata = PayloadFrameCodec.metadata(frame); + if (decodedMetadata != null) { + if (cm != null) { + metadata = cm.addComponents(true, decodedMetadata.retain()); + } else { + metadata = PayloadFrameCodec.metadata(frame).retain(); + } } else { - metadata = PayloadFrameFlyweight.metadata(frame).retain(); + metadata = cm; } ByteBuf data = assembleData(frame, streamId); - return FragmentationFlyweight.encode(allocator, header, metadata, data); + return FragmentationCodec.encode(allocator, header, metadata, data); } private ByteBuf assembleData(ByteBuf frame, int streamId) { ByteBuf data; CompositeByteBuf cd = removeData(streamId); if (cd != null) { - cd.addComponents(true, PayloadFrameFlyweight.data(frame).retain()); + cd.addComponents(true, PayloadFrameCodec.data(frame).retain()); data = cd; } else { data = Unpooled.EMPTY_BUFFER; diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java new file mode 100644 index 000000000..6060c0c20 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/ReassemblyDuplexConnection.java @@ -0,0 +1,92 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.fragmentation; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameLengthCodec; +import java.util.Objects; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * A {@link DuplexConnection} implementation that reassembles {@link ByteBuf}s. + * + * @see Fragmentation + * and Reassembly + */ +public class ReassemblyDuplexConnection implements DuplexConnection { + private final DuplexConnection delegate; + private final FrameReassembler frameReassembler; + private final boolean decodeLength; + + public ReassemblyDuplexConnection(DuplexConnection delegate, boolean decodeLength) { + Objects.requireNonNull(delegate, "delegate must not be null"); + this.decodeLength = decodeLength; + this.delegate = delegate; + this.frameReassembler = new FrameReassembler(delegate.alloc()); + + delegate.onClose().doFinally(s -> frameReassembler.dispose()).subscribe(); + } + + @Override + public Mono send(Publisher frames) { + return delegate.send(frames); + } + + @Override + public Mono sendOne(ByteBuf frame) { + return delegate.sendOne(frame); + } + + private ByteBuf decode(ByteBuf frame) { + if (decodeLength) { + return FrameLengthCodec.frame(frame).retain(); + } else { + return frame; + } + } + + @Override + public Flux receive() { + return delegate + .receive() + .handle( + (byteBuf, sink) -> { + ByteBuf decode = decode(byteBuf); + frameReassembler.reassembleFrame(decode, sink); + }); + } + + @Override + public ByteBufAllocator alloc() { + return delegate.alloc(); + } + + @Override + public Mono onClose() { + return delegate.onClose(); + } + + @Override + public void dispose() { + delegate.dispose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java index 4431f98dd..8cc3fb41a 100644 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java similarity index 55% rename from rsocket-core/src/main/java/io/rsocket/frame/CancelFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java index 349a43c3a..d0d929f0f 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java @@ -3,10 +3,10 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -public class CancelFrameFlyweight { - private CancelFrameFlyweight() {} +public class CancelFrameCodec { + private CancelFrameCodec() {} public static ByteBuf encode(final ByteBufAllocator allocator, final int streamId) { - return FrameHeaderFlyweight.encode(allocator, streamId, FrameType.CANCEL, 0); + return FrameHeaderCodec.encode(allocator, streamId, FrameType.CANCEL, 0); } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java deleted file mode 100644 index e4b16fec7..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/DataAndMetadataFlyweight.java +++ /dev/null @@ -1,88 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.buffer.TupleByteBuf; - -class DataAndMetadataFlyweight { - public static final int FRAME_LENGTH_MASK = 0xFFFFFF; - - private DataAndMetadataFlyweight() {} - - private static void encodeLength(final ByteBuf byteBuf, final int length) { - if ((length & ~FRAME_LENGTH_MASK) != 0) { - throw new IllegalArgumentException("Length is larger than 24 bits"); - } - // Write each byte separately in reverse order, this mean we can write 1 << 23 without - // overflowing. - byteBuf.writeByte(length >> 16); - byteBuf.writeByte(length >> 8); - byteBuf.writeByte(length); - } - - private static int decodeLength(final ByteBuf byteBuf) { - byte b = byteBuf.readByte(); - int length = (b & 0xFF) << 16; - byte b1 = byteBuf.readByte(); - length |= (b1 & 0xFF) << 8; - byte b2 = byteBuf.readByte(); - length |= b2 & 0xFF; - return length; - } - - static ByteBuf encodeOnlyMetadata( - ByteBufAllocator allocator, final ByteBuf header, ByteBuf metadata) { - return TupleByteBuf.of(allocator, header, metadata); - } - - static ByteBuf encodeOnlyData(ByteBufAllocator allocator, final ByteBuf header, ByteBuf data) { - return TupleByteBuf.of(allocator, header, data); - } - - static ByteBuf encode( - ByteBufAllocator allocator, final ByteBuf header, ByteBuf metadata, ByteBuf data) { - - int length = metadata.readableBytes(); - encodeLength(header, length); - return TupleByteBuf.of(allocator, header, metadata, data); - } - - static ByteBuf metadataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { - if (hasMetadata) { - int length = decodeLength(byteBuf); - return byteBuf.readSlice(length); - } else { - return Unpooled.EMPTY_BUFFER; - } - } - - static ByteBuf metadata(ByteBuf byteBuf, boolean hasMetadata) { - byteBuf.markReaderIndex(); - byteBuf.skipBytes(6); - ByteBuf metadata = metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; - } - - static ByteBuf dataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { - if (hasMetadata) { - /*moves reader index*/ - int length = decodeLength(byteBuf); - byteBuf.skipBytes(length); - } - if (byteBuf.readableBytes() > 0) { - return byteBuf.readSlice(byteBuf.readableBytes()); - } else { - return Unpooled.EMPTY_BUFFER; - } - } - - static ByteBuf data(ByteBuf byteBuf, boolean hasMetadata) { - byteBuf.markReaderIndex(); - byteBuf.skipBytes(6); - ByteBuf data = dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return data; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java similarity index 70% rename from rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java index 55e23541e..dcacb57dc 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java @@ -3,28 +3,35 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; -import io.rsocket.exceptions.RSocketException; +import io.rsocket.RSocketErrorException; import java.nio.charset.StandardCharsets; -public class ErrorFrameFlyweight { +public class ErrorFrameCodec { - // defined error codes + // defined zero stream id error codes public static final int INVALID_SETUP = 0x00000001; public static final int UNSUPPORTED_SETUP = 0x00000002; public static final int REJECTED_SETUP = 0x00000003; public static final int REJECTED_RESUME = 0x00000004; public static final int CONNECTION_ERROR = 0x00000101; public static final int CONNECTION_CLOSE = 0x00000102; + // defined non-zero stream id error codes public static final int APPLICATION_ERROR = 0x00000201; public static final int REJECTED = 0x00000202; public static final int CANCELED = 0x00000203; public static final int INVALID = 0x00000204; + // defined user-allowed error codes range + public static final int MIN_USER_ALLOWED_ERROR_CODE = 0x00000301; + public static final int MAX_USER_ALLOWED_ERROR_CODE = 0xFFFFFFFE; public static ByteBuf encode( ByteBufAllocator allocator, int streamId, Throwable t, ByteBuf data) { - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.ERROR, 0); + ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.ERROR, 0); - int errorCode = errorCodeFromException(t); + int errorCode = + t instanceof RSocketErrorException + ? ((RSocketErrorException) t).errorCode() + : APPLICATION_ERROR; header.writeInt(errorCode); @@ -37,17 +44,9 @@ public static ByteBuf encode(ByteBufAllocator allocator, int streamId, Throwable return encode(allocator, streamId, t, data); } - public static int errorCodeFromException(Throwable t) { - if (t instanceof RSocketException) { - return ((RSocketException) t).errorCode(); - } - - return APPLICATION_ERROR; - } - public static int errorCode(ByteBuf byteBuf) { byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); + byteBuf.skipBytes(FrameHeaderCodec.size()); int i = byteBuf.readInt(); byteBuf.resetReaderIndex(); return i; @@ -55,7 +54,7 @@ public static int errorCode(ByteBuf byteBuf) { public static ByteBuf data(ByteBuf byteBuf) { byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); ByteBuf slice = byteBuf.slice(); byteBuf.resetReaderIndex(); return slice; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java deleted file mode 100644 index ccbff374e..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorType.java +++ /dev/null @@ -1,74 +0,0 @@ -package io.rsocket.frame; - -/** - * The types of {@link Error} that can be set. - * - * @see Error - * Codes - */ -public final class ErrorType { - - /** - * Application layer logic generating a Reactive Streams onError event. Stream ID MUST be > 0. - */ - public static final int APPLICATION_ERROR = 0x00000201;; - - /** - * The Responder canceled the request but may have started processing it (similar to REJECTED but - * doesn't guarantee lack of side-effects). Stream ID MUST be > 0. - */ - public static final int CANCELED = 0x00000203; - - /** - * The connection is being terminated. Stream ID MUST be 0. Sender or Receiver of this frame MUST - * wait for outstanding streams to terminate before closing the connection. New requests MAY not - * be accepted. - */ - public static final int CONNECTION_CLOSE = 0x00000102; - - /** - * The connection is being terminated. Stream ID MUST be 0. Sender or Receiver of this frame MAY - * close the connection immediately without waiting for outstanding streams to terminate. - */ - public static final int CONNECTION_ERROR = 0x00000101; - - /** The request is invalid. Stream ID MUST be > 0. */ - public static final int INVALID = 0x00000204; - - /** - * The Setup frame is invalid for the server (it could be that the client is too recent for the - * old server). Stream ID MUST be 0. - */ - public static final int INVALID_SETUP = 0x00000001; - - /** - * Despite being a valid request, the Responder decided to reject it. The Responder guarantees - * that it didn't process the request. The reason for the rejection is explained in the Error Data - * section. Stream ID MUST be > 0. - */ - public static final int REJECTED = 0x00000202; - - /** - * The server rejected the resume, it can specify the reason in the payload. Stream ID MUST be 0. - */ - public static final int REJECTED_RESUME = 0x00000004; - - /** - * The server rejected the setup, it can specify the reason in the payload. Stream ID MUST be 0. - */ - public static final int REJECTED_SETUP = 0x00000003; - - /** Reserved. */ - public static final int RESERVED = 0x00000000; - - /** Reserved for Extension Use. */ - public static final int RESERVED_FOR_EXTENSION = 0xFFFFFFFF; - - /** - * Some (or all) of the parameters specified by the client are unsupported by the server. Stream - * ID MUST be 0. - */ - public static final int UNSUPPORTED_SETUP = 0x00000002; - - private ErrorType() {} -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java new file mode 100644 index 000000000..418926596 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java @@ -0,0 +1,67 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +public class ExtensionFrameCodec { + private ExtensionFrameCodec() {} + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + int extendedType, + @Nullable ByteBuf metadata, + ByteBuf data) { + + final boolean hasMetadata = metadata != null; + + int flags = FrameHeaderCodec.FLAGS_I; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.EXT, flags); + header.writeInt(extendedType); + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + public static int extendedType(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i; + } + + public static ByteBuf data(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + // Extended type + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + // Extended type + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java deleted file mode 100644 index df8b308e9..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameFlyweight.java +++ /dev/null @@ -1,66 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import javax.annotation.Nullable; - -public class ExtensionFrameFlyweight { - private ExtensionFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - int extendedType, - @Nullable ByteBuf metadata, - ByteBuf data) { - - int flags = FrameHeaderFlyweight.FLAGS_I; - - if (metadata != null) { - flags |= FrameHeaderFlyweight.FLAGS_M; - } - - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.EXT, flags); - header.writeInt(extendedType); - if (data == null && metadata == null) { - return header; - } else if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } - } - - public static int extendedType(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.EXT, byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - int i = byteBuf.readInt(); - byteBuf.resetReaderIndex(); - return i; - } - - public static ByteBuf data(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.EXT, byteBuf); - - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - // Extended type - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf data = DataAndMetadataFlyweight.dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return data; - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.EXT, byteBuf); - - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - // Extended type - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java similarity index 60% rename from rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java index 06efeab6c..de228b271 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FragmentationFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java @@ -5,7 +5,7 @@ import reactor.util.annotation.Nullable; /** FragmentationFlyweight is used to re-assemble frames */ -public class FragmentationFlyweight { +public class FragmentationCodec { public static ByteBuf encode(final ByteBufAllocator allocator, ByteBuf header, ByteBuf data) { return encode(allocator, header, null, data); } @@ -13,12 +13,7 @@ public static ByteBuf encode(final ByteBufAllocator allocator, ByteBuf header, B public static ByteBuf encode( final ByteBufAllocator allocator, ByteBuf header, @Nullable ByteBuf metadata, ByteBuf data) { - if (data == null && metadata == null) { - return header; - } else if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } + final boolean hasMetadata = metadata != null; + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java new file mode 100644 index 000000000..ea011e503 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java @@ -0,0 +1,103 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import reactor.util.annotation.Nullable; + +class FrameBodyCodec { + public static final int FRAME_LENGTH_MASK = 0xFFFFFF; + + private FrameBodyCodec() {} + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + byte b = byteBuf.readByte(); + int length = (b & 0xFF) << 16; + byte b1 = byteBuf.readByte(); + length |= (b1 & 0xFF) << 8; + byte b2 = byteBuf.readByte(); + length |= b2 & 0xFF; + return length; + } + + static ByteBuf encode( + ByteBufAllocator allocator, + final ByteBuf header, + @Nullable ByteBuf metadata, + boolean hasMetadata, + @Nullable ByteBuf data) { + + final boolean addData; + if (data != null) { + if (data.isReadable()) { + addData = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + data.release(); + addData = false; + } + } else { + addData = false; + } + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (hasMetadata) { + int length = metadata.readableBytes(); + encodeLength(header, length); + } + + if (addMetadata && addData) { + return allocator.compositeBuffer(3).addComponents(true, header, metadata, data); + } else if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else if (addData) { + return allocator.compositeBuffer(2).addComponents(true, header, data); + } else { + return header; + } + } + + static ByteBuf metadataWithoutMarking(ByteBuf byteBuf) { + int length = decodeLength(byteBuf); + return byteBuf.readSlice(length); + } + + static ByteBuf dataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { + if (hasMetadata) { + /*moves reader index*/ + int length = decodeLength(byteBuf); + byteBuf.skipBytes(length); + } + if (byteBuf.readableBytes() > 0) { + return byteBuf.readSlice(byteBuf.readableBytes()); + } else { + return Unpooled.EMPTY_BUFFER; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java similarity index 98% rename from rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java index cbc677444..28f39459d 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java @@ -12,7 +12,7 @@ * *

Not thread-safe. Assumed to be used single-threaded */ -public final class FrameHeaderFlyweight { +public final class FrameHeaderCodec { /** (I)gnore flag: a value of 0 indicates the protocol can't ignore this frame */ public static final int FLAGS_I = 0b10_0000_0000; /** (M)etadata flag: a value of 1 indicates the frame contains metadata */ @@ -38,7 +38,7 @@ public final class FrameHeaderFlyweight { disableFrameTypeCheck = Boolean.getBoolean(DISABLE_FRAME_TYPE_CHECK); } - private FrameHeaderFlyweight() {} + private FrameHeaderCodec() {} static ByteBuf encodeStreamZero( final ByteBufAllocator allocator, final FrameType frameType, int flags) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java similarity index 90% rename from rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java index 6011263fa..f6c19c8ee 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java @@ -2,17 +2,16 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.buffer.TupleByteBuf; /** * Some transports like TCP aren't framed, and require a length. This is used by DuplexConnections * for transports that need to send length */ -public class FrameLengthFlyweight { +public class FrameLengthCodec { public static final int FRAME_LENGTH_MASK = 0xFFFFFF; public static final int FRAME_LENGTH_SIZE = 3; - private FrameLengthFlyweight() {} + private FrameLengthCodec() {} private static void encodeLength(final ByteBuf byteBuf, final int length) { if ((length & ~FRAME_LENGTH_MASK) != 0) { @@ -35,7 +34,7 @@ private static int decodeLength(final ByteBuf byteBuf) { public static ByteBuf encode(ByteBufAllocator allocator, int length, ByteBuf frame) { ByteBuf buffer = allocator.buffer(); encodeLength(buffer, length); - return TupleByteBuf.of(allocator, buffer, frame); + return allocator.compositeBuffer(2).addComponents(true, buffer, frame); } public static int length(ByteBuf byteBuf) { diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java index 0d2175fb6..66d18c8a7 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -9,8 +9,8 @@ public class FrameUtil { private FrameUtil() {} public static String toString(ByteBuf frame) { - FrameType frameType = FrameHeaderFlyweight.frameType(frame); - int streamId = FrameHeaderFlyweight.streamId(frame); + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); StringBuilder payload = new StringBuilder(); payload @@ -19,10 +19,18 @@ public static String toString(ByteBuf frame) { .append(" Type: ") .append(frameType) .append(" Flags: 0b") - .append(Integer.toBinaryString(FrameHeaderFlyweight.flags(frame))) + .append(Integer.toBinaryString(FrameHeaderCodec.flags(frame))) .append(" Length: " + frame.readableBytes()); - if (FrameHeaderFlyweight.hasMetadata(frame)) { + if (frameType.hasInitialRequestN()) { + payload.append(" InitialRequestN: ").append(RequestStreamFrameCodec.initialRequestN(frame)); + } + + if (frameType == FrameType.REQUEST_N) { + payload.append(" RequestN: ").append(RequestNFrameCodec.requestN(frame)); + } + + if (FrameHeaderCodec.hasMetadata(frame)) { payload.append("\nMetadata:\n"); ByteBufUtil.appendPrettyHexDump(payload, getMetadata(frame, frameType)); @@ -35,37 +43,37 @@ public static String toString(ByteBuf frame) { } private static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(frame); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(frame); if (hasMetadata) { ByteBuf metadata; switch (frameType) { case REQUEST_FNF: - metadata = RequestFireAndForgetFrameFlyweight.metadata(frame); + metadata = RequestFireAndForgetFrameCodec.metadata(frame); break; case REQUEST_STREAM: - metadata = RequestStreamFrameFlyweight.metadata(frame); + metadata = RequestStreamFrameCodec.metadata(frame); break; case REQUEST_RESPONSE: - metadata = RequestResponseFrameFlyweight.metadata(frame); + metadata = RequestResponseFrameCodec.metadata(frame); break; case REQUEST_CHANNEL: - metadata = RequestChannelFrameFlyweight.metadata(frame); + metadata = RequestChannelFrameCodec.metadata(frame); break; // Payload and synthetic types case PAYLOAD: case NEXT: case NEXT_COMPLETE: case COMPLETE: - metadata = PayloadFrameFlyweight.metadata(frame); + metadata = PayloadFrameCodec.metadata(frame); break; case METADATA_PUSH: - metadata = MetadataPushFrameFlyweight.metadata(frame); + metadata = MetadataPushFrameCodec.metadata(frame); break; case SETUP: - metadata = SetupFrameFlyweight.metadata(frame); + metadata = SetupFrameCodec.metadata(frame); break; case LEASE: - metadata = LeaseFrameFlyweight.metadata(frame); + metadata = LeaseFrameCodec.metadata(frame); break; default: return Unpooled.EMPTY_BUFFER; @@ -80,26 +88,26 @@ private static ByteBuf getData(ByteBuf frame, FrameType frameType) { ByteBuf data; switch (frameType) { case REQUEST_FNF: - data = RequestFireAndForgetFrameFlyweight.data(frame); + data = RequestFireAndForgetFrameCodec.data(frame); break; case REQUEST_STREAM: - data = RequestStreamFrameFlyweight.data(frame); + data = RequestStreamFrameCodec.data(frame); break; case REQUEST_RESPONSE: - data = RequestResponseFrameFlyweight.data(frame); + data = RequestResponseFrameCodec.data(frame); break; case REQUEST_CHANNEL: - data = RequestChannelFrameFlyweight.data(frame); + data = RequestChannelFrameCodec.data(frame); break; // Payload and synthetic types case PAYLOAD: case NEXT: case NEXT_COMPLETE: case COMPLETE: - data = PayloadFrameFlyweight.data(frame); + data = PayloadFrameCodec.data(frame); break; case SETUP: - data = SetupFrameFlyweight.data(frame); + data = SetupFrameCodec.data(frame); break; default: return Unpooled.EMPTY_BUFFER; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java new file mode 100644 index 000000000..56a93d869 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java @@ -0,0 +1,159 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +class GenericFrameCodec { + + static ByteBuf encodeReleasingPayload( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean complete, + boolean next, + final Payload payload) { + return encodeReleasingPayload(allocator, frameType, streamId, complete, next, 0, payload); + } + + static ByteBuf encodeReleasingPayload( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean complete, + boolean next, + int requestN, + final Payload payload) { + + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } + + return encode(allocator, frameType, streamId, false, complete, next, requestN, metadata, data); + } + + static ByteBuf encode( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + return encode(allocator, frameType, streamId, fragmentFollows, false, false, 0, metadata, data); + } + + static ByteBuf encode( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean fragmentFollows, + boolean complete, + boolean next, + int requestN, + @Nullable ByteBuf metadata, + @Nullable ByteBuf data) { + + final boolean hasMetadata = metadata != null; + + int flags = 0; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + if (fragmentFollows) { + flags |= FrameHeaderCodec.FLAGS_F; + } + + if (complete) { + flags |= FrameHeaderCodec.FLAGS_C; + } + + if (next) { + flags |= FrameHeaderCodec.FLAGS_N; + } + + final ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, frameType, flags); + + if (requestN > 0) { + header.writeInt(requestN); + } + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + static ByteBuf data(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + int idx = byteBuf.readerIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.readerIndex(idx); + return data; + } + + @Nullable + static ByteBuf metadata(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + static ByteBuf dataWithRequestN(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + @Nullable + static ByteBuf metadataWithRequestN(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + static int initialRequestN(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int i = byteBuf.skipBytes(FrameHeaderCodec.size()).readInt(); + byteBuf.resetReaderIndex(); + return i; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java similarity index 61% rename from rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java index e4e6029b3..752d5b3eb 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java @@ -3,7 +3,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -public class KeepAliveFrameFlyweight { +public class KeepAliveFrameCodec { /** * (R)espond: Set by the sender of the KEEPALIVE, to which the responder MUST reply with a * KEEPALIVE without the R flag set @@ -12,7 +12,7 @@ public class KeepAliveFrameFlyweight { public static final long LAST_POSITION_MASK = 0x8000000000000000L; - private KeepAliveFrameFlyweight() {} + private KeepAliveFrameCodec() {} public static ByteBuf encode( final ByteBufAllocator allocator, @@ -20,7 +20,7 @@ public static ByteBuf encode( final long lastPosition, final ByteBuf data) { final int flags = respond ? FLAGS_KEEPALIVE_R : 0; - ByteBuf header = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.KEEPALIVE, flags); + ByteBuf header = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.KEEPALIVE, flags); long lp = 0; if (lastPosition > 0) { @@ -29,27 +29,27 @@ public static ByteBuf encode( header.writeLong(lp); - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); + return FrameBodyCodec.encode(allocator, header, null, false, data); } public static boolean respondFlag(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.KEEPALIVE, byteBuf); - int flags = FrameHeaderFlyweight.flags(byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + int flags = FrameHeaderCodec.flags(byteBuf); return (flags & FLAGS_KEEPALIVE_R) == FLAGS_KEEPALIVE_R; } public static long lastPosition(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); byteBuf.markReaderIndex(); - long l = byteBuf.skipBytes(FrameHeaderFlyweight.size()).readLong(); + long l = byteBuf.skipBytes(FrameHeaderCodec.size()).readLong(); byteBuf.resetReaderIndex(); return l; } public static ByteBuf data(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); byteBuf.markReaderIndex(); - ByteBuf slice = byteBuf.skipBytes(FrameHeaderFlyweight.size() + Long.BYTES).slice(); + ByteBuf slice = byteBuf.skipBytes(FrameHeaderCodec.size() + Long.BYTES).slice(); byteBuf.resetReaderIndex(); return slice; } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java new file mode 100644 index 000000000..f20c25d3b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java @@ -0,0 +1,83 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +public class LeaseFrameCodec { + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final int ttl, + final int numRequests, + @Nullable final ByteBuf metadata) { + + final boolean hasMetadata = metadata != null; + + int flags = 0; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = + FrameHeaderCodec.encodeStreamZero(allocator, FrameType.LEASE, flags) + .writeInt(ttl) + .writeInt(numRequests); + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else { + return header; + } + } + + public static int ttl(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int ttl = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return ttl; + } + + public static int numRequests(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + byteBuf.markReaderIndex(); + // Ttl + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + int numRequests = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return numRequests; + } + + @Nullable + public static ByteBuf metadata(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + if (FrameHeaderCodec.hasMetadata(byteBuf)) { + byteBuf.markReaderIndex(); + // Ttl + Num of requests + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES * 2); + ByteBuf metadata = byteBuf.slice(); + byteBuf.resetReaderIndex(); + return metadata; + } else { + return null; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java deleted file mode 100644 index 4676f4c9d..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java +++ /dev/null @@ -1,66 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import javax.annotation.Nullable; - -public class LeaseFrameFlyweight { - - public static ByteBuf encode( - final ByteBufAllocator allocator, - final int ttl, - final int numRequests, - @Nullable final ByteBuf metadata) { - - int flags = 0; - - if (metadata != null) { - flags |= FrameHeaderFlyweight.FLAGS_M; - } - - ByteBuf header = - FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.LEASE, flags) - .writeInt(ttl) - .writeInt(numRequests); - - if (metadata == null) { - return header; - } else { - return DataAndMetadataFlyweight.encodeOnlyMetadata(allocator, header, metadata); - } - } - - public static int ttl(final ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.LEASE, byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - int ttl = byteBuf.readInt(); - byteBuf.resetReaderIndex(); - return ttl; - } - - public static int numRequests(final ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.LEASE, byteBuf); - byteBuf.markReaderIndex(); - // Ttl - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - int numRequests = byteBuf.readInt(); - byteBuf.resetReaderIndex(); - return numRequests; - } - - public static ByteBuf metadata(final ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.LEASE, byteBuf); - if (FrameHeaderFlyweight.hasMetadata(byteBuf)) { - byteBuf.markReaderIndex(); - // Ttl + Num of requests - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES * 2); - ByteBuf metadata = byteBuf.slice(); - byteBuf.resetReaderIndex(); - return metadata; - } else { - return Unpooled.EMPTY_BUFFER; - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java new file mode 100644 index 000000000..d8ffe3eef --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java @@ -0,0 +1,43 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; + +public class MetadataPushFrameCodec { + + public static ByteBuf encodeReleasingPayload(ByteBufAllocator allocator, Payload payload) { + if (!payload.hasMetadata()) { + throw new IllegalStateException( + "Metadata push requires to have metadata present" + " in the given Payload"); + } + final ByteBuf metadata = payload.metadata().retain(); + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + metadata.release(); + throw e; + } + return encode(allocator, metadata); + } + + public static ByteBuf encode(ByteBufAllocator allocator, ByteBuf metadata) { + ByteBuf header = + FrameHeaderCodec.encodeStreamZero( + allocator, FrameType.METADATA_PUSH, FrameHeaderCodec.FLAGS_M); + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } + + public static ByteBuf metadata(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int headerSize = FrameHeaderCodec.size(); + int metadataLength = byteBuf.readableBytes() - headerSize; + byteBuf.skipBytes(headerSize); + ByteBuf metadata = byteBuf.readSlice(metadataLength); + byteBuf.resetReaderIndex(); + return metadata; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java deleted file mode 100644 index d37b573ba..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameFlyweight.java +++ /dev/null @@ -1,23 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; - -public class MetadataPushFrameFlyweight { - public static ByteBuf encode(ByteBufAllocator allocator, ByteBuf metadata) { - ByteBuf header = - FrameHeaderFlyweight.encodeStreamZero( - allocator, FrameType.METADATA_PUSH, FrameHeaderFlyweight.FLAGS_M); - return allocator.compositeBuffer(2).addComponents(true, header, metadata); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - byteBuf.markReaderIndex(); - int headerSize = FrameHeaderFlyweight.size(); - int metadataLength = byteBuf.readableBytes() - headerSize; - byteBuf.skipBytes(headerSize); - ByteBuf metadata = byteBuf.readSlice(metadataLength); - byteBuf.resetReaderIndex(); - return metadata; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java new file mode 100644 index 000000000..1ae9c6750 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java @@ -0,0 +1,56 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class PayloadFrameCodec { + + private PayloadFrameCodec() {} + + public static ByteBuf encodeNextReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, false, payload); + } + + public static ByteBuf encodeNextCompleteReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, true, payload); + } + + static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, boolean complete, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.PAYLOAD, streamId, complete, true, payload); + } + + public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { + return encode(allocator, streamId, false, true, false, null, null); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + boolean next, + @Nullable ByteBuf metadata, + @Nullable ByteBuf data) { + + return GenericFrameCodec.encode( + allocator, FrameType.PAYLOAD, streamId, fragmentFollows, complete, next, 0, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java deleted file mode 100644 index 4f67d9c72..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameFlyweight.java +++ /dev/null @@ -1,78 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Payload; - -public class PayloadFrameFlyweight { - private static final RequestFlyweight FLYWEIGHT = new RequestFlyweight(FrameType.PAYLOAD); - - private PayloadFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - boolean complete, - boolean next, - ByteBuf metadata, - ByteBuf data) { - return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, complete, next, 0, metadata, data); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - boolean complete, - boolean next, - Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - fragmentFollows, - complete, - next, - 0, - payload.hasMetadata() ? payload.metadata().retain() : null, - payload.data().retain()); - } - - public static ByteBuf encodeNextComplete( - ByteBufAllocator allocator, int streamId, Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - false, - true, - true, - 0, - payload.hasMetadata() ? payload.metadata().retain() : null, - payload.data().retain()); - } - - public static ByteBuf encodeNext(ByteBufAllocator allocator, int streamId, Payload payload) { - return FLYWEIGHT.encode( - allocator, - streamId, - false, - false, - true, - 0, - payload.hasMetadata() ? payload.metadata().retain() : null, - payload.data().retain()); - } - - public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { - return FLYWEIGHT.encode(allocator, streamId, false, true, false, 0, null, null); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.data(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadata(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java new file mode 100644 index 000000000..60906083d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java @@ -0,0 +1,69 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestChannelFrameCodec { + + private RequestChannelFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, + int streamId, + boolean complete, + long initialRequestN, + Payload payload) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_CHANNEL, streamId, complete, false, reqN, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + long initialRequestN, + @Nullable ByteBuf metadata, + ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encode( + allocator, + FrameType.REQUEST_CHANNEL, + streamId, + fragmentFollows, + complete, + false, + reqN, + metadata, + data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.dataWithRequestN(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadataWithRequestN(byteBuf); + } + + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = GenericFrameCodec.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java deleted file mode 100644 index 06ddcda03..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameFlyweight.java +++ /dev/null @@ -1,60 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Payload; - -public class RequestChannelFrameFlyweight { - - private static final RequestFlyweight FLYWEIGHT = new RequestFlyweight(FrameType.REQUEST_CHANNEL); - - private RequestChannelFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - boolean complete, - long requestN, - Payload payload) { - - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - - return FLYWEIGHT.encode( - allocator, - streamId, - fragmentFollows, - complete, - false, - reqN, - payload.metadata(), - payload.data()); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - boolean complete, - long requestN, - ByteBuf metadata, - ByteBuf data) { - - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - - return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, complete, false, reqN, metadata, data); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.dataWithRequestN(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadataWithRequestN(byteBuf); - } - - public static int initialRequestN(ByteBuf byteBuf) { - return FLYWEIGHT.initialRequestN(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java new file mode 100644 index 000000000..b91199179 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java @@ -0,0 +1,38 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestFireAndForgetFrameCodec { + + private RequestFireAndForgetFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_FNF, streamId, false, false, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + + return GenericFrameCodec.encode( + allocator, FrameType.REQUEST_FNF, streamId, fragmentFollows, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java deleted file mode 100644 index 5f2d606e4..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameFlyweight.java +++ /dev/null @@ -1,37 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Payload; - -public class RequestFireAndForgetFrameFlyweight { - - private static final RequestFlyweight FLYWEIGHT = new RequestFlyweight(FrameType.REQUEST_FNF); - - private RequestFireAndForgetFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, int streamId, boolean fragmentFollows, Payload payload) { - - return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, payload.metadata(), payload.data()); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - ByteBuf metadata, - ByteBuf data) { - - return FLYWEIGHT.encode(allocator, streamId, fragmentFollows, metadata, data); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.data(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadata(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java deleted file mode 100644 index 98d862f36..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFlyweight.java +++ /dev/null @@ -1,107 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import javax.annotation.Nullable; - -class RequestFlyweight { - FrameType frameType; - - RequestFlyweight(FrameType frameType) { - this.frameType = frameType; - } - - ByteBuf encode( - final ByteBufAllocator allocator, - final int streamId, - boolean fragmentFollows, - @Nullable ByteBuf metadata, - ByteBuf data) { - return encode(allocator, streamId, fragmentFollows, false, false, 0, metadata, data); - } - - ByteBuf encode( - final ByteBufAllocator allocator, - final int streamId, - boolean fragmentFollows, - boolean complete, - boolean next, - int requestN, - @Nullable ByteBuf metadata, - ByteBuf data) { - int flags = 0; - - if (metadata != null) { - flags |= FrameHeaderFlyweight.FLAGS_M; - } - - if (fragmentFollows) { - flags |= FrameHeaderFlyweight.FLAGS_F; - } - - if (complete) { - flags |= FrameHeaderFlyweight.FLAGS_C; - } - - if (next) { - flags |= FrameHeaderFlyweight.FLAGS_N; - } - - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, frameType, flags); - - if (requestN > 0) { - header.writeInt(requestN); - } - - if (data == null && metadata == null) { - return header; - } else if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } - } - - ByteBuf data(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - int idx = byteBuf.readerIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - ByteBuf data = DataAndMetadataFlyweight.dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.readerIndex(idx); - return data; - } - - ByteBuf metadata(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; - } - - ByteBuf dataWithRequestN(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf data = DataAndMetadataFlyweight.dataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return data; - } - - ByteBuf metadataWithRequestN(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size() + Integer.BYTES); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); - byteBuf.resetReaderIndex(); - return metadata; - } - - int initialRequestN(ByteBuf byteBuf) { - byteBuf.markReaderIndex(); - int i = byteBuf.skipBytes(FrameHeaderFlyweight.size()).readInt(); - byteBuf.resetReaderIndex(); - return i; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java new file mode 100644 index 000000000..66bdd46f4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java @@ -0,0 +1,30 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class RequestNFrameCodec { + private RequestNFrameCodec() {} + + public static ByteBuf encode( + final ByteBufAllocator allocator, final int streamId, long requestN) { + + if (requestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; + + ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.REQUEST_N, 0); + return header.writeInt(reqN); + } + + public static long requestN(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.REQUEST_N, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i == Integer.MAX_VALUE ? Long.MAX_VALUE : i; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java deleted file mode 100644 index 5a4c4c273..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; - -public class RequestNFrameFlyweight { - private RequestNFrameFlyweight() {} - - public static ByteBuf encode( - final ByteBufAllocator allocator, final int streamId, long requestN) { - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - return encode(allocator, streamId, reqN); - } - - public static ByteBuf encode(final ByteBufAllocator allocator, final int streamId, int requestN) { - ByteBuf header = FrameHeaderFlyweight.encode(allocator, streamId, FrameType.REQUEST_N, 0); - - if (requestN < 1) { - throw new IllegalArgumentException("request n is less than 1"); - } - - return header.writeInt(requestN); - } - - public static int requestN(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.REQUEST_N, byteBuf); - byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); - int i = byteBuf.readInt(); - byteBuf.resetReaderIndex(); - return i; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java new file mode 100644 index 000000000..4a37acfd5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java @@ -0,0 +1,37 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestResponseFrameCodec { + + private RequestResponseFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_RESPONSE, streamId, false, false, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + return GenericFrameCodec.encode( + allocator, FrameType.REQUEST_RESPONSE, streamId, fragmentFollows, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java deleted file mode 100644 index 2e06c9b82..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameFlyweight.java +++ /dev/null @@ -1,34 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Payload; - -public class RequestResponseFrameFlyweight { - private static final RequestFlyweight FLYWEIGHT = - new RequestFlyweight(FrameType.REQUEST_RESPONSE); - - private RequestResponseFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, int streamId, boolean fragmentFollows, Payload payload) { - return encode(allocator, streamId, fragmentFollows, payload.metadata(), payload.data()); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - ByteBuf metadata, - ByteBuf data) { - return FLYWEIGHT.encode(allocator, streamId, fragmentFollows, metadata, data); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.data(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadata(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java new file mode 100644 index 000000000..2f5dbf0d8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java @@ -0,0 +1,64 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestStreamFrameCodec { + + private RequestStreamFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, long initialRequestN, Payload payload) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_STREAM, streamId, false, false, reqN, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + long initialRequestN, + @Nullable ByteBuf metadata, + ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encode( + allocator, + FrameType.REQUEST_STREAM, + streamId, + fragmentFollows, + false, + false, + reqN, + metadata, + data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.dataWithRequestN(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadataWithRequestN(byteBuf); + } + + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = GenericFrameCodec.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java deleted file mode 100644 index 171c41990..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameFlyweight.java +++ /dev/null @@ -1,66 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.Payload; - -public class RequestStreamFrameFlyweight { - - private static final RequestFlyweight FLYWEIGHT = new RequestFlyweight(FrameType.REQUEST_STREAM); - - private RequestStreamFrameFlyweight() {} - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - long requestN, - Payload payload) { - return encode( - allocator, streamId, fragmentFollows, requestN, payload.metadata(), payload.data()); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - int requestN, - Payload payload) { - return encode( - allocator, streamId, fragmentFollows, requestN, payload.metadata(), payload.data()); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - long requestN, - ByteBuf metadata, - ByteBuf data) { - int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - return encode(allocator, streamId, fragmentFollows, reqN, metadata, data); - } - - public static ByteBuf encode( - ByteBufAllocator allocator, - int streamId, - boolean fragmentFollows, - int requestN, - ByteBuf metadata, - ByteBuf data) { - return FLYWEIGHT.encode( - allocator, streamId, fragmentFollows, false, false, requestN, metadata, data); - } - - public static ByteBuf data(ByteBuf byteBuf) { - return FLYWEIGHT.dataWithRequestN(byteBuf); - } - - public static ByteBuf metadata(ByteBuf byteBuf) { - return FLYWEIGHT.metadataWithRequestN(byteBuf); - } - - public static int initialRequestN(ByteBuf byteBuf) { - return FLYWEIGHT.initialRequestN(byteBuf); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java similarity index 79% rename from rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java index 06c9fc38c..aae89f7ab 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java @@ -21,8 +21,8 @@ import io.netty.buffer.Unpooled; import java.util.UUID; -public class ResumeFrameFlyweight { - static final int CURRENT_VERSION = SetupFrameFlyweight.CURRENT_VERSION; +public class ResumeFrameCodec { + static final int CURRENT_VERSION = SetupFrameCodec.CURRENT_VERSION; public static ByteBuf encode( ByteBufAllocator allocator, @@ -30,7 +30,7 @@ public static ByteBuf encode( long lastReceivedServerPos, long firstAvailableClientPos) { - ByteBuf byteBuf = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.RESUME, 0); + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.RESUME, 0); byteBuf.writeInt(CURRENT_VERSION); token.markReaderIndex(); byteBuf.writeShort(token.readableBytes()); @@ -43,10 +43,10 @@ public static ByteBuf encode( } public static int version(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.RESUME, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); byteBuf.markReaderIndex(); - byteBuf.skipBytes(FrameHeaderFlyweight.size()); + byteBuf.skipBytes(FrameHeaderCodec.size()); int version = byteBuf.readInt(); byteBuf.resetReaderIndex(); @@ -54,11 +54,11 @@ public static int version(ByteBuf byteBuf) { } public static ByteBuf token(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.RESUME, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); byteBuf.markReaderIndex(); // header + version - int tokenPos = FrameHeaderFlyweight.size() + Integer.BYTES; + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; byteBuf.skipBytes(tokenPos); // token int tokenLength = byteBuf.readShort() & 0xFFFF; @@ -69,11 +69,11 @@ public static ByteBuf token(ByteBuf byteBuf) { } public static long lastReceivedServerPos(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.RESUME, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); byteBuf.markReaderIndex(); // header + version - int tokenPos = FrameHeaderFlyweight.size() + Integer.BYTES; + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; byteBuf.skipBytes(tokenPos); // token int tokenLength = byteBuf.readShort() & 0xFFFF; @@ -85,11 +85,11 @@ public static long lastReceivedServerPos(ByteBuf byteBuf) { } public static long firstAvailableClientPos(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.RESUME, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); byteBuf.markReaderIndex(); // header + version - int tokenPos = FrameHeaderFlyweight.size() + Integer.BYTES; + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; byteBuf.skipBytes(tokenPos); // token int tokenLength = byteBuf.readShort() & 0xFFFF; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java similarity index 67% rename from rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java index dd1971603..2b6951e49 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java @@ -3,18 +3,18 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -public class ResumeOkFrameFlyweight { +public class ResumeOkFrameCodec { public static ByteBuf encode(final ByteBufAllocator allocator, long lastReceivedClientPos) { - ByteBuf byteBuf = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.RESUME_OK, 0); + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.RESUME_OK, 0); byteBuf.writeLong(lastReceivedClientPos); return byteBuf; } public static long lastReceivedClientPos(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.RESUME_OK, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.RESUME_OK, byteBuf); byteBuf.markReaderIndex(); - long lastReceivedClientPosition = byteBuf.skipBytes(FrameHeaderFlyweight.size()).readLong(); + long lastReceivedClientPosition = byteBuf.skipBytes(FrameHeaderCodec.size()).readLong(); byteBuf.resetReaderIndex(); return lastReceivedClientPosition; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java similarity index 78% rename from rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java index 9f92e715f..547e2436e 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java @@ -6,8 +6,9 @@ import io.netty.buffer.Unpooled; import io.rsocket.Payload; import java.nio.charset.StandardCharsets; +import reactor.util.annotation.Nullable; -public class SetupFrameFlyweight { +public class SetupFrameCodec { /** * A flag used to indicate that the client requires connection resumption, if possible (the frame * contains a Resume Identification Token) @@ -17,9 +18,9 @@ public class SetupFrameFlyweight { /** A flag used to indicate that the client will honor LEASE sent by the server */ public static final int FLAGS_WILL_HONOR_LEASE = 0b00_0100_0000; - public static final int CURRENT_VERSION = VersionFlyweight.encode(1, 0); + public static final int CURRENT_VERSION = VersionCodec.encode(1, 0); - private static final int VERSION_FIELD_OFFSET = FrameHeaderFlyweight.size(); + private static final int VERSION_FIELD_OFFSET = FrameHeaderCodec.size(); private static final int KEEPALIVE_INTERVAL_FIELD_OFFSET = VERSION_FIELD_OFFSET + Integer.BYTES; private static final int KEEPALIVE_MAX_LIFETIME_FIELD_OFFSET = KEEPALIVE_INTERVAL_FIELD_OFFSET + Integer.BYTES; @@ -55,12 +56,13 @@ public static ByteBuf encode( final String dataMimeType, final Payload setupPayload) { - ByteBuf metadata = setupPayload.hasMetadata() ? setupPayload.sliceMetadata() : null; - ByteBuf data = setupPayload.sliceData(); + final ByteBuf data = setupPayload.sliceData(); + final boolean hasMetadata = setupPayload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? setupPayload.sliceMetadata() : null; int flags = 0; - if (resumeToken != null && resumeToken.readableBytes() > 0) { + if (resumeToken.readableBytes() > 0) { flags |= FLAGS_RESUME_ENABLE; } @@ -68,11 +70,11 @@ public static ByteBuf encode( flags |= FLAGS_WILL_HONOR_LEASE; } - if (metadata != null) { - flags |= FrameHeaderFlyweight.FLAGS_M; + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; } - ByteBuf header = FrameHeaderFlyweight.encodeStreamZero(allocator, FrameType.SETUP, flags); + final ByteBuf header = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.SETUP, flags); header.writeInt(CURRENT_VERSION).writeInt(keepaliveInterval).writeInt(maxLifetime); @@ -91,17 +93,12 @@ public static ByteBuf encode( length = ByteBufUtil.utf8Bytes(dataMimeType); header.writeByte(length); ByteBufUtil.writeUtf8(header, dataMimeType); - if (data == null && metadata == null) { - return header; - } else if (metadata != null) { - return DataAndMetadataFlyweight.encode(allocator, header, metadata, data); - } else { - return DataAndMetadataFlyweight.encodeOnlyData(allocator, header, data); - } + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); } public static int version(ByteBuf byteBuf) { - FrameHeaderFlyweight.ensureFrameType(FrameType.SETUP, byteBuf); + FrameHeaderCodec.ensureFrameType(FrameType.SETUP, byteBuf); byteBuf.markReaderIndex(); int version = byteBuf.skipBytes(VERSION_FIELD_OFFSET).readInt(); byteBuf.resetReaderIndex(); @@ -110,7 +107,7 @@ public static int version(ByteBuf byteBuf) { public static String humanReadableVersion(ByteBuf byteBuf) { int encodedVersion = version(byteBuf); - return VersionFlyweight.major(encodedVersion) + "." + VersionFlyweight.minor(encodedVersion); + return VersionCodec.major(encodedVersion) + "." + VersionCodec.minor(encodedVersion); } public static boolean isSupportedVersion(ByteBuf byteBuf) { @@ -139,11 +136,11 @@ public static int keepAliveMaxLifetime(ByteBuf byteBuf) { } public static boolean honorLease(ByteBuf byteBuf) { - return (FLAGS_WILL_HONOR_LEASE & FrameHeaderFlyweight.flags(byteBuf)) == FLAGS_WILL_HONOR_LEASE; + return (FLAGS_WILL_HONOR_LEASE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_WILL_HONOR_LEASE; } public static boolean resumeEnabled(ByteBuf byteBuf) { - return (FLAGS_RESUME_ENABLE & FrameHeaderFlyweight.flags(byteBuf)) == FLAGS_RESUME_ENABLE; + return (FLAGS_RESUME_ENABLE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_RESUME_ENABLE; } public static ByteBuf resumeToken(ByteBuf byteBuf) { @@ -151,7 +148,7 @@ public static ByteBuf resumeToken(ByteBuf byteBuf) { byteBuf.markReaderIndex(); // header int resumePos = - FrameHeaderFlyweight.size() + FrameHeaderCodec.size() + // version Integer.BYTES @@ -167,7 +164,7 @@ public static ByteBuf resumeToken(ByteBuf byteBuf) { byteBuf.resetReaderIndex(); return resumeToken; } else { - return null; + return Unpooled.EMPTY_BUFFER; } } @@ -190,27 +187,31 @@ public static String dataMimeType(ByteBuf byteBuf) { return mimeType; } + @Nullable public static ByteBuf metadata(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } byteBuf.markReaderIndex(); skipToPayload(byteBuf); - ByteBuf metadata = DataAndMetadataFlyweight.metadataWithoutMarking(byteBuf, hasMetadata); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); byteBuf.resetReaderIndex(); return metadata; } public static ByteBuf data(ByteBuf byteBuf) { - boolean hasMetadata = FrameHeaderFlyweight.hasMetadata(byteBuf); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); byteBuf.markReaderIndex(); skipToPayload(byteBuf); - ByteBuf data = DataAndMetadataFlyweight.dataWithoutMarking(byteBuf, hasMetadata); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); byteBuf.resetReaderIndex(); return data; } private static int bytesToSkipToMimeType(ByteBuf byteBuf) { int bytesToSkip = VARIABLE_DATA_OFFSET; - if ((FLAGS_RESUME_ENABLE & FrameHeaderFlyweight.flags(byteBuf)) == FLAGS_RESUME_ENABLE) { + if ((FLAGS_RESUME_ENABLE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_RESUME_ENABLE) { bytesToSkip += resumeTokenLength(byteBuf) + Short.BYTES; } return bytesToSkip; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java similarity index 96% rename from rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java index e238b3fe2..35e4aa86a 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java @@ -16,7 +16,7 @@ package io.rsocket.frame; -public class VersionFlyweight { +public class VersionCodec { public static int encode(int major, int minor) { return (major << 16) | (minor & 0xFFFF); diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java index 74186f1d1..e6874c097 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java @@ -3,56 +3,67 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Payload; -import io.rsocket.frame.*; -import io.rsocket.util.ByteBufPayload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.util.DefaultPayload; import java.nio.ByteBuffer; /** Default Frame decoder that copies the frames contents for easy of use. */ class DefaultPayloadDecoder implements PayloadDecoder { @Override - public synchronized Payload apply(ByteBuf byteBuf) { + public Payload apply(ByteBuf byteBuf) { ByteBuf m; ByteBuf d; - FrameType type = FrameHeaderFlyweight.frameType(byteBuf); + FrameType type = FrameHeaderCodec.frameType(byteBuf); switch (type) { case REQUEST_FNF: - d = RequestFireAndForgetFrameFlyweight.data(byteBuf); - m = RequestFireAndForgetFrameFlyweight.metadata(byteBuf); + d = RequestFireAndForgetFrameCodec.data(byteBuf); + m = RequestFireAndForgetFrameCodec.metadata(byteBuf); break; case REQUEST_RESPONSE: - d = RequestResponseFrameFlyweight.data(byteBuf); - m = RequestResponseFrameFlyweight.metadata(byteBuf); + d = RequestResponseFrameCodec.data(byteBuf); + m = RequestResponseFrameCodec.metadata(byteBuf); break; case REQUEST_STREAM: - d = RequestStreamFrameFlyweight.data(byteBuf); - m = RequestStreamFrameFlyweight.metadata(byteBuf); + d = RequestStreamFrameCodec.data(byteBuf); + m = RequestStreamFrameCodec.metadata(byteBuf); break; case REQUEST_CHANNEL: - d = RequestChannelFrameFlyweight.data(byteBuf); - m = RequestChannelFrameFlyweight.metadata(byteBuf); + d = RequestChannelFrameCodec.data(byteBuf); + m = RequestChannelFrameCodec.metadata(byteBuf); break; case NEXT: case NEXT_COMPLETE: - d = PayloadFrameFlyweight.data(byteBuf); - m = PayloadFrameFlyweight.metadata(byteBuf); + d = PayloadFrameCodec.data(byteBuf); + m = PayloadFrameCodec.metadata(byteBuf); break; case METADATA_PUSH: d = Unpooled.EMPTY_BUFFER; - m = MetadataPushFrameFlyweight.metadata(byteBuf); + m = MetadataPushFrameCodec.metadata(byteBuf); break; default: throw new IllegalArgumentException("unsupported frame type: " + type); } - ByteBuffer metadata = ByteBuffer.allocateDirect(m.readableBytes()); ByteBuffer data = ByteBuffer.allocateDirect(d.readableBytes()); - data.put(d.nioBuffer()); data.flip(); - metadata.put(m.nioBuffer()); - metadata.flip(); - return ByteBufPayload.create(data, metadata); + if (m != null) { + ByteBuffer metadata = ByteBuffer.allocateDirect(m.readableBytes()); + metadata.put(m.nioBuffer()); + metadata.flip(); + + return DefaultPayload.create(data, metadata); + } + + return DefaultPayload.create(data); } } diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java index 0b63590e8..3a0dc7bb5 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java @@ -3,7 +3,14 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Payload; -import io.rsocket.frame.*; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; import io.rsocket.util.ByteBufPayload; /** @@ -15,37 +22,37 @@ public class ZeroCopyPayloadDecoder implements PayloadDecoder { public Payload apply(ByteBuf byteBuf) { ByteBuf m; ByteBuf d; - FrameType type = FrameHeaderFlyweight.frameType(byteBuf); + FrameType type = FrameHeaderCodec.frameType(byteBuf); switch (type) { case REQUEST_FNF: - d = RequestFireAndForgetFrameFlyweight.data(byteBuf); - m = RequestFireAndForgetFrameFlyweight.metadata(byteBuf); + d = RequestFireAndForgetFrameCodec.data(byteBuf); + m = RequestFireAndForgetFrameCodec.metadata(byteBuf); break; case REQUEST_RESPONSE: - d = RequestResponseFrameFlyweight.data(byteBuf); - m = RequestResponseFrameFlyweight.metadata(byteBuf); + d = RequestResponseFrameCodec.data(byteBuf); + m = RequestResponseFrameCodec.metadata(byteBuf); break; case REQUEST_STREAM: - d = RequestStreamFrameFlyweight.data(byteBuf); - m = RequestStreamFrameFlyweight.metadata(byteBuf); + d = RequestStreamFrameCodec.data(byteBuf); + m = RequestStreamFrameCodec.metadata(byteBuf); break; case REQUEST_CHANNEL: - d = RequestChannelFrameFlyweight.data(byteBuf); - m = RequestChannelFrameFlyweight.metadata(byteBuf); + d = RequestChannelFrameCodec.data(byteBuf); + m = RequestChannelFrameCodec.metadata(byteBuf); break; case NEXT: case NEXT_COMPLETE: - d = PayloadFrameFlyweight.data(byteBuf); - m = PayloadFrameFlyweight.metadata(byteBuf); + d = PayloadFrameCodec.data(byteBuf); + m = PayloadFrameCodec.metadata(byteBuf); break; case METADATA_PUSH: d = Unpooled.EMPTY_BUFFER; - m = MetadataPushFrameFlyweight.metadata(byteBuf); + m = MetadataPushFrameCodec.metadata(byteBuf); break; default: throw new IllegalArgumentException("unsupported frame type: " + type); } - return ByteBufPayload.create(d.retain(), m.retain()); + return ByteBufPayload.create(d.retain(), m != null ? m.retain() : null); } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/OnceConsumer.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java similarity index 60% rename from rsocket-core/src/main/java/io/rsocket/util/OnceConsumer.java rename to rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java index af4c038cc..82e8acaf3 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/OnceConsumer.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,20 +14,11 @@ * limitations under the License. */ -package io.rsocket.util; - -import java.util.function.Consumer; - -public abstract class OnceConsumer implements Consumer { - private boolean isFirst = true; - - @Override - public final void accept(T t) { - if (isFirst) { - isFirst = false; - acceptOnce(t); - } - } +/** + * Support for encoding and decoding of RSocket frames to and from {@link io.rsocket.Payload + * Payload}. + */ +@NonNullApi +package io.rsocket.frame.decoder; - public abstract void acceptOnce(T t); -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/package-info.java b/rsocket-core/src/main/java/io/rsocket/frame/package-info.java new file mode 100644 index 000000000..69f6d6860 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Support for encoding and decoding of RSocket frames to and from {@link io.rsocket.Payload + * Payload}. + */ +@NonNullApi +package io.rsocket.frame; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/BitUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/BitUtil.java deleted file mode 100644 index 79be9ccd5..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/BitUtil.java +++ /dev/null @@ -1,287 +0,0 @@ -/* - * Copyright 2014-2019 Real Logic Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal; - -import static java.nio.charset.StandardCharsets.UTF_8; - -import java.util.concurrent.ThreadLocalRandom; - -/** Miscellaneous useful functions for dealing with low level bits and bytes. */ -public class BitUtil { - /** Size of a byte in bytes */ - public static final int SIZE_OF_BYTE = 1; - - /** Size of a boolean in bytes */ - public static final int SIZE_OF_BOOLEAN = 1; - - /** Size of a char in bytes */ - public static final int SIZE_OF_CHAR = 2; - - /** Size of a short in bytes */ - public static final int SIZE_OF_SHORT = 2; - - /** Size of an int in bytes */ - public static final int SIZE_OF_INT = 4; - - /** Size of a float in bytes */ - public static final int SIZE_OF_FLOAT = 4; - - /** Size of a long in bytes */ - public static final int SIZE_OF_LONG = 8; - - /** Size of a double in bytes */ - public static final int SIZE_OF_DOUBLE = 8; - - /** Length of the data blocks used by the CPU cache sub-system in bytes. */ - public static final int CACHE_LINE_LENGTH = 64; - - private static final byte[] HEX_DIGIT_TABLE = { - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f' - }; - - private static final byte[] FROM_HEX_DIGIT_TABLE; - - static { - FROM_HEX_DIGIT_TABLE = new byte[128]; - - FROM_HEX_DIGIT_TABLE['0'] = 0x00; - FROM_HEX_DIGIT_TABLE['1'] = 0x01; - FROM_HEX_DIGIT_TABLE['2'] = 0x02; - FROM_HEX_DIGIT_TABLE['3'] = 0x03; - FROM_HEX_DIGIT_TABLE['4'] = 0x04; - FROM_HEX_DIGIT_TABLE['5'] = 0x05; - FROM_HEX_DIGIT_TABLE['6'] = 0x06; - FROM_HEX_DIGIT_TABLE['7'] = 0x07; - FROM_HEX_DIGIT_TABLE['8'] = 0x08; - FROM_HEX_DIGIT_TABLE['9'] = 0x09; - FROM_HEX_DIGIT_TABLE['a'] = 0x0a; - FROM_HEX_DIGIT_TABLE['A'] = 0x0a; - FROM_HEX_DIGIT_TABLE['b'] = 0x0b; - FROM_HEX_DIGIT_TABLE['B'] = 0x0b; - FROM_HEX_DIGIT_TABLE['c'] = 0x0c; - FROM_HEX_DIGIT_TABLE['C'] = 0x0c; - FROM_HEX_DIGIT_TABLE['d'] = 0x0d; - FROM_HEX_DIGIT_TABLE['D'] = 0x0d; - FROM_HEX_DIGIT_TABLE['e'] = 0x0e; - FROM_HEX_DIGIT_TABLE['E'] = 0x0e; - FROM_HEX_DIGIT_TABLE['f'] = 0x0f; - FROM_HEX_DIGIT_TABLE['F'] = 0x0f; - } - - private static final int LAST_DIGIT_MASK = 0b1; - - /** - * Fast method of finding the next power of 2 greater than or equal to the supplied value. - * - *

If the value is <= 0 then 1 will be returned. - * - *

This method is not suitable for {@link Integer#MIN_VALUE} or numbers greater than 2^30. - * - * @param value from which to search for next power of 2 - * @return The next power of 2 or the value itself if it is a power of 2 - */ - public static int findNextPositivePowerOfTwo(final int value) { - return 1 << (Integer.SIZE - Integer.numberOfLeadingZeros(value - 1)); - } - - /** - * Align a value to the next multiple up of alignment. If the value equals an alignment multiple - * then it is returned unchanged. - * - *

This method executes without branching. This code is designed to be use in the fast path and - * should not be used with negative numbers. Negative numbers will result in undefined behaviour. - * - * @param value to be aligned up. - * @param alignment to be used. - * @return the value aligned to the next boundary. - */ - public static int align(final int value, final int alignment) { - return (value + (alignment - 1)) & -alignment; - } - - /** - * Generate a byte array from the hex representation of the given byte array. - * - * @param buffer to convert from a hex representation (in Big Endian). - * @return new byte array that is decimal representation of the passed array. - */ - public static byte[] fromHexByteArray(final byte[] buffer) { - final byte[] outputBuffer = new byte[buffer.length >> 1]; - - for (int i = 0; i < buffer.length; i += 2) { - final int hi = FROM_HEX_DIGIT_TABLE[buffer[i]] << 4; - final int lo = FROM_HEX_DIGIT_TABLE[buffer[i + 1]]; // lgtm [java/index-out-of-bounds] - outputBuffer[i >> 1] = (byte) (hi | lo); - } - - return outputBuffer; - } - - /** - * Generate a byte array that is a hex representation of a given byte array. - * - * @param buffer to convert to a hex representation. - * @return new byte array that is hex representation (in Big Endian) of the passed array. - */ - public static byte[] toHexByteArray(final byte[] buffer) { - return toHexByteArray(buffer, 0, buffer.length); - } - - /** - * Generate a byte array that is a hex representation of a given byte array. - * - * @param buffer to convert to a hex representation. - * @param offset the offset into the buffer. - * @param length the number of bytes to convert. - * @return new byte array that is hex representation (in Big Endian) of the passed array. - */ - public static byte[] toHexByteArray(final byte[] buffer, final int offset, final int length) { - final byte[] outputBuffer = new byte[length << 1]; - - for (int i = 0; i < (length << 1); i += 2) { - final byte b = buffer[offset + (i >> 1)]; - - outputBuffer[i] = HEX_DIGIT_TABLE[(b >> 4) & 0x0F]; - outputBuffer[i + 1] = HEX_DIGIT_TABLE[b & 0x0F]; - } - - return outputBuffer; - } - - /** - * Generate a byte array from a string that is the hex representation of the given byte array. - * - * @param string to convert from a hex representation (in Big Endian). - * @return new byte array holding the decimal representation of the passed array. - */ - public static byte[] fromHex(final String string) { - return fromHexByteArray(string.getBytes(UTF_8)); - } - - /** - * Generate a string that is the hex representation of a given byte array. - * - * @param buffer to convert to a hex representation. - * @param offset the offset into the buffer. - * @param length the number of bytes to convert. - * @return new String holding the hex representation (in Big Endian) of the passed array. - */ - public static String toHex(final byte[] buffer, final int offset, final int length) { - return new String(toHexByteArray(buffer, offset, length), UTF_8); - } - - /** - * Generate a string that is the hex representation of a given byte array. - * - * @param buffer to convert to a hex representation. - * @return new String holding the hex representation (in Big Endian) of the passed array. - */ - public static String toHex(final byte[] buffer) { - return new String(toHexByteArray(buffer), UTF_8); - } - - /** - * Is a number even. - * - * @param value to check. - * @return true if the number is even otherwise false. - */ - public static boolean isEven(final int value) { - return (value & LAST_DIGIT_MASK) == 0; - } - - /** - * Is a value a positive power of 2. - * - * @param value to be checked. - * @return true if the number is a positive power of 2, otherwise false. - */ - public static boolean isPowerOfTwo(final int value) { - return value > 0 && ((value & (~value + 1)) == value); - } - - /** - * Cycles indices of an array one at a time in a forward fashion - * - * @param current value to be incremented. - * @param max value for the cycle. - * @return the next value, or zero if max is reached. - */ - public static int next(final int current, final int max) { - int next = current + 1; - if (next == max) { - next = 0; - } - - return next; - } - - /** - * Cycles indices of an array one at a time in a backwards fashion - * - * @param current value to be decremented. - * @param max value of the cycle. - * @return the next value, or max - 1 if current is zero. - */ - public static int previous(final int current, final int max) { - if (0 == current) { - return max - 1; - } - - return current - 1; - } - - /** - * Calculate the shift value to scale a number based on how refs are compressed or not. - * - * @param scale of the number reported by Unsafe. - * @return how many times the number needs to be shifted to the left. - */ - public static int calculateShiftForScale(final int scale) { - if (4 == scale) { - return 2; - } else if (8 == scale) { - return 3; - } - - throw new IllegalArgumentException("unknown pointer size for scale=" + scale); - } - - /** - * Generate a randomised integer over [{@link Integer#MIN_VALUE}, {@link Integer#MAX_VALUE}]. - * - * @return randomised integer suitable as an Id. - */ - public static int generateRandomisedId() { - return ThreadLocalRandom.current().nextInt(); - } - - /** - * Is an address aligned on a boundary. - * - * @param address to be tested. - * @param alignment boundary the address is tested against. - * @return true if the address is on the aligned boundary otherwise false. - * @throws IllegalArgumentException if the alignment is not a power of 2. - */ - public static boolean isAligned(final long address, final int alignment) { - if (!BitUtil.isPowerOfTwo(alignment)) { - throw new IllegalArgumentException("alignment must be a power of 2: alignment=" + alignment); - } - - return (address & (alignment - 1)) == 0; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java index 1e76b6898..038120efc 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -17,12 +17,13 @@ package io.rsocket.internal; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameUtil; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; -import io.rsocket.plugins.PluginRegistry; +import io.rsocket.plugins.InitializingInterceptorRegistry; import org.reactivestreams.Publisher; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,7 +46,8 @@ */ public class ClientServerInputMultiplexer implements Closeable { private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); - private static final PluginRegistry emptyPluginRegistry = new PluginRegistry(); + private static final InitializingInterceptorRegistry emptyInterceptorRegistry = + new InitializingInterceptorRegistry(); private final DuplexConnection setupConnection; private final DuplexConnection serverConnection; @@ -54,33 +56,33 @@ public class ClientServerInputMultiplexer implements Closeable { private final DuplexConnection clientServerConnection; public ClientServerInputMultiplexer(DuplexConnection source) { - this(source, emptyPluginRegistry, false); + this(source, emptyInterceptorRegistry, false); } public ClientServerInputMultiplexer( - DuplexConnection source, PluginRegistry plugins, boolean isClient) { + DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { this.source = source; final MonoProcessor> setup = MonoProcessor.create(); final MonoProcessor> server = MonoProcessor.create(); final MonoProcessor> client = MonoProcessor.create(); - source = plugins.applyConnection(Type.SOURCE, source); + source = registry.initConnection(Type.SOURCE, source); setupConnection = - plugins.applyConnection(Type.SETUP, new InternalDuplexConnection(source, setup)); + registry.initConnection(Type.SETUP, new InternalDuplexConnection(source, setup)); serverConnection = - plugins.applyConnection(Type.SERVER, new InternalDuplexConnection(source, server)); + registry.initConnection(Type.SERVER, new InternalDuplexConnection(source, server)); clientConnection = - plugins.applyConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); + registry.initConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); clientServerConnection = new InternalDuplexConnection(source, client, server); source .receive() .groupBy( frame -> { - int streamId = FrameHeaderFlyweight.streamId(frame); + int streamId = FrameHeaderCodec.streamId(frame); final Type type; if (streamId == 0) { - switch (FrameHeaderFlyweight.frameType(frame)) { + switch (FrameHeaderCodec.frameType(frame)) { case SETUP: case RESUME: case RESUME_OK: @@ -117,10 +119,7 @@ public ClientServerInputMultiplexer( break; } }, - t -> { - LOGGER.error("Error receiving frame:", t); - dispose(); - }); + t -> {}); } public DuplexConnection asClientServerConnection() { @@ -159,6 +158,7 @@ private static class InternalDuplexConnection implements DuplexConnection { private final MonoProcessor>[] processors; private final boolean debugEnabled; + @SafeVarargs public InternalDuplexConnection( DuplexConnection source, MonoProcessor>... processors) { this.source = source; @@ -200,6 +200,11 @@ public Flux receive() { })); } + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + @Override public void dispose() { source.dispose(); diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientSetup.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientSetup.java deleted file mode 100644 index 38217bdc2..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientSetup.java +++ /dev/null @@ -1,112 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import static io.rsocket.keepalive.KeepAliveHandler.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.DuplexConnection; -import io.rsocket.keepalive.KeepAliveHandler; -import io.rsocket.resume.ClientRSocketSession; -import io.rsocket.resume.ResumableDuplexConnection; -import io.rsocket.resume.ResumableFramesStore; -import io.rsocket.resume.ResumeStrategy; -import java.time.Duration; -import java.util.function.Supplier; -import reactor.core.publisher.Mono; - -public interface ClientSetup { - - DuplexConnection connection(); - - KeepAliveHandler keepAliveHandler(); - - ByteBuf resumeToken(); - - class DefaultClientSetup implements ClientSetup { - private final DuplexConnection connection; - - public DefaultClientSetup(DuplexConnection connection) { - this.connection = connection; - } - - @Override - public DuplexConnection connection() { - return connection; - } - - @Override - public KeepAliveHandler keepAliveHandler() { - return new DefaultKeepAliveHandler(connection); - } - - @Override - public ByteBuf resumeToken() { - return Unpooled.EMPTY_BUFFER; - } - } - - class ResumableClientSetup implements ClientSetup { - private final ByteBuf resumeToken; - private final ResumableDuplexConnection duplexConnection; - private final ResumableKeepAliveHandler keepAliveHandler; - - public ResumableClientSetup( - ByteBufAllocator allocator, - DuplexConnection connection, - Mono newConnectionFactory, - ByteBuf resumeToken, - ResumableFramesStore resumableFramesStore, - Duration resumeSessionDuration, - Duration resumeStreamTimeout, - Supplier resumeStrategySupplier, - boolean cleanupStoreOnKeepAlive) { - - ClientRSocketSession rSocketSession = - new ClientRSocketSession( - connection, - allocator, - resumeSessionDuration, - resumeStrategySupplier, - resumableFramesStore, - resumeStreamTimeout, - cleanupStoreOnKeepAlive) - .continueWith(newConnectionFactory) - .resumeToken(resumeToken); - this.duplexConnection = rSocketSession.resumableConnection(); - this.keepAliveHandler = new ResumableKeepAliveHandler(duplexConnection); - this.resumeToken = resumeToken; - } - - @Override - public DuplexConnection connection() { - return duplexConnection; - } - - @Override - public KeepAliveHandler keepAliveHandler() { - return keepAliveHandler; - } - - @Override - public ByteBuf resumeToken() { - return resumeToken; - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/CollectionUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/CollectionUtil.java deleted file mode 100644 index 8d4526c36..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/CollectionUtil.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Copyright 2014-2019 Real Logic Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal; - -import java.util.List; -import java.util.Map; -import java.util.function.Function; -import java.util.function.Predicate; -import java.util.function.ToIntFunction; - -/** Utility functions for collection objects. */ -public class CollectionUtil { - /** - * A getOrDefault that doesn't create garbage if its suppler is non-capturing. - * - * @param map to perform the lookup on. - * @param key on which the lookup is done. - * @param supplier of the default value if one is not found. - * @param type of the key - * @param type of the value - * @return the value if found or a new default which as been added to the map. - */ - public static V getOrDefault( - final Map map, final K key, final Function supplier) { - V value = map.get(key); - if (value == null) { - value = supplier.apply(key); - map.put(key, value); - } - - return value; - } - - /** - * Garbage free sum function. - * - *

Note: the list must implement {@link java.util.RandomAccess} to be efficient. - * - * @param values the list of input values - * @param function function that map each value to an int - * @param the value to add up - * @return the sum of all the int values returned for each member of the list. - */ - public static int sum(final List values, final ToIntFunction function) { - int total = 0; - - final int size = values.size(); - for (int i = 0; i < size; i++) { - final V value = values.get(i); - total += function.applyAsInt(value); - } - - return total; - } - - /** - * Validate that a load factor is in the range of 0.1 to 0.9. - * - *

Load factors in the range 0.5 - 0.7 are recommended for open-addressing with linear probing. - * - * @param loadFactor to be validated. - */ - public static void validateLoadFactor(final float loadFactor) { - if (loadFactor < 0.1f || loadFactor > 0.9f) { - throw new IllegalArgumentException( - "load factor must be in the range of 0.1 to 0.9: " + loadFactor); - } - } - - /** - * Validate that a number is a power of two. - * - * @param value to be validated. - */ - public static void validatePositivePowerOfTwo(final int value) { - if (value > 0 && 1 == (value & (value - 1))) { - throw new IllegalStateException("value must be a positive power of two"); - } - } - - /** - * Remove element from a list if it matches a predicate. - * - *

Note: the list must implement {@link java.util.RandomAccess} to be efficient. - * - * @param values to be iterated over. - * @param predicate to test the value against - * @param type of the value. - * @return the number of items remove. - */ - public static int removeIf(final List values, final Predicate predicate) { - int size = values.size(); - int total = 0; - - for (int i = 0; i < size; ) { - final T value = values.get(i); - if (predicate.test(value)) { - values.remove(i); - total++; - size--; - } else { - i++; - } - } - - return total; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/Hashing.java b/rsocket-core/src/main/java/io/rsocket/internal/Hashing.java deleted file mode 100644 index 613dce209..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/Hashing.java +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Copyright 2014-2019 Real Logic Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal; - -/** Hashing functions for applying to integers. */ -public class Hashing { - /** Default load factor to be used in open addressing hashed data structures. */ - public static final float DEFAULT_LOAD_FACTOR = 0.55f; - - /** - * Generate a hash for an int value. This is a no op. - * - * @param value to be hashed. - * @return the hashed value. - */ - public static int hash(final int value) { - return value * 31; - } - - /** - * Generate a hash for an long value. - * - * @param value to be hashed. - * @return the hashed value. - */ - public static int hash(final long value) { - long hash = value * 31; - hash = (int) hash ^ (int) (hash >>> 32); - - return (int) hash; - } - - /** - * Generate a hash for a int value. - * - * @param value to be hashed. - * @param mask mask to be applied that must be a power of 2 - 1. - * @return the hash of the value. - */ - public static int hash(final int value, final int mask) { - final int hash = value * 31; - - return hash & mask; - } - - /** - * Generate a hash for a K value. - * - * @param is the type of value - * @param value to be hashed. - * @param mask mask to be applied that must be a power of 2 - 1. - * @return the hash of the value. - */ - public static int hash(final K value, final int mask) { - final int hash = value.hashCode(); - - return hash & mask; - } - - /** - * Generate a hash for a long value. - * - * @param value to be hashed. - * @param mask mask to be applied that must be a power of 2 - 1. - * @return the hash of the value. - */ - public static int hash(final long value, final int mask) { - long hash = value * 31; - hash = (int) hash ^ (int) (hash >>> 32); - - return (int) hash & mask; - } - - /** - * Generate an even hash for a int value. - * - * @param value to be hashed. - * @param mask mask to be applied that must be a power of 2 - 1. - * @return the hash of the value which is always even. - */ - public static int evenHash(final int value, final int mask) { - final int hash = (value << 1) - (value << 8); - - return hash & mask; - } - - /** - * Generate an even hash for a long value. - * - * @param value to be hashed. - * @param mask mask to be applied that must be a power of 2 - 1. - * @return the hash of the value which is always even. - */ - public static int evenHash(final long value, final int mask) { - int hash = (int) value ^ (int) (value >>> 32); - hash = (hash << 1) - (hash << 8); - - return hash & mask; - } - - /** - * Combined two 32 bit keys into a 64-bit compound. - * - * @param keyPartA to make the upper bits - * @param keyPartB to make the lower bits. - * @return the compound key - */ - public static long compoundKey(final int keyPartA, final int keyPartB) { - return ((long) keyPartA << 32) | (keyPartB & 0xFFFF_FFFFL); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java b/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java deleted file mode 100755 index 8adb7542a..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java +++ /dev/null @@ -1,204 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import javax.annotation.Nullable; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; - -/** */ -public class LimitableRequestPublisher extends Flux implements Subscription { - - private static final int NOT_CANCELED_STATE = 0; - private static final int CANCELED_STATE = 1; - - private final Publisher source; - - private volatile int canceled; - private static final AtomicIntegerFieldUpdater CANCELED = - AtomicIntegerFieldUpdater.newUpdater(LimitableRequestPublisher.class, "canceled"); - - private final long prefetch; - - private long internalRequested; - - private long externalRequested; - - private boolean subscribed; - - private @Nullable Subscription internalSubscription; - - private LimitableRequestPublisher(Publisher source, long prefetch) { - this.source = source; - this.prefetch = prefetch; - } - - public static LimitableRequestPublisher wrap(Publisher source, long prefetch) { - return new LimitableRequestPublisher<>(source, prefetch); - } - - public static LimitableRequestPublisher wrap(Publisher source) { - return wrap(source, Long.MAX_VALUE); - } - - @Override - public void subscribe(CoreSubscriber destination) { - synchronized (this) { - if (subscribed) { - throw new IllegalStateException("only one subscriber at a time"); - } - - subscribed = true; - } - final InnerOperator s = new InnerOperator(destination); - - destination.onSubscribe(s); - source.subscribe(s); - increaseInternalLimit(prefetch); - } - - public void increaseInternalLimit(long n) { - synchronized (this) { - long requested = internalRequested; - if (requested == Long.MAX_VALUE) { - return; - } - internalRequested = Operators.addCap(n, requested); - } - - requestN(); - } - - @Override - public void request(long n) { - synchronized (this) { - long requested = externalRequested; - if (requested == Long.MAX_VALUE) { - return; - } - externalRequested = Operators.addCap(n, requested); - } - - requestN(); - } - - private void requestN() { - long r; - final Subscription s; - - synchronized (this) { - s = internalSubscription; - if (s == null) { - return; - } - - long er = externalRequested; - long ir = internalRequested; - - if (er != Long.MAX_VALUE || ir != Long.MAX_VALUE) { - r = Math.min(ir, er); - if (er != Long.MAX_VALUE) { - externalRequested -= r; - } - if (ir != Long.MAX_VALUE) { - internalRequested -= r; - } - } else { - r = Long.MAX_VALUE; - } - } - - if (r > 0) { - s.request(r); - } - } - - public void cancel() { - if (!isCanceled() && CANCELED.compareAndSet(this, NOT_CANCELED_STATE, CANCELED_STATE)) { - Subscription s; - - synchronized (this) { - s = internalSubscription; - internalSubscription = null; - subscribed = false; - } - - if (s != null) { - s.cancel(); - } - } - } - - private boolean isCanceled() { - return canceled == 1; - } - - private class InnerOperator implements CoreSubscriber, Subscription { - final Subscriber destination; - - private InnerOperator(Subscriber destination) { - this.destination = destination; - } - - @Override - public void onSubscribe(Subscription s) { - synchronized (LimitableRequestPublisher.this) { - LimitableRequestPublisher.this.internalSubscription = s; - - if (isCanceled()) { - s.cancel(); - subscribed = false; - LimitableRequestPublisher.this.internalSubscription = null; - } - } - - requestN(); - } - - @Override - public void onNext(T t) { - try { - destination.onNext(t); - } catch (Throwable e) { - onError(e); - } - } - - @Override - public void onError(Throwable t) { - destination.onError(t); - } - - @Override - public void onComplete() { - destination.onComplete(); - } - - @Override - public void request(long n) {} - - @Override - public void cancel() { - LimitableRequestPublisher.this.cancel(); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java b/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java deleted file mode 100755 index cdb0d0c0c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/RateLimitableRequestPublisher.java +++ /dev/null @@ -1,242 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import javax.annotation.Nullable; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; - -/** */ -public class RateLimitableRequestPublisher extends Flux implements Subscription { - - private static final int NOT_CANCELED_STATE = 0; - private static final int CANCELED_STATE = 1; - - private final Publisher source; - - private volatile int canceled; - private static final AtomicIntegerFieldUpdater CANCELED = - AtomicIntegerFieldUpdater.newUpdater(RateLimitableRequestPublisher.class, "canceled"); - - private final long prefetch; - private final long limit; - - private long externalRequested; // need sync - private int pendingToFulfil; // need sync since should be checked/zerroed in onNext - // and increased in request - private int deliveredElements; // no need to sync since increased zerroed only in - // the request method - - private boolean subscribed; - - private @Nullable Subscription internalSubscription; - - private RateLimitableRequestPublisher(Publisher source, long prefetch) { - this.source = source; - this.prefetch = prefetch; - this.limit = prefetch == Integer.MAX_VALUE ? Integer.MAX_VALUE : (prefetch - (prefetch >> 2)); - } - - public static RateLimitableRequestPublisher wrap(Publisher source, long prefetch) { - return new RateLimitableRequestPublisher<>(source, prefetch); - } - - @Override - public void subscribe(CoreSubscriber destination) { - synchronized (this) { - if (subscribed) { - throw new IllegalStateException("only one subscriber at a time"); - } - - subscribed = true; - } - final InnerOperator s = new InnerOperator(destination); - - source.subscribe(s); - destination.onSubscribe(s); - } - - @Override - public void request(long n) { - synchronized (this) { - long requested = externalRequested; - if (requested == Long.MAX_VALUE) { - return; - } - externalRequested = Operators.addCap(n, requested); - } - - requestN(); - } - - private void requestN() { - final long r; - final Subscription s; - - synchronized (this) { - s = internalSubscription; - if (s == null) { - return; - } - - final long er = externalRequested; - final long p = prefetch; - final int pendingFulfil = pendingToFulfil; - - if (er != Long.MAX_VALUE || p != Integer.MAX_VALUE) { - // shortcut - if (pendingFulfil == p) { - return; - } - - r = Math.min(p - pendingFulfil, er); - if (er != Long.MAX_VALUE) { - externalRequested -= r; - } - if (p != Integer.MAX_VALUE) { - pendingToFulfil += r; - } - } else { - r = Long.MAX_VALUE; - } - } - - if (r > 0) { - s.request(r); - } - } - - public void cancel() { - if (!isCanceled() && CANCELED.compareAndSet(this, NOT_CANCELED_STATE, CANCELED_STATE)) { - Subscription s; - - synchronized (this) { - s = internalSubscription; - internalSubscription = null; - subscribed = false; - } - - if (s != null) { - s.cancel(); - } - } - } - - private boolean isCanceled() { - return canceled == CANCELED_STATE; - } - - private class InnerOperator implements CoreSubscriber, Subscription { - final Subscriber destination; - - private InnerOperator(Subscriber destination) { - this.destination = destination; - } - - @Override - public void onSubscribe(Subscription s) { - synchronized (RateLimitableRequestPublisher.this) { - RateLimitableRequestPublisher.this.internalSubscription = s; - - if (isCanceled()) { - s.cancel(); - subscribed = false; - RateLimitableRequestPublisher.this.internalSubscription = null; - } - } - - requestN(); - } - - @Override - public void onNext(T t) { - try { - destination.onNext(t); - - if (prefetch == Integer.MAX_VALUE) { - return; - } - - final long l = limit; - int d = deliveredElements + 1; - - if (d == l) { - d = 0; - final long r; - final Subscription s; - - synchronized (RateLimitableRequestPublisher.this) { - long er = externalRequested; - s = internalSubscription; - - if (s == null) { - return; - } - - if (er >= l) { - er -= l; - // keep pendingToFulfil as is since it is eq to prefetch - r = l; - } else { - pendingToFulfil -= l; - if (er > 0) { - r = er; - er = 0; - pendingToFulfil += r; - } else { - r = 0; - } - } - - externalRequested = er; - } - - if (r > 0) { - s.request(r); - } - } - - deliveredElements = d; - } catch (Throwable e) { - onError(e); - } - } - - @Override - public void onError(Throwable t) { - destination.onError(t); - } - - @Override - public void onComplete() { - destination.onComplete(); - } - - @Override - public void request(long n) {} - - @Override - public void cancel() { - RateLimitableRequestPublisher.this.cancel(); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java b/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java deleted file mode 100644 index 0d2e5988e..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/SwitchTransformFlux.java +++ /dev/null @@ -1,581 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.Objects; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.BiFunction; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Fuseable; -import reactor.core.Scannable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; -import reactor.util.annotation.Nullable; -import reactor.util.context.Context; - -/** @deprecated in favour of {@link Flux#switchOnFirst(BiFunction)} */ -@Deprecated -public final class SwitchTransformFlux extends Flux { - - final Publisher source; - final BiFunction, Publisher> transformer; - - public SwitchTransformFlux( - Publisher source, BiFunction, Publisher> transformer) { - this.source = Objects.requireNonNull(source, "source"); - this.transformer = Objects.requireNonNull(transformer, "transformer"); - } - - @Override - public int getPrefetch() { - return 1; - } - - @Override - @SuppressWarnings("unchecked") - public void subscribe(CoreSubscriber actual) { - if (actual instanceof Fuseable.ConditionalSubscriber) { - source.subscribe( - new SwitchTransformConditionalOperator<>( - (Fuseable.ConditionalSubscriber) actual, transformer)); - return; - } - source.subscribe(new SwitchTransformOperator<>(actual, transformer)); - } - - static final class SwitchTransformOperator extends Flux - implements CoreSubscriber, Subscription, Scannable { - - final CoreSubscriber outer; - final BiFunction, Publisher> transformer; - - Subscription s; - Throwable throwable; - - volatile boolean done; - volatile T first; - - volatile CoreSubscriber inner; - - @SuppressWarnings("rawtypes") - static final AtomicReferenceFieldUpdater INNER = - AtomicReferenceFieldUpdater.newUpdater( - SwitchTransformOperator.class, CoreSubscriber.class, "inner"); - - volatile int wip; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformOperator.class, "wip"); - - volatile int once; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformOperator.class, "once"); - - SwitchTransformOperator( - CoreSubscriber outer, - BiFunction, Publisher> transformer) { - this.outer = outer; - this.transformer = transformer; - } - - @Override - @Nullable - public Object scanUnsafe(Attr key) { - if (key == Attr.CANCELLED) return s == Operators.cancelledSubscription(); - if (key == Attr.PREFETCH) return 1; - - return null; - } - - @Override - public Context currentContext() { - CoreSubscriber actual = inner; - - if (actual != null) { - return actual.currentContext(); - } - - return outer.currentContext(); - } - - @Override - public void cancel() { - if (s != Operators.cancelledSubscription()) { - Subscription s = this.s; - this.s = Operators.cancelledSubscription(); - - if (WIP.getAndIncrement(this) == 0) { - INNER.lazySet(this, null); - - T f = first; - if (f != null) { - first = null; - Operators.onDiscard(f, currentContext()); - } - } - - s.cancel(); - } - } - - @Override - public void subscribe(CoreSubscriber actual) { - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - INNER.lazySet(this, actual); - actual.onSubscribe(this); - } else { - Operators.error( - actual, new IllegalStateException("SwitchTransform allows only one Subscriber")); - } - } - - @Override - public void onSubscribe(Subscription s) { - if (Operators.validate(this.s, s)) { - this.s = s; - s.request(1); - } - } - - @Override - public void onNext(T t) { - if (done) { - Operators.onNextDropped(t, currentContext()); - return; - } - - CoreSubscriber i = inner; - - if (i == null) { - try { - first = t; - Publisher result = - Objects.requireNonNull( - transformer.apply(t, this), "The transformer returned a null value"); - result.subscribe(outer); - return; - } catch (Throwable e) { - onError(Operators.onOperatorError(s, e, t, currentContext())); - return; - } - } - - i.onNext(t); - } - - @Override - public void onError(Throwable t) { - if (done) { - Operators.onErrorDropped(t, currentContext()); - return; - } - - throwable = t; - done = true; - CoreSubscriber i = inner; - - if (i != null) { - if (first == null) { - drainRegular(); - } - } else { - Operators.error(outer, t); - } - } - - @Override - public void onComplete() { - if (done) { - return; - } - - done = true; - CoreSubscriber i = inner; - - if (i != null) { - if (first == null) { - drainRegular(); - } - } else { - Operators.complete(outer); - } - } - - @Override - public void request(long n) { - if (Operators.validate(n)) { - if (first != null && drainRegular() && n != Long.MAX_VALUE) { - if (--n > 0) { - s.request(n); - } - } else { - s.request(n); - } - } - } - - boolean drainRegular() { - if (WIP.getAndIncrement(this) != 0) { - return false; - } - - T f = first; - int m = 1; - boolean sent = false; - Subscription s = this.s; - CoreSubscriber a = inner; - - for (; ; ) { - if (f != null) { - first = null; - - if (s == Operators.cancelledSubscription()) { - Operators.onDiscard(f, a.currentContext()); - return true; - } - - a.onNext(f); - f = null; - sent = true; - } - - if (s == Operators.cancelledSubscription()) { - return sent; - } - - if (done) { - Throwable t = throwable; - if (t != null) { - a.onError(t); - } else { - a.onComplete(); - } - return sent; - } - - m = WIP.addAndGet(this, -m); - - if (m == 0) { - return sent; - } - } - } - } - - static final class SwitchTransformConditionalOperator extends Flux - implements Fuseable.ConditionalSubscriber, Subscription, Scannable { - - final Fuseable.ConditionalSubscriber outer; - final BiFunction, Publisher> transformer; - - Subscription s; - Throwable throwable; - - volatile boolean done; - volatile T first; - - volatile Fuseable.ConditionalSubscriber inner; - - @SuppressWarnings("rawtypes") - static final AtomicReferenceFieldUpdater< - SwitchTransformConditionalOperator, Fuseable.ConditionalSubscriber> - INNER = - AtomicReferenceFieldUpdater.newUpdater( - SwitchTransformConditionalOperator.class, - Fuseable.ConditionalSubscriber.class, - "inner"); - - volatile int wip; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater WIP = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformConditionalOperator.class, "wip"); - - volatile int once; - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(SwitchTransformConditionalOperator.class, "once"); - - SwitchTransformConditionalOperator( - Fuseable.ConditionalSubscriber outer, - BiFunction, Publisher> transformer) { - this.outer = outer; - this.transformer = transformer; - } - - @Override - @Nullable - public Object scanUnsafe(Attr key) { - if (key == Attr.CANCELLED) return s == Operators.cancelledSubscription(); - if (key == Attr.PREFETCH) return 1; - - return null; - } - - @Override - public Context currentContext() { - CoreSubscriber actual = inner; - - if (actual != null) { - return actual.currentContext(); - } - - return outer.currentContext(); - } - - @Override - public void cancel() { - if (s != Operators.cancelledSubscription()) { - Subscription s = this.s; - this.s = Operators.cancelledSubscription(); - - if (WIP.getAndIncrement(this) == 0) { - INNER.lazySet(this, null); - - T f = first; - if (f != null) { - first = null; - Operators.onDiscard(f, currentContext()); - } - } - - s.cancel(); - } - } - - @Override - @SuppressWarnings("unchecked") - public void subscribe(CoreSubscriber actual) { - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - if (actual instanceof Fuseable.ConditionalSubscriber) { - INNER.lazySet(this, (Fuseable.ConditionalSubscriber) actual); - } else { - INNER.lazySet(this, new ConditionalSubscriberAdapter<>(actual)); - } - actual.onSubscribe(this); - } else { - Operators.error( - actual, new IllegalStateException("SwitchTransform allows only one Subscriber")); - } - } - - @Override - public void onSubscribe(Subscription s) { - if (Operators.validate(this.s, s)) { - this.s = s; - s.request(1); - } - } - - @Override - public void onNext(T t) { - if (done) { - Operators.onNextDropped(t, currentContext()); - return; - } - - CoreSubscriber i = inner; - - if (i == null) { - try { - first = t; - Publisher result = - Objects.requireNonNull( - transformer.apply(t, this), "The transformer returned a null value"); - result.subscribe(outer); - return; - } catch (Throwable e) { - onError(Operators.onOperatorError(s, e, t, currentContext())); - return; - } - } - - i.onNext(t); - } - - @Override - public boolean tryOnNext(T t) { - if (done) { - Operators.onNextDropped(t, currentContext()); - return false; - } - - Fuseable.ConditionalSubscriber i = inner; - - if (i == null) { - try { - first = t; - Publisher result = - Objects.requireNonNull( - transformer.apply(t, this), "The transformer returned a null value"); - result.subscribe(outer); - return true; - } catch (Throwable e) { - onError(Operators.onOperatorError(s, e, t, currentContext())); - return false; - } - } - - return i.tryOnNext(t); - } - - @Override - public void onError(Throwable t) { - if (done) { - Operators.onErrorDropped(t, currentContext()); - return; - } - - throwable = t; - done = true; - CoreSubscriber i = inner; - - if (i != null) { - if (first == null) { - drainRegular(); - } - } else { - Operators.error(outer, t); - } - } - - @Override - public void onComplete() { - if (done) { - return; - } - - done = true; - CoreSubscriber i = inner; - - if (i != null) { - if (first == null) { - drainRegular(); - } - } else { - Operators.complete(outer); - } - } - - @Override - public void request(long n) { - if (Operators.validate(n)) { - if (first != null && drainRegular() && n != Long.MAX_VALUE) { - if (--n > 0) { - s.request(n); - } - } else { - s.request(n); - } - } - } - - boolean drainRegular() { - if (WIP.getAndIncrement(this) != 0) { - return false; - } - - T f = first; - int m = 1; - boolean sent = false; - Subscription s = this.s; - CoreSubscriber a = inner; - - for (; ; ) { - if (f != null) { - first = null; - - if (s == Operators.cancelledSubscription()) { - Operators.onDiscard(f, a.currentContext()); - return true; - } - - a.onNext(f); - f = null; - sent = true; - } - - if (s == Operators.cancelledSubscription()) { - return sent; - } - - if (done) { - Throwable t = throwable; - if (t != null) { - a.onError(t); - } else { - a.onComplete(); - } - return sent; - } - - m = WIP.addAndGet(this, -m); - - if (m == 0) { - return sent; - } - } - } - } - - static final class ConditionalSubscriberAdapter implements Fuseable.ConditionalSubscriber { - - final CoreSubscriber delegate; - - ConditionalSubscriberAdapter(CoreSubscriber delegate) { - this.delegate = delegate; - } - - @Override - public Context currentContext() { - return delegate.currentContext(); - } - - @Override - public void onSubscribe(Subscription s) { - delegate.onSubscribe(s); - } - - @Override - public void onNext(T t) { - delegate.onNext(t); - } - - @Override - public void onError(Throwable t) { - delegate.onError(t); - } - - @Override - public void onComplete() { - delegate.onComplete(); - } - - @Override - public boolean tryOnNext(T t) { - delegate.onNext(t); - return true; - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java index dfcc13a64..cb8b5d63d 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -43,40 +43,58 @@ public final class UnboundedProcessor extends FluxProcessor implements Fuseable.QueueSubscription, Fuseable { + final Queue queue; + final Queue priorityQueue; + + volatile boolean done; + Throwable error; + // important to not loose the downstream too early and miss discard hook, while + // having relevant hasDownstreams() + boolean hasDownstream; + volatile CoreSubscriber actual; + + volatile boolean cancelled; + + volatile int once; + @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater ONCE = AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "once"); + volatile int wip; + @SuppressWarnings("rawtypes") static final AtomicIntegerFieldUpdater WIP = AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "wip"); + volatile int discardGuard; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater DISCARD_GUARD = + AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "discardGuard"); + + volatile long requested; + @SuppressWarnings("rawtypes") static final AtomicLongFieldUpdater REQUESTED = AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "requested"); - final Queue queue; - volatile boolean done; - Throwable error; - volatile CoreSubscriber actual; - volatile boolean cancelled; - volatile int once; - volatile int wip; - volatile long requested; - volatile boolean outputFused; + boolean outputFused; public UnboundedProcessor() { this.queue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.priorityQueue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); } @Override public int getBufferSize() { - return Queues.capacity(this.queue); + return Integer.MAX_VALUE; } @Override public Object scanUnsafe(Attr key) { if (Attr.BUFFERED == key) return queue.size(); + if (Attr.PREFETCH == key) return Integer.MAX_VALUE; return super.scanUnsafe(key); } @@ -84,6 +102,7 @@ void drainRegular(Subscriber a) { int missed = 1; final Queue q = queue; + final Queue pq = priorityQueue; for (; ; ) { @@ -93,10 +112,18 @@ void drainRegular(Subscriber a) { while (r != e) { boolean d = done; - T t = q.poll(); - boolean empty = t == null; + T t; + boolean empty; + + if (!pq.isEmpty()) { + t = pq.poll(); + empty = false; + } else { + t = q.poll(); + empty = t == null; + } - if (checkTerminated(d, empty, a, q)) { + if (checkTerminated(d, empty, a)) { return; } @@ -110,7 +137,7 @@ void drainRegular(Subscriber a) { } if (r == e) { - if (checkTerminated(done, q.isEmpty(), a, q)) { + if (checkTerminated(done, q.isEmpty() && pq.isEmpty(), a)) { return; } } @@ -129,13 +156,11 @@ void drainRegular(Subscriber a) { void drainFused(Subscriber a) { int missed = 1; - final Queue q = queue; - for (; ; ) { if (cancelled) { - q.clear(); - actual = null; + this.clear(); + hasDownstream = false; return; } @@ -144,7 +169,7 @@ void drainFused(Subscriber a) { a.onNext(null); if (d) { - actual = null; + hasDownstream = false; Throwable ex = error; if (ex != null) { @@ -164,6 +189,9 @@ void drainFused(Subscriber a) { public void drain() { if (WIP.getAndIncrement(this) != 0) { + if (cancelled) { + this.clear(); + } return; } @@ -188,20 +216,15 @@ public void drain() { } } - boolean checkTerminated(boolean d, boolean empty, Subscriber a, Queue q) { + boolean checkTerminated(boolean d, boolean empty, Subscriber a) { if (cancelled) { - while (!q.isEmpty()) { - T t = q.poll(); - if (t != null) { - release(t); - } - } - actual = null; + this.clear(); + hasDownstream = false; return true; } if (d && empty) { Throwable e = error; - actual = null; + hasDownstream = false; if (e != null) { a.onError(e); } else { @@ -222,10 +245,6 @@ public void onSubscribe(Subscription s) { } } - public long available() { - return requested; - } - @Override public int getPrefetch() { return Integer.MAX_VALUE; @@ -237,6 +256,23 @@ public Context currentContext() { return actual != null ? actual.currentContext() : Context.empty(); } + public void onNextPrioritized(T t) { + if (done || cancelled) { + Operators.onNextDropped(t, currentContext()); + release(t); + return; + } + + if (!priorityQueue.offer(t)) { + Throwable ex = + Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext()); + onError(Operators.onOperatorError(null, ex, t, currentContext())); + release(t); + return; + } + drain(); + } + @Override public void onNext(T t) { if (done || cancelled) { @@ -287,7 +323,7 @@ public void subscribe(CoreSubscriber actual) { actual.onSubscribe(this); this.actual = actual; if (cancelled) { - this.actual = null; + this.hasDownstream = false; } else { drain(); } @@ -314,38 +350,56 @@ public void cancel() { cancelled = true; if (WIP.getAndIncrement(this) == 0) { - clear(); - actual = null; + this.clear(); + hasDownstream = false; } } - @Override - public T peek() { - return queue.peek(); - } - @Override @Nullable public T poll() { + Queue pq = this.priorityQueue; + if (!pq.isEmpty()) { + return pq.poll(); + } return queue.poll(); } @Override public int size() { - return queue.size(); + return priorityQueue.size() + queue.size(); } @Override public boolean isEmpty() { - return queue.isEmpty(); + return priorityQueue.isEmpty() && queue.isEmpty(); } @Override public void clear() { - while (!queue.isEmpty()) { - T t = queue.poll(); - if (t != null) { - release(t); + if (DISCARD_GUARD.getAndIncrement(this) != 0) { + return; + } + + int missed = 1; + + for (; ; ) { + while (!queue.isEmpty()) { + T t = queue.poll(); + if (t != null) { + release(t); + } + } + while (!priorityQueue.isEmpty()) { + T t = priorityQueue.poll(); + if (t != null) { + release(t); + } + } + + missed = DISCARD_GUARD.addAndGet(this, -missed); + if (missed == 0) { + break; } } } @@ -387,14 +441,18 @@ public long downstreamCount() { @Override public boolean hasDownstreams() { - return actual != null; + return hasDownstream; } void release(T t) { if (t instanceof ReferenceCounted) { ReferenceCounted refCounted = (ReferenceCounted) t; if (refCounted.refCnt() > 0) { - refCounted.release(); + try { + refCounted.release(); + } catch (Throwable ex) { + // no ops + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java deleted file mode 100644 index 35d4906ec..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/UnicastMonoProcessor.java +++ /dev/null @@ -1,179 +0,0 @@ -package io.rsocket.internal; - -import java.util.Objects; -import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; -import java.util.stream.Stream; -import org.reactivestreams.Processor; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.Disposable; -import reactor.core.Scannable; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.Operators; -import reactor.util.annotation.Nullable; -import reactor.util.context.Context; -import reactor.util.function.Tuple2; - -public class UnicastMonoProcessor extends Mono - implements Processor, CoreSubscriber, Disposable, Subscription, Scannable { - - @SuppressWarnings("rawtypes") - static final AtomicIntegerFieldUpdater ONCE = - AtomicIntegerFieldUpdater.newUpdater(UnicastMonoProcessor.class, "once"); - - private final MonoProcessor processor; - - @SuppressWarnings("unused") - private volatile int once; - - private UnicastMonoProcessor() { - this.processor = MonoProcessor.create(); - } - - public static UnicastMonoProcessor create() { - return new UnicastMonoProcessor<>(); - } - - @Override - public Stream actuals() { - return processor.actuals(); - } - - @Override - public boolean isScanAvailable() { - return processor.isScanAvailable(); - } - - @Override - public String name() { - return processor.name(); - } - - @Override - public String stepName() { - return processor.stepName(); - } - - @Override - public Stream steps() { - return processor.steps(); - } - - @Override - public Stream parents() { - return processor.parents(); - } - - @Override - @Nullable - public T scan(Attr key) { - return processor.scan(key); - } - - @Override - public T scanOrDefault(Attr key, T defaultValue) { - return processor.scanOrDefault(key, defaultValue); - } - - @Override - public Stream> tags() { - return processor.tags(); - } - - @Override - public void onSubscribe(Subscription s) { - processor.onSubscribe(s); - } - - @Override - public void onNext(O o) { - processor.onNext(o); - } - - @Override - public void onError(Throwable t) { - processor.onError(t); - } - - @Nullable - public Throwable getError() { - return processor.getError(); - } - - public boolean isCancelled() { - return processor.isCancelled(); - } - - public boolean isError() { - return processor.isError(); - } - - public boolean isSuccess() { - return processor.isSuccess(); - } - - public boolean isTerminated() { - return processor.isTerminated(); - } - - @Nullable - public O peek() { - return processor.peek(); - } - - public long downstreamCount() { - return processor.downstreamCount(); - } - - public boolean hasDownstreams() { - return processor.hasDownstreams(); - } - - @Override - public void onComplete() { - processor.onComplete(); - } - - @Override - public void request(long n) { - processor.request(n); - } - - @Override - public void cancel() { - processor.cancel(); - } - - @Override - public void dispose() { - processor.dispose(); - } - - @Override - public Context currentContext() { - return processor.currentContext(); - } - - @Override - public boolean isDisposed() { - return processor.isDisposed(); - } - - @Override - public Object scanUnsafe(Attr key) { - return processor.scanUnsafe(key); - } - - @Override - public void subscribe(CoreSubscriber actual) { - Objects.requireNonNull(actual, "subscribe"); - if (once == 0 && ONCE.compareAndSet(this, 0, 1)) { - processor.subscribe(actual); - } else { - Operators.error( - actual, - new IllegalStateException("UnicastMonoProcessor allows only a single Subscriber")); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/package-info.java b/rsocket-core/src/main/java/io/rsocket/internal/package-info.java index 09918f3d1..07ddfab41 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/package-info.java @@ -18,5 +18,7 @@ * Internal package and must not be used outside this project. There are no guarantees for * API compatibility. */ -@javax.annotation.ParametersAreNonnullByDefault +@NonNullApi package io.rsocket.internal; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java index ea8a0de22..db29d8030 100644 --- a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java @@ -19,7 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; -import io.rsocket.frame.KeepAliveFrameFlyweight; +import io.rsocket.frame.KeepAliveFrameCodec; import io.rsocket.resume.ResumeStateHolder; import java.time.Duration; import java.util.concurrent.atomic.AtomicBoolean; @@ -69,14 +69,14 @@ public void receive(ByteBuf keepAliveFrame) { long remoteLastReceivedPos = remoteLastReceivedPosition(keepAliveFrame); resumeStateHolder.onImpliedPosition(remoteLastReceivedPos); } - if (KeepAliveFrameFlyweight.respondFlag(keepAliveFrame)) { + if (KeepAliveFrameCodec.respondFlag(keepAliveFrame)) { long localLastReceivedPos = localLastReceivedPosition(); send( - KeepAliveFrameFlyweight.encode( + KeepAliveFrameCodec.encode( allocator, false, localLastReceivedPos, - KeepAliveFrameFlyweight.data(keepAliveFrame).retain())); + KeepAliveFrameCodec.data(keepAliveFrame).retain())); } } @@ -118,7 +118,7 @@ long localLastReceivedPosition() { } long remoteLastReceivedPosition(ByteBuf keepAliveFrame) { - return KeepAliveFrameFlyweight.lastPosition(keepAliveFrame); + return KeepAliveFrameCodec.lastPosition(keepAliveFrame); } public static final class ServerKeepAliveSupport extends KeepAliveSupport { @@ -145,7 +145,7 @@ public ClientKeepAliveSupport( void onIntervalTick() { tryTimeout(); send( - KeepAliveFrameFlyweight.encode( + KeepAliveFrameCodec.encode( allocator, true, localLastReceivedPosition(), Unpooled.EMPTY_BUFFER)); } } diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java b/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java new file mode 100644 index 000000000..d94a93cad --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Support classes for sending and keeping track of KEEPALIVE frames from the remote. */ +@NonNullApi +package io.rsocket.keepalive; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java index b9d99f88a..673b4a480 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java @@ -19,8 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Availability; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; +import reactor.util.annotation.Nullable; /** A contract for RSocket lease, which is sent by a request acceptor and is time bound. */ public interface Lease extends Availability { @@ -78,7 +77,6 @@ default int getRemainingTimeToLiveMillis(long now) { * * @return Metadata for the lease. */ - @Nonnull ByteBuf getMetadata(); /** diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java index 63b0433cb..7abb8aab9 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java @@ -19,8 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import java.util.concurrent.atomic.AtomicInteger; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; +import reactor.util.annotation.Nullable; public class LeaseImpl implements Lease { private final int timeToLiveMillis; @@ -60,7 +59,6 @@ public int getStartingAllowedRequests() { return startingAllowedRequests; } - @Nonnull @Override public ByteBuf getMetadata() { return metadata; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/MissingLeaseException.java b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java similarity index 51% rename from rsocket-core/src/main/java/io/rsocket/exceptions/MissingLeaseException.java rename to rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java index 4bd6ffb99..3b6cec62c 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/MissingLeaseException.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java @@ -1,18 +1,32 @@ -package io.rsocket.exceptions; +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.lease; -import io.rsocket.lease.Lease; +import io.rsocket.exceptions.RejectedException; import java.util.Objects; -import javax.annotation.Nonnull; -import javax.annotation.Nullable; +import reactor.util.annotation.Nullable; public class MissingLeaseException extends RejectedException { private static final long serialVersionUID = -6169748673403858959L; - public MissingLeaseException(@Nonnull Lease lease, @Nonnull String tag) { + public MissingLeaseException(Lease lease, String tag) { super(leaseMessage(Objects.requireNonNull(lease), Objects.requireNonNull(tag))); } - public MissingLeaseException(@Nonnull String tag) { + public MissingLeaseException(String tag) { super(leaseMessage(null, Objects.requireNonNull(tag))); } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java index ca2111e87..fd569a2c8 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/RequesterLeaseHandler.java @@ -18,8 +18,7 @@ import io.netty.buffer.ByteBuf; import io.rsocket.Availability; -import io.rsocket.exceptions.MissingLeaseException; -import io.rsocket.frame.LeaseFrameFlyweight; +import io.rsocket.frame.LeaseFrameCodec; import java.util.function.Consumer; import reactor.core.Disposable; import reactor.core.publisher.Flux; @@ -64,9 +63,9 @@ public Exception leaseError() { @Override public void receive(ByteBuf leaseFrame) { - int numberOfRequests = LeaseFrameFlyweight.numRequests(leaseFrame); - int timeToLiveMillis = LeaseFrameFlyweight.ttl(leaseFrame); - ByteBuf metadata = LeaseFrameFlyweight.metadata(leaseFrame); + int numberOfRequests = LeaseFrameCodec.numRequests(leaseFrame); + int timeToLiveMillis = LeaseFrameCodec.ttl(leaseFrame); + ByteBuf metadata = LeaseFrameCodec.metadata(leaseFrame); LeaseImpl lease = LeaseImpl.create(timeToLiveMillis, numberOfRequests, metadata); currentLease = lease; receivedLease.onNext(lease); diff --git a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java b/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java index c517a55c4..df8787cb7 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/ResponderLeaseHandler.java @@ -19,15 +19,14 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.Availability; -import io.rsocket.exceptions.MissingLeaseException; -import io.rsocket.frame.LeaseFrameFlyweight; +import io.rsocket.frame.LeaseFrameCodec; import java.util.Optional; import java.util.function.Consumer; import java.util.function.Function; -import javax.annotation.Nullable; import reactor.core.Disposable; import reactor.core.Disposables; import reactor.core.publisher.Flux; +import reactor.util.annotation.Nullable; public interface ResponderLeaseHandler extends Availability { @@ -42,7 +41,6 @@ final class Impl implements ResponderLeaseHandler { private final String tag; private final ByteBufAllocator allocator; private final Function, Flux> leaseSender; - private final Consumer errorConsumer; private final Optional leaseStatsOption; private final T leaseStats; @@ -50,12 +48,10 @@ public Impl( String tag, ByteBufAllocator allocator, Function, Flux> leaseSender, - Consumer errorConsumer, Optional leaseStatsOption) { this.tag = tag; this.allocator = allocator; this.leaseSender = leaseSender; - this.errorConsumer = errorConsumer; this.leaseStatsOption = leaseStatsOption; this.leaseStats = leaseStatsOption.orElse(null); } @@ -87,8 +83,7 @@ public Disposable send(Consumer leaseFrameSender) { lease -> { currentLease = create(lease); leaseFrameSender.accept(createLeaseFrame(lease)); - }, - errorConsumer); + }); } @Override @@ -97,7 +92,7 @@ public double availability() { } private ByteBuf createLeaseFrame(Lease lease) { - return LeaseFrameFlyweight.encode( + return LeaseFrameCodec.encode( allocator, lease.getTimeToLiveMillis(), lease.getAllowedRequests(), lease.getMetadata()); } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/package-info.java b/rsocket-core/src/main/java/io/rsocket/lease/package-info.java index 6700c10d9..342ab27f7 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,5 +14,14 @@ * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** + * Contains support classes for the Lease feature of the RSocket protocol. + * + * @see Resuming + * Operation + */ +@NonNullApi package io.rsocket.lease; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java new file mode 100644 index 000000000..d908abb3c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java @@ -0,0 +1,335 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.util.CharByteBufUtil; + +public class AuthMetadataCodec { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + static final int USERNAME_BYTES_LENGTH = 1; + static final int AUTH_TYPE_ID_LENGTH = 1; + + static final char[] EMPTY_CHARS_ARRAY = new char[0]; + + private AuthMetadataCodec() {} + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customAuthType the custom mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code customAuthType} is non US_ASCII string or + * empty string or its length is greater than 128 bytes + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, String customAuthType, ByteBuf metadata) { + + int actualASCIILength = ByteBufUtil.utf8Bytes(customAuthType); + if (actualASCIILength != customAuthType.length()) { + throw new IllegalArgumentException("custom auth type must be US_ASCII characters only"); + } + if (actualASCIILength < 1 || actualASCIILength > 128) { + throw new IllegalArgumentException( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + int capacity = 1 + actualASCIILength; + ByteBuf headerBuffer = allocator.buffer(capacity, capacity); + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + headerBuffer.writeByte(actualASCIILength - 1); + + ByteBufUtil.reserveAndWriteUtf8(headerBuffer, customAuthType, actualASCIILength); + + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to create intermediate buffers as needed. + * @param authType the well-known mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code authType} is {@link + * WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} or {@link + * WellKnownAuthType#UNKNOWN_RESERVED_AUTH_TYPE} + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, WellKnownAuthType authType, ByteBuf metadata) { + + if (authType == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE + || authType == WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE) { + throw new IllegalArgumentException("only allowed AuthType should be used"); + } + + int capacity = AUTH_TYPE_ID_LENGTH; + ByteBuf headerBuffer = + allocator + .buffer(capacity, capacity) + .writeByte(authType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using Simple Authentication format + * + * @throws IllegalArgumentException if the username length is greater than 255 + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param username the char sequence which represents user name. + * @param password the char sequence which represents user password. + */ + public static ByteBuf encodeSimpleMetadata( + ByteBufAllocator allocator, char[] username, char[] password) { + + int usernameLength = CharByteBufUtil.utf8Bytes(username); + if (usernameLength > 255) { + throw new IllegalArgumentException( + "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + } + + int passwordLength = CharByteBufUtil.utf8Bytes(password); + int capacity = AUTH_TYPE_ID_LENGTH + USERNAME_BYTES_LENGTH + usernameLength + passwordLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.SIMPLE.getIdentifier() | STREAM_METADATA_KNOWN_MASK) + .writeByte(usernameLength); + + CharByteBufUtil.writeUtf8(buffer, username); + CharByteBufUtil.writeUtf8(buffer, password); + + return buffer; + } + + /** + * Encode a Authentication CompositeMetadata payload using Bearer Authentication format + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param token the char sequence which represents BEARER token. + */ + public static ByteBuf encodeBearerMetadata(ByteBufAllocator allocator, char[] token) { + + int tokenLength = CharByteBufUtil.utf8Bytes(token); + int capacity = AUTH_TYPE_ID_LENGTH + tokenLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.BEARER.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + CharByteBufUtil.writeUtf8(buffer, token); + + return buffer; + } + + /** + * Encode a new Authentication Metadata payload information, first verifying if the passed {@link + * String} matches a {@link WellKnownAuthType} (in which case it will be encoded in a compressed + * fashion using the mime id of that type). + * + *

Prefer using {@link #encodeMetadata(ByteBufAllocator, String, ByteBuf)} if you already know + * that the mime type is not a {@link WellKnownAuthType}. + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param authType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeMetadata(ByteBufAllocator, WellKnownAuthType, ByteBuf) + * @see #encodeMetadata(ByteBufAllocator, String, ByteBuf) + */ + public static ByteBuf encodeMetadataWithCompression( + ByteBufAllocator allocator, String authType, ByteBuf metadata) { + WellKnownAuthType wkn = WellKnownAuthType.fromString(authType); + if (wkn == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE) { + return AuthMetadataCodec.encodeMetadata(allocator, authType, metadata); + } else { + return AuthMetadataCodec.encodeMetadata(allocator, wkn, metadata); + } + } + + /** + * Get the first {@code byte} from a {@link ByteBuf} and check whether it is length or {@link + * WellKnownAuthType}. Assuming said buffer properly contains such a {@code byte} + * + * @param metadata byteBuf used to get information from + */ + public static boolean isWellKnownAuthType(ByteBuf metadata) { + byte lengthOrId = metadata.getByte(0); + return (lengthOrId & STREAM_METADATA_LENGTH_MASK) != lengthOrId; + } + + /** + * Read first byte from the given {@code metadata} and tries to convert it's value to {@link + * WellKnownAuthType}. + * + * @param metadata given metadata buffer to read from + * @return Return on of the know Auth types or {@link WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} if + * field's value is length or unknown auth type + * @throws IllegalStateException if not enough readable bytes in the given {@link ByteBuf} + */ + public static WellKnownAuthType readWellKnownAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 1) { + throw new IllegalStateException( + "Unable to decode Well Know Auth type. Not enough readable bytes"); + } + byte lengthOrId = metadata.readByte(); + int normalizedId = (byte) (lengthOrId & STREAM_METADATA_LENGTH_MASK); + + if (normalizedId != lengthOrId) { + return WellKnownAuthType.fromIdentifier(normalizedId); + } + + return WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } + + /** + * Read up to 129 bytes from the given metadata in order to get the custom Auth Type + * + * @param metadata + * @return + */ + public static CharSequence readCustomAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 2) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Not enough readable bytes"); + } + + byte encodedLength = metadata.readByte(); + if (encodedLength < 0) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Incorrect auth type length"); + } + + // encoded length is realLength - 1 in order to avoid intersection with 0x00 authtype + int realLength = encodedLength + 1; + if (metadata.readableBytes() < realLength) { + throw new IllegalArgumentException( + "Unable to decode custom Auth type. Malformed length or auth type string"); + } + + return metadata.readCharSequence(realLength, CharsetUtil.US_ASCII); + } + + /** + * Read all remaining {@code bytes} from the given {@link ByteBuf} and return sliced + * representation of a payload + * + * @param metadata metadata to get payload from. Please note, the {@code metadata#readIndex} + * should be set to the beginning of the payload bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if no bytes readable in the + * given one + */ + public static ByteBuf readPayload(ByteBuf metadata) { + if (metadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return metadata.readSlice(metadata.readableBytes()); + } + + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero + */ + public static ByteBuf readUsername(ByteBuf simpleAuthMetadata) { + short usernameLength = readUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read password from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if password length is zero + */ + public static ByteBuf readPassword(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(simpleAuthMetadata.readableBytes()); + } + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return {@code char[]} which represents UTF-8 username + */ + public static char[] readUsernameAsCharArray(ByteBuf simpleAuthMetadata) { + short usernameLength = readUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] readPasswordAsCharArray(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, simpleAuthMetadata.readableBytes()); + } + + /** + * Read all the remaining {@code bytes} from the given {@link ByteBuf} where the first byte is + * username length and the subsequent number of bytes equal to decoded length + * + * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] readBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { + if (bearerAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(bearerAuthMetadata, bearerAuthMetadata.readableBytes()); + } + + private static short readUsernameLength(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() < 1) { + throw new IllegalStateException( + "Unable to decode custom username. Not enough readable bytes"); + } + + short usernameLength = simpleAuthMetadata.readUnsignedByte(); + + if (simpleAuthMetadata.readableBytes() < usernameLength) { + throw new IllegalArgumentException( + "Unable to decode username. Malformed username length or content"); + } + + return usernameLength; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java new file mode 100644 index 000000000..5e00abba8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java @@ -0,0 +1,385 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.util.NumberUtils; +import reactor.util.annotation.Nullable; + +/** + * A flyweight class that can be used to encode/decode composite metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * CompositeMetadata} for an Iterator-like approach to decoding entries. + */ +public class CompositeMetadataCodec { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private CompositeMetadataCodec() {} + + public static int computeNextEntryIndex( + int currentEntryIndex, ByteBuf headerSlice, ByteBuf contentSlice) { + return currentEntryIndex + + headerSlice.readableBytes() // this includes the mime length byte + + 3 // 3 bytes of the content length, which are excluded from the slice + + contentSlice.readableBytes(); + } + + /** + * Decode the next metadata entry (a mime header + content pair of {@link ByteBuf}) from a {@link + * ByteBuf} that contains at least enough bytes for one more such entry. These buffers are + * actually slices of the full metadata buffer, and this method doesn't move the full metadata + * buffer's {@link ByteBuf#readerIndex()}. As such, it requires the user to provide an {@code + * index} to read from. The next index is computed by calling {@link #computeNextEntryIndex(int, + * ByteBuf, ByteBuf)}. Size of the first buffer (the "header buffer") drives which decoding method + * should be further applied to it. + * + *

The header buffer is either: + * + *

    + *
  • made up of a single byte: this represents an encoded mime id, which can be further + * decoded using {@link #decodeMimeIdFromMimeBuffer(ByteBuf)} + *
  • made up of 2 or more bytes: this represents an encoded mime String + its length, which + * can be further decoded using {@link #decodeMimeTypeFromMimeBuffer(ByteBuf)}. Note the + * encoded length, in the first byte, is skipped by this decoding method because the + * remaining length of the buffer is that of the mime string. + *
+ * + * @param compositeMetadata the source {@link ByteBuf} that originally contains one or more + * metadata entries + * @param entryIndex the {@link ByteBuf#readerIndex()} to start decoding from. original reader + * index is kept on the source buffer + * @param retainSlices should produced metadata entry buffers {@link ByteBuf#slice() slices} be + * {@link ByteBuf#retainedSlice() retained}? + * @return a {@link ByteBuf} array of length 2 containing the mime header buffer + * slice and the content buffer slice, or one of the + * zero-length error constant arrays + */ + public static ByteBuf[] decodeMimeAndContentBuffersSlices( + ByteBuf compositeMetadata, int entryIndex, boolean retainSlices) { + compositeMetadata.markReaderIndex(); + compositeMetadata.readerIndex(entryIndex); + + if (compositeMetadata.isReadable()) { + ByteBuf mime; + int ridx = compositeMetadata.readerIndex(); + byte mimeIdOrLength = compositeMetadata.readByte(); + if ((mimeIdOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + mime = + retainSlices + ? compositeMetadata.retainedSlice(ridx, 1) + : compositeMetadata.slice(ridx, 1); + } else { + // M flag unset, remaining 7 bits are the length of the mime + int mimeLength = Byte.toUnsignedInt(mimeIdOrLength) + 1; + + if (compositeMetadata.isReadable( + mimeLength)) { // need to be able to read an extra mimeLength bytes + // here we need a way for the returned ByteBuf to differentiate between a + // 1-byte length mime type and a 1 byte encoded mime id, preferably without + // re-applying the byte mask. The easiest way is to include the initial byte + // and have further decoding ignore the first byte. 1 byte buffer == id, 2+ byte + // buffer == full mime string. + mime = + retainSlices + ? + // we accommodate that we don't read from current readerIndex, but + // readerIndex - 1 ("0"), for a total slice size of mimeLength + 1 + compositeMetadata.retainedSlice(ridx, mimeLength + 1) + : compositeMetadata.slice(ridx, mimeLength + 1); + // we thus need to skip the bytes we just sliced, but not the flag/length byte + // which was already skipped in initial read + compositeMetadata.skipBytes(mimeLength); + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + + if (compositeMetadata.isReadable(3)) { + // ensures the length medium can be read + final int metadataLength = compositeMetadata.readUnsignedMedium(); + if (compositeMetadata.isReadable(metadataLength)) { + ByteBuf metadata = + retainSlices + ? compositeMetadata.readRetainedSlice(metadataLength) + : compositeMetadata.readSlice(metadataLength); + compositeMetadata.resetReaderIndex(); + return new ByteBuf[] {mime, metadata}; + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + compositeMetadata.resetReaderIndex(); + throw new IllegalArgumentException( + String.format("entry index %d is larger than buffer size", entryIndex)); + } + + /** + * Decode a {@code byte} compressed mime id from a {@link ByteBuf}, assuming said buffer properly + * contains such an id. + * + *

The buffer must have exactly one readable byte, which is assumed to have been tested for + * mime id encoding via the {@link #STREAM_METADATA_KNOWN_MASK} mask ({@code firstByte & + * STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). + * + *

If there is no readable byte, the negative identifier of {@link + * WellKnownMimeType#UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeBuffer the buffer that should next contain the compressed mime id byte + * @return the compressed mime id, between 0 and 127, or a negative id if the input is invalid + * @see #decodeMimeTypeFromMimeBuffer(ByteBuf) + */ + public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { + if (mimeBuffer.readableBytes() != 1) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier(); + } + return (byte) (mimeBuffer.readByte() & STREAM_METADATA_LENGTH_MASK); + } + + /** + * Decode a {@link CharSequence} custome mime type from a {@link ByteBuf}, assuming said buffer + * properly contains such a mime type. + * + *

The buffer must at least have two readable bytes, which distinguishes it from the {@link + * #decodeMimeIdFromMimeBuffer(ByteBuf) compressed id} case. The first byte is a size and the + * remaining bytes must correspond to the {@link CharSequence}, encoded fully in US_ASCII. As a + * result, the first byte can simply be skipped, and the remaining of the buffer be decoded to the + * mime type. + * + *

If the mime header buffer is less than 2 bytes long, returns {@code null}. + * + * @param flyweightMimeBuffer the mime header {@link ByteBuf} that contains length + custom mime + * type + * @return the decoded custom mime type, as a {@link CharSequence}, or null if the input is + * invalid + * @see #decodeMimeIdFromMimeBuffer(ByteBuf) + */ + @Nullable + public static CharSequence decodeMimeTypeFromMimeBuffer(ByteBuf flyweightMimeBuffer) { + if (flyweightMimeBuffer.readableBytes() < 2) { + throw new IllegalStateException("unable to decode explicit MIME type"); + } + // the encoded length is assumed to be kept at the start of the buffer + // but also assumed to be irrelevant because the rest of the slice length + // actually already matches _decoded_length + flyweightMimeBuffer.skipBytes(1); + int mimeStringLength = flyweightMimeBuffer.readableBytes(); + return flyweightMimeBuffer.readCharSequence(mimeStringLength, CharsetUtil.US_ASCII); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, without checking if the {@link String} can be matched with a well known compressable + * mime type. Prefer using this method and {@link #encodeAndAddMetadata(CompositeByteBuf, + * ByteBufAllocator, WellKnownMimeType, ByteBuf)} if you know in advance whether or not the mime + * is well known. Otherwise use {@link #encodeAndAddMetadataWithCompression(CompositeByteBuf, + * ByteBufAllocator, String, ByteBuf)} + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customMimeType the custom mime type to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String customMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, customMimeType, metadata.readableBytes()), metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + WellKnownMimeType knownMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, knownMimeType.getIdentifier(), metadata.readableBytes()), + metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, first verifying if the passed {@link String} matches a {@link WellKnownMimeType} (in + * which case it will be encoded in a compressed fashion using the mime id of that type). + * + *

Prefer using {@link #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, String, + * ByteBuf)} if you already know that the mime type is not a {@link WellKnownMimeType}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param mimeType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, WellKnownMimeType, ByteBuf) + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadataWithCompression( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String mimeType, + ByteBuf metadata) { + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, mimeType, metadata.readableBytes()), metadata); + } else { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, wkn.getIdentifier(), metadata.readableBytes()), + metadata); + } + } + + /** + * Returns whether there is another entry available at a given index + * + * @param compositeMetadata the buffer to inspect + * @param entryIndex the index to check at + * @return whether there is another entry available at a given index + */ + public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { + return compositeMetadata.writerIndex() - entryIndex > 0; + } + + /** + * Returns whether the header represents a well-known MIME type. + * + * @param header the header to inspect + * @return whether the header represents a well-known MIME type + */ + public static boolean isWellKnownMimeType(ByteBuf header) { + return header.readableBytes() == 1; + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param unknownCompressedMimeType the id of the {@link + * WellKnownMimeType#UNKNOWN_RESERVED_MIME_TYPE} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + byte unknownCompressedMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, unknownCompressedMimeType, metadata.readableBytes()), + metadata); + } + + /** + * Encode a custom mime type and a metadata value length into a newly allocated {@link ByteBuf}. + * + *

This larger representation encodes the mime type representation's length on a single byte, + * then the representation itself, then the unsigned metadata value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param customMime a custom mime type to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, String customMime, int metadataLength) { + ByteBuf metadataHeader = allocator.buffer(4 + customMime.length()); + // reserve 1 byte for the customMime length + // /!\ careful not to read that first byte, which is random at this point + int writerIndexInitial = metadataHeader.writerIndex(); + metadataHeader.writerIndex(writerIndexInitial + 1); + + // write the custom mime in UTF8 but validate it is all ASCII-compatible + // (which produces the right result since ASCII chars are still encoded on 1 byte in UTF8) + int customMimeLength = ByteBufUtil.writeUtf8(metadataHeader, customMime); + if (!ByteBufUtil.isText( + metadataHeader, metadataHeader.readerIndex() + 1, customMimeLength, CharsetUtil.US_ASCII)) { + metadataHeader.release(); + throw new IllegalArgumentException("custom mime type must be US_ASCII characters only"); + } + if (customMimeLength < 1 || customMimeLength > 128) { + metadataHeader.release(); + throw new IllegalArgumentException( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + metadataHeader.markWriterIndex(); + + // go back to beginning and write the length + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + metadataHeader.writerIndex(writerIndexInitial); + metadataHeader.writeByte(customMimeLength - 1); + + // go back to post-mime type and write the metadata content length + metadataHeader.resetWriterIndex(); + NumberUtils.encodeUnsignedMedium(metadataHeader, metadataLength); + + return metadataHeader; + } + + /** + * Encode a {@link WellKnownMimeType well known mime type} and a metadata value length into a + * newly allocated {@link ByteBuf}. + * + *

This compact representation encodes the mime type via its ID on a single byte, and the + * unsigned value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param mimeType a byte identifier of a {@link WellKnownMimeType} to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, byte mimeType, int metadataLength) { + ByteBuf buffer = allocator.buffer(4, 4).writeByte(mimeType | STREAM_METADATA_KNOWN_MASK); + + NumberUtils.encodeUnsignedMedium(buffer, metadataLength); + + return buffer; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java index 0520285c2..9916dfd3b 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataFlyweight.java @@ -18,31 +18,25 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; import io.netty.buffer.CompositeByteBuf; -import io.netty.util.CharsetUtil; -import io.rsocket.util.NumberUtils; import reactor.util.annotation.Nullable; /** * A flyweight class that can be used to encode/decode composite metadata information to/from {@link * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link * CompositeMetadata} for an Iterator-like approach to decoding entries. + * + * @deprecated in favor of {@link CompositeMetadataCodec} */ +@Deprecated public class CompositeMetadataFlyweight { - static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 - - static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 - private CompositeMetadataFlyweight() {} public static int computeNextEntryIndex( int currentEntryIndex, ByteBuf headerSlice, ByteBuf contentSlice) { - return currentEntryIndex - + headerSlice.readableBytes() // this includes the mime length byte - + 3 // 3 bytes of the content length, which are excluded from the slice - + contentSlice.readableBytes(); + return CompositeMetadataCodec.computeNextEntryIndex( + currentEntryIndex, headerSlice, contentSlice); } /** @@ -77,67 +71,8 @@ public static int computeNextEntryIndex( */ public static ByteBuf[] decodeMimeAndContentBuffersSlices( ByteBuf compositeMetadata, int entryIndex, boolean retainSlices) { - compositeMetadata.markReaderIndex(); - compositeMetadata.readerIndex(entryIndex); - - if (compositeMetadata.isReadable()) { - ByteBuf mime; - int ridx = compositeMetadata.readerIndex(); - byte mimeIdOrLength = compositeMetadata.readByte(); - if ((mimeIdOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { - mime = - retainSlices - ? compositeMetadata.retainedSlice(ridx, 1) - : compositeMetadata.slice(ridx, 1); - } else { - // M flag unset, remaining 7 bits are the length of the mime - int mimeLength = Byte.toUnsignedInt(mimeIdOrLength) + 1; - - if (compositeMetadata.isReadable( - mimeLength)) { // need to be able to read an extra mimeLength bytes - // here we need a way for the returned ByteBuf to differentiate between a - // 1-byte length mime type and a 1 byte encoded mime id, preferably without - // re-applying the byte mask. The easiest way is to include the initial byte - // and have further decoding ignore the first byte. 1 byte buffer == id, 2+ byte - // buffer == full mime string. - mime = - retainSlices - ? - // we accommodate that we don't read from current readerIndex, but - // readerIndex - 1 ("0"), for a total slice size of mimeLength + 1 - compositeMetadata.retainedSlice(ridx, mimeLength + 1) - : compositeMetadata.slice(ridx, mimeLength + 1); - // we thus need to skip the bytes we just sliced, but not the flag/length byte - // which was already skipped in initial read - compositeMetadata.skipBytes(mimeLength); - } else { - compositeMetadata.resetReaderIndex(); - throw new IllegalStateException("metadata is malformed"); - } - } - - if (compositeMetadata.isReadable(3)) { - // ensures the length medium can be read - final int metadataLength = compositeMetadata.readUnsignedMedium(); - if (compositeMetadata.isReadable(metadataLength)) { - ByteBuf metadata = - retainSlices - ? compositeMetadata.readRetainedSlice(metadataLength) - : compositeMetadata.readSlice(metadataLength); - compositeMetadata.resetReaderIndex(); - return new ByteBuf[] {mime, metadata}; - } else { - compositeMetadata.resetReaderIndex(); - throw new IllegalStateException("metadata is malformed"); - } - } else { - compositeMetadata.resetReaderIndex(); - throw new IllegalStateException("metadata is malformed"); - } - } - compositeMetadata.resetReaderIndex(); - throw new IllegalArgumentException( - String.format("entry index %d is larger than buffer size", entryIndex)); + return CompositeMetadataCodec.decodeMimeAndContentBuffersSlices( + compositeMetadata, entryIndex, retainSlices); } /** @@ -145,8 +80,8 @@ public static ByteBuf[] decodeMimeAndContentBuffersSlices( * contains such an id. * *

The buffer must have exactly one readable byte, which is assumed to have been tested for - * mime id encoding via the {@link #STREAM_METADATA_KNOWN_MASK} mask ({@code firstByte & - * STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). + * mime id encoding via the {@link CompositeMetadataCodec#STREAM_METADATA_KNOWN_MASK} mask ({@code + * firstByte & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). * *

If there is no readable byte, the negative identifier of {@link * WellKnownMimeType#UNPARSEABLE_MIME_TYPE} is returned. @@ -156,10 +91,7 @@ public static ByteBuf[] decodeMimeAndContentBuffersSlices( * @see #decodeMimeTypeFromMimeBuffer(ByteBuf) */ public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { - if (mimeBuffer.readableBytes() != 1) { - return WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier(); - } - return (byte) (mimeBuffer.readByte() & STREAM_METADATA_LENGTH_MASK); + return CompositeMetadataCodec.decodeMimeIdFromMimeBuffer(mimeBuffer); } /** @@ -182,15 +114,7 @@ public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { */ @Nullable public static CharSequence decodeMimeTypeFromMimeBuffer(ByteBuf flyweightMimeBuffer) { - if (flyweightMimeBuffer.readableBytes() < 2) { - throw new IllegalStateException("unable to decode explicit MIME type"); - } - // the encoded length is assumed to be kept at the start of the buffer - // but also assumed to be irrelevant because the rest of the slice length - // actually already matches _decoded_length - flyweightMimeBuffer.skipBytes(1); - int mimeStringLength = flyweightMimeBuffer.readableBytes(); - return flyweightMimeBuffer.readCharSequence(mimeStringLength, CharsetUtil.US_ASCII); + return CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(flyweightMimeBuffer); } /** @@ -212,8 +136,8 @@ public static void encodeAndAddMetadata( ByteBufAllocator allocator, String customMimeType, ByteBuf metadata) { - compositeMetaData.addComponents( - true, encodeMetadataHeader(allocator, customMimeType, metadata.readableBytes()), metadata); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetaData, allocator, customMimeType, metadata); } /** @@ -231,10 +155,8 @@ public static void encodeAndAddMetadata( ByteBufAllocator allocator, WellKnownMimeType knownMimeType, ByteBuf metadata) { - compositeMetaData.addComponents( - true, - encodeMetadataHeader(allocator, knownMimeType.getIdentifier(), metadata.readableBytes()), - metadata); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetaData, allocator, knownMimeType, metadata); } /** @@ -258,16 +180,8 @@ public static void encodeAndAddMetadataWithCompression( ByteBufAllocator allocator, String mimeType, ByteBuf metadata) { - WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); - if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { - compositeMetaData.addComponents( - true, encodeMetadataHeader(allocator, mimeType, metadata.readableBytes()), metadata); - } else { - compositeMetaData.addComponents( - true, - encodeMetadataHeader(allocator, wkn.getIdentifier(), metadata.readableBytes()), - metadata); - } + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + compositeMetaData, allocator, mimeType, metadata); } /** @@ -278,7 +192,7 @@ public static void encodeAndAddMetadataWithCompression( * @return whether there is another entry available at a given index */ public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { - return compositeMetadata.writerIndex() - entryIndex > 0; + return CompositeMetadataCodec.hasEntry(compositeMetadata, entryIndex); } /** @@ -288,7 +202,7 @@ public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { * @return whether the header represents a well-known MIME type */ public static boolean isWellKnownMimeType(ByteBuf header) { - return header.readableBytes() == 1; + return CompositeMetadataCodec.isWellKnownMimeType(header); } /** @@ -307,10 +221,8 @@ static void encodeAndAddMetadata( ByteBufAllocator allocator, byte unknownCompressedMimeType, ByteBuf metadata) { - compositeMetaData.addComponents( - true, - encodeMetadataHeader(allocator, unknownCompressedMimeType, metadata.readableBytes()), - metadata); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetaData, allocator, unknownCompressedMimeType, metadata); } /** @@ -327,38 +239,7 @@ static void encodeAndAddMetadata( */ static ByteBuf encodeMetadataHeader( ByteBufAllocator allocator, String customMime, int metadataLength) { - ByteBuf metadataHeader = allocator.buffer(4 + customMime.length()); - // reserve 1 byte for the customMime length - // /!\ careful not to read that first byte, which is random at this point - int writerIndexInitial = metadataHeader.writerIndex(); - metadataHeader.writerIndex(writerIndexInitial + 1); - - // write the custom mime in UTF8 but validate it is all ASCII-compatible - // (which produces the right result since ASCII chars are still encoded on 1 byte in UTF8) - int customMimeLength = ByteBufUtil.writeUtf8(metadataHeader, customMime); - if (!ByteBufUtil.isText( - metadataHeader, metadataHeader.readerIndex() + 1, customMimeLength, CharsetUtil.US_ASCII)) { - metadataHeader.release(); - throw new IllegalArgumentException("custom mime type must be US_ASCII characters only"); - } - if (customMimeLength < 1 || customMimeLength > 128) { - metadataHeader.release(); - throw new IllegalArgumentException( - "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); - } - metadataHeader.markWriterIndex(); - - // go back to beginning and write the length - // encoded length is one less than actual length, since 0 is never a valid length, which gives - // wider representation range - metadataHeader.writerIndex(writerIndexInitial); - metadataHeader.writeByte(customMimeLength - 1); - - // go back to post-mime type and write the metadata content length - metadataHeader.resetWriterIndex(); - NumberUtils.encodeUnsignedMedium(metadataHeader, metadataLength); - - return metadataHeader; + return CompositeMetadataCodec.encodeMetadataHeader(allocator, customMime, metadataLength); } /** @@ -376,10 +257,6 @@ static ByteBuf encodeMetadataHeader( */ static ByteBuf encodeMetadataHeader( ByteBufAllocator allocator, byte mimeType, int metadataLength) { - ByteBuf buffer = allocator.buffer(4, 4).writeByte(mimeType | STREAM_METADATA_KNOWN_MASK); - - NumberUtils.encodeUnsignedMedium(buffer, metadataLength); - - return buffer; + return CompositeMetadataCodec.encodeMetadataHeader(allocator, mimeType, metadataLength); } } diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java new file mode 100644 index 000000000..d766cf59f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java @@ -0,0 +1,76 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import java.nio.charset.StandardCharsets; +import java.util.Collection; + +/** + * A flyweight class that can be used to encode/decode tagging metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * TaggingMetadata} for an Iterator-like approach to decoding entries. + * + * @author linux_china + */ +public class TaggingMetadataCodec { + /** Tag max length in bytes */ + private static int TAG_LENGTH_MAX = 0xFF; + + /** + * create routing metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return routing metadata + */ + public static RoutingMetadata createRoutingMetadata( + ByteBufAllocator allocator, Collection tags) { + return new RoutingMetadata(createTaggingContent(allocator, tags)); + } + + /** + * create tagging metadata from composite metadata entry + * + * @param entry composite metadata entry + * @return tagging metadata + */ + public static TaggingMetadata createTaggingMetadata(CompositeMetadata.Entry entry) { + return new TaggingMetadata(entry.getMimeType(), entry.getContent()); + } + + /** + * create tagging metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param tags tag values + * @return Tagging Metadata + */ + public static TaggingMetadata createTaggingMetadata( + ByteBufAllocator allocator, String knownMimeType, Collection tags) { + return new TaggingMetadata(knownMimeType, createTaggingContent(allocator, tags)); + } + + /** + * create tagging content + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return tagging content + */ + public static ByteBuf createTaggingContent(ByteBufAllocator allocator, Collection tags) { + CompositeByteBuf taggingContent = allocator.compositeBuffer(); + for (String key : tags) { + int length = ByteBufUtil.utf8Bytes(key); + if (length == 0 || length > TAG_LENGTH_MAX) { + continue; + } + ByteBuf byteBuf = allocator.buffer().writeByte(length); + byteBuf.writeCharSequence(key, StandardCharsets.UTF_8); + taggingContent.addComponent(true, byteBuf); + } + return taggingContent; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java index c7870bf0d..718528358 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataFlyweight.java @@ -2,9 +2,6 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.CompositeByteBuf; -import java.nio.charset.StandardCharsets; import java.util.Collection; /** @@ -12,12 +9,11 @@ * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link * TaggingMetadata} for an Iterator-like approach to decoding entries. * + * @deprecated in favor of {@link TaggingMetadataCodec} * @author linux_china */ +@Deprecated public class TaggingMetadataFlyweight { - /** Tag max length in bytes */ - private static int TAG_LENGTH_MAX = 0xFF; - /** * create routing metadata * @@ -27,7 +23,7 @@ public class TaggingMetadataFlyweight { */ public static RoutingMetadata createRoutingMetadata( ByteBufAllocator allocator, Collection tags) { - return new RoutingMetadata(createTaggingContent(allocator, tags)); + return TaggingMetadataCodec.createRoutingMetadata(allocator, tags); } /** @@ -37,7 +33,7 @@ public static RoutingMetadata createRoutingMetadata( * @return tagging metadata */ public static TaggingMetadata createTaggingMetadata(CompositeMetadata.Entry entry) { - return new TaggingMetadata(entry.getMimeType(), entry.getContent()); + return TaggingMetadataCodec.createTaggingMetadata(entry); } /** @@ -50,7 +46,7 @@ public static TaggingMetadata createTaggingMetadata(CompositeMetadata.Entry entr */ public static TaggingMetadata createTaggingMetadata( ByteBufAllocator allocator, String knownMimeType, Collection tags) { - return new TaggingMetadata(knownMimeType, createTaggingContent(allocator, tags)); + return TaggingMetadataCodec.createTaggingMetadata(allocator, knownMimeType, tags); } /** @@ -61,16 +57,6 @@ public static TaggingMetadata createTaggingMetadata( * @return tagging content */ public static ByteBuf createTaggingContent(ByteBufAllocator allocator, Collection tags) { - CompositeByteBuf taggingContent = allocator.compositeBuffer(); - for (String key : tags) { - int length = ByteBufUtil.utf8Bytes(key); - if (length == 0 || length > TAG_LENGTH_MAX) { - continue; - } - ByteBuf byteBuf = allocator.buffer().writeByte(length); - byteBuf.writeCharSequence(key, StandardCharsets.UTF_8); - taggingContent.addComponent(true, byteBuf); - } - return taggingContent; + return TaggingMetadataCodec.createTaggingContent(allocator, tags); } } diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java new file mode 100644 index 000000000..d276a9436 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +/** + * Represents decoded tracing metadata which is fully compatible with Zipkin B3 propagation + * + * @since 1.0 + */ +public final class TracingMetadata { + + final long traceIdHigh; + final long traceId; + private final boolean hasParentId; + final long parentId; + final long spanId; + final boolean isEmpty; + final boolean isNotSampled; + final boolean isSampled; + final boolean isDebug; + + TracingMetadata( + long traceIdHigh, + long traceId, + long spanId, + boolean hasParentId, + long parentId, + boolean isEmpty, + boolean isNotSampled, + boolean isSampled, + boolean isDebug) { + this.traceIdHigh = traceIdHigh; + this.traceId = traceId; + this.spanId = spanId; + this.hasParentId = hasParentId; + this.parentId = parentId; + this.isEmpty = isEmpty; + this.isNotSampled = isNotSampled; + this.isSampled = isSampled; + this.isDebug = isDebug; + } + + /** When non-zero, the trace containing this span uses 128-bit trace identifiers. */ + public long traceIdHigh() { + return traceIdHigh; + } + + /** Unique 8-byte identifier for a trace, set on all spans within it. */ + public long traceId() { + return traceId; + } + + /** Indicates if the parent's {@link #spanId} or if this the root span in a trace. */ + public final boolean hasParent() { + return hasParentId; + } + + /** Returns the parent's {@link #spanId} where zero implies absent. */ + public long parentId() { + return parentId; + } + + /** + * Unique 8-byte identifier of this span within a trace. + * + *

A span is uniquely identified in storage by ({@linkplain #traceId}, {@linkplain #spanId}). + */ + public long spanId() { + return spanId; + } + + /** Indicates that trace IDs should be accepted for tracing. */ + public boolean isSampled() { + return isSampled; + } + + /** Indicates that trace IDs should be force traced. */ + public boolean isDebug() { + return isDebug; + } + + /** Includes that there is sampling information and no trace IDs. */ + public boolean isEmpty() { + return isEmpty; + } + + /** + * Indicated that sampling decision is present. If {@code false} means that decision is unknown + * and says explicitly that {@link #isDebug()} and {@link #isSampled()} also returns {@code + * false}. + */ + public boolean isDecided() { + return isNotSampled || isDebug || isSampled; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java new file mode 100644 index 000000000..eb44956f6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java @@ -0,0 +1,172 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +/** + * Represents codes for tracing metadata which is fully compatible with Zipkin B3 propagation + * + * @since 1.0 + */ +public class TracingMetadataCodec { + + static final int FLAG_EXTENDED_TRACE_ID_SIZE = 0b0000_1000; + static final int FLAG_INCLUDE_PARENT_ID = 0b0000_0100; + static final int FLAG_NOT_SAMPLED = 0b0001_0000; + static final int FLAG_SAMPLED = 0b0010_0000; + static final int FLAG_DEBUG = 0b0100_0000; + static final int FLAG_IDS_SET = 0b1000_0000; + + public static ByteBuf encodeEmpty(ByteBufAllocator allocator, Flags flag) { + + return encode(allocator, true, 0, 0, false, 0, 0, false, flag); + } + + public static ByteBuf encode128( + ByteBufAllocator allocator, + long traceIdHigh, + long traceId, + long spanId, + long parentId, + Flags flag) { + + return encode(allocator, false, traceIdHigh, traceId, true, spanId, parentId, true, flag); + } + + public static ByteBuf encode128( + ByteBufAllocator allocator, long traceIdHigh, long traceId, long spanId, Flags flag) { + + return encode(allocator, false, traceIdHigh, traceId, true, spanId, 0, false, flag); + } + + public static ByteBuf encode64( + ByteBufAllocator allocator, long traceId, long spanId, long parentId, Flags flag) { + + return encode(allocator, false, 0, traceId, false, spanId, parentId, true, flag); + } + + public static ByteBuf encode64( + ByteBufAllocator allocator, long traceId, long spanId, Flags flag) { + return encode(allocator, false, 0, traceId, false, spanId, 0, false, flag); + } + + static ByteBuf encode( + ByteBufAllocator allocator, + boolean isEmpty, + long traceIdHigh, + long traceId, + boolean extendedTraceId, + long spanId, + long parentId, + boolean includesParent, + Flags flag) { + int size = + 1 + + (isEmpty + ? 0 + : (Long.BYTES + + Long.BYTES + + (extendedTraceId ? Long.BYTES : 0) + + (includesParent ? Long.BYTES : 0))); + final ByteBuf buffer = allocator.buffer(size); + + int byteFlags = 0; + switch (flag) { + case NOT_SAMPLE: + byteFlags |= FLAG_NOT_SAMPLED; + break; + case SAMPLE: + byteFlags |= FLAG_SAMPLED; + break; + case DEBUG: + byteFlags |= FLAG_DEBUG; + break; + } + + if (isEmpty) { + return buffer.writeByte(byteFlags); + } + + byteFlags |= FLAG_IDS_SET; + + if (extendedTraceId) { + byteFlags |= FLAG_EXTENDED_TRACE_ID_SIZE; + } + + if (includesParent) { + byteFlags |= FLAG_INCLUDE_PARENT_ID; + } + + buffer.writeByte(byteFlags); + + if (extendedTraceId) { + buffer.writeLong(traceIdHigh); + } + + buffer.writeLong(traceId).writeLong(spanId); + + if (includesParent) { + buffer.writeLong(parentId); + } + + return buffer; + } + + public static TracingMetadata decode(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + try { + byte flags = byteBuf.readByte(); + boolean isNotSampled = (flags & FLAG_NOT_SAMPLED) == FLAG_NOT_SAMPLED; + boolean isSampled = (flags & FLAG_SAMPLED) == FLAG_SAMPLED; + boolean isDebug = (flags & FLAG_DEBUG) == FLAG_DEBUG; + boolean isIDSet = (flags & FLAG_IDS_SET) == FLAG_IDS_SET; + + if (!isIDSet) { + return new TracingMetadata(0, 0, 0, false, 0, true, isNotSampled, isSampled, isDebug); + } + + boolean extendedTraceId = + (flags & FLAG_EXTENDED_TRACE_ID_SIZE) == FLAG_EXTENDED_TRACE_ID_SIZE; + + long traceIdHigh; + if (extendedTraceId) { + traceIdHigh = byteBuf.readLong(); + } else { + traceIdHigh = 0; + } + + long traceId = byteBuf.readLong(); + long spanId = byteBuf.readLong(); + + boolean includesParent = (flags & FLAG_INCLUDE_PARENT_ID) == FLAG_INCLUDE_PARENT_ID; + + long parentId; + if (includesParent) { + parentId = byteBuf.readLong(); + } else { + parentId = 0; + } + + return new TracingMetadata( + traceIdHigh, + traceId, + spanId, + includesParent, + parentId, + false, + isNotSampled, + isSampled, + isDebug); + } finally { + byteBuf.resetReaderIndex(); + } + } + + public enum Flags { + UNDECIDED, + NOT_SAMPLE, + SAMPLE, + DEBUG + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java new file mode 100644 index 000000000..66c98701c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Auth Types, as defined in the eponymous extension. Such auth types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownAuthType { + UNPARSEABLE_AUTH_TYPE("UNPARSEABLE_AUTH_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_AUTH_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + SIMPLE("simple", (byte) 0x00), + BEARER("bearer", (byte) 0x01); + // ... reserved for future use ... + + static final WellKnownAuthType[] TYPES_BY_AUTH_ID; + static final Map TYPES_BY_AUTH_STRING; + + static { + // precompute an array of all valid auth ids, filling the blanks with the RESERVED enum + TYPES_BY_AUTH_ID = new WellKnownAuthType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_AUTH_ID, UNKNOWN_RESERVED_AUTH_TYPE); + // also prepare a Map of the types by auth string + TYPES_BY_AUTH_STRING = new LinkedHashMap<>(128); + + for (WellKnownAuthType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_AUTH_ID[value.getIdentifier()] = value; + TYPES_BY_AUTH_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownAuthType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownAuthType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_AUTH_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_AUTH_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownAuthType}, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownAuthType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_AUTH_TYPE; + } + return TYPES_BY_AUTH_ID[id]; + } + + /** + * Find the {@link WellKnownAuthType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownAuthType}, the {@link + * #UNPARSEABLE_AUTH_TYPE} is returned. + * + * @param authType the looked up auth type + * @return the matching {@link WellKnownAuthType}, or {@link #UNPARSEABLE_AUTH_TYPE} if none + * matches + */ + public static WellKnownAuthType fromString(String authType) { + if (authType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_AUTH_TYPE's text has been used + if (authType.equals(UNKNOWN_RESERVED_AUTH_TYPE.str)) { + return UNPARSEABLE_AUTH_TYPE; + } + + return TYPES_BY_AUTH_STRING.getOrDefault(authType, UNPARSEABLE_AUTH_TYPE); + } + + /** @return the byte identifier of the auth type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the auth type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java index 82cce54a0..e78e87629 100644 --- a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java @@ -72,7 +72,9 @@ public enum WellKnownMimeType { APPLICATION_CLOUDEVENTS_JSON("application/cloudevents+json", (byte) 0x28), // ... reserved for future use ... - + MESSAGE_RSOCKET_MIMETYPE("message/x.rsocket.mime-type.v0", (byte) 0x7A), + MESSAGE_RSOCKET_ACCEPT_MIMETYPES("message/x.rsocket.accept-mime-types.v0", (byte) 0x7B), + MESSAGE_RSOCKET_AUTHENTICATION("message/x.rsocket.authentication.v0", (byte) 0x7C), MESSAGE_RSOCKET_TRACING_ZIPKIN("message/x.rsocket.tracing-zipkin.v0", (byte) 0x7D), MESSAGE_RSOCKET_ROUTING("message/x.rsocket.routing.v0", (byte) 0x7E), MESSAGE_RSOCKET_COMPOSITE_METADATA("message/x.rsocket.composite-metadata.v0", (byte) 0x7F); diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java b/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java new file mode 100644 index 000000000..3fb9ae1d6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains implementations of RSocket protocol extensions related + * to the use of metadata. + */ +@NonNullApi +package io.rsocket.metadata; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java new file mode 100644 index 000000000..e1a8ba449 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/security/AuthMetadataFlyweight.java @@ -0,0 +1,194 @@ +package io.rsocket.metadata.security; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.metadata.AuthMetadataCodec; + +/** @deprecated in favor of {@link io.rsocket.metadata.AuthMetadataCodec} */ +@Deprecated +public class AuthMetadataFlyweight { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + private AuthMetadataFlyweight() {} + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customAuthType the custom mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code customAuthType} is non US_ASCII string or + * empty string or its length is greater than 128 bytes + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, String customAuthType, ByteBuf metadata) { + + return AuthMetadataCodec.encodeMetadata(allocator, customAuthType, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to create intermediate buffers as needed. + * @param authType the well-known mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code authType} is {@link + * WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} or {@link + * WellKnownAuthType#UNKNOWN_RESERVED_AUTH_TYPE} + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, WellKnownAuthType authType, ByteBuf metadata) { + + return AuthMetadataCodec.encodeMetadata(allocator, WellKnownAuthType.cast(authType), metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using Simple Authentication format + * + * @throws IllegalArgumentException if the username length is greater than 255 + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param username the char sequence which represents user name. + * @param password the char sequence which represents user password. + */ + public static ByteBuf encodeSimpleMetadata( + ByteBufAllocator allocator, char[] username, char[] password) { + return AuthMetadataCodec.encodeSimpleMetadata(allocator, username, password); + } + + /** + * Encode a Authentication CompositeMetadata payload using Bearer Authentication format + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param token the char sequence which represents BEARER token. + */ + public static ByteBuf encodeBearerMetadata(ByteBufAllocator allocator, char[] token) { + return AuthMetadataCodec.encodeBearerMetadata(allocator, token); + } + + /** + * Encode a new Authentication Metadata payload information, first verifying if the passed {@link + * String} matches a {@link WellKnownAuthType} (in which case it will be encoded in a compressed + * fashion using the mime id of that type). + * + *

Prefer using {@link #encodeMetadata(ByteBufAllocator, String, ByteBuf)} if you already know + * that the mime type is not a {@link WellKnownAuthType}. + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param authType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeMetadata(ByteBufAllocator, WellKnownAuthType, ByteBuf) + * @see #encodeMetadata(ByteBufAllocator, String, ByteBuf) + */ + public static ByteBuf encodeMetadataWithCompression( + ByteBufAllocator allocator, String authType, ByteBuf metadata) { + return AuthMetadataCodec.encodeMetadataWithCompression(allocator, authType, metadata); + } + + /** + * Get the first {@code byte} from a {@link ByteBuf} and check whether it is length or {@link + * WellKnownAuthType}. Assuming said buffer properly contains such a {@code byte} + * + * @param metadata byteBuf used to get information from + */ + public static boolean isWellKnownAuthType(ByteBuf metadata) { + return AuthMetadataCodec.isWellKnownAuthType(metadata); + } + + /** + * Read first byte from the given {@code metadata} and tries to convert it's value to {@link + * WellKnownAuthType}. + * + * @param metadata given metadata buffer to read from + * @return Return on of the know Auth types or {@link WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} if + * field's value is length or unknown auth type + * @throws IllegalStateException if not enough readable bytes in the given {@link ByteBuf} + */ + public static WellKnownAuthType decodeWellKnownAuthType(ByteBuf metadata) { + return WellKnownAuthType.cast(AuthMetadataCodec.readWellKnownAuthType(metadata)); + } + + /** + * Read up to 129 bytes from the given metadata in order to get the custom Auth Type + * + * @param metadata + * @return + */ + public static CharSequence decodeCustomAuthType(ByteBuf metadata) { + return AuthMetadataCodec.readCustomAuthType(metadata); + } + + /** + * Read all remaining {@code bytes} from the given {@link ByteBuf} and return sliced + * representation of a payload + * + * @param metadata metadata to get payload from. Please note, the {@code metadata#readIndex} + * should be set to the beginning of the payload bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if no bytes readable in the + * given one + */ + public static ByteBuf decodePayload(ByteBuf metadata) { + return AuthMetadataCodec.readPayload(metadata); + } + + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero + */ + public static ByteBuf decodeUsername(ByteBuf simpleAuthMetadata) { + return AuthMetadataCodec.readUsername(simpleAuthMetadata); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read password from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if password length is zero + */ + public static ByteBuf decodePassword(ByteBuf simpleAuthMetadata) { + return AuthMetadataCodec.readPassword(simpleAuthMetadata); + } + /** + * Read up to 257 {@code bytes} from the given {@link ByteBuf} where the first byte is username + * length and the subsequent number of bytes equal to decoded length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return {@code char[]} which represents UTF-8 username + */ + public static char[] decodeUsernameAsCharArray(ByteBuf simpleAuthMetadata) { + return AuthMetadataCodec.readUsernameAsCharArray(simpleAuthMetadata); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] decodePasswordAsCharArray(ByteBuf simpleAuthMetadata) { + return AuthMetadataCodec.readPasswordAsCharArray(simpleAuthMetadata); + } + + /** + * Read all the remaining {@code bytes} from the given {@link ByteBuf} where the first byte is + * username length and the subsequent number of bytes equal to decoded length + * + * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] decodeBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { + return AuthMetadataCodec.readBearerTokenAsCharArray(bearerAuthMetadata); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java b/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java new file mode 100644 index 000000000..24e5ff0db --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/security/WellKnownAuthType.java @@ -0,0 +1,147 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata.security; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Auth Types, as defined in the eponymous extension. Such auth types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + * + * @deprecated in favor of {@link io.rsocket.metadata.WellKnownAuthType} + */ +@Deprecated +public enum WellKnownAuthType { + UNPARSEABLE_AUTH_TYPE("UNPARSEABLE_AUTH_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_AUTH_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + SIMPLE("simple", (byte) 0x00), + BEARER("bearer", (byte) 0x01); + // ... reserved for future use ... + + static final WellKnownAuthType[] TYPES_BY_AUTH_ID; + static final Map TYPES_BY_AUTH_STRING; + + static { + // precompute an array of all valid auth ids, filling the blanks with the RESERVED enum + TYPES_BY_AUTH_ID = new WellKnownAuthType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_AUTH_ID, UNKNOWN_RESERVED_AUTH_TYPE); + // also prepare a Map of the types by auth string + TYPES_BY_AUTH_STRING = new LinkedHashMap<>(128); + + for (WellKnownAuthType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_AUTH_ID[value.getIdentifier()] = value; + TYPES_BY_AUTH_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownAuthType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + static io.rsocket.metadata.WellKnownAuthType cast(WellKnownAuthType wellKnownAuthType) { + byte identifier = wellKnownAuthType.identifier; + if (identifier == io.rsocket.metadata.WellKnownAuthType.UNPARSEABLE_AUTH_TYPE.getIdentifier()) { + return io.rsocket.metadata.WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } else if (identifier + == io.rsocket.metadata.WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE.getIdentifier()) { + return io.rsocket.metadata.WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + } else { + return io.rsocket.metadata.WellKnownAuthType.fromIdentifier(identifier); + } + } + + static WellKnownAuthType cast(io.rsocket.metadata.WellKnownAuthType wellKnownAuthType) { + byte identifier = wellKnownAuthType.getIdentifier(); + if (identifier == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE.identifier) { + return WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } else if (identifier == WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE.identifier) { + return WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + } else { + return TYPES_BY_AUTH_ID[identifier]; + } + } + + /** + * Find the {@link WellKnownAuthType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_AUTH_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_AUTH_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownAuthType}, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownAuthType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_AUTH_TYPE; + } + return TYPES_BY_AUTH_ID[id]; + } + + /** + * Find the {@link WellKnownAuthType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownAuthType}, the {@link + * #UNPARSEABLE_AUTH_TYPE} is returned. + * + * @param authType the looked up auth type + * @return the matching {@link WellKnownAuthType}, or {@link #UNPARSEABLE_AUTH_TYPE} if none + * matches + */ + public static WellKnownAuthType fromString(String authType) { + if (authType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_AUTH_TYPE's text has been used + if (authType.equals(UNKNOWN_RESERVED_AUTH_TYPE.str)) { + return UNPARSEABLE_AUTH_TYPE; + } + + return TYPES_BY_AUTH_STRING.getOrDefault(authType, UNPARSEABLE_AUTH_TYPE); + } + + /** @return the byte identifier of the auth type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the auth type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/package-info.java b/rsocket-core/src/main/java/io/rsocket/package-info.java index 243c1ab52..6fe74fb38 100644 --- a/rsocket-core/src/main/java/io/rsocket/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,4 +14,16 @@ * limitations under the License. */ +/** + * Contains key contracts of the RSocket programming model including {@link io.rsocket.RSocket + * RSocket} for performing or handling RSocket interactions, {@link io.rsocket.SocketAcceptor + * SocketAcceptor} for declaring responders, {@link io.rsocket.Payload Payload} for access to the + * content of a payload, and others. + * + *

To connect to or start a server see {@link io.rsocket.core.RSocketConnector RSocketConnector} + * and {@link io.rsocket.core.RSocketServer RSocketServer} in {@link io.rsocket.core}. + */ +@NonNullApi package io.rsocket; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java index 056ded0cd..6b2a7a71b 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,13 @@ import io.rsocket.DuplexConnection; import java.util.function.BiFunction; -/** */ +/** + * Contract to decorate a {@link DuplexConnection} and intercept the sending and receiving of + * RSocket frames at the transport level. + */ public @FunctionalInterface interface DuplexConnectionInterceptor extends BiFunction { + enum Type { SETUP, CLIENT, diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java new file mode 100644 index 000000000..fc032847c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -0,0 +1,56 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; + +/** + * Extends {@link InterceptorRegistry} with methods for building a chain of registered interceptors. + * This is not intended for direct use by applications. + */ +public class InitializingInterceptorRegistry extends InterceptorRegistry { + + public DuplexConnection initConnection( + DuplexConnectionInterceptor.Type type, DuplexConnection connection) { + for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { + connection = interceptor.apply(type, connection); + } + return connection; + } + + public RSocket initRequester(RSocket rsocket) { + for (RSocketInterceptor interceptor : getRequesterInteceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public RSocket initResponder(RSocket rsocket) { + for (RSocketInterceptor interceptor : getResponderInterceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public SocketAcceptor initSocketAcceptor(SocketAcceptor acceptor) { + for (SocketAcceptorInterceptor interceptor : getSocketAcceptorInterceptors()) { + acceptor = interceptor.apply(acceptor); + } + return acceptor; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java new file mode 100644 index 000000000..427fa15ae --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -0,0 +1,120 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +/** + * Provides support for registering interceptors at the following levels: + * + *

    + *
  • {@link #forConnection(DuplexConnectionInterceptor)} -- transport level + *
  • {@link #forSocketAcceptor(SocketAcceptorInterceptor)} -- for accepting new connections + *
  • {@link #forRequester(RSocketInterceptor)} -- for performing of requests + *
  • {@link #forResponder(RSocketInterceptor)} -- for responding to requests + *
+ */ +public class InterceptorRegistry { + private List requesterInteceptors = new ArrayList<>(); + private List responderInterceptors = new ArrayList<>(); + private List socketAcceptorInterceptors = new ArrayList<>(); + private List connectionInterceptors = new ArrayList<>(); + + /** + * Add an {@link RSocketInterceptor} that will decorate the RSocket used for performing requests. + */ + public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { + requesterInteceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forRequester(RSocketInterceptor)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forRequester(Consumer> consumer) { + consumer.accept(requesterInteceptors); + return this; + } + + /** + * Add an {@link RSocketInterceptor} that will decorate the RSocket used for resonding to + * requests. + */ + public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { + responderInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forResponder(RSocketInterceptor)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forResponder(Consumer> consumer) { + consumer.accept(responderInterceptors); + return this; + } + + /** + * Add a {@link SocketAcceptorInterceptor} that will intercept the accepting of new connections. + */ + public InterceptorRegistry forSocketAcceptor(SocketAcceptorInterceptor interceptor) { + socketAcceptorInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forSocketAcceptor(SocketAcceptorInterceptor)} with access to the list of + * existing registrations. + */ + public InterceptorRegistry forSocketAcceptor(Consumer> consumer) { + consumer.accept(socketAcceptorInterceptors); + return this; + } + + /** Add a {@link DuplexConnectionInterceptor}. */ + public InterceptorRegistry forConnection(DuplexConnectionInterceptor interceptor) { + connectionInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forConnection(DuplexConnectionInterceptor)} with access to the list of + * existing registrations. + */ + public InterceptorRegistry forConnection(Consumer> consumer) { + consumer.accept(connectionInterceptors); + return this; + } + + List getRequesterInteceptors() { + return requesterInteceptors; + } + + List getResponderInterceptors() { + return responderInterceptors; + } + + List getConnectionInterceptors() { + return connectionInterceptors; + } + + List getSocketAcceptorInterceptors() { + return socketAcceptorInterceptors; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java new file mode 100644 index 000000000..d7d9742d0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +/** + * Interceptor that adds {@link Flux#limitRate(int, int)} to publishers of outbound streams that + * breaks down or aggregates demand values from the remote end (i.e. {@code REQUEST_N} frames) into + * batches of a uniform size. For example the remote may request {@code Long.MAXVALUE} or it may + * start requesting one at a time, in both cases with the limit set to 64, the publisher will see a + * demand of 64 to start and subsequent batches of 48, i.e. continuing to prefetch and refill an + * internal queue when it falls to 75% full. The high and low tide marks are configurable. + * + *

See static factory methods to create an instance for a requester or for a responder. + * + *

Note: keep in mind that the {@code limitRate} operator always uses requests + * the same request values, even if the remote requests less than the limit. For example given a + * limit of 64, if the remote requests 4, 64 will be prefetched of which 4 will be sent and 60 will + * be cached. + * + * @since 1.0 + */ +public class LimitRateInterceptor implements RSocketInterceptor { + + private final int highTide; + private final int lowTide; + private final boolean requesterProxy; + + private LimitRateInterceptor(int highTide, int lowTide, boolean requesterProxy) { + this.highTide = highTide; + this.lowTide = lowTide; + this.requesterProxy = requesterProxy; + } + + @Override + public RSocket apply(RSocket socket) { + return requesterProxy ? new RequesterProxy(socket) : new ResponderProxy(socket); + } + + /** + * Create an interceptor for an {@code RSocket} that handles request-stream and/or request-channel + * interactions. + * + * @param prefetchRate the prefetch rate to pass to {@link Flux#limitRate(int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forResponder(int prefetchRate) { + return forResponder(prefetchRate, prefetchRate); + } + + /** + * Create an interceptor for an {@code RSocket} that handles request-stream and/or request-channel + * interactions with more control over the overall prefetch rate and replenish threshold. + * + * @param highTide the high tide value to pass to {@link Flux#limitRate(int, int)} + * @param lowTide the low tide value to pass to {@link Flux#limitRate(int, int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forResponder(int highTide, int lowTide) { + return new LimitRateInterceptor(highTide, lowTide, false); + } + + /** + * Create an interceptor for an {@code RSocket} that performs request-channel interactions. + * + * @param prefetchRate the prefetch rate to pass to {@link Flux#limitRate(int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forRequester(int prefetchRate) { + return forRequester(prefetchRate, prefetchRate); + } + + /** + * Create an interceptor for an {@code RSocket} that performs request-channel interactions with + * more control over the overall prefetch rate and replenish threshold. + * + * @param highTide the high tide value to pass to {@link Flux#limitRate(int, int)} + * @param lowTide the low tide value to pass to {@link Flux#limitRate(int, int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forRequester(int highTide, int lowTide) { + return new LimitRateInterceptor(highTide, lowTide, true); + } + + /** Responder side proxy, limits response streams. */ + private class ResponderProxy extends RSocketProxy { + + ResponderProxy(RSocket source) { + super(source); + } + + @Override + public Flux requestStream(Payload payload) { + return super.requestStream(payload).limitRate(highTide, lowTide); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return super.requestChannel(payloads).limitRate(highTide, lowTide); + } + } + + /** Requester side proxy, limits channel request stream. */ + private class RequesterProxy extends RSocketProxy { + + RequesterProxy(RSocket source) { + super(source); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return super.requestChannel(Flux.from(payloads).limitRate(highTide, lowTide)); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java deleted file mode 100644 index e3a19367c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.plugins; - -import io.rsocket.DuplexConnection; -import io.rsocket.RSocket; -import io.rsocket.SocketAcceptor; -import java.util.ArrayList; -import java.util.List; - -public class PluginRegistry { - private List connections = new ArrayList<>(); - private List requesters = new ArrayList<>(); - private List responders = new ArrayList<>(); - private List socketAcceptorInterceptors = new ArrayList<>(); - - public PluginRegistry() {} - - public PluginRegistry(PluginRegistry defaults) { - this.connections.addAll(defaults.connections); - this.requesters.addAll(defaults.requesters); - this.responders.addAll(defaults.responders); - } - - public void addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - connections.add(interceptor); - } - - /** Deprecated. Use {@link #addRequesterPlugin(RSocketInterceptor)} instead */ - @Deprecated - public void addClientPlugin(RSocketInterceptor interceptor) { - addRequesterPlugin(interceptor); - } - - public void addRequesterPlugin(RSocketInterceptor interceptor) { - requesters.add(interceptor); - } - - /** Deprecated. Use {@link #addResponderPlugin(RSocketInterceptor)} instead */ - @Deprecated - public void addServerPlugin(RSocketInterceptor interceptor) { - addResponderPlugin(interceptor); - } - - public void addResponderPlugin(RSocketInterceptor interceptor) { - responders.add(interceptor); - } - - public void addSocketAcceptorPlugin(SocketAcceptorInterceptor interceptor) { - socketAcceptorInterceptors.add(interceptor); - } - - /** Deprecated. Use {@link #applyRequester(RSocket)} instead */ - @Deprecated - public RSocket applyClient(RSocket rSocket) { - return applyRequester(rSocket); - } - - public RSocket applyRequester(RSocket rSocket) { - for (RSocketInterceptor i : requesters) { - rSocket = i.apply(rSocket); - } - - return rSocket; - } - - /** Deprecated. Use {@link #applyResponder(RSocket)} instead */ - @Deprecated - public RSocket applyServer(RSocket rSocket) { - return applyResponder(rSocket); - } - - public RSocket applyResponder(RSocket rSocket) { - for (RSocketInterceptor i : responders) { - rSocket = i.apply(rSocket); - } - - return rSocket; - } - - public SocketAcceptor applySocketAcceptorInterceptor(SocketAcceptor acceptor) { - for (SocketAcceptorInterceptor i : socketAcceptorInterceptors) { - acceptor = i.apply(acceptor); - } - - return acceptor; - } - - public DuplexConnection applyConnection( - DuplexConnectionInterceptor.Type type, DuplexConnection connection) { - for (DuplexConnectionInterceptor i : connections) { - connection = i.apply(type, connection); - } - - return connection; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java b/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java deleted file mode 100644 index 1ac147687..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.plugins; - -/** JVM wide plugins for RSocket */ -public class Plugins { - private static PluginRegistry DEFAULT = new PluginRegistry(); - - private Plugins() {} - - public static void interceptConnection(DuplexConnectionInterceptor interceptor) { - DEFAULT.addConnectionPlugin(interceptor); - } - - public static void interceptClient(RSocketInterceptor interceptor) { - DEFAULT.addClientPlugin(interceptor); - } - - public static void interceptServer(RSocketInterceptor interceptor) { - DEFAULT.addServerPlugin(interceptor); - } - - public static PluginRegistry defaultPlugins() { - return DEFAULT; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java index 0bad0faed..0cd4bb8f6 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,5 +19,10 @@ import io.rsocket.RSocket; import java.util.function.Function; -/** */ +/** + * Contract to decorate an {@link RSocket}, providing a way to intercept interactions. This can be + * applied to a {@link InterceptorRegistry#forRequester(RSocketInterceptor) requester} or {@link + * InterceptorRegistry#forResponder(RSocketInterceptor) responder} {@code RSocket} of a client or + * server. + */ public @FunctionalInterface interface RSocketInterceptor extends Function {} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java index c9201ca5b..6dd850ba9 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,8 +22,8 @@ * Contract to decorate a {@link SocketAcceptor}, providing access to connection {@code setup} * information and the ability to also decorate the sockets for requesting and responding. * - *

This can be used as an alternative to individual requester and responder {@link - * RSocketInterceptor} plugins. + *

This could be used as an alternative to registering an individual "requester" {@code + * RSocketInterceptor} and "responder" {@code RSocketInterceptor}. */ public @FunctionalInterface interface SocketAcceptorInterceptor extends Function {} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java b/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java new file mode 100644 index 000000000..fd9e1f01a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contracts for interception of transports, connections, and requests in in RSocket Java. */ +@NonNullApi +package io.rsocket.plugins; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java index b347642e3..ed9450357 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,16 +20,17 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.exceptions.ConnectionErrorException; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.frame.ResumeFrameFlyweight; -import io.rsocket.frame.ResumeOkFrameFlyweight; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; import io.rsocket.internal.ClientServerInputMultiplexer; import java.time.Duration; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; public class ClientRSocketSession implements RSocketSession> { private static final Logger logger = LoggerFactory.getLogger(ClientRSocketSession.class); @@ -41,13 +42,12 @@ public class ClientRSocketSession implements RSocketSession resumeStrategy, + Retry retry, ResumableFramesStore resumableFramesStore, Duration resumeStreamTimeout, boolean cleanupStoreOnKeepAlive) { - this.allocator = allocator; + this.allocator = duplexConnection.alloc(); this.resumableConnection = new ResumableDuplexConnection( "client", @@ -64,24 +64,13 @@ public ClientRSocketSession( .flatMap( err -> { logger.debug("Client session connection error. Starting new connection"); - ResumeStrategy reconnectOnError = resumeStrategy.get(); - ClientResume clientResume = new ClientResume(resumeSessionDuration, resumeToken); AtomicBoolean once = new AtomicBoolean(); return newConnection .delaySubscription( once.compareAndSet(false, true) - ? reconnectOnError.apply(clientResume, err) + ? retry.generateCompanion(Flux.just(new RetrySignal(err))) : Mono.empty()) - .retryWhen( - errors -> - errors - .doOnNext( - retryErr -> - logger.debug("Resumption reconnection error", retryErr)) - .flatMap( - retryErr -> - Mono.from(reconnectOnError.apply(clientResume, retryErr)) - .doOnNext(v -> logger.debug("Retrying with: {}", v)))) + .retryWhen(retry) .timeout(resumeSessionDuration); }) .map(ClientServerInputMultiplexer::new) @@ -97,7 +86,7 @@ public ClientRSocketSession( position); /*Connection is established again: send RESUME frame to server, listen for RESUME_OK*/ sendFrame( - ResumeFrameFlyweight.encode( + ResumeFrameCodec.encode( allocator, /*retain so token is not released once sent as part of resume frame*/ resumeToken.retain(), @@ -134,7 +123,7 @@ public ClientRSocketSession resumeWith(ByteBuf resumeOkFrame) { .onErrorResume( err -> sendFrame( - ErrorFrameFlyweight.encode( + ErrorFrameCodec.encode( allocator, 0, errorFrameThrowable(remoteImpliedPos))) .then(Mono.fromRunnable(resumableConnection::dispose)) /*Resumption is impossible: no need to return control to ResumableConnection*/ @@ -168,7 +157,7 @@ private Mono sendFrame(ByteBuf frame) { } private static long remoteImpliedPos(ByteBuf resumeOkFrame) { - return ResumeOkFrameFlyweight.lastReceivedClientPos(resumeOkFrame); + return ResumeOkFrameCodec.lastReceivedClientPos(resumeOkFrame); } private static long remotePos(ByteBuf resumeOkFrame) { @@ -178,4 +167,28 @@ private static long remotePos(ByteBuf resumeOkFrame) { private static ConnectionErrorException errorFrameThrowable(long impliedPos) { return new ConnectionErrorException("resumption_server_pos=[" + impliedPos + "]"); } + + private static class RetrySignal implements Retry.RetrySignal { + + private final Throwable ex; + + RetrySignal(Throwable ex) { + this.ex = ex; + } + + @Override + public long totalRetries() { + return 0; + } + + @Override + public long totalRetriesInARow() { + return 0; + } + + @Override + public Throwable failure() { + return ex; + } + } } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java index b46ac864b..461be02d2 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ExponentialBackoffResumeStrategy.java @@ -21,7 +21,13 @@ import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; +/** + * @deprecated as of 1.0 RC7 in favor of passing {@link Retry#backoff(long, Duration)} to {@link + * io.rsocket.core.Resume#retry(Retry)}. + */ +@Deprecated public class ExponentialBackoffResumeStrategy implements ResumeStrategy { private volatile Duration next; private final Duration firstBackoff; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java index abfefe0b1..bd447c8a9 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/PeriodicResumeStrategy.java @@ -19,7 +19,13 @@ import java.time.Duration; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; +/** + * @deprecated as of 1.0 RC7 in favor of passing {@link Retry#fixedDelay(long, Duration)} to {@link + * io.rsocket.core.Resume#retry(Retry)}. + */ +@Deprecated public class PeriodicResumeStrategy implements ResumeStrategy { private final Duration interval; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java index 49401d560..461d71228 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java @@ -17,9 +17,10 @@ package io.rsocket.resume; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameHeaderCodec; import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.Queue; @@ -105,6 +106,11 @@ public ResumableDuplexConnection( reconnect(duplexConnection); } + @Override + public ByteBufAllocator alloc() { + return curConnection.alloc(); + } + public void disconnect() { DuplexConnection c = this.curConnection; if (c != null) { @@ -217,6 +223,10 @@ public boolean isDisposed() { } private void sendFrame(ByteBuf f) { + if (disposed.get()) { + f.release(); + return; + } /*resuming from store so no need to save again*/ if (state != State.RESUME && isResumableFrame(f)) { resumeSaveFrames.onNext(f); @@ -363,7 +373,7 @@ private void releaseFramesToPosition(long remoteImpliedPos) { } static boolean isResumableFrame(ByteBuf frame) { - switch (FrameHeaderFlyweight.nativeFrameType(frame)) { + switch (FrameHeaderCodec.nativeFrameType(frame)) { case REQUEST_CHANNEL: case REQUEST_STREAM: case REQUEST_RESPONSE: diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java index 903431192..d9dec9f54 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStrategy.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,12 @@ import java.util.function.BiFunction; import org.reactivestreams.Publisher; +import reactor.util.retry.Retry; +/** + * @deprecated as of 1.0 RC7 in favor of using {@link io.rsocket.core.Resume#retry(Retry)} via + * {@link io.rsocket.core.RSocketConnector} or {@link io.rsocket.core.RSocketServer}. + */ +@Deprecated @FunctionalInterface public interface ResumeStrategy extends BiFunction> {} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java index 1a0605497..b54ce644f 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java @@ -20,9 +20,9 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.exceptions.RejectedResumeException; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.frame.ResumeFrameFlyweight; -import io.rsocket.frame.ResumeOkFrameFlyweight; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; import java.time.Duration; import java.util.function.Function; import org.slf4j.Logger; @@ -43,13 +43,12 @@ public class ServerRSocketSession implements RSocketSession { public ServerRSocketSession( DuplexConnection duplexConnection, - ByteBufAllocator allocator, Duration resumeSessionDuration, Duration resumeStreamTimeout, Function resumeStoreFactory, ByteBuf resumeToken, boolean cleanupStoreOnKeepAlive) { - this.allocator = allocator; + this.allocator = duplexConnection.alloc(); this.resumeToken = resumeToken; this.resumableConnection = new ResumableDuplexConnection( @@ -104,12 +103,10 @@ public ServerRSocketSession resumeWith(ByteBuf resumeFrame) { remotePos, remoteImpliedPos, pos -> - pos.flatMap( - impliedPos -> sendFrame(ResumeOkFrameFlyweight.encode(allocator, impliedPos))) + pos.flatMap(impliedPos -> sendFrame(ResumeOkFrameCodec.encode(allocator, impliedPos))) .onErrorResume( err -> - sendFrame( - ErrorFrameFlyweight.encode(allocator, 0, errorFrameThrowable(err))) + sendFrame(ErrorFrameCodec.encode(allocator, 0, errorFrameThrowable(err))) .then(Mono.fromRunnable(resumableConnection::dispose)) /*Resumption is impossible: no need to return control to ResumableConnection*/ .then(Mono.never()))); @@ -137,11 +134,11 @@ private Mono sendFrame(ByteBuf frame) { } private static long remotePos(ByteBuf resumeFrame) { - return ResumeFrameFlyweight.firstAvailableClientPos(resumeFrame); + return ResumeFrameCodec.firstAvailableClientPos(resumeFrame); } private static long remoteImpliedPos(ByteBuf resumeFrame) { - return ResumeFrameFlyweight.lastReceivedServerPos(resumeFrame); + return ResumeFrameCodec.lastReceivedServerPos(resumeFrame); } private static RejectedResumeException errorFrameThrowable(Throwable err) { diff --git a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java index 3882103a0..1d5c23bd6 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java @@ -19,7 +19,7 @@ import io.netty.buffer.ByteBuf; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import javax.annotation.Nullable; +import reactor.util.annotation.Nullable; public class SessionManager { private volatile boolean isDisposed; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/package-info.java b/rsocket-core/src/main/java/io/rsocket/resume/package-info.java index 57027bee2..98744386a 100644 --- a/rsocket-core/src/main/java/io/rsocket/resume/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/package-info.java @@ -14,5 +14,14 @@ * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** + * Contains support classes for the RSocket resume capability. + * + * @see Resuming + * Operation + */ +@NonNullApi package io.rsocket.resume; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/transport/package-info.java b/rsocket-core/src/main/java/io/rsocket/transport/package-info.java index 86e7c311a..00536122a 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,5 +14,8 @@ * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** Client and server transport contracts for pluggable transports. */ +@NonNullApi package io.rsocket.transport; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java b/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java deleted file mode 100644 index ec3d4ab3c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.Optional; -import java.util.ServiceLoader; - -/** Maps a {@link URI} to a {@link ClientTransport} or {@link ServerTransport}. */ -public interface UriHandler { - - /** - * Load all registered instances of {@code UriHandler}. - * - * @return all registered instances of {@code UriHandler} - */ - static ServiceLoader loadServices() { - return ServiceLoader.load(UriHandler.class); - } - - /** - * Returns an implementation of {@link ClientTransport} unambiguously mapped to a {@link URI}, - * otherwise {@link Optional#EMPTY}. - * - * @param uri the uri to map - * @return an implementation of {@link ClientTransport} unambiguously mapped to a {@link URI}, * - * otherwise {@link Optional#EMPTY} - * @throws NullPointerException if {@code uri} is {@code null} - */ - Optional buildClient(URI uri); - - /** - * Returns an implementation of {@link ServerTransport} unambiguously mapped to a {@link URI}, - * otherwise {@link Optional#EMPTY}. - * - * @param uri the uri to map - * @return an implementation of {@link ServerTransport} unambiguously mapped to a {@link URI}, * - * otherwise {@link Optional#EMPTY} - * @throws NullPointerException if {@code uri} is {@code null} - */ - Optional buildServer(URI uri); -} diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java b/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java deleted file mode 100644 index 204c5d1ea..000000000 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import static io.rsocket.uri.UriHandler.loadServices; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.ServiceLoader; -import reactor.core.publisher.Mono; - -/** - * Registry for looking up transports by URI. - * - *

Uses the Jar Services mechanism with services defined by {@link UriHandler}. - */ -public class UriTransportRegistry { - private static final ClientTransport FAILED_CLIENT_LOOKUP = - (mtu) -> Mono.error(new UnsupportedOperationException()); - private static final ServerTransport FAILED_SERVER_LOOKUP = - (acceptor, mtu) -> Mono.error(new UnsupportedOperationException()); - - private List handlers; - - public UriTransportRegistry(ServiceLoader services) { - handlers = new ArrayList<>(); - services.forEach(handlers::add); - } - - public static UriTransportRegistry fromServices() { - ServiceLoader services = loadServices(); - - return new UriTransportRegistry(services); - } - - public static ClientTransport clientForUri(String uri) { - return UriTransportRegistry.fromServices().findClient(uri); - } - - public static ServerTransport serverForUri(String uri) { - return UriTransportRegistry.fromServices().findServer(uri); - } - - private ClientTransport findClient(String uriString) { - URI uri = URI.create(uriString); - - for (UriHandler h : handlers) { - Optional r = h.buildClient(uri); - if (r.isPresent()) { - return r.get(); - } - } - - return FAILED_CLIENT_LOOKUP; - } - - private ServerTransport findServer(String uriString) { - URI uri = URI.create(uriString); - - for (UriHandler h : handlers) { - Optional r = h.buildServer(uri); - if (r.isPresent()) { - return r.get(); - } - } - - return FAILED_SERVER_LOOKUP; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java index b91cf8ac6..4cf33fa86 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -21,13 +21,14 @@ import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; import io.netty.util.Recycler; import io.netty.util.Recycler.Handle; import io.rsocket.Payload; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.Charset; -import javax.annotation.Nullable; +import reactor.util.annotation.Nullable; public final class ByteBufPayload extends AbstractReferenceCounted implements Payload { private static final Recycler RECYCLER = @@ -112,9 +113,10 @@ public static Payload create(ByteBuf data) { public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { ByteBufPayload payload = RECYCLER.get(); - payload.setRefCnt(1); payload.data = data; payload.metadata = metadata; + // unsure data and metadata is set before refCnt change + payload.setRefCnt(1); return payload; } @@ -126,26 +128,31 @@ public static Payload create(Payload payload) { @Override public boolean hasMetadata() { + ensureAccessible(); return metadata != null; } @Override public ByteBuf sliceMetadata() { + ensureAccessible(); return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); } @Override public ByteBuf data() { + ensureAccessible(); return data; } @Override public ByteBuf metadata() { + ensureAccessible(); return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; } @Override public ByteBuf sliceData() { + ensureAccessible(); return data.slice(); } @@ -163,6 +170,7 @@ public ByteBufPayload retain(int increment) { @Override public ByteBufPayload touch() { + ensureAccessible(); data.touch(); if (metadata != null) { metadata.touch(); @@ -172,6 +180,7 @@ public ByteBufPayload touch() { @Override public ByteBufPayload touch(Object hint) { + ensureAccessible(); data.touch(hint); if (metadata != null) { metadata.touch(hint); @@ -189,4 +198,22 @@ protected void deallocate() { } handle.recycle(this); } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link ByteBufPayload#ensureAccessible()} to try to guard against using the + * buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java new file mode 100644 index 000000000..328fb8435 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java @@ -0,0 +1,210 @@ +package io.rsocket.util; + +import static io.netty.util.internal.StringUtil.isSurrogate; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.MathUtil; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.util.Arrays; + +public class CharByteBufUtil { + + private static final byte WRITE_UTF_UNKNOWN = (byte) '?'; + + private CharByteBufUtil() {} + + /** + * Returns the exact bytes length of UTF8 character sequence. + * + *

This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[])}. + */ + public static int utf8Bytes(final char[] seq) { + return utf8ByteCount(seq, 0, seq.length); + } + + /** + * This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[], int, + * int)}. + */ + public static int utf8Bytes(final char[] seq, int start, int end) { + return utf8ByteCount(checkCharSequenceBounds(seq, start, end), start, end); + } + + private static int utf8ByteCount(final char[] seq, int start, int end) { + int i = start; + // ASCII fast path + while (i < end && seq[i] < 0x80) { + ++i; + } + // !ASCII is packed in a separate method to let the ASCII case be smaller + return i < end ? (i - start) + utf8BytesNonAscii(seq, i, end) : i - start; + } + + private static int utf8BytesNonAscii(final char[] seq, final int start, final int end) { + int encodedLength = 0; + for (int i = start; i < end; i++) { + final char c = seq[i]; + // making it 100% branchless isn't rewarding due to the many bit operations necessary! + if (c < 0x800) { + // branchless version of: (c <= 127 ? 0:1) + 1 + encodedLength += ((0x7f - c) >>> 31) + 1; + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + encodedLength++; + // WRITE_UTF_UNKNOWN + continue; + } + final char c2; + try { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. + c2 = seq[++i]; + } catch (IndexOutOfBoundsException ignored) { + encodedLength++; + // WRITE_UTF_UNKNOWN + break; + } + if (!Character.isLowSurrogate(c2)) { + // WRITE_UTF_UNKNOWN + (Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2) + encodedLength += 2; + continue; + } + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + encodedLength += 4; + } else { + encodedLength += 3; + } + } + return encodedLength; + } + + private static char[] checkCharSequenceBounds(char[] seq, int start, int end) { + if (MathUtil.isOutOfBounds(start, end - start, seq.length)) { + throw new IndexOutOfBoundsException( + "expected: 0 <= start(" + + start + + ") <= end (" + + end + + ") <= seq.length(" + + seq.length + + ')'); + } + return seq; + } + + /** + * Encode a {@code char[]} in UTF-8 and write it + * into {@link ByteBuf}. + * + *

This method returns the actual number of bytes written. + */ + public static int writeUtf8(ByteBuf buf, char[] seq) { + return writeUtf8(buf, seq, 0, seq.length); + } + + /** + * Equivalent to {@link #writeUtf8(ByteBuf, char[]) writeUtf8(buf, seq.subSequence(start, end), + * reserveBytes)} but avoids subsequence object allocation if possible. + * + * @return actual number of bytes written + */ + public static int writeUtf8(ByteBuf buf, char[] seq, int start, int end) { + return writeUtf8(buf, buf.writerIndex(), checkCharSequenceBounds(seq, start, end), start, end); + } + + // Fast-Path implementation + static int writeUtf8(ByteBuf buffer, int writerIndex, char[] seq, int start, int end) { + int oldWriterIndex = writerIndex; + + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = start; i < end; i++) { + char c = seq[i]; + if (c < 0x80) { + buffer.setByte(writerIndex++, (byte) c); + } else if (c < 0x800) { + buffer.setByte(writerIndex++, (byte) (0xc0 | (c >> 6))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + continue; + } + final char c2; + if (seq.length > ++i) { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. If an IndexOutOfBoundsException is thrown we + // will + // re-throw a more informative exception describing the problem. + c2 = seq[i]; + } else { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + break; + } + // Extra method to allow inlining the rest of writeUtf8 which is the most likely code path. + writerIndex = writeUtf8Surrogate(buffer, writerIndex, c, c2); + } else { + buffer.setByte(writerIndex++, (byte) (0xe0 | (c >> 12))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((c >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } + } + buffer.writerIndex(writerIndex); + return writerIndex - oldWriterIndex; + } + + private static int writeUtf8Surrogate(ByteBuf buffer, int writerIndex, char c, char c2) { + if (!Character.isLowSurrogate(c2)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + buffer.setByte(writerIndex++, Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2); + return writerIndex; + } + int codePoint = Character.toCodePoint(c, c2); + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer.setByte(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + return writerIndex; + } + + public static char[] readUtf8(ByteBuf byteBuf, int length) { + CharsetDecoder charsetDecoder = CharsetUtil.UTF_8.newDecoder(); + int en = (int) (length * (double) charsetDecoder.maxCharsPerByte()); + char[] ca = new char[en]; + + CharBuffer charBuffer = CharBuffer.wrap(ca); + ByteBuffer byteBuffer = + byteBuf.nioBufferCount() == 1 + ? byteBuf.internalNioBuffer(byteBuf.readerIndex(), length) + : byteBuf.nioBuffer(byteBuf.readerIndex(), length); + byteBuffer.mark(); + try { + CoderResult cr = charsetDecoder.decode(byteBuffer, charBuffer, true); + if (!cr.isUnderflow()) cr.throwException(); + cr = charsetDecoder.flush(charBuffer); + if (!cr.isUnderflow()) cr.throwException(); + + byteBuffer.reset(); + byteBuf.skipBytes(length); + + return safeTrim(charBuffer.array(), charBuffer.position()); + } catch (CharacterCodingException x) { + // Substitution is always enabled, + // so this shouldn't happen + throw new IllegalStateException("unable to decode char array from the given buffer", x); + } + } + + private static char[] safeTrim(char[] ca, int len) { + if (len == ca.length) return ca; + else return Arrays.copyOf(ca, len); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/ConnectionUtils.java b/rsocket-core/src/main/java/io/rsocket/util/ConnectionUtils.java deleted file mode 100644 index dd8bbf907..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/ConnectionUtils.java +++ /dev/null @@ -1,17 +0,0 @@ -package io.rsocket.util; - -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.internal.ClientServerInputMultiplexer; -import reactor.core.publisher.Mono; - -public class ConnectionUtils { - - public static Mono sendError( - ByteBufAllocator allocator, ClientServerInputMultiplexer multiplexer, Exception exception) { - return multiplexer - .asSetupConnection() - .sendOne(ErrorFrameFlyweight.encode(allocator, 0, exception)) - .onErrorResume(err -> Mono.empty()); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java index ec73399f1..58f282110 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -23,7 +23,7 @@ import java.nio.CharBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import javax.annotation.Nullable; +import reactor.util.annotation.Nullable; /** * An implementation of {@link Payload}. This implementation is not thread-safe, and hence diff --git a/rsocket-core/src/main/java/io/rsocket/util/DisposableUtils.java b/rsocket-core/src/main/java/io/rsocket/util/DisposableUtils.java deleted file mode 100644 index c87a08220..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/DisposableUtils.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.util; - -import java.util.Arrays; -import reactor.core.Disposable; - -/** Utilities for working with the {@link Disposable} type. */ -public final class DisposableUtils { - - private DisposableUtils() {} - - /** - * Calls the {@link Disposable#dispose()} method if the instance is not null. If any exceptions - * are thrown during disposal, suppress them. - * - * @param disposables the {@link Disposable}s to dispose - */ - public static void disposeQuietly(Disposable... disposables) { - Arrays.stream(disposables) - .forEach( - disposable -> { - try { - if (disposable != null) { - disposable.dispose(); - } - } catch (RuntimeException e) { - // Suppress any exceptions during disposal - } - }); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/DuplexConnectionProxy.java b/rsocket-core/src/main/java/io/rsocket/util/DuplexConnectionProxy.java deleted file mode 100644 index fa19553a7..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/DuplexConnectionProxy.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.util; - -import io.netty.buffer.ByteBuf; -import io.rsocket.DuplexConnection; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -public class DuplexConnectionProxy implements DuplexConnection { - private final DuplexConnection connection; - - public DuplexConnectionProxy(DuplexConnection connection) { - this.connection = connection; - } - - @Override - public Mono send(Publisher frames) { - return connection.send(frames); - } - - @Override - public Flux receive() { - return connection.receive(); - } - - @Override - public double availability() { - return connection.availability(); - } - - @Override - public Mono onClose() { - return connection.onClose(); - } - - @Override - public void dispose() { - connection.dispose(); - } - - @Override - public boolean isDisposed() { - return connection.isDisposed(); - } - - public DuplexConnection delegate() { - return connection; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/MultiSubscriberRSocket.java b/rsocket-core/src/main/java/io/rsocket/util/MultiSubscriberRSocket.java deleted file mode 100644 index c2db6c238..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/MultiSubscriberRSocket.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2015-2019 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.util; - -import io.rsocket.Payload; -import io.rsocket.RSocket; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -public class MultiSubscriberRSocket extends RSocketProxy { - public MultiSubscriberRSocket(RSocket source) { - super(source); - } - - @Override - public Mono fireAndForget(Payload payload) { - return Mono.defer(() -> super.fireAndForget(payload)); - } - - @Override - public Mono requestResponse(Payload payload) { - return Mono.defer(() -> super.requestResponse(payload)); - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.defer(() -> super.requestStream(payload)); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.defer(() -> super.requestChannel(payloads)); - } - - @Override - public Mono metadataPush(Payload payload) { - return Mono.defer(() -> super.metadataPush(payload)); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/RecyclerFactory.java b/rsocket-core/src/main/java/io/rsocket/util/RecyclerFactory.java deleted file mode 100644 index 30385195c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/RecyclerFactory.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.util; - -import io.netty.util.Recycler; -import io.netty.util.Recycler.Handle; -import java.util.Objects; -import java.util.function.Function; - -/** A factory for creating {@link Recycler}s. */ -public final class RecyclerFactory { - - /** - * Creates a new {@link Recycler}. - * - * @param newObjectCreator the {@link Function} to create a new object - * @param the type being recycled. - * @return the {@link Recycler} - * @throws NullPointerException if {@code newObjectCreator} is {@code null} - */ - public static Recycler createRecycler(Function, T> newObjectCreator) { - Objects.requireNonNull(newObjectCreator, "newObjectCreator must not be null"); - - return new Recycler() { - - @Override - protected T newObject(Handle handle) { - return newObjectCreator.apply(handle); - } - }; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/package-info.java b/rsocket-core/src/main/java/io/rsocket/util/package-info.java index 79123d3b2..2fac3327f 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/util/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,5 +14,8 @@ * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** Shared utility classes and {@link io.rsocket.Payload} implementations. */ +@NonNullApi package io.rsocket.util; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java deleted file mode 100644 index a739f2e67..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTest.java +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static io.rsocket.frame.FrameType.*; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.*; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import io.rsocket.util.MultiSubscriberRSocket; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -import org.assertj.core.api.Assertions; -import org.junit.Rule; -import org.junit.Test; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.publisher.BaseSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.UnicastProcessor; - -public class RSocketRequesterTest { - - @Rule public final ClientSocketRule rule = new ClientSocketRule(); - - @Test(timeout = 2_000) - public void testInvalidFrameOnStream0() { - rule.connection.addToReceivedBuffer( - RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 0, 10)); - assertThat("Unexpected errors.", rule.errors, hasSize(1)); - assertThat( - "Unexpected error received.", - rule.errors, - contains(instanceOf(IllegalStateException.class))); - } - - @Test(timeout = 2_000) - public void testStreamInitialN() { - Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); - - BaseSubscriber subscriber = - new BaseSubscriber() { - @Override - protected void hookOnSubscribe(Subscription subscription) { - // don't request here - // subscription.request(3); - } - }; - stream.subscribe(subscriber); - - subscriber.request(5); - - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> frameType(f) != KEEPALIVE) - .collect(Collectors.toList()); - - assertThat("sent frame count", sent.size(), is(1)); - - ByteBuf f = sent.get(0); - - assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); - assertThat("initial request n", RequestStreamFrameFlyweight.initialRequestN(f), is(5)); - } - - @Test(timeout = 2_000) - public void testHandleSetupException() { - rule.connection.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("boom"))); - assertThat("Unexpected errors.", rule.errors, hasSize(1)); - assertThat( - "Unexpected error received.", - rule.errors, - contains(instanceOf(RejectedSetupException.class))); - } - - @Test(timeout = 2_000) - public void testHandleApplicationException() { - rule.connection.clearSendReceiveBuffers(); - Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber responseSub = TestSubscriber.create(); - response.subscribe(responseSub); - - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, new ApplicationErrorException("error"))); - - verify(responseSub).onError(any(ApplicationErrorException.class)); - } - - @Test(timeout = 2_000) - public void testHandleValidFrame() { - Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeNext( - ByteBufAllocator.DEFAULT, streamId, EmptyPayload.INSTANCE)); - - verify(sub).onComplete(); - } - - @Test(timeout = 2_000) - public void testRequestReplyWithCancel() { - Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - - try { - response.block(Duration.ofMillis(100)); - } catch (IllegalStateException ise) { - } - - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> frameType(f) != KEEPALIVE) - .collect(Collectors.toList()); - - assertThat( - "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); - assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); - } - - @Test(timeout = 2_000) - public void testRequestReplyErrorOnSend() { - rule.connection.setAvailability(0); // Fails send - Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); - Subscriber responseSub = TestSubscriber.create(10); - response.subscribe(responseSub); - - this.rule.assertNoConnectionErrors(); - - verify(responseSub).onSubscribe(any(Subscription.class)); - - // TODO this should get the error reported through the response subscription - // verify(responseSub).onError(any(RuntimeException.class)); - } - - @Test(timeout = 2_000) - public void testLazyRequestResponse() { - Publisher response = - new MultiSubscriberRSocket(rule.socket).requestResponse(EmptyPayload.INSTANCE); - int streamId = sendRequestResponse(response); - rule.connection.clearSendReceiveBuffers(); - int streamId2 = sendRequestResponse(response); - assertThat("Stream ID reused.", streamId2, not(equalTo(streamId))); - } - - @Test - public void testChannelRequestCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - Flux request = Flux.never().doOnCancel(cancelled::onComplete); - rule.socket.requestChannel(request).subscribe().dispose(); - Flux.first( - cancelled, - Flux.error(new IllegalStateException("Channel request not cancelled")) - .delaySubscription(Duration.ofSeconds(1))) - .blockFirst(); - } - - @Test - public void testChannelRequestServerSideCancellation() { - MonoProcessor cancelled = MonoProcessor.create(); - UnicastProcessor request = UnicastProcessor.create(); - request.onNext(EmptyPayload.INSTANCE); - rule.socket.requestChannel(request).subscribe(cancelled); - int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); - rule.connection.addToReceivedBuffer( - CancelFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId)); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeComplete(ByteBufAllocator.DEFAULT, streamId)); - Flux.first( - cancelled, - Flux.error(new IllegalStateException("Channel request not cancelled")) - .delaySubscription(Duration.ofSeconds(1))) - .blockFirst(); - - Assertions.assertThat(request.isDisposed()).isTrue(); - } - - public int sendRequestResponse(Publisher response) { - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - PayloadFrameFlyweight.encodeNextComplete( - ByteBufAllocator.DEFAULT, streamId, EmptyPayload.INSTANCE)); - verify(sub).onNext(any(Payload.class)); - verify(sub).onComplete(); - return streamId; - } - - public static class ClientSocketRule extends AbstractSocketRule { - @Override - protected RSocketRequester newRSocket() { - return new RSocketRequester( - ByteBufAllocator.DEFAULT, - connection, - DefaultPayload::create, - throwable -> errors.add(throwable), - StreamIdSupplier.clientSupplier(), - 0, - 0, - null, - RequesterLeaseHandler.None); - } - - public int getStreamIdForRequestType(FrameType expectedFrameType) { - assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); - List framesFound = new ArrayList<>(); - for (ByteBuf frame : connection.getSent()) { - FrameType frameType = frameType(frame); - if (frameType == expectedFrameType) { - return FrameHeaderFlyweight.streamId(frame); - } - framesFound.add(frameType); - } - throw new AssertionError( - "No frames sent with frame type: " - + expectedFrameType - + ", frames found: " - + framesFound); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java deleted file mode 100644 index b6281414d..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketResponderTest.java +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.frame.FrameHeaderFlyweight.frameType; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.frame.*; -import io.rsocket.lease.ResponderLeaseHandler; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import java.util.Collection; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.Mono; - -public class RSocketResponderTest { - - @Rule public final ServerSocketRule rule = new ServerSocketRule(); - - @Test(timeout = 2000) - @Ignore - public void testHandleKeepAlive() throws Exception { - rule.connection.addToReceivedBuffer( - KeepAliveFrameFlyweight.encode(ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER)); - ByteBuf sent = rule.connection.awaitSend(); - assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); - /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ - assertThat( - "Unexpected keep-alive frame respond flag.", - KeepAliveFrameFlyweight.respondFlag(sent), - is(false)); - } - - @Test(timeout = 2000) - @Ignore - public void testHandleResponseFrameNoError() throws Exception { - final int streamId = 4; - rule.connection.clearSendReceiveBuffers(); - - rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - Collection> sendSubscribers = rule.connection.getSendSubscribers(); - assertThat("Request not sent.", sendSubscribers, hasSize(1)); - assertThat("Unexpected error.", rule.errors, is(empty())); - Subscriber sendSub = sendSubscribers.iterator().next(); - assertThat( - "Unexpected frame sent.", - frameType(rule.connection.awaitSend()), - anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); - } - - @Test(timeout = 2000) - @Ignore - public void testHandlerEmitsError() throws Exception { - final int streamId = 4; - rule.sendRequest(streamId, FrameType.REQUEST_STREAM); - assertThat("Unexpected error.", rule.errors, is(empty())); - assertThat( - "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); - } - - @Test(timeout = 2_0000) - public void testCancel() { - final int streamId = 4; - final AtomicBoolean cancelled = new AtomicBoolean(); - rule.setAcceptingSocket( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.never().doOnCancel(() -> cancelled.set(true)); - } - }); - rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - assertThat("Unexpected error.", rule.errors, is(empty())); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - - rule.connection.addToReceivedBuffer( - CancelFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId)); - - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - assertThat("Subscription not cancelled.", cancelled.get(), is(true)); - } - - public static class ServerSocketRule extends AbstractSocketRule { - - private RSocket acceptingSocket; - - @Override - protected void init() { - acceptingSocket = - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - }; - super.init(); - } - - public void setAcceptingSocket(RSocket acceptingSocket) { - this.acceptingSocket = acceptingSocket; - connection = new TestDuplexConnection(); - connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); - super.init(); - } - - public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { - this.acceptingSocket = acceptingSocket; - connection = new TestDuplexConnection(); - connection.setInitialSendRequestN(prefetch); - connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); - super.init(); - } - - @Override - protected RSocketResponder newRSocket() { - return new RSocketResponder( - ByteBufAllocator.DEFAULT, - connection, - acceptingSocket, - DefaultPayload::create, - throwable -> errors.add(throwable), - ResponderLeaseHandler.None); - } - - private void sendRequest(int streamId, FrameType frameType) { - ByteBuf request; - - switch (frameType) { - case REQUEST_CHANNEL: - request = - RequestChannelFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, false, 1, EmptyPayload.INSTANCE); - break; - case REQUEST_STREAM: - request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, 1, EmptyPayload.INSTANCE); - break; - case REQUEST_RESPONSE: - request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, streamId, false, EmptyPayload.INSTANCE); - break; - default: - throw new IllegalArgumentException("unsupported type: " + frameType); - } - - connection.addToReceivedBuffer(request); - connection.addToReceivedBuffer( - RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, streamId, 2)); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketTest.java deleted file mode 100644 index 5d9672fb9..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.lease.RequesterLeaseHandler; -import io.rsocket.lease.ResponderLeaseHandler; -import io.rsocket.test.util.LocalDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.DefaultPayload; -import io.rsocket.util.EmptyPayload; -import java.util.ArrayList; -import org.hamcrest.MatcherAssert; -import org.junit.Assert; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.DirectProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -public class RSocketTest { - - @Rule public final SocketRule rule = new SocketRule(); - - public static void assertError(String s, String mode, ArrayList errors) { - for (Throwable t : errors) { - if (t.toString().equals(s)) { - return; - } - } - - Assert.fail("Expected " + mode + " connection error: " + s + " other errors " + errors.size()); - } - - @Test(timeout = 2_000) - public void testRequestReplyNoError() { - StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) - .expectNextCount(1) - .expectComplete() - .verify(); - } - - @Test(timeout = 2000) - public void testHandlerEmitsError() { - rule.setRequestAcceptor( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.error(new NullPointerException("Deliberate exception.")); - } - }); - Subscriber subscriber = TestSubscriber.create(); - rule.crs.requestResponse(EmptyPayload.INSTANCE).subscribe(subscriber); - verify(subscriber).onError(any(ApplicationErrorException.class)); - - // Client sees error through normal API - rule.assertNoClientErrors(); - - rule.assertServerError("java.lang.NullPointerException: Deliberate exception."); - } - - @Test(timeout = 2000) - public void testStream() throws Exception { - Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); - StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); - } - - @Test(timeout = 2000) - public void testChannel() throws Exception { - Flux requests = - Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); - Flux responses = rule.crs.requestChannel(requests); - StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); - } - - public static class SocketRule extends ExternalResource { - - DirectProcessor serverProcessor; - DirectProcessor clientProcessor; - private RSocketRequester crs; - - @SuppressWarnings("unused") - private RSocketResponder srs; - - private RSocket requestAcceptor; - private ArrayList clientErrors = new ArrayList<>(); - private ArrayList serverErrors = new ArrayList<>(); - - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - init(); - base.evaluate(); - } - }; - } - - protected void init() { - serverProcessor = DirectProcessor.create(); - clientProcessor = DirectProcessor.create(); - - LocalDuplexConnection serverConnection = - new LocalDuplexConnection("server", clientProcessor, serverProcessor); - LocalDuplexConnection clientConnection = - new LocalDuplexConnection("client", serverProcessor, clientProcessor); - - requestAcceptor = - null != requestAcceptor - ? requestAcceptor - : new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 10) - .map( - i -> DefaultPayload.create("server got -> [" + payload.toString() + "]")); - } - - @Override - public Flux requestChannel(Publisher payloads) { - Flux.from(payloads) - .map( - payload -> - DefaultPayload.create("server got -> [" + payload.toString() + "]")) - .subscribe(); - - return Flux.range(1, 10) - .map( - payload -> - DefaultPayload.create("server got -> [" + payload.toString() + "]")); - } - }; - - srs = - new RSocketResponder( - ByteBufAllocator.DEFAULT, - serverConnection, - requestAcceptor, - DefaultPayload::create, - throwable -> serverErrors.add(throwable), - ResponderLeaseHandler.None); - - crs = - new RSocketRequester( - ByteBufAllocator.DEFAULT, - clientConnection, - DefaultPayload::create, - throwable -> clientErrors.add(throwable), - StreamIdSupplier.clientSupplier(), - 0, - 0, - null, - RequesterLeaseHandler.None); - } - - public void setRequestAcceptor(RSocket requestAcceptor) { - this.requestAcceptor = requestAcceptor; - init(); - } - - public void assertNoErrors() { - assertNoClientErrors(); - assertNoServerErrors(); - } - - public void assertNoClientErrors() { - MatcherAssert.assertThat( - "Unexpected error on the client connection.", clientErrors, is(empty())); - } - - public void assertNoServerErrors() { - MatcherAssert.assertThat( - "Unexpected error on the server connection.", serverErrors, is(empty())); - } - - public void assertClientError(String s) { - assertError(s, "client", this.clientErrors); - } - - public void assertServerError(String s) { - assertError(s, "server", this.serverErrors); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/TestScheduler.java b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java new file mode 100644 index 000000000..7bc98d45d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java @@ -0,0 +1,80 @@ +package io.rsocket; + +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; +import reactor.core.scheduler.Scheduler; +import reactor.util.concurrent.Queues; + +/** + * This is an implementation of scheduler which allows task execution on the caller thread or + * scheduling it for thread which are currently working (with "work stealing" behaviour) + */ +public final class TestScheduler implements Scheduler { + + public static final Scheduler INSTANCE = new TestScheduler(); + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(TestScheduler.class, "wip"); + + final Worker sharedWorker = new TestWorker(this); + final Queue tasks = Queues.unboundedMultiproducer().get(); + + private TestScheduler() {} + + @Override + public Disposable schedule(Runnable task) { + tasks.offer(task); + if (WIP.getAndIncrement(this) != 0) { + return Disposables.never(); + } + + int missed = 1; + + for (; ; ) { + for (; ; ) { + Runnable runnable = tasks.poll(); + + if (runnable == null) { + break; + } + + try { + runnable.run(); + } catch (Throwable t) { + Exceptions.throwIfFatal(t); + } + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + return Disposables.never(); + } + } + } + + @Override + public Worker createWorker() { + return sharedWorker; + } + + static class TestWorker implements Worker { + + final TestScheduler parent; + + TestWorker(TestScheduler parent) { + this.parent = parent; + } + + @Override + public Disposable schedule(Runnable task) { + return parent.schedule(task); + } + + @Override + public void dispose() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..800e5d678 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,153 @@ +package io.rsocket.buffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + private LeaksTrackingByteBufAllocator(ByteBufAllocator delegate) { + this.delegate = delegate; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + Assertions.assertThat(tracker) + .allSatisfy( + buf -> + Assertions.assertThat(buf) + .matches(bb -> bb.refCnt() == 0, "buffer should be released")); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/Tuple3ByteBufTest.java b/rsocket-core/src/test/java/io/rsocket/buffer/Tuple3ByteBufTest.java deleted file mode 100644 index 4515fb29b..000000000 --- a/rsocket-core/src/test/java/io/rsocket/buffer/Tuple3ByteBufTest.java +++ /dev/null @@ -1,98 +0,0 @@ -package io.rsocket.buffer; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import java.nio.ByteBuffer; -import java.nio.charset.Charset; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Assert; -import org.junit.jupiter.api.Test; - -class Tuple3ByteBufTest { - @Test - void testTupleBufferGet() { - ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - ByteBuf one = allocator.directBuffer(9); - - byte[] bytes = new byte[9]; - ThreadLocalRandom.current().nextBytes(bytes); - one.writeBytes(bytes); - - bytes = new byte[8]; - ThreadLocalRandom.current().nextBytes(bytes); - ByteBuf two = Unpooled.wrappedBuffer(bytes); - - bytes = new byte[9]; - ThreadLocalRandom.current().nextBytes(bytes); - ByteBuf three = Unpooled.wrappedBuffer(bytes); - - ByteBuf tuple = TupleByteBuf.of(one, two, three); - - int anInt = tuple.getInt(16); - - long aLong = tuple.getLong(15); - - short aShort = tuple.getShort(8); - - int medium = tuple.getMedium(8); - } - - @Test - void testTuple3BufferSlicing() { - ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - ByteBuf one = allocator.directBuffer(); - ByteBufUtil.writeUtf8(one, "foo"); - - ByteBuf two = allocator.directBuffer(); - ByteBufUtil.writeUtf8(two, "bar"); - - ByteBuf three = allocator.directBuffer(); - ByteBufUtil.writeUtf8(three, "bar"); - - ByteBuf buf = TupleByteBuf.of(one, two, three); - - String s = buf.slice(0, 6).toString(Charset.defaultCharset()); - Assert.assertEquals("foobar", s); - - String s1 = buf.slice(3, 6).toString(Charset.defaultCharset()); - Assert.assertEquals("barbar", s1); - - String s2 = buf.slice(4, 4).toString(Charset.defaultCharset()); - Assert.assertEquals("arba", s2); - } - - @Test - void testTuple3ToNioBuffers() throws Exception { - ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; - ByteBuf one = allocator.directBuffer(); - ByteBufUtil.writeUtf8(one, "one"); - - ByteBuf two = allocator.directBuffer(); - ByteBufUtil.writeUtf8(two, "two"); - - ByteBuf three = allocator.directBuffer(); - ByteBufUtil.writeUtf8(three, "three"); - - ByteBuf buf = TupleByteBuf.of(one, two, three); - ByteBuffer[] byteBuffers = buf.nioBuffers(); - - Assert.assertEquals(3, byteBuffers.length); - - ByteBuffer bb = byteBuffers[0]; - byte[] dst = new byte[bb.remaining()]; - bb.get(dst); - Assert.assertEquals("one", new String(dst, "UTF-8")); - - bb = byteBuffers[1]; - dst = new byte[bb.remaining()]; - bb.get(dst); - Assert.assertEquals("two", new String(dst, "UTF-8")); - - bb = byteBuffers[2]; - dst = new byte[bb.remaining()]; - bb.get(dst); - Assert.assertEquals("three", new String(dst, "UTF-8")); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java similarity index 75% rename from rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java rename to rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java index 22568bfcc..ac5832aaa 100644 --- a/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -14,12 +14,13 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestSubscriber; -import java.util.concurrent.ConcurrentLinkedQueue; -import org.junit.Assert; import org.junit.rules.ExternalResource; import org.junit.runner.Description; import org.junit.runners.model.Statement; @@ -30,16 +31,16 @@ public abstract class AbstractSocketRule extends ExternalReso protected TestDuplexConnection connection; protected Subscriber connectSub; protected T socket; - protected ConcurrentLinkedQueue errors; + protected LeaksTrackingByteBufAllocator allocator; @Override public Statement apply(final Statement base, Description description) { return new Statement() { @Override public void evaluate() throws Throwable { - connection = new TestDuplexConnection(); + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(allocator); connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); init(); base.evaluate(); } @@ -52,9 +53,11 @@ protected void init() { protected abstract T newRSocket(); - public void assertNoConnectionErrors() { - if (errors.size() > 1) { - Assert.fail("No connection errors expected: " + errors.peek().toString()); - } + public ByteBufAllocator alloc() { + return allocator; + } + + public void assertHasNoLeaks() { + allocator.assertHasNoLeaks(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/ConnectionSetupPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java similarity index 82% rename from rsocket-core/src/test/java/io/rsocket/ConnectionSetupPayloadTest.java rename to rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java index 16e0f2ec7..8eb5dee09 100644 --- a/rsocket-core/src/test/java/io/rsocket/ConnectionSetupPayloadTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java @@ -1,4 +1,4 @@ -package io.rsocket; +package io.rsocket.core; import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -6,7 +6,9 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; -import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.frame.SetupFrameCodec; import io.rsocket.util.DefaultPayload; import org.junit.jupiter.api.Test; @@ -24,13 +26,13 @@ void testSetupPayloadWithDataMetadata() { boolean leaseEnabled = true; ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(frame); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); assertTrue(setupPayload.willClientHonorLease()); assertEquals(KEEP_ALIVE_INTERVAL, setupPayload.keepAliveInterval()); assertEquals(KEEP_ALIVE_MAX_LIFETIME, setupPayload.keepAliveMaxLifetime()); - assertEquals(METADATA_TYPE, SetupFrameFlyweight.metadataMimeType(frame)); - assertEquals(DATA_TYPE, SetupFrameFlyweight.dataMimeType(frame)); + assertEquals(METADATA_TYPE, SetupFrameCodec.metadataMimeType(frame)); + assertEquals(DATA_TYPE, SetupFrameCodec.dataMimeType(frame)); assertTrue(setupPayload.hasMetadata()); assertNotNull(setupPayload.metadata()); assertEquals(payload.metadata(), setupPayload.metadata()); @@ -46,7 +48,7 @@ void testSetupPayloadWithNoMetadata() { boolean leaseEnabled = false; ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(frame); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); assertFalse(setupPayload.willClientHonorLease()); assertFalse(setupPayload.hasMetadata()); @@ -64,7 +66,7 @@ void testSetupPayloadWithEmptyMetadata() { boolean leaseEnabled = false; ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(frame); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); assertFalse(setupPayload.willClientHonorLease()); assertTrue(setupPayload.hasMetadata()); @@ -75,12 +77,12 @@ void testSetupPayloadWithEmptyMetadata() { } private static ByteBuf encodeSetupFrame(boolean leaseEnabled, Payload setupPayload) { - return SetupFrameFlyweight.encode( + return SetupFrameCodec.encode( ByteBufAllocator.DEFAULT, leaseEnabled, KEEP_ALIVE_INTERVAL, KEEP_ALIVE_MAX_LIFETIME, - null, + Unpooled.EMPTY_BUFFER, METADATA_TYPE, DATA_TYPE, setupPayload); diff --git a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java similarity index 78% rename from rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java rename to rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java index b275ccc33..d98f86113 100644 --- a/rsocket-core/src/test/java/io/rsocket/KeepAliveTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import static io.rsocket.keepalive.KeepAliveHandler.DefaultKeepAliveHandler; import static io.rsocket.keepalive.KeepAliveHandler.ResumableKeepAliveHandler; @@ -22,19 +22,19 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.ConnectionErrorException; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; -import io.rsocket.frame.KeepAliveFrameFlyweight; +import io.rsocket.frame.KeepAliveFrameCodec; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.resume.InMemoryResumableFramesStore; import io.rsocket.resume.ResumableDuplexConnection; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.DefaultPayload; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.function.Consumer; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -51,24 +51,27 @@ public class KeepAliveTest { private ResumableRSocketState resumableRequesterState; static RSocketState requester(int tickPeriod, int timeout) { - TestDuplexConnection connection = new TestDuplexConnection(); - Errors errors = new Errors(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); RSocketRequester rSocket = new RSocketRequester( - ByteBufAllocator.DEFAULT, connection, DefaultPayload::create, - errors, StreamIdSupplier.clientSupplier(), + 0, tickPeriod, timeout, new DefaultKeepAliveHandler(connection), - RequesterLeaseHandler.None); - return new RSocketState(rSocket, errors, connection); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + return new RSocketState(rSocket, allocator, connection); } static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { - TestDuplexConnection connection = new TestDuplexConnection(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); ResumableDuplexConnection resumableConnection = new ResumableDuplexConnection( "test", @@ -77,19 +80,18 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { Duration.ofSeconds(10), false); - Errors errors = new Errors(); RSocketRequester rSocket = new RSocketRequester( - ByteBufAllocator.DEFAULT, resumableConnection, DefaultPayload::create, - errors, StreamIdSupplier.clientSupplier(), + 0, tickPeriod, timeout, new ResumableKeepAliveHandler(resumableConnection), - RequesterLeaseHandler.None); - return new ResumableRSocketState(rSocket, errors, connection, resumableConnection); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + return new ResumableRSocketState(rSocket, connection, resumableConnection, allocator); } @BeforeEach @@ -106,16 +108,14 @@ void rSocketNotDisposedOnPresentKeepAlives() { .subscribe( n -> connection.addToReceivedBuffer( - KeepAliveFrameFlyweight.encode( + KeepAliveFrameCodec.encode( ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER))); Mono.delay(Duration.ofMillis(2000)).block(); RSocket rSocket = requesterState.rSocket(); - List errors = requesterState.errors().errors(); Assertions.assertThat(rSocket.isDisposed()).isFalse(); - Assertions.assertThat(errors).isEmpty(); } @Test @@ -134,11 +134,12 @@ void rSocketDisposedOnMissingKeepAlives() { Mono.delay(Duration.ofMillis(2000)).block(); - List errors = requesterState.errors().errors(); Assertions.assertThat(rSocket.isDisposed()).isTrue(); - Assertions.assertThat(errors).hasSize(1); - Throwable throwable = errors.get(0); - Assertions.assertThat(throwable).isInstanceOf(ConnectionErrorException.class); + rSocket + .onClose() + .as(StepVerifier::create) + .expectError(ConnectionErrorException.class) + .verify(Duration.ofMillis(100)); } @Test @@ -162,7 +163,7 @@ void requesterRespondsToKeepAlives() { .subscribe( l -> connection.addToReceivedBuffer( - KeepAliveFrameFlyweight.encode( + KeepAliveFrameCodec.encode( ByteBufAllocator.DEFAULT, true, 0, Unpooled.EMPTY_BUFFER))); StepVerifier.create(Flux.from(connection.getSentAsPublisher()).take(1)) @@ -191,7 +192,7 @@ void resumableRequesterKeepAlivesAfterReconnect() { resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); resumableDuplexConnection.disconnect(); - TestDuplexConnection newTestConnection = new TestDuplexConnection(); + TestDuplexConnection newTestConnection = new TestDuplexConnection(rSocketState.alloc()); resumableDuplexConnection.reconnect(newTestConnection); resumableDuplexConnection.resume(0, 0, ignored -> Mono.empty()); @@ -215,37 +216,36 @@ void resumableRequesterNoKeepAlivesAfterDispose() { @Test void resumableRSocketsNotDisposedOnMissingKeepAlives() { RSocket rSocket = resumableRequesterState.rSocket(); - List errors = resumableRequesterState.errors().errors(); TestDuplexConnection connection = resumableRequesterState.connection(); Mono.delay(Duration.ofMillis(500)).block(); Assertions.assertThat(rSocket.isDisposed()).isFalse(); - Assertions.assertThat(errors).hasSize(0); Assertions.assertThat(connection.isDisposed()).isTrue(); } private boolean keepAliveFrame(ByteBuf frame) { - return FrameHeaderFlyweight.frameType(frame) == FrameType.KEEPALIVE; + return FrameHeaderCodec.frameType(frame) == FrameType.KEEPALIVE; } private boolean keepAliveFrameWithRespondFlag(ByteBuf frame) { - return keepAliveFrame(frame) && KeepAliveFrameFlyweight.respondFlag(frame); + return keepAliveFrame(frame) && KeepAliveFrameCodec.respondFlag(frame); } private boolean keepAliveFrameWithoutRespondFlag(ByteBuf frame) { - return keepAliveFrame(frame) && !KeepAliveFrameFlyweight.respondFlag(frame); + return keepAliveFrame(frame) && !KeepAliveFrameCodec.respondFlag(frame); } static class RSocketState { private final RSocket rSocket; - private final Errors errors; private final TestDuplexConnection connection; + private final LeaksTrackingByteBufAllocator allocator; - public RSocketState(RSocket rSocket, Errors errors, TestDuplexConnection connection) { + public RSocketState( + RSocket rSocket, LeaksTrackingByteBufAllocator allocator, TestDuplexConnection connection) { this.rSocket = rSocket; - this.errors = errors; this.connection = connection; + this.allocator = allocator; } public TestDuplexConnection connection() { @@ -256,26 +256,26 @@ public RSocket rSocket() { return rSocket; } - public Errors errors() { - return errors; + public LeaksTrackingByteBufAllocator alloc() { + return allocator; } } static class ResumableRSocketState { private final RSocket rSocket; - private final Errors errors; private final TestDuplexConnection connection; private final ResumableDuplexConnection resumableDuplexConnection; + private final LeaksTrackingByteBufAllocator allocator; public ResumableRSocketState( RSocket rSocket, - Errors errors, TestDuplexConnection connection, - ResumableDuplexConnection resumableDuplexConnection) { + ResumableDuplexConnection resumableDuplexConnection, + LeaksTrackingByteBufAllocator allocator) { this.rSocket = rSocket; - this.errors = errors; this.connection = connection; this.resumableDuplexConnection = resumableDuplexConnection; + this.allocator = allocator; } public TestDuplexConnection connection() { @@ -290,21 +290,8 @@ public RSocket rSocket() { return rSocket; } - public Errors errors() { - return errors; - } - } - - static class Errors implements Consumer { - private final List errors = new ArrayList<>(); - - @Override - public void accept(Throwable throwable) { - errors.add(throwable); - } - - public List errors() { - return new ArrayList<>(errors); + public LeaksTrackingByteBufAllocator alloc() { + return allocator; } } } diff --git a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java new file mode 100644 index 000000000..ed9f1ec4a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -0,0 +1,99 @@ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class PayloadValidationUtilsTest { + + @Test + void shouldBeValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthCodec.FRAME_LENGTH_MASK + - FrameLengthCodec.FRAME_LENGTH_SIZE + - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation() { + byte[] data = + new byte + [FrameLengthCodec.FRAME_LENGTH_MASK + - FrameLengthCodec.FRAME_LENGTH_SIZE + - FrameHeaderCodec.size() + + 1]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation0() { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK / 2]; + byte[] data = + new byte + [FrameLengthCodec.FRAME_LENGTH_MASK / 2 + - FrameLengthCodec.FRAME_LENGTH_SIZE + - FrameHeaderCodec.size() + - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation1() { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation2() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, payload)).isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation3() { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation4() { + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, payload)).isTrue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java similarity index 84% rename from rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java rename to rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java index 2a2567843..ab336b8cd 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketLeaseTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,27 +14,32 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import static io.rsocket.frame.FrameType.ERROR; import static io.rsocket.frame.FrameType.SETUP; import static org.assertj.core.data.Offset.offset; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; -import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; -import io.rsocket.exceptions.MissingLeaseException; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; -import io.rsocket.frame.LeaseFrameFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.SetupFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.internal.ClientServerInputMultiplexer; import io.rsocket.lease.*; -import io.rsocket.plugins.PluginRegistry; +import io.rsocket.lease.MissingLeaseException; +import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.test.util.TestClientTransport; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.test.util.TestServerTransport; @@ -73,28 +78,28 @@ class RSocketLeaseTest { @BeforeEach void setUp() { - connection = new TestDuplexConnection(); PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; - byteBufAllocator = UnpooledByteBufAllocator.DEFAULT; + byteBufAllocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(byteBufAllocator); requesterLeaseHandler = new RequesterLeaseHandler.Impl(TAG, leases -> leaseReceiver = leases); responderLeaseHandler = new ResponderLeaseHandler.Impl<>( - TAG, byteBufAllocator, stats -> leaseSender, err -> {}, Optional.empty()); + TAG, byteBufAllocator, stats -> leaseSender, Optional.empty()); ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, new PluginRegistry(), true); + new ClientServerInputMultiplexer(connection, new InitializingInterceptorRegistry(), true); rSocketRequester = new RSocketRequester( - byteBufAllocator, multiplexer.asClientConnection(), payloadDecoder, - err -> {}, StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, - requesterLeaseHandler); + requesterLeaseHandler, + TestScheduler.INSTANCE); RSocket mockRSocketHandler = mock(RSocket.class); when(mockRSocketHandler.metadataPush(any())).thenReturn(Mono.empty()); @@ -105,19 +110,18 @@ void setUp() { rSocketResponder = new RSocketResponder( - byteBufAllocator, multiplexer.asServerConnection(), mockRSocketHandler, payloadDecoder, - err -> {}, - responderLeaseHandler); + responderLeaseHandler, + 0); } @Test public void serverRSocketFactoryRejectsUnsupportedLease() { Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); ByteBuf setupFrame = - SetupFrameFlyweight.encode( + SetupFrameCodec.encode( ByteBufAllocator.DEFAULT, true, 1000, @@ -127,12 +131,7 @@ public void serverRSocketFactoryRejectsUnsupportedLease() { payload); TestServerTransport transport = new TestServerTransport(); - Closeable server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new AbstractRSocket() {})) - .transport(transport) - .start() - .block(); + RSocketServer.create().bind(transport).block(); TestDuplexConnection connection = transport.connect(); connection.addToReceivedBuffer(setupFrame); @@ -140,20 +139,21 @@ public void serverRSocketFactoryRejectsUnsupportedLease() { Collection sent = connection.getSent(); Assertions.assertThat(sent).hasSize(1); ByteBuf error = sent.iterator().next(); - Assertions.assertThat(FrameHeaderFlyweight.frameType(error)).isEqualTo(ERROR); - Assertions.assertThat(Exceptions.from(error).getMessage()).isEqualTo("lease is not supported"); + Assertions.assertThat(FrameHeaderCodec.frameType(error)).isEqualTo(ERROR); + Assertions.assertThat(Exceptions.from(0, error).getMessage()) + .isEqualTo("lease is not supported"); } @Test public void clientRSocketFactorySetsLeaseFlag() { TestClientTransport clientTransport = new TestClientTransport(); - RSocketFactory.connect().lease().transport(clientTransport).start().block(); + RSocketConnector.create().lease(Leases::new).connect(clientTransport).block(); Collection sent = clientTransport.testConnection().getSent(); Assertions.assertThat(sent).hasSize(1); ByteBuf setup = sent.iterator().next(); - Assertions.assertThat(FrameHeaderFlyweight.frameType(setup)).isEqualTo(SETUP); - Assertions.assertThat(SetupFrameFlyweight.honorLease(setup)).isTrue(); + Assertions.assertThat(FrameHeaderCodec.frameType(setup)).isEqualTo(SETUP); + Assertions.assertThat(SetupFrameCodec.honorLease(setup)).isTrue(); } @ParameterizedTest @@ -276,13 +276,13 @@ void sendLease() { connection .getSent() .stream() - .filter(f -> FrameHeaderFlyweight.frameType(f) == FrameType.LEASE) + .filter(f -> FrameHeaderCodec.frameType(f) == FrameType.LEASE) .findFirst() .orElseThrow(() -> new IllegalStateException("Lease frame not sent")); - Assertions.assertThat(LeaseFrameFlyweight.ttl(leaseFrame)).isEqualTo(ttl); - Assertions.assertThat(LeaseFrameFlyweight.numRequests(leaseFrame)).isEqualTo(numberOfRequests); - Assertions.assertThat(LeaseFrameFlyweight.metadata(leaseFrame).toString(utf8)) + Assertions.assertThat(LeaseFrameCodec.ttl(leaseFrame)).isEqualTo(ttl); + Assertions.assertThat(LeaseFrameCodec.numRequests(leaseFrame)).isEqualTo(numberOfRequests); + Assertions.assertThat(LeaseFrameCodec.metadata(leaseFrame).toString(utf8)) .isEqualTo(metadataContent); } @@ -310,7 +310,7 @@ void receiveLease() { } ByteBuf leaseFrame(int ttl, int requests, ByteBuf metadata) { - return LeaseFrameFlyweight.encode(byteBufAllocator, ttl, requests, metadata); + return LeaseFrameCodec.encode(byteBufAllocator, ttl, requests, metadata); } static Stream>> interactions() { diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java new file mode 100644 index 000000000..dc76b5450 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static org.junit.Assert.assertEquals; + +import io.rsocket.RSocket; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.transport.ClientTransport; +import java.io.UncheckedIOException; +import java.time.Duration; +import java.util.Iterator; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.Exceptions; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RSocketReconnectTest { + + private Queue retries = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldBeASharedReconnectableInstanceOfRSocketMono() { + TestClientTransport[] testClientTransport = + new TestClientTransport[] {new TestClientTransport()}; + Mono rSocketMono = + RSocketConnector.create() + .reconnect(Retry.indefinitely()) + .connect(() -> testClientTransport[0]); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + testClientTransport[0] = new TestClientTransport(); + + RSocket rSocket3 = rSocketMono.block(); + RSocket rSocket4 = rSocketMono.block(); + + Assertions.assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + } + + @Test + @SuppressWarnings({"rawtype", "unchecked"}) + public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + Mockito.when(transport.connect(0)) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(new TestClientTransport().connect(0)); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isEqualTo(rSocket2); + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + } + + @Test + @SuppressWarnings({"rawtype", "unchecked"}) + public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + Mockito.when(transport.connect(0)) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(new TestClientTransport().connect(0)); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + } + + @Test + public void shouldBeNotBeASharedReconnectableInstanceOfRSocketMono() { + + Mono rSocketMono = RSocketConnector.connectWith(new TestClientTransport()); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + Assertions.assertThat(rSocket1).isNotEqualTo(rSocket2); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertEquals(exceptions.length, retries.size()); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertEquals(index, retryContext.totalRetries()); + assertEquals(exceptions[index], retryContext.failure().getClass()); + index++; + } + } + + Consumer onRetry() { + return context -> retries.add(context); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java similarity index 57% rename from rsocket-core/src/test/java/io/rsocket/RSocketRequesterSubscribersTest.java rename to rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java index b49dbe809..4cd3a3a26 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterSubscribersTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -14,18 +14,21 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.DefaultPayload; -import io.rsocket.util.MultiSubscriberRSocket; -import java.time.Duration; import java.util.Arrays; import java.util.Collection; import java.util.HashSet; @@ -38,8 +41,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; class RSocketRequesterSubscribersTest { @@ -52,46 +54,71 @@ class RSocketRequesterSubscribersTest { FrameType.REQUEST_STREAM, FrameType.REQUEST_CHANNEL)); + private LeaksTrackingByteBufAllocator allocator; private RSocket rSocketRequester; private TestDuplexConnection connection; @BeforeEach void setUp() { - connection = new TestDuplexConnection(); + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(allocator); rSocketRequester = new RSocketRequester( - ByteBufAllocator.DEFAULT, connection, PayloadDecoder.DEFAULT, - err -> {}, StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); } @ParameterizedTest @MethodSource("allInteractions") - void multiSubscriber(Function> interaction) { - RSocket multiSubsRSocket = new MultiSubscriberRSocket(rSocketRequester); - Flux response = Flux.from(interaction.apply(multiSubsRSocket)).take(Duration.ofMillis(10)); - StepVerifier.create(response).expectComplete().verify(Duration.ofSeconds(5)); - StepVerifier.create(response).expectComplete().verify(Duration.ofSeconds(5)); + void singleSubscriber(Function> interaction) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); + + response.subscribe(assertSubscriberA); + response.subscribe(assertSubscriberB); + + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), 1)); + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); - Assertions.assertThat(requestFramesCount(connection.getSent())).isEqualTo(2); + Assertions.assertThat(requestFramesCount(connection.getSent())).isEqualTo(1); } @ParameterizedTest @MethodSource("allInteractions") - void singleSubscriber(Function> interaction) { - Flux response = Flux.from(interaction.apply(rSocketRequester)).take(Duration.ofMillis(10)); - StepVerifier.create(response).expectComplete().verify(Duration.ofSeconds(5)); - StepVerifier.create(response) - .expectError(IllegalStateException.class) - .verify(Duration.ofSeconds(5)); + void singleSubscriberInCaseOfRacing(Function> interaction) { + for (int i = 1; i < 20000; i += 2) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); - Assertions.assertThat(requestFramesCount(connection.getSent())).isEqualTo(1); + RaceTestUtils.race( + () -> response.subscribe(assertSubscriberA), () -> response.subscribe(assertSubscriberB)); + + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), i)); + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); + + Assertions.assertThat(new AssertSubscriber[] {assertSubscriberA, assertSubscriberB}) + .anySatisfy(as -> as.assertError(IllegalStateException.class)); + + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))); + connection.clearSendReceiveBuffers(); + } } @ParameterizedTest @@ -105,7 +132,7 @@ void singleSubscriberInteractionsAreLazy(Function> interac static long requestFramesCount(Collection frames) { return frames .stream() - .filter(frame -> REQUEST_TYPES.contains(FrameHeaderFlyweight.frameType(frame))) + .filter(frame -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(frame))) .count(); } @@ -114,7 +141,7 @@ static Stream>> allInteractions() { rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), rSocket -> rSocket.requestStream(DefaultPayload.create("test")), - rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), - rSocket -> rSocket.metadataPush(DefaultPayload.create("test"))); + // rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), + rSocket -> rSocket.metadataPush(DefaultPayload.create("", "test"))); } } diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java similarity index 93% rename from rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java rename to rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java index a2c17cf95..de6f86c57 100644 --- a/rsocket-core/src/test/java/io/rsocket/RSocketRequesterTerminationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java @@ -1,6 +1,8 @@ -package io.rsocket; +package io.rsocket.core; -import io.rsocket.RSocketRequesterTest.ClientSocketRule; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketRequesterTest.ClientSocketRule; import io.rsocket.util.EmptyPayload; import java.nio.channels.ClosedChannelException; import java.time.Duration; diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java new file mode 100644 index 000000000..1ba75f75a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -0,0 +1,1031 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameHeaderCodec.frameType; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.runners.model.Statement; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.UnicastProcessor; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketRequesterTest { + + ClientSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped((t) -> {}); + rule = new ClientSocketRule(); + rule.apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + } + + @Test + @Timeout(2_000) + public void testInvalidFrameOnStream0ShouldNotTerminateRSocket() { + rule.connection.addToReceivedBuffer(RequestNFrameCodec.encode(rule.alloc(), 0, 10)); + Assertions.assertThat(rule.socket.isDisposed()).isFalse(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testStreamInitialN() { + Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); + + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + // don't request here + } + }; + stream.subscribe(subscriber); + + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + subscriber.request(5); + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat("sent frame count", sent.size(), is(1)); + + ByteBuf f = sent.get(0); + + assertThat("initial frame", frameType(f), is(REQUEST_STREAM)); + assertThat("initial request n", RequestStreamFrameCodec.initialRequestN(f), is(5L)); + assertThat("should be released", f.release(), is(true)); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleSetupException() { + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), 0, new RejectedSetupException("boom"))); + Assertions.assertThatThrownBy(() -> rule.socket.onClose().block()) + .isInstanceOf(RejectedSetupException.class); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleApplicationException() { + rule.connection.clearSendReceiveBuffers(); + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(); + response.subscribe(responseSub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), streamId, new ApplicationErrorException("error"))); + + verify(responseSub).onError(any(ApplicationErrorException.class)); + + Assertions.assertThat(rule.connection.getSent()) + // requestResponseFrame + .hasSize(1) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleValidFrame() { + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber sub = TestSubscriber.create(); + response.subscribe(sub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeNextReleasingPayload( + rule.alloc(), streamId, EmptyPayload.INSTANCE)); + + verify(sub).onComplete(); + Assertions.assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testRequestReplyWithCancel() { + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + + try { + response.block(Duration.ofMillis(100)); + } catch (IllegalStateException ise) { + } + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat( + "Unexpected frame sent on the connection.", frameType(sent.get(0)), is(REQUEST_RESPONSE)); + assertThat("Unexpected frame sent on the connection.", frameType(sent.get(1)), is(CANCEL)); + Assertions.assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Disabled("invalid") + @Timeout(2_000) + public void testRequestReplyErrorOnSend() { + rule.connection.setAvailability(0); // Fails send + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(10); + response.subscribe(responseSub); + + this.rule + .socket + .onClose() + .as(StepVerifier::create) + .expectComplete() + .verify(Duration.ofMillis(100)); + + verify(responseSub).onSubscribe(any(Subscription.class)); + + rule.assertHasNoLeaks(); + // TODO this should get the error reported through the response subscription + // verify(responseSub).onError(any(RuntimeException.class)); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation() { + MonoProcessor cancelled = MonoProcessor.create(); + Flux request = Flux.never().doOnCancel(cancelled::onComplete); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation2() { + MonoProcessor cancelled = MonoProcessor.create(); + Flux request = + Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::onComplete); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testChannelRequestServerSideCancellation() { + MonoProcessor cancelled = MonoProcessor.create(); + UnicastProcessor request = UnicastProcessor.create(); + request.onNext(EmptyPayload.INSTANCE); + rule.socket.requestChannel(request).subscribe(cancelled); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(rule.alloc(), streamId)); + rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.alloc(), streamId)); + Flux.first( + cancelled, + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + + Assertions.assertThat(request.isDisposed()).isTrue(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_CHANNEL) + .matches(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testCorrectFrameOrder() { + MonoProcessor delayer = MonoProcessor.create(); + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) {} + }; + rule.socket + .requestChannel( + Flux.concat(Flux.just(0).delayUntil(i -> delayer), Flux.range(1, 999)) + .map(i -> DefaultPayload.create(i + ""))) + .subscribe(subscriber); + + subscriber.request(1); + subscriber.request(Long.MAX_VALUE); + delayer.onComplete(); + + Iterator iterator = rule.connection.getSent().iterator(); + + ByteBuf initialFrame = iterator.next(); + + Assertions.assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); + Assertions.assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame)) + .isEqualTo(Long.MAX_VALUE); + Assertions.assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8)) + .isEqualTo("0"); + Assertions.assertThat(initialFrame.release()).isTrue(); + + Assertions.assertThat(iterator.hasNext()).isFalse(); + rule.assertHasNoLeaks(); + } + + @Test + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + generator.apply(rule.socket, DefaultPayload.create(data, metadata))) + .expectSubscription() + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .verify(); + rule.assertHasNoLeaks(); + }); + } + + static Stream>> prepareCalls() { + return Stream.of( + RSocket::fireAndForget, + RSocket::requestResponse, + RSocket::requestStream, + (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), + RSocket::metadataPush); + } + + @Test + public void + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() { + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + rule.socket.requestChannel( + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata)))) + .expectSubscription() + .then( + () -> + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(INVALID_PAYLOAD_ERROR_MESSAGE)) + .verify(); + Assertions.assertThat(rule.connection.getSent()) + // expect to be sent RequestChannelFrame + // expect to be sent CancelFrame + .hasSize(2) + .allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("racingCases") + public void checkNoLeaksOnRacing( + Function> initiator, + BiConsumer, ClientSocketRule> runner) { + for (int i = 0; i < 10000; i++) { + ClientSocketRule clientSocketRule = new ClientSocketRule(); + try { + clientSocketRule + .apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } catch (Throwable throwable) { + throwable.printStackTrace(); + } + + Publisher payloadP = initiator.apply(clientSocketRule); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + if (payloadP instanceof Flux) { + ((Flux) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } else { + ((Mono) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } + + runner.accept(assertSubscriber, clientSocketRule); + + Assertions.assertThat(clientSocketRule.connection.getSent()) + .allMatch(ReferenceCounted::release); + + clientSocketRule.assertHasNoLeaks(); + } + } + + private static Stream racingCases() { + return Stream.of( + Arguments.of( + (Function>) + (rule) -> rule.socket.requestStream(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + + return rule.socket.requestStream(payload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + Assertions.assertThat(rule.connection.getSent()).hasSize(2); + Assertions.assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_STREAM, + "Expected first frame matches {" + + REQUEST_STREAM + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + return rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + sink.complete(); + return ++index; + })); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + // + // Assertions.assertThat(rule.connection.getSent()).hasSize(2); + Assertions.assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_CHANNEL, + "Expected first frame matches {" + + REQUEST_CHANNEL + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = CancelFrameCodec.encode(allocator, streamId); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + ErrorFrameCodec.encode(allocator, streamId, new RuntimeException("test")); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> rule.socket.requestResponse(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + })); + } + + @Test + public void simpleOnDiscardRequestChannelTest() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + TestPublisher testPublisher = TestPublisher.create(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.next( + ByteBufPayload.create("d", "m"), + ByteBufPayload.create("d1", "m1"), + ByteBufPayload.create("d2", "m2")); + + assertSubscriber.cancel(); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleOnDiscardRequestChannelTest2() { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + TestPublisher testPublisher = TestPublisher.create(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.next(ByteBufPayload.create("d", "m")); + + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + testPublisher.next(ByteBufPayload.create("d1", "m1"), ByteBufPayload.create("d2", "m2")); + + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode( + allocator, streamId, new CustomRSocketException(0x00000404, "test"))); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(responsesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + Publisher response; + + switch (frameType) { + case REQUEST_FNF: + response = + testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p).then(Mono.empty())); + break; + case REQUEST_RESPONSE: + response = testPublisher.mono().flatMap(p -> rule.socket.requestResponse(p)); + break; + case REQUEST_STREAM: + response = testPublisher.mono().flatMapMany(p -> rule.socket.requestStream(p)); + break; + case REQUEST_CHANNEL: + response = rule.socket.requestChannel(testPublisher.flux()); + break; + default: + throw new UnsupportedOperationException("illegal case"); + } + + response.subscribe(assertSubscriber); + testPublisher.next(ByteBufPayload.create("d")); + + int streamId = rule.getStreamIdForRequestType(frameType); + + if (responsesCnt > 0) { + for (int i = 0; i < responsesCnt - 1; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + streamId, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("rd" + (i + 1)).getBytes()))); + } + + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + streamId, + false, + true, + true, + null, + Unpooled.wrappedBuffer(("rd" + responsesCnt).getBytes()))); + } + + if (framesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode(allocator, streamId, framesCnt)); + } + + for (int i = 1; i < framesCnt; i++) { + testPublisher.next(ByteBufPayload.create("d" + i)); + } + + Assertions.assertThat(rule.connection.getSent()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, framesCnt) + .hasSize(framesCnt) + .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)) + .allMatch(ByteBuf::release); + + Assertions.assertThat(assertSubscriber.isTerminated()) + .describedAs("Interaction Type :[%s]. Expected to be terminated", frameType) + .isTrue(); + + Assertions.assertThat(assertSubscriber.values()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames received", + frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(p -> p.release()); + + rule.assertHasNoLeaks(); + rule.connection.clearSendReceiveBuffers(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload( + BiFunction> sourceProducer) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + + Publisher source = sourceProducer.apply(invalidPayload, rule.socket); + + StepVerifier.create(source, 0) + .expectError(IllegalReferenceCountException.class) + .verify(Duration.ofMillis(100)); + } + + private static Stream>> refCntCases() { + return Stream.of( + (p, r) -> r.fireAndForget(p), + (p, r) -> r.requestResponse(p), + (p, r) -> r.requestStream(p), + (p, r) -> r.requestChannel(Mono.just(p)), + (p, r) -> + r.requestChannel(Flux.just(EmptyPayload.INSTANCE, p).doOnSubscribe(s -> s.request(1)))); + } + + @Test + public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { + Payload payload1 = ByteBufPayload.create("abc1"); + Mono fnf1 = rule.socket.fireAndForget(payload1); + + Payload payload2 = ByteBufPayload.create("abc2"); + Mono fnf2 = rule.socket.fireAndForget(payload2); + + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + // checks that fnf2 should have id 1 even though it was generated later than fnf1 + AssertSubscriber voidAssertSubscriber2 = fnf2.subscribeWith(AssertSubscriber.create(0)); + voidAssertSubscriber2.assertTerminated().assertNoError(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_FNF) + .matches(bb -> FrameHeaderCodec.streamId(bb) == 1) + // ensures that this is fnf1 with abc2 data + .matches( + bb -> + ByteBufUtil.equals( + RequestFireAndForgetFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc2".getBytes()))) + .matches(ReferenceCounted::release); + + rule.connection.clearSendReceiveBuffers(); + + // checks that fnf1 should have id 3 even though it was generated earlier + AssertSubscriber voidAssertSubscriber1 = fnf1.subscribeWith(AssertSubscriber.create(0)); + voidAssertSubscriber1.assertTerminated().assertNoError(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_FNF) + .matches(bb -> FrameHeaderCodec.streamId(bb) == 3) + // ensures that this is fnf1 with abc1 data + .matches( + bb -> + ByteBufUtil.equals( + RequestFireAndForgetFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc1".getBytes()))) + .matches(ReferenceCounted::release); + } + + @ParameterizedTest + @MethodSource("requestNInteractions") + public void ensuresThatNoOpsMustHappenUntilFirstRequestN( + FrameType frameType, BiFunction> interaction) { + Payload payload1 = ByteBufPayload.create("abc1"); + Publisher interaction1 = interaction.apply(rule, payload1); + + Payload payload2 = ByteBufPayload.create("abc2"); + Publisher interaction2 = interaction.apply(rule, payload2); + + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(0); + interaction1.subscribe(assertSubscriber1); + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(0); + interaction2.subscribe(assertSubscriber2); + assertSubscriber1.assertNotTerminated().assertNoError(); + assertSubscriber2.assertNotTerminated().assertNoError(); + // even though we subscribed, nothing should happen until the first requestN + Assertions.assertThat(rule.connection.getSent()).isEmpty(); + + // first request on the second interaction to ensure that stream id issuing on the first request + assertSubscriber2.request(1); + + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == frameType) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 1, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next()) + + "}") + .matches( + bb -> { + switch (frameType) { + case REQUEST_RESPONSE: + return ByteBufUtil.equals( + RequestResponseFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc2".getBytes())); + case REQUEST_STREAM: + return ByteBufUtil.equals( + RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes())); + case REQUEST_CHANNEL: + return ByteBufUtil.equals( + RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes())); + } + + return false; + }) + .matches(ReferenceCounted::release); + + rule.connection.clearSendReceiveBuffers(); + + assertSubscriber1.request(1); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == frameType) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 3, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next()) + + "}") + .matches( + bb -> { + switch (frameType) { + case REQUEST_RESPONSE: + return ByteBufUtil.equals( + RequestResponseFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc1".getBytes())); + case REQUEST_STREAM: + return ByteBufUtil.equals( + RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes())); + case REQUEST_CHANNEL: + return ByteBufUtil.equals( + RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes())); + } + + return false; + }) + .matches(ReferenceCounted::release); + } + + private static Stream requestNInteractions() { + return Stream.of( + Arguments.of( + REQUEST_RESPONSE, + (BiFunction>) + (rule, payload) -> rule.socket.requestResponse(payload)), + Arguments.of( + REQUEST_STREAM, + (BiFunction>) + (rule, payload) -> rule.socket.requestStream(payload)), + Arguments.of( + REQUEST_CHANNEL, + (BiFunction>) + (rule, payload) -> rule.socket.requestChannel(Flux.just(payload)))); + } + + @ParameterizedTest + @MethodSource("streamIdRacingCases") + public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( + BiFunction> interaction1, + BiFunction> interaction2) { + for (int i = 1; i < 10000; i += 4) { + Payload payload = DefaultPayload.create("test"); + Publisher publisher1 = interaction1.apply(rule, payload); + Publisher publisher2 = interaction2.apply(rule, payload); + RaceTestUtils.race( + () -> publisher1.subscribe(AssertSubscriber.create()), + () -> publisher2.subscribe(AssertSubscriber.create())); + + Assertions.assertThat(rule.connection.getSent()) + .extracting(FrameHeaderCodec::streamId) + .containsExactly(i, i + 2); + rule.connection.getSent().clear(); + } + } + + public static Stream streamIdRacingCases() { + return Stream.of( + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + (BiFunction>) + (r, p) -> r.socket.requestResponse(p)), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestResponse(p), + (BiFunction>) + (r, p) -> r.socket.requestStream(p)), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestStream(p), + (BiFunction>) + (r, p) -> r.socket.requestChannel(Flux.just(p))), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestChannel(Flux.just(p)), + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p))); + } + + public int sendRequestResponse(Publisher response) { + Subscriber sub = TestSubscriber.create(); + response.subscribe(sub); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeNextCompleteReleasingPayload( + rule.alloc(), streamId, EmptyPayload.INSTANCE)); + verify(sub).onNext(any(Payload.class)); + verify(sub).onComplete(); + return streamId; + } + + public static class ClientSocketRule extends AbstractSocketRule { + @Override + protected RSocketRequester newRSocket() { + return new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + } + + public int getStreamIdForRequestType(FrameType expectedFrameType) { + assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); + List framesFound = new ArrayList<>(); + for (ByteBuf frame : connection.getSent()) { + FrameType frameType = frameType(frame); + if (frameType == expectedFrameType) { + return FrameHeaderCodec.streamId(frame); + } + framesFound.add(frameType); + } + throw new AssertionError( + "No frames sent with frame type: " + + expectedFrameType + + ", frames found: " + + framesFound); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java new file mode 100644 index 000000000..036dc2eef --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -0,0 +1,823 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameHeaderCodec.frameType; +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.util.Collection; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.runners.model.Statement; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketResponderTest { + + ServerSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped(t -> {}); + rule = new ServerSocketRule(); + rule.apply( + new Statement() { + @Override + public void evaluate() {} + }, + null) + .evaluate(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandleKeepAlive() throws Exception { + rule.connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(rule.alloc(), true, 0, Unpooled.EMPTY_BUFFER)); + ByteBuf sent = rule.connection.awaitSend(); + assertThat("Unexpected frame sent.", frameType(sent), is(FrameType.KEEPALIVE)); + /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ + assertThat( + "Unexpected keep-alive frame respond flag.", + KeepAliveFrameCodec.respondFlag(sent), + is(false)); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandleResponseFrameNoError() throws Exception { + final int streamId = 4; + rule.connection.clearSendReceiveBuffers(); + + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + + Collection> sendSubscribers = rule.connection.getSendSubscribers(); + assertThat("Request not sent.", sendSubscribers, hasSize(1)); + Subscriber sendSub = sendSubscribers.iterator().next(); + assertThat( + "Unexpected frame sent.", + frameType(rule.connection.awaitSend()), + anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandlerEmitsError() throws Exception { + final int streamId = 4; + rule.sendRequest(streamId, FrameType.REQUEST_STREAM); + assertThat( + "Unexpected frame sent.", frameType(rule.connection.awaitSend()), is(FrameType.ERROR)); + } + + @Test + @Timeout(20_000) + public void testCancel() { + ByteBufAllocator allocator = rule.alloc(); + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return Mono.never().doOnCancel(() -> cancelled.set(true)); + } + }); + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + + assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); + + rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(allocator, streamId)); + + assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); + assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() { + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + byte[] metadata = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + byte[] data = new byte[FrameLengthCodec.FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + final RSocket acceptingSocket = + new RSocket() { + @Override + public Mono requestResponse(Payload p) { + p.release(); + return Mono.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestStream(Payload p) { + p.release(); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .doOnNext(Payload::release) + .subscribe( + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + subscription.request(1); + } + }); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + }; + rule.setAcceptingSocket(acceptingSocket); + + final Runnable[] runnables = { + () -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE), + () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM), + () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL) + }; + + for (Runnable runnable : runnables) { + rule.connection.clearSendReceiveBuffers(); + runnable.run(); + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.ERROR) + .matches(bb -> ErrorFrameCodec.dataUtf8(bb).contains(INVALID_PAYLOAD_ERROR_MESSAGE)) + .matches(ReferenceCounted::release); + + assertThat("Subscription not cancelled.", cancelled.get(), is(true)); + } + + rule.assertHasNoLeaks(); + } + + @Test + public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return Flux.never(); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + + RaceTestUtils.race( + () -> { + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + }, + assertSubscriber::cancel); + + Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(cancelFrame), + parallel), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void + checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + FluxSink[] sinks = new FluxSink[1]; + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + + return Flux.create( + sink -> { + sinks[0] = sink; + }, + FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + + ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); + + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), + parallel), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + sink.error(new RuntimeException()); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestResponseTest1() { + Scheduler parallel = Schedulers.parallel(); + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + for (int i = 0; i < 10000; i++) { + Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + sources[0] = new Operators.MonoSubscriber<>(actual); + actual.onSubscribe(sources[0]); + } + }; + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_RESPONSE); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sources[0].complete(ByteBufPayload.create("d1", "m1")); + }, + parallel); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + } + + @Test + public void simpleDiscardRequestStreamTest() { + ByteBufAllocator allocator = rule.alloc(); + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + rule.connection.addToReceivedBuffer(cancelFrame); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleDiscardRequestChannelTest() { + ByteBufAllocator allocator = rule.alloc(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return (Flux) payloads; + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("de3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + + rule.connection.addToReceivedBuffer(cancelFrame); + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(framesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return testPublisher.flux(); + } + }, + 1); + + rule.sendRequest(1, frameType, ByteBufPayload.create("d")); + + // if responses number is bigger than 1 we have to send one extra requestN + if (responsesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode(allocator, 1, responsesCnt - 1)); + } + + // respond with specific number of elements + for (int i = 0; i < responsesCnt; i++) { + testPublisher.next(ByteBufPayload.create("rd" + i)); + } + + // Listen to incoming frames. Valid for RequestChannel case only + if (framesCnt > 1) { + for (int i = 1; i < responsesCnt; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("d" + (i + 1)).getBytes()))); + } + } + + if (responsesCnt > 0) { + Assertions.assertThat( + rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)); + } + + if (framesCnt > 1) { + Assertions.assertThat( + rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe single RequestN(%s) frame", + frameType, framesCnt - 1) + .hasSize(1) + .first() + .matches(bb -> RequestNFrameCodec.requestN(bb) == (framesCnt - 1)); + } + + Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + Assertions.assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) + .hasSize(framesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload(FrameType frameType) { + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Mono.just(invalidPayload); + } + + @Override + public Flux requestStream(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + }); + + rule.sendRequest(1, frameType); + + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches( + bb -> frameType(bb) == ERROR, + "Expect frame type to be {" + + ERROR + + "} but was {" + + frameType(rule.connection.getSent().iterator().next()) + + "}"); + } + + private static Stream refCntCases() { + return Stream.of(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + } + + public static class ServerSocketRule extends AbstractSocketRule { + + private RSocket acceptingSocket; + private volatile int prefetch; + + @Override + protected void init() { + acceptingSocket = + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + }; + super.init(); + } + + public void setAcceptingSocket(RSocket acceptingSocket) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + this.prefetch = Integer.MAX_VALUE; + super.init(); + } + + public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + this.prefetch = prefetch; + super.init(); + } + + @Override + protected RSocketResponder newRSocket() { + return new RSocketResponder( + connection, acceptingSocket, PayloadDecoder.ZERO_COPY, ResponderLeaseHandler.None, 0); + } + + private void sendRequest(int streamId, FrameType frameType) { + sendRequest(streamId, frameType, EmptyPayload.INSTANCE); + } + + private void sendRequest(int streamId, FrameType frameType, Payload payload) { + ByteBuf request; + + switch (frameType) { + case REQUEST_CHANNEL: + request = + RequestChannelFrameCodec.encodeReleasingPayload( + allocator, streamId, false, prefetch, payload); + break; + case REQUEST_STREAM: + request = + RequestStreamFrameCodec.encodeReleasingPayload( + allocator, streamId, prefetch, payload); + break; + case REQUEST_RESPONSE: + request = RequestResponseFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + break; + case REQUEST_FNF: + request = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + break; + default: + throw new IllegalArgumentException("unsupported type: " + frameType); + } + + connection.addToReceivedBuffer(request); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java new file mode 100644 index 000000000..073ebfd06 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java @@ -0,0 +1,43 @@ +package io.rsocket.core; + +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestServerTransport; +import org.assertj.core.api.Assertions; +import org.junit.Test; + +public class RSocketServerFragmentationTest { + + @Test + public void serverErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketServer.create().fragment(2)) + .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void serverSucceedsWithEnabledFragmentationOnSufficientMtu() { + RSocketServer.create().fragment(100).bind(new TestServerTransport()).block(); + } + + @Test + public void serverSucceedsWithDisabledFragmentation() { + RSocketServer.create().bind(new TestServerTransport()).block(); + } + + @Test + public void clientErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketConnector.create().fragment(2)) + .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void clientSucceedsWithEnabledFragmentationOnSufficientMtu() { + RSocketConnector.create().fragment(100).connect(new TestClientTransport()).block(); + } + + @Test + public void clientSucceedsWithDisabledFragmentation() { + RSocketConnector.connectWith(new TestClientTransport()).block(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java new file mode 100644 index 000000000..692894fd6 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -0,0 +1,506 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.TestScheduler; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.RequesterLeaseHandler; +import io.rsocket.lease.ResponderLeaseHandler; +import io.rsocket.test.util.LocalDuplexConnection; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import org.assertj.core.api.Assertions; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExternalResource; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; +import org.reactivestreams.Publisher; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.DirectProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; + +public class RSocketTest { + + @Rule public final SocketRule rule = new SocketRule(); + + @Test + public void rsocketDisposalShouldEndupWithNoErrorsOnClose() { + RSocket requestHandlingRSocket = + new RSocket() { + final Disposable disposable = Disposables.single(); + + @Override + public void dispose() { + disposable.dispose(); + } + + @Override + public boolean isDisposed() { + return disposable.isDisposed(); + } + }; + rule.setRequestAcceptor(requestHandlingRSocket); + rule.crs + .onClose() + .as(StepVerifier::create) + .expectSubscription() + .then(rule.crs::dispose) + .expectComplete() + .verify(Duration.ofMillis(100)); + + Assertions.assertThat(requestHandlingRSocket.isDisposed()).isTrue(); + } + + @Test(timeout = 2_000) + public void testRequestReplyNoError() { + StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) + .expectNextCount(1) + .expectComplete() + .verify(); + } + + @Test(timeout = 2000) + public void testHandlerEmitsError() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error(new NullPointerException("Deliberate exception.")); + } + }); + rule.crs + .requestResponse(EmptyPayload.INSTANCE) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(ApplicationErrorException.class) + .hasMessage("Deliberate exception.")) + .verify(Duration.ofMillis(100)); + } + + @Test(timeout = 2000) + public void testHandlerEmitsCustomError() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error( + new CustomRSocketException(0x00000501, "Deliberate Custom exception.")); + } + }); + rule.crs + .requestResponse(EmptyPayload.INSTANCE) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("Deliberate Custom exception.") + .hasFieldOrPropertyWithValue("errorCode", 0x00000501)) + .verify(); + } + + @Test(timeout = 2000) + public void testRequestPropagatesCorrectlyForRequestChannel() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + // specifically limits request to 3 in order to prevent 256 request from limitRate + // hidden on the responder side + .limitRequest(3); + } + }); + + Flux.range(0, 3) + .map(i -> DefaultPayload.create("" + i)) + .as(rule.crs::requestChannel) + .as(publisher -> StepVerifier.create(publisher, 3)) + .expectSubscription() + .expectNextCount(3) + .expectComplete() + .verify(Duration.ofMillis(5000)); + } + + @Test(timeout = 2000) + public void testStream() throws Exception { + Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test(timeout = 2000) + public void testChannel() throws Exception { + Flux requests = + Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test(timeout = 2000) + public void testErrorPropagatesCorrectly() { + AtomicReference error = new AtomicReference<>(); + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads).doOnError(error::set); + } + }); + Flux requests = Flux.error(new RuntimeException("test")); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectErrorMessage("test").verify(); + Assertions.assertThat(error.get()).isNull(); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion1() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion2() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_CancellationFromResponderShouldLeaveStreamInHalfClosedStateWithNextCompletionPossibleFromRequester() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void + requestChannelCase_CompletionFromRequesterShouldLeaveStreamInHalfClosedStateWithNextCancellationPossibleFromResponder() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_ensureThatRequesterSubscriberCancellationTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + cancelFromRequesterSubscriber( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + void initRequestChannelCase( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(responderSubscriber); + return responderPublisher.flux(); + } + }); + + rule.crs.requestChannel(requesterPublisher).subscribe(requesterSubscriber); + + requesterPublisher.assertWasSubscribed(); + requesterSubscriber.assertSubscribed(); + + responderSubscriber.assertNotSubscribed(); + responderPublisher.assertWasNotSubscribed(); + + // firstRequest + requesterSubscriber.request(1); + requesterPublisher.assertMaxRequested(1); + requesterPublisher.next(DefaultPayload.create("initialData", "initialMetadata")); + + responderSubscriber.assertSubscribed(); + responderPublisher.assertWasSubscribed(); + } + + void nextFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that outerUpstream and innerSubscriber is not terminated so the requestChannel + requesterPublisher.assertSubscribers(1); + responderSubscriber.assertNotTerminated(); + + responderSubscriber.request(6); + requesterPublisher.next( + DefaultPayload.create("d1", "m1"), + DefaultPayload.create("d2"), + DefaultPayload.create("d3", "m3"), + DefaultPayload.create("d4"), + DefaultPayload.create("d5", "m5")); + + List innerPayloads = responderSubscriber.awaitAndAssertNextValueCount(6).values(); + Assertions.assertThat(innerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("initialData", "d1", "d2", "d3", "d4", "d5"); + Assertions.assertThat(innerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, true, false, true, false, true); + Assertions.assertThat(innerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("initialMetadata", "m1", "", "m3", "", "m5"); + } + + void completeFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + requesterPublisher.complete(); + responderSubscriber.assertTerminated(); + requesterPublisher.assertNoSubscribers(); + } + + void cancelFromResponderSubscriber( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + responderSubscriber.cancel(); + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + void nextFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that downstream is not terminated so the requestChannel state is half-closed + responderPublisher.assertSubscribers(1); + requesterSubscriber.assertNotTerminated(); + + // ensures responderPublisher can send messages and outerSubscriber can receive them + requesterSubscriber.request(5); + responderPublisher.next( + DefaultPayload.create("rd1", "rm1"), + DefaultPayload.create("rd2"), + DefaultPayload.create("rd3", "rm3"), + DefaultPayload.create("rd4"), + DefaultPayload.create("rd5", "rm5")); + + List outerPayloads = requesterSubscriber.awaitAndAssertNextValueCount(5).values(); + Assertions.assertThat(outerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("rd1", "rd2", "rd3", "rd4", "rd5"); + Assertions.assertThat(outerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, false, true, false, true); + Assertions.assertThat(outerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("rm1", "", "rm3", "", "rm5"); + } + + void completeFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that after sending complete inner upstream is closed + responderPublisher.complete(); + requesterSubscriber.assertTerminated(); + responderPublisher.assertNoSubscribers(); + } + + void cancelFromRequesterSubscriber( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + requesterSubscriber.cancel(); + // error should be propagated + responderSubscriber.assertTerminated(); + responderPublisher.assertWasCancelled(); + responderPublisher.assertNoSubscribers(); + // ensures that cancellation is propagated to the actual upstream + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + public static class SocketRule extends ExternalResource { + + DirectProcessor serverProcessor; + DirectProcessor clientProcessor; + private RSocketRequester crs; + + @SuppressWarnings("unused") + private RSocketResponder srs; + + private RSocket requestAcceptor; + + private LeaksTrackingByteBufAllocator allocator; + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + init(); + base.evaluate(); + } + }; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + protected void init() { + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + serverProcessor = DirectProcessor.create(); + clientProcessor = DirectProcessor.create(); + + LocalDuplexConnection serverConnection = + new LocalDuplexConnection("server", allocator, clientProcessor, serverProcessor); + LocalDuplexConnection clientConnection = + new LocalDuplexConnection("client", allocator, serverProcessor, clientProcessor); + + clientConnection.onClose().doFinally(__ -> serverConnection.dispose()).subscribe(); + serverConnection.onClose().doFinally(__ -> clientConnection.dispose()).subscribe(); + + requestAcceptor = + null != requestAcceptor + ? requestAcceptor + : new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 10) + .map( + i -> DefaultPayload.create("server got -> [" + payload.toString() + "]")); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")) + .subscribe(); + + return Flux.range(1, 10) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")); + } + }; + + srs = + new RSocketResponder( + serverConnection, + requestAcceptor, + PayloadDecoder.DEFAULT, + ResponderLeaseHandler.None, + 0); + + crs = + new RSocketRequester( + clientConnection, + PayloadDecoder.DEFAULT, + StreamIdSupplier.clientSupplier(), + 0, + 0, + 0, + null, + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); + } + + public void setRequestAcceptor(RSocket requestAcceptor) { + this.requestAcceptor = requestAcceptor; + init(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java new file mode 100644 index 000000000..968a1a793 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java @@ -0,0 +1,868 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import org.assertj.core.api.Assertions; +import org.junit.Test; +import org.mockito.Mockito; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ReconnectMonoTests { + + private Queue retries = new ConcurrentLinkedQueue<>(); + private Queue> received = new ConcurrentLinkedQueue<>(); + private Queue expired = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldExpireValueOnRacingDisposeAndNext() { + Hooks.onErrorDropped(t -> {}); + Hooks.onNextDropped(System.out::println); + for (int i = 0; i < 100000; i++) { + final int index = i; + final CoreSubscriber[] monoSubscribers = new CoreSubscriber[1]; + Subscription mockSubscription = Mockito.mock(Subscription.class); + final Mono stringMono = + new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + actual.onSubscribe(mockSubscription); + monoSubscribers[0] = actual; + } + }; + + final ReconnectMono reconnectMono = + stringMono + .doOnDiscard(Object.class, System.out::println) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + RaceTestUtils.race(() -> monoSubscribers[0].onNext("value" + index), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + Mockito.verify(mockSubscription).cancel(); + + if (processor.isError()) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + + Assertions.assertThat(expired).containsOnly("value" + i); + } else { + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + final MonoProcessor racerProcessor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(racerProcessor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = MonoProcessor.create(); + final MonoProcessor racerProcessor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + RaceTestUtils.race( + () -> reconnectMono.subscribe(processor), () -> reconnectMono.subscribe(racerProcessor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + Assertions.assertThat(racerProcessor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(racerProcessor.peek()).isEqualTo("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = MonoProcessor.create(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + String[] values = new String[1]; + + RaceTestUtils.race( + () -> values[0] = reconnectMono.block(timeout), () -> reconnectMono.subscribe(processor)); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + Assertions.assertThat(values).containsExactly("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add(new ReconnectMono.ReconnectInner<>(processor, reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + Assertions.assertThat(cold.subscribeCount()).isZero(); + + String[] values1 = new String[1]; + String[] values2 = new String[1]; + + RaceTestUtils.race( + () -> values1[0] = reconnectMono.block(timeout), + () -> values2[0] = reconnectMono.block(timeout)); + + Assertions.assertThat(values2).containsExactly("value" + i); + Assertions.assertThat(values1).containsExactly("value" + i); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.READY); + + Assertions.assertThat(cold.subscribeCount()).isOne(); + + Assertions.assertThat( + reconnectMono.add( + new ReconnectMono.ReconnectInner<>(MonoProcessor.create(), reconnectMono))) + .isEqualTo(ReconnectMono.READY_STATE); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + Throwable error = processor.getError(); + + if (error instanceof CancellationException) { + Assertions.assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(error) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Unexpected Completion of the Upstream"); + } + + Assertions.assertThat(expired).isEmpty(); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndError() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + if (processor.getError() instanceof CancellationException) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(processor.getError()) + .isInstanceOf(RuntimeException.class) + .hasMessage("test"); + } + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < 100000; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono() + .retryWhen(Retry.max(1).filter(t -> t instanceof Exception)) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + Assertions.assertThat(processor.isTerminated()).isTrue(); + + if (processor.isError()) { + + if (processor.getError() instanceof CancellationException) { + Assertions.assertThat(processor.getError()) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + Assertions.assertThat(processor.getError()) + .matches(t -> Exceptions.isRetryExhausted(t)) + .hasCause(runtimeException); + } + + Assertions.assertThat(expired).hasSize(1).containsOnly("value" + i); + } else { + Assertions.assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value" + i, reconnectMono)); + Assertions.assertThat(processor.peek()).isEqualTo("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldThrowOnBlocking() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on Mono blocking read"); + } + + @Test + public void shouldThrowOnBlockingIfHasAlreadyTerminated() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + publisher.error(new RuntimeException("test")); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(RuntimeException.class) + .hasMessage("test") + .hasSuppressedException(new Exception("ReconnectMono terminated with an error")); + } + + @Test + public void shouldBeScannable() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final Mono parent = publisher.mono(); + final ReconnectMono reconnectMono = + parent.as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final Scannable scannableOfReconnect = Scannable.from(reconnectMono); + + Assertions.assertThat( + (List) + scannableOfReconnect.parents().map(s -> s.getClass()).collect(Collectors.toList())) + .hasSize(1) + .containsExactly(publisher.mono().getClass()); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) + .isEqualTo(false); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); + + final MonoProcessor processor = reconnectMono.subscribeWith(MonoProcessor.create()); + + final Scannable scannableOfMonoProcessor = Scannable.from(processor); + + Assertions.assertThat( + (List) + scannableOfMonoProcessor + .parents() + .map(s -> s.getClass()) + .collect(Collectors.toList())) + .hasSize(3) + .containsExactly( + ReconnectMono.ReconnectInner.class, ReconnectMono.class, publisher.mono().getClass()); + + reconnectMono.dispose(); + + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)) + .isEqualTo(true); + Assertions.assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) + .isInstanceOf(CancellationException.class); + } + + @Test + public void shouldNotExpiredIfNotCompleted() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + reconnectMono.invalidate(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + publisher.assertSubscribers(1); + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isTrue(); + + publisher.assertSubscribers(0); + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldNotEmitUntilCompletion() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isTrue(); + Assertions.assertThat(processor.peek()).isEqualTo("test"); + } + + @Test + public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + MonoProcessor processor = MonoProcessor.create(); + + reconnectMono.subscribe(processor); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + publisher.next("test"); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(processor.isTerminated()).isFalse(); + + processor.cancel(); + + Assertions.assertThat(reconnectMono.subscribers).isEqualTo(ReconnectMono.EMPTY_SUBSCRIBED); + + publisher.complete(); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(processor.isTerminated()).isFalse(); + Assertions.assertThat(processor.peek()).isNull(); + } + + @Test + public void shouldExpireValueOnDispose() { + final TestPublisher publisher = TestPublisher.create(); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono) + .expectSubscription() + .then(() -> publisher.next("value")) + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1); + + reconnectMono.dispose(); + + Assertions.assertThat(expired).hasSize(1); + Assertions.assertThat(received).hasSize(1); + Assertions.assertThat(reconnectMono.isDisposed()).isTrue(); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectError(CancellationException.class) + .verify(Duration.ofSeconds(timeout)); + } + + @Test + public void shouldNotifyAllTheSubscribers() { + final TestPublisher publisher = TestPublisher.create(); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final MonoProcessor sub1 = MonoProcessor.create(); + final MonoProcessor sub2 = MonoProcessor.create(); + final MonoProcessor sub3 = MonoProcessor.create(); + final MonoProcessor sub4 = MonoProcessor.create(); + + reconnectMono.subscribe(sub1); + reconnectMono.subscribe(sub2); + reconnectMono.subscribe(sub3); + reconnectMono.subscribe(sub4); + + Assertions.assertThat(reconnectMono.subscribers).hasSize(4); + + final ArrayList> processors = new ArrayList<>(200); + + for (int i = 0; i < 100; i++) { + final MonoProcessor subA = MonoProcessor.create(); + final MonoProcessor subB = MonoProcessor.create(); + processors.add(subA); + processors.add(subB); + RaceTestUtils.race(() -> reconnectMono.subscribe(subA), () -> reconnectMono.subscribe(subB)); + } + + Assertions.assertThat(reconnectMono.subscribers).hasSize(204); + + sub1.dispose(); + + Assertions.assertThat(reconnectMono.subscribers).hasSize(203); + + publisher.next("value"); + + Assertions.assertThatThrownBy(sub1::peek).isInstanceOf(CancellationException.class); + Assertions.assertThat(sub2.peek()).isEqualTo("value"); + Assertions.assertThat(sub3.peek()).isEqualTo("value"); + Assertions.assertThat(sub4.peek()).isEqualTo("value"); + + for (MonoProcessor sub : processors) { + Assertions.assertThat(sub.peek()).isEqualTo("value"); + Assertions.assertThat(sub.isTerminated()).isTrue(); + } + + Assertions.assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldExpireValueExactlyOnce() { + for (int i = 0; i < 1000; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value"); + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).isEmpty(); + Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::invalidate); + + Assertions.assertThat(expired).hasSize(1).containsOnly("value"); + Assertions.assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.elastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(expired).hasSize(1).containsOnly("value"); + Assertions.assertThat(received) + .hasSize(2) + .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value", reconnectMono)); + + Assertions.assertThat(cold.subscribeCount()).isEqualTo(2); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldTimeoutRetryWithVirtualTime() { + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + // then + StepVerifier.withVirtualTime( + () -> + Mono.error(new RuntimeException("Something went wrong")) + .retryWhen( + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(minBackoff)) + .doAfterRetry(onRetry()) + .maxBackoff(Duration.ofSeconds(maxBackoff))) + .timeout(Duration.ofSeconds(timeout)) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())) + .subscribeOn(Schedulers.elastic())) + .expectSubscription() + .thenAwait(Duration.ofSeconds(timeout)) + .expectError(TimeoutException.class) + .verify(Duration.ofSeconds(timeout)); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryNoBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.max(2).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.create(mono).verifyErrorMatches(Exceptions::isRetryExhausted); + assertRetries(IOException.class, IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryFixedBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.fixedDelay(1, Duration.ofMillis(500)).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .expectNoEvent(Duration.ofMillis(300)) + .thenAwait(Duration.ofMillis(300)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryExponentialBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .jitter(0.0d) + .doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .thenAwait(Duration.ofMillis(100)) + .thenAwait(Duration.ofMillis(200)) + .thenAwait(Duration.ofMillis(400)) + .thenAwait(Duration.ofMillis(500)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class, IOException.class, IOException.class, IOException.class); + + Assertions.assertThat(received).isEmpty(); + Assertions.assertThat(expired).isEmpty(); + } + + Consumer onRetry() { + return context -> retries.add(context); + } + + BiConsumer onValue() { + return (v, __) -> received.add(Tuples.of(v, __)); + } + + Consumer onExpire() { + return (v) -> expired.add(v); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertEquals(exceptions.length, retries.size()); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertEquals(index, retryContext.totalRetries()); + assertEquals(exceptions[index], retryContext.failure().getClass()); + index++; + } + } + + static boolean isRetryExhausted(Throwable e, Class cause) { + return Exceptions.isRetryExhausted(e) && cause.isInstance(e.getCause()); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java similarity index 75% rename from rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java rename to rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java index cce53a2f2..2957a051e 100644 --- a/rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -1,23 +1,24 @@ -package io.rsocket; +package io.rsocket.core; import static io.rsocket.transport.ServerTransport.ConnectionAcceptor; import static org.assertj.core.api.Assertions.assertThat; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; +import io.rsocket.*; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.exceptions.Exceptions; import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; -import io.rsocket.frame.SetupFrameFlyweight; +import io.rsocket.frame.SetupFrameCodec; import io.rsocket.lease.RequesterLeaseHandler; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.transport.ServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; -import java.util.ArrayList; -import java.util.List; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; import reactor.core.publisher.UnicastProcessor; @@ -31,13 +32,13 @@ void responderRejectSetup() { String errorMsg = "error"; RejectingAcceptor acceptor = new RejectingAcceptor(errorMsg); - RSocketFactory.receive().acceptor(acceptor).transport(transport).start().block(); + RSocketServer.create().acceptor(acceptor).bind(transport).block(); transport.connect(); ByteBuf sentFrame = transport.awaitSent(); - assertThat(FrameHeaderFlyweight.frameType(sentFrame)).isEqualTo(FrameType.ERROR); - RuntimeException error = Exceptions.from(sentFrame); + assertThat(FrameHeaderCodec.frameType(sentFrame)).isEqualTo(FrameType.ERROR); + RuntimeException error = Exceptions.from(0, sentFrame); assertThat(errorMsg).isEqualTo(error.getMessage()); assertThat(error).isInstanceOf(RejectedSetupException.class); RSocket acceptorSender = acceptor.senderRSocket().block(); @@ -45,20 +46,22 @@ void responderRejectSetup() { } @Test + @Disabled("FIXME: needs to be revised") void requesterStreamsTerminatedOnZeroErrorFrame() { - TestDuplexConnection conn = new TestDuplexConnection(); - List errors = new ArrayList<>(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); RSocketRequester rSocket = new RSocketRequester( - ByteBufAllocator.DEFAULT, conn, DefaultPayload::create, - errors::add, StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); String errorMsg = "error"; @@ -68,7 +71,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { .doOnRequest( ignored -> conn.addToReceivedBuffer( - ErrorFrameFlyweight.encode( + ErrorFrameCodec.encode( ByteBufAllocator.DEFAULT, 0, new RejectedSetupException(errorMsg))))) @@ -76,28 +79,28 @@ void requesterStreamsTerminatedOnZeroErrorFrame() { err -> err instanceof RejectedSetupException && errorMsg.equals(err.getMessage())) .verify(Duration.ofSeconds(5)); - assertThat(errors).hasSize(1); assertThat(rSocket.isDisposed()).isTrue(); } @Test void requesterNewStreamsTerminatedAfterZeroErrorFrame() { - TestDuplexConnection conn = new TestDuplexConnection(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); RSocketRequester rSocket = new RSocketRequester( - ByteBufAllocator.DEFAULT, conn, DefaultPayload::create, - err -> {}, StreamIdSupplier.clientSupplier(), 0, 0, + 0, null, - RequesterLeaseHandler.None); + RequesterLeaseHandler.None, + TestScheduler.INSTANCE); conn.addToReceivedBuffer( - ErrorFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); StepVerifier.create( rSocket @@ -129,7 +132,9 @@ public Mono senderRSocket() { private static class SingleConnectionTransport implements ServerTransport { - private final TestDuplexConnection conn = new TestDuplexConnection(); + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private final TestDuplexConnection conn = new TestDuplexConnection(allocator); @Override public Mono start(ConnectionAcceptor acceptor, int mtu) { @@ -146,9 +151,7 @@ public ByteBuf awaitSent() { public void connect() { Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); - ByteBuf setup = - SetupFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, false, 0, 42, "mdMime", "dMime", payload); + ByteBuf setup = SetupFrameCodec.encode(allocator, false, 0, 42, "mdMime", "dMime", payload); conn.addToReceivedBuffer(setup); } diff --git a/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java similarity index 99% rename from rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java rename to rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java index 766a6aaf7..00248b6d8 100644 --- a/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java +++ b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.rsocket; +package io.rsocket.core; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; diff --git a/rsocket-core/src/test/java/io/rsocket/TestingStuff.java b/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java similarity index 97% rename from rsocket-core/src/test/java/io/rsocket/TestingStuff.java rename to rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java index 64c790053..e0ebf5064 100644 --- a/rsocket-core/src/test/java/io/rsocket/TestingStuff.java +++ b/rsocket-core/src/test/java/io/rsocket/core/TestingStuff.java @@ -1,4 +1,4 @@ -package io.rsocket; +package io.rsocket.core; import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; @@ -16,6 +16,6 @@ public void testStuff() { ByteBuf byteBuf = Unpooled.wrappedBuffer(ByteBufUtil.decodeHexDump(f1)); System.out.println(ByteBufUtil.prettyHexDump(byteBuf)); - ConnectionSetupPayload.create(byteBuf); + new DefaultConnectionSetupPayload(byteBuf); } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java index c7bbfadf6..b3f596a37 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java @@ -16,131 +16,220 @@ package io.rsocket.exceptions; +import static io.rsocket.frame.ErrorFrameCodec.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.CANCELED; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.INVALID; +import static io.rsocket.frame.ErrorFrameCodec.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.UNSUPPORTED_SETUP; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + final class ExceptionsTest { - /* @DisplayName("from returns ApplicationErrorException") @Test void fromApplicationException() { - ByteBuf byteBuf = createErrorFrame(APPLICATION_ERROR, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, APPLICATION_ERROR, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(ApplicationErrorException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); } @DisplayName("from returns CanceledException") @Test void fromCanceledException() { - ByteBuf byteBuf = createErrorFrame(CANCELED, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, CANCELED, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(CanceledException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); } @DisplayName("from returns ConnectionCloseException") @Test void fromConnectionCloseException() { - ByteBuf byteBuf = createErrorFrame(CONNECTION_CLOSE, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_CLOSE, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(ConnectionCloseException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); } @DisplayName("from returns ConnectionErrorException") @Test void fromConnectionErrorException() { - ByteBuf byteBuf = createErrorFrame(CONNECTION_ERROR, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_ERROR, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(ConnectionErrorException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); } @DisplayName("from returns IllegalArgumentException if error frame has illegal error code") @Test void fromIllegalErrorFrame() { - ByteBuf byteBuf = createErrorFrame(0x00000000, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, 0x00000000, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) - .isInstanceOf(IllegalArgumentException.class) - .withFailMessage("Invalid Error frame: %d, '%s'", 0, "test-message"); + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") + .isInstanceOf(IllegalArgumentException.class); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns InvalidException") @Test void fromInvalidException() { - ByteBuf byteBuf = createErrorFrame(INVALID, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, INVALID, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(InvalidException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns InvalidSetupException") @Test void fromInvalidSetupException() { - ByteBuf byteBuf = createErrorFrame(INVALID_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, INVALID_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(InvalidSetupException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedException") @Test void fromRejectedException() { - ByteBuf byteBuf = createErrorFrame(REJECTED, "test-message"); + ByteBuf byteBuf = createErrorFrame(1, REJECTED, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(1, byteBuf)) .isInstanceOf(RejectedException.class) .withFailMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedResumeException") @Test void fromRejectedResumeException() { - ByteBuf byteBuf = createErrorFrame(REJECTED_RESUME, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, REJECTED_RESUME, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(RejectedResumeException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns RejectedSetupException") @Test void fromRejectedSetupException() { - ByteBuf byteBuf = createErrorFrame(REJECTED_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, REJECTED_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(RejectedSetupException.class) .withFailMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); } @DisplayName("from returns UnsupportedSetupException") @Test void fromUnsupportedSetupException() { - ByteBuf byteBuf = createErrorFrame(UNSUPPORTED_SETUP, "test-message"); + ByteBuf byteBuf = createErrorFrame(0, UNSUPPORTED_SETUP, "test-message"); - assertThat(Exceptions.from(Frame.from(byteBuf))) + assertThat(Exceptions.from(0, byteBuf)) .isInstanceOf(UnsupportedSetupException.class) - .withFailMessage("test-message"); + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } + + @DisplayName("from returns CustomRSocketException") + @Test + void fromCustomRSocketException() { + for (int i = 0; i < 1000; i++) { + int randomCode = + ThreadLocalRandom.current().nextBoolean() + ? ThreadLocalRandom.current() + .nextInt(Integer.MIN_VALUE, ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE) + : ThreadLocalRandom.current() + .nextInt(ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE, Integer.MAX_VALUE); + ByteBuf byteBuf = createErrorFrame(0, randomCode, "test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } } @DisplayName("from throws NullPointerException with null frame") @Test void fromWithNullFrame() { assertThatNullPointerException() - .isThrownBy(() -> Exceptions.from(null)) + .isThrownBy(() -> Exceptions.from(0, null)) .withMessage("frame must not be null"); } - private ByteBuf createErrorFrame(int errorCode, String message) { - ByteBuf byteBuf = Unpooled.buffer(); - - ErrorFrameFlyweight.encode(byteBuf, 0, errorCode, Unpooled.copiedBuffer(message, UTF_8)); - - return byteBuf; - }*/ + private ByteBuf createErrorFrame(int streamId, int errorCode, String message) { + return ErrorFrameCodec.encode( + UnpooledByteBufAllocator.DEFAULT, streamId, new TestRSocketException(errorCode, message)); + } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java index 8c39e8250..ccf7649d2 100644 --- a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java @@ -17,27 +17,22 @@ package io.rsocket.exceptions; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; interface RSocketExceptionTest { - @DisplayName("constructor throws NullPointerException with null message") + @DisplayName("constructor does not throw NullPointerException with null message") @Test default void constructorWithNullMessage() { - assertThatNullPointerException() - .isThrownBy(() -> getException(null)) - .withMessage("message must not be null"); + assertThat(getException(null)).hasMessage(null); } - @DisplayName("constructor throws NullPointerException with null message and cause") + @DisplayName("constructor does not throw NullPointerException with null message and cause") @Test default void constructorWithNullMessageAndCause() { - assertThatNullPointerException() - .isThrownBy(() -> getException(null, new Exception())) - .withMessage("message must not be null"); + assertThat(getException(null)).hasMessage(null); } @DisplayName("errorCode returns specified value") diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java new file mode 100644 index 000000000..6c2e63730 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java @@ -0,0 +1,39 @@ +package io.rsocket.exceptions; + +public class TestRSocketException extends RSocketException { + private static final long serialVersionUID = 7873267740343446585L; + + private final int errorCode; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code + * @param message the message + * @throws NullPointerException if {@code message} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message) { + super(message); + this.errorCode = errorCode; + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code + * @param message the message + * @param cause the cause of this exception + * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message, Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + } + + @Override + public int errorCode() { + return errorCode; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java index 3d96bfd12..932df4283 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,18 +22,16 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.*; -import io.rsocket.util.DefaultPayload; -import java.util.Arrays; -import java.util.List; import java.util.concurrent.ThreadLocalRandom; import org.junit.Assert; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -50,20 +48,22 @@ final class FragmentationDuplexConnectionTest { private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + { + Mockito.when(delegate.onClose()).thenReturn(Mono.never()); + } + @SuppressWarnings("unchecked") private final ArgumentCaptor> publishers = ArgumentCaptor.forClass(Publisher.class); - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); @DisplayName("constructor throws IllegalArgumentException with negative maxFragmentLength") @Test void constructorInvalidMaxFragmentSize() { assertThatIllegalArgumentException() - .isThrownBy( - () -> - new FragmentationDuplexConnection( - delegate, allocator, Integer.MIN_VALUE, false, "")) + .isThrownBy(() -> new FragmentationDuplexConnection(delegate, Integer.MIN_VALUE, false, "")) .withMessage("smallest allowed mtu size is 64 bytes, provided: -2147483648"); } @@ -71,246 +71,29 @@ void constructorInvalidMaxFragmentSize() { @Test void constructorMtuLessThanMin() { assertThatIllegalArgumentException() - .isThrownBy(() -> new FragmentationDuplexConnection(delegate, allocator, 2, false, "")) + .isThrownBy(() -> new FragmentationDuplexConnection(delegate, 2, false, "")) .withMessage("smallest allowed mtu size is 64 bytes, provided: 2"); } - @DisplayName("constructor throws NullPointerException with null byteBufAllocator") - @Test - void constructorNullByteBufAllocator() { - assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(delegate, null, 64, false, "")) - .withMessage("byteBufAllocator must not be null"); - } - @DisplayName("constructor throws NullPointerException with null delegate") @Test void constructorNullDelegate() { assertThatNullPointerException() - .isThrownBy(() -> new FragmentationDuplexConnection(null, allocator, 64, false, "")) + .isThrownBy(() -> new FragmentationDuplexConnection(null, 64, false, "")) .withMessage("delegate must not be null"); } - @DisplayName("reassembles data") - @Test - void reassembleData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, false, false, true, DefaultPayload.create(data))); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata") - @Test - void reassembleMetadata() { - List byteBufs = - Arrays.asList( - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - false, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestResponseFrameFlyweight.metadata(byteBuf); - Assert.assertEquals(metadata, m); - }) - .verifyComplete(); - } - - @DisplayName("reassembles metadata and data") - @Test - void reassembleMetadataAndData() { - List byteBufs = - Arrays.asList( - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, - 1, - true, - false, - true, - DefaultPayload.create( - Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, 1, false, false, true, DefaultPayload.create(data))); - - CompositeByteBuf data = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.data)); - - CompositeByteBuf metadata = - allocator - .compositeDirectBuffer() - .addComponents( - true, - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata), - Unpooled.wrappedBuffer(FragmentationDuplexConnectionTest.metadata)); - - when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); - Assert.assertEquals(metadata, RequestResponseFrameFlyweight.metadata(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("does not reassemble a non-fragment frame") - @Test - void reassembleNonFragment() { - ByteBuf encode = - RequestResponseFrameFlyweight.encode( - allocator, 1, false, DefaultPayload.create(Unpooled.wrappedBuffer(data))); - - when(delegate.receive()).thenReturn(Flux.just(encode)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals( - Unpooled.wrappedBuffer(data), RequestResponseFrameFlyweight.data(byteBuf)); - }) - .verifyComplete(); - } - - @DisplayName("does not reassemble non fragmentable frame") - @Test - void reassembleNonFragmentableFrame() { - ByteBuf encode = CancelFrameFlyweight.encode(allocator, 2); - - when(delegate.receive()).thenReturn(Flux.just(encode)); - when(delegate.onClose()).thenReturn(Mono.never()); - - new FragmentationDuplexConnection(delegate, allocator, 1030, false, "") - .receive() - .as(StepVerifier::create) - .assertNext( - byteBuf -> { - Assert.assertEquals(FrameType.CANCEL, FrameHeaderFlyweight.frameType(byteBuf)); - }) - .verifyComplete(); - } - @DisplayName("fragments data") @Test void sendData() { ByteBuf encode = - RequestResponseFrameFlyweight.encode( + RequestResponseFrameCodec.encode( allocator, 1, false, Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(data)); when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); - new FragmentationDuplexConnection(delegate, allocator, 64, false, "").sendOne(encode.retain()); + new FragmentationDuplexConnection(delegate, 64, false, "").sendOne(encode.retain()); verify(delegate).send(publishers.capture()); @@ -318,8 +101,8 @@ void sendData() { .expectNextCount(17) .assertNext( byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderFlyweight.hasFollows(byteBuf)); + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); }) .verifyComplete(); } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java index 984207936..ff62b56f2 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationIntegrationTest.java @@ -2,9 +2,9 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameUtil; -import io.rsocket.frame.PayloadFrameFlyweight; +import io.rsocket.frame.PayloadFrameCodec; import io.rsocket.util.DefaultPayload; import java.util.concurrent.ThreadLocalRandom; import org.junit.Assert; @@ -28,14 +28,15 @@ public class FragmentationIntegrationTest { @Test void fragmentAndReassembleData() { ByteBuf frame = - PayloadFrameFlyweight.encodeNextComplete(allocator, 2, DefaultPayload.create(data)); + PayloadFrameCodec.encodeNextCompleteReleasingPayload( + allocator, 2, DefaultPayload.create(data)); System.out.println(FrameUtil.toString(frame)); frame.retain(); Publisher fragments = FrameFragmenter.fragmentFrame( - allocator, 64, frame, FrameHeaderFlyweight.frameType(frame), false); + allocator, 64, frame, FrameHeaderCodec.frameType(frame), false); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -49,9 +50,8 @@ void fragmentAndReassembleData() { String s = FrameUtil.toString(assembled); System.out.println(s); - Assert.assertEquals( - FrameHeaderFlyweight.frameType(frame), FrameHeaderFlyweight.frameType(assembled)); + Assert.assertEquals(FrameHeaderCodec.frameType(frame), FrameHeaderCodec.frameType(assembled)); Assert.assertEquals(frame.readableBytes(), assembled.readableBytes()); - Assert.assertEquals(PayloadFrameFlyweight.data(frame), PayloadFrameFlyweight.data(assembled)); + Assert.assertEquals(PayloadFrameCodec.data(frame), PayloadFrameCodec.data(assembled)); } } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java index f5a013357..60dbef74b 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java @@ -20,7 +20,6 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; import io.rsocket.frame.*; -import io.rsocket.util.DefaultPayload; import java.util.concurrent.ThreadLocalRandom; import org.junit.Assert; import org.junit.jupiter.api.DisplayName; @@ -43,14 +42,15 @@ final class FrameFragmenterTest { @Test void testGettingData() { ByteBuf rr = - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); ByteBuf fnf = - RequestFireAndForgetFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestFireAndForgetFrameCodec.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)); ByteBuf rs = - RequestStreamFrameFlyweight.encode(allocator, 1, true, 1, DefaultPayload.create(data)); + RequestStreamFrameCodec.encode(allocator, 1, true, 1, null, Unpooled.wrappedBuffer(data)); ByteBuf rc = - RequestChannelFrameFlyweight.encode( - allocator, 1, true, false, 1, DefaultPayload.create(data)); + RequestChannelFrameCodec.encode( + allocator, 1, true, false, 1, null, Unpooled.wrappedBuffer(data)); ByteBuf data = FrameFragmenter.getData(rr, FrameType.REQUEST_RESPONSE); Assert.assertEquals(data, Unpooled.wrappedBuffer(data)); @@ -72,17 +72,23 @@ void testGettingData() { @Test void testGettingMetadata() { ByteBuf rr = - RequestResponseFrameFlyweight.encode( - allocator, 1, true, DefaultPayload.create(data, metadata)); + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); ByteBuf fnf = - RequestFireAndForgetFrameFlyweight.encode( - allocator, 1, true, DefaultPayload.create(data, metadata)); + RequestFireAndForgetFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); ByteBuf rs = - RequestStreamFrameFlyweight.encode( - allocator, 1, true, 1, DefaultPayload.create(data, metadata)); + RequestStreamFrameCodec.encode( + allocator, 1, true, 1, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); ByteBuf rc = - RequestChannelFrameFlyweight.encode( - allocator, 1, true, false, 1, DefaultPayload.create(data, metadata)); + RequestChannelFrameCodec.encode( + allocator, + 1, + true, + false, + 1, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)); ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); Assert.assertEquals(data, Unpooled.wrappedBuffer(metadata)); @@ -104,7 +110,7 @@ void testGettingMetadata() { @Test void returnEmptBufferWhenNoMetadataPresent() { ByteBuf rr = - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); ByteBuf data = FrameFragmenter.getMetadata(rr, FrameType.REQUEST_RESPONSE); Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); @@ -115,7 +121,7 @@ void returnEmptBufferWhenNoMetadataPresent() { @Test void encodeFirstFrameWithData() { ByteBuf rr = - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -128,23 +134,23 @@ void encodeFirstFrameWithData() { Unpooled.wrappedBuffer(data)); Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderFlyweight.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - ByteBuf data = RequestResponseFrameFlyweight.data(fragment); + ByteBuf data = RequestResponseFrameCodec.data(fragment); ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); Assert.assertEquals(byteBuf, data); - Assert.assertFalse(FrameHeaderFlyweight.hasMetadata(fragment)); + Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); } @DisplayName("encode first channel frame") @Test void encodeFirstWithDataChannel() { ByteBuf rc = - RequestChannelFrameFlyweight.encode( - allocator, 1, true, false, 10, DefaultPayload.create(data)); + RequestChannelFrameCodec.encode( + allocator, 1, true, false, 10, null, Unpooled.wrappedBuffer(data)); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -157,23 +163,23 @@ void encodeFirstWithDataChannel() { Unpooled.wrappedBuffer(data)); Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_CHANNEL, FrameHeaderFlyweight.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); - Assert.assertEquals(10, RequestChannelFrameFlyweight.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + Assert.assertEquals(FrameType.REQUEST_CHANNEL, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertEquals(10, RequestChannelFrameCodec.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - ByteBuf data = RequestChannelFrameFlyweight.data(fragment); + ByteBuf data = RequestChannelFrameCodec.data(fragment); ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); Assert.assertEquals(byteBuf, data); - Assert.assertFalse(FrameHeaderFlyweight.hasMetadata(fragment)); + Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); } @DisplayName("encode first stream frame") @Test void encodeFirstWithDataStream() { ByteBuf rc = - RequestStreamFrameFlyweight.encode(allocator, 1, true, 50, DefaultPayload.create(data)); + RequestStreamFrameCodec.encode(allocator, 1, true, 50, null, Unpooled.wrappedBuffer(data)); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -186,27 +192,24 @@ void encodeFirstWithDataStream() { Unpooled.wrappedBuffer(data)); Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderFlyweight.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); - Assert.assertEquals(50, RequestStreamFrameFlyweight.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertEquals(50, RequestStreamFrameCodec.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - ByteBuf data = RequestStreamFrameFlyweight.data(fragment); + ByteBuf data = RequestStreamFrameCodec.data(fragment); ByteBuf byteBuf = Unpooled.wrappedBuffer(this.data).readSlice(data.readableBytes()); Assert.assertEquals(byteBuf, data); - Assert.assertFalse(FrameHeaderFlyweight.hasMetadata(fragment)); + Assert.assertFalse(FrameHeaderCodec.hasMetadata(fragment)); } @DisplayName("encode first frame with only metadata") @Test void encodeFirstFrameWithMetadata() { ByteBuf rr = - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))); + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -219,22 +222,22 @@ void encodeFirstFrameWithMetadata() { Unpooled.EMPTY_BUFFER); Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderFlyweight.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - ByteBuf data = RequestResponseFrameFlyweight.data(fragment); + ByteBuf data = RequestResponseFrameCodec.data(fragment); Assert.assertEquals(data, Unpooled.EMPTY_BUFFER); - Assert.assertTrue(FrameHeaderFlyweight.hasMetadata(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasMetadata(fragment)); } @DisplayName("encode first stream frame with data and metadata") @Test void encodeFirstWithDataAndMetadataStream() { ByteBuf rc = - RequestStreamFrameFlyweight.encode( - allocator, 1, true, 50, DefaultPayload.create(data, metadata)); + RequestStreamFrameCodec.encode( + allocator, 1, true, 50, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); ByteBuf fragment = FrameFragmenter.encodeFirstFragment( @@ -247,26 +250,26 @@ void encodeFirstWithDataAndMetadataStream() { Unpooled.wrappedBuffer(data)); Assert.assertEquals(256, fragment.readableBytes()); - Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderFlyweight.frameType(fragment)); - Assert.assertEquals(1, FrameHeaderFlyweight.streamId(fragment)); - Assert.assertEquals(50, RequestStreamFrameFlyweight.initialRequestN(fragment)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(fragment)); + Assert.assertEquals(FrameType.REQUEST_STREAM, FrameHeaderCodec.frameType(fragment)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(fragment)); + Assert.assertEquals(50, RequestStreamFrameCodec.initialRequestN(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(fragment)); - ByteBuf data = RequestStreamFrameFlyweight.data(fragment); + ByteBuf data = RequestStreamFrameCodec.data(fragment); Assert.assertEquals(0, data.readableBytes()); - ByteBuf metadata = RequestStreamFrameFlyweight.metadata(fragment); + ByteBuf metadata = RequestStreamFrameCodec.metadata(fragment); ByteBuf byteBuf = Unpooled.wrappedBuffer(this.metadata).readSlice(metadata.readableBytes()); Assert.assertEquals(byteBuf, metadata); - Assert.assertTrue(FrameHeaderFlyweight.hasMetadata(fragment)); + Assert.assertTrue(FrameHeaderCodec.hasMetadata(fragment)); } @DisplayName("fragments frame with only data") @Test void fragmentData() { ByteBuf rr = - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)); + RequestResponseFrameCodec.encode(allocator, 1, true, null, Unpooled.wrappedBuffer(data)); Publisher fragments = FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE, false); @@ -275,15 +278,15 @@ void fragmentData() { .expectNextCount(1) .assertNext( byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); - Assert.assertEquals(1, FrameHeaderFlyweight.streamId(byteBuf)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(byteBuf)); + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(byteBuf)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); }) .expectNextCount(2) .assertNext( byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderFlyweight.hasFollows(byteBuf)); + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); }) .verifyComplete(); } @@ -292,12 +295,8 @@ void fragmentData() { @Test void fragmentMetadata() { ByteBuf rr = - RequestStreamFrameFlyweight.encode( - allocator, - 1, - true, - 10, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))); + RequestStreamFrameCodec.encode( + allocator, 1, true, 10, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER); Publisher fragments = FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_STREAM, false); @@ -306,15 +305,15 @@ void fragmentMetadata() { .expectNextCount(1) .assertNext( byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); - Assert.assertEquals(1, FrameHeaderFlyweight.streamId(byteBuf)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(byteBuf)); + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertEquals(1, FrameHeaderCodec.streamId(byteBuf)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); }) .expectNextCount(2) .assertNext( byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderFlyweight.hasFollows(byteBuf)); + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); }) .verifyComplete(); } @@ -323,8 +322,8 @@ void fragmentMetadata() { @Test void fragmentDataAndMetadata() { ByteBuf rr = - RequestResponseFrameFlyweight.encode( - allocator, 1, true, DefaultPayload.create(data, metadata)); + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.wrappedBuffer(data)); Publisher fragments = FrameFragmenter.fragmentFrame(allocator, 1024, rr, FrameType.REQUEST_RESPONSE, false); @@ -332,20 +331,19 @@ void fragmentDataAndMetadata() { StepVerifier.create(Flux.from(fragments).doOnError(Throwable::printStackTrace)) .assertNext( byteBuf -> { - Assert.assertEquals( - FrameType.REQUEST_RESPONSE, FrameHeaderFlyweight.frameType(byteBuf)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(byteBuf)); + Assert.assertEquals(FrameType.REQUEST_RESPONSE, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); }) .expectNextCount(6) .assertNext( byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); - Assert.assertTrue(FrameHeaderFlyweight.hasFollows(byteBuf)); + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertTrue(FrameHeaderCodec.hasFollows(byteBuf)); }) .assertNext( byteBuf -> { - Assert.assertEquals(FrameType.NEXT, FrameHeaderFlyweight.frameType(byteBuf)); - Assert.assertFalse(FrameHeaderFlyweight.hasFollows(byteBuf)); + Assert.assertEquals(FrameType.NEXT, FrameHeaderCodec.frameType(byteBuf)); + Assert.assertFalse(FrameHeaderCodec.hasFollows(byteBuf)); }) .verifyComplete(); } diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java index 6e0d0dc1b..56f7fcf90 100644 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java @@ -22,7 +22,6 @@ import io.netty.buffer.Unpooled; import io.netty.util.ReferenceCountUtil; import io.rsocket.frame.*; -import io.rsocket.util.DefaultPayload; import java.util.Arrays; import java.util.List; import java.util.concurrent.ThreadLocalRandom; @@ -48,15 +47,16 @@ final class FrameReassemblerTest { void reassembleData() { List byteBufs = Arrays.asList( - RequestResponseFrameFlyweight.encode(allocator, 1, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, true, false, true, DefaultPayload.create(data)), - PayloadFrameFlyweight.encode( - allocator, 1, false, false, true, DefaultPayload.create(data))); + RequestResponseFrameCodec.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -76,7 +76,7 @@ void reassembleData() { StepVerifier.create(assembled) .assertNext( byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); ReferenceCountUtil.safeRelease(byteBuf); }) .verifyComplete(); @@ -88,7 +88,8 @@ void reassembleData() { void passthrough() { List byteBufs = Arrays.asList( - RequestResponseFrameFlyweight.encode(allocator, 1, false, DefaultPayload.create(data))); + RequestResponseFrameCodec.encode( + allocator, 1, false, null, Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -102,7 +103,7 @@ void passthrough() { StepVerifier.create(assembled) .assertNext( byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); ReferenceCountUtil.safeRelease(byteBuf); }) .verifyComplete(); @@ -114,39 +115,40 @@ void passthrough() { void reassembleMetadata() { List byteBufs = Arrays.asList( - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, false, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -167,7 +169,7 @@ void reassembleMetadata() { .assertNext( byteBuf -> { System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestResponseFrameFlyweight.metadata(byteBuf); + ByteBuf m = RequestResponseFrameCodec.metadata(byteBuf); Assert.assertEquals(metadata, m); }) .verifyComplete(); @@ -178,41 +180,46 @@ void reassembleMetadata() { void reassembleMetadataChannel() { List byteBufs = Arrays.asList( - RequestChannelFrameFlyweight.encode( + RequestChannelFrameCodec.encode( allocator, 1, true, false, 100, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, false, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -233,9 +240,9 @@ void reassembleMetadataChannel() { .assertNext( byteBuf -> { System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestChannelFrameFlyweight.metadata(byteBuf); + ByteBuf m = RequestChannelFrameCodec.metadata(byteBuf); Assert.assertEquals(metadata, m); - Assert.assertEquals(100, RequestChannelFrameFlyweight.initialRequestN(byteBuf)); + Assert.assertEquals(100, RequestChannelFrameCodec.initialRequestN(byteBuf)); ReferenceCountUtil.safeRelease(byteBuf); }) .verifyComplete(); @@ -248,40 +255,40 @@ void reassembleMetadataChannel() { void reassembleMetadataStream() { List byteBufs = Arrays.asList( - RequestStreamFrameFlyweight.encode( - allocator, - 1, - true, - 250, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + RequestStreamFrameCodec.encode( + allocator, 1, true, 250, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, false, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -302,9 +309,9 @@ void reassembleMetadataStream() { .assertNext( byteBuf -> { System.out.println(byteBuf.readableBytes()); - ByteBuf m = RequestStreamFrameFlyweight.metadata(byteBuf); + ByteBuf m = RequestStreamFrameCodec.metadata(byteBuf); Assert.assertEquals(metadata, m); - Assert.assertEquals(250, RequestChannelFrameFlyweight.initialRequestN(byteBuf)); + Assert.assertEquals(250, RequestChannelFrameCodec.initialRequestN(byteBuf)); ReferenceCountUtil.safeRelease(byteBuf); }) .verifyComplete(); @@ -318,35 +325,34 @@ void reassembleMetadataAndData() { List byteBufs = Arrays.asList( - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create( - Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( - allocator, 1, false, false, true, DefaultPayload.create(data))); + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); @@ -373,8 +379,8 @@ void reassembleMetadataAndData() { StepVerifier.create(assembled) .assertNext( byteBuf -> { - Assert.assertEquals(data, RequestResponseFrameFlyweight.data(byteBuf)); - Assert.assertEquals(metadata, RequestResponseFrameFlyweight.metadata(byteBuf)); + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); + Assert.assertEquals(metadata, RequestResponseFrameCodec.metadata(byteBuf)); }) .verifyComplete(); ReferenceCountUtil.safeRelease(data); @@ -386,33 +392,32 @@ void reassembleMetadataAndData() { public void cancelBeforeAssembling() { List byteBufs = Arrays.asList( - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create( - Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); @@ -421,7 +426,7 @@ public void cancelBeforeAssembling() { Assert.assertTrue(reassembler.metadata.containsKey(1)); Assert.assertTrue(reassembler.data.containsKey(1)); - Flux.just(CancelFrameFlyweight.encode(allocator, 1)) + Flux.just(CancelFrameCodec.encode(allocator, 1)) .handle(reassembler::reassembleFrame) .blockLast(); @@ -435,33 +440,32 @@ public void cancelBeforeAssembling() { public void dispose() { List byteBufs = Arrays.asList( - RequestResponseFrameFlyweight.encode( - allocator, - 1, - true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create(Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer(metadata))), - PayloadFrameFlyweight.encode( + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( allocator, 1, true, false, true, - DefaultPayload.create( - Unpooled.wrappedBuffer(data), Unpooled.wrappedBuffer(metadata)))); + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data))); FrameReassembler reassembler = new FrameReassembler(allocator); Flux.fromIterable(byteBufs).handle(reassembler::reassembleFrame).blockLast(); diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java new file mode 100644 index 000000000..b083d6841 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/fragmentation/ReassembleDuplexConnectionTest.java @@ -0,0 +1,272 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.fragmentation; + +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.Assert; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +final class ReassembleDuplexConnectionTest { + private static byte[] data = new byte[1024]; + private static byte[] metadata = new byte[1024]; + + static { + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + } + + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + @DisplayName("reassembles data") + @Test + void reassembleData() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, true, false, true, null, Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("reassembles metadata") + @Test + void reassembleMetadata() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + System.out.println(byteBuf.readableBytes()); + ByteBuf m = RequestResponseFrameCodec.metadata(byteBuf); + Assert.assertEquals(metadata, m); + }) + .verifyComplete(); + } + + @DisplayName("reassembles metadata and data") + @Test + void reassembleMetadataAndData() { + List byteBufs = + Arrays.asList( + RequestResponseFrameCodec.encode( + allocator, 1, true, Unpooled.wrappedBuffer(metadata), Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.EMPTY_BUFFER), + PayloadFrameCodec.encode( + allocator, + 1, + true, + false, + true, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)), + PayloadFrameCodec.encode( + allocator, 1, false, false, true, null, Unpooled.wrappedBuffer(data))); + + CompositeByteBuf data = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.data)); + + CompositeByteBuf metadata = + allocator + .compositeDirectBuffer() + .addComponents( + true, + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata), + Unpooled.wrappedBuffer(ReassembleDuplexConnectionTest.metadata)); + + when(delegate.receive()).thenReturn(Flux.fromIterable(byteBufs)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(data, RequestResponseFrameCodec.data(byteBuf)); + Assert.assertEquals(metadata, RequestResponseFrameCodec.metadata(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("does not reassemble a non-fragment frame") + @Test + void reassembleNonFragment() { + ByteBuf encode = + RequestResponseFrameCodec.encode(allocator, 1, false, null, Unpooled.wrappedBuffer(data)); + + when(delegate.receive()).thenReturn(Flux.just(encode)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals( + Unpooled.wrappedBuffer(data), RequestResponseFrameCodec.data(byteBuf)); + }) + .verifyComplete(); + } + + @DisplayName("does not reassemble non fragmentable frame") + @Test + void reassembleNonFragmentableFrame() { + ByteBuf encode = CancelFrameCodec.encode(allocator, 2); + + when(delegate.receive()).thenReturn(Flux.just(encode)); + when(delegate.onClose()).thenReturn(Mono.never()); + when(delegate.alloc()).thenReturn(allocator); + + new ReassemblyDuplexConnection(delegate, false) + .receive() + .as(StepVerifier::create) + .assertNext( + byteBuf -> { + Assert.assertEquals(FrameType.CANCEL, FrameHeaderCodec.frameType(byteBuf)); + }) + .verifyComplete(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java index b22a95c0b..63300c718 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -17,6 +17,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; import org.assertj.core.presentation.StandardRepresentation; public final class ByteBufRepresentation extends StandardRepresentation { @@ -24,7 +25,17 @@ public final class ByteBufRepresentation extends StandardRepresentation { @Override protected String fallbackToStringOf(Object object) { if (object instanceof ByteBuf) { - return ByteBufUtil.prettyHexDump((ByteBuf) object); + try { + String normalBufferString = object.toString(); + String prettyHexDump = ByteBufUtil.prettyHexDump((ByteBuf) object); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } catch (IllegalReferenceCountException e) { + // noops + } } return super.fallbackToStringOf(object); diff --git a/rsocket-core/src/test/java/io/rsocket/frame/DataAndMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/DataAndMetadataFlyweightTest.java deleted file mode 100644 index 6f9113d73..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/DataAndMetadataFlyweightTest.java +++ /dev/null @@ -1,51 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.*; -import org.junit.jupiter.api.Test; - -class DataAndMetadataFlyweightTest { - @Test - void testEncodeData() { - ByteBuf header = FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 1, FrameType.PAYLOAD, 0); - ByteBuf data = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm data_"); - ByteBuf frame = DataAndMetadataFlyweight.encodeOnlyData(ByteBufAllocator.DEFAULT, header, data); - ByteBuf d = DataAndMetadataFlyweight.data(frame, false); - String s = ByteBufUtil.prettyHexDump(d); - System.out.println(s); - } - - @Test - void testEncodeMetadata() { - ByteBuf header = FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 1, FrameType.PAYLOAD, 0); - ByteBuf data = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm metadata_"); - ByteBuf frame = - DataAndMetadataFlyweight.encodeOnlyMetadata(ByteBufAllocator.DEFAULT, header, data); - ByteBuf d = DataAndMetadataFlyweight.data(frame, false); - String s = ByteBufUtil.prettyHexDump(d); - System.out.println(s); - } - - @Test - void testEncodeDataAndMetadata() { - ByteBuf header = - FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 1, FrameType.REQUEST_RESPONSE, 0); - ByteBuf data = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm data_"); - ByteBuf metadata = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm metadata_"); - ByteBuf frame = - DataAndMetadataFlyweight.encode(ByteBufAllocator.DEFAULT, header, metadata, data); - ByteBuf m = DataAndMetadataFlyweight.metadata(frame, true); - String s = ByteBufUtil.prettyHexDump(m); - System.out.println(s); - FrameType frameType = FrameHeaderFlyweight.frameType(frame); - System.out.println(frameType); - - for (int i = 0; i < 10_000_000; i++) { - ByteBuf d1 = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm data_"); - ByteBuf m1 = ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "_I'm metadata_"); - ByteBuf h1 = - FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 1, FrameType.REQUEST_RESPONSE, 0); - ByteBuf f1 = DataAndMetadataFlyweight.encode(ByteBufAllocator.DEFAULT, h1, m1, d1); - f1.release(); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java similarity index 65% rename from rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java index fa663432c..dc04c1141 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java @@ -8,13 +8,13 @@ import io.rsocket.exceptions.ApplicationErrorException; import org.junit.jupiter.api.Test; -class ErrorFrameFlyweightTest { +class ErrorFrameCodecTest { @Test void testEncode() { ByteBuf frame = - ErrorFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 1, new ApplicationErrorException("d")); + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, new ApplicationErrorException("d")); - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); assertEquals("00000b000000012c000000020164", ByteBufUtil.hexDump(frame)); frame.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java new file mode 100644 index 000000000..28209393e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java @@ -0,0 +1,62 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ExtensionFrameCodecTest { + + @Test + void extensionDataMetadata() { + ByteBuf metadata = bytebuf("md"); + ByteBuf data = bytebuf("d"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, extendedType, metadata, data); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertEquals(metadata, ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(data, ExtensionFrameCodec.data(extension)); + extension.release(); + } + + @Test + void extensionData() { + ByteBuf data = bytebuf("d"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, extendedType, null, data); + + Assertions.assertFalse(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertNull(ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(data, ExtensionFrameCodec.data(extension)); + extension.release(); + } + + @Test + void extensionMetadata() { + ByteBuf metadata = bytebuf("md"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode( + ByteBufAllocator.DEFAULT, 1, extendedType, metadata, Unpooled.EMPTY_BUFFER); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertEquals(metadata, ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(0, ExtensionFrameCodec.data(extension).readableBytes()); + extension.release(); + } + + private static ByteBuf bytebuf(String str) { + return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java deleted file mode 100644 index e337d4332..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameFlyweightTest.java +++ /dev/null @@ -1,62 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import java.nio.charset.StandardCharsets; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class ExtensionFrameFlyweightTest { - - @Test - void extensionDataMetadata() { - ByteBuf metadata = bytebuf("md"); - ByteBuf data = bytebuf("d"); - int extendedType = 1; - - ByteBuf extension = - ExtensionFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 1, extendedType, metadata, data); - - Assertions.assertTrue(FrameHeaderFlyweight.hasMetadata(extension)); - Assertions.assertEquals(extendedType, ExtensionFrameFlyweight.extendedType(extension)); - Assertions.assertEquals(metadata, ExtensionFrameFlyweight.metadata(extension)); - Assertions.assertEquals(data, ExtensionFrameFlyweight.data(extension)); - extension.release(); - } - - @Test - void extensionData() { - ByteBuf data = bytebuf("d"); - int extendedType = 1; - - ByteBuf extension = - ExtensionFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 1, extendedType, null, data); - - Assertions.assertFalse(FrameHeaderFlyweight.hasMetadata(extension)); - Assertions.assertEquals(extendedType, ExtensionFrameFlyweight.extendedType(extension)); - Assertions.assertEquals(0, ExtensionFrameFlyweight.metadata(extension).readableBytes()); - Assertions.assertEquals(data, ExtensionFrameFlyweight.data(extension)); - extension.release(); - } - - @Test - void extensionMetadata() { - ByteBuf metadata = bytebuf("md"); - int extendedType = 1; - - ByteBuf extension = - ExtensionFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, 1, extendedType, metadata, Unpooled.EMPTY_BUFFER); - - Assertions.assertTrue(FrameHeaderFlyweight.hasMetadata(extension)); - Assertions.assertEquals(extendedType, ExtensionFrameFlyweight.extendedType(extension)); - Assertions.assertEquals(metadata, ExtensionFrameFlyweight.metadata(extension)); - Assertions.assertEquals(0, ExtensionFrameFlyweight.data(extension).readableBytes()); - extension.release(); - } - - private static ByteBuf bytebuf(String str) { - return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java similarity index 52% rename from rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java index a17fcc205..15788e631 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java @@ -7,7 +7,7 @@ import io.netty.buffer.ByteBufAllocator; import org.junit.jupiter.api.Test; -class FrameHeaderFlyweightTest { +class FrameHeaderCodecTest { // Taken from spec private static final int FRAME_MAX_SIZE = 16_777_215; @@ -15,10 +15,10 @@ class FrameHeaderFlyweightTest { void typeAndFlag() { FrameType frameType = FrameType.REQUEST_FNF; int flags = 0b1110110111; - ByteBuf header = FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); + ByteBuf header = FrameHeaderCodec.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); - assertEquals(flags, FrameHeaderFlyweight.flags(header)); - assertEquals(frameType, FrameHeaderFlyweight.frameType(header)); + assertEquals(flags, FrameHeaderCodec.flags(header)); + assertEquals(frameType, FrameHeaderCodec.frameType(header)); header.release(); } @@ -26,11 +26,11 @@ void typeAndFlag() { void typeAndFlagTruncated() { FrameType frameType = FrameType.SETUP; int flags = 0b11110110111; // 1 bit too many - ByteBuf header = FrameHeaderFlyweight.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); + ByteBuf header = FrameHeaderCodec.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); - assertNotEquals(flags, FrameHeaderFlyweight.flags(header)); - assertEquals(flags & 0b0000_0011_1111_1111, FrameHeaderFlyweight.flags(header)); - assertEquals(frameType, FrameHeaderFlyweight.frameType(header)); + assertNotEquals(flags, FrameHeaderCodec.flags(header)); + assertEquals(flags & 0b0000_0011_1111_1111, FrameHeaderCodec.flags(header)); + assertEquals(frameType, FrameHeaderCodec.frameType(header)); header.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java new file mode 100644 index 000000000..ac19dc754 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java @@ -0,0 +1,264 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; + +class GenericFrameCodecTest { + @Test + void testEncoding() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length + // | | | | ⌌Encoded Metadata + // | | | | | ⌌Encoded Data + // __|________|_________|______|____|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "000010000000011900000000010000026d6464"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void testEncodingWithEmptyMetadata() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length (0) + // | | | | ⌌Encoded Data + // __|________|_________|_______|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "00000e0000000119000000000100000064"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void testEncodingWithNullMetadata() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Data + // __|________|_________|_____| + // ↓<-> ↓↓ <-> ↓↓ <-> ↓↓↓ + String expected = "00000b0000000118000000000164"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void requestResponseDataMetadata() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestResponseFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = RequestResponseFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestResponseData() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestResponseFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestResponseFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertNull(metadata); + request.release(); + } + + @Test + void requestResponseMetadata() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + ByteBuf data = RequestResponseFrameCodec.data(request); + String metadata = RequestResponseFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertTrue(data.readableBytes() == 0); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestStreamDataMetadata() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Integer.MAX_VALUE + 1L, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + String data = RequestStreamFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = RequestStreamFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals(Long.MAX_VALUE, actualRequest); + assertEquals("md", metadata); + assertEquals("d", data); + request.release(); + } + + @Test + void requestStreamData() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 42, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + String data = RequestStreamFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestStreamFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals(42L, actualRequest); + assertNull(metadata); + assertEquals("d", data); + request.release(); + } + + @Test + void requestStreamMetadata() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 42, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + ByteBuf data = RequestStreamFrameCodec.data(request); + String metadata = RequestStreamFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals(42L, actualRequest); + assertTrue(data.readableBytes() == 0); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestFnfDataAndMetadata() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestFireAndForgetFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = + RequestFireAndForgetFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestFnfData() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestFireAndForgetFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestFireAndForgetFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertNull(metadata); + request.release(); + } + + @Test + void requestFnfMetadata() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + ByteBuf data = RequestFireAndForgetFrameCodec.data(request); + String metadata = + RequestFireAndForgetFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("md", metadata); + assertTrue(data.readableBytes() == 0); + request.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java index eb55e89cd..bc013e024 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java @@ -14,18 +14,18 @@ class KeepaliveFrameFlyweightTest { @Test void canReadData() { ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); - ByteBuf frame = KeepAliveFrameFlyweight.encode(ByteBufAllocator.DEFAULT, true, 0, data); - assertTrue(KeepAliveFrameFlyweight.respondFlag(frame)); - assertEquals(data, KeepAliveFrameFlyweight.data(frame)); + ByteBuf frame = KeepAliveFrameCodec.encode(ByteBufAllocator.DEFAULT, true, 0, data); + assertTrue(KeepAliveFrameCodec.respondFlag(frame)); + assertEquals(data, KeepAliveFrameCodec.data(frame)); frame.release(); } @Test void testEncoding() { ByteBuf frame = - KeepAliveFrameFlyweight.encode( + KeepAliveFrameCodec.encode( ByteBufAllocator.DEFAULT, true, 0, Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); assertEquals("00000f000000000c80000000000000000064", ByteBufUtil.hexDump(frame)); frame.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java new file mode 100644 index 000000000..73c3bde5e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java @@ -0,0 +1,42 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class LeaseFrameCodecTest { + + @Test + void leaseMetadata() { + ByteBuf metadata = bytebuf("md"); + int ttl = 1; + int numRequests = 42; + ByteBuf lease = LeaseFrameCodec.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, metadata); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(lease)); + Assertions.assertEquals(ttl, LeaseFrameCodec.ttl(lease)); + Assertions.assertEquals(numRequests, LeaseFrameCodec.numRequests(lease)); + Assertions.assertEquals(metadata, LeaseFrameCodec.metadata(lease)); + lease.release(); + } + + @Test + void leaseAbsentMetadata() { + int ttl = 1; + int numRequests = 42; + ByteBuf lease = LeaseFrameCodec.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, null); + + Assertions.assertFalse(FrameHeaderCodec.hasMetadata(lease)); + Assertions.assertEquals(ttl, LeaseFrameCodec.ttl(lease)); + Assertions.assertEquals(numRequests, LeaseFrameCodec.numRequests(lease)); + Assertions.assertNull(LeaseFrameCodec.metadata(lease)); + lease.release(); + } + + private static ByteBuf bytebuf(String str) { + return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameFlyweightTest.java deleted file mode 100644 index 0fc0c112b..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameFlyweightTest.java +++ /dev/null @@ -1,43 +0,0 @@ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import java.nio.charset.StandardCharsets; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class LeaseFrameFlyweightTest { - - @Test - void leaseMetadata() { - ByteBuf metadata = bytebuf("md"); - int ttl = 1; - int numRequests = 42; - ByteBuf lease = - LeaseFrameFlyweight.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, metadata); - - Assertions.assertTrue(FrameHeaderFlyweight.hasMetadata(lease)); - Assertions.assertEquals(ttl, LeaseFrameFlyweight.ttl(lease)); - Assertions.assertEquals(numRequests, LeaseFrameFlyweight.numRequests(lease)); - Assertions.assertEquals(metadata, LeaseFrameFlyweight.metadata(lease)); - lease.release(); - } - - @Test - void leaseAbsentMetadata() { - int ttl = 1; - int numRequests = 42; - ByteBuf lease = LeaseFrameFlyweight.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, null); - - Assertions.assertFalse(FrameHeaderFlyweight.hasMetadata(lease)); - Assertions.assertEquals(ttl, LeaseFrameFlyweight.ttl(lease)); - Assertions.assertEquals(numRequests, LeaseFrameFlyweight.numRequests(lease)); - Assertions.assertEquals(0, LeaseFrameFlyweight.metadata(lease).readableBytes()); - lease.release(); - } - - private static ByteBuf bytebuf(String str) { - return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java index 9ef89326a..aecbb31ce 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java @@ -15,9 +15,9 @@ public class PayloadFlyweightTest { void nextCompleteDataMetadata() { Payload payload = DefaultPayload.create("d", "md"); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); - String data = PayloadFrameFlyweight.data(nextComplete).toString(StandardCharsets.UTF_8); - String metadata = PayloadFrameFlyweight.metadata(nextComplete).toString(StandardCharsets.UTF_8); + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(nextComplete).toString(StandardCharsets.UTF_8); + String metadata = PayloadFrameCodec.metadata(nextComplete).toString(StandardCharsets.UTF_8); Assertions.assertEquals("d", data); Assertions.assertEquals("md", metadata); nextComplete.release(); @@ -27,11 +27,11 @@ void nextCompleteDataMetadata() { void nextCompleteData() { Payload payload = DefaultPayload.create("d"); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); - String data = PayloadFrameFlyweight.data(nextComplete).toString(StandardCharsets.UTF_8); - ByteBuf metadata = PayloadFrameFlyweight.metadata(nextComplete); + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(nextComplete).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(nextComplete); Assertions.assertEquals("d", data); - Assertions.assertTrue(metadata.readableBytes() == 0); + Assertions.assertNull(metadata); nextComplete.release(); } @@ -42,9 +42,9 @@ void nextCompleteMetaData() { Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer("md".getBytes(StandardCharsets.UTF_8))); ByteBuf nextComplete = - PayloadFrameFlyweight.encodeNextComplete(ByteBufAllocator.DEFAULT, 1, payload); - ByteBuf data = PayloadFrameFlyweight.data(nextComplete); - String metadata = PayloadFrameFlyweight.metadata(nextComplete).toString(StandardCharsets.UTF_8); + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + ByteBuf data = PayloadFrameCodec.data(nextComplete); + String metadata = PayloadFrameCodec.metadata(nextComplete).toString(StandardCharsets.UTF_8); Assertions.assertTrue(data.readableBytes() == 0); Assertions.assertEquals("md", metadata); nextComplete.release(); @@ -53,9 +53,10 @@ void nextCompleteMetaData() { @Test void nextDataMetadata() { Payload payload = DefaultPayload.create("d", "md"); - ByteBuf next = PayloadFrameFlyweight.encodeNext(ByteBufAllocator.DEFAULT, 1, payload); - String data = PayloadFrameFlyweight.data(next).toString(StandardCharsets.UTF_8); - String metadata = PayloadFrameFlyweight.metadata(next).toString(StandardCharsets.UTF_8); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + String metadata = PayloadFrameCodec.metadata(next).toString(StandardCharsets.UTF_8); Assertions.assertEquals("d", data); Assertions.assertEquals("md", metadata); next.release(); @@ -64,11 +65,24 @@ void nextDataMetadata() { @Test void nextData() { Payload payload = DefaultPayload.create("d"); - ByteBuf next = PayloadFrameFlyweight.encodeNext(ByteBufAllocator.DEFAULT, 1, payload); - String data = PayloadFrameFlyweight.data(next).toString(StandardCharsets.UTF_8); - ByteBuf metadata = PayloadFrameFlyweight.metadata(next); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(next); Assertions.assertEquals("d", data); - Assertions.assertTrue(metadata.readableBytes() == 0); + Assertions.assertNull(metadata); + next.release(); + } + + @Test + void nextDataEmptyMetadata() { + Payload payload = DefaultPayload.create("d".getBytes(), new byte[0]); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(next); + Assertions.assertEquals("d", data); + Assertions.assertEquals(metadata.readableBytes(), 0); next.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java deleted file mode 100644 index 9acec2c81..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/RequestFlyweightTest.java +++ /dev/null @@ -1,249 +0,0 @@ -package io.rsocket.frame; - -import static org.junit.jupiter.api.Assertions.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import java.nio.charset.StandardCharsets; -import org.junit.jupiter.api.Test; - -class RequestFlyweightTest { - @Test - void testEncoding() { - ByteBuf frame = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 1, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - - assertEquals("000010000000011900000000010000026d6464", ByteBufUtil.hexDump(frame)); - frame.release(); - } - - @Test - void testEncodingWithEmptyMetadata() { - ByteBuf frame = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 1, - Unpooled.EMPTY_BUFFER, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - - assertEquals("00000e0000000119000000000100000064", ByteBufUtil.hexDump(frame)); - frame.release(); - } - - @Test - void testEncodingWithNullMetadata() { - ByteBuf frame = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 1, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); - - assertEquals("00000b0000000118000000000164", ByteBufUtil.hexDump(frame)); - frame.release(); - } - - @Test - void requestResponseDataMetadata() { - ByteBuf request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - String data = RequestResponseFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - String metadata = - RequestResponseFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("d", data); - assertEquals("md", metadata); - request.release(); - } - - @Test - void requestResponseData() { - ByteBuf request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - String data = RequestResponseFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - ByteBuf metadata = RequestResponseFrameFlyweight.metadata(request); - - assertFalse(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("d", data); - assertTrue(metadata.readableBytes() == 0); - request.release(); - } - - @Test - void requestResponseMetadata() { - ByteBuf request = - RequestResponseFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.EMPTY_BUFFER); - - ByteBuf data = RequestResponseFrameFlyweight.data(request); - String metadata = - RequestResponseFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertTrue(data.readableBytes() == 0); - assertEquals("md", metadata); - request.release(); - } - - @Test - void requestStreamDataMetadata() { - ByteBuf request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Integer.MAX_VALUE + 1L, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); - String data = RequestStreamFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - String metadata = - RequestStreamFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(Integer.MAX_VALUE, actualRequest); - assertEquals("md", metadata); - assertEquals("d", data); - request.release(); - } - - @Test - void requestStreamData() { - ByteBuf request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 42, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); - String data = RequestStreamFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - ByteBuf metadata = RequestStreamFrameFlyweight.metadata(request); - - assertFalse(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(42, actualRequest); - assertTrue(metadata.readableBytes() == 0); - assertEquals("d", data); - request.release(); - } - - @Test - void requestStreamMetadata() { - ByteBuf request = - RequestStreamFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - 42, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.EMPTY_BUFFER); - - int actualRequest = RequestStreamFrameFlyweight.initialRequestN(request); - ByteBuf data = RequestStreamFrameFlyweight.data(request); - String metadata = - RequestStreamFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals(42, actualRequest); - assertTrue(data.readableBytes() == 0); - assertEquals("md", metadata); - request.release(); - } - - @Test - void requestFnfDataAndMetadata() { - ByteBuf request = - RequestFireAndForgetFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - String data = RequestFireAndForgetFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - String metadata = - RequestFireAndForgetFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("d", data); - assertEquals("md", metadata); - request.release(); - } - - @Test - void requestFnfData() { - ByteBuf request = - RequestFireAndForgetFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - - String data = RequestFireAndForgetFrameFlyweight.data(request).toString(StandardCharsets.UTF_8); - ByteBuf metadata = RequestFireAndForgetFrameFlyweight.metadata(request); - - assertFalse(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("d", data); - assertTrue(metadata.readableBytes() == 0); - request.release(); - } - - @Test - void requestFnfMetadata() { - ByteBuf request = - RequestFireAndForgetFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, - 1, - false, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.EMPTY_BUFFER); - - ByteBuf data = RequestFireAndForgetFrameFlyweight.data(request); - String metadata = - RequestFireAndForgetFrameFlyweight.metadata(request).toString(StandardCharsets.UTF_8); - - assertTrue(FrameHeaderFlyweight.hasMetadata(request)); - assertEquals("md", metadata); - assertTrue(data.readableBytes() == 0); - request.release(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java similarity index 63% rename from rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java index 4411b98c9..e38258040 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java @@ -7,12 +7,12 @@ import io.netty.buffer.ByteBufUtil; import org.junit.jupiter.api.Test; -class RequestNFrameFlyweightTest { +class RequestNFrameCodecTest { @Test void testEncoding() { - ByteBuf frame = RequestNFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 1, 5); + ByteBuf frame = RequestNFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, 5); - frame = FrameLengthFlyweight.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); assertEquals("00000a00000001200000000005", ByteBufUtil.hexDump(frame)); frame.release(); } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java similarity index 68% rename from rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java index f8b481f05..fe05335d2 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java @@ -23,19 +23,18 @@ import org.junit.Assert; import org.junit.jupiter.api.Test; -public class ResumeFrameFlyweightTest { +public class ResumeFrameCodecTest { @Test void testEncoding() { byte[] tokenBytes = new byte[65000]; Arrays.fill(tokenBytes, (byte) 1); ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); - ByteBuf byteBuf = ResumeFrameFlyweight.encode(ByteBufAllocator.DEFAULT, token, 21, 12); - Assert.assertEquals( - ResumeFrameFlyweight.CURRENT_VERSION, ResumeFrameFlyweight.version(byteBuf)); - Assert.assertEquals(token, ResumeFrameFlyweight.token(byteBuf)); - Assert.assertEquals(21, ResumeFrameFlyweight.lastReceivedServerPos(byteBuf)); - Assert.assertEquals(12, ResumeFrameFlyweight.firstAvailableClientPos(byteBuf)); + ByteBuf byteBuf = ResumeFrameCodec.encode(ByteBufAllocator.DEFAULT, token, 21, 12); + Assert.assertEquals(ResumeFrameCodec.CURRENT_VERSION, ResumeFrameCodec.version(byteBuf)); + Assert.assertEquals(token, ResumeFrameCodec.token(byteBuf)); + Assert.assertEquals(21, ResumeFrameCodec.lastReceivedServerPos(byteBuf)); + Assert.assertEquals(12, ResumeFrameCodec.firstAvailableClientPos(byteBuf)); byteBuf.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java similarity index 51% rename from rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java index c73409ffa..33dd8eb70 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java @@ -5,12 +5,12 @@ import org.junit.Assert; import org.junit.Test; -public class ResumeOkFrameFlyweightTest { +public class ResumeOkFrameCodecTest { @Test public void testEncoding() { - ByteBuf byteBuf = ResumeOkFrameFlyweight.encode(ByteBufAllocator.DEFAULT, 42); - Assert.assertEquals(42, ResumeOkFrameFlyweight.lastReceivedClientPos(byteBuf)); + ByteBuf byteBuf = ResumeOkFrameCodec.encode(ByteBufAllocator.DEFAULT, 42); + Assert.assertEquals(42, ResumeOkFrameCodec.lastReceivedClientPos(byteBuf)); byteBuf.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java new file mode 100644 index 000000000..9607ad327 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java @@ -0,0 +1,57 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +class SetupFrameCodecTest { + @Test + void testEncodingNoResume() { + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + Payload payload = DefaultPayload.create(data, metadata); + ByteBuf frame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, false, 5, 500, "metadata_type", "data_type", payload); + + assertEquals(FrameType.SETUP, FrameHeaderCodec.frameType(frame)); + assertFalse(SetupFrameCodec.resumeEnabled(frame)); + assertEquals(0, SetupFrameCodec.resumeToken(frame).readableBytes()); + assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); + assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); + assertEquals(metadata, SetupFrameCodec.metadata(frame)); + assertEquals(data, SetupFrameCodec.data(frame)); + assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); + frame.release(); + } + + @Test + void testEncodingResume() { + byte[] tokenBytes = new byte[65000]; + Arrays.fill(tokenBytes, (byte) 1); + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + Payload payload = DefaultPayload.create(data, metadata); + ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); + ByteBuf frame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, true, 5, 500, token, "metadata_type", "data_type", payload); + + assertEquals(FrameType.SETUP, FrameHeaderCodec.frameType(frame)); + assertTrue(SetupFrameCodec.honorLease(frame)); + assertTrue(SetupFrameCodec.resumeEnabled(frame)); + assertEquals(token, SetupFrameCodec.resumeToken(frame)); + assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); + assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); + assertEquals(metadata, SetupFrameCodec.metadata(frame)); + assertEquals(data, SetupFrameCodec.data(frame)); + assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java deleted file mode 100644 index 128b3ff84..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java +++ /dev/null @@ -1,57 +0,0 @@ -package io.rsocket.frame; - -import static org.junit.jupiter.api.Assertions.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.rsocket.Payload; -import io.rsocket.util.DefaultPayload; -import java.util.Arrays; -import org.junit.jupiter.api.Test; - -class SetupFrameFlyweightTest { - @Test - void testEncodingNoResume() { - ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); - ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); - Payload payload = DefaultPayload.create(data, metadata); - ByteBuf frame = - SetupFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, false, 5, 500, "metadata_type", "data_type", payload); - - assertEquals(FrameType.SETUP, FrameHeaderFlyweight.frameType(frame)); - assertFalse(SetupFrameFlyweight.resumeEnabled(frame)); - assertNull(SetupFrameFlyweight.resumeToken(frame)); - assertEquals("metadata_type", SetupFrameFlyweight.metadataMimeType(frame)); - assertEquals("data_type", SetupFrameFlyweight.dataMimeType(frame)); - assertEquals(metadata, SetupFrameFlyweight.metadata(frame)); - assertEquals(data, SetupFrameFlyweight.data(frame)); - assertEquals(SetupFrameFlyweight.CURRENT_VERSION, SetupFrameFlyweight.version(frame)); - frame.release(); - } - - @Test - void testEncodingResume() { - byte[] tokenBytes = new byte[65000]; - Arrays.fill(tokenBytes, (byte) 1); - ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); - ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); - Payload payload = DefaultPayload.create(data, metadata); - ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); - ByteBuf frame = - SetupFrameFlyweight.encode( - ByteBufAllocator.DEFAULT, true, 5, 500, token, "metadata_type", "data_type", payload); - - assertEquals(FrameType.SETUP, FrameHeaderFlyweight.frameType(frame)); - assertTrue(SetupFrameFlyweight.honorLease(frame)); - assertTrue(SetupFrameFlyweight.resumeEnabled(frame)); - assertEquals(token, SetupFrameFlyweight.resumeToken(frame)); - assertEquals("metadata_type", SetupFrameFlyweight.metadataMimeType(frame)); - assertEquals("data_type", SetupFrameFlyweight.dataMimeType(frame)); - assertEquals(metadata, SetupFrameFlyweight.metadata(frame)); - assertEquals(data, SetupFrameFlyweight.data(frame)); - assertEquals(SetupFrameFlyweight.CURRENT_VERSION, SetupFrameFlyweight.version(frame)); - frame.release(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java similarity index 58% rename from rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java rename to rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java index 3f311c7ef..be7fb837b 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java @@ -20,29 +20,29 @@ import org.junit.jupiter.api.Test; -public class VersionFlyweightTest { +public class VersionCodecTest { @Test public void simple() { - int version = VersionFlyweight.encode(1, 0); - assertEquals(1, VersionFlyweight.major(version)); - assertEquals(0, VersionFlyweight.minor(version)); + int version = VersionCodec.encode(1, 0); + assertEquals(1, VersionCodec.major(version)); + assertEquals(0, VersionCodec.minor(version)); assertEquals(0x00010000, version); - assertEquals("1.0", VersionFlyweight.toString(version)); + assertEquals("1.0", VersionCodec.toString(version)); } @Test public void complex() { - int version = VersionFlyweight.encode(0x1234, 0x5678); - assertEquals(0x1234, VersionFlyweight.major(version)); - assertEquals(0x5678, VersionFlyweight.minor(version)); + int version = VersionCodec.encode(0x1234, 0x5678); + assertEquals(0x1234, VersionCodec.major(version)); + assertEquals(0x5678, VersionCodec.minor(version)); assertEquals(0x12345678, version); - assertEquals("4660.22136", VersionFlyweight.toString(version)); + assertEquals("4660.22136", VersionCodec.toString(version)); } @Test public void noShortOverflow() { - int version = VersionFlyweight.encode(43210, 43211); - assertEquals(43210, VersionFlyweight.major(version)); - assertEquals(43211, VersionFlyweight.minor(version)); + int version = VersionCodec.encode(43210, 43211); + assertEquals(43210, VersionCodec.major(version)); + assertEquals(43211, VersionCodec.minor(version)); } } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java index 8f56608d8..63acc40aa 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java @@ -21,8 +21,9 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.Unpooled; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.frame.*; -import io.rsocket.plugins.PluginRegistry; +import io.rsocket.plugins.InitializingInterceptorRegistry; import io.rsocket.test.util.TestDuplexConnection; import io.rsocket.util.DefaultPayload; import java.util.concurrent.atomic.AtomicInteger; @@ -32,14 +33,17 @@ public class ClientServerInputMultiplexerTest { private TestDuplexConnection source; private ClientServerInputMultiplexer clientMultiplexer; - private ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); private ClientServerInputMultiplexer serverMultiplexer; @Before public void setup() { - source = new TestDuplexConnection(); - clientMultiplexer = new ClientServerInputMultiplexer(source, new PluginRegistry(), true); - serverMultiplexer = new ClientServerInputMultiplexer(source, new PluginRegistry(), false); + source = new TestDuplexConnection(allocator); + clientMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); + serverMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); } @Test @@ -189,11 +193,11 @@ public void serverSplits() { } private ByteBuf resumeFrame() { - return ResumeFrameFlyweight.encode(allocator, Unpooled.EMPTY_BUFFER, 0, 0); + return ResumeFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER, 0, 0); } private ByteBuf setupFrame() { - return SetupFrameFlyweight.encode( + return SetupFrameCodec.encode( ByteBufAllocator.DEFAULT, false, 0, @@ -204,22 +208,22 @@ private ByteBuf setupFrame() { } private ByteBuf leaseFrame() { - return LeaseFrameFlyweight.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); + return LeaseFrameCodec.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); } private ByteBuf errorFrame(int i) { - return ErrorFrameFlyweight.encode(allocator, i, new Exception()); + return ErrorFrameCodec.encode(allocator, i, new Exception()); } private ByteBuf resumeOkFrame() { - return ResumeOkFrameFlyweight.encode(allocator, 0); + return ResumeOkFrameCodec.encode(allocator, 0); } private ByteBuf keepAliveFrame() { - return KeepAliveFrameFlyweight.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); + return KeepAliveFrameCodec.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); } private ByteBuf metadataPushFrame() { - return MetadataPushFrameFlyweight.encode(allocator, Unpooled.EMPTY_BUFFER); + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); } } diff --git a/rsocket-core/src/test/java/io/rsocket/internal/LimitableRequestPublisherTest.java b/rsocket-core/src/test/java/io/rsocket/internal/LimitableRequestPublisherTest.java deleted file mode 100644 index 8c51c123e..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/LimitableRequestPublisherTest.java +++ /dev/null @@ -1,33 +0,0 @@ -package io.rsocket.internal; - -import java.util.ArrayDeque; -import java.util.Queue; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.RepeatedTest; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.DirectProcessor; -import reactor.test.util.RaceTestUtils; - -class LimitableRequestPublisherTest { - - @Test - @RepeatedTest(2) - public void requestLimitRacingTest() throws InterruptedException { - Queue requests = new ArrayDeque<>(10000); - LimitableRequestPublisher limitableRequestPublisher = - LimitableRequestPublisher.wrap(DirectProcessor.create().doOnRequest(requests::add), 0); - - Runnable request1 = () -> limitableRequestPublisher.request(1); - Runnable request2 = () -> limitableRequestPublisher.increaseInternalLimit(2); - - limitableRequestPublisher.subscribe(); - - for (int i = 0; i < 10000; i++) { - RaceTestUtils.race(request1, request2); - } - - Thread.sleep(1000); - - Assertions.assertThat(requests.stream().mapToLong(l -> l).sum()).isEqualTo(10000); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java b/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java deleted file mode 100644 index af4c528e9..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/RateLimitableRequestPublisherTest.java +++ /dev/null @@ -1,140 +0,0 @@ -package io.rsocket.internal; - -import static org.junit.jupiter.api.Assertions.*; - -import java.time.Duration; -import java.util.concurrent.ThreadLocalRandom; -import java.util.function.Consumer; -import org.assertj.core.api.Assertions; -import org.junit.jupiter.api.Test; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; -import reactor.test.StepVerifier; - -class RateLimitableRequestPublisherTest { - - @Test - public void testThatRequest1WillBePropagatedUpstream() { - Flux source = - Flux.just(1) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(1)) - .expectNext(1) - .expectComplete() - .verify(Duration.ofMillis(1000)); - } - - @Test - public void testThatRequest256WillBePropagatedToUpstreamWithLimitedRate() { - Flux source = - Flux.range(0, 256) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(256)) - .expectNextCount(256) - .expectComplete() - .verify(Duration.ofMillis(1000)); - } - - @Test - public void testThatRequest256WillBePropagatedToUpstreamWithLimitedRateInFewSteps() { - Flux source = - Flux.range(0, 256) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(10)) - .expectNextCount(5) - .then(() -> rateLimitableRequestPublisher.request(128)) - .expectNextCount(133) - .expectNoEvent(Duration.ofMillis(10)) - .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) - .expectNextCount(118) - .expectComplete() - .verify(Duration.ofMillis(1000)); - } - - @Test - public void testThatRequestInRandomFashionWillBePropagatedToUpstreamWithLimitedRateInFewSteps() { - Flux source = - Flux.range(0, 10000000) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then( - () -> - Flux.interval(Duration.ofMillis(1000)) - .onBackpressureDrop() - .subscribe( - new Consumer() { - int count = 10000000; - - @Override - public void accept(Long __) { - int random = ThreadLocalRandom.current().nextInt(1, 512); - - long request = Math.min(random, count); - - count -= request; - - rateLimitableRequestPublisher.request(count); - } - })) - .expectNextCount(10000000) - .expectComplete() - .verify(Duration.ofMillis(30000)); - } - - @Test - public void testThatRequestLongMaxValueWillBeDeliveredInSeparateChunks() { - Flux source = - Flux.range(0, 10000000) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isLessThanOrEqualTo(128)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, 128); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) - .expectNextCount(10000000) - .expectComplete() - .verify(Duration.ofMillis(30000)); - } - - @Test - public void testThatRequestLongMaxWithIntegerMaxValuePrefetchWillBeDeliveredAsLongMaxValue() { - Flux source = - Flux.range(0, 10000000) - .subscribeOn(Schedulers.parallel()) - .doOnRequest(r -> Assertions.assertThat(r).isEqualTo(Long.MAX_VALUE)); - - RateLimitableRequestPublisher rateLimitableRequestPublisher = - RateLimitableRequestPublisher.wrap(source, Integer.MAX_VALUE); - - StepVerifier.create(rateLimitableRequestPublisher) - .then(() -> rateLimitableRequestPublisher.request(Long.MAX_VALUE)) - .expectNextCount(10000000) - .expectComplete() - .verify(Duration.ofMillis(30000)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java new file mode 100644 index 000000000..d73f92b85 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java @@ -0,0 +1,23 @@ +package io.rsocket.internal; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import reactor.core.scheduler.Scheduler; + +public class SchedulerUtils { + + public static void warmup(Scheduler scheduler) throws InterruptedException { + warmup(scheduler, 10000); + } + + public static void warmup(Scheduler scheduler, int times) throws InterruptedException { + scheduler.start(); + + // warm up + CountDownLatch latch = new CountDownLatch(times); + for (int i = 0; i < times; i++) { + scheduler.schedule(latch::countDown); + } + latch.await(5, TimeUnit.SECONDS); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java b/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java deleted file mode 100644 index 07fbf695f..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/SwitchTransformFluxTest.java +++ /dev/null @@ -1,446 +0,0 @@ -package io.rsocket.internal; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasItem; - -import java.time.Duration; -import java.util.ArrayList; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.Ignore; -import org.junit.Test; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; -import reactor.test.StepVerifier; -import reactor.test.publisher.TestPublisher; -import reactor.test.util.RaceTestUtils; -import reactor.util.context.Context; - -@Ignore -public class SwitchTransformFluxTest { - - @Test - public void shouldBeAbleToCancelSubscription() throws InterruptedException { - for (int j = 0; j < 10; j++) { - ArrayList capturedElements = new ArrayList<>(); - ArrayList capturedCompletions = new ArrayList<>(); - for (int i = 0; i < 1000; i++) { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong captureElement = new AtomicLong(0L); - AtomicBoolean captureCompletion = new AtomicBoolean(false); - AtomicLong requested = new AtomicLong(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::addAndGet) - .doOnCancel(latch::countDown) - .transform( - flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)); - - publisher.next(1L); - - switchTransformed.subscribe( - captureElement::set, - __ -> {}, - () -> captureCompletion.set(true), - s -> - new Thread( - () -> - RaceTestUtils.race( - publisher::complete, - () -> - RaceTestUtils.race( - s::cancel, () -> s.request(1), Schedulers.parallel()), - Schedulers.parallel())) - .start()); - - Assert.assertTrue(latch.await(5, TimeUnit.SECONDS)); - Assert.assertEquals(requested.get(), 1L); - capturedElements.add(captureElement.get()); - capturedCompletions.add(captureCompletion.get()); - } - - Assume.assumeThat(capturedElements, hasItem(equalTo(0L))); - Assume.assumeThat(capturedCompletions, hasItem(equalTo(false))); - } - } - - @Test - public void shouldRequestExpectedAmountOfElements() throws InterruptedException { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong capture = new AtomicLong(); - AtomicLong requested = new AtomicLong(); - CountDownLatch latch = new CountDownLatch(1); - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::addAndGet) - .transform(flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)); - - publisher.next(1L); - - switchTransformed.subscribe( - capture::set, - __ -> {}, - latch::countDown, - s -> { - for (int i = 0; i < 10000; i++) { - RaceTestUtils.race(() -> s.request(1), () -> s.request(1)); - } - RaceTestUtils.race(publisher::complete, publisher::complete); - }); - - latch.await(5, TimeUnit.SECONDS); - - Assert.assertEquals(capture.get(), 1L); - Assert.assertEquals(requested.get(), 20000L); - } - - @Test - public void shouldReturnCorrectContextOnEmptySource() { - Flux switchTransformed = - Flux.empty() - .transform(flux -> new SwitchTransformFlux<>(flux, (first, innerFlux) -> innerFlux)) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - StepVerifier.create(switchTransformed, 0) - .expectSubscription() - .thenRequest(1) - .expectAccessibleContext() - .contains("a", "c") - .contains("c", "d") - .then() - .expectComplete() - .verify(); - } - - @Test - public void shouldNotFailOnIncorrectPublisherBehavior() { - TestPublisher publisher = - TestPublisher.createNoncompliant(TestPublisher.Violation.CLEANUP_ON_TERMINATE); - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, - (first, innerFlux) -> innerFlux.subscriberContext(Context.of("a", "b")))); - - StepVerifier.create( - new Flux() { - @Override - public void subscribe(CoreSubscriber actual) { - switchTransformed.subscribe(actual); - publisher.next(1L); - } - }, - 0) - .thenRequest(1) - .expectNext(1L) - .thenRequest(1) - .then(() -> publisher.next(2L)) - .expectNext(2L) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .then(() -> publisher.error(new RuntimeException())) - .expectError() - .verifyThenAssertThat() - .hasDroppedErrors(3) - .tookLessThan(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - // @Test - // public void shouldNotFailOnIncorrePu - - @Test - public void shouldBeAbleToAccessUpstreamContext() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, - (first, innerFlux) -> - innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")))) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - publisher.next(1L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("1") - .thenRequest(1) - .then(() -> publisher.next(2L)) - .expectNext("2") - .expectAccessibleContext() - .contains("a", "b") - .contains("c", "d") - .then() - .then(publisher::complete) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void shouldNotHangWhenOneElementUpstream() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, - (first, innerFlux) -> - innerFlux.map(String::valueOf).subscriberContext(Context.of("a", "b")))) - .subscriberContext(Context.of("a", "c")) - .subscriberContext(Context.of("c", "d")); - - publisher.next(1L); - publisher.complete(); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("1") - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void backpressureTest() { - TestPublisher publisher = TestPublisher.createCold(); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .flux() - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - publisher.next(1L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("1") - .thenRequest(1) - .then(() -> publisher.next(2L)) - .expectNext("2") - .then(publisher::complete) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - - Assert.assertEquals(2L, requested.get()); - } - - @Test - public void backpressureConditionalTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) - .filter(e -> false); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assert.assertEquals(2L, requested.get()); - } - - @Test - public void backpressureHiddenConditionalTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf).hide())) - .filter(e -> false); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assert.assertEquals(10001L, requested.get()); - } - - @Test - public void backpressureDrawbackOnConditionalInTransformTest() { - Flux publisher = Flux.range(0, 10000); - AtomicLong requested = new AtomicLong(); - - Flux switchTransformed = - publisher - .doOnRequest(requested::addAndGet) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, - (first, innerFlux) -> innerFlux.map(String::valueOf).filter(e -> false))); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectComplete() - .verify(Duration.ofSeconds(10)); - - Assert.assertEquals(10001L, requested.get()); - } - - @Test - public void shouldErrorOnOverflowTest() { - TestPublisher publisher = TestPublisher.createCold(); - - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - publisher.next(1L); - - StepVerifier.create(switchTransformed, 0) - .thenRequest(1) - .expectNext("1") - .then(() -> publisher.next(2L)) - .expectError() - .verify(Duration.ofSeconds(10)); - - publisher.assertWasRequested(); - publisher.assertNoRequestOverflow(); - } - - @Test - public void shouldPropagateonCompleteCorrectly() { - Flux switchTransformed = - Flux.empty() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - StepVerifier.create(switchTransformed).expectComplete().verify(Duration.ofSeconds(10)); - } - - @Test - public void shouldPropagateErrorCorrectly() { - Flux switchTransformed = - Flux.error(new RuntimeException("hello")) - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - StepVerifier.create(switchTransformed) - .expectErrorMessage("hello") - .verify(Duration.ofSeconds(10)); - } - - @Test - public void shouldBeAbleToBeCancelledProperly() { - TestPublisher publisher = TestPublisher.createCold(); - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))); - - publisher.next(1); - - StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - } - - @Test - public void shouldBeAbleToCatchDiscardedElement() { - TestPublisher publisher = TestPublisher.createCold(); - Integer[] discarded = new Integer[1]; - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) - .doOnDiscard(Integer.class, e -> discarded[0] = e); - - publisher.next(1); - - StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - - Assert.assertArrayEquals(new Integer[] {1}, discarded); - } - - @Test - public void shouldBeAbleToCatchDiscardedElementInCaseOfConditional() { - TestPublisher publisher = TestPublisher.createCold(); - Integer[] discarded = new Integer[1]; - Flux switchTransformed = - publisher - .flux() - .transform( - flux -> - new SwitchTransformFlux<>( - flux, (first, innerFlux) -> innerFlux.map(String::valueOf))) - .filter(t -> true) - .doOnDiscard(Integer.class, e -> discarded[0] = e); - - publisher.next(1); - - StepVerifier.create(switchTransformed, 0).thenCancel().verify(Duration.ofSeconds(10)); - - publisher.assertCancelled(); - publisher.assertWasRequested(); - - Assert.assertArrayEquals(new Integer[] {1}, discarded); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java index 0dc7d9090..7bf975543 100644 --- a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -17,6 +17,7 @@ package io.rsocket.internal; import io.rsocket.Payload; +import io.rsocket.util.ByteBufPayload; import io.rsocket.util.EmptyPayload; import java.util.concurrent.CountDownLatch; import org.junit.Assert; @@ -82,6 +83,36 @@ public void testOnNextAfterSubscribe_1000() throws Exception { testOnNextAfterSubscribeN(1000); } + @Test + public void testPrioritizedSending() { + UnboundedProcessor processor = new UnboundedProcessor<>(); + + for (int i = 0; i < 1000; i++) { + processor.onNext(EmptyPayload.INSTANCE); + } + + processor.onNextPrioritized(ByteBufPayload.create("test")); + + Payload closestPayload = processor.next().block(); + + Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + } + + @Test + public void testPrioritizedFused() { + UnboundedProcessor processor = new UnboundedProcessor<>(); + + for (int i = 0; i < 1000; i++) { + processor.onNext(EmptyPayload.INSTANCE); + } + + processor.onNextPrioritized(ByteBufPayload.create("test")); + + Payload closestPayload = processor.poll(); + + Assert.assertEquals(closestPayload.getDataUtf8(), "test"); + } + public void testOnNextAfterSubscribeN(int n) throws Exception { CountDownLatch latch = new CountDownLatch(n); UnboundedProcessor processor = new UnboundedProcessor<>(); diff --git a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java new file mode 100644 index 000000000..84a589a8d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java @@ -0,0 +1,1154 @@ +/* + * Copyright (c) 2011-2017 Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.subscriber; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BooleanSupplier; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.context.Context; + +/** + * A Subscriber implementation that hosts assertion tests for its state and allows asynchronous + * cancellation and requesting. + * + *

To create a new instance of {@link AssertSubscriber}, you have the choice between these static + * methods: + * + *

    + *
  • {@link AssertSubscriber#create()}: create a new {@link AssertSubscriber} and requests an + * unbounded number of elements. + *
  • {@link AssertSubscriber#create(long)}: create a new {@link AssertSubscriber} and requests + * {@code n} elements (can be 0 if you want no initial demand). + *
+ * + *

If you are testing asynchronous publishers, don't forget to use one of the {@code await*()} + * methods to wait for the data to assert. + * + *

You can extend this class but only the onNext, onError and onComplete can be overridden. You + * can call {@link #request(long)} and {@link #cancel()} from any thread or from within the + * overridable methods but you should avoid calling the assertXXX methods asynchronously. + * + *

Usage: + * + *

{@code
+ * AssertSubscriber
+ *   .subscribe(publisher)
+ *   .await()
+ *   .assertValues("ABC", "DEF");
+ * }
+ * + * @param the value type. + * @author Sebastien Deleuze + * @author David Karnok + * @author Anatoly Kadyshev + * @author Stephane Maldini + * @author Brian Clozel + */ +public class AssertSubscriber implements CoreSubscriber, Subscription { + + /** Default timeout for waiting next values to be received */ + public static final Duration DEFAULT_VALUES_TIMEOUT = Duration.ofSeconds(3); + + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(AssertSubscriber.class, "requested"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater NEXT_VALUES = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, List.class, "values"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, Subscription.class, "s"); + + private final Context context; + + private final List errors = new LinkedList<>(); + + private final CountDownLatch cdl = new CountDownLatch(1); + + volatile Subscription s; + + volatile long requested; + + volatile List values = new LinkedList<>(); + + /** The fusion mode to request. */ + private int requestedFusionMode = -1; + + /** The established fusion mode. */ + private volatile int establishedFusionMode = -1; + + /** The fuseable QueueSubscription in case a fusion mode was specified. */ + private Fuseable.QueueSubscription qs; + + private int subscriptionCount = 0; + + private int completionCount = 0; + + private volatile long valueCount = 0L; + + private volatile long nextValueAssertedCount = 0L; + + private Duration valuesTimeout = DEFAULT_VALUES_TIMEOUT; + + private boolean valuesStorage = true; + + // + // ============================================================================================================== + // Static methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throws an {@link AssertionError} with the specified error message + * supplier. + * + * @param timeout the timeout duration + * @param errorMessageSupplier the error message supplier + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, Supplier errorMessageSupplier, BooleanSupplier conditionSupplier) { + + Objects.requireNonNull(errorMessageSupplier); + Objects.requireNonNull(conditionSupplier); + Objects.requireNonNull(timeout); + + long timeoutNs = timeout.toNanos(); + long startTime = System.nanoTime(); + do { + if (conditionSupplier.getAsBoolean()) { + return; + } + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } while (System.nanoTime() - startTime < timeoutNs); + throw new AssertionError(errorMessageSupplier.get()); + } + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throw an {@link AssertionError} with the specified error message. + * + * @param timeout the timeout duration + * @param errorMessage the error message + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, final String errorMessage, BooleanSupplier conditionSupplier) { + await( + timeout, + new Supplier() { + @Override + public String get() { + return errorMessage; + } + }, + conditionSupplier); + } + + /** + * Create a new {@link AssertSubscriber} that requests an unbounded number of elements. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create() { + return new AssertSubscriber<>(); + } + + /** + * Create a new {@link AssertSubscriber} that requests initially {@code n} elements. You can then + * manage the demand with {@link Subscription#request(long)}. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param n Number of elements to request (can be 0 if you want no initial demand). + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create(long n) { + return new AssertSubscriber<>(n); + } + + // + // ============================================================================================================== + // constructors + // + // ============================================================================================================== + + public AssertSubscriber() { + this(Context.empty(), Long.MAX_VALUE); + } + + public AssertSubscriber(long n) { + this(Context.empty(), n); + } + + public AssertSubscriber(Context context) { + this(context, Long.MAX_VALUE); + } + + public AssertSubscriber(Context context, long n) { + if (n < 0) { + throw new IllegalArgumentException("initialRequest >= required but it was " + n); + } + this.context = context; + REQUESTED.lazySet(this, n); + } + + // + // ============================================================================================================== + // Configuration + // + // ============================================================================================================== + + /** + * Enable or disabled the values storage. It is enabled by default, and can be disable in order to + * be able to perform performance benchmarks or tests with a huge amount values. + * + * @param enabled enable value storage? + * @return this + */ + public final AssertSubscriber configureValuesStorage(boolean enabled) { + this.valuesStorage = enabled; + return this; + } + + /** + * Configure the timeout in seconds for waiting next values to be received (3 seconds by default). + * + * @param timeout the new default value timeout duration + * @return this + */ + public final AssertSubscriber configureValuesTimeout(Duration timeout) { + this.valuesTimeout = timeout; + return this; + } + + /** + * Returns the established fusion mode or -1 if it was not enabled + * + * @return the fusion mode, see Fuseable constants + */ + public final int establishedFusionMode() { + return establishedFusionMode; + } + + // + // ============================================================================================================== + // Assertions + // + // ============================================================================================================== + + /** + * Assert a complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertComplete() { + assertNoError(); + int c = completionCount; + if (c == 0) { + throw new AssertionError("Not completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert the specified values have been received. Values storage should be enabled to use this + * method. + * + * @param expectedValues the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertContainValues(Set expectedValues) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + if (expectedValues.size() > values.size()) { + throw new AssertionError("Actual contains fewer elements" + values, null); + } + + Iterator expected = expectedValues.iterator(); + + for (; ; ) { + boolean n2 = expected.hasNext(); + if (n2) { + T t2 = expected.next(); + if (!values.contains(t2)) { + throw new AssertionError( + "The element is not contained in the " + + "received results" + + " = " + + valueAndClass(t2), + null); + } + } else { + break; + } + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertError() { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @param clazz The class of the exception contained in the error signal + * @return this + */ + public final AssertSubscriber assertError(Class clazz) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + Throwable e = errors.get(0); + if (!clazz.isInstance(e)) { + throw new AssertionError( + "Error class incompatible: expected = " + clazz + ", actual = " + e, null); + } + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + public final AssertSubscriber assertErrorMessage(String message) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + assertionError("No error", null); + } + if (s == 1) { + if (!Objects.equals(message, errors.get(0).getMessage())) { + assertionError( + "Error class incompatible: expected = \"" + + message + + "\", actual = \"" + + errors.get(0).getMessage() + + "\"", + null); + } + } + if (s > 1) { + assertionError("Multiple errors: " + s, null); + } + + return this; + } + + /** + * Assert an error signal has been received. + * + * @param expectation A method that can verify the exception contained in the error signal and + * throw an exception (like an {@link AssertionError}) if the exception is not valid. + * @return this + */ + public final AssertSubscriber assertErrorWith(Consumer expectation) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + expectation.accept(errors.get(0)); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert that the upstream was a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertFuseableSource() { + if (qs == null) { + throw new AssertionError("Upstream was not Fuseable"); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionEnabled() { + if (establishedFusionMode != Fuseable.SYNC && establishedFusionMode != Fuseable.ASYNC) { + throw new AssertionError("Fusion was not enabled"); + } + return this; + } + + public final AssertSubscriber assertFusionMode(int expectedMode) { + if (establishedFusionMode != expectedMode) { + throw new AssertionError( + "Wrong fusion mode: expected: " + + fusionModeName(expectedMode) + + ", actual: " + + fusionModeName(establishedFusionMode)); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionRejected() { + if (establishedFusionMode != Fuseable.NONE) { + throw new AssertionError("Fusion was granted"); + } + return this; + } + + /** + * Assert no error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNoError() { + int s = errors.size(); + if (s == 1) { + Throwable e = errors.get(0); + String valueAndClass = e == null ? null : e + " (" + e.getClass().getSimpleName() + ")"; + throw new AssertionError("Error present: " + valueAndClass, null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert no values have been received. + * + * @return this + */ + public final AssertSubscriber assertNoValues() { + if (valueCount != 0) { + throw new AssertionError( + "No values expected but received: [length = " + values.size() + "] " + values, null); + } + return this; + } + + /** + * Assert that the upstream was not a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertNonFuseableSource() { + if (qs != null) { + throw new AssertionError("Upstream was Fuseable"); + } + return this; + } + + /** + * Assert no complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotComplete() { + int c = completionCount; + if (c == 1) { + throw new AssertionError("Completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert no subscription occurred. + * + * @return this + */ + public final AssertSubscriber assertNotSubscribed() { + int s = subscriptionCount; + + if (s == 1) { + throw new AssertionError("OnSubscribe called once", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert no complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotTerminated() { + if (cdl.getCount() == 0) { + throw new AssertionError("Terminated", null); + } + return this; + } + + /** + * Assert subscription occurred (once). + * + * @return this + */ + public final AssertSubscriber assertSubscribed() { + int s = subscriptionCount; + + if (s == 0) { + throw new AssertionError("OnSubscribe not called", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert either complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertTerminated() { + if (cdl.getCount() != 0) { + throw new AssertionError("Not terminated", null); + } + return this; + } + + /** + * Assert {@code n} values has been received. + * + * @param n the expected value count + * @return this + */ + public final AssertSubscriber assertValueCount(long n) { + if (valueCount != n) { + throw new AssertionError( + "Different value count: expected = " + n + ", actual = " + valueCount, null); + } + return this; + } + + /** + * Assert the specified values have been received in the same order read by the passed {@link + * Iterable}. Values storage should be enabled to use this method. + * + * @param expectedSequence the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertValueSequence(Iterable expectedSequence) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + Iterator actual = values.iterator(); + Iterator expected = expectedSequence.iterator(); + int i = 0; + for (; ; ) { + boolean n1 = actual.hasNext(); + boolean n2 = expected.hasNext(); + if (n1 && n2) { + T t1 = actual.next(); + T t2 = expected.next(); + if (!Objects.equals(t1, t2)) { + throw new AssertionError( + "The element with index " + + i + + " does not match: expected = " + + valueAndClass(t2) + + ", actual = " + + valueAndClass(t1), + null); + } + i++; + } else if (n1 && !n2) { + throw new AssertionError("Actual contains more elements" + values, null); + } else if (!n1 && n2) { + throw new AssertionError("Actual contains fewer elements: " + values, null); + } else { + break; + } + } + return this; + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectedValues the values to assert + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValues(T... expectedValues) { + return assertValueSequence(Arrays.asList(expectedValues)); + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValuesWith(Consumer... expectations) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + final int expectedValueCount = expectations.length; + if (expectedValueCount != values.size()) { + throw new AssertionError( + "Different value count: expected = " + expectedValueCount + ", actual = " + valueCount, + null); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = values.get(i); + consumer.accept(actualValue); + } + return this; + } + + // + // ============================================================================================================== + // Await methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until a complete successfully or error signal is received. + * + * @return this + */ + public final AssertSubscriber await() { + if (cdl.getCount() == 0) { + return this; + } + try { + cdl.await(); + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + return this; + } + + /** + * Blocking method that waits until a complete successfully or error signal is received or until a + * timeout occurs. + * + * @param timeout The timeout value + * @return this + */ + public final AssertSubscriber await(Duration timeout) { + if (cdl.getCount() == 0) { + return this; + } + try { + if (!cdl.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new AssertionError("No complete or error signal before timeout"); + } + return this; + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + } + + /** + * Blocking method that waits until {@code n} next values have been received. + * + * @param n the value count to assert + * @return this + */ + public final AssertSubscriber awaitAndAssertNextValueCount(final long n) { + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + n, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d", + valueCount - nextValueAssertedCount, n, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + n)); + nextValueAssertedCount += n; + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * values provided) to assert them. + * + * @param values the values to assert + * @return this + */ + @SafeVarargs + @SuppressWarnings("unchecked") + public final AssertSubscriber awaitAndAssertNextValues(T... values) { + final int expectedNum = values.length; + final List> expectations = new ArrayList<>(); + for (int i = 0; i < expectedNum; i++) { + final T expectedValue = values[i]; + expectations.add( + actualValue -> { + if (!actualValue.equals(expectedValue)) { + throw new AssertionError( + String.format( + "Expected Next signal: %s, but got: %s", expectedValue, actualValue)); + } + }); + } + awaitAndAssertNextValuesWith(expectations.toArray((Consumer[]) new Consumer[0])); + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * expectations provided) to assert them. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + */ + @SafeVarargs + public final AssertSubscriber awaitAndAssertNextValuesWith(Consumer... expectations) { + valuesStorage = true; + final int expectedValueCount = expectations.length; + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + expectedValueCount, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d ms", + valueCount - nextValueAssertedCount, expectedValueCount, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + expectedValueCount)); + List nextValuesSnapshot; + List empty = new ArrayList<>(); + for (; ; ) { + nextValuesSnapshot = values; + if (NEXT_VALUES.compareAndSet(this, values, empty)) { + break; + } + } + if (nextValuesSnapshot.size() < expectedValueCount) { + throw new AssertionError( + String.format( + "Expected %d number of signals but received %d", + expectedValueCount, nextValuesSnapshot.size())); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = nextValuesSnapshot.get(i); + consumer.accept(actualValue); + } + nextValueAssertedCount += expectedValueCount; + return this; + } + + // + // ============================================================================================================== + // Overrides + // + // ============================================================================================================== + + @Override + public void cancel() { + Subscription a = s; + if (a != Operators.cancelledSubscription()) { + a = S.getAndSet(this, Operators.cancelledSubscription()); + if (a != null && a != Operators.cancelledSubscription()) { + a.cancel(); + } + } + } + + final boolean isCancelled() { + return s == Operators.cancelledSubscription(); + } + + public final boolean isTerminated() { + return cdl.getCount() == 0; + } + + @Override + public void onComplete() { + completionCount++; + cdl.countDown(); + } + + @Override + public void onError(Throwable t) { + errors.add(t); + cdl.countDown(); + } + + @Override + public void onNext(T t) { + if (establishedFusionMode == Fuseable.ASYNC) { + for (; ; ) { + t = qs.poll(); + if (t == null) { + break; + } + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } else { + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + subscriptionCount++; + int requestMode = requestedFusionMode; + if (requestMode >= 0) { + if (!setWithoutRequesting(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + } else { + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; + + int m = qs.requestFusion(requestMode); + establishedFusionMode = m; + + if (m == Fuseable.SYNC) { + for (; ; ) { + T v = qs.poll(); + if (v == null) { + onComplete(); + break; + } + + onNext(v); + } + } else { + requestDeferred(); + } + } else { + requestDeferred(); + } + } + } else { + if (!set(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + } + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + if (establishedFusionMode != Fuseable.SYNC) { + normalRequest(n); + } + } + } + + @Override + @NonNull + public Context currentContext() { + return context; + } + + /** + * Setup what fusion mode should be requested from the incoming Subscription if it happens to be + * QueueSubscription + * + * @param requestMode the mode to request, see Fuseable constants + * @return this + */ + public final AssertSubscriber requestedFusionMode(int requestMode) { + this.requestedFusionMode = requestMode; + return this; + } + + public Subscription upstream() { + return s; + } + + // + // ============================================================================================================== + // Non public methods + // + // ============================================================================================================== + + protected final void normalRequest(long n) { + Subscription a = s; + if (a != null) { + a.request(n); + } else { + Operators.addCap(REQUESTED, this, n); + + a = s; + + if (a != null) { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + a.request(r); + } + } + } + } + + /** Requests the deferred amount if not zero. */ + protected final void requestDeferred() { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + } + + /** + * Atomically sets the single subscription and requests the missed amount from it. + * + * @param s + * @return false if this arbiter is cancelled or there was a subscription already set + */ + protected final boolean set(Subscription s) { + Objects.requireNonNull(s, "s"); + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + + return true; + } + + a = this.s; + + if (a != Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + + Operators.reportSubscriptionSet(); + return false; + } + + /** + * Sets the Subscription once but does not request anything. + * + * @param s the Subscription to set + * @return true if successful, false if the current subscription is not null + */ + protected final boolean setWithoutRequesting(Subscription s) { + Objects.requireNonNull(s, "s"); + for (; ; ) { + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + return true; + } + } + } + + /** + * Prepares and throws an AssertionError exception based on the message, cause, the active state + * and the potential errors so far. + * + * @param message the message + * @param cause the optional Throwable cause + * @throws AssertionError as expected + */ + protected final void assertionError(String message, Throwable cause) { + StringBuilder b = new StringBuilder(); + + if (cdl.getCount() != 0) { + b.append("(active) "); + } + b.append(message); + + List err = errors; + if (!err.isEmpty()) { + b.append(" (+ ").append(err.size()).append(" errors)"); + } + AssertionError e = new AssertionError(b.toString(), cause); + + for (Throwable t : err) { + e.addSuppressed(t); + } + + throw e; + } + + protected final String fusionModeName(int mode) { + switch (mode) { + case -1: + return "Disabled"; + case Fuseable.NONE: + return "None"; + case Fuseable.SYNC: + return "Sync"; + case Fuseable.ASYNC: + return "Async"; + default: + return "Unknown(" + mode + ")"; + } + } + + protected final String valueAndClass(Object o) { + if (o == null) { + return null; + } + return o + " (" + o.getClass().getSimpleName() + ")"; + } + + public List values() { + return values; + } + + public final AssertSubscriber assertNoEvents() { + return assertNoValues().assertNoError().assertNotComplete(); + } + + @SafeVarargs + public final AssertSubscriber assertIncomplete(T... values) { + return assertValues(values).assertNotComplete().assertNoError(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java new file mode 100644 index 000000000..cb8478c13 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java @@ -0,0 +1,209 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCounted; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +public class TracingMetadataCodecTest { + + private static Stream flags() { + return Stream.of(TracingMetadataCodec.Flags.values()); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeEmptyTrace(TracingMetadataCodec.Flags expectedFlag) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = TracingMetadataCodec.encodeEmpty(allocator, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(TracingMetadata::isEmpty) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace64WithParent(TracingMetadataCodec.Flags expectedFlag) { + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + long parentId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode64(allocator, traceId, spanId, parentId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == 0) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> tm.hasParent()) + .matches(tm -> tm.parentId() == parentId) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace64(TracingMetadataCodec.Flags expectedFlag) { + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = TracingMetadataCodec.encode64(allocator, traceId, spanId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == 0) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> !tm.hasParent()) + .matches(tm -> tm.parentId() == 0) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace128WithParent(TracingMetadataCodec.Flags expectedFlag) { + long traceIdHighLocal; + do { + traceIdHighLocal = ThreadLocalRandom.current().nextLong(); + + } while (traceIdHighLocal == 0); + long traceIdHigh = traceIdHighLocal; + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + long parentId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode128( + allocator, traceIdHigh, traceId, spanId, parentId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == traceIdHigh) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> tm.hasParent()) + .matches(tm -> tm.parentId() == parentId) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace128(TracingMetadataCodec.Flags expectedFlag) { + long traceIdHighLocal; + do { + traceIdHighLocal = ThreadLocalRandom.current().nextLong(); + + } while (traceIdHighLocal == 0); + long traceIdHigh = traceIdHighLocal; + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode128(allocator, traceIdHigh, traceId, spanId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == traceIdHigh) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> !tm.hasParent()) + .matches(tm -> tm.parentId() == 0) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java new file mode 100644 index 000000000..13d910e15 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/security/AuthMetadataFlyweightTest.java @@ -0,0 +1,470 @@ +package io.rsocket.metadata.security; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class AuthMetadataFlyweightTest { + + public static final int AUTH_TYPE_ID_LENGTH = 1; + public static final int USER_NAME_BYTES_LENGTH = 1; + public static final String TEST_BEARER_TOKEN = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJpYXQxIjoxNTE2MjM5MDIyLCJpYXQyIjoxNTE2MjM5MDIyLCJpYXQzIjoxNTE2MjM5MDIyLCJpYXQ0IjoxNTE2MjM5MDIyfQ.ljYuH-GNyyhhLcx-rHMchRkGbNsR2_4aSxo8XjrYrSM"; + + @Test + void shouldCorrectlyEncodeData() { + String username = "test"; + String password = "tset1234"; + + int usernameLength = username.length(); + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData1() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData2() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎1234567#4? "; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + private static void checkSimpleAuthMetadataEncoding( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(byteBuf.readUnsignedByte() & ~0x80) + .isEqualTo(WellKnownAuthType.SIMPLE.getIdentifier()); + Assertions.assertThat(byteBuf.readUnsignedByte()).isEqualTo((short) usernameLength); + + Assertions.assertThat(byteBuf.readCharSequence(usernameLength, CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(byteBuf.readCharSequence(passwordLength, CharsetUtil.UTF_8)) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + private static void checkSimpleAuthMetadataEncodingUsingDecoders( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(AuthMetadataFlyweight.decodeWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.SIMPLE); + byteBuf.markReaderIndex(); + Assertions.assertThat(AuthMetadataFlyweight.decodeUsername(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(AuthMetadataFlyweight.decodePassword(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(password); + byteBuf.resetReaderIndex(); + + Assertions.assertThat(new String(AuthMetadataFlyweight.decodeUsernameAsCharArray(byteBuf))) + .isEqualTo(username); + Assertions.assertThat(new String(AuthMetadataFlyweight.decodePasswordAsCharArray(byteBuf))) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + @Test + void shouldThrowExceptionIfUsernameLengthExitsAllowedBounds() { + String username = + "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎𠸏𠹷𠺝𠺢𠻗𠻹𠻺𠼭𠼮𠽌𠾴𠾼𠿪𡁜𡁯𡁵𡁶𡁻𡃁𡃉𡇙𢃇𢞵𢫕𢭃𢯊𢱑𢱕𢳂𢴈𢵌𢵧𢺳𣲷𤓓𤶸𤷪𥄫𦉘𦟌𦧲𦧺𧨾𨅝𨈇𨋢𨳊𨳍𨳒𩶘𠜎𠜱𠝹"; + String password = "tset1234"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray())) + .hasMessage( + "Username should be shorter than or equal to 255 bytes length in UTF-8 encoding"); + } + + @Test + void shouldEncodeBearerMetadata() { + String testToken = TEST_BEARER_TOKEN; + + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeBearerMetadata( + ByteBufAllocator.DEFAULT, testToken.toCharArray()); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(testToken, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(testToken, byteBuf); + } + + private static void checkBearerAuthMetadataEncoding(String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat( + byteBuf.readUnsignedByte() & ~AuthMetadataFlyweight.STREAM_METADATA_KNOWN_MASK) + .isEqualTo(WellKnownAuthType.BEARER.getIdentifier()); + Assertions.assertThat(byteBuf.readSlice(byteBuf.capacity() - 1).toString(CharsetUtil.UTF_8)) + .isEqualTo(testToken); + } + + private static void checkBearerAuthMetadataEncodingUsingDecoders( + String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(byteBuf)).isTrue(); + Assertions.assertThat(AuthMetadataFlyweight.decodeWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.BEARER); + byteBuf.markReaderIndex(); + Assertions.assertThat(new String(AuthMetadataFlyweight.decodeBearerTokenAsCharArray(byteBuf))) + .isEqualTo(testToken); + byteBuf.resetReaderIndex(); + Assertions.assertThat( + AuthMetadataFlyweight.decodePayload(byteBuf).toString(CharsetUtil.UTF_8).toString()) + .isEqualTo(testToken); + } + + @Test + void shouldEncodeCustomAuth() { + String payloadAsAText = "testsecuritybuffer"; + ByteBuf testSecurityPayload = + Unpooled.wrappedBuffer(payloadAsAText.getBytes(CharsetUtil.UTF_8)); + + String customAuthType = "myownauthtype"; + ByteBuf buffer = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload); + + checkCustomAuthMetadataEncoding(testSecurityPayload, customAuthType, buffer); + } + + private static void checkCustomAuthMetadataEncoding( + ByteBuf testSecurityPayload, String customAuthType, ByteBuf buffer) { + Assertions.assertThat(buffer.capacity()) + .isEqualTo(1 + customAuthType.length() + testSecurityPayload.capacity()); + Assertions.assertThat(buffer.readUnsignedByte()) + .isEqualTo((short) (customAuthType.length() - 1)); + Assertions.assertThat( + buffer.readCharSequence(customAuthType.length(), CharsetUtil.US_ASCII).toString()) + .isEqualTo(customAuthType); + Assertions.assertThat(buffer.readSlice(testSecurityPayload.capacity())) + .isEqualTo(testSecurityPayload); + + ReferenceCountUtil.release(buffer); + } + + @Test + void shouldThrowOnNonASCIIChars() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = "1234567#4? 𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage("custom auth type must be US_ASCII characters only"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + // 130 chars + String customAuthType = + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType1() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = ""; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldEncodeUsingWellKnownAuthType() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer(3, 3).writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType1() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType2() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldThrowIfWellKnownAuthTypeIsUnsupportedOrUnknown() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + buffer.release(); + } + + @Test + void shouldCompressMetadata() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "simple", + ByteBufAllocator.DEFAULT.buffer().writeByte(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldCompressMetadata1() { + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "bearer", + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldNotCompressMetadata() { + ByteBuf testMetadataPayload = + Unpooled.wrappedBuffer(TEST_BEARER_TOKEN.getBytes(CharsetUtil.UTF_8)); + String customAuthType = "testauthtype"; + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, customAuthType, testMetadataPayload); + + checkCustomAuthMetadataEncoding(testMetadataPayload, customAuthType, byteBuf); + } + + @Test + void shouldConfirmWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(metadata)).isTrue(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldConfirmGivenMetadataIsNotAWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple/afafgafadf", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.isWellKnownAuthType(metadata)).isFalse(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadSimpleWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.SIMPLE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType1() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "bearer", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.BEARER; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType2() { + ByteBuf metadata = + ByteBufAllocator.DEFAULT + .buffer() + .writeByte(3 | AuthMetadataFlyweight.STREAM_METADATA_KNOWN_MASK); + WellKnownAuthType expectedType = WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength() { + ByteBuf metadata = ByteBufAllocator.DEFAULT.buffer().writeByte(3); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength1() { + ByteBuf metadata = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, "testmetadataauthtype", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldThrowExceptionIsNotEnoughReadableBytes() { + Assertions.assertThatThrownBy( + () -> AuthMetadataFlyweight.decodeWellKnownAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode Well Know Auth type. Not enough readable bytes"); + } + + private static void checkDecodeWellKnowAuthTypeCorrectly( + ByteBuf metadata, WellKnownAuthType expectedType) { + int initialReaderIndex = metadata.readerIndex(); + + WellKnownAuthType wellKnownAuthType = AuthMetadataFlyweight.decodeWellKnownAuthType(metadata); + + Assertions.assertThat(wellKnownAuthType).isEqualTo(expectedType); + Assertions.assertThat(metadata.readerIndex()) + .isNotEqualTo(initialReaderIndex) + .isEqualTo(initialReaderIndex + 1); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadCustomEncodedAuthType() { + String testAuthType = "TestAuthType"; + ByteBuf byteBuf = + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, testAuthType, Unpooled.EMPTY_BUFFER); + checkDecodeCustomAuthTypeCorrectly(testAuthType, byteBuf); + } + + @Test + void shouldThrowExceptionOnEmptyMetadata() { + Assertions.assertThatThrownBy( + () -> AuthMetadataFlyweight.decodeCustomAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode custom Auth type. Not enough readable bytes"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_wellknowninstead() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.decodeCustomAuthType( + AuthMetadataFlyweight.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(new byte[] {'a', 'b'})))) + .hasMessage("Unable to decode custom Auth type. Incorrect auth type length"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_length() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataFlyweight.decodeCustomAuthType( + ByteBufAllocator.DEFAULT.buffer().writeByte(127).writeChar('a').writeChar('b'))) + .hasMessage("Unable to decode custom Auth type. Malformed length or auth type string"); + } + + private static void checkDecodeCustomAuthTypeCorrectly(String testAuthType, ByteBuf byteBuf) { + int initialReaderIndex = byteBuf.readerIndex(); + + Assertions.assertThat(AuthMetadataFlyweight.decodeCustomAuthType(byteBuf).toString()) + .isEqualTo(testAuthType); + Assertions.assertThat(byteBuf.readerIndex()) + .isEqualTo(initialReaderIndex + testAuthType.length() + 1); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java index d945dd45d..58323c066 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.test.util; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import org.reactivestreams.Publisher; import reactor.core.publisher.DirectProcessor; @@ -25,17 +26,22 @@ import reactor.core.publisher.MonoProcessor; public class LocalDuplexConnection implements DuplexConnection { + private final ByteBufAllocator allocator; private final DirectProcessor send; private final DirectProcessor receive; private final MonoProcessor onClose; private final String name; public LocalDuplexConnection( - String name, DirectProcessor send, DirectProcessor receive) { + String name, + ByteBufAllocator allocator, + DirectProcessor send, + DirectProcessor receive) { this.name = name; + this.allocator = allocator; this.send = send; this.receive = receive; - onClose = MonoProcessor.create(); + this.onClose = MonoProcessor.create(); } @Override @@ -52,6 +58,11 @@ public Flux receive() { return receive.doOnNext(f -> System.out.println(name + " - " + f.toString())); } + @Override + public ByteBufAllocator alloc() { + return allocator; + } + @Override public void dispose() { onClose.onComplete(); diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java index 37ad8ee5b..a30e75875 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java @@ -1,12 +1,15 @@ package io.rsocket.test.util; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.transport.ClientTransport; import reactor.core.publisher.Mono; public class TestClientTransport implements ClientTransport { - - private final TestDuplexConnection testDuplexConnection = new TestDuplexConnection(); + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private final TestDuplexConnection testDuplexConnection = new TestDuplexConnection(allocator); @Override public Mono connect(int mtu) { @@ -16,4 +19,8 @@ public Mono connect(int mtu) { public TestDuplexConnection testConnection() { return testDuplexConnection; } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java index 6298b0c3a..17a19b8c9 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.test.util; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import java.util.Collection; import java.util.concurrent.ConcurrentLinkedQueue; @@ -46,17 +47,19 @@ public class TestDuplexConnection implements DuplexConnection { private final FluxSink receivedSink; private final MonoProcessor onClose; private final ConcurrentLinkedQueue> sendSubscribers; + private final ByteBufAllocator allocator; private volatile double availability = 1; private volatile int initialSendRequestN = Integer.MAX_VALUE; - public TestDuplexConnection() { - sent = new LinkedBlockingQueue<>(); - received = DirectProcessor.create(); - receivedSink = received.sink(); - sentPublisher = DirectProcessor.create(); - sendSink = sentPublisher.sink(); - sendSubscribers = new ConcurrentLinkedQueue<>(); - onClose = MonoProcessor.create(); + public TestDuplexConnection(ByteBufAllocator allocator) { + this.allocator = allocator; + this.sent = new LinkedBlockingQueue<>(); + this.received = DirectProcessor.create(); + this.receivedSink = received.sink(); + this.sentPublisher = DirectProcessor.create(); + this.sendSink = sentPublisher.sink(); + this.sendSubscribers = new ConcurrentLinkedQueue<>(); + this.onClose = MonoProcessor.create(); } @Override @@ -83,6 +86,11 @@ public Flux receive() { return received; } + @Override + public ByteBufAllocator alloc() { + return allocator; + } + @Override public double availability() { return availability; diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java index 5cebf0da1..325496148 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java @@ -1,12 +1,16 @@ package io.rsocket.test.util; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; import io.rsocket.transport.ServerTransport; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; public class TestServerTransport implements ServerTransport { private final MonoProcessor conn = MonoProcessor.create(); + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); @Override public Mono start(ConnectionAcceptor acceptor, int mtu) { @@ -39,8 +43,12 @@ private void disposeConnection() { } public TestDuplexConnection connect() { - TestDuplexConnection c = new TestDuplexConnection(); + TestDuplexConnection c = new TestDuplexConnection(allocator); conn.onNext(c); return c; } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } } diff --git a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java b/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java deleted file mode 100644 index 526757fbe..000000000 --- a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; -import reactor.core.publisher.Mono; - -public final class TestUriHandler implements UriHandler { - - private static final String SCHEME = "test"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of((mtu) -> Mono.just(new TestDuplexConnection())); - } - - @Override - public Optional buildServer(URI uri) { - return Optional.empty(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java b/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java deleted file mode 100644 index 7aeef708f..000000000 --- a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import static org.junit.Assert.assertTrue; - -import io.rsocket.DuplexConnection; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.transport.ClientTransport; -import org.junit.Test; - -public class UriTransportRegistryTest { - @Test - public void testTestRegistered() { - ClientTransport test = UriTransportRegistry.clientForUri("test://test"); - - DuplexConnection duplexConnection = test.connect(0).block(); - - assertTrue(duplexConnection instanceof TestDuplexConnection); - } - - @Test(expected = UnsupportedOperationException.class) - public void testTestUnregistered() { - ClientTransport test = UriTransportRegistry.clientForUri("mailto://bonson@baulsupp.net"); - - test.connect(0).block(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java new file mode 100644 index 000000000..2ad944d09 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java @@ -0,0 +1,64 @@ +package io.rsocket.util; + +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ByteBufPayloadTest { + + @Test + public void shouldIndicateThatItHasMetadata() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = ByteBufPayload.create("data"); + + Assertions.assertThat(payload.hasMetadata()).isFalse(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + ByteBufPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldThrowExceptionIfAccessAfterRelease() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.release()).isTrue(); + + Assertions.assertThatThrownBy(payload::hasMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::data).isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::metadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::touch) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(() -> payload.touch("test")) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getDataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java index 45ee4eacb..6bae0886b 100644 --- a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java +++ b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java @@ -16,10 +16,13 @@ package io.rsocket.util; -import static org.hamcrest.MatcherAssert.*; -import static org.hamcrest.Matchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import io.netty.buffer.Unpooled; import io.rsocket.Payload; +import java.nio.ByteBuffer; +import org.assertj.core.api.Assertions; import org.junit.Test; public class DefaultPayloadTest { @@ -48,4 +51,27 @@ public void staticMethods() { assertDataAndMetadata(DefaultPayload.create(DATA_VAL, METADATA_VAL), DATA_VAL, METADATA_VAL); assertDataAndMetadata(DefaultPayload.create(DATA_VAL), DATA_VAL, null); } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = DefaultPayload.create("data"); + + Assertions.assertThat(payload.hasMetadata()).isFalse(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + DefaultPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata2() { + Payload payload = + DefaultPayload.create(ByteBuffer.wrap("data".getBytes()), ByteBuffer.allocate(0)); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + } } diff --git a/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index 068667aa7..000000000 --- a/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.uri.TestUriHandler diff --git a/rsocket-examples/build.gradle b/rsocket-examples/build.gradle index 5f63b0761..01e80cfa1 100644 --- a/rsocket-examples/build.gradle +++ b/rsocket-examples/build.gradle @@ -22,13 +22,13 @@ dependencies { implementation project(':rsocket-core') implementation project(':rsocket-transport-local') implementation project(':rsocket-transport-netty') + runtimeOnly 'ch.qos.logback:logback-classic' testImplementation project(':rsocket-test') testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.mockito:mockito-core' testImplementation 'org.assertj:assertj-core' testImplementation 'io.projectreactor:reactor-test' - testImplementation 'ch.qos.logback:logback-classic' // TODO: Remove after JUnit5 migration testCompileOnly 'junit:junit' diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java index ac889ecfc..b532c0140 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,73 +16,48 @@ package io.rsocket.examples.transport.tcp.channel; -import io.rsocket.AbstractRSocket; -import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; -import io.rsocket.frame.decoder.PayloadDecoder; -import io.rsocket.transport.local.LocalClientTransport; -import io.rsocket.transport.local.LocalServerTransport; -import io.rsocket.util.ByteBufPayload; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; import java.time.Duration; -import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; public final class ChannelEchoClient { - static final Payload payload1 = ByteBufPayload.create("Hello "); - public static void main(String[] args) { - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new SocketAcceptorImpl()) - .transport(LocalServerTransport.create("localhost")) - .start() - .subscribe(); + private static final Logger logger = LoggerFactory.getLogger(ChannelEchoClient.class); - RSocket socket = - RSocketFactory.connect() - .keepAliveAckTimeout(Duration.ofMinutes(10)) - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(LocalClientTransport.create("localhost")) - .start() - .block(); - - Flux.range(0, 100000000) - .concatMap(i -> socket.fireAndForget(payload1.retain())) - // .doOnNext(p -> { - //// System.out.println(p.getDataUtf8()); - // p.release(); - // }) - .blockLast(); - } - - private static class SocketAcceptorImpl implements SocketAcceptor { - @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { - return Mono.just( - new AbstractRSocket() { + public static void main(String[] args) { - @Override - public Mono fireAndForget(Payload payload) { - // System.out.println(payload.getDataUtf8()); - payload.release(); - return Mono.empty(); - } + SocketAcceptor echoAcceptor = + SocketAcceptor.forRequestChannel( + payloads -> + Flux.from(payloads) + .map(Payload::getDataUtf8) + .map(s -> "Echo: " + s) + .map(DefaultPayload::create)); - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } + RSocketServer.create(echoAcceptor) + .bind(TcpServerTransport.create("localhost", 7000)) + .subscribe(); - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.from(payloads).subscribeOn(Schedulers.single()); - } - }); - } + RSocket socket = + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); + + socket + .requestChannel( + Flux.interval(Duration.ofMillis(1000)).map(i -> DefaultPayload.create("Hello"))) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .doFinally(signalType -> socket.dispose()) + .then() + .block(); } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java index c0a271d66..3eba5a800 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,12 @@ package io.rsocket.examples.transport.tcp.duplex; -import io.rsocket.AbstractRSocket; +import static io.rsocket.SocketAcceptor.forRequestStream; + import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; @@ -30,36 +32,30 @@ public final class DuplexClient { public static void main(String[] args) { - RSocketFactory.receive() - .acceptor( - (setup, reactiveSocket) -> { - reactiveSocket + + RSocketServer.create( + (setup, rsocket) -> { + rsocket .requestStream(DefaultPayload.create("Hello-Bidi")) .map(Payload::getDataUtf8) .log() .subscribe(); - return Mono.just(new AbstractRSocket() {}); + return Mono.just(new RSocket() {}); }) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); - RSocket socket = - RSocketFactory.connect() + RSocket rsocket = + RSocketConnector.create() .acceptor( - rSocket -> - new AbstractRSocket() { - @Override - public Flux requestStream(Payload payload) { - return Flux.interval(Duration.ofSeconds(1)) - .map(aLong -> DefaultPayload.create("Bi-di Response => " + aLong)); - } - }) - .transport(TcpClientTransport.create("localhost", 7000)) - .start() + forRequestStream( + payload -> + Flux.interval(Duration.ofSeconds(1)) + .map(aLong -> DefaultPayload.create("Bi-di Response => " + aLong)))) + .connect(TcpClientTransport.create("localhost", 7000)) .block(); - socket.onClose().block(); + rsocket.onClose().block(); } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java index 7482c7d1a..3eaebd89a 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/LeaseExample.java @@ -1,11 +1,28 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.examples.transport.tcp.lease; import static java.time.Duration.ofSeconds; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.lease.Lease; import io.rsocket.lease.LeaseStats; import io.rsocket.lease.Leases; @@ -27,28 +44,28 @@ public class LeaseExample { public static void main(String[] args) { CloseableChannel server = - RSocketFactory.receive() + RSocketServer.create( + (setup, sendingRSocket) -> Mono.just(new ServerRSocket(sendingRSocket))) .lease( () -> Leases.create() .sender(new LeaseSender(SERVER_TAG, 7_000, 5)) .receiver(new LeaseReceiver(SERVER_TAG)) .stats(new NoopStats())) - .acceptor((setup, sendingRSocket) -> Mono.just(new ServerAcceptor(sendingRSocket))) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + .bind(TcpServerTransport.create("localhost", 7000)) .block(); RSocket clientRSocket = - RSocketFactory.connect() + RSocketConnector.create() .lease( () -> Leases.create() .sender(new LeaseSender(CLIENT_TAG, 3_000, 5)) .receiver(new LeaseReceiver(CLIENT_TAG))) - .acceptor(rSocket -> new ClientAcceptor()) - .transport(TcpClientTransport.create(server.address())) - .start() + .acceptor( + SocketAcceptor.forRequestResponse( + payload -> Mono.just(DefaultPayload.create("Client Response " + new Date())))) + .connect(TcpClientTransport.create(server.address())) .block(); Flux.interval(ofSeconds(1)) @@ -118,17 +135,10 @@ private static class NoopStats implements LeaseStats { public void onEvent(EventType eventType) {} } - private static class ClientAcceptor extends AbstractRSocket { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(DefaultPayload.create("Client Response " + new Date())); - } - } - - private static class ServerAcceptor extends AbstractRSocket { + private static class ServerRSocket implements RSocket { private final RSocket senderRSocket; - public ServerAcceptor(RSocket senderRSocket) { + public ServerRSocket(RSocket senderRSocket) { this.senderRSocket = senderRSocket; } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java new file mode 100644 index 000000000..67a85b67f --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java @@ -0,0 +1,84 @@ +package io.rsocket.examples.transport.tcp.plugins; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.plugins.LimitRateInterceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public class LimitRateInterceptorExample { + + private static final Logger logger = LoggerFactory.getLogger(LimitRateInterceptorExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.interval(Duration.ofMillis(100)) + .doOnRequest( + e -> logger.debug("Server publisher receives request for " + e)) + .map(aLong -> DefaultPayload.create("Interval: " + aLong)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .doOnRequest( + e -> logger.debug("Server publisher receives request for " + e)); + } + })) + .interceptors(registry -> registry.forResponder(LimitRateInterceptor.forResponder(64))) + .bind(TcpServerTransport.create("localhost", 7000)) + .subscribe(); + + RSocket socket = + RSocketConnector.create() + .interceptors(registry -> registry.forRequester(LimitRateInterceptor.forRequester(64))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + logger.debug( + "\n\nStart of requestStream interaction\n" + "----------------------------------\n"); + + socket + .requestStream(DefaultPayload.create("Hello")) + .doOnRequest(e -> logger.debug("Client sends requestN(" + e + ")")) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .block(); + + logger.debug( + "\n\nStart of requestChannel interaction\n" + "-----------------------------------\n"); + + socket + .requestChannel( + Flux.generate( + () -> 1L, + (s, sink) -> { + sink.next(DefaultPayload.create("Next " + s)); + return ++s; + }) + .doOnRequest(e -> logger.debug("Client publisher receives request for " + e))) + .doOnRequest(e -> logger.debug("Client sends requestN(" + e + ")")) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .doFinally(signalType -> socket.dispose()) + .then() + .block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java index 537485fa4..85faeee82 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,65 +16,54 @@ package io.rsocket.examples.transport.tcp.requestresponse; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; public final class HelloWorldClient { + private static final Logger logger = LoggerFactory.getLogger(HelloWorldClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor( - (setupPayload, reactiveSocket) -> - Mono.just( - new AbstractRSocket() { - boolean fail = true; - @Override - public Mono requestResponse(Payload p) { - if (fail) { - fail = false; - return Mono.error(new Throwable()); - } else { - return Mono.just(p); - } - } - })) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() - .subscribe(); + RSocket rsocket = + new RSocket() { + boolean fail = true; - RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); + @Override + public Mono requestResponse(Payload p) { + if (fail) { + fail = false; + return Mono.error(new Throwable("Simulated error")); + } else { + return Mono.just(p); + } + } + }; - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + RSocketServer.create(SocketAcceptor.with(rsocket)) + .bind(TcpServerTransport.create("localhost", 7000)) + .subscribe(); - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + RSocket socket = + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); - socket - .requestResponse(DefaultPayload.create("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + for (int i = 0; i < 3; i++) { + socket + .requestResponse(DefaultPayload.create("Hello")) + .map(Payload::getDataUtf8) + .onErrorReturn("error") + .doOnNext(logger::debug) + .block(); + } socket.dispose(); } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java index e6867f8b5..6724ca93f 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java @@ -3,13 +3,21 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.rsocket.Payload; -import java.io.*; +import java.io.BufferedInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.SynchronousSink; class Files { + private static final Logger logger = LoggerFactory.getLogger(Files.class); public static Flux fileSource(String fileName, int chunkSizeBytes) { return Flux.generate( @@ -35,8 +43,7 @@ public void onNext(Payload payload) { ByteBuf data = payload.data(); receivedBytes += data.readableBytes(); receivedCount += 1; - System.out.println( - "Received file chunk: " + receivedCount + ". Total size: " + receivedBytes); + logger.debug("Received file chunk: " + receivedCount + ". Total size: " + receivedBytes); if (outputStream == null) { outputStream = open(fileName); } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java index ca115d281..93b54e146 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -1,91 +1,87 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.examples.transport.tcp.resume; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; -import io.rsocket.resume.ClientResume; -import io.rsocket.resume.PeriodicResumeStrategy; -import io.rsocket.resume.ResumeStrategy; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; -import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; public class ResumeFileTransfer { + /*amount of file chunks requested by subscriber: n, refilled on n/2 of received items*/ private static final int PREFETCH_WINDOW_SIZE = 4; + private static final Logger logger = LoggerFactory.getLogger(ResumeFileTransfer.class); public static void main(String[] args) { - RequestCodec requestCodec = new RequestCodec(); + + Resume resume = + new Resume() + .sessionDuration(Duration.ofMinutes(5)) + .retry( + Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)) + .doBeforeRetry(s -> logger.debug("Disconnected. Trying to resume..."))); + + RequestCodec codec = new RequestCodec(); CloseableChannel server = - RSocketFactory.receive() - .resume() - .resumeSessionDuration(Duration.ofMinutes(5)) - .acceptor((setup, rSocket) -> Mono.just(new FileServer(requestCodec))) - .transport(TcpServerTransport.create("localhost", 8000)) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> { + Request request = codec.decode(payload); + payload.release(); + String fileName = request.getFileName(); + int chunkSize = request.getChunkSize(); + + Flux ticks = Flux.interval(Duration.ofMillis(500)).onBackpressureDrop(); + + return Files.fileSource(fileName, chunkSize) + .map(DefaultPayload::create) + .zipWith(ticks, (p, tick) -> p); + })) + .resume(resume) + .bind(TcpServerTransport.create("localhost", 8000)) .block(); RSocket client = - RSocketFactory.connect() - .resume() - .resumeStrategy( - () -> new VerboseResumeStrategy(new PeriodicResumeStrategy(Duration.ofSeconds(1)))) - .resumeSessionDuration(Duration.ofMinutes(5)) - .transport(TcpClientTransport.create("localhost", 8001)) - .start() + RSocketConnector.create() + .resume(resume) + .connect(TcpClientTransport.create("localhost", 8001)) .block(); client - .requestStream(requestCodec.encode(new Request(16, "lorem.txt"))) + .requestStream(codec.encode(new Request(16, "lorem.txt"))) .doFinally(s -> server.dispose()) .subscribe(Files.fileSink("rsocket-examples/out/lorem_output.txt", PREFETCH_WINDOW_SIZE)); server.onClose().block(); } - private static class FileServer extends AbstractRSocket { - private final RequestCodec requestCodec; - - public FileServer(RequestCodec requestCodec) { - this.requestCodec = requestCodec; - } - - @Override - public Flux requestStream(Payload payload) { - Request request = requestCodec.decode(payload); - payload.release(); - String fileName = request.getFileName(); - int chunkSize = request.getChunkSize(); - - Flux ticks = Flux.interval(Duration.ofMillis(500)).onBackpressureDrop(); - - return Files.fileSource(fileName, chunkSize) - .map(DefaultPayload::create) - .zipWith(ticks, (p, tick) -> p); - } - } - - private static class VerboseResumeStrategy implements ResumeStrategy { - private final ResumeStrategy resumeStrategy; - - public VerboseResumeStrategy(ResumeStrategy resumeStrategy) { - this.resumeStrategy = resumeStrategy; - } - - @Override - public Publisher apply(ClientResume clientResume, Throwable throwable) { - return Flux.from(resumeStrategy.apply(clientResume, throwable)) - .doOnNext(v -> System.out.println("Disconnected. Trying to resume connection...")); - } - } - private static class RequestCodec { public Payload encode(Request request) { diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java index 57a659c1d..6ac329d56 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,51 +16,43 @@ package io.rsocket.examples.transport.tcp.stream; -import io.rsocket.*; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public final class StreamingClient { + private static final Logger logger = LoggerFactory.getLogger(StreamingClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor(new SocketAcceptorImpl()) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.interval(Duration.ofMillis(100)) + .map(aLong -> DefaultPayload.create("Interval: " + aLong)))) + .bind(TcpServerTransport.create("localhost", 7000)) .subscribe(); RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); socket .requestStream(DefaultPayload.create("Hello")) .map(Payload::getDataUtf8) - .doOnNext(System.out::println) + .doOnNext(logger::debug) .take(10) .then() .doFinally(signalType -> socket.dispose()) .then() .block(); } - - private static class SocketAcceptorImpl implements SocketAcceptor { - @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { - return Mono.just( - new AbstractRSocket() { - @Override - public Flux requestStream(Payload payload) { - return Flux.interval(Duration.ofMillis(100)) - .map(aLong -> DefaultPayload.create("Interval: " + aLong)); - } - }); - } - } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java index d3865c01b..2ab73116d 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,13 +17,13 @@ package io.rsocket.examples.transport.ws; import io.netty.handler.codec.http.HttpResponseStatus; -import io.rsocket.AbstractRSocket; -import io.rsocket.ConnectionSetupPayload; import io.rsocket.DuplexConnection; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.WebsocketDuplexConnection; @@ -45,10 +45,9 @@ public class WebSocketHeadersSample { public static void main(String[] args) { ServerTransport.ConnectionAcceptor acceptor = - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new SocketAcceptorImpl()) - .toConnectionAcceptor(); + RSocketServer.create(SocketAcceptor.with(new ServerRSocket())) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .asConnectionAcceptor(); DisposableServer disposableServer = HttpServer.create() @@ -61,7 +60,8 @@ public static void main(String[] args) { (in, out) -> { if (in.headers().containsValue("Authorization", "test", true)) { DuplexConnection connection = - new WebsocketDuplexConnection((Connection) in); + new ReassemblyDuplexConnection( + new WebsocketDuplexConnection((Connection) in), false); return acceptor.apply(connection).then(out.neverComplete()); } @@ -82,11 +82,10 @@ public static void main(String[] args) { }); RSocket socket = - RSocketFactory.connect() - .keepAliveAckTimeout(Duration.ofMinutes(10)) - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(clientTransport) - .start() + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(clientTransport) .block(); Flux.range(0, 100) @@ -102,40 +101,33 @@ public static void main(String[] args) { WebsocketClientTransport.create(disposableServer.host(), disposableServer.port()); RSocket rSocket = - RSocketFactory.connect() - .keepAliveAckTimeout(Duration.ofMinutes(10)) - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(clientTransport2) - .start() + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(clientTransport2) .block(); // expect error here because of closed channel rSocket.requestResponse(payload1).block(); } - private static class SocketAcceptorImpl implements SocketAcceptor { + private static class ServerRSocket implements RSocket { + + @Override + public Mono fireAndForget(Payload payload) { + // System.out.println(payload.getDataUtf8()); + payload.release(); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { - return Mono.just( - new AbstractRSocket() { - - @Override - public Mono fireAndForget(Payload payload) { - // System.out.println(payload.getDataUtf8()); - payload.release(); - return Mono.empty(); - } - - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.from(payloads).subscribeOn(Schedulers.single()); - } - }); + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads).subscribeOn(Schedulers.single()); } } } diff --git a/rsocket-examples/src/main/resources/log4j.properties b/rsocket-examples/src/main/resources/log4j.properties deleted file mode 100644 index 035f18ebd..000000000 --- a/rsocket-examples/src/main/resources/log4j.properties +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=DEBUG, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n \ No newline at end of file diff --git a/rsocket-examples/src/main/resources/logback.xml b/rsocket-examples/src/main/resources/logback.xml new file mode 100644 index 000000000..17dd8b5e3 --- /dev/null +++ b/rsocket-examples/src/main/resources/logback.xml @@ -0,0 +1,35 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + + + diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java index 19c29061b..e2471f2fc 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,10 +23,10 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.plugins.DuplexConnectionInterceptor; import io.rsocket.plugins.RSocketInterceptor; import io.rsocket.plugins.SocketAcceptorInterceptor; @@ -39,7 +39,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; -import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.reactivestreams.Publisher; @@ -49,19 +48,20 @@ public class IntegrationTest { - private static final RSocketInterceptor requesterPlugin; - private static final RSocketInterceptor responderPlugin; - private static final SocketAcceptorInterceptor clientAcceptorPlugin; - private static final SocketAcceptorInterceptor serverAcceptorPlugin; - private static final DuplexConnectionInterceptor connectionPlugin; - public static volatile boolean calledRequester = false; - public static volatile boolean calledResponder = false; - public static volatile boolean calledClientAcceptor = false; - public static volatile boolean calledServerAcceptor = false; - public static volatile boolean calledFrame = false; + private static final RSocketInterceptor requesterInterceptor; + private static final RSocketInterceptor responderInterceptor; + private static final SocketAcceptorInterceptor clientAcceptorInterceptor; + private static final SocketAcceptorInterceptor serverAcceptorInterceptor; + private static final DuplexConnectionInterceptor connectionInterceptor; + + private static volatile boolean calledRequester = false; + private static volatile boolean calledResponder = false; + private static volatile boolean calledClientAcceptor = false; + private static volatile boolean calledServerAcceptor = false; + private static volatile boolean calledFrame = false; static { - requesterPlugin = + requesterInterceptor = reactiveSocket -> new RSocketProxy(reactiveSocket) { @Override @@ -71,7 +71,7 @@ public Mono requestResponse(Payload payload) { } }; - responderPlugin = + responderInterceptor = reactiveSocket -> new RSocketProxy(reactiveSocket) { @Override @@ -81,21 +81,21 @@ public Mono requestResponse(Payload payload) { } }; - clientAcceptorPlugin = + clientAcceptorInterceptor = acceptor -> (setup, sendingSocket) -> { calledClientAcceptor = true; return acceptor.accept(setup, sendingSocket); }; - serverAcceptorPlugin = + serverAcceptorInterceptor = acceptor -> (setup, sendingSocket) -> { calledServerAcceptor = true; return acceptor.accept(setup, sendingSocket); }; - connectionPlugin = + connectionInterceptor = (type, connection) -> { calledFrame = true; return connection; @@ -114,18 +114,8 @@ public void startup() { requestCount = new AtomicInteger(); disconnectionCounter = new CountDownLatch(1); - TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); - server = - RSocketFactory.receive() - .addResponderPlugin(responderPlugin) - .addSocketAcceptorPlugin(serverAcceptorPlugin) - .addConnectionPlugin(connectionPlugin) - .errorConsumer( - t -> { - errorCount.incrementAndGet(); - }) - .acceptor( + RSocketServer.create( (setup, sendingSocket) -> { sendingSocket .onClose() @@ -133,7 +123,7 @@ public void startup() { .subscribe(); return Mono.just( - new AbstractRSocket() { + new RSocket() { @Override public Mono requestResponse(Payload payload) { return Mono.just(DefaultPayload.create("RESPONSE", "METADATA")) @@ -152,17 +142,24 @@ public Flux requestChannel(Publisher payloads) { } }); }) - .transport(serverTransport) - .start() + .interceptors( + registry -> + registry + .forResponder(responderInterceptor) + .forSocketAcceptor(serverAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) .block(); client = - RSocketFactory.connect() - .addRequesterPlugin(requesterPlugin) - .addSocketAcceptorPlugin(clientAcceptorPlugin) - .addConnectionPlugin(connectionPlugin) - .transport(TcpClientTransport.create(server.address())) - .start() + RSocketConnector.create() + .interceptors( + registry -> + registry + .forRequester(requesterInterceptor) + .forSocketAcceptor(clientAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .connect(TcpClientTransport.create(server.address())) .block(); } @@ -204,8 +201,6 @@ public void testCallRequestWithErrorAndThenRequest() { } catch (Throwable t) { } - Assert.assertEquals(1, errorCount.incrementAndGet()); - testRequest(); } } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java index 7a30a7fd1..48e5baaa7 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -1,9 +1,10 @@ package io.rsocket.integration; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.test.SlowTest; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; @@ -14,32 +15,26 @@ import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public class InteractionsLoadTest { @Test @SlowTest public void channel() { - TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); - CloseableChannel server = - RSocketFactory.receive() - .acceptor((setup, rsocket) -> Mono.just(new EchoRSocket())) - .transport(serverTransport) - .start() + RSocketServer.create(SocketAcceptor.with(new EchoRSocket())) + .bind(TcpServerTransport.create("localhost", 0)) .block(Duration.ofSeconds(10)); - TcpClientTransport transport = TcpClientTransport.create(server.address()); - - RSocket client = - RSocketFactory.connect().transport(transport).start().block(Duration.ofSeconds(10)); + RSocket clientRSocket = + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) + .block(Duration.ofSeconds(10)); int concurrency = 16; Flux.range(1, concurrency) .flatMap( v -> - client + clientRSocket .requestChannel( input().onBackpressureDrop().map(iv -> DefaultPayload.create("foo"))) .limitRate(10000), @@ -70,7 +65,8 @@ private static Flux input() { return interval; } - private static class EchoRSocket extends AbstractRSocket { + private static class EchoRSocket implements RSocket { + @Override public Flux requestChannel(Publisher payloads) { return Flux.from(payloads) diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java index 9e7f5b0a7..de27bcb9b 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -19,10 +19,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; @@ -40,26 +40,20 @@ import reactor.core.scheduler.Schedulers; public class TcpIntegrationTest { - private AbstractRSocket handler; + private RSocket handler; private CloseableChannel server; @Before public void startup() { - TcpServerTransport serverTransport = TcpServerTransport.create("localhost", 0); server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) - .transport(serverTransport) - .start() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) + .bind(TcpServerTransport.create("localhost", 0)) .block(); } private RSocket buildClient() { - return RSocketFactory.connect() - .transport(TcpClientTransport.create(server.address())) - .start() - .block(); + return RSocketConnector.connectWith(TcpClientTransport.create(server.address())).block(); } @After @@ -70,7 +64,7 @@ public void cleanup() { @Test(timeout = 15_000L) public void testCompleteWithoutNext() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return Flux.empty(); @@ -86,7 +80,7 @@ public Flux requestStream(Payload payload) { @Test(timeout = 15_000L) public void testSingleStream() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return Flux.just(DefaultPayload.create("RESPONSE", "METADATA")); @@ -103,7 +97,7 @@ public Flux requestStream(Payload payload) { @Test(timeout = 15_000L) public void testZeroPayload() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return Flux.just(EmptyPayload.INSTANCE); @@ -120,7 +114,7 @@ public Flux requestStream(Payload payload) { @Test(timeout = 15_000L) public void testRequestResponseErrors() { handler = - new AbstractRSocket() { + new RSocket() { boolean first = true; @Override @@ -160,7 +154,7 @@ public void testTwoConcurrentStreams() throws InterruptedException { map.put("REQUEST2", processor2); handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return map.get(payload.getDataUtf8()); diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java index ec1d41bf9..7d34ba478 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,60 +16,41 @@ package io.rsocket.integration; -import io.rsocket.*; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.exceptions.ApplicationErrorException; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; import io.rsocket.transport.local.LocalClientTransport; import io.rsocket.transport.local.LocalServerTransport; import io.rsocket.util.DefaultPayload; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Supplier; import org.junit.Test; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public class TestingStreaming { - private Supplier> serverSupplier = - () -> LocalServerTransport.create("test"); - - private Supplier clientSupplier = () -> LocalClientTransport.create("test"); + LocalServerTransport serverTransport = LocalServerTransport.create("test"); @Test(expected = ApplicationErrorException.class) public void testRangeButThrowException() { Closeable server = null; try { server = - RSocketFactory.receive() - .errorConsumer(Throwable::printStackTrace) - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1000) - .doOnNext( - i -> { - if (i > 3) { - throw new RuntimeException("BOOM!"); - } - }) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 1000) + .doOnNext( + i -> { + if (i > 3) { + throw new RuntimeException("BOOM!"); + } + }) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); @@ -85,29 +66,13 @@ public void testRangeOfConsumers() { Closeable server = null; try { server = - RSocketFactory.receive() - .errorConsumer(Throwable::printStackTrace) - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1000) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 1000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); @@ -119,10 +84,7 @@ public Flux requestStream(Payload payload) { } private Flux consumer(String s) { - return RSocketFactory.connect() - .errorConsumer(Throwable::printStackTrace) - .transport(clientSupplier) - .start() + return RSocketConnector.connectWith(LocalClientTransport.create("test")) .flatMapMany( rSocket -> { AtomicInteger count = new AtomicInteger(); @@ -135,31 +97,15 @@ private Flux consumer(String s) { @Test public void testSingleConsumer() { Closeable server = null; - try { server = - RSocketFactory.receive() - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 10_000) - .map(l -> DefaultPayload.create("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 10_000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); consumer("1").blockLast(); diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java index 009d0d8db..b2dad0022 100644 --- a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,12 @@ package io.rsocket.resume; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.exceptions.RejectedResumeException; import io.rsocket.exceptions.UnsupportedSetupException; import io.rsocket.test.SlowTest; @@ -33,15 +35,14 @@ import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; import org.assertj.core.api.Assertions; import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.ReplayProcessor; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; +import reactor.util.retry.Retry; @SlowTest public class ResumeIntegrationTest { @@ -103,11 +104,9 @@ public void reconnectOnMissingSession() { DisconnectableClientTransport clientTransport = new DisconnectableClientTransport(clientTransport(closeable.address())); - ErrorConsumer errorConsumer = new ErrorConsumer(); int clientSessionDurationSeconds = 10; - RSocket rSocket = - newClientRSocket(clientTransport, clientSessionDurationSeconds, errorConsumer).block(); + RSocket rSocket = newClientRSocket(clientTransport, clientSessionDurationSeconds).block(); Mono.delay(Duration.ofSeconds(1)) .subscribe(v -> clientTransport.disconnectFor(Duration.ofSeconds(3))); @@ -117,43 +116,34 @@ public void reconnectOnMissingSession() { .expectError() .verify(Duration.ofSeconds(5)); - StepVerifier.create(errorConsumer.errors().next()) - .expectNextMatches( + StepVerifier.create(rSocket.onClose()) + .expectErrorMatches( err -> err instanceof RejectedResumeException && "unknown resume token".equals(err.getMessage())) - .expectComplete() .verify(Duration.ofSeconds(5)); } @Test void serverMissingResume() { CloseableChannel closeableChannel = - RSocketFactory.receive() - .acceptor((setupPayload, rSocket) -> Mono.just(new TestResponderRSocket())) - .transport(serverTransport(SERVER_HOST, SERVER_PORT)) - .start() + RSocketServer.create(SocketAcceptor.with(new TestResponderRSocket())) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)) .block(); - ErrorConsumer errorConsumer = new ErrorConsumer(); - RSocket rSocket = - RSocketFactory.connect() - .resume() - .errorConsumer(errorConsumer) - .transport(clientTransport(closeableChannel.address())) - .start() + RSocketConnector.create() + .resume(new Resume()) + .connect(clientTransport(closeableChannel.address())) .block(); - StepVerifier.create(errorConsumer.errors().next().doFinally(s -> closeableChannel.dispose())) - .expectNextMatches( + StepVerifier.create(rSocket.onClose().doFinally(s -> closeableChannel.dispose())) + .expectErrorMatches( err -> err instanceof UnsupportedSetupException && "resume not supported".equals(err.getMessage())) - .expectComplete() .verify(Duration.ofSeconds(5)); - StepVerifier.create(rSocket.onClose()).expectComplete().verify(Duration.ofSeconds(5)); Assertions.assertThat(rSocket.isDisposed()).isTrue(); } @@ -165,21 +155,8 @@ static ServerTransport serverTransport(String host, int port) return TcpServerTransport.create(host, port); } - private static class ErrorConsumer implements Consumer { - private final ReplayProcessor errors = ReplayProcessor.create(); - - public Flux errors() { - return errors; - } - - @Override - public void accept(Throwable throwable) { - errors.onNext(throwable); - } - } - private static Flux testRequest() { - return Flux.interval(Duration.ofMillis(50)) + return Flux.interval(Duration.ofMillis(500)) .map(v -> DefaultPayload.create("client_request")) .onBackpressureDrop(); } @@ -201,24 +178,15 @@ private void throwOnNonContinuous(AtomicInteger counter, String x) { private static Mono newClientRSocket( DisconnectableClientTransport clientTransport, int sessionDurationSeconds) { - return newClientRSocket(clientTransport, sessionDurationSeconds, err -> {}); - } - - private static Mono newClientRSocket( - DisconnectableClientTransport clientTransport, - int sessionDurationSeconds, - Consumer errConsumer) { - return RSocketFactory.connect() - .resume() - .resumeSessionDuration(Duration.ofSeconds(sessionDurationSeconds)) - .resumeStore(t -> new InMemoryResumableFramesStore("client", 500_000)) - .resumeCleanupOnKeepAlive() - .keepAliveTickPeriod(Duration.ofSeconds(5)) - .keepAliveAckTimeout(Duration.ofMinutes(5)) - .errorConsumer(errConsumer) - .resumeStrategy(() -> new PeriodicResumeStrategy(Duration.ofSeconds(1))) - .transport(clientTransport) - .start(); + return RSocketConnector.create() + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .storeFactory(t -> new InMemoryResumableFramesStore("client", 500_000)) + .cleanupStoreOnKeepAlive() + .retry(Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)))) + .keepAlive(Duration.ofSeconds(5), Duration.ofMinutes(5)) + .connect(clientTransport); } private static Mono newServerRSocket() { @@ -226,17 +194,16 @@ private static Mono newServerRSocket() { } private static Mono newServerRSocket(int sessionDurationSeconds) { - return RSocketFactory.receive() - .resume() - .resumeStore(t -> new InMemoryResumableFramesStore("server", 500_000)) - .resumeSessionDuration(Duration.ofSeconds(sessionDurationSeconds)) - .resumeCleanupOnKeepAlive() - .acceptor((setupPayload, rSocket) -> Mono.just(new TestResponderRSocket())) - .transport(serverTransport(SERVER_HOST, SERVER_PORT)) - .start(); + return RSocketServer.create(SocketAcceptor.with(new TestResponderRSocket())) + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .cleanupStoreOnKeepAlive() + .storeFactory(t -> new InMemoryResumableFramesStore("server", 500_000))) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)); } - private static class TestResponderRSocket extends AbstractRSocket { + private static class TestResponderRSocket implements RSocket { AtomicInteger counter = new AtomicInteger(); diff --git a/rsocket-examples/src/test/resources/log4j.properties b/rsocket-examples/src/test/resources/log4j.properties deleted file mode 100644 index 51731fc15..000000000 --- a/rsocket-examples/src/test/resources/log4j.properties +++ /dev/null @@ -1,21 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss,SSS} %5p [%t] (%F) - %m%n -#log4j.logger.io.rsocket.FrameLogger=Debug \ No newline at end of file diff --git a/rsocket-examples/src/test/resources/logback-test.xml b/rsocket-examples/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-examples/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-load-balancer/build.gradle b/rsocket-load-balancer/build.gradle index a2c8b73c7..748f95de6 100644 --- a/rsocket-load-balancer/build.gradle +++ b/rsocket-load-balancer/build.gradle @@ -34,6 +34,7 @@ dependencies { testCompileOnly 'junit:junit' testImplementation 'org.hamcrest:hamcrest-library' testRuntimeOnly 'org.junit.vintage:junit-vintage-engine' + testRuntimeOnly 'ch.qos.logback:logback-classic' } description = 'Transparent Load Balancer for RSocket' diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java index ed7550233..65ce80934 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java @@ -536,7 +536,7 @@ public Mono onClose() { * Wrapper of a RSocket, it computes statistics about the req/resp calls and update availability * accordingly. */ - private class WeightedSocket extends AbstractRSocket implements LoadBalancerSocketMetrics { + private class WeightedSocket implements LoadBalancerSocketMetrics, RSocket { private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; private final Quantile lowerQuantile; @@ -554,6 +554,7 @@ private class WeightedSocket extends AbstractRSocket implements LoadBalancerSock private AtomicLong pendingStreams; // number of active streams private volatile double availability = 0.0; + private final MonoProcessor onClose = MonoProcessor.create(); WeightedSocket( RSocketSupplier factory, @@ -791,6 +792,21 @@ public double availability() { return availability; } + @Override + public void dispose() { + onClose.onComplete(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + @Override public String toString() { return "WeightedSocket(" diff --git a/rsocket-core/src/main/java/io/rsocket/util/Function3.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java similarity index 79% rename from rsocket-core/src/main/java/io/rsocket/util/Function3.java rename to rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java index 5783665ae..55ce5646c 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/Function3.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2019 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,7 @@ * limitations under the License. */ -package io.rsocket.util; +@NonNullApi +package io.rsocket.client.filter; -public interface Function3 { - - R apply(T t, U u, V v); -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/SupportsIterator.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java similarity index 68% rename from rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/SupportsIterator.java rename to rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java index 50d2a326f..ec21dee96 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/SupportsIterator.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java @@ -1,9 +1,11 @@ /* + * Copyright 2015-2020 the original author or authors. + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -11,10 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.internal.jctools.queues; -import io.rsocket.internal.jctools.util.InternalAPI; +@NonNullApi +package io.rsocket.client; -/** Tagging interface to help testing */ -@InternalAPI -public interface SupportsIterator {} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java new file mode 100644 index 000000000..cfb071175 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +package io.rsocket.stat; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/test/resources/log4j.properties b/rsocket-load-balancer/src/test/resources/log4j.properties deleted file mode 100644 index 8fc3a9cdd..000000000 --- a/rsocket-load-balancer/src/test/resources/log4j.properties +++ /dev/null @@ -1,20 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] (%F:%L) - %m%n \ No newline at end of file diff --git a/rsocket-load-balancer/src/test/resources/logback-test.xml b/rsocket-load-balancer/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-load-balancer/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-micrometer/build.gradle b/rsocket-micrometer/build.gradle index 5f2aeb16f..4be616623 100644 --- a/rsocket-micrometer/build.gradle +++ b/rsocket-micrometer/build.gradle @@ -27,8 +27,6 @@ dependencies { implementation 'org.slf4j:slf4j-api' - compileOnly 'com.google.code.findbugs:jsr305' - testImplementation project(':rsocket-test') testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java index 20d58dcb7..c8b22382a 100644 --- a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java @@ -20,8 +20,9 @@ import io.micrometer.core.instrument.*; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameHeaderFlyweight; +import io.rsocket.frame.FrameHeaderCodec; import io.rsocket.frame.FrameType; import io.rsocket.plugins.DuplexConnectionInterceptor.Type; import java.util.Objects; @@ -82,6 +83,11 @@ final class MicrometerDuplexConnection implements DuplexConnection { this.frameCounters = new FrameCounters(connectionType, meterRegistry, tags); } + @Override + public ByteBufAllocator alloc() { + return delegate.alloc(); + } + @Override public void dispose() { delegate.dispose(); @@ -185,7 +191,7 @@ private static Counter counter( @Override public void accept(ByteBuf frame) { - FrameType frameType = FrameHeaderFlyweight.frameType(frame); + FrameType frameType = FrameHeaderCodec.frameType(frame); switch (frameType) { case SETUP: diff --git a/rsocket-test/build.gradle b/rsocket-test/build.gradle index 3009b5135..5ec1a8061 100644 --- a/rsocket-test/build.gradle +++ b/rsocket-test/build.gradle @@ -26,8 +26,6 @@ dependencies { api 'org.hdrhistogram:HdrHistogram' api 'org.junit.jupiter:junit-jupiter-api' - compileOnly 'com.google.code.findbugs:jsr305' - implementation 'io.projectreactor:reactor-test' implementation 'org.assertj:assertj-core' implementation 'org.mockito:mockito-core' diff --git a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java index ec143b7ab..6f562875f 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java +++ b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,8 @@ import io.rsocket.Closeable; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.function.BiFunction; @@ -47,17 +48,13 @@ public ClientSetupRule( this.serverInit = address -> - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) - .transport(serverTransportSupplier.apply(address)) - .start() + RSocketServer.create((setup, rsocket) -> Mono.just(new TestRSocket(data, metadata))) + .bind(serverTransportSupplier.apply(address)) .block(); this.clientConnector = (address, server) -> - RSocketFactory.connect() - .transport(clientTransportSupplier.apply(address, server)) - .start() + RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) .doOnError(Throwable::printStackTrace) .block(); } diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java index 902014e7f..47f40a59d 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java @@ -16,7 +16,6 @@ package io.rsocket.test; -import io.rsocket.AbstractRSocket; import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -43,7 +42,7 @@ public PingHandler(byte[] data) { @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { return Mono.just( - new AbstractRSocket() { + new RSocket() { @Override public Mono requestResponse(Payload payload) { payload.release(); diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java index 2651b14ec..1e66abc5e 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java @@ -33,71 +33,69 @@ private TestFrames() {} /** @return {@link ByteBuf} representing test instance of Cancel frame */ public static ByteBuf createTestCancelFrame() { - return CancelFrameFlyweight.encode(allocator, 1); + return CancelFrameCodec.encode(allocator, 1); } /** @return {@link ByteBuf} representing test instance of Error frame */ public static ByteBuf createTestErrorFrame() { - return ErrorFrameFlyweight.encode(allocator, 1, new RuntimeException()); + return ErrorFrameCodec.encode(allocator, 1, new RuntimeException()); } /** @return {@link ByteBuf} representing test instance of Extension frame */ public static ByteBuf createTestExtensionFrame() { - return ExtensionFrameFlyweight.encode( + return ExtensionFrameCodec.encode( allocator, 1, 1, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Keep-Alive frame */ public static ByteBuf createTestKeepaliveFrame() { - return KeepAliveFrameFlyweight.encode(allocator, false, 1, Unpooled.EMPTY_BUFFER); + return KeepAliveFrameCodec.encode(allocator, false, 1, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Lease frame */ public static ByteBuf createTestLeaseFrame() { - return LeaseFrameFlyweight.encode(allocator, 1, 1, null); + return LeaseFrameCodec.encode(allocator, 1, 1, null); } /** @return {@link ByteBuf} representing test instance of Metadata-Push frame */ public static ByteBuf createTestMetadataPushFrame() { - return MetadataPushFrameFlyweight.encode(allocator, Unpooled.EMPTY_BUFFER); + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Payload frame */ public static ByteBuf createTestPayloadFrame() { - return PayloadFrameFlyweight.encode( - allocator, 1, false, true, false, null, Unpooled.EMPTY_BUFFER); + return PayloadFrameCodec.encode(allocator, 1, false, true, false, null, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Request-Channel frame */ public static ByteBuf createTestRequestChannelFrame() { - return RequestChannelFrameFlyweight.encode( + return RequestChannelFrameCodec.encode( allocator, 1, false, false, 1, null, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Fire-and-Forget frame */ public static ByteBuf createTestRequestFireAndForgetFrame() { - return RequestFireAndForgetFrameFlyweight.encode( - allocator, 1, false, null, Unpooled.EMPTY_BUFFER); + return RequestFireAndForgetFrameCodec.encode(allocator, 1, false, null, Unpooled.EMPTY_BUFFER); } /** @return {@link ByteBuf} representing test instance of Request-N frame */ public static ByteBuf createTestRequestNFrame() { - return RequestNFrameFlyweight.encode(allocator, 1, 1); + return RequestNFrameCodec.encode(allocator, 1, 1); } /** @return {@link ByteBuf} representing test instance of Request-Response frame */ public static ByteBuf createTestRequestResponseFrame() { - return RequestResponseFrameFlyweight.encode(allocator, 1, false, emptyPayload); + return RequestResponseFrameCodec.encodeReleasingPayload(allocator, 1, emptyPayload); } /** @return {@link ByteBuf} representing test instance of Request-Stream frame */ public static ByteBuf createTestRequestStreamFrame() { - return RequestStreamFrameFlyweight.encode(allocator, 1, false, 1L, emptyPayload); + return RequestStreamFrameCodec.encodeReleasingPayload(allocator, 1, 1L, emptyPayload); } /** @return {@link ByteBuf} representing test instance of Setup frame */ public static ByteBuf createTestSetupFrame() { - return SetupFrameFlyweight.encode( + return SetupFrameCodec.encode( allocator, false, 1, diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index 57a2e5c3c..d48700445 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -16,14 +16,14 @@ package io.rsocket.test; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; +import io.rsocket.RSocket; import io.rsocket.util.DefaultPayload; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -public class TestRSocket extends AbstractRSocket { +public class TestRSocket implements RSocket { private final String data; private final String metadata; @@ -55,6 +55,6 @@ public Mono fireAndForget(Payload payload) { @Override public Flux requestChannel(Publisher payloads) { // TODO is defensive copy neccesary? - return Flux.from(payloads).map(DefaultPayload::create); + return Flux.from(payloads).map(Payload::retain); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java index fc6301d7d..fc059c7d1 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,28 +19,62 @@ import io.rsocket.Closeable; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.util.DefaultPayload; +import java.io.BufferedReader; +import java.io.InputStreamReader; import java.time.Duration; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; +import org.assertj.core.api.Assertions; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import reactor.core.Disposable; import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; public interface TransportTest { + String MOCK_DATA = "test-data"; + String MOCK_METADATA = "metadata"; + String LARGE_DATA = read("words.shakespeare.txt.gz"); + Payload LARGE_PAYLOAD = DefaultPayload.create(LARGE_DATA, LARGE_DATA); + + static String read(String resourceName) { + + try (BufferedReader br = + new BufferedReader( + new InputStreamReader( + new GZIPInputStream( + TransportTest.class.getClassLoader().getResourceAsStream(resourceName))))) { + + return br.lines().map(String::toLowerCase).collect(Collectors.joining("\n\r")); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @BeforeEach + default void setUp() { + Hooks.onOperatorDebug(); + } + @AfterEach default void close() { getTransportPair().dispose(); + Hooks.resetOnOperatorDebug(); } default Payload createTestPayload(int metadataPresent) { @@ -54,12 +88,12 @@ default Payload createTestPayload(int metadataPresent) { metadata1 = ""; break; default: - metadata1 = "metadata"; + metadata1 = MOCK_METADATA; break; } String metadata = metadata1; - return DefaultPayload.create("test-data", metadata); + return DefaultPayload.create(MOCK_DATA, metadata); } @DisplayName("makes 10 fireAndForget requests") @@ -73,6 +107,17 @@ default void fireAndForget10() { .verify(getTimeout()); } + @DisplayName("makes 10 fireAndForget with Large Payload in Requests") + @Test + default void largePayloadFireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD)) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + default RSocket getClient() { return getTransportPair().getClient(); } @@ -92,6 +137,17 @@ default void metadataPush10() { .verify(getTimeout()); } + @DisplayName("makes 10 metadataPush with Large Metadata in requests") + @Test + default void largePayloadMetadataPush10() { + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(DefaultPayload.create("", LARGE_DATA))) + .as(StepVerifier::create) + .expectNextCount(0) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 1 requestChannel request with 0 payloads") @Test default void requestChannel0() { @@ -127,6 +183,19 @@ default void requestChannel200_000() { .verify(getTimeout()); } + @DisplayName("makes 1 requestChannel request with 200 large payloads") + @Test + default void largePayloadRequestChannel200() { + Flux payloads = Flux.range(0, 200).map(__ -> LARGE_PAYLOAD); + + getClient() + .requestChannel(payloads) + .as(StepVerifier::create) + .expectNextCount(200) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 1 requestChannel request with 20,000 payloads") @Test default void requestChannel20_000() { @@ -157,14 +226,18 @@ default void requestChannel2_000_000() { @DisplayName("makes 1 requestChannel request with 3 payloads") @Test default void requestChannel3() { - Flux payloads = Flux.range(0, 3).map(this::createTestPayload); + AtomicLong requested = new AtomicLong(); + Flux payloads = + Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); getClient() .requestChannel(payloads) - .as(StepVerifier::create) + .as(publisher -> StepVerifier.create(publisher, 3)) .expectNextCount(3) .expectComplete() .verify(getTimeout()); + + Assertions.assertThat(requested.get()).isEqualTo(3L); } @DisplayName("makes 1 requestChannel request with 512 payloads") @@ -223,6 +296,17 @@ default void requestResponse100() { .verify(getTimeout()); } + @DisplayName("makes 100 requestResponse requests") + @Test + default void largePayloadRequestResponse100() { + Flux.range(1, 100) + .flatMap(i -> getClient().requestResponse(LARGE_PAYLOAD).map(Payload::getDataUtf8)) + .as(StepVerifier::create) + .expectNextCount(100) + .expectComplete() + .verify(getTimeout()); + } + @DisplayName("makes 10,000 requestResponse requests") @Test default void requestResponse10_000() { @@ -283,7 +367,7 @@ default void assertPayload(Payload p) { } default void assertChannelPayload(Payload p) { - if (!"test-data".equals(p.getDataUtf8()) || !"metadata".equals(p.getMetadataUtf8())) { + if (!MOCK_DATA.equals(p.getDataUtf8()) || !MOCK_METADATA.equals(p.getMetadataUtf8())) { throw new IllegalStateException("Unexpected payload"); } } @@ -304,16 +388,12 @@ public TransportPair( T address = addressSupplier.get(); server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) - .transport(serverTransportSupplier.apply(address)) - .start() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new TestRSocket(data, metadata))) + .bind(serverTransportSupplier.apply(address)) .block(); client = - RSocketFactory.connect() - .transport(clientTransportSupplier.apply(address, server)) - .start() + RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) .doOnError(Throwable::printStackTrace) .block(); } diff --git a/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java b/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java deleted file mode 100644 index ad45e106a..000000000 --- a/rsocket-test/src/main/java/io/rsocket/test/UriHandlerTest.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.test; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; - -import io.rsocket.uri.UriHandler; -import java.net.URI; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -public interface UriHandlerTest { - - @DisplayName("returns empty Optional client with invalid URI") - @Test - default void buildClientInvalidUri() { - assertThat(getUriHandler().buildClient(URI.create(getInvalidUri()))).isEmpty(); - } - - @DisplayName("buildClient throws NullPointerException with null uri") - @Test - default void buildClientNullUri() { - assertThatNullPointerException() - .isThrownBy(() -> getUriHandler().buildClient(null)) - .withMessage("uri must not be null"); - } - - @DisplayName("returns client with value URI") - @Test - default void buildClientValidUri() { - assertThat(getUriHandler().buildClient(URI.create(getValidUri()))).isNotEmpty(); - } - - @DisplayName("returns empty Optional server with invalid URI") - @Test - default void buildServerInvalidUri() { - assertThat(getUriHandler().buildServer(URI.create(getInvalidUri()))).isEmpty(); - } - - @DisplayName("buildServer throws NullPointerException with null uri") - @Test - default void buildServerNullUri() { - assertThatNullPointerException() - .isThrownBy(() -> getUriHandler().buildServer(null)) - .withMessage("uri must not be null"); - } - - @DisplayName("returns server with value URI") - @Test - default void buildServerValidUri() { - assertThat(getUriHandler().buildServer(URI.create(getValidUri()))).isNotEmpty(); - } - - String getInvalidUri(); - - UriHandler getUriHandler(); - - String getValidUri(); -} diff --git a/rsocket-test/src/main/resources/words.shakespeare.txt.gz b/rsocket-test/src/main/resources/words.shakespeare.txt.gz new file mode 100644 index 000000000..422a4b331 Binary files /dev/null and b/rsocket-test/src/main/resources/words.shakespeare.txt.gz differ diff --git a/rsocket-transport-local/build.gradle b/rsocket-transport-local/build.gradle index 8c3226065..a5ba84d5c 100644 --- a/rsocket-transport-local/build.gradle +++ b/rsocket-transport-local/build.gradle @@ -24,8 +24,6 @@ plugins { dependencies { api project(':rsocket-core') - compileOnly 'com.google.code.findbugs:jsr305' - testImplementation project(':rsocket-test') testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java index 990acddfe..d69bd65e8 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -20,6 +20,7 @@ import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.internal.UnboundedProcessor; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; @@ -36,21 +37,39 @@ public final class LocalClientTransport implements ClientTransport { private final String name; - private LocalClientTransport(String name) { + private final ByteBufAllocator allocator; + + private LocalClientTransport(String name, ByteBufAllocator allocator) { this.name = name; + this.allocator = allocator; } /** * Creates a new instance. * - * @param name the name of the {@link ServerTransport} instance to connect to + * @param name the name of the {@link ClientTransport} instance to connect to * @return a new instance * @throws NullPointerException if {@code name} is {@code null} */ public static LocalClientTransport create(String name) { Objects.requireNonNull(name, "name must not be null"); - return new LocalClientTransport(name); + return create(name, ByteBufAllocator.DEFAULT); + } + + /** + * Creates a new instance. + * + * @param name the name of the {@link ClientTransport} instance to connect to + * @param allocator the allocator used by {@link ClientTransport} instance + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ + public static LocalClientTransport create(String name, ByteBufAllocator allocator) { + Objects.requireNonNull(name, "name must not be null"); + Objects.requireNonNull(allocator, "allocator must not be null"); + + return new LocalClientTransport(name, allocator); } private Mono connect() { @@ -65,9 +84,10 @@ private Mono connect() { UnboundedProcessor out = new UnboundedProcessor<>(); MonoProcessor closeNotifier = MonoProcessor.create(); - server.accept(new LocalDuplexConnection(out, in, closeNotifier)); + server.accept(new LocalDuplexConnection(allocator, out, in, closeNotifier)); - return Mono.just((DuplexConnection) new LocalDuplexConnection(in, out, closeNotifier)); + return Mono.just( + (DuplexConnection) new LocalDuplexConnection(allocator, in, out, closeNotifier)); }); } @@ -75,13 +95,14 @@ private Mono connect() { public Mono connect(int mtu) { Mono isError = FragmentationDuplexConnection.checkMtu(mtu); Mono connect = isError != null ? isError : connect(); - if (mtu > 0) { - return connect.map( - duplexConnection -> - new FragmentationDuplexConnection( - duplexConnection, ByteBufAllocator.DEFAULT, mtu, false, "client")); - } else { - return connect; - } + + return connect.map( + duplexConnection -> { + if (mtu > 0) { + return new FragmentationDuplexConnection(duplexConnection, mtu, false, "client"); + } else { + return new ReassemblyDuplexConnection(duplexConnection, false); + } + }); } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java index f9501717c..afaa14f95 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java @@ -17,6 +17,7 @@ package io.rsocket.transport.local; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import java.util.Objects; import org.reactivestreams.Publisher; @@ -28,6 +29,7 @@ /** An implementation of {@link DuplexConnection} that connects inside the same JVM. */ final class LocalDuplexConnection implements DuplexConnection { + private final ByteBufAllocator allocator; private final Flux in; private final MonoProcessor onClose; @@ -42,7 +44,12 @@ final class LocalDuplexConnection implements DuplexConnection { * @param onClose the closing notifier * @throws NullPointerException if {@code in}, {@code out}, or {@code onClose} are {@code null} */ - LocalDuplexConnection(Flux in, Subscriber out, MonoProcessor onClose) { + LocalDuplexConnection( + ByteBufAllocator allocator, + Flux in, + Subscriber out, + MonoProcessor onClose) { + this.allocator = Objects.requireNonNull(allocator, "allocator must not be null"); this.in = Objects.requireNonNull(in, "in must not be null"); this.out = Objects.requireNonNull(out, "out must not be null"); this.onClose = Objects.requireNonNull(onClose, "onClose must not be null"); @@ -82,4 +89,9 @@ public Mono sendOne(ByteBuf frame) { out.onNext(frame); return Mono.empty(); } + + @Override + public ByteBufAllocator alloc() { + return allocator; + } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java index d755859d2..382b4533a 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -16,10 +16,10 @@ package io.rsocket.transport.local; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.Objects; @@ -166,8 +166,9 @@ public void accept(DuplexConnection duplexConnection) { if (mtu > 0) { duplexConnection = - new FragmentationDuplexConnection( - duplexConnection, ByteBufAllocator.DEFAULT, mtu, false, "server"); + new FragmentationDuplexConnection(duplexConnection, mtu, false, "server"); + } else { + duplexConnection = new ReassemblyDuplexConnection(duplexConnection, false); } acceptor.apply(duplexConnection).subscribe(); diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java deleted file mode 100644 index 89c816d7a..000000000 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - -/** - * An implementation of {@link UriHandler} that creates {@link LocalClientTransport}s and {@link - * LocalServerTransport}s. - */ -public final class LocalUriHandler implements UriHandler { - - private static final String SCHEME = "local"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(LocalClientTransport.create(uri.getSchemeSpecificPart())); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(LocalServerTransport.create(uri.getSchemeSpecificPart())); - } -} diff --git a/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index 6ff8ffb50..000000000 --- a/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.transport.local.LocalUriHandler diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java index 2e4f93ac4..9228e2d05 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ package io.rsocket.transport.local; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingClient; import io.rsocket.test.PingHandler; @@ -28,18 +29,15 @@ public final class LocalPingPong { public static void main(String... args) { - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(LocalServerTransport.create("test-local-server")) - .start() + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(LocalServerTransport.create("test-local-server")) .block(); Mono client = - RSocketFactory.connect() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(LocalClientTransport.create("test-local-server")) - .start(); + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(LocalClientTransport.create("test-local-server")); PingClient pingClient = new PingClient(client); diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java deleted file mode 100644 index f6b5cda7e..000000000 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class LocalUriTransportRegistryTest { - - @DisplayName("local URI returns LocalClientTransport") - @Test - void clientForUri() { - assertThat(UriTransportRegistry.clientForUri("local:test1")) - .isInstanceOf(LocalClientTransport.class); - } - - @DisplayName("non-local URI does not return LocalClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("http://localhost")) - .isNotInstanceOf(LocalClientTransport.class); - } - - @DisplayName("local URI returns LocalServerTransport") - @Test - void serverForUri() { - assertThat(UriTransportRegistry.serverForUri("local:test1")) - .isInstanceOf(LocalServerTransport.class); - } - - @DisplayName("non-local URI does not return LocalServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("http://localhost")) - .isNotInstanceOf(LocalServerTransport.class); - } -} diff --git a/rsocket-transport-netty/build.gradle b/rsocket-transport-netty/build.gradle index 0aac12d5c..64e483c90 100644 --- a/rsocket-transport-netty/build.gradle +++ b/rsocket-transport-netty/build.gradle @@ -32,8 +32,6 @@ dependencies { api 'io.projectreactor.netty:reactor-netty' api 'org.slf4j:slf4j-api' - compileOnly 'com.google.code.findbugs:jsr305' - testImplementation project(':rsocket-test') testImplementation 'io.projectreactor:reactor-test' testImplementation 'org.assertj:assertj-core' diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java index 68d7ab175..e2c134653 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java @@ -16,8 +16,8 @@ package io.rsocket.transport.netty; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_SIZE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_SIZE; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java index c9c29f0a9..b7081593c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -19,7 +19,7 @@ import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import io.rsocket.frame.FrameLengthFlyweight; +import io.rsocket.frame.FrameLengthCodec; import io.rsocket.internal.BaseDuplexConnection; import java.util.Objects; import org.reactivestreams.Publisher; @@ -31,7 +31,6 @@ public final class TcpDuplexConnection extends BaseDuplexConnection { private final Connection connection; - private final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; private final boolean encodeLength; /** @@ -62,6 +61,11 @@ public TcpDuplexConnection(Connection connection, boolean encodeLength) { }); } + @Override + public ByteBufAllocator alloc() { + return connection.channel().alloc(); + } + @Override protected void doOnClose() { if (!connection.isDisposed()) { @@ -84,7 +88,7 @@ public Mono send(Publisher frames) { private ByteBuf encode(ByteBuf frame) { if (encodeLength) { - return FrameLengthFlyweight.encode(allocator, frame.readableBytes(), frame); + return FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame); } else { return frame; } @@ -92,7 +96,7 @@ private ByteBuf encode(ByteBuf frame) { private ByteBuf decode(ByteBuf frame) { if (encodeLength) { - return FrameLengthFlyweight.frame(frame).retain(); + return FrameLengthCodec.frame(frame).retain(); } else { return frame; } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java deleted file mode 100644 index d4ebd57b7..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; -import reactor.netty.tcp.TcpServer; - -/** - * An implementation of {@link UriHandler} that creates {@link TcpClientTransport}s and {@link - * TcpServerTransport}s. - */ -public final class TcpUriHandler implements UriHandler { - - private static final String SCHEME = "tcp"; - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of(TcpClientTransport.create(uri.getHost(), uri.getPort())); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (!SCHEME.equals(uri.getScheme())) { - return Optional.empty(); - } - - return Optional.of( - TcpServerTransport.create(TcpServer.create().host(uri.getHost()).port(uri.getPort()))); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java index ead297928..0183ef19d 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java @@ -16,6 +16,7 @@ package io.rsocket.transport.netty; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.rsocket.DuplexConnection; import io.rsocket.internal.BaseDuplexConnection; @@ -53,6 +54,11 @@ public WebsocketDuplexConnection(Connection connection) { }); } + @Override + public ByteBufAllocator alloc() { + return connection.channel().alloc(); + } + @Override protected void doOnClose() { if (!connection.isDisposed()) { diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java deleted file mode 100644 index 6438c4e28..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static io.rsocket.transport.netty.UriUtils.getPort; -import static io.rsocket.transport.netty.UriUtils.isSecure; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Arrays; -import java.util.List; -import java.util.Objects; -import java.util.Optional; - -/** - * An implementation of {@link UriHandler} that creates {@link WebsocketClientTransport}s and {@link - * WebsocketServerTransport}s. - */ -public final class WebsocketUriHandler implements UriHandler { - - private static final List SCHEME = Arrays.asList("ws", "wss", "http", "https"); - - @Override - public Optional buildClient(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (SCHEME.stream().noneMatch(scheme -> scheme.equals(uri.getScheme()))) { - return Optional.empty(); - } - - return Optional.of(WebsocketClientTransport.create(uri)); - } - - @Override - public Optional buildServer(URI uri) { - Objects.requireNonNull(uri, "uri must not be null"); - - if (SCHEME.stream().noneMatch(scheme -> scheme.equals(uri.getScheme()))) { - return Optional.empty(); - } - - int port = isSecure(uri) ? getPort(uri, 443) : getPort(uri, 80); - - return Optional.of(WebsocketServerTransport.create(uri.getHost(), port)); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java index f5e79e9bf..22f139310 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -16,9 +16,9 @@ package io.rsocket.transport.netty.client; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; @@ -75,7 +75,7 @@ public static TcpClientTransport create(String bindAddress, int port) { public static TcpClientTransport create(InetSocketAddress address) { Objects.requireNonNull(address, "address must not be null"); - TcpClient tcpClient = TcpClient.create().addressSupplier(() -> address); + TcpClient tcpClient = TcpClient.create().remoteAddress(() -> address); return create(tcpClient); } @@ -104,13 +104,9 @@ public Mono connect(int mtu) { c -> { if (mtu > 0) { return new FragmentationDuplexConnection( - new TcpDuplexConnection(c, false), - ByteBufAllocator.DEFAULT, - mtu, - true, - "client"); + new TcpDuplexConnection(c, false), mtu, true, "client"); } else { - return new TcpDuplexConnection(c); + return new ReassemblyDuplexConnection(new TcpDuplexConnection(c), false); } }); } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index 5049119a5..747401210 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -16,13 +16,13 @@ package io.rsocket.transport.netty.client; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import static io.rsocket.transport.netty.UriUtils.getPort; import static io.rsocket.transport.netty.UriUtils.isSecure; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.TransportHeaderAware; @@ -35,6 +35,7 @@ import java.util.function.Supplier; import reactor.core.publisher.Mono; import reactor.netty.http.client.HttpClient; +import reactor.netty.http.client.WebsocketClientSpec; import reactor.netty.tcp.TcpClient; /** @@ -43,12 +44,11 @@ */ public final class WebsocketClientTransport implements ClientTransport, TransportHeaderAware { - private static final int DEFAULT_FRAME_SIZE = 65536; private static final String DEFAULT_PATH = "/"; private final HttpClient client; - private String path; + private final String path; private Supplier> transportHeaders = Collections::emptyMap; @@ -93,7 +93,7 @@ public static WebsocketClientTransport create(String bindAddress, int port) { public static WebsocketClientTransport create(InetSocketAddress address) { Objects.requireNonNull(address, "address must not be null"); - TcpClient client = TcpClient.create().addressSupplier(() -> address); + TcpClient client = TcpClient.create().remoteAddress(() -> address); return create(client); } @@ -156,7 +156,8 @@ public Mono connect(int mtu) { ? isError : client .headers(headers -> transportHeaders.get().forEach(headers::set)) - .websocket(FRAME_LENGTH_MASK) + .websocket( + WebsocketClientSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build()) .uri(path) .connect() .map( @@ -164,8 +165,9 @@ public Mono connect(int mtu) { DuplexConnection connection = new WebsocketDuplexConnection(c); if (mtu > 0) { connection = - new FragmentationDuplexConnection( - connection, ByteBufAllocator.DEFAULT, mtu, false, "client"); + new FragmentationDuplexConnection(connection, mtu, false, "client"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); } return connection; }); diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java index f6e83bc36..c0340c7a2 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java @@ -28,7 +28,7 @@ */ public final class CloseableChannel implements Closeable { - private DisposableChannel channel; + private final DisposableChannel channel; /** * Creates a new instance diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java index 54ef016c0..56dd59d45 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -16,9 +16,9 @@ package io.rsocket.transport.netty.server; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; @@ -105,13 +105,9 @@ public Mono start(ConnectionAcceptor acceptor, int mtu) { if (mtu > 0) { connection = new FragmentationDuplexConnection( - new TcpDuplexConnection(c, false), - ByteBufAllocator.DEFAULT, - mtu, - true, - "server"); + new TcpDuplexConnection(c, false), mtu, true, "server"); } else { - connection = new TcpDuplexConnection(c); + connection = new ReassemblyDuplexConnection(new TcpDuplexConnection(c), false); } acceptor .apply(connection) diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java index 30aa0fa96..bd19f18b0 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -16,29 +16,23 @@ package io.rsocket.transport.netty.server; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; -import io.netty.buffer.ByteBufAllocator; -import io.netty.handler.codec.http.HttpMethod; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.WebsocketDuplexConnection; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.function.BiFunction; import java.util.function.Consumer; -import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; import reactor.netty.http.server.HttpServerRoutes; +import reactor.netty.http.server.WebsocketServerSpec; import reactor.netty.http.websocket.WebsocketInbound; import reactor.netty.http.websocket.WebsocketOutbound; @@ -48,7 +42,7 @@ */ public final class WebsocketRouteTransport extends BaseWebsocketServerTransport { - private final UriPathTemplate template; + private final String path; private final Consumer routesBuilder; @@ -65,7 +59,7 @@ public WebsocketRouteTransport( HttpServer server, Consumer routesBuilder, String path) { this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null"); - this.template = new UriPathTemplate(Objects.requireNonNull(path, "path must not be null")); + this.path = Objects.requireNonNull(path, "path must not be null"); } @Override @@ -77,10 +71,9 @@ public Mono start(ConnectionAcceptor acceptor, int mtu) { routes -> { routesBuilder.accept(routes); routes.ws( - hsr -> hsr.method().equals(HttpMethod.GET) && template.matches(hsr.uri()), + path, newHandler(acceptor, mtu), - null, - FRAME_LENGTH_MASK); + WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK).build()); }) .bind() .map(CloseableChannel::new); @@ -111,128 +104,11 @@ public static BiFunction> n return (in, out) -> { DuplexConnection connection = new WebsocketDuplexConnection((Connection) in); if (mtu > 0) { - connection = - new FragmentationDuplexConnection( - connection, ByteBufAllocator.DEFAULT, mtu, false, "server"); + connection = new FragmentationDuplexConnection(connection, mtu, false, "server"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); } return acceptor.apply(connection).then(out.neverComplete()); }; } - - static final class UriPathTemplate { - - private static final Pattern FULL_SPLAT_PATTERN = Pattern.compile("[\\*][\\*]"); - private static final String FULL_SPLAT_REPLACEMENT = ".*"; - - private static final Pattern NAME_SPLAT_PATTERN = Pattern.compile("\\{([^/]+?)\\}[\\*][\\*]"); - private static final String NAME_SPLAT_REPLACEMENT = "(?<%NAME%>.*)"; - - private static final Pattern NAME_PATTERN = Pattern.compile("\\{([^/]+?)\\}"); - private static final String NAME_REPLACEMENT = "(?<%NAME%>[^\\/]*)"; - - private final List pathVariables = new ArrayList<>(); - private final HashMap matchers = new HashMap<>(); - private final HashMap> vars = new HashMap<>(); - - private final Pattern uriPattern; - - static String filterQueryParams(String uri) { - int hasQuery = uri.lastIndexOf("?"); - if (hasQuery != -1) { - return uri.substring(0, hasQuery); - } else { - return uri; - } - } - - /** - * Creates a new {@code UriPathTemplate} from the given {@code uriPattern}. - * - * @param uriPattern The pattern to be used by the template - */ - UriPathTemplate(String uriPattern) { - String s = "^" + filterQueryParams(uriPattern); - - Matcher m = NAME_SPLAT_PATTERN.matcher(s); - while (m.find()) { - for (int i = 1; i <= m.groupCount(); i++) { - String name = m.group(i); - pathVariables.add(name); - s = m.replaceFirst(NAME_SPLAT_REPLACEMENT.replaceAll("%NAME%", name)); - m.reset(s); - } - } - - m = NAME_PATTERN.matcher(s); - while (m.find()) { - for (int i = 1; i <= m.groupCount(); i++) { - String name = m.group(i); - pathVariables.add(name); - s = m.replaceFirst(NAME_REPLACEMENT.replaceAll("%NAME%", name)); - m.reset(s); - } - } - - m = FULL_SPLAT_PATTERN.matcher(s); - while (m.find()) { - s = m.replaceAll(FULL_SPLAT_REPLACEMENT); - m.reset(s); - } - - this.uriPattern = Pattern.compile(s + "$"); - } - - /** - * Tests the given {@code uri} against this template, returning {@code true} if the uri matches - * the template, {@code false} otherwise. - * - * @param uri The uri to match - * @return {@code true} if there's a match, {@code false} otherwise - */ - public boolean matches(String uri) { - return matcher(uri).matches(); - } - - /** - * Matches the template against the given {@code uri} returning a map of path parameters - * extracted from the uri, keyed by the names in the template. If the uri does not match, or - * there are no path parameters, an empty map is returned. - * - * @param uri The uri to match - * @return the path parameters from the uri. Never {@code null}. - */ - final Map match(String uri) { - Map pathParameters = vars.get(uri); - if (null != pathParameters) { - return pathParameters; - } - - pathParameters = new HashMap<>(); - Matcher m = matcher(uri); - if (m.matches()) { - int i = 1; - for (String name : pathVariables) { - String val = m.group(i++); - pathParameters.put(name, val); - } - } - synchronized (vars) { - vars.put(uri, pathParameters); - } - - return pathParameters; - } - - private Matcher matcher(String uri) { - uri = filterQueryParams(uri); - Matcher m = matchers.get(uri); - if (null == m) { - m = uriPattern.matcher(uri); - synchronized (matchers) { - matchers.put(uri, m); - } - } - return m; - } - } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index 948d6f573..1a0b32cf0 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -16,11 +16,11 @@ package io.rsocket.transport.netty.server; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; -import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; import io.rsocket.fragmentation.FragmentationDuplexConnection; +import io.rsocket.fragmentation.ReassemblyDuplexConnection; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.TransportHeaderAware; @@ -35,6 +35,7 @@ import reactor.core.publisher.Mono; import reactor.netty.Connection; import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.WebsocketServerSpec; /** * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a @@ -122,18 +123,20 @@ public Mono start(ConnectionAcceptor acceptor, int mtu) { (request, response) -> { transportHeaders.get().forEach(response::addHeader); return response.sendWebsocket( - null, - FRAME_LENGTH_MASK, (in, out) -> { DuplexConnection connection = new WebsocketDuplexConnection((Connection) in); if (mtu > 0) { connection = - new FragmentationDuplexConnection( - connection, ByteBufAllocator.DEFAULT, mtu, false, "server"); + new FragmentationDuplexConnection(connection, mtu, false, "server"); + } else { + connection = new ReassemblyDuplexConnection(connection, false); } return acceptor.apply(connection).then(out.neverComplete()); - }); + }, + WebsocketServerSpec.builder() + .maxFramePayloadLength(FRAME_LENGTH_MASK) + .build()); }) .bind() .map(CloseableChannel::new); diff --git a/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index ec7ddcb80..000000000 --- a/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright 2015-2018 the original author or authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.transport.netty.TcpUriHandler -io.rsocket.transport.netty.WebsocketUriHandler diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java index 575993c18..23041ec65 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,32 +18,36 @@ import static org.assertj.core.api.Assertions.assertThat; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.util.DefaultPayload; import io.rsocket.util.RSocketProxy; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class FragmentTest { - private static final int frameSize = 64; - private AbstractRSocket handler; + private RSocket handler; private CloseableChannel server; private String message = null; private String metaData = null; private String responseMessage = null; - @BeforeEach - public void startup() { + private static Stream cases() { + return Stream.of(Arguments.of(0, 64), Arguments.of(64, 0), Arguments.of(64, 64)); + } + + public void startup(int frameSize) { int randomPort = ThreadLocalRandom.current().nextInt(10_000, 20_000); StringBuilder message = new StringBuilder(); StringBuilder responseMessage = new StringBuilder(); @@ -59,19 +63,16 @@ public void startup() { TcpServerTransport serverTransport = TcpServerTransport.create("localhost", randomPort); server = - RSocketFactory.receive() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) .fragment(frameSize) - .acceptor((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) - .transport(serverTransport) - .start() + .bind(serverTransport) .block(); } - private RSocket buildClient() { - return RSocketFactory.connect() + private RSocket buildClient(int frameSize) { + return RSocketConnector.create() .fragment(frameSize) - .transport(TcpClientTransport.create(server.address())) - .start() + .connect(TcpClientTransport.create(server.address())) .block(); } @@ -80,12 +81,14 @@ public void cleanup() { server.dispose(); } - @Test - void testFragmentNoMetaData() { + @ParameterizedTest + @MethodSource("cases") + void testFragmentNoMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); System.out.println( "-------------------------------------------------testFragmentNoMetaData-------------------------------------------------"); handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { String request = payload.getDataUtf8(); @@ -97,7 +100,7 @@ public Flux requestStream(Payload payload) { } }; - RSocket client = buildClient(); + RSocket client = buildClient(clientFrameSize); System.out.println("original message: " + message); System.out.println("original metadata: " + metaData); @@ -108,12 +111,14 @@ public Flux requestStream(Payload payload) { assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); } - @Test - void testFragmentRequestMetaDataOnly() { + @ParameterizedTest + @MethodSource("cases") + void testFragmentRequestMetaDataOnly(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); System.out.println( "-------------------------------------------------testFragmentRequestMetaDataOnly-------------------------------------------------"); handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { String request = payload.getDataUtf8(); @@ -125,7 +130,7 @@ public Flux requestStream(Payload payload) { } }; - RSocket client = buildClient(); + RSocket client = buildClient(clientFrameSize); System.out.println("original message: " + message); System.out.println("original metadata: " + metaData); @@ -136,13 +141,15 @@ public Flux requestStream(Payload payload) { assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); } - @Test - void testFragmentBothMetaData() { + @ParameterizedTest + @MethodSource("cases") + void testFragmentBothMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); Payload responsePayload = DefaultPayload.create(responseMessage); System.out.println( "-------------------------------------------------testFragmentBothMetaData-------------------------------------------------"); handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { String request = payload.getDataUtf8(); @@ -164,7 +171,7 @@ public Mono requestResponse(Payload payload) { } }; - RSocket client = buildClient(); + RSocket client = buildClient(clientFrameSize); System.out.println("original message: " + message); System.out.println("original metadata: " + metaData); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java index 07e9378fa..b9c0d4f60 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java @@ -1,18 +1,15 @@ package io.rsocket.transport.netty; -import static io.rsocket.RSocketFactory.*; - import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; -import io.rsocket.transport.ClientTransport; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; import io.rsocket.transport.netty.server.WebsocketServerTransport; import java.time.Duration; -import java.util.function.Function; import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; @@ -22,104 +19,62 @@ class RSocketFactoryNettyTransportFragmentationTest { - @ParameterizedTest - @MethodSource("serverTransportProvider") - void serverErrorsWithEnabledFragmentationOnInsufficientMtu( - ServerTransport serverTransport) { - Mono server = createServer(serverTransport, f -> f.fragment(2)); - - StepVerifier.create(server) - .expectErrorMatches( - err -> - err instanceof IllegalArgumentException - && "smallest allowed mtu size is 64 bytes, provided: 2" - .equals(err.getMessage())) - .verify(Duration.ofSeconds(5)); + static Stream> arguments() { + return Stream.of(TcpServerTransport.create(0), WebsocketServerTransport.create(0)); } @ParameterizedTest - @MethodSource("serverTransportProvider") + @MethodSource("arguments") void serverSucceedsWithEnabledFragmentationOnSufficientMtu( ServerTransport serverTransport) { Mono server = - createServer(serverTransport, f -> f.fragment(100)).doOnNext(CloseableChannel::dispose); + RSocketServer.create(mockAcceptor()) + .fragment(100) + .bind(serverTransport) + .doOnNext(CloseableChannel::dispose); StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); } @ParameterizedTest - @MethodSource("serverTransportProvider") - void serverSucceedsWithDisabledFragmentation() { + @MethodSource("arguments") + void serverSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { Mono server = - createServer(TcpServerTransport.create("localhost", 0), Function.identity()) + RSocketServer.create(mockAcceptor()) + .bind(serverTransport) .doOnNext(CloseableChannel::dispose); StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); } @ParameterizedTest - @MethodSource("serverTransportProvider") - void clientErrorsWithEnabledFragmentationOnInsufficientMtu( - ServerTransport serverTransport) { - CloseableChannel server = createServer(serverTransport, f -> f.fragment(100)).block(); - - Mono rSocket = - createClient(TcpClientTransport.create(server.address()), f -> f.fragment(2)) - .doFinally(s -> server.dispose()); - - StepVerifier.create(rSocket) - .expectErrorMatches( - err -> - err instanceof IllegalArgumentException - && "smallest allowed mtu size is 64 bytes, provided: 2" - .equals(err.getMessage())) - .verify(Duration.ofSeconds(5)); - } - - @ParameterizedTest - @MethodSource("serverTransportProvider") + @MethodSource("arguments") void clientSucceedsWithEnabledFragmentationOnSufficientMtu( ServerTransport serverTransport) { - CloseableChannel server = createServer(serverTransport, f -> f.fragment(100)).block(); + CloseableChannel server = + RSocketServer.create(mockAcceptor()).fragment(100).bind(serverTransport).block(); Mono rSocket = - createClient(TcpClientTransport.create(server.address()), f -> f.fragment(100)) + RSocketConnector.create() + .fragment(100) + .connect(TcpClientTransport.create(server.address())) .doFinally(s -> server.dispose()); StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); } @ParameterizedTest - @MethodSource("serverTransportProvider") - void clientSucceedsWithDisabledFragmentation() { - CloseableChannel server = - createServer(TcpServerTransport.create("localhost", 0), Function.identity()).block(); + @MethodSource("arguments") + void clientSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { + CloseableChannel server = RSocketServer.create(mockAcceptor()).bind(serverTransport).block(); Mono rSocket = - createClient(TcpClientTransport.create(server.address()), Function.identity()) + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) .doFinally(s -> server.dispose()); StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); } - private Mono createClient( - ClientTransport transport, Function f) { - return f.apply(RSocketFactory.connect()).transport(transport).start(); - } - - private Mono createServer( - ServerTransport transport, - Function f) { - return f.apply(receive()).acceptor(mockAcceptor()).transport(transport).start(); - } - private SocketAcceptor mockAcceptor() { SocketAcceptor mock = Mockito.mock(SocketAcceptor.class); Mockito.when(mock.accept(Mockito.any(), Mockito.any())) .thenReturn(Mono.just(Mockito.mock(RSocket.class))); return mock; } - - static Stream> serverTransportProvider() { - String host = "localhost"; - int port = 0; - return Stream.of( - TcpServerTransport.create(host, port), WebsocketServerTransport.create(host, port)); - } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java index f32d28a0b..6fd3de791 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java @@ -2,8 +2,9 @@ import io.rsocket.ConnectionSetupPayload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.exceptions.RejectedSetupException; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; @@ -41,19 +42,15 @@ void rejectSetupTcp( Mono serverRequester = acceptor.requesterRSocket(); CloseableChannel channel = - RSocketFactory.receive() - .acceptor(acceptor) - .transport(serverTransport.apply(new InetSocketAddress("localhost", 0))) - .start() + RSocketServer.create(acceptor) + .bind(serverTransport.apply(new InetSocketAddress("localhost", 0))) .block(Duration.ofSeconds(5)); ErrorConsumer errorConsumer = new ErrorConsumer(); RSocket clientRequester = - RSocketFactory.connect() - .errorConsumer(errorConsumer) - .transport(clientTransport.apply(channel.address())) - .start() + RSocketConnector.connectWith(clientTransport.apply(channel.address())) + .doOnError(errorConsumer) .block(Duration.ofSeconds(5)); StepVerifier.create(errorConsumer.errors().next()) diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java index c2e136635..88c64648c 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java @@ -17,7 +17,8 @@ package io.rsocket.transport.netty; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PerfTest; import io.rsocket.test.PingClient; @@ -81,16 +82,15 @@ private static PingClient newResumablePingClient() { } private static PingClient newPingClient(boolean isResumable) { - RSocketFactory.ClientRSocketFactory clientRSocketFactory = RSocketFactory.connect(); + RSocketConnector connector = RSocketConnector.create(); if (isResumable) { - clientRSocketFactory.resume(); + connector.resume(new Resume()); } Mono rSocket = - clientRSocketFactory - .frameDecoder(PayloadDecoder.ZERO_COPY) - .keepAlive(Duration.ofMinutes(1), Duration.ofMinutes(30), 3) - .transport(TcpClientTransport.create(port)) - .start(); + connector + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMinutes(1), Duration.ofMinutes(30)) + .connect(TcpClientTransport.create(port)); return new PingClient(rSocket); } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java index b40f35e51..338868470 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java @@ -16,7 +16,8 @@ package io.rsocket.transport.netty; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingHandler; import io.rsocket.transport.netty.server.TcpServerTransport; @@ -31,15 +32,13 @@ public static void main(String... args) { System.out.println("port: " + port); System.out.println("resume enabled: " + isResume); - RSocketFactory.ServerRSocketFactory serverRSocketFactory = RSocketFactory.receive(); + RSocketServer server = RSocketServer.create(new PingHandler()); if (isResume) { - serverRSocketFactory.resume(); + server.resume(new Resume()); } - serverRSocketFactory - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(TcpServerTransport.create("localhost", port)) - .start() + server + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create("localhost", port)) .block() .onClose() .block(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java new file mode 100644 index 000000000..95bebd6aa --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java @@ -0,0 +1,55 @@ +package io.rsocket.transport.netty; + +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; +import reactor.core.Exceptions; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +public class TcpSecureTransportTest implements TransportTest { + private final TransportPair transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE)))), + address -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + TcpServer server = + TcpServer.create() + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return TcpServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(10); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java deleted file mode 100644 index 25b443dd6..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriHandlerTest.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; - -final class TcpUriHandlerTest implements UriHandlerTest { - - @Override - public String getInvalidUri() { - return "http://test"; - } - - @Override - public UriHandler getUriHandler() { - return new TcpUriHandler(); - } - - @Override - public String getValidUri() { - return "tcp://test:9898"; - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java deleted file mode 100644 index a71cc27f9..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpUriTransportRegistryTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class TcpUriTransportRegistryTest { - - @DisplayName("non-tcp URI does not return TcpClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("amqp://localhost")) - .isNotInstanceOf(TcpClientTransport.class) - .isNotInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("tcp URI returns TcpClientTransport") - @Test - void clientForUriTcp() { - assertThat(UriTransportRegistry.clientForUri("tcp://test:9898")) - .isInstanceOf(TcpClientTransport.class); - } - - @DisplayName("non-tcp URI does not return TcpServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("amqp://localhost")) - .isNotInstanceOf(TcpServerTransport.class) - .isNotInstanceOf(WebsocketServerTransport.class); - } - - @DisplayName("tcp URI returns TcpServerTransport") - @Test - void serverForUriTcp() { - assertThat(UriTransportRegistry.serverForUri("tcp://test:9898")) - .isInstanceOf(TcpServerTransport.class); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java index 4fe40d232..c418dea0f 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java @@ -1,9 +1,9 @@ package io.rsocket.transport.netty; -import io.rsocket.AbstractRSocket; -import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.transport.netty.server.WebsocketRouteTransport; @@ -13,7 +13,6 @@ import java.time.Duration; import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; import reactor.netty.DisposableServer; import reactor.netty.http.server.HttpServer; import reactor.test.StepVerifier; @@ -23,19 +22,11 @@ public class WebSocketTransportIntegrationTest { @Test public void sendStreamOfDataWithExternalHttpServerTest() { ServerTransport.ConnectionAcceptor acceptor = - RSocketFactory.receive() - .acceptor( - (setupPayload, sendingRSocket) -> { - return Mono.just( - new AbstractRSocket() { - @Override - public Flux requestStream(Payload payload) { - return Flux.range(0, 10) - .map(i -> DefaultPayload.create(String.valueOf(i))); - } - }); - }) - .toConnectionAcceptor(); + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(0, 10).map(i -> DefaultPayload.create(String.valueOf(i))))) + .asConnectionAcceptor(); DisposableServer server = HttpServer.create() @@ -44,11 +35,9 @@ public Flux requestStream(Payload payload) { .bindNow(); RSocket rsocket = - RSocketFactory.connect() - .transport( + RSocketConnector.connectWith( WebsocketClientTransport.create( URI.create("ws://" + server.host() + ":" + server.port() + "/test"))) - .start() .block(); StepVerifier.create(rsocket.requestStream(EmptyPayload.INSTANCE)) diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java index 306be4e43..a784a43c0 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ package io.rsocket.transport.netty; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingClient; import io.rsocket.transport.netty.client.WebsocketClientTransport; @@ -29,10 +29,9 @@ public final class WebsocketPing { public static void main(String... args) { Mono client = - RSocketFactory.connect() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .transport(WebsocketClientTransport.create(7878)) - .start(); + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(WebsocketClientTransport.create(7878)); PingClient pingClient = new PingClient(client); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java index eac091dd8..e2ee9e521 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java @@ -8,7 +8,11 @@ import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocketFrame; import io.netty.util.ReferenceCountUtil; -import io.rsocket.*; +import io.rsocket.Closeable; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.transport.netty.server.WebsocketRouteTransport; @@ -42,10 +46,8 @@ void tearDown() { @MethodSource("provideServerTransport") void webSocketPingPong(ServerTransport serverTransport) { server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new EchoRSocket())) - .transport(serverTransport) - .start() + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .bind(serverTransport) .block(); String expectedData = "data"; @@ -63,10 +65,7 @@ void webSocketPingPong(ServerTransport serverTransport) { .port(port)); RSocket rSocket = - RSocketFactory.connect() - .transport(WebsocketClientTransport.create(httpClient, "/")) - .start() - .block(); + RSocketConnector.connectWith(WebsocketClientTransport.create(httpClient, "/")).block(); rSocket .requestResponse(DefaultPayload.create(expectedData)) @@ -100,13 +99,6 @@ private static Stream provideServerTransport() { HttpServer.create().host(host).port(port), routes -> {}, "/"))); } - private static class EchoRSocket extends AbstractRSocket { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - } - private static class PingSender extends ChannelInboundHandlerAdapter { private final MonoProcessor channel = MonoProcessor.create(); private final MonoProcessor pong = MonoProcessor.create(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java index 7fdb1813a..84dc816be 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java @@ -1,5 +1,5 @@ /* - * Copyright 2015-2018 the original author or authors. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package io.rsocket.transport.netty; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketServer; import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingHandler; import io.rsocket.transport.netty.server.WebsocketServerTransport; @@ -24,11 +24,9 @@ public final class WebsocketPongServer { public static void main(String... args) { - RSocketFactory.receive() - .frameDecoder(PayloadDecoder.ZERO_COPY) - .acceptor(new PingHandler()) - .transport(WebsocketServerTransport.create(7878)) - .start() + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(WebsocketServerTransport.create(7878)) .block() .onClose() .block(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java index c1d608979..ec33060b2 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java @@ -38,7 +38,7 @@ final class WebsocketSecureTransportTest implements TransportTest { (address, server) -> WebsocketClientTransport.create( HttpClient.create() - .addressSupplier(server::address) + .remoteAddress(server::address) .secure( ssl -> ssl.sslContext( @@ -53,7 +53,7 @@ final class WebsocketSecureTransportTest implements TransportTest { HttpServer server = HttpServer.from( TcpServer.create() - .addressSupplier(() -> address) + .bindAddress(() -> address) .secure( ssl -> ssl.sslContext( diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java deleted file mode 100644 index 72a700b0e..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriHandlerTest.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.test.UriHandlerTest; -import io.rsocket.uri.UriHandler; - -final class WebsocketUriHandlerTest implements UriHandlerTest { - - @Override - public String getInvalidUri() { - return "amqp://test"; - } - - @Override - public UriHandler getUriHandler() { - return new WebsocketUriHandler(); - } - - @Override - public String getValidUri() { - return "ws://test:9898"; - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java deleted file mode 100644 index 5688f14ed..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketUriTransportRegistryTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright 2015-2018 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static org.assertj.core.api.Assertions.assertThat; - -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriTransportRegistry; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; - -final class WebsocketUriTransportRegistryTest { - - @DisplayName("non-ws URI does not return WebsocketClientTransport") - @Test - void clientForUriInvalid() { - assertThat(UriTransportRegistry.clientForUri("amqp://localhost")) - .isNotInstanceOf(TcpClientTransport.class) - .isNotInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("ws URI returns WebsocketClientTransport") - @Test - void clientForUriWebsocket() { - assertThat(UriTransportRegistry.clientForUri("ws://test:9898")) - .isInstanceOf(WebsocketClientTransport.class); - } - - @DisplayName("non-ws URI does not return WebsocketServerTransport") - @Test - void serverForUriInvalid() { - assertThat(UriTransportRegistry.serverForUri("amqp://localhost")) - .isNotInstanceOf(TcpServerTransport.class) - .isNotInstanceOf(WebsocketServerTransport.class); - } - - @DisplayName("ws URI returns WebsocketServerTransport") - @Test - void serverForUriWebsocket() { - assertThat(UriTransportRegistry.serverForUri("ws://test:9898")) - .isInstanceOf(WebsocketServerTransport.class); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java index 905f022f2..fc035c536 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java @@ -16,7 +16,7 @@ package io.rsocket.transport.netty.client; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java index 5a2986485..249a3e12a 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java @@ -16,7 +16,7 @@ package io.rsocket.transport.netty.server; -import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNullPointerException; diff --git a/settings.gradle b/settings.gradle index 625633774..25c3feee5 100644 --- a/settings.gradle +++ b/settings.gradle @@ -13,14 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +plugins { + id 'com.gradle.enterprise' version '3.1' +} rootProject.name = 'rsocket-java' include 'rsocket-core' -include 'rsocket-examples' include 'rsocket-load-balancer' include 'rsocket-micrometer' include 'rsocket-test' include 'rsocket-transport-local' include 'rsocket-transport-netty' include 'rsocket-bom' + +include 'rsocket-examples' +include 'benchmarks' + + + +gradleEnterprise { + buildScan { + termsOfServiceUrl = 'https://gradle.com/terms-of-service' + termsOfServiceAgree = 'yes' + } +} +