|
2 | 2 |
|
3 | 3 | import com.fasterxml.jackson.annotation.JsonProperty;
|
4 | 4 | import com.fasterxml.jackson.annotation.JsonPropertyDescription;
|
| 5 | +import com.fasterxml.jackson.core.JsonProcessingException; |
5 | 6 | import com.fasterxml.jackson.databind.JsonNode;
|
| 7 | +import com.fasterxml.jackson.databind.ObjectMapper; |
6 | 8 | import com.fasterxml.jackson.databind.node.ObjectNode;
|
7 | 9 | import com.launchableinc.openai.completion.chat.*;
|
| 10 | +import com.launchableinc.openai.completion.chat.ChatResponseFormat.ResponseFormat; |
8 | 11 | import org.junit.jupiter.api.Assumptions;
|
9 | 12 | import org.junit.jupiter.api.BeforeAll;
|
10 | 13 | import org.junit.jupiter.api.Test;
|
@@ -77,6 +80,37 @@ void createChatCompletion() {
|
77 | 80 | assertEquals(5, choices.size());
|
78 | 81 | }
|
79 | 82 |
|
| 83 | + @Test |
| 84 | + void createChatCompletion_with_json_mode() { |
| 85 | + final List<ChatMessage> messages = new ArrayList<>(); |
| 86 | + final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), |
| 87 | + "Generate a random name and age json object. name field is a object that has first and last fields. age is a number."); |
| 88 | + messages.add(systemMessage); |
| 89 | + |
| 90 | + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest |
| 91 | + .builder() |
| 92 | + .model("gpt-3.5-turbo-1106") |
| 93 | + .messages(messages) |
| 94 | + .maxTokens(50) |
| 95 | + .logitBias(new HashMap<>()) |
| 96 | + .responseFormat(ChatResponseFormat.builder().type(ResponseFormat.JSON).build()) |
| 97 | + .build(); |
| 98 | + |
| 99 | + ChatCompletionChoice choices = service.createChatCompletion(chatCompletionRequest) |
| 100 | + .getChoices().get(0); |
| 101 | + assertTrue(isValidJson(choices.getMessage().getContent())); |
| 102 | + } |
| 103 | + |
| 104 | + private boolean isValidJson(String jsonString) { |
| 105 | + ObjectMapper objectMapper = new ObjectMapper(); |
| 106 | + try { |
| 107 | + objectMapper.readTree(jsonString); |
| 108 | + return true; |
| 109 | + } catch (JsonProcessingException e) { |
| 110 | + return false; |
| 111 | + } |
| 112 | + } |
| 113 | + |
80 | 114 | @Test
|
81 | 115 | void streamChatCompletion() {
|
82 | 116 | final List<ChatMessage> messages = new ArrayList<>();
|
|
0 commit comments