Skip to content

Commit 38a2e97

Browse files
authored
Merge pull request #13 from launchableinc/update-lib-versions
Update lib versions
2 parents 744de95 + fb1f268 commit 38a2e97

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

api/src/main/java/com/launchableinc/openai/utils/TikTokensUtil.java

+16-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.knuddels.jtokkit.api.Encoding;
55
import com.knuddels.jtokkit.api.EncodingRegistry;
66
import com.knuddels.jtokkit.api.EncodingType;
7+
import com.knuddels.jtokkit.api.IntArrayList;
78
import com.knuddels.jtokkit.api.ModelType;
89
import com.launchableinc.openai.completion.chat.ChatMessage;
910
import lombok.AllArgsConstructor;
@@ -46,7 +47,7 @@ public class TikTokensUtil {
4647
* @return Encoding array
4748
*/
4849
public static List<Integer> encode(Encoding enc, String text) {
49-
return isBlank(text) ? new ArrayList<>() : enc.encode(text);
50+
return isBlank(text) ? new ArrayList<>() : enc.encode(text).boxed();
5051
}
5152

5253
/**
@@ -69,7 +70,7 @@ public static int tokens(Encoding enc, String text) {
6970
* @return Text information corresponding to the encoding array.
7071
*/
7172
public static String decode(Encoding enc, List<Integer> encoded) {
72-
return enc.decode(encoded);
73+
return enc.decode(toIntArrayList(encoded));
7374
}
7475

7576
/**
@@ -94,7 +95,7 @@ public static List<Integer> encode(EncodingType encodingType, String text) {
9495
return new ArrayList<>();
9596
}
9697
Encoding enc = getEncoding(encodingType);
97-
List<Integer> encoded = enc.encode(text);
98+
List<Integer> encoded = enc.encode(text).boxed();
9899
return encoded;
99100
}
100101

@@ -119,7 +120,7 @@ public static int tokens(EncodingType encodingType, String text) {
119120
*/
120121
public static String decode(EncodingType encodingType, List<Integer> encoded) {
121122
Encoding enc = getEncoding(encodingType);
122-
return enc.decode(encoded);
123+
return enc.decode(toIntArrayList(encoded));
123124
}
124125

125126

@@ -147,7 +148,7 @@ public static List<Integer> encode(String modelName, String text) {
147148
if (Objects.isNull(enc)) {
148149
return new ArrayList<>();
149150
}
150-
List<Integer> encoded = enc.encode(text);
151+
List<Integer> encoded = enc.encode(text).boxed();
151152
return encoded;
152153
}
153154

@@ -209,7 +210,16 @@ public static int tokens(String modelName, List<ChatMessage> messages) {
209210
*/
210211
public static String decode(String modelName, List<Integer> encoded) {
211212
Encoding enc = getEncoding(modelName);
212-
return enc.decode(encoded);
213+
return enc.decode(toIntArrayList(encoded));
214+
}
215+
216+
private static IntArrayList toIntArrayList(List<Integer> encoded) {
217+
IntArrayList intArrayList = new IntArrayList(encoded.size());
218+
for (Integer e : encoded) {
219+
intArrayList.add(e);
220+
}
221+
222+
return intArrayList;
213223
}
214224

215225

gradle/libs.versions.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ retrofit = { module = "com.squareup.retrofit2:retrofit", version.ref = "retrofit
1212
retrofitJackson = { module = "com.squareup.retrofit2:converter-jackson", version.ref = "retrofit" }
1313
retrofitRxJava2 = { module = "com.squareup.retrofit2:adapter-rxjava2", version.ref = "retrofit" }
1414
retrofitMock = { module = "com.squareup.retrofit2:retrofit-mock", version.ref = "retrofit" }
15-
jtokkit = { module = "com.knuddels:jtokkit", version = "0.6.1" }
15+
jtokkit = { module = "com.knuddels:jtokkit", version = "1.1.0" }

service/src/test/java/com/launchableinc/openai/service/ChatCompletionTest.java

+21
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,27 @@ void createChatCompletion_with_json_mode() {
101101
assertTrue(isValidJson(choices.getMessage().getContent()));
102102
}
103103

104+
@Test
105+
void createChatCompletion_with_gpt4o() {
106+
final List<ChatMessage> messages = new ArrayList<>();
107+
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(),
108+
"You are a cat and will speak as such.");
109+
messages.add(systemMessage);
110+
111+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
112+
.builder()
113+
.model("gpt-4o")
114+
.messages(messages)
115+
.n(5)
116+
.maxTokens(50)
117+
.logitBias(new HashMap<>())
118+
.build();
119+
120+
List<ChatCompletionChoice> choices = service.createChatCompletion(chatCompletionRequest)
121+
.getChoices();
122+
assertEquals(5, choices.size());
123+
}
124+
104125
private boolean isValidJson(String jsonString) {
105126
ObjectMapper objectMapper = new ObjectMapper();
106127
try {

0 commit comments

Comments
 (0)