Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7dc5b5b

Browse filesBrowse files
authored
Streaming support (TheoKanning#195)
Utilize retrofit2.http.Streaming and retrofit2.Call<ResponseBody> in additional OpenAIApi methods to enable a streamable ResponseBody. Utilize retrofit2.Callback to get the streamable ResponseBody, parse Server Sent Events (SSE) and emit them using io.reactivex.FlowableEmitter. Enable: - Streaming of raw bytes - Streaming of Java objects - Shutdown of OkHttp ExecutorService Fixes: TheoKanning#51, TheoKanning#83, TheoKanning#182, TheoKanning#184
1 parent a44d79b commit 7dc5b5b
Copy full SHA for 7dc5b5b

File tree

Expand file treeCollapse file tree

9 files changed

+403
-8
lines changed
Filter options
Expand file treeCollapse file tree

9 files changed

+403
-8
lines changed
+37Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.theokanning.openai.completion;
2+
3+
import lombok.Data;
4+
import java.util.List;
5+
6+
/**
7+
* Object containing a response chunk from the completions streaming api.
8+
*
9+
* https://beta.openai.com/docs/api-reference/completions/create
10+
*/
11+
@Data
12+
public class CompletionChunk {
13+
/**
14+
* A unique id assigned to this completion.
15+
*/
16+
String id;
17+
18+
/**https://beta.openai.com/docs/api-reference/create-completion
19+
* The type of object returned, should be "text_completion"
20+
*/
21+
String object;
22+
23+
/**
24+
* The creation time in epoch seconds.
25+
*/
26+
long created;
27+
28+
/**
29+
* The GPT-3 model used.
30+
*/
31+
String model;
32+
33+
/**
34+
* A list of generated completions.
35+
*/
36+
List<CompletionChoice> choices;
37+
}

‎api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChoice.java

Copy file name to clipboardExpand all lines: api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionChoice.java
+3-1Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
package com.theokanning.openai.completion.chat;
2+
import com.fasterxml.jackson.annotation.JsonAlias;
23
import com.fasterxml.jackson.annotation.JsonProperty;
34
import lombok.Data;
45

@@ -14,8 +15,9 @@ public class ChatCompletionChoice {
1415
Integer index;
1516

1617
/**
17-
* The {@link ChatMessageRole#assistant} message which was generated.
18+
* The {@link ChatMessageRole#assistant} message or delta (when streaming) which was generated
1819
*/
20+
@JsonAlias("delta")
1921
ChatMessage message;
2022

2123
/**
+35Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.theokanning.openai.completion.chat;
2+
import lombok.Data;
3+
4+
import java.util.List;
5+
6+
/**
7+
* Object containing a response chunk from the chat completions streaming api.
8+
*/
9+
@Data
10+
public class ChatCompletionChunk {
11+
/**
12+
* Unique id assigned to this chat completion.
13+
*/
14+
String id;
15+
16+
/**
17+
* The type of object returned, should be "chat.completion.chunk"
18+
*/
19+
String object;
20+
21+
/**
22+
* The creation time in epoch seconds.
23+
*/
24+
long created;
25+
26+
/**
27+
* The GPT-3.5 model used.
28+
*/
29+
String model;
30+
31+
/**
32+
* A list of all generated completions.
33+
*/
34+
List<ChatCompletionChoice> choices;
35+
}

‎client/src/main/java/com/theokanning/openai/OpenAiApi.java

Copy file name to clipboardExpand all lines: client/src/main/java/com/theokanning/openai/OpenAiApi.java
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import io.reactivex.Single;
2323
import okhttp3.MultipartBody;
2424
import okhttp3.RequestBody;
25+
import okhttp3.ResponseBody;
26+
import retrofit2.Call;
2527
import retrofit2.http.*;
2628

2729
public interface OpenAiApi {
@@ -34,10 +36,18 @@ public interface OpenAiApi {
3436

3537
@POST("/v1/completions")
3638
Single<CompletionResult> createCompletion(@Body CompletionRequest request);
39+
40+
@Streaming
41+
@POST("/v1/completions")
42+
Call<ResponseBody> createCompletionStream(@Body CompletionRequest request);
3743

3844
@POST("/v1/chat/completions")
3945
Single<ChatCompletionResult> createChatCompletion(@Body ChatCompletionRequest request);
4046

47+
@Streaming
48+
@POST("/v1/chat/completions")
49+
Call<ResponseBody> createChatCompletionStream(@Body ChatCompletionRequest request);
50+
4151
@Deprecated
4252
@POST("/v1/engines/{engine_id}/completions")
4353
Single<CompletionResult> createCompletion(@Path("engine_id") String engineId, @Body CompletionRequest request);
+79Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package example;
2+
3+
import com.theokanning.openai.service.OpenAiService;
4+
5+
import java.util.ArrayList;
6+
import java.util.HashMap;
7+
import java.util.List;
8+
9+
import com.theokanning.openai.completion.CompletionRequest;
10+
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
11+
import com.theokanning.openai.completion.chat.ChatMessage;
12+
import com.theokanning.openai.completion.chat.ChatMessageRole;
13+
14+
public class OpenAiApiStreamExample {
15+
public static void main(String... args) {
16+
String token = System.getenv("OPENAI_TOKEN");
17+
OpenAiService service = new OpenAiService(token);
18+
19+
System.out.println("\nCreating completion...");
20+
CompletionRequest completionRequest = CompletionRequest.builder()
21+
.model("ada")
22+
.prompt("Somebody once told me the world is gonna roll me")
23+
.echo(true)
24+
.user("testing")
25+
.n(3)
26+
.build();
27+
28+
/*
29+
Note: when using blockingForEach the calling Thread waits until the loop finishes.
30+
Use forEach instaed of blockignForEach if you don't want the calling Thread to wait.
31+
*/
32+
33+
// stream raw bytes
34+
service
35+
.streamCompletionBytes(completionRequest)
36+
.doOnError( e -> {
37+
e.printStackTrace();
38+
})
39+
.blockingForEach( bytes -> {
40+
System.out.print(new String(bytes));
41+
});
42+
43+
// stream CompletionChunks
44+
service
45+
.streamCompletion(completionRequest)
46+
.doOnError( e -> {
47+
e.printStackTrace();
48+
})
49+
.blockingForEach(System.out::println);
50+
51+
52+
final List<ChatMessage> messages = new ArrayList<>();
53+
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a dog and will speak as such.");
54+
messages.add(systemMessage);
55+
56+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
57+
.builder()
58+
.model("gpt-3.5-turbo")
59+
.messages(messages)
60+
.n(5)
61+
.maxTokens(50)
62+
.logitBias(new HashMap<>())
63+
.build();
64+
65+
// stream ChatCompletionChunks
66+
service
67+
.streamChatCompletion(chatCompletionRequest)
68+
.doOnError( e -> {
69+
e.printStackTrace();
70+
})
71+
.blockingForEach(System.out::println);
72+
73+
/*
74+
* shutdown the OkHttp ExecutorService to
75+
* exit immediately after the loops have finished
76+
*/
77+
service.shutdownExecutor();
78+
}
79+
}

‎service/src/main/java/com/theokanning/openai/service/OpenAiService.java

Copy file name to clipboardExpand all lines: service/src/main/java/com/theokanning/openai/service/OpenAiService.java
+100-7Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
import com.theokanning.openai.OpenAiApi;
99
import com.theokanning.openai.OpenAiError;
1010
import com.theokanning.openai.OpenAiHttpException;
11+
import com.theokanning.openai.completion.CompletionChunk;
1112
import com.theokanning.openai.completion.CompletionRequest;
1213
import com.theokanning.openai.completion.CompletionResult;
14+
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
1315
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
1416
import com.theokanning.openai.completion.chat.ChatCompletionResult;
1517
import com.theokanning.openai.edit.EditRequest;
@@ -27,17 +29,22 @@
2729
import com.theokanning.openai.model.Model;
2830
import com.theokanning.openai.moderation.ModerationRequest;
2931
import com.theokanning.openai.moderation.ModerationResult;
32+
33+
import io.reactivex.BackpressureStrategy;
34+
import io.reactivex.Flowable;
3035
import io.reactivex.Single;
3136
import okhttp3.*;
3237
import retrofit2.HttpException;
3338
import retrofit2.Retrofit;
39+
import retrofit2.Call;
3440
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
3541
import retrofit2.converter.jackson.JacksonConverterFactory;
3642

3743
import java.io.IOException;
3844
import java.time.Duration;
3945
import java.util.List;
4046
import java.util.Objects;
47+
import java.util.concurrent.ExecutorService;
4148
import java.util.concurrent.TimeUnit;
4249

4350
public class OpenAiService {
@@ -47,6 +54,7 @@ public class OpenAiService {
4754
private static final ObjectMapper errorMapper = defaultObjectMapper();
4855

4956
private final OpenAiApi api;
57+
private final ExecutorService executorService;
5058

5159
/**
5260
* Creates a new OpenAiService that wraps OpenAiApi
@@ -64,17 +72,29 @@ public OpenAiService(final String token) {
6472
* @param timeout http read timeout, Duration.ZERO means no timeout
6573
*/
6674
public OpenAiService(final String token, final Duration timeout) {
67-
this(buildApi(token, timeout));
75+
this(defaultClient(token, timeout));
76+
}
77+
78+
/**
79+
* Creates a new OpenAiService that wraps OpenAiApi
80+
*
81+
* @param client OkHttpClient to be used for api calls
82+
*/
83+
public OpenAiService(OkHttpClient client){
84+
this(buildApi(client), client.dispatcher().executorService());
6885
}
6986

7087
/**
7188
* Creates a new OpenAiService that wraps OpenAiApi.
72-
* Use this if you need more customization.
89+
* The ExecutoryService must be the one you get from the client you created the api with
90+
* otherwise shutdownExecutor() won't work. Use this if you need more customization.
7391
*
7492
* @param api OpenAiApi instance to use for all methods
93+
* @param executorService the ExecutorService from client.dispatcher().executorService()
7594
*/
76-
public OpenAiService(final OpenAiApi api) {
95+
public OpenAiService(final OpenAiApi api, final ExecutorService executorService) {
7796
this.api = api;
97+
this.executorService = executorService;
7898
}
7999

80100
public List<Model> listModels() {
@@ -88,11 +108,39 @@ public Model getModel(String modelId) {
88108
public CompletionResult createCompletion(CompletionRequest request) {
89109
return execute(api.createCompletion(request));
90110
}
111+
112+
public Flowable<byte[]> streamCompletionBytes(CompletionRequest request) {
113+
request.setStream(true);
114+
115+
return stream(api.createCompletionStream(request), true).map(sse -> {
116+
return sse.toBytes();
117+
});
118+
}
119+
120+
public Flowable<CompletionChunk> streamCompletion(CompletionRequest request) {
121+
request.setStream(true);
122+
123+
return stream(api.createCompletionStream(request), CompletionChunk.class);
124+
}
91125

92126
public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) {
93127
return execute(api.createChatCompletion(request));
94128
}
95129

130+
public Flowable<byte[]> streamChatCompletionBytes(ChatCompletionRequest request) {
131+
request.setStream(true);
132+
133+
return stream(api.createChatCompletionStream(request), true).map(sse -> {
134+
return sse.toBytes();
135+
});
136+
}
137+
138+
public Flowable<ChatCompletionChunk> streamChatCompletion(ChatCompletionRequest request) {
139+
request.setStream(true);
140+
141+
return stream(api.createChatCompletionStream(request), ChatCompletionChunk.class);
142+
}
143+
96144
public EditResult createEdit(EditRequest request) {
97145
return execute(api.createEdit(request));
98146
}
@@ -232,12 +280,55 @@ public static <T> T execute(Single<T> apiCall) {
232280
}
233281
}
234282

235-
public static OpenAiApi buildApi(String token, Duration timeout) {
236-
Objects.requireNonNull(token, "OpenAI token required");
283+
/**
284+
* Calls the Open AI api and returns a Flowable of SSE for streaming
285+
* omitting the last message.
286+
*
287+
* @param apiCall The api call
288+
*/
289+
public static Flowable<SSE> stream(Call<ResponseBody> apiCall) {
290+
return stream(apiCall, false);
291+
}
292+
293+
/**
294+
* Calls the Open AI api and returns a Flowable of SSE for streaming.
295+
*
296+
* @param apiCall The api call
297+
* @param emitDone If true the last message ([DONE]) is emitted
298+
*/
299+
public static Flowable<SSE> stream(Call<ResponseBody> apiCall, boolean emitDone) {
300+
return Flowable.create(emitter -> {
301+
apiCall.enqueue(new ResponseBodyCallback(emitter, emitDone));
302+
}, BackpressureStrategy.BUFFER);
303+
}
304+
305+
/**
306+
* Calls the Open AI api and returns a Flowable of type T for streaming
307+
* omitting the last message.
308+
*
309+
* @param apiCall The api call
310+
* @param cl Class of type T to return
311+
*/
312+
public static <T> Flowable<T> stream(Call<ResponseBody> apiCall, Class<T> cl) {
313+
return stream(apiCall).map(sse -> {
314+
return errorMapper.readValue(sse.getData(), cl);
315+
});
316+
}
317+
318+
/**
319+
* Shuts down the OkHttp ExecutorService.
320+
* The default behaviour of OkHttp's ExecutorService (ConnectionPool)
321+
* is to shutdown after an idle timeout of 60s.
322+
* Call this method to shutdown the ExecutorService immediately.
323+
*/
324+
public void shutdownExecutor(){
325+
this.executorService.shutdown();
326+
}
327+
328+
public static OpenAiApi buildApi(OkHttpClient client) {
237329
ObjectMapper mapper = defaultObjectMapper();
238-
OkHttpClient client = defaultClient(token, timeout);
239330
Retrofit retrofit = defaultRetrofit(client, mapper);
240-
331+
241332
return retrofit.create(OpenAiApi.class);
242333
}
243334

@@ -250,6 +341,8 @@ public static ObjectMapper defaultObjectMapper() {
250341
}
251342

252343
public static OkHttpClient defaultClient(String token, Duration timeout) {
344+
Objects.requireNonNull(token, "OpenAI token required");
345+
253346
return new OkHttpClient.Builder()
254347
.addInterceptor(new AuthenticationInterceptor(token))
255348
.connectionPool(new ConnectionPool(5, 1, TimeUnit.SECONDS))

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.