|
| 1 | +package com.theokanning.openai.service; |
| 2 | + |
| 3 | +import com.fasterxml.jackson.annotation.JsonInclude; |
| 4 | +import com.fasterxml.jackson.core.JsonProcessingException; |
| 5 | +import com.fasterxml.jackson.core.type.TypeReference; |
| 6 | +import com.fasterxml.jackson.databind.DeserializationFeature; |
| 7 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 8 | +import com.fasterxml.jackson.databind.PropertyNamingStrategy; |
| 9 | +import com.theokanning.openai.ListSearchParameters; |
| 10 | +import com.theokanning.openai.OpenAiResponse; |
| 11 | +import com.theokanning.openai.assistants.Assistant; |
| 12 | +import com.theokanning.openai.assistants.AssistantFunction; |
| 13 | +import com.theokanning.openai.assistants.AssistantRequest; |
| 14 | +import com.theokanning.openai.assistants.AssistantToolsEnum; |
| 15 | +import com.theokanning.openai.assistants.Tool; |
| 16 | +import com.theokanning.openai.completion.chat.ChatCompletionRequest; |
| 17 | +import com.theokanning.openai.completion.chat.ChatFunction; |
| 18 | +import com.theokanning.openai.completion.chat.ChatFunctionCall; |
| 19 | +import com.theokanning.openai.messages.Message; |
| 20 | +import com.theokanning.openai.messages.MessageRequest; |
| 21 | +import com.theokanning.openai.runs.RequiredAction; |
| 22 | +import com.theokanning.openai.runs.Run; |
| 23 | +import com.theokanning.openai.runs.RunCreateRequest; |
| 24 | +import com.theokanning.openai.runs.RunStep; |
| 25 | +import com.theokanning.openai.runs.SubmitToolOutputRequestItem; |
| 26 | +import com.theokanning.openai.runs.SubmitToolOutputs; |
| 27 | +import com.theokanning.openai.runs.SubmitToolOutputsRequest; |
| 28 | +import com.theokanning.openai.runs.ToolCall; |
| 29 | +import com.theokanning.openai.threads.Thread; |
| 30 | +import com.theokanning.openai.threads.ThreadRequest; |
| 31 | +import com.theokanning.openai.utils.TikTokensUtil; |
| 32 | +import org.junit.jupiter.api.Test; |
| 33 | + |
| 34 | +import java.time.Duration; |
| 35 | +import java.util.ArrayList; |
| 36 | +import java.util.List; |
| 37 | +import java.util.Map; |
| 38 | +import java.util.Objects; |
| 39 | + |
| 40 | +import static org.junit.jupiter.api.Assertions.assertEquals; |
| 41 | +import static org.junit.jupiter.api.Assertions.assertNotNull; |
| 42 | + |
| 43 | +class AssistantFunctionTest { |
| 44 | + String token = System.getenv("OPENAI_TOKEN"); |
| 45 | + OpenAiService service = new OpenAiService(token, Duration.ofMinutes(1)); |
| 46 | + |
| 47 | + @Test |
| 48 | + void createRetrieveRun() throws JsonProcessingException { |
| 49 | + |
| 50 | + ObjectMapper mapper = new ObjectMapper(); |
| 51 | + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); |
| 52 | + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); |
| 53 | + mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); |
| 54 | + mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class); |
| 55 | + mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class); |
| 56 | + mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class); |
| 57 | + |
| 58 | + String funcDef = "{\n" + |
| 59 | + " \"type\": \"object\",\n" + |
| 60 | + " \"properties\": {\n" + |
| 61 | + " \"location\": {\n" + |
| 62 | + " \"type\": \"string\",\n" + |
| 63 | + " \"description\": \"The city and state, e.g. San Francisco, CA\"\n" + |
| 64 | + " },\n" + |
| 65 | + " \"unit\": {\n" + |
| 66 | + " \"type\": \"string\",\n" + |
| 67 | + " \"enum\": [\"celsius\", \"fahrenheit\"]\n" + |
| 68 | + " }\n" + |
| 69 | + " },\n" + |
| 70 | + " \"required\": [\"location\"]\n" + |
| 71 | + "}"; |
| 72 | + Map<String, Object> funcParameters = mapper.readValue(funcDef, new TypeReference<Map<String, Object>>() {}); |
| 73 | + AssistantFunction function = AssistantFunction.builder() |
| 74 | + .name("weather_reporter") |
| 75 | + .description("Get the current weather of a location") |
| 76 | + .parameters(funcParameters) |
| 77 | + .build(); |
| 78 | + |
| 79 | + List<Tool> toolList = new ArrayList<>(); |
| 80 | + Tool funcTool = new Tool(AssistantToolsEnum.FUNCTION, function); |
| 81 | + toolList.add(funcTool); |
| 82 | + |
| 83 | + |
| 84 | + AssistantRequest assistantRequest = AssistantRequest.builder() |
| 85 | + .model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName()) |
| 86 | + .name("MATH_TUTOR") |
| 87 | + .instructions("You are a personal Math Tutor.") |
| 88 | + .tools(toolList) |
| 89 | + .build(); |
| 90 | + Assistant assistant = service.createAssistant(assistantRequest); |
| 91 | + |
| 92 | + ThreadRequest threadRequest = ThreadRequest.builder() |
| 93 | + .build(); |
| 94 | + Thread thread = service.createThread(threadRequest); |
| 95 | + |
| 96 | + MessageRequest messageRequest = MessageRequest.builder() |
| 97 | + .content("What's the weather of Xiamen?") |
| 98 | + .build(); |
| 99 | + |
| 100 | + Message message = service.createMessage(thread.getId(), messageRequest); |
| 101 | + |
| 102 | + RunCreateRequest runCreateRequest = RunCreateRequest.builder() |
| 103 | + .assistantId(assistant.getId()) |
| 104 | + .build(); |
| 105 | + |
| 106 | + Run run = service.createRun(thread.getId(), runCreateRequest); |
| 107 | + assertNotNull(run); |
| 108 | + |
| 109 | + Run retrievedRun = service.retrieveRun(thread.getId(), run.getId()); |
| 110 | + while (!(retrievedRun.getStatus().equals("completed")) |
| 111 | + && !(retrievedRun.getStatus().equals("failed")) |
| 112 | + && !(retrievedRun.getStatus().equals("requires_action"))){ |
| 113 | + retrievedRun = service.retrieveRun(thread.getId(), run.getId()); |
| 114 | + } |
| 115 | + if (retrievedRun.getStatus().equals("requires_action")) { |
| 116 | + RequiredAction requiredAction = retrievedRun.getRequiredAction(); |
| 117 | + System.out.println("requiredAction"); |
| 118 | + System.out.println(mapper.writeValueAsString(requiredAction)); |
| 119 | + List<ToolCall> toolCalls = requiredAction.getSubmitToolOutputs().getToolCalls(); |
| 120 | + ToolCall toolCall = toolCalls.get(0); |
| 121 | + String toolCallId = toolCall.getId(); |
| 122 | + |
| 123 | + SubmitToolOutputRequestItem toolOutputRequestItem = SubmitToolOutputRequestItem.builder() |
| 124 | + .toolCallId(toolCallId) |
| 125 | + .output("sunny") |
| 126 | + .build(); |
| 127 | + List<SubmitToolOutputRequestItem> toolOutputRequestItems = new ArrayList<>(); |
| 128 | + toolOutputRequestItems.add(toolOutputRequestItem); |
| 129 | + SubmitToolOutputsRequest submitToolOutputsRequest = SubmitToolOutputsRequest.builder() |
| 130 | + .toolOutputs(toolOutputRequestItems) |
| 131 | + .build(); |
| 132 | + retrievedRun = service.submitToolOutputs(retrievedRun.getThreadId(), retrievedRun.getId(), submitToolOutputsRequest); |
| 133 | + |
| 134 | + while (!(retrievedRun.getStatus().equals("completed")) |
| 135 | + && !(retrievedRun.getStatus().equals("failed")) |
| 136 | + && !(retrievedRun.getStatus().equals("requires_action"))){ |
| 137 | + retrievedRun = service.retrieveRun(thread.getId(), run.getId()); |
| 138 | + } |
| 139 | + |
| 140 | + OpenAiResponse<Message> response = service.listMessages(thread.getId()); |
| 141 | + |
| 142 | + List<Message> messages = response.getData(); |
| 143 | + |
| 144 | + System.out.println(mapper.writeValueAsString(messages)); |
| 145 | + |
| 146 | + } |
| 147 | + } |
| 148 | +} |
0 commit comments