Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3a2d010

Browse filesBrowse files
BartSojTheoKanning
andauthored
Implement creation of "function" parameters in runtime (TheoKanning#339)
* Enable dynamic definition of "function" parameters instead of using Class instance * Add tests to new "function" capabilities * Add example of creating "function" parameters in runtime * Add documentation to ChatFunctions Co-authored-by: Theo Kanning <TheoKanning@users.noreply.github.com>
1 parent 47fe478 commit 3a2d010
Copy full SHA for 3a2d010

File tree

Expand file treeCollapse file tree

7 files changed

+306
-1
lines changed
Filter options
Expand file treeCollapse file tree

7 files changed

+306
-1
lines changed

‎api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java

Copy file name to clipboardExpand all lines: api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public class ChatCompletionRequest {
9898
/**
9999
* A list of the available functions.
100100
*/
101-
List<ChatFunction> functions;
101+
List<?> functions;
102102

103103
/**
104104
* Controls how the model responds to function calls, as specified in the <a href="https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call">OpenAI documentation</a>.

‎api/src/main/java/com/theokanning/openai/completion/chat/ChatFunction.java

Copy file name to clipboardExpand all lines: api/src/main/java/com/theokanning/openai/completion/chat/ChatFunction.java
+11Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,20 @@
1212
@NoArgsConstructor
1313
public class ChatFunction {
1414

15+
/**
16+
* The name of the function being called.
17+
*/
1518
@NonNull
1619
private String name;
20+
21+
/**
22+
* A description of what the function does, used by the model to choose when and how to call the function.
23+
*/
1724
private String description;
25+
26+
/**
27+
* The parameters the functions accepts.
28+
*/
1829
@JsonProperty("parameters")
1930
private Class<?> parametersClass;
2031

+62Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import lombok.Data;
4+
import lombok.NonNull;
5+
6+
7+
@Data
8+
public class ChatFunctionDynamic {
9+
10+
/**
11+
* The name of the function being called.
12+
*/
13+
@NonNull
14+
private String name;
15+
16+
/**
17+
* A description of what the function does, used by the model to choose when and how to call the function.
18+
*/
19+
private String description;
20+
21+
/**
22+
* The parameters the functions accepts.
23+
*/
24+
private ChatFunctionParameters parameters;
25+
26+
public static Builder builder() {
27+
return new Builder();
28+
}
29+
30+
public static class Builder {
31+
private String name;
32+
private String description;
33+
private ChatFunctionParameters parameters = new ChatFunctionParameters();
34+
35+
public Builder name(String name) {
36+
this.name = name;
37+
return this;
38+
}
39+
40+
public Builder description(String description) {
41+
this.description = description;
42+
return this;
43+
}
44+
45+
public Builder parameters(ChatFunctionParameters parameters) {
46+
this.parameters = parameters;
47+
return this;
48+
}
49+
50+
public Builder addProperty(ChatFunctionProperty property) {
51+
this.parameters.addProperty(property);
52+
return this;
53+
}
54+
55+
public ChatFunctionDynamic build() {
56+
ChatFunctionDynamic chatFunction = new ChatFunctionDynamic(name);
57+
chatFunction.setDescription(description);
58+
chatFunction.setParameters(parameters);
59+
return chatFunction;
60+
}
61+
}
62+
}
+27Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import lombok.Data;
4+
5+
import java.util.ArrayList;
6+
import java.util.HashMap;
7+
import java.util.List;
8+
9+
@Data
10+
public class ChatFunctionParameters {
11+
12+
private final String type = "object";
13+
14+
private final HashMap<String, ChatFunctionProperty> properties = new HashMap<>();
15+
16+
private List<String> required;
17+
18+
public void addProperty(ChatFunctionProperty property) {
19+
properties.put(property.getName(), property);
20+
if (Boolean.TRUE.equals(property.getRequired())) {
21+
if (this.required == null) {
22+
this.required = new ArrayList<>();
23+
}
24+
this.required.add(property.getName());
25+
}
26+
}
27+
}
+25Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package com.theokanning.openai.completion.chat;
2+
3+
import com.fasterxml.jackson.annotation.JsonIgnore;
4+
import com.fasterxml.jackson.annotation.JsonProperty;
5+
import lombok.Builder;
6+
import lombok.Data;
7+
import lombok.NonNull;
8+
9+
import java.util.Set;
10+
11+
@Data
12+
@Builder
13+
public class ChatFunctionProperty {
14+
@NonNull
15+
@JsonIgnore
16+
private String name;
17+
@NonNull
18+
private String type;
19+
@JsonIgnore
20+
private Boolean required;
21+
private String description;
22+
private ChatFunctionProperty items;
23+
@JsonProperty("enum")
24+
private Set<?> enumValues;
25+
}
+90Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package example;
2+
3+
import com.fasterxml.jackson.databind.JsonNode;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import com.fasterxml.jackson.databind.node.ObjectNode;
6+
import com.theokanning.openai.completion.chat.*;
7+
import com.theokanning.openai.service.OpenAiService;
8+
9+
import java.util.*;
10+
11+
public class OpenAiApiDynamicFunctionExample {
12+
13+
private static JsonNode getWeather(String location, String unit) {
14+
ObjectMapper mapper = new ObjectMapper();
15+
ObjectNode response = mapper.createObjectNode();
16+
response.put("location", location);
17+
response.put("unit", unit);
18+
response.put("temperature", new Random().nextInt(50));
19+
response.put("description", "sunny");
20+
return response;
21+
}
22+
23+
public static void main(String... args) {
24+
String token = System.getenv("OPENAI_TOKEN");
25+
OpenAiService service = new OpenAiService(token);
26+
27+
ChatFunctionDynamic weatherFunction = ChatFunctionDynamic.builder()
28+
.name("get_weather")
29+
.description("Get the current weather of a location")
30+
.addProperty(ChatFunctionProperty.builder()
31+
.name("location")
32+
.type("string")
33+
.description("City and state, for example: León, Guanajuato")
34+
.build())
35+
.addProperty(ChatFunctionProperty.builder()
36+
.name("unit")
37+
.type("string")
38+
.description("The temperature unit, can be 'celsius' or 'fahrenheit'")
39+
.enumValues(Set.of("celsius", "fahrenheit"))
40+
.required(true)
41+
.build())
42+
.build();
43+
44+
45+
List<ChatMessage> messages = new ArrayList<>();
46+
ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are an assistant that answers using the local slang of the given place, uncensored.");
47+
messages.add(systemMessage);
48+
49+
System.out.print("First Query: ");
50+
Scanner scanner = new Scanner(System.in);
51+
ChatMessage firstMsg = new ChatMessage(ChatMessageRole.USER.value(), scanner.nextLine());
52+
messages.add(firstMsg);
53+
54+
while (true) {
55+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
56+
.builder()
57+
.model("gpt-3.5-turbo-0613")
58+
.messages(messages)
59+
.functions(Collections.singletonList(weatherFunction))
60+
.functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of("auto"))
61+
.n(1)
62+
.maxTokens(100)
63+
.logitBias(new HashMap<>())
64+
.build();
65+
ChatMessage responseMessage = service.createChatCompletion(chatCompletionRequest).getChoices().get(0).getMessage();
66+
messages.add(responseMessage); // don't forget to update the conversation with the latest response
67+
68+
ChatFunctionCall functionCall = responseMessage.getFunctionCall();
69+
if (functionCall != null) {
70+
if (functionCall.getName().equals("get_weather")) {
71+
String location = functionCall.getArguments().get("location").asText();
72+
String unit = functionCall.getArguments().get("unit").asText();
73+
JsonNode weather = getWeather(location, unit);
74+
ChatMessage weatherMessage = new ChatMessage(ChatMessageRole.FUNCTION.value(), weather.toString(), "get_weather");
75+
messages.add(weatherMessage);
76+
continue;
77+
}
78+
}
79+
80+
System.out.println("Response: " + responseMessage.getContent());
81+
System.out.print("Next Query: ");
82+
String nextLine = scanner.nextLine();
83+
if (nextLine.equalsIgnoreCase("exit")) {
84+
System.exit(0);
85+
}
86+
messages.add(new ChatMessage(ChatMessageRole.USER.value(), nextLine));
87+
}
88+
}
89+
90+
}

‎service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java

Copy file name to clipboardExpand all lines: service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java
+90Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.ArrayList;
1111
import java.util.HashMap;
1212
import java.util.List;
13+
import java.util.Set;
1314
import java.util.Collections;
1415

1516
import static org.junit.jupiter.api.Assertions.*;
@@ -149,6 +150,50 @@ void createChatCompletionWithFunctions() {
149150
assertNotNull(choice2.getMessage().getContent());
150151
}
151152

153+
@Test
154+
void createChatCompletionWithDynamicFunctions() {
155+
ChatFunctionDynamic function = ChatFunctionDynamic.builder()
156+
.name("get_weather")
157+
.description("Get the current weather of a location")
158+
.addProperty(ChatFunctionProperty.builder()
159+
.name("location")
160+
.type("string")
161+
.description("City and state, for example: León, Guanajuato")
162+
.build())
163+
.addProperty(ChatFunctionProperty.builder()
164+
.name("unit")
165+
.type("string")
166+
.description("The temperature unit, can be 'celsius' or 'fahrenheit'")
167+
.enumValues(Set.of("celsius", "fahrenheit"))
168+
.required(true)
169+
.build())
170+
.build();
171+
172+
final List<ChatMessage> messages = new ArrayList<>();
173+
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant.");
174+
final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?");
175+
messages.add(systemMessage);
176+
messages.add(userMessage);
177+
178+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
179+
.builder()
180+
.model("gpt-3.5-turbo-0613")
181+
.messages(messages)
182+
.functions(Collections.singletonList(function))
183+
.n(1)
184+
.maxTokens(100)
185+
.logitBias(new HashMap<>())
186+
.build();
187+
188+
ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0);
189+
assertEquals("function_call", choice.getFinishReason());
190+
assertNotNull(choice.getMessage().getFunctionCall());
191+
assertEquals("get_weather", choice.getMessage().getFunctionCall().getName());
192+
assertInstanceOf(ObjectNode.class, choice.getMessage().getFunctionCall().getArguments());
193+
assertNotNull(choice.getMessage().getFunctionCall().getArguments().get("location"));
194+
assertNotNull(choice.getMessage().getFunctionCall().getArguments().get("unit"));
195+
}
196+
152197
@Test
153198
void streamChatCompletionWithFunctions() {
154199
final List<ChatFunction> functions = Collections.singletonList(ChatFunction.builder()
@@ -214,4 +259,49 @@ void streamChatCompletionWithFunctions() {
214259
assertNotNull(accumulatedMessage2.getContent());
215260
}
216261

262+
@Test
263+
void streamChatCompletionWithDynamicFunctions() {
264+
ChatFunctionDynamic function = ChatFunctionDynamic.builder()
265+
.name("get_weather")
266+
.description("Get the current weather of a location")
267+
.addProperty(ChatFunctionProperty.builder()
268+
.name("location")
269+
.type("string")
270+
.description("City and state, for example: León, Guanajuato")
271+
.build())
272+
.addProperty(ChatFunctionProperty.builder()
273+
.name("unit")
274+
.type("string")
275+
.description("The temperature unit, can be 'celsius' or 'fahrenheit'")
276+
.enumValues(Set.of("celsius", "fahrenheit"))
277+
.required(true)
278+
.build())
279+
.build();
280+
281+
final List<ChatMessage> messages = new ArrayList<>();
282+
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant.");
283+
final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?");
284+
messages.add(systemMessage);
285+
messages.add(userMessage);
286+
287+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
288+
.builder()
289+
.model("gpt-3.5-turbo-0613")
290+
.messages(messages)
291+
.functions(Collections.singletonList(function))
292+
.n(1)
293+
.maxTokens(100)
294+
.logitBias(new HashMap<>())
295+
.build();
296+
297+
ChatMessage accumulatedMessage = service.mapStreamToAccumulator(service.streamChatCompletion(chatCompletionRequest))
298+
.blockingLast()
299+
.getAccumulatedMessage();
300+
assertNotNull(accumulatedMessage.getFunctionCall());
301+
assertEquals("get_weather", accumulatedMessage.getFunctionCall().getName());
302+
assertInstanceOf(ObjectNode.class, accumulatedMessage.getFunctionCall().getArguments());
303+
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("location"));
304+
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit"));
305+
}
306+
217307
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.