diff --git a/ChatExample/app/build.gradle b/ChatExample/app/build.gradle index 83d08b6..6ddf2db 100644 --- a/ChatExample/app/build.gradle +++ b/ChatExample/app/build.gradle @@ -1,15 +1,12 @@ apply plugin: 'com.android.application' - apply plugin: 'kotlin-android' -apply plugin: 'kotlin-android-extensions' - android { - compileSdkVersion 28 + compileSdkVersion 33 defaultConfig { applicationId "com.github.dsrees.chatexample" minSdkVersion 19 - targetSdkVersion 28 + targetSdkVersion 33 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" @@ -21,6 +18,10 @@ android { } } + buildFeatures { + viewBinding true + } + compileOptions { targetCompatibility = "8" sourceCompatibility = "8" @@ -29,25 +30,21 @@ android { dependencies { /* - To update the JavaPhoenixClient, either use the latest dependency from jcenter + To update the JavaPhoenixClient, either use the latest dependency from mavenCentral() OR run `./gradlew jar` and copy `/build/lib/*.jar` to `/ChatExample/app/libs` - and comment out the jcenter dependency + and comment out the mavenCentral() dependency */ implementation fileTree(dir: 'libs', include: ['*.jar']) -// implementation 'com.github.dsrees:JavaPhoenixClient:0.2.3' - - - implementation "com.google.code.gson:gson:2.8.5" - implementation "com.squareup.okhttp3:okhttp:3.12.2" - +// implementation 'com.github.dsrees:JavaPhoenixClient:0.3.4' - implementation"org.jetbrains.kotlin:kotlin-stdlib-jdk7:$kotlin_version" - implementation 'androidx.appcompat:appcompat:1.0.2' - implementation 'androidx.recyclerview:recyclerview:1.0.0' - implementation 'androidx.constraintlayout:constraintlayout:1.1.3' + implementation "com.google.code.gson:gson:2.10.1" + implementation "com.squareup.okhttp3:okhttp:4.11.0" + implementation 'androidx.appcompat:appcompat:1.6.1' + implementation 'androidx.recyclerview:recyclerview:1.3.1' + implementation 'androidx.constraintlayout:constraintlayout:2.1.4' } diff --git a/ChatExample/app/libs/JavaPhoenixClient-0.3.0.jar b/ChatExample/app/libs/JavaPhoenixClient-0.3.0.jar deleted file mode 100644 index 4aa7d32..0000000 Binary files a/ChatExample/app/libs/JavaPhoenixClient-0.3.0.jar and /dev/null differ diff --git a/ChatExample/app/libs/JavaPhoenixClient-1.1.5.jar b/ChatExample/app/libs/JavaPhoenixClient-1.1.5.jar new file mode 100644 index 0000000..dae2106 Binary files /dev/null and b/ChatExample/app/libs/JavaPhoenixClient-1.1.5.jar differ diff --git a/ChatExample/app/src/main/AndroidManifest.xml b/ChatExample/app/src/main/AndroidManifest.xml index 5dbb22a..5ed4835 100644 --- a/ChatExample/app/src/main/AndroidManifest.xml +++ b/ChatExample/app/src/main/AndroidManifest.xml @@ -1,24 +1,28 @@ + xmlns:tools="http://schemas.android.com/tools" + package="com.github.dsrees.chatexample"> - + - - - - + + + + - - - - + + + + \ No newline at end of file diff --git a/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt index 3acdcb5..9a4caa5 100644 --- a/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt +++ b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MainActivity.kt @@ -3,11 +3,8 @@ package com.github.dsrees.chatexample import androidx.appcompat.app.AppCompatActivity import android.os.Bundle import android.util.Log -import android.widget.ArrayAdapter -import android.widget.Button -import android.widget.EditText import androidx.recyclerview.widget.LinearLayoutManager -import kotlinx.android.synthetic.main.activity_main.* +import com.github.dsrees.chatexample.databinding.ActivityMainBinding import org.phoenixframework.Channel import org.phoenixframework.Socket @@ -17,44 +14,47 @@ class MainActivity : AppCompatActivity() { const val TAG = "MainActivity" } + private lateinit var binding: ActivityMainBinding + private val messagesAdapter = MessagesAdapter() private val layoutManager = LinearLayoutManager(this) // Use when connecting to https://github.com/dwyl/phoenix-chat-example - private val socket = Socket("https://phxchat.herokuapp.com/socket/websocket") + private val socket = Socket("https://phoenix-chat.fly.dev/socket/websocket") private val topic = "room:lobby" // Use when connecting to local server -// private val socket = Socket("ws://10.0.2.2:4000/socket/websocket") -// private val topic = "rooms:lobby" + // private val socket = Socket("ws://10.0.2.2:4000/socket/websocket") + // private val topic = "rooms:lobby" private var lobbyChannel: Channel? = null private val username: String - get() = username_input.text.toString() + get() = binding.usernameInput.text.toString() private val message: String - get() = message_input.text.toString() + get() = binding.messageInput.text.toString() override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) - setContentView(R.layout.activity_main) + this.binding = ActivityMainBinding.inflate(layoutInflater) + setContentView(binding.root) layoutManager.stackFromEnd = true - messages_recycler_view.layoutManager = layoutManager - messages_recycler_view.adapter = messagesAdapter + binding.messagesRecyclerView.layoutManager = layoutManager + binding.messagesRecyclerView.adapter = messagesAdapter socket.onOpen { this.addText("Socket Opened") - runOnUiThread { connect_button.text = "Disconnect" } + runOnUiThread { binding.connectButton.text = "Disconnect" } } socket.onClose { this.addText("Socket Closed") - runOnUiThread { connect_button.text = "Connect" } + runOnUiThread { binding.connectButton.text = "Connect" } } socket.onError { throwable, response -> @@ -67,7 +67,7 @@ class MainActivity : AppCompatActivity() { } - connect_button.setOnClickListener { + binding.connectButton.setOnClickListener { if (socket.isConnected) { this.disconnectAndLeave() } else { @@ -76,7 +76,7 @@ class MainActivity : AppCompatActivity() { } } - send_button.setOnClickListener { sendMessage() } + binding.sendButton.setOnClickListener { sendMessage() } } private fun sendMessage() { @@ -85,7 +85,7 @@ class MainActivity : AppCompatActivity() { ?.receive("ok") { Log.d(TAG, "success $it") } ?.receive("error") { Log.d(TAG, "error $it") } - message_input.text.clear() + binding.messageInput.text.clear() } private fun disconnectAndLeave() { @@ -132,9 +132,7 @@ class MainActivity : AppCompatActivity() { private fun addText(message: String) { runOnUiThread { this.messagesAdapter.add(message) - layoutManager.smoothScrollToPosition(messages_recycler_view, null, messagesAdapter.itemCount) + layoutManager.smoothScrollToPosition(binding.messagesRecyclerView, null, messagesAdapter.itemCount) } - } - } diff --git a/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MessagesAdapter.kt b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MessagesAdapter.kt index e99b294..143df58 100644 --- a/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MessagesAdapter.kt +++ b/ChatExample/app/src/main/java/com/github/dsrees/chatexample/MessagesAdapter.kt @@ -5,6 +5,7 @@ import android.view.View import android.view.ViewGroup import android.widget.TextView import androidx.recyclerview.widget.RecyclerView +import com.github.dsrees.chatexample.databinding.ItemMessageBinding class MessagesAdapter : RecyclerView.Adapter() { @@ -17,8 +18,8 @@ class MessagesAdapter : RecyclerView.Adapter() { override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): ViewHolder { - val view = LayoutInflater.from(parent.context).inflate(R.layout.item_message, parent, false) - return ViewHolder(view) + val binding = ItemMessageBinding.inflate(LayoutInflater.from(parent.context), parent, false) + return ViewHolder(binding) } override fun getItemCount(): Int = messages.size @@ -27,8 +28,8 @@ class MessagesAdapter : RecyclerView.Adapter() { holder.label.text = messages[position] } - inner class ViewHolder(itemView: View) : RecyclerView.ViewHolder(itemView) { - val label: TextView = itemView.findViewById(R.id.item_message_label) + inner class ViewHolder(binding: ItemMessageBinding) : RecyclerView.ViewHolder(binding.root) { + val label: TextView = binding.itemMessageLabel } } \ No newline at end of file diff --git a/ChatExample/build.gradle b/ChatExample/build.gradle index b39ac10..08b070b 100644 --- a/ChatExample/build.gradle +++ b/ChatExample/build.gradle @@ -1,24 +1,22 @@ // Top-level build file where you can add configuration options common to all sub-projects/modules. buildscript { - ext.kotlin_version = '1.3.31' + ext.kotlin_version = '1.8.0' repositories { google() - jcenter() + mavenCentral() } dependencies { - classpath 'com.android.tools.build:gradle:3.4.1' + classpath 'com.android.tools.build:gradle:7.4.2' classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" - // NOTE: Do not place your application dependencies here; they belong - // in the individual module build.gradle files } } allprojects { repositories { google() - jcenter() + mavenCentral() } } diff --git a/ChatExample/gradle/wrapper/gradle-wrapper.properties b/ChatExample/gradle/wrapper/gradle-wrapper.properties index 3c4d55d..5fa6080 100644 --- a/ChatExample/gradle/wrapper/gradle-wrapper.properties +++ b/ChatExample/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ -#Tue May 14 11:28:07 EDT 2019 +#Wed Oct 04 12:35:15 EDT 2023 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.1.1-all.zip diff --git a/README.md b/README.md index 9960308..ecdb961 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # JavaPhoenixClient -[![Download](https://api.bintray.com/packages/drees/java-phoenix-client/JavaPhoenixClient/images/download.svg) ](https://bintray.com/drees/java-phoenix-client/JavaPhoenixClient/_latestVersion) +[![Maven Central](https://img.shields.io/maven-central/v/com.github.dsrees/JavaPhoenixClient.svg?label=Maven%20Central)](https://search.maven.org/search?q=g:%22com.github.dsrees%22%20AND%20a:%22JavaPhoenixClient%22) [![Build Status](https://travis-ci.com/dsrees/JavaPhoenixClient.svg?branch=master)](https://travis-ci.com/dsrees/JavaPhoenixClient) [![codecov](https://codecov.io/gh/dsrees/JavaPhoenixClient/branch/master/graph/badge.svg)](https://codecov.io/gh/dsrees/JavaPhoenixClient) @@ -39,40 +39,86 @@ fun connectToChatRoom() { } ``` + +If you need to provide dynamic parameters that can change between calls to `connect()`, then you can pass a closure to the constructor + +```kotlin + +// Create the Socket +var authToken = "abc" +val socket = Socket("http://localhost:4000/socket/websocket", { mapOf("token" to authToken) }) + +// Connect with query parameters "?token=abc" +socket.connect() + + +// later in time, connect with query parameters "?token=xyz" +authToken = "xyz" +socket.connect() // or internal reconnect logic kicks in +``` + + You can also inject your own OkHttp Client into the Socket to provide your own configuration ```kotlin -// Create the Socket with a pre-configured OkHttp Client +// Configure your own OkHttp Client val client = OkHttpClient.Builder() .connectTimeout(1000, TimeUnit.MILLISECONDS) .build() +// Create Socket with your custom instances val params = hashMapOf("token" to "abc123") val socket = Socket("http://localhost:4000/socket/websocket", - params, - client) + params, + client) ``` +By default, the client use GSON to encode and decode JSON. If you prefer to manage this yourself, you +can provide custom encode/decode functions in the `Socket` constructor. + +```kotlin + +// Configure your own GSON instance +val gson = Gson.Builder().create() +val encoder: EncodeClosure = { + // Encode a Map into JSON using your custom GSON instance or another JSON library + // of your choice (Moshi, etc) +} +val decoder: DecodeClosure = { + // Decode a JSON String into a `Message` object using your custom JSON library +} + +// Create Socket with your custom instances +val params = hashMapOf("token" to "abc123") +val socket = Socket("http://localhost:4000/socket/websocket", + params, + encoder, + decoder) +``` + + + + ### Installation -JavaPhoenixClient is hosted on JCenter. You'll need to make sure you declare `jcenter()` as one of your repositories +JavaPhoenixClient is hosted on MavenCentral. You'll need to make sure you declare `mavenCentral()` as one of your repositories ``` repositories { - jcenter() + mavenCentral() } ``` and then add the library. See [releases](https://github.com/dsrees/JavaPhoenixClient/releases) for the latest version ```$xslt dependencies { - implementation 'com.github.dsrees:JavaPhoenixClient:0.3.4' + implementation 'com.github.dsrees:JavaPhoenixClient:1.3.1' } ``` ### Feedback -Please submit in issue if you have any problems! +Please submit in issue if you have any problems or questions! PRs are also welcome. This library is built to mirror the [phoenix.js](https://hexdocs.pm/phoenix/js/) and [SwiftPhoenixClient](https://github.com/davidstump/SwiftPhoenixClient) libraries. diff --git a/RELEASING.md b/RELEASING.md index 75a44b4..891816d 100644 --- a/RELEASING.md +++ b/RELEASING.md @@ -7,3 +7,6 @@ Release Process 4. Tag: `git tag -a X.Y.Z -m "Version X.Y.Z"` 5. Push: `git push && git push --tags` 6. Add the new release with notes (https://github.com/dsrees/JavaPhoenixClient/releases). + 7. Publish to Maven Central by running `./gradlew clean publish`. Can only be done by dsrees until CI setup + 8. Close the staging repo here: https://s01.oss.sonatype.org/#stagingRepositories + 9. Release the closed repo diff --git a/build.gradle b/build.gradle index c0e2779..78e6a8a 100644 --- a/build.gradle +++ b/build.gradle @@ -1,19 +1,36 @@ -buildscript { repositories { jcenter() } } +buildscript { + repositories { + jcenter() + mavenCentral() + } + + dependencies { + classpath 'com.vanniktech:gradle-maven-publish-plugin:0.14.2' + classpath 'org.jetbrains.dokka:dokka-gradle-plugin:1.4.30' + } +} + plugins { id 'java' id 'jacoco' - id 'org.jetbrains.kotlin.jvm' version '1.3.31' - id 'nebula.project' version '6.0.3' - id "nebula.maven-publish" version '9.5.4' - id 'nebula.nebula-bintray' version '5.0.0' + id 'org.jetbrains.kotlin.jvm' version '1.8.0' +} + +ext { + RELEASE_REPOSITORY_URL = "https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/" + SNAPSHOT_REPOSITORY_URL = "https://s01.oss.sonatype.org/content/repositories/snapshots/" } +apply plugin: "org.jetbrains.dokka" +apply plugin: "com.vanniktech.maven.publish" + group 'com.github.dsrees' -version '0.3.4' +version '1.3.1' sourceCompatibility = 1.8 repositories { + jcenter() mavenCentral() } @@ -30,16 +47,16 @@ test { } dependencies { - compile "org.jetbrains.kotlin:kotlin-stdlib-jdk8" - compile "com.google.code.gson:gson:2.8.5" - compile "com.squareup.okhttp3:okhttp:3.12.2" + implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8" + implementation "com.google.code.gson:gson:2.10.1" + implementation "com.squareup.okhttp3:okhttp:4.11.0" testImplementation 'org.junit.jupiter:junit-jupiter-api:5.3.1' testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.3.1' - testCompile group: 'com.google.truth', name: 'truth', version: '0.44' - testCompile group: 'org.mockito', name: 'mockito-core', version: '2.27.0' - testCompile group: 'com.nhaarman.mockitokotlin2', name: 'mockito-kotlin', version: '2.1.0' + testImplementation group: 'com.google.truth', name: 'truth', version: '1.1.3' + testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.0.0' + testImplementation group: 'org.mockito.kotlin', name: 'mockito-kotlin', version: '4.0.0' } jacocoTestReport { @@ -49,8 +66,6 @@ jacocoTestReport { } } - - compileKotlin { kotlinOptions.jvmTarget = "1.8" } @@ -58,23 +73,3 @@ compileTestKotlin { kotlinOptions.jvmTarget = "1.8" } -bintray { - user = System.getenv('bintrayUser') - key = System.getenv('bintrayApiKey') - dryRun = false - publish = true - pkg { - repo = 'java-phoenix-client' - name = 'JavaPhoenixClient' - userOrg = user - websiteUrl = 'https://github.com/dsrees/JavaPhoenixClient' - issueTrackerUrl = 'https://github.com/dsrees/JavaPhoenixClient/issues' - vcsUrl = 'https://github.com/dsrees/JavaPhoenixClient.git' - licenses = ['MIT'] - version { - name = project.version - vcsTag = project.version - } - } - publications = ['nebula'] -} diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 0000000..65bd7af --- /dev/null +++ b/gradle.properties @@ -0,0 +1,20 @@ +GROUP=com.github.dsrees +POM_ARTIFACT_ID=JavaPhoenixClient +VERSION_NAME=0.3.4 + +POM_NAME=JavaPhoenixClient +POM_DESCRIPTION=A phoenix channels client built for the JVM +POM_INCEPTION_YEAR=2018 + +POM_URL=https://github.com/dsrees/JavaPhoenixClient +POM_SCM_URL=https://github.com/dsrees/JavaPhoenixClient.git +POM_SCM_CONNECTION=scm:git:git://github.com/dsrees/JavaPhoenixClient.git +POM_SCM_DEV_CONNECTION=scm:git:ssh://git@github.com/dsrees/JavaPhoenixClient.git + +POM_LICENCE_NAME=MIT License +POM_LICENCE_URL=https://github.com/dsrees/JavaPhoenixClient/blob/master/LICENSE.md +POM_LICENCE_DIST=repo + +POM_DEVELOPER_ID=dsrees +POM_DEVELOPER_NAME=Daniel Rees +POM_DEVELOPER_URL=https://github.com/dsrees/ \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 05e0e01..aa98239 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ -#Thu Jun 13 12:25:29 EDT 2019 +#Wed Oct 04 12:21:20 EDT 2023 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.2-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-5.4.1-all.zip diff --git a/src/main/kotlin/org/phoenixframework/Channel.kt b/src/main/kotlin/org/phoenixframework/Channel.kt index 4c97999..d6103ab 100644 --- a/src/main/kotlin/org/phoenixframework/Channel.kt +++ b/src/main/kotlin/org/phoenixframework/Channel.kt @@ -38,7 +38,7 @@ data class Binding( */ class Channel( val topic: String, - params: Payload, + paramsClosure: PayloadClosure, internal val socket: Socket ) { @@ -94,10 +94,10 @@ class Channel( internal var timeout: Long /** Params passed in through constructions and provided to the JoinPush */ - var params: Payload = params + var params: Payload + get() = joinPush.payload set(value) { joinPush.payload = value - field = value } /** Set to true once the channel has attempted to join */ @@ -112,12 +112,21 @@ class Channel( /** Timer to attempt rejoins */ internal var rejoinTimer: TimeoutTimer + /** Refs if stateChange hooks */ + internal var stateChangeRefs: MutableList + /** * Optional onMessage hook that can be provided. Receives all event messages for specialized * handling before dispatching to the Channel event callbacks. */ internal var onMessage: (Message) -> Message = { it } + constructor( + topic: String, + params: Payload, + socket: Socket + ) : this(topic, { params }, socket) + init { this.state = State.CLOSED this.bindings = ConcurrentLinkedQueue() @@ -125,6 +134,7 @@ class Channel( this.timeout = socket.timeout this.joinedOnce = false this.pushBuffer = mutableListOf() + this.stateChangeRefs = mutableListOf() this.rejoinTimer = TimeoutTimer( dispatchQueue = socket.dispatchQueue, timerCalculation = socket.rejoinAfterMs, @@ -133,17 +143,18 @@ class Channel( // Respond to socket events this.socket.onError { _, _-> this.rejoinTimer.reset() } + .apply { stateChangeRefs.add(this) } this.socket.onOpen { this.rejoinTimer.reset() if (this.isErrored) { this.rejoin() } - } + }.apply { stateChangeRefs.add(this) } // Setup Push to be sent when joining this.joinPush = Push( channel = this, event = Event.JOIN.value, - payload = params, + payloadClosure = paramsClosure, timeout = timeout) // Perform once the Channel has joined @@ -203,7 +214,14 @@ class Channel( this.socket.logItems("Channel: error $topic ${it.payload}") // If error was received while joining, then reset the Push - if (isJoining) { this.joinPush.reset() } + if (isJoining) { + // Make sure that the "phx_join" isn't buffered to send once the socket + // reconnects. The channel will send a new join event when the socket connects. + this.joinRef?.let { this.socket.removeFromSendBuffer(it) } + + // Reset the push to be used again later + this.joinPush.reset() + } // Mark the channel as errored and attempt to rejoin if socket is currently connected this.state = State.ERRORED @@ -212,7 +230,7 @@ class Channel( // Perform when the join reply is received this.on(Event.REPLY) { message -> - this.trigger(replyEventName(message.ref), message.payload, message.ref, message.joinRef) + this.trigger(replyEventName(message.ref), message.rawPayload, message.ref, message.joinRef, message.payloadJson) } } @@ -371,18 +389,20 @@ class Channel( event: Event, payload: Payload = hashMapOf(), ref: String = "", - joinRef: String? = null + joinRef: String? = null, + payloadJson: String = "" ) { - this.trigger(event.value, payload, ref, joinRef) + this.trigger(event.value, payload, ref, joinRef, payloadJson) } internal fun trigger( event: String, payload: Payload = hashMapOf(), ref: String = "", - joinRef: String? = null + joinRef: String? = null, + payloadJson: String = "" ) { - this.trigger(Message(ref, topic, event, payload, joinRef)) + this.trigger(Message(joinRef, ref, topic, event, payload, payloadJson)) } internal fun trigger(message: Message) { @@ -414,6 +434,9 @@ class Channel( // Do not attempt to rejoin if the channel is in the process of leaving if (isLeaving) return + // Leave potentially duplicated channels + this.socket.leaveOpenTopic(this.topic) + // Send the joinPush this.sendJoin(timeout) } diff --git a/src/main/kotlin/org/phoenixframework/Defaults.kt b/src/main/kotlin/org/phoenixframework/Defaults.kt index b6c12be..06ecc9f 100644 --- a/src/main/kotlin/org/phoenixframework/Defaults.kt +++ b/src/main/kotlin/org/phoenixframework/Defaults.kt @@ -25,6 +25,10 @@ package org.phoenixframework import com.google.gson.FieldNamingPolicy import com.google.gson.Gson import com.google.gson.GsonBuilder +import com.google.gson.JsonObject +import com.google.gson.reflect.TypeToken +import okhttp3.HttpUrl.Companion.toHttpUrlOrNull +import java.net.URL object Defaults { @@ -34,9 +38,14 @@ object Defaults { /** Default heartbeat interval of 30s */ const val HEARTBEAT: Long = 30_000 + /** Default JSON Serializer Version set to 2.0.0 */ + const val VSN: String = "2.0.0" + /** Default reconnect algorithm for the socket */ val reconnectSteppedBackOff: (Int) -> Long = { tries -> - if (tries > 9) 5_000 else listOf(10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L)[tries - 1] + if (tries > 9) 5_000 else listOf( + 10L, 50L, 100L, 150L, 200L, 250L, 500L, 1_000L, 2_000L + )[tries - 1] } /** Default rejoin algorithm for individual channels */ @@ -44,11 +53,110 @@ object Defaults { if (tries > 3) 10_000 else listOf(1_000L, 2_000L, 5_000L)[tries - 1] } - /** The default Gson configuration to use when parsing messages */ val gson: Gson get() = GsonBuilder() - .setLenient() - .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) - .create() + .setLenient() + .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) + .create() + + /** + * Default JSON decoder, backed by GSON, that takes JSON and converts it + * into a Message object. + */ + @Suppress("UNCHECKED_CAST") + val decode: DecodeClosure = { rawMessage -> + + val parseValue: (String) -> String? = { value -> + when(value) { + "null" -> null + else -> value.replace("\"", "") + } + } + + var message = rawMessage + message = message.removeRange(0, 1) // remove '[' + + val joinRef = message.takeWhile { it != ',' } // take "join ref", "null" or "\"5\"" + message = message.removeRange(0, joinRef.length) // remove join ref + message = message.removeRange(0, 1) // remove ',' + + val ref = message.takeWhile { it != ',' } // take ref, "null" or "\"5\"" + message = message.removeRange(0, ref.length) // remove ref + message = message.removeRange(0, 1) // remove ',' + + val topic = message.takeWhile { it != ',' } // take topic, "\"topic\"" + message = message.removeRange(0, topic.length) + message = message.removeRange(0, 1) // remove ',' + + val event = message.takeWhile { it != ',' } // take event, "\"phx_reply\"" + message = message.removeRange(0, event.length) + message = message.removeRange(0, 1) // remove ',' + + var remaining = message.removeRange(message.length - 1, message.length) // remove ']' + + // Payload should now just be "{"message":"hello","from":"user_1"}" or + // "{"response": {"message":"hello","from":"user_1"}},"status":"ok"}", flatten. + val jsonObj = gson.fromJson(remaining, JsonObject::class.java) + val response = jsonObj.get("response") + val payload = response?.let { gson.toJson(response) } ?: remaining + + val anyType = object : TypeToken>() {}.type + val result = gson.fromJson>(remaining, anyType) + + // vsn=2.0.0 message structure + // [join_ref, ref, topic, event, payload] + Message( + joinRef = parseValue(joinRef), + ref = parseValue(ref) ?: "", + topic = parseValue(topic) ?: "", + event = parseValue(event) ?: "", + rawPayload = result, + payloadJson = payload + ) + } + + /** + * Default JSON encoder, backed by GSON, that takes a Map and + * converts it into a JSON String. + */ + val encode: EncodeClosure = { payload -> + gson.toJson(payload) + } + + /** + * Takes an endpoint and a params closure given by the User and constructs a URL that + * is ready to be sent to the Socket connection. + * + * Will convert "ws://" and "wss://" to http/s which is what OkHttp expects. + * + * @throws IllegalArgumentException if [endpoint] is not a valid URL endpoint. + */ + internal fun buildEndpointUrl( + endpoint: String, + paramsClosure: PayloadClosure, + vsn: String + ): URL { + var mutableUrl = endpoint + // Silently replace web socket URLs with HTTP URLs. + if (endpoint.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) { + mutableUrl = "http:" + endpoint.substring(3) + } else if (endpoint.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) { + mutableUrl = "https:" + endpoint.substring(4) + } + + // Add the VSN query parameter + var httpUrl = mutableUrl.toHttpUrlOrNull() + ?: throw IllegalArgumentException("invalid url: $endpoint") + val httpBuilder = httpUrl.newBuilder() + httpBuilder.addQueryParameter("vsn", vsn) + + // Append any additional query params + paramsClosure.invoke().forEach { (key, value) -> + httpBuilder.addQueryParameter(key, value.toString()) + } + + // Return the [URL] that will be used to establish a connection + return httpBuilder.build().toUrl() + } } \ No newline at end of file diff --git a/src/main/kotlin/org/phoenixframework/Message.kt b/src/main/kotlin/org/phoenixframework/Message.kt index 386a5e2..ed177b0 100644 --- a/src/main/kotlin/org/phoenixframework/Message.kt +++ b/src/main/kotlin/org/phoenixframework/Message.kt @@ -22,34 +22,36 @@ package org.phoenixframework -import com.google.gson.annotations.SerializedName -class Message( +data class Message( + /** The ref sent during a join event. Empty if not present. */ + val joinRef: String? = null, + /** The unique string ref. Empty if not present */ - @SerializedName("ref") val ref: String = "", /** The message topic */ - @SerializedName("topic") val topic: String = "", /** The message event name, for example "phx_join" or any other custom name */ - @SerializedName("event") val event: String = "", - /** The payload of the message */ - @SerializedName("payload") - val payload: Payload = HashMap(), + /** The raw payload of the message. It is recommended that you use `payload` instead. */ + internal val rawPayload: Payload = HashMap(), - /** The ref sent during a join event. Empty if not present. */ - @SerializedName("join_ref") - val joinRef: String? = null) { + /** The payload, as a json string */ + val payloadJson: String = "" +) { + /** The payload of the message */ + @Suppress("UNCHECKED_CAST") + val payload: Payload + get() = rawPayload["response"] as? Payload ?: rawPayload /** * Convenience var to access the message's payload's status. Equivalent * to checking message.payload["status"] yourself */ val status: String? - get() = payload["status"] as? String + get() = rawPayload["status"] as? String } diff --git a/src/main/kotlin/org/phoenixframework/Presence.kt b/src/main/kotlin/org/phoenixframework/Presence.kt index 3e4b9db..3a1a9b8 100644 --- a/src/main/kotlin/org/phoenixframework/Presence.kt +++ b/src/main/kotlin/org/phoenixframework/Presence.kt @@ -126,7 +126,7 @@ class Presence(channel: Channel, opts: Options = Options.defaults) { if (stateEvent != null && diffEvent != null) { this.channel.on(stateEvent) { message -> - val newState = message.payload.toMutableMap() as PresenceState + val newState = message.rawPayload.toMutableMap() as PresenceState this.joinRef = this.channel.joinRef this.state = @@ -142,7 +142,7 @@ class Presence(channel: Channel, opts: Options = Options.defaults) { } this.channel.on(diffEvent) { message -> - val diff = message.payload.toMutableMap() as PresenceDiff + val diff = message.rawPayload.toMutableMap() as PresenceDiff if (isPendingSyncState) { this.pendingDiffs.add(diff) } else { diff --git a/src/main/kotlin/org/phoenixframework/Push.kt b/src/main/kotlin/org/phoenixframework/Push.kt index 5234205..75c75f1 100644 --- a/src/main/kotlin/org/phoenixframework/Push.kt +++ b/src/main/kotlin/org/phoenixframework/Push.kt @@ -32,8 +32,8 @@ class Push( val channel: Channel, /** The event the Push is targeting */ val event: String, - /** The message to be sent */ - var payload: Payload = mapOf(), + /** Closure that allows changing parameters sent during push */ + var payloadClosure: PayloadClosure, /** Duration before the message is considered timed out and failed to send */ var timeout: Long = Defaults.TIMEOUT ) { @@ -47,6 +47,9 @@ class Push( /** Hooks into a Push. Where .receive("ok", callback(Payload)) are stored */ var receiveHooks: MutableMap Unit)>> = HashMap() + /** Hooks into a Push. Where .receiveAll(callback(status, message)) are stored */ + private var receiveAllHooks: MutableList<(status: String, message: Message) -> Unit> = mutableListOf() + /** True if the Push has been sent */ var sent: Boolean = false @@ -56,6 +59,23 @@ class Push( /** The event that is associated with the reference ID of the Push */ var refEvent: String? = null + var payload: Payload + get() = payloadClosure.invoke() + set(value) { + payloadClosure = { value } + } + + constructor( + /** The channel the Push is being sent through */ + channel: Channel, + /** The event the Push is targeting */ + event: String, + /** The message to be sent */ + payload: Payload = mapOf(), + /** Duration before the message is considered timed out and failed to send */ + timeout: Long = Defaults.TIMEOUT + ) : this(channel, event, { payload }, timeout) + //------------------------------------------------------------------------------ // Public //------------------------------------------------------------------------------ @@ -100,6 +120,28 @@ class Push( return this } + /** + * Receives any event that was a response to an outbound message. + * + * Example: + * channel + * .send("event", mPayload) + * .receive { status, message -> + * print(status) // "ok" + * } + */ + fun receive(callback: (status: String, message: Message) -> Unit): Push { + // If the message has already been received, pass it to the callback. + receivedMessage?.let { + val status = it.status + if (status != null) { + callback(status, it) + } + } + receiveAllHooks.add(callback) + return this + } + //------------------------------------------------------------------------------ // Internal //------------------------------------------------------------------------------ @@ -165,6 +207,7 @@ class Push( */ private fun matchReceive(status: String, message: Message) { receiveHooks[status]?.forEach { it(message) } + receiveAllHooks.forEach { it(status, message) } } /** Removes receive hook from Channel regarding this Push */ diff --git a/src/main/kotlin/org/phoenixframework/Socket.kt b/src/main/kotlin/org/phoenixframework/Socket.kt index 66bb7fc..4d3ddef 100644 --- a/src/main/kotlin/org/phoenixframework/Socket.kt +++ b/src/main/kotlin/org/phoenixframework/Socket.kt @@ -22,8 +22,6 @@ package org.phoenixframework -import com.google.gson.Gson -import okhttp3.HttpUrl import okhttp3.OkHttpClient import okhttp3.Response import java.net.URL @@ -35,33 +33,53 @@ typealias Payload = Map /** Data class that holds callbacks assigned to the socket */ internal class StateChangeCallbacks { - var open: List<() -> Unit> = ArrayList() + var open: List Unit>> = ArrayList() private set - var close: List<() -> Unit> = ArrayList() + var close: List Unit>> = ArrayList() private set - var error: List<(Throwable, Response?) -> Unit> = ArrayList() + var error: List Unit>> = ArrayList() private set - var message: List<(Message) -> Unit> = ArrayList() + var message: List Unit>> = ArrayList() private set /** Safely adds an onOpen callback */ - fun onOpen(callback: () -> Unit) { - this.open = this.open + callback + fun onOpen( + ref: String, + callback: () -> Unit + ) { + this.open = this.open + Pair(ref, callback) } /** Safely adds an onClose callback */ - fun onClose(callback: () -> Unit) { - this.close = this.close + callback + fun onClose( + ref: String, + callback: () -> Unit + ) { + this.close = this.close + Pair(ref, callback) } /** Safely adds an onError callback */ - fun onError(callback: (Throwable, Response?) -> Unit) { - this.error = this.error + callback + fun onError( + ref: String, + callback: (Throwable, Response?) -> Unit + ) { + this.error = this.error + Pair(ref, callback) } /** Safely adds an onMessage callback */ - fun onMessage(callback: (Message) -> Unit) { - this.message = this.message + callback + fun onMessage( + ref: String, + callback: (Message) -> Unit + ) { + this.message = this.message + Pair(ref, callback) + } + + /** Clears any callbacks with the matching refs */ + fun release(refs: List) { + open = open.filterNot { refs.contains(it.first) } + close = close.filterNot { refs.contains(it.first) } + error = error.filterNot { refs.contains(it.first) } + message = message.filterNot { refs.contains(it.first) } } /** Clears all stored callbacks */ @@ -79,13 +97,43 @@ const val WS_CLOSE_NORMAL = 1000 /** RFC 6455: indicates that the connection was closed abnormally */ const val WS_CLOSE_ABNORMAL = 1006 +/** + * A closure that will return an optional Payload + */ +typealias PayloadClosure = () -> Payload + +/** A closure that will encode a Map into a JSON String */ +typealias EncodeClosure = (Any) -> String + +/** A closure that will decode a JSON String into a [Message] */ +typealias DecodeClosure = (String) -> Message + + /** * Connects to a Phoenix Server */ + +/** + * A [Socket] which connects to a Phoenix Server. Takes a closure to allow for changing parameters + * to be sent to the server when connecting. + * + * ## Example + * ``` + * val socket = Socket("https://example.com/socket", { mapOf("token" to mAuthToken) }) + * ``` + * @param url Url to connect to such as https://example.com/socket + * @param paramsClosure Closure which allows to change parameters sent during connection. + * @param vsn JSON Serializer version to use. Defaults to 2.0.0 + * @param encode Optional. Provide a custom JSON encoding implementation + * @param decode Optional. Provide a custom JSON decoding implementation + * @param client Default OkHttpClient to connect with. You can provide your own if needed. + */ class Socket( url: String, - params: Payload? = null, - private val gson: Gson = Defaults.gson, + val paramsClosure: PayloadClosure, + val vsn: String = Defaults.VSN, + private val encode: EncodeClosure = Defaults.encode, + private val decode: DecodeClosure = Defaults.decode, private val client: OkHttpClient = OkHttpClient.Builder().build() ) { @@ -101,13 +149,8 @@ class Socket( val endpoint: String /** The fully qualified socket URL */ - val endpointUrl: URL - - /** - * The optional params to pass when connecting. Must be set when - * initializing the Socket. These will be appended to the URL. - */ - val params: Payload? = params + var endpointUrl: URL + private set /** Timeout to use when opening a connection */ var timeout: Long = Defaults.TIMEOUT @@ -151,8 +194,11 @@ class Socket( /** Collection of unclosed channels created by the Socket */ internal var channels: List = ArrayList() - /** Buffers messages that need to be sent once the socket has connected */ - internal var sendBuffer: MutableList<() -> Unit> = ArrayList() + /** + * Buffers messages that need to be sent once the socket has connected. It is an array of Pairs + * that contain the ref of the message to send and the callback that will send the message. + */ + internal var sendBuffer: MutableList Unit>> = ArrayList() /** Ref counter for messages */ internal var ref: Int = 0 @@ -178,6 +224,31 @@ class Socket( //------------------------------------------------------------------------------ // Initialization //------------------------------------------------------------------------------ + /** + * A [Socket] which connects to a Phoenix Server. Takes a constant parameter to be sent to the + * server when connecting. Defaults to null if excluded. + * + * ## Example + * ``` + * val socket = Socket("https://example.com/socket", mapOf("token" to mAuthToken)) + * ``` + * + * @param url Url to connect to such as https://example.com/socket + * @param params Constant parameters to send when connecting. Defaults to null + * @param vsn JSON Serializer version to use. Defaults to 2.0.0 + * @param encode Optional. Provide a custom JSON encoding implementation + * @param decode Optional. Provide a custom JSON decoding implementation + * @param client Default OkHttpClient to connect with. You can provide your own if needed. + */ + constructor( + url: String, + params: Payload = mapOf(), + vsn: String = Defaults.VSN, + encode: EncodeClosure = Defaults.encode, + decode: DecodeClosure = Defaults.decode, + client: OkHttpClient = OkHttpClient.Builder().build() + ) : this(url, { params }, vsn, encode, decode, client) + init { var mutableUrl = url @@ -195,35 +266,18 @@ class Socket( // Store the endpoint before changing the protocol this.endpoint = mutableUrl - // Silently replace web socket URLs with HTTP URLs. - if (url.regionMatches(0, "ws:", 0, 3, ignoreCase = true)) { - mutableUrl = "http:" + url.substring(3) - } else if (url.regionMatches(0, "wss:", 0, 4, ignoreCase = true)) { - mutableUrl = "https:" + url.substring(4) - } - - // If there are query params, append them now - var httpUrl = HttpUrl.parse(mutableUrl) ?: throw IllegalArgumentException("invalid url: $url") - params?.let { - val httpBuilder = httpUrl.newBuilder() - it.forEach { (key, value) -> - httpBuilder.addQueryParameter(key, value.toString()) - } - - httpUrl = httpBuilder.build() - } - - // Store the URL that will be used to establish a connection - this.endpointUrl = httpUrl.url() + // Store the URL that will be used to establish a connection. Could potentially be + // different at the time connect() is called based on a changing params closure. + this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure, this.vsn) // Create reconnect timer this.reconnectTimer = TimeoutTimer( - dispatchQueue = dispatchQueue, - timerCalculation = reconnectAfterMs, - callback = { - this.logItems("Socket attempting to reconnect") - this.teardown { this.connect() } - }) + dispatchQueue = dispatchQueue, + timerCalculation = { reconnectAfterMs(it) }, + callback = { + this.logItems("Socket attempting to reconnect") + this.teardown { this.connect() } + }) } //------------------------------------------------------------------------------ @@ -239,7 +293,11 @@ class Socket( /** @return True if the connection exists and is open */ val isConnected: Boolean - get() = this.connection?.readyState == Transport.ReadyState.OPEN + get() = this.connectionState == Transport.ReadyState.OPEN + + /** @return The ready state of the connection. */ + val connectionState: Transport.ReadyState + get() = this.connection?.readyState ?: Transport.ReadyState.CLOSED //------------------------------------------------------------------------------ // Public @@ -251,6 +309,11 @@ class Socket( // Reset the clean close flag when attempting to connect this.closeWasClean = false + // Build the new endpointUrl with the params closure. The payload returned + // from the closure could be different such as a changing authToken. + this.endpointUrl = Defaults.buildEndpointUrl(this.endpoint, this.paramsClosure, this.vsn) + + // Now create the connection transport and attempt to connect this.connection = this.transport(endpointUrl) this.connection?.onOpen = { onConnectionOpened() } this.connection?.onClose = { code -> onConnectionClosed(code) } @@ -270,42 +333,60 @@ class Socket( // Reset any reconnects and teardown the socket connection this.reconnectTimer.reset() this.teardown(code, reason, callback) - } - fun onOpen(callback: (() -> Unit)) { - this.stateChangeCallbacks.onOpen(callback) + fun onOpen(callback: (() -> Unit)): String { + return makeRef().apply { stateChangeCallbacks.onOpen(this, callback) } } - fun onClose(callback: () -> Unit) { - this.stateChangeCallbacks.onClose(callback) + fun onClose(callback: () -> Unit): String { + return makeRef().apply { stateChangeCallbacks.onClose(this, callback) } } - fun onError(callback: (Throwable, Response?) -> Unit) { - this.stateChangeCallbacks.onError(callback) + fun onError(callback: (Throwable, Response?) -> Unit): String { + return makeRef().apply { stateChangeCallbacks.onError(this, callback) } } - fun onMessage(callback: (Message) -> Unit) { - this.stateChangeCallbacks.onMessage(callback) + fun onMessage(callback: (Message) -> Unit): String { + return makeRef().apply { stateChangeCallbacks.onMessage(this, callback) } } fun removeAllCallbacks() { this.stateChangeCallbacks.release() } - fun channel(topic: String, params: Payload = mapOf()): Channel { - val channel = Channel(topic, params, this) - this.channels = this.channels + channel + fun channel( + topic: String, + params: Payload = mapOf() + ): Channel = this.channel(topic) { params } + + fun channel( + topic: String, + paramsClosure: PayloadClosure + ): Channel { + val channel = Channel(topic, paramsClosure, this) + this.channels += channel return channel } fun remove(channel: Channel) { + this.off(channel.stateChangeRefs) + // To avoid a ConcurrentModificationException, filter out the channels to be // removed instead of calling .remove() on the list, thus returning a new list // that does not contain the channel that was removed. this.channels = channels - .filter { it.joinRef != channel.joinRef } + .filter { it.joinRef != channel.joinRef } + } + + /** + * Removes [onOpen], [onClose], [onError], and [onMessage] registrations by their [ref] value. + * + * @param refs List of refs to remove + */ + fun off(refs: List) { + this.stateChangeCallbacks.release(refs) } //------------------------------------------------------------------------------ @@ -320,15 +401,8 @@ class Socket( ) { val callback: (() -> Unit) = { - val body = mutableMapOf() - body["topic"] = topic - body["event"] = event - body["payload"] = payload - - ref?.let { body["ref"] = it } - joinRef?.let { body["join_ref"] = it } - - val data = gson.toJson(body) + val body = listOf(joinRef, ref, topic, event, payload) + val data = this.encode(body) connection?.let { transport -> this.logItems("Push: Sending $data") transport.send(data) @@ -341,7 +415,7 @@ class Socket( } else { // If the socket is not connected, add the push to a buffer which will // be sent immediately upon connection. - sendBuffer.add(callback) + sendBuffer.add(Pair(ref, callback)) } } @@ -374,7 +448,7 @@ class Socket( // Since the connections onClose was null'd out, inform all state callbacks // that the Socket has closed - this.stateChangeCallbacks.close.forEach { it.invoke() } + this.stateChangeCallbacks.close.forEach { it.second.invoke() } callback?.invoke() } @@ -391,11 +465,27 @@ class Socket( /** Send all messages that were buffered before the socket opened */ internal fun flushSendBuffer() { if (isConnected && sendBuffer.isNotEmpty()) { - this.sendBuffer.forEach { it.invoke() } + this.sendBuffer.forEach { it.second.invoke() } this.sendBuffer.clear() } } + /** Removes an item from the send buffer with the matching ref */ + internal fun removeFromSendBuffer(ref: String) { + this.sendBuffer = this.sendBuffer + .filter { it.first != ref } + .toMutableList() + } + + internal fun leaveOpenTopic(topic: String) { + this.channels + .firstOrNull { it.topic == topic && (it.isJoined || it.isJoining) } + ?.let { + logItems("Transport: Leaving duplicate topic: [$topic]") + it.leave() + } + } + //------------------------------------------------------------------------------ // Heartbeat //------------------------------------------------------------------------------ @@ -411,7 +501,7 @@ class Socket( val period = heartbeatIntervalMs heartbeatTask = - dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() } + dispatchQueue.queueAtFixedRate(delay, period, TimeUnit.MILLISECONDS) { sendHeartbeat() } } internal fun sendHeartbeat() { @@ -433,10 +523,11 @@ class Socket( // The last heartbeat was acknowledged by the server. Send another one this.pendingHeartbeatRef = this.makeRef() this.push( - topic = "phoenix", - event = Channel.Event.HEARTBEAT.value, - payload = mapOf(), - ref = pendingHeartbeatRef) + topic = "phoenix", + event = Channel.Event.HEARTBEAT.value, + payload = mapOf(), + ref = pendingHeartbeatRef + ) } private fun abnormalClose(reason: String) { @@ -469,7 +560,7 @@ class Socket( this.resetHeartbeat() // Inform all onOpen callbacks that the Socket has opened - this.stateChangeCallbacks.open.forEach { it.invoke() } + this.stateChangeCallbacks.open.forEach { it.second.invoke() } } internal fun onConnectionClosed(code: Int) { @@ -486,35 +577,37 @@ class Socket( } // Inform callbacks the socket closed - this.stateChangeCallbacks.close.forEach { it.invoke() } + this.stateChangeCallbacks.close.forEach { it.second.invoke() } } internal fun onConnectionMessage(rawMessage: String) { this.logItems("Receive: $rawMessage") // Parse the message as JSON - val message = gson.fromJson(rawMessage, Message::class.java) + val message = this.decode(rawMessage) // Clear heartbeat ref, preventing a heartbeat timeout disconnect if (message.ref == pendingHeartbeatRef) pendingHeartbeatRef = null // Dispatch the message to all channels that belong to the topic this.channels - .filter { it.isMember(message) } - .forEach { it.trigger(message) } + .filter { it.isMember(message) } + .forEach { it.trigger(message) } // Inform all onMessage callbacks of the message - this.stateChangeCallbacks.message.forEach { it.invoke(message) } + this.stateChangeCallbacks.message.forEach { it.second.invoke(message) } } - internal fun onConnectionError(t: Throwable, response: Response?) { + internal fun onConnectionError( + t: Throwable, + response: Response? + ) { this.logItems("Transport: error $t") // Send an error to all channels this.triggerChannelError() // Inform any state callbacks of the error - this.stateChangeCallbacks.error.forEach { it.invoke(t, response) } + this.stateChangeCallbacks.error.forEach { it.second.invoke(t, response) } } - -} \ No newline at end of file +} diff --git a/src/main/kotlin/org/phoenixframework/Transport.kt b/src/main/kotlin/org/phoenixframework/Transport.kt index 25c8818..738fad6 100644 --- a/src/main/kotlin/org/phoenixframework/Transport.kt +++ b/src/main/kotlin/org/phoenixframework/Transport.kt @@ -155,6 +155,7 @@ class WebSocketTransport( override fun onClosing(webSocket: WebSocket, code: Int, reason: String) { this.readyState = Transport.ReadyState.CLOSING + webSocket.close(code, reason) } override fun onMessage(webSocket: WebSocket, text: String) { diff --git a/src/test/kotlin/org/phoenixframework/ChannelTest.kt b/src/test/kotlin/org/phoenixframework/ChannelTest.kt index 4963bfd..9965caa 100644 --- a/src/test/kotlin/org/phoenixframework/ChannelTest.kt +++ b/src/test/kotlin/org/phoenixframework/ChannelTest.kt @@ -1,14 +1,14 @@ package org.phoenixframework import com.google.common.truth.Truth.assertThat -import com.nhaarman.mockitokotlin2.any -import com.nhaarman.mockitokotlin2.eq -import com.nhaarman.mockitokotlin2.mock -import com.nhaarman.mockitokotlin2.never -import com.nhaarman.mockitokotlin2.spy -import com.nhaarman.mockitokotlin2.times -import com.nhaarman.mockitokotlin2.verify -import com.nhaarman.mockitokotlin2.whenever +import org.mockito.kotlin.any +import org.mockito.kotlin.eq +import org.mockito.kotlin.mock +import org.mockito.kotlin.never +import org.mockito.kotlin.spy +import org.mockito.kotlin.times +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever import okhttp3.OkHttpClient import org.junit.jupiter.api.AfterEach import org.junit.jupiter.api.BeforeEach @@ -16,7 +16,7 @@ import org.junit.jupiter.api.DisplayName import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import org.mockito.Mock -import org.mockito.Mockito.verifyZeroInteractions +import org.mockito.Mockito.verifyNoInteractions import org.mockito.MockitoAnnotations import org.phoenixframework.queue.ManualDispatchQueue import org.phoenixframework.utilities.getBindings @@ -28,6 +28,7 @@ class ChannelTest { @Mock lateinit var socket: Socket @Mock lateinit var mockCallback: ((Message) -> Unit) + @Mock lateinit var mockStatusCallback: ((String, Message) -> Unit) private val kDefaultRef = "1" private val kDefaultTimeout = 10_000L @@ -149,13 +150,39 @@ class ChannelTest { /* End JoinParams */ } + + @Nested + @DisplayName("join paramsClosure") + inner class JoinParamsClosure { + @Test + internal fun `updating join params closure`() { + val paramsClosure = { mapOf("value" to 1) } + val change = mapOf("value" to 2) + + channel = Channel("topic", paramsClosure, socket) + val joinPush = channel.joinPush + + assertThat(joinPush.channel).isEqualTo(channel) + assertThat(joinPush.payload["value"]).isEqualTo(1) + assertThat(joinPush.event).isEqualTo("phx_join") + assertThat(joinPush.timeout).isEqualTo(10_000L) + + channel.params = change + assertThat(joinPush.channel).isEqualTo(channel) + assertThat(joinPush.payload["value"]).isEqualTo(2) + assertThat(channel.params["value"]).isEqualTo(2) + assertThat(joinPush.event).isEqualTo("phx_join") + assertThat(joinPush.timeout).isEqualTo(10_000L) + } + } + @Nested @DisplayName("join") inner class Join { @BeforeEach internal fun setUp() { - socket = spy(Socket(url ="https://localhost:4000/socket", client = okHttpClient)) + socket = spy(Socket(url = "https://localhost:4000/socket", client = okHttpClient)) socket.dispatchQueue = fakeClock channel = Channel("topic", kDefaultPayload, socket) } @@ -193,7 +220,7 @@ class ChannelTest { @Test internal fun `triggers socket push with channel params`() { channel.join() - verify(socket).push("topic", "phx_join", kDefaultPayload, kDefaultRef, channel.joinRef) + verify(socket).push("topic", "phx_join", kDefaultPayload, "3", channel.joinRef) } @Test @@ -206,6 +233,22 @@ class ChannelTest { assertThat(joinPush.timeout).isEqualTo(newTimeout) } + @Test + internal fun `leaves existing duplicate topic on new join`() { + val socket = spy(Socket("wss://localhost:4000/socket")) + val channel = socket.channel("topic") + + channel.join().receive("ok") { + val newChannel = socket.channel("topic") + assertThat(channel.isJoined).isTrue() + newChannel.join() + + assertThat(channel.isJoined).isFalse() + } + + channel.joinPush.trigger("ok", kEmptyPayload) + } + @Nested @DisplayName("timeout behavior") inner class TimeoutBehavior { @@ -368,6 +411,11 @@ class ChannelTest { joinPush.trigger("error", mapOf("a" to "b")) } + private fun receivesApproved() { + fakeClock.tick(joinPush.timeout / 2) + joinPush.trigger("approved", mapOf("a" to "b")) + } + @Nested @DisplayName("receives 'ok'") inner class ReceivesOk { @@ -429,11 +477,11 @@ class ChannelTest { @Test internal fun `removes channel binding`() { - var bindings = channel.getBindings("chan_reply_1") + var bindings = channel.getBindings("chan_reply_3") assertThat(bindings).hasSize(1) receivesOk() - bindings = channel.getBindings("chan_reply_1") + bindings = channel.getBindings("chan_reply_3") assertThat(bindings).isEmpty() } @@ -493,8 +541,8 @@ class ChannelTest { .receive("ok", mockOk) .receive("error", mockError) .receive("timeout") { - verifyZeroInteractions(mockOk) - verifyZeroInteractions(mockError) + verifyNoInteractions(mockOk) + verifyNoInteractions(mockError) timeoutReceived = true } @@ -555,8 +603,8 @@ class ChannelTest { receivesTimeout() verify(mockError, times(1)).invoke(any()) - verifyZeroInteractions(mockOk) - verifyZeroInteractions(mockTimeout) + verifyNoInteractions(mockOk) + verifyNoInteractions(mockTimeout) } @Test @@ -583,7 +631,7 @@ class ChannelTest { @Test internal fun `removes channel binding`() { - var bindings = channel.getBindings("chan_reply_1") + var bindings = channel.getBindings("chan_reply_3") assertThat(bindings).hasSize(1) receivesError() @@ -603,13 +651,59 @@ class ChannelTest { channel.pushBuffer.add(mockPush) receivesError() - verifyZeroInteractions(mockPush) + verifyNoInteractions(mockPush) assertThat(channel.pushBuffer).hasSize(1) } /* End ReceivesError */ } + + @Nested + @DisplayName("receives 'all status'") + inner class ReceivesAllStatus { + @Test + internal fun `triggers receive('error') callback after error response`() { + assertThat(channel.state).isEqualTo(Channel.State.JOINING) + joinPush.receive(mockStatusCallback) + + receivesError() + joinPush.trigger("error", kEmptyPayload) + verify(mockStatusCallback, times(1)).invoke(any(), any()) + } + + @Test + internal fun `triggers receive('error') callback if error response already received`() { + receivesError() + + joinPush.receive(mockStatusCallback) + + verify(mockStatusCallback).invoke(any(), any()) + } + + @Test + internal fun `triggers receive('approved') callback after approved response`() { + assertThat(channel.state).isEqualTo(Channel.State.JOINING) + joinPush.receive(mockStatusCallback) + + receivesApproved() + joinPush.trigger("approved", kEmptyPayload) + verify(mockStatusCallback, times(1)).invoke(any(), any()) + + } + + @Test + internal fun `triggers receive('approved') callback if approved response already received`() { + receivesApproved() + + joinPush.receive(mockStatusCallback) + + verify(mockStatusCallback).invoke(any(), any()) + } + + /* End ReceivesAllStatus */ + } + /* End JoinPush */ } @@ -656,12 +750,25 @@ class ChannelTest { fakeClock.tick(1000) verify(joinPush, times(1)).send() - channel.trigger("error") + channel.trigger(Channel.Event.ERROR) fakeClock.tick(1000) verify(joinPush, times(1)).send() } + @Test + internal fun `removes the joinPush message from sendBuffer`() { + val channel = Channel("topic", kDefaultPayload, socket) + val push = mock() + whenever(push.ref).thenReturn("10") + channel.joinPush = push + channel.state = Channel.State.JOINING + + channel.trigger(Channel.Event.ERROR) + verify(socket).removeFromSendBuffer("10") + verify(push).reset() + } + @Test internal fun `tries to rejoin with backoff`() { val mockTimer = mock() @@ -715,7 +822,7 @@ class ChannelTest { joinPush.trigger("ok", kEmptyPayload) assertThat(channel.state).isEqualTo(Channel.State.JOINED) - verifyZeroInteractions(mockCallback) + verifyNoInteractions(mockCallback) channel.trigger(Channel.Event.ERROR) verify(mockCallback, times(1)).invoke(any()) @@ -786,7 +893,7 @@ class ChannelTest { @Test internal fun `triggers additional callbacks`() { channel.onClose(mockCallback) - verifyZeroInteractions(mockCallback) + verifyNoInteractions(mockCallback) channel.trigger(Channel.Event.CLOSE) verify(mockCallback, times(1)).invoke(any()) @@ -878,8 +985,8 @@ class ChannelTest { channel.trigger(event = "event", ref = kDefaultRef) channel.trigger(event = "other", ref = kDefaultRef) - verifyZeroInteractions(callback1) - verifyZeroInteractions(callback2) + verifyNoInteractions(callback1) + verifyNoInteractions(callback2) verify(callback3, times(1)).invoke(any()) } @@ -894,7 +1001,7 @@ class ChannelTest { channel.off("event", ref1) channel.trigger(event = "event", ref = kDefaultRef) - verifyZeroInteractions(callback1) + verifyNoInteractions(callback1) verify(callback2, times(1)).invoke(any()) } @@ -952,7 +1059,7 @@ class ChannelTest { .receive("timeout", mockCallback) fakeClock.tick(channel.timeout / 2) - verifyZeroInteractions(mockCallback) + verifyNoInteractions(mockCallback) fakeClock.tick(channel.timeout) verify(mockCallback).invoke(any()) @@ -966,7 +1073,7 @@ class ChannelTest { .receive("timeout", mockCallback) fakeClock.tick(channel.timeout) - verifyZeroInteractions(mockCallback) + verifyNoInteractions(mockCallback) fakeClock.tick(channel.timeout * 2) verify(mockCallback).invoke(any()) @@ -980,12 +1087,12 @@ class ChannelTest { .receive("timeout", mockCallback) fakeClock.tick(channel.timeout / 2) - verifyZeroInteractions(mockCallback) + verifyNoInteractions(mockCallback) push.trigger("ok", kEmptyPayload) fakeClock.tick(channel.timeout) - verifyZeroInteractions(mockCallback) + verifyNoInteractions(mockCallback) } @Test diff --git a/src/test/kotlin/org/phoenixframework/DefaultsTest.kt b/src/test/kotlin/org/phoenixframework/DefaultsTest.kt index eca6d0b..bf2da9b 100644 --- a/src/test/kotlin/org/phoenixframework/DefaultsTest.kt +++ b/src/test/kotlin/org/phoenixframework/DefaultsTest.kt @@ -42,4 +42,101 @@ internal class DefaultsTest { assertThat(reconnect(4)).isEqualTo(10_000) assertThat(reconnect(5)).isEqualTo(10_000) } + + @Test + internal fun `decoder converts json array into message`() { + val v2Json = """ + [null,null,"room:lobby","shout",{"message":"Hi","name":"Tester"}] + """.trimIndent() + + val message = Defaults.decode(v2Json) + assertThat(message.joinRef).isNull() + assertThat(message.ref).isEqualTo("") + assertThat(message.topic).isEqualTo("room:lobby") + assertThat(message.event).isEqualTo("shout") + assertThat(message.payload).isEqualTo(mapOf("message" to "Hi", "name" to "Tester")) + } + + @Test + internal fun `decoder provides raw json payload`() { + val v2Json = """ + ["1","2","room:lobby","shout",{"message":"Hi","name":"Tester","count":15,"ratio":0.2}] + """.trimIndent() + + val message = Defaults.decode(v2Json) + assertThat(message.joinRef).isEqualTo("1") + assertThat(message.ref).isEqualTo("2") + assertThat(message.topic).isEqualTo("room:lobby") + assertThat(message.event).isEqualTo("shout") + assertThat(message.payloadJson).isEqualTo("{\"message\":\"Hi\",\"name\":\"Tester\",\"count\":15,\"ratio\":0.2}") + assertThat(message.payload).isEqualTo(mapOf( + "message" to "Hi", + "name" to "Tester", + "count" to 15.0, // Note that this is a bug and should eventually be removed + "ratio" to 0.2 + )) + } + + @Test + internal fun `decoder decodes a status`() { + val v2Json = """ + ["1","2","room:lobby","phx_reply",{"response":{"message":"Hi","name":"Tester","count":15,"ratio":0.2},"status":"ok"}] + """.trimIndent() + + val message = Defaults.decode(v2Json) + assertThat(message.joinRef).isEqualTo("1") + assertThat(message.ref).isEqualTo("2") + assertThat(message.topic).isEqualTo("room:lobby") + assertThat(message.event).isEqualTo("phx_reply") + assertThat(message.payloadJson).isEqualTo("{\"message\":\"Hi\",\"name\":\"Tester\",\"count\":15,\"ratio\":0.2}") + assertThat(message.payload).isEqualTo(mapOf( + "message" to "Hi", + "name" to "Tester", + "count" to 15.0, // Note that this is a bug and should eventually be removed + "ratio" to 0.2 + )) + } + + + + @Test + internal fun `decoder decodes an error`() { + val v2Json = """ + ["6","8","drivers:self","phx_reply",{"response":{"details":"invalid code specified"},"status":"error"}] + """.trimIndent() + + val message = Defaults.decode(v2Json) + assertThat(message.payloadJson).isEqualTo("{\"details\":\"invalid code specified\"}") + assertThat(message.rawPayload).isEqualTo(mapOf( + "response" to mapOf( + "details" to "invalid code specified" + ), + "status" to "error" + )) + assertThat(message.payload).isEqualTo(mapOf( + "details" to "invalid code specified" + )) + + } + + @Test + internal fun `decoder decodes a non-json payload`() { + val v2Json = """ + ["1","2","room:lobby","phx_reply",{"response":"hello","status":"ok"}] + """.trimIndent() + + val message = Defaults.decode(v2Json) + assertThat(message.payloadJson).isEqualTo("\"hello\"") + assertThat(message.payload).isEqualTo(mapOf( + "response" to "hello", + "status" to "ok" + )) + } + + @Test + internal fun `encode converts message into json`() { + val body = listOf(null, null, "topic", "event", mapOf("one" to "two")) + assertThat(Defaults.encode(body)) + .isEqualTo("[null,null,\"topic\",\"event\",{\"one\":\"two\"}]") + } } \ No newline at end of file diff --git a/src/test/kotlin/org/phoenixframework/MessageTest.kt b/src/test/kotlin/org/phoenixframework/MessageTest.kt index f949b70..2dcd321 100644 --- a/src/test/kotlin/org/phoenixframework/MessageTest.kt +++ b/src/test/kotlin/org/phoenixframework/MessageTest.kt @@ -8,18 +8,37 @@ import org.junit.jupiter.api.Test class MessageTest { @Nested - @DisplayName("status") - inner class Status { + @DisplayName("json parsing") + inner class JsonParsing { @Test - internal fun `returns the status from the payload`() { - val payload = mapOf("one" to "two", "status" to "ok") - val message = Message("ref", "topic", "event", payload, null) - - assertThat(message.ref).isEqualTo("ref") - assertThat(message.topic).isEqualTo("topic") - assertThat(message.event).isEqualTo("event") - assertThat(message.payload).isEqualTo(payload) + internal fun `jsonParsing parses normal message`() { + val json = """ + [null,"6","my-topic","update",{"user":"James S.","message":"This is a test"}] + """.trimIndent() + + val message = Defaults.decode.invoke(json) + + assertThat(message.ref).isEqualTo("6") + assertThat(message.topic).isEqualTo("my-topic") + assertThat(message.event).isEqualTo("update") + assertThat(message.payload).isEqualTo(mapOf("user" to "James S.", "message" to "This is a test")) + assertThat(message.joinRef).isNull() + assertThat(message.status).isNull() + } + + @Test + internal fun `jsonParsing parses a reply`() { + val json = """ + [null,"6","my-topic","phx_reply",{"response":{"user":"James S.","message":"This is a test"},"status": "ok"}] + """.trimIndent() + + val message = Defaults.decode.invoke(json) + + assertThat(message.ref).isEqualTo("6") + assertThat(message.topic).isEqualTo("my-topic") + assertThat(message.event).isEqualTo("phx_reply") + assertThat(message.payload).isEqualTo(mapOf("user" to "James S.", "message" to "This is a test")) assertThat(message.joinRef).isNull() assertThat(message.status).isEqualTo("ok") } diff --git a/src/test/kotlin/org/phoenixframework/PresenceTest.kt b/src/test/kotlin/org/phoenixframework/PresenceTest.kt index 780f897..2874e1f 100644 --- a/src/test/kotlin/org/phoenixframework/PresenceTest.kt +++ b/src/test/kotlin/org/phoenixframework/PresenceTest.kt @@ -1,8 +1,8 @@ package org.phoenixframework import com.google.common.truth.Truth.assertThat -import com.nhaarman.mockitokotlin2.mock -import com.nhaarman.mockitokotlin2.whenever +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.DisplayName import org.junit.jupiter.api.Nested @@ -196,7 +196,7 @@ class PresenceTest { @Test internal fun `onJoins new presences and onLeaves left presences`() { val newState = fixState - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u4" to mutableMapOf("metas" to listOf(mapOf("id" to 4, "phx_ref" to "4")))) val joined: PresenceDiff = mutableMapOf() @@ -245,9 +245,9 @@ class PresenceTest { @Test internal fun `onJoins only newly added metas`() { - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3")))) - val newState = mutableMapOf( + val newState: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf( mapOf("id" to 3, "phx_ref" to "3"), mapOf("id" to 3, "phx_ref" to "3.new") @@ -285,9 +285,9 @@ class PresenceTest { @Test internal fun `onLeaves only newly removed metas`() { - val newState = mutableMapOf( + val newState: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3")))) - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf( mapOf("id" to 3, "phx_ref" to "3"), mapOf("id" to 3, "phx_ref" to "3.left") @@ -326,13 +326,13 @@ class PresenceTest { @Test internal fun `syncs both joined and left metas`() { - val newState = mutableMapOf( + val newState: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf( mapOf("id" to 3, "phx_ref" to "3"), mapOf("id" to 3, "phx_ref" to "3.new") ))) - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf( mapOf("id" to 3, "phx_ref" to "3"), mapOf("id" to 3, "phx_ref" to "3.left") @@ -421,13 +421,13 @@ class PresenceTest { @Test internal fun `removes meta while leaving key if other metas exist`() { - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u1" to mutableMapOf("metas" to listOf( mapOf("id" to 1, "phx_ref" to "1"), mapOf("id" to 1, "phx_ref" to "1.2") ))) - val leaves = mutableMapOf( + val leaves: MutableMap>>> = mutableMapOf( "u1" to mutableMapOf("metas" to listOf( mapOf("id" to 1, "phx_ref" to "1") ))) diff --git a/src/test/kotlin/org/phoenixframework/SocketTest.kt b/src/test/kotlin/org/phoenixframework/SocketTest.kt index 3150211..6bf44a4 100644 --- a/src/test/kotlin/org/phoenixframework/SocketTest.kt +++ b/src/test/kotlin/org/phoenixframework/SocketTest.kt @@ -1,16 +1,16 @@ package org.phoenixframework import com.google.common.truth.Truth.assertThat -import com.nhaarman.mockitokotlin2.any -import com.nhaarman.mockitokotlin2.argumentCaptor -import com.nhaarman.mockitokotlin2.eq -import com.nhaarman.mockitokotlin2.mock -import com.nhaarman.mockitokotlin2.never -import com.nhaarman.mockitokotlin2.spy -import com.nhaarman.mockitokotlin2.times -import com.nhaarman.mockitokotlin2.verify -import com.nhaarman.mockitokotlin2.verifyZeroInteractions -import com.nhaarman.mockitokotlin2.whenever +import org.mockito.kotlin.any +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.eq +import org.mockito.kotlin.mock +import org.mockito.kotlin.never +import org.mockito.kotlin.spy +import org.mockito.kotlin.times +import org.mockito.kotlin.verify +import org.mockito.kotlin.verifyNoInteractions +import org.mockito.kotlin.whenever import okhttp3.OkHttpClient import okhttp3.Response import org.junit.jupiter.api.BeforeEach @@ -48,11 +48,12 @@ class SocketTest { internal fun `sets defaults`() { val socket = Socket("wss://localhost:4000/socket") - assertThat(socket.params).isNull() + assertThat(socket.paramsClosure.invoke()).isEmpty() assertThat(socket.channels).isEmpty() assertThat(socket.sendBuffer).isEmpty() assertThat(socket.ref).isEqualTo(0) assertThat(socket.endpoint).isEqualTo("wss://localhost:4000/socket/websocket") + assertThat(socket.vsn).isEqualTo(Defaults.VSN) assertThat(socket.stateChangeCallbacks.open).isEmpty() assertThat(socket.stateChangeCallbacks.close).isEmpty() assertThat(socket.stateChangeCallbacks.error).isEmpty() @@ -81,7 +82,7 @@ class SocketTest { socket.logger = { } socket.reconnectAfterMs = { 10 } - assertThat(socket.params).isEqualTo(mapOf("one" to 2)) + assertThat(socket.paramsClosure?.invoke()).isEqualTo(mapOf("one" to 2)) assertThat(socket.endpoint).isEqualTo("wss://localhost:4000/socket/websocket") assertThat(socket.timeout).isEqualTo(40_000) assertThat(socket.heartbeatIntervalMs).isEqualTo(60_000) @@ -94,32 +95,34 @@ class SocketTest { internal fun `constructs with a valid URL`() { // Test different schemes assertThat(Socket("http://localhost:4000/socket/websocket").endpointUrl.toString()) - .isEqualTo("http://localhost:4000/socket/websocket") + .isEqualTo("http://localhost:4000/socket/websocket?vsn=2.0.0") assertThat(Socket("https://localhost:4000/socket/websocket").endpointUrl.toString()) - .isEqualTo("https://localhost:4000/socket/websocket") + .isEqualTo("https://localhost:4000/socket/websocket?vsn=2.0.0") assertThat(Socket("ws://localhost:4000/socket/websocket").endpointUrl.toString()) - .isEqualTo("http://localhost:4000/socket/websocket") + .isEqualTo("http://localhost:4000/socket/websocket?vsn=2.0.0") assertThat(Socket("wss://localhost:4000/socket/websocket").endpointUrl.toString()) - .isEqualTo("https://localhost:4000/socket/websocket") + .isEqualTo("https://localhost:4000/socket/websocket?vsn=2.0.0") // test params val singleParam = hashMapOf("token" to "abc123") assertThat(Socket("ws://localhost:4000/socket/websocket", singleParam).endpointUrl.toString()) - .isEqualTo("http://localhost:4000/socket/websocket?token=abc123") + .isEqualTo("http://localhost:4000/socket/websocket?vsn=2.0.0&token=abc123") val multipleParams = hashMapOf("token" to "abc123", "user_id" to 1) assertThat( - Socket("http://localhost:4000/socket/websocket", multipleParams).endpointUrl.toString()) - .isEqualTo("http://localhost:4000/socket/websocket?user_id=1&token=abc123") + Socket("http://localhost:4000/socket/websocket", multipleParams).endpointUrl.toString() + ) + .isEqualTo("http://localhost:4000/socket/websocket?vsn=2.0.0&user_id=1&token=abc123") // test params with spaces val spacesParams = hashMapOf("token" to "abc 123", "user_id" to 1) assertThat( - Socket("wss://localhost:4000/socket/websocket", spacesParams).endpointUrl.toString()) - .isEqualTo("https://localhost:4000/socket/websocket?user_id=1&token=abc%20123") + Socket("wss://localhost:4000/socket/websocket", spacesParams).endpointUrl.toString() + ) + .isEqualTo("https://localhost:4000/socket/websocket?vsn=2.0.0&user_id=1&token=abc%20123") } /* End Constructor */ @@ -185,6 +188,28 @@ class SocketTest { assertThat(socket.connection).isNotNull() } + @Test + internal fun `accounts for changing parameters`() { + val transport = mock<(URL) -> Transport>() + whenever(transport.invoke(any())).thenReturn(connection) + + var token = "a" + val socket = Socket("wss://localhost:4000/socket", { mapOf("token" to token) }) + socket.transport = transport + + socket.connect() + argumentCaptor { + verify(transport).invoke(capture()) + assertThat(firstValue.query).isEqualTo("vsn=2.0.0&token=a") + + token = "b" + socket.disconnect() + socket.connect() + verify(transport, times(2)).invoke(capture()) + assertThat(lastValue.query).isEqualTo("vsn=2.0.0&token=b") + } + } + @Test internal fun `sets callbacks for connection`() { var open = 0 @@ -215,12 +240,7 @@ class SocketTest { assertThat(lastError).isNotNull() assertThat(lastResponse).isNull() - val data = mapOf( - "topic" to "topic", - "event" to "event", - "payload" to mapOf("go" to true), - "status" to "status" - ) + val data = listOf(null, null, "topic", "event", mapOf("go" to true)) val json = Defaults.gson.toJson(data) socket.connection?.onMessage?.invoke(json) @@ -258,12 +278,7 @@ class SocketTest { assertThat(lastError).isNull() assertThat(lastResponse).isNull() - val data = mapOf( - "topic" to "topic", - "event" to "event", - "payload" to mapOf("go" to true), - "status" to "status" - ) + val data = listOf(null, null, "topic", "event", mapOf("go" to true)) val json = Defaults.gson.toJson(data) socket.connection?.onMessage?.invoke(json) @@ -339,7 +354,7 @@ class SocketTest { @Test internal fun `does nothing if not connected`() { socket.disconnect() - verifyZeroInteractions(connection) + verifyNoInteractions(connection) } /* End Disconnect */ @@ -423,6 +438,29 @@ class SocketTest { /* End Remove */ } + @Nested + @DisplayName("release") + inner class Release { + @Test + internal fun `Clears any callbacks with the matching refs`() { + socket.stateChangeCallbacks.onOpen("1") {} + socket.stateChangeCallbacks.onOpen("2") {} + socket.stateChangeCallbacks.onClose("1") {} + socket.stateChangeCallbacks.onClose("2") {} + socket.stateChangeCallbacks.onError("1") { _: Throwable, _: Response? -> } + socket.stateChangeCallbacks.onError("2") { _: Throwable, _: Response? -> } + socket.stateChangeCallbacks.onMessage("1") { } + socket.stateChangeCallbacks.onMessage("2") { } + + socket.stateChangeCallbacks.release(listOf("1")) + + assertThat(socket.stateChangeCallbacks.open).doesNotContain("1") + assertThat(socket.stateChangeCallbacks.close).doesNotContain("1") + assertThat(socket.stateChangeCallbacks.error).doesNotContain("1") + assertThat(socket.stateChangeCallbacks.message).doesNotContain("1") + } + } + @Nested @DisplayName("push") inner class Push { @@ -433,9 +471,8 @@ class SocketTest { socket.connect() socket.push("topic", "event", mapOf("one" to "two"), "ref", "join-ref") - val expect = - "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"ref\",\"join_ref\":\"join-ref\"}" - verify(connection).send(expect) + val expected = "[\"join-ref\",\"ref\",\"topic\",\"event\",{\"one\":\"two\"}]" + verify(connection).send(expected) } @Test @@ -445,8 +482,8 @@ class SocketTest { socket.connect() socket.push("topic", "event", mapOf("one" to "two")) - val expect = "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"}}" - verify(connection).send(expect) + val expected = "[null,null,\"topic\",\"event\",{\"one\":\"two\"}]" + verify(connection).send(expected) } @Test @@ -462,7 +499,7 @@ class SocketTest { verify(connection, never()).send(any()) assertThat(socket.sendBuffer).hasSize(2) - socket.sendBuffer.forEach { it.invoke() } + socket.sendBuffer.forEach { it.second.invoke() } verify(connection, times(2)).send(any()) } @@ -517,7 +554,7 @@ class SocketTest { internal fun `pushes heartbeat data when connected`() { socket.sendHeartbeat() - val expected = "{\"topic\":\"phoenix\",\"event\":\"heartbeat\",\"payload\":{},\"ref\":\"1\"}" + val expected = "[null,\"1\",\"phoenix\",\"heartbeat\",{}]" assertThat(socket.pendingHeartbeatRef).isEqualTo(socket.ref.toString()) verify(connection).send(expected) } @@ -540,9 +577,9 @@ class SocketTest { @Test internal fun `invokes callbacks in buffer when connected`() { var oneCalled = 0 - socket.sendBuffer.add { oneCalled += 1 } + socket.sendBuffer.add(Pair("0", { oneCalled += 1 })) var twoCalled = 0 - socket.sendBuffer.add { twoCalled += 1 } + socket.sendBuffer.add(Pair("1", { twoCalled += 1 })) val threeCalled = 0 whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN) @@ -563,7 +600,7 @@ class SocketTest { @Test internal fun `empties send buffer`() { - socket.sendBuffer.add { } + socket.sendBuffer.add(Pair(null, {})) whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN) socket.connect() @@ -577,6 +614,30 @@ class SocketTest { /* End FlushSendBuffer */ } + @Nested + @DisplayName("removeFromSendBuffer") + inner class RemoveFromSendBuffer { + @Test + internal fun `removes a callback with matching ref`() { + var oneCalled = 0 + socket.sendBuffer.add(Pair("0", { oneCalled += 1 })) + var twoCalled = 0 + socket.sendBuffer.add(Pair("1", { twoCalled += 1 })) + + whenever(connection.readyState).thenReturn(Transport.ReadyState.OPEN) + + // connect + socket.connect() + + socket.removeFromSendBuffer("0") + + // sends once connected + socket.flushSendBuffer() + assertThat(oneCalled).isEqualTo(0) + assertThat(twoCalled).isEqualTo(1) + } + } + @Nested @DisplayName("resetHeartbeat") inner class ResetHeartbeat { @@ -593,7 +654,7 @@ class SocketTest { socket.skipHeartbeat = true socket.resetHeartbeat() - verifyZeroInteractions(mockDispatchQueue) + verifyNoInteractions(mockDispatchQueue) } @Test @@ -610,14 +671,15 @@ class SocketTest { assertThat(socket.heartbeatTask).isNotNull() argumentCaptor<() -> Unit> { - verify(mockDispatchQueue).queueAtFixedRate(eq(5_000L), eq(5_000L), - eq(TimeUnit.MILLISECONDS), capture()) + verify(mockDispatchQueue).queueAtFixedRate( + eq(5_000L), eq(5_000L), + eq(TimeUnit.MILLISECONDS), capture() + ) // fire the task allValues.first().invoke() - val expected = - "{\"topic\":\"phoenix\",\"event\":\"heartbeat\",\"payload\":{},\"ref\":\"1\"}" + val expected = "[null,\"1\",\"phoenix\",\"heartbeat\",{}]" verify(connection).send(expected) } } @@ -638,7 +700,7 @@ class SocketTest { @Test internal fun `flushes the send buffer`() { var oneCalled = 0 - socket.sendBuffer.add { oneCalled += 1 } + socket.sendBuffer.add(Pair("1", { oneCalled += 1 })) socket.onConnectionOpened() assertThat(oneCalled).isEqualTo(1) @@ -878,6 +940,8 @@ class SocketTest { @Nested @DisplayName("onConnectionMessage") inner class OnConnectionMessage { + + @Test internal fun `parses raw messages and triggers channel event`() { val targetChannel = mock() @@ -888,8 +952,7 @@ class SocketTest { socket.channels = socket.channels.plus(targetChannel) socket.channels = socket.channels.minus(otherChannel) - val rawMessage = - "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}" + val rawMessage = "[null,null,\"topic\",\"event\",{\"one\":\"two\",\"status\":\"ok\"}]" socket.onConnectionMessage(rawMessage) verify(targetChannel).trigger(message = any()) @@ -901,8 +964,7 @@ class SocketTest { var message: Message? = null socket.onMessage { message = it } - val rawMessage = - "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"status\":\"ok\"}" + val rawMessage = "[null,null,\"topic\",\"event\",{\"one\":\"two\",\"status\":\"ok\"}]" socket.onConnectionMessage(rawMessage) assertThat(message?.topic).isEqualTo("topic") @@ -913,8 +975,7 @@ class SocketTest { internal fun `clears pending heartbeat`() { socket.pendingHeartbeatRef = "5" - val rawMessage = - "{\"topic\":\"topic\",\"event\":\"event\",\"payload\":{\"one\":\"two\"},\"ref\":\"5\"}" + val rawMessage = "[null,\"5\",\"topic\",\"event\",{\"one\":\"two\",\"status\":\"ok\"}]" socket.onConnectionMessage(rawMessage) assertThat(socket.pendingHeartbeatRef).isNull() } @@ -922,7 +983,6 @@ class SocketTest { /* End OnConnectionMessage */ } - @Nested @DisplayName("ConcurrentModificationException") inner class ConcurrentModificationExceptionTests { @@ -967,7 +1027,7 @@ class SocketTest { internal fun `onError does not throw`() { var oneCalled = 0 var twoCalled = 0 - socket.onError { _, _-> + socket.onError { _, _ -> socket.onError { _, _ -> twoCalled += 1 } oneCalled += 1 } @@ -990,11 +1050,12 @@ class SocketTest { oneCalled += 1 } - socket.onConnectionMessage("{\"status\":\"ok\"}") + val message = "[null,null,\"room:lobby\",\"shout\",{\"message\":\"Hi\",\"name\":\"Tester\"}]" + socket.onConnectionMessage(message) assertThat(oneCalled).isEqualTo(1) assertThat(twoCalled).isEqualTo(0) - socket.onConnectionMessage("{\"status\":\"ok\"}") + socket.onConnectionMessage(message) assertThat(oneCalled).isEqualTo(2) assertThat(twoCalled).isEqualTo(1) } @@ -1013,6 +1074,4 @@ class SocketTest { /* End ConcurrentModificationExceptionTests */ } - - -} \ No newline at end of file +} diff --git a/src/test/kotlin/org/phoenixframework/TimeoutTimerTest.kt b/src/test/kotlin/org/phoenixframework/TimeoutTimerTest.kt index df6792b..29ac4ea 100644 --- a/src/test/kotlin/org/phoenixframework/TimeoutTimerTest.kt +++ b/src/test/kotlin/org/phoenixframework/TimeoutTimerTest.kt @@ -1,12 +1,12 @@ package org.phoenixframework import com.google.common.truth.Truth -import com.nhaarman.mockitokotlin2.any -import com.nhaarman.mockitokotlin2.argumentCaptor -import com.nhaarman.mockitokotlin2.eq -import com.nhaarman.mockitokotlin2.times -import com.nhaarman.mockitokotlin2.verify -import com.nhaarman.mockitokotlin2.whenever +import org.mockito.kotlin.any +import org.mockito.kotlin.argumentCaptor +import org.mockito.kotlin.eq +import org.mockito.kotlin.times +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.mockito.Mock diff --git a/src/test/kotlin/org/phoenixframework/WebSocketTransportTest.kt b/src/test/kotlin/org/phoenixframework/WebSocketTransportTest.kt index 37ed0dc..b9a7846 100644 --- a/src/test/kotlin/org/phoenixframework/WebSocketTransportTest.kt +++ b/src/test/kotlin/org/phoenixframework/WebSocketTransportTest.kt @@ -1,10 +1,10 @@ package org.phoenixframework import com.google.common.truth.Truth.assertThat -import com.nhaarman.mockitokotlin2.any -import com.nhaarman.mockitokotlin2.mock -import com.nhaarman.mockitokotlin2.verify -import com.nhaarman.mockitokotlin2.whenever +import org.mockito.kotlin.any +import org.mockito.kotlin.mock +import org.mockito.kotlin.verify +import org.mockito.kotlin.whenever import okhttp3.OkHttpClient import okhttp3.Response import okhttp3.WebSocket @@ -136,6 +136,7 @@ class WebSocketTransportTest { transport.readyState = Transport.ReadyState.OPEN transport.onClosing(mockWebSocket, 10, "reason") + verify(mockWebSocket).close(10, "reason") assertThat(transport.readyState).isEqualTo(Transport.ReadyState.CLOSING) }