Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 103c34d

Browse filesBrowse files
authored
Add classification support (TheoKanning#11)
1 parent ed2f115 commit 103c34d
Copy full SHA for 103c34d

File tree

Expand file treeCollapse file tree

9 files changed

+249
-6
lines changed
Filter options
Expand file treeCollapse file tree

9 files changed

+249
-6
lines changed

‎.github/workflows/test.yml

Copy file name to clipboardExpand all lines: .github/workflows/test.yml
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Publish
1+
name: Test
22

33
on:
44
push:
+120Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
package com.theokanning.openai.classification;
2+
3+
import lombok.*;
4+
5+
import java.util.List;
6+
import java.util.Map;
7+
8+
/**
9+
* A request for OpenAi to classify text based on provided examples
10+
* All fields are nullable.
11+
*
12+
* Documentation taken from
13+
* https://beta.openai.com/docs/api-reference/classifications/create
14+
*/
15+
@Builder
16+
@NoArgsConstructor
17+
@AllArgsConstructor
18+
@Data
19+
public class ClassificationRequest {
20+
21+
/**
22+
* ID of the engine to use for completion
23+
*/
24+
@NonNull
25+
String model;
26+
27+
/**
28+
* Query to be classified
29+
*/
30+
@NonNull
31+
String query;
32+
33+
/**
34+
* A list of examples with labels, in the following format:
35+
*
36+
* [["The movie is so interesting.", "Positive"], ["It is quite boring.", "Negative"], ...]
37+
*
38+
* All the label strings will be normalized to be capitalized.
39+
*
40+
* You should specify either examples or file, but not both.
41+
*/
42+
List<List<String>> examples;
43+
44+
/**
45+
* The ID of the uploaded file that contains training examples.
46+
* See upload file for how to upload a file of the desired format and purpose.
47+
*
48+
* You should specify either examples or file, but not both.
49+
*/
50+
String file;
51+
52+
/**
53+
* The set of categories being classified.
54+
* If not specified, candidate labels will be automatically collected from the examples you provide.
55+
* All the label strings will be normalized to be capitalized.
56+
*/
57+
List<String> labels;
58+
59+
/**
60+
* ID of the engine to use for Search. You can select one of ada, babbage, curie, or davinci.
61+
*/
62+
String searchModel;
63+
64+
/**
65+
* What sampling temperature to use. Higher values means the model will take more risks.
66+
* Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
67+
*
68+
* We generally recommend using this or {@link top_p} but not both.
69+
*/
70+
Double temperature;
71+
72+
/**
73+
* Include the log probabilities on the logprobs most likely tokens, as well the chosen tokens.
74+
* For example, if logprobs is 10, the API will return a list of the 10 most likely tokens.
75+
* The API will always return the logprob of the sampled token,
76+
* so there may be up to logprobs+1 elements in the response.
77+
*/
78+
Integer logprobs;
79+
80+
/**
81+
* The maximum number of examples to be ranked by Search when using file.
82+
* Setting it to a higher value leads to improved accuracy but with increased latency and cost.
83+
*/
84+
Integer maxExamples;
85+
86+
/**
87+
* Modify the likelihood of specified tokens appearing in the completion.
88+
*
89+
* Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an
90+
* associated bias value from -100 to 100.
91+
*/
92+
Map<String, Double> logitBias;
93+
94+
/**
95+
* If set to true, the returned JSON will include a "prompt" field containing the final prompt that was
96+
* used to request a completion. This is mainly useful for debugging purposes.
97+
*/
98+
Boolean returnPrompt;
99+
100+
/**
101+
* A special boolean flag for showing metadata.
102+
* If set to true, each document entry in the returned JSON will contain a "metadata" field.
103+
*
104+
* This flag only takes effect when file is set.
105+
*/
106+
Boolean returnMetadata;
107+
108+
/**
109+
* If an object name is in the list, we provide the full information of the object;
110+
* otherwise, we only provide the object ID.
111+
*
112+
* Currently we support completion and file objects for expansion.
113+
*/
114+
List<String> expand;
115+
116+
/**
117+
* A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
118+
*/
119+
String user;
120+
}
+45Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.theokanning.openai.classification;
2+
3+
import com.theokanning.openai.completion.CompletionChoice;
4+
import lombok.Data;
5+
6+
import java.util.List;
7+
8+
/**
9+
* An object containing a response from the classification api
10+
* <p>
11+
* https://beta.openai.com/docs/api-reference/classifications/create
12+
*/
13+
@Data
14+
public class ClassificationResult {
15+
16+
/**
17+
* A unique id assigned to this completion
18+
*/
19+
String completion;
20+
21+
/**
22+
* The predicted label for the query text.
23+
*/
24+
String label;
25+
26+
/**
27+
* The GPT-3 model used for completion
28+
*/
29+
String model;
30+
31+
/**
32+
* The type of object returned, should be "classification"
33+
*/
34+
String object;
35+
36+
/**
37+
* The GPT-3 model used for search
38+
*/
39+
String searchModel;
40+
41+
/**
42+
* A list of the most relevant examples for the query text.
43+
*/
44+
List<Example> selectedExamples;
45+
}
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package com.theokanning.openai.classification;
2+
3+
import lombok.Data;
4+
5+
/**
6+
* Represents an example returned by the classification api
7+
*
8+
* https://beta.openai.com/docs/api-reference/classifications/create
9+
*/
10+
@Data
11+
public class Example {
12+
/**
13+
* The position of this example in the example list
14+
*/
15+
Integer document;
16+
17+
/**
18+
* The label of the example
19+
*/
20+
String label;
21+
22+
/**
23+
* The text of the example
24+
*/
25+
String text;
26+
}

‎api/src/main/java/com/theokanning/openai/finetune/FineTuneRequest.java

Copy file name to clipboardExpand all lines: api/src/main/java/com/theokanning/openai/finetune/FineTuneRequest.java
+2-4Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package com.theokanning.openai.finetune;
22

3-
import lombok.AllArgsConstructor;
4-
import lombok.Builder;
5-
import lombok.Data;
6-
import lombok.NoArgsConstructor;
3+
import lombok.*;
74

85
import java.util.List;
96

@@ -22,6 +19,7 @@ public class FineTuneRequest {
2219
/**
2320
* The ID of an uploaded file that contains training data.
2421
*/
22+
@NonNull
2523
String trainingFile;
2624

2725
/**

‎client/src/main/java/com/theokanning/openai/OpenAiApi.java

Copy file name to clipboardExpand all lines: client/src/main/java/com/theokanning/openai/OpenAiApi.java
+5Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package com.theokanning.openai;
22

3+
import com.theokanning.openai.classification.ClassificationRequest;
4+
import com.theokanning.openai.classification.ClassificationResult;
35
import com.theokanning.openai.engine.Engine;
46
import com.theokanning.openai.file.File;
57
import com.theokanning.openai.finetune.FineTuneRequest;
@@ -28,6 +30,9 @@ public interface OpenAiApi {
2830
@POST("/v1/engines/{engine_id}/search")
2931
Single<OpenAiResponse<SearchResult>> search(@Path("engine_id") String engineId, @Body SearchRequest request);
3032

33+
@POST("v1/classifications")
34+
Single<ClassificationResult> createClassification(@Body ClassificationRequest request);
35+
3136
@GET("/v1/files")
3237
Single<OpenAiResponse<File>> listFiles();
3338

‎client/src/main/java/com/theokanning/openai/OpenAiService.java

Copy file name to clipboardExpand all lines: client/src/main/java/com/theokanning/openai/OpenAiService.java
+6Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import com.fasterxml.jackson.databind.DeserializationFeature;
55
import com.fasterxml.jackson.databind.ObjectMapper;
66
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
7+
import com.theokanning.openai.classification.ClassificationRequest;
8+
import com.theokanning.openai.classification.ClassificationResult;
79
import com.theokanning.openai.file.File;
810
import com.theokanning.openai.finetune.FineTuneRequest;
911
import com.theokanning.openai.finetune.FineTuneEvent;
@@ -62,6 +64,10 @@ public List<SearchResult> search(String engineId, SearchRequest request) {
6264
return api.search(engineId, request).blockingGet().data;
6365
}
6466

67+
public ClassificationResult createClassification(ClassificationRequest request) {
68+
return api.createClassification(request).blockingGet();
69+
}
70+
6571
public List<File> listFiles() {
6672
return api.listFiles().blockingGet().data;
6773
}
+39Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package com.theokanning.openai;
2+
3+
import com.theokanning.openai.classification.ClassificationRequest;
4+
import com.theokanning.openai.classification.ClassificationResult;
5+
import com.theokanning.openai.completion.CompletionChoice;
6+
import com.theokanning.openai.completion.CompletionRequest;
7+
import org.junit.jupiter.api.Test;
8+
9+
import java.util.Arrays;
10+
import java.util.List;
11+
12+
import static org.junit.jupiter.api.Assertions.assertFalse;
13+
import static org.junit.jupiter.api.Assertions.assertNotNull;
14+
15+
16+
public class ClassificationTest {
17+
18+
String token = System.getenv("OPENAI_TOKEN");
19+
OpenAiService service = new OpenAiService(token);
20+
21+
@Test
22+
void createCompletion() {
23+
ClassificationRequest classificationRequest = ClassificationRequest.builder()
24+
.examples(Arrays.asList(
25+
Arrays.asList("A happy moment", "Positive"),
26+
Arrays.asList("I am sad.", "Negative"),
27+
Arrays.asList("I am feeling awesome", "Positive")
28+
))
29+
.query("It is a raining day :(")
30+
.model("curie")
31+
.searchModel("ada")
32+
.labels(Arrays.asList("Positive", "Negative", "Neutral"))
33+
.build();
34+
35+
ClassificationResult result = service.createClassification(classificationRequest);
36+
37+
assertNotNull(result.getCompletion());
38+
}
39+
}

‎client/src/test/java/com/theokanning/openai/FineTuneTest.java

Copy file name to clipboardExpand all lines: client/src/test/java/com/theokanning/openai/FineTuneTest.java
+5-1Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import org.junit.jupiter.api.*;
77

88
import java.util.List;
9+
import java.util.concurrent.TimeUnit;
910

1011
import static org.junit.jupiter.api.Assertions.*;
1112

@@ -17,10 +18,13 @@ public class FineTuneTest {
1718

1819

1920
@BeforeAll
20-
static void setup() {
21+
static void setup() throws Exception {
2122
String token = System.getenv("OPENAI_TOKEN");
2223
service = new OpenAiService(token);
2324
fileId = service.uploadFile("fine-tune", "src/test/resources/fine-tuning-data.jsonl").getId();
25+
26+
// wait for file to be processed
27+
TimeUnit.SECONDS.sleep(10);
2428
}
2529

2630
@AfterAll

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.