Skip to content

Commit

Permalink
fix: Fix AI dialogue history (#1289)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszsocha2 authored Feb 19, 2025
1 parent 9119adc commit 29b6519
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 56 deletions.
62 changes: 36 additions & 26 deletions src/intTest/java/com/box/sdk/BoxAIIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;

import com.eclipsesource.json.Json;
import com.eclipsesource.json.JsonObject;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -111,8 +111,8 @@ public void askAIMultipleItems() throws InterruptedException {
public void askAITextGenItemWithDialogueHistory() throws ParseException, InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
String fileName = "[askAITextGenItemWithDialogueHistory] Test File.txt";
Date date1 = BoxDateFormat.parse("2013-05-16T15:27:57-07:00");
Date date2 = BoxDateFormat.parse("2013-05-16T15:26:57-07:00");
Date date1 = BoxDateFormat.parse("2021-01-01T00:00:00Z");
Date date2 = BoxDateFormat.parse("2021-02-01T00:00:00Z");

BoxFile uploadedFile = uploadFileToUniqueFolder(api, fileName, "Test file");
try {
Expand Down Expand Up @@ -148,28 +148,25 @@ public void askAITextGenItemWithDialogueHistory() throws ParseException, Interru
@Test
public void getAIAgentDefaultConfiguration() {
BoxAPIConnection api = jwtApiForServiceAccount();
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.ASK,
"en", "openai__gpt_3_5_turbo");
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.ASK);
BoxAIAgentAsk askAgent = (BoxAIAgentAsk) agent;

assertThat(askAgent.getType(), is(equalTo(BoxAIAgentAsk.TYPE)));
assertThat(askAgent.getBasicText().getModel(), is(equalTo("openai__gpt_3_5_turbo")));
assertThat(askAgent.getBasicText().getModel(), is(notNullValue()));

BoxAIAgent agent2 = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.TEXT_GEN,
"en", "openai__gpt_3_5_turbo");
BoxAIAgent agent2 = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.TEXT_GEN);
BoxAIAgentTextGen textGenAgent = (BoxAIAgentTextGen) agent2;

assertThat(textGenAgent.getType(), is(equalTo(BoxAIAgentTextGen.TYPE)));
assertThat(textGenAgent.getBasicGen().getModel(), is(equalTo("openai__gpt_3_5_turbo")));
assertThat(textGenAgent.getBasicGen().getModel(), is(notNullValue()));
}

@Test
public void askAISingleItemWithAgent() throws InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
String fileName = "[askAISingleItem] Test File.txt";
BoxFile uploadedFile = uploadFileToUniqueFolder(api, fileName, "Test file");
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.ASK,
"en", "openai__gpt_3_5_turbo_16k");
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.ASK);
BoxAIAgentAsk askAgent = (BoxAIAgentAsk) agent;

try {
Expand Down Expand Up @@ -199,8 +196,10 @@ public void askAISingleItemWithAgent() throws InterruptedException {
@Test
public void aiExtract() throws InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT, "en-US", null);
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT);
BoxAIAgentExtract agentExtract = (BoxAIAgentExtract) agent;
// AI team is going to move away from supporting overriding embeddings model
agentExtract.getLongText().setEmbeddings(null);

BoxFile uploadedFile = uploadFileToUniqueFolder(api, "[aiExtract] Test File.txt",
"My name is John Doe. I live in San Francisco. I was born in 1990. I work at Box.");
Expand All @@ -224,8 +223,10 @@ public void aiExtract() throws InterruptedException {
@Test
public void aiExtractStructuredWithFields() throws InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT_STRUCTURED, "en-US", null);
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT_STRUCTURED);
BoxAIAgentExtractStructured agentExtractStructured = (BoxAIAgentExtractStructured) agent;
// AI team is going to move away from supporting overriding embeddings model
agentExtractStructured.getLongText().setEmbeddings(null);

BoxFile uploadedFile = uploadFileToUniqueFolder(api, "[aiExtractStructuredWithFields] Test File.txt",
"My name is John Doe. I was born in 4th July 1990. I am 34 years old. My hobby is guitar.");
Expand Down Expand Up @@ -259,12 +260,16 @@ public void aiExtractStructuredWithFields() throws InterruptedException {
"What is your hobby?")
),
agentExtractStructured);
JsonObject sourceJson = response.getSourceJson();
assertThat(sourceJson.get("firstName").asString(), is(equalTo("John")));
assertThat(sourceJson.get("lastName").asString(), is(equalTo("Doe")));
assertThat(sourceJson.get("dateOfBirth").asString(), is(equalTo("1990-07-04")));
assertThat(sourceJson.get("age").asInt(), is(equalTo(34)));
assertThat(sourceJson.get("hobby").asArray().get(0).asString(), is(equalTo("guitar")));
assertThat(response.getSourceJson().get("answer"), is(equalTo(response.getAnswer())));

assertThat(response.getAnswer().get("firstName").asString(), is(equalTo("John")));
assertThat(response.getAnswer().get("lastName").asString(), is(equalTo("Doe")));
assertThat(response.getAnswer().get("dateOfBirth").asString(), is(equalTo("1990-07-04")));
assertThat(response.getAnswer().get("age").asInt(), is(equalTo(34)));
assertThat(response.getAnswer().get("hobby").asArray().get(0).asString(), is(equalTo("guitar")));

assertThat(response.getCompletionReason(), equalTo("done"));
assertThat(response.getCreatedAt(), is(notNullValue()));
}, 2, 2000);
} finally {
deleteFile(uploadedFile);
Expand All @@ -274,8 +279,10 @@ public void aiExtractStructuredWithFields() throws InterruptedException {
@Test
public void aiExtractStructuredWithMetadataTemplate() throws InterruptedException {
BoxAPIConnection api = jwtApiForServiceAccount();
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT_STRUCTURED, "en-US", null);
BoxAIAgent agent = BoxAI.getAiAgentDefaultConfig(api, BoxAIAgent.Mode.EXTRACT_STRUCTURED);
BoxAIAgentExtractStructured agentExtractStructured = (BoxAIAgentExtractStructured) agent;
// AI team is going to move away from supporting overriding embeddings model
agentExtractStructured.getLongText().setEmbeddings(null);

BoxFile uploadedFile = uploadFileToUniqueFolder(api, "[aiExtractStructuredWithMetadataTemplate] Test File.txt",
"My name is John Doe. I was born in 4th July 1990. I am 34 years old. My hobby is guitar.");
Expand Down Expand Up @@ -312,12 +319,15 @@ public void aiExtractStructuredWithMetadataTemplate() throws InterruptedExceptio
Collections.singletonList(new BoxAIItem(uploadedFile.getID(), BoxAIItem.Type.FILE)),
new BoxAIExtractMetadataTemplate(templateKey, "enterprise"),
agentExtractStructured);
JsonObject sourceJson = response.getSourceJson();
assertThat(sourceJson.get("firstName").asString(), is(equalTo("John")));
assertThat(sourceJson.get("lastName").asString(), is(equalTo("Doe")));
assertThat(sourceJson.get("dateOfBirth").asString(), is(equalTo("1990-07-04T00:00:00Z")));
assertThat(sourceJson.get("age").asInt(), is(equalTo(34)));
assertThat(sourceJson.get("hobby").asArray().get(0).asString(), is(equalTo("guitar")));
assertThat(response.getSourceJson().get("answer"), is(equalTo(response.getAnswer())));

assertThat(response.getAnswer().get("firstName").asString(), is(equalTo("John")));
assertThat(response.getAnswer().get("lastName").asString(), is(equalTo("Doe")));
assertThat(response.getAnswer().get("dateOfBirth").asString(), is(equalTo("1990-07-04T00:00:00Z")));
assertThat(response.getAnswer().get("age").asInt(), is(equalTo(34)));
assertThat(response.getAnswer().get("hobby").asArray().get(0).asString(), is(equalTo("guitar")));
assertThat(response.getCompletionReason(), equalTo("done"));
assertThat(response.getCreatedAt(), is(notNullValue()));
}, 2, 2000);
} finally {
deleteFile(uploadedFile);
Expand Down
16 changes: 12 additions & 4 deletions src/main/java/com/box/sdk/BoxAIAgentAsk.java
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,18 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "type", this.getType());
JsonUtils.addIfNotNull(jsonObject, "basic_text", this.basicText.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "basic_text_multi", this.basicTextMulti.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "long_text", this.longText.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "long_text_multi", this.longTextMulti.getJSONObject());
if (this.basicText != null) {
jsonObject.add("basic_text", this.basicText.getJSONObject());
}
if (this.basicTextMulti != null) {
jsonObject.add("basic_text_multi", this.basicTextMulti.getJSONObject());
}
if (this.longText != null) {
jsonObject.add("long_text", this.longText.getJSONObject());
}
if (this.longTextMulti != null) {
jsonObject.add("long_text_multi", this.longTextMulti.getJSONObject());
}
return jsonObject;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/box/sdk/BoxAIAgentAskBasicText.java
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ void parseJSONMember(JsonObject.Member member) {

public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "llm_endpoint_params", this.llmEndpointParams.getJSONObject());
if (this.llmEndpointParams != null) {
jsonObject.add("llm_endpoint_params", this.llmEndpointParams.getJSONObject());
}
JsonUtils.addIfNotNull(jsonObject, "model", this.model);
JsonUtils.addIfNotNull(jsonObject, "num_tokens_for_completion", this.numTokensForCompletion);
JsonUtils.addIfNotNull(jsonObject, "prompt_template", this.promptTemplate);
Expand Down
8 changes: 6 additions & 2 deletions src/main/java/com/box/sdk/BoxAIAgentAskLongText.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,12 @@ void parseJSONMember(JsonObject.Member member) {

public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "embeddings", this.embeddings.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "llm_endpoint_params", this.llmEndpointParams.getJSONObject());
if (this.embeddings != null) {
jsonObject.add("embeddings", this.embeddings.getJSONObject());
}
if (this.llmEndpointParams != null) {
jsonObject.add("llm_endpoint_params", this.llmEndpointParams.getJSONObject());
}
JsonUtils.addIfNotNull(jsonObject, "model", this.model);
JsonUtils.addIfNotNull(jsonObject, "num_tokens_for_completion", this.numTokensForCompletion);
JsonUtils.addIfNotNull(jsonObject, "prompt_template", this.promptTemplate);
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/box/sdk/BoxAIAgentEmbeddings.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "model", this.model);
JsonUtils.addIfNotNull(jsonObject, "strategy", this.strategy.getJSONObject());
if (this.strategy != null) {
jsonObject.add("strategy", this.strategy.getJSONObject());
}
return jsonObject;
}
}
8 changes: 6 additions & 2 deletions src/main/java/com/box/sdk/BoxAIAgentExtract.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "type", this.getType());
JsonUtils.addIfNotNull(jsonObject, "basic_text", this.basicText.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "long_text", this.longText.getJSONObject());
if (this.basicText != null) {
jsonObject.add("basic_text", this.basicText.getJSONObject());
}
if (this.longText != null) {
jsonObject.add("long_text", this.longText.getJSONObject());
}
return jsonObject;
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/main/java/com/box/sdk/BoxAIAgentExtractStructured.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "type", this.getType());
JsonUtils.addIfNotNull(jsonObject, "basic_text", this.basicText.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "long_text", this.longText.getJSONObject());
if (this.basicText != null) {
jsonObject.add("basic_text", this.basicText.getJSONObject());
}
if (this.longText != null) {
jsonObject.add("long_text", this.longText.getJSONObject());
}
return jsonObject;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/box/sdk/BoxAIAgentTextGen.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "type", this.getType());
JsonUtils.addIfNotNull(jsonObject, "basic_gen", this.basicGen.getJSONObject());
if (this.basicGen != null) {
jsonObject.add("basic_gen", this.basicGen.getJSONObject());
}
return jsonObject;
}
}
8 changes: 6 additions & 2 deletions src/main/java/com/box/sdk/BoxAIAgentTextGenBasicGen.java
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,12 @@ void parseJSONMember(JsonObject.Member member) {
public JsonObject getJSONObject() {
JsonObject jsonObject = new JsonObject();
JsonUtils.addIfNotNull(jsonObject, "content_template", this.contentTemplate);
JsonUtils.addIfNotNull(jsonObject, "embeddings", this.embeddings.getJSONObject());
JsonUtils.addIfNotNull(jsonObject, "llm_endpoint_params", this.llmEndpointParams.getJSONObject());
if (this.embeddings != null) {
jsonObject.add("embeddings", this.embeddings.getJSONObject());
}
if (this.llmEndpointParams != null) {
jsonObject.add("llm_endpoint_params", this.llmEndpointParams.getJSONObject());
}
JsonUtils.addIfNotNull(jsonObject, "model", this.model);
JsonUtils.addIfNotNull(jsonObject, "num_tokens_for_completion", this.numTokensForCompletion);
JsonUtils.addIfNotNull(jsonObject, "prompt_template", this.promptTemplate);
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/com/box/sdk/BoxAIDialogueEntry.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ public void setCreatedAt(Date createdAt) {
*/
public JsonObject getJSONObject() {
JsonObject itemJSON = new JsonObject()
.add("id", this.prompt)
.add("type", this.answer);
.add("prompt", this.prompt)
.add("answer", this.answer);

if (this.createdAt != null) {
itemJSON.add("content", this.createdAt.toString());
itemJSON.add("created_at", BoxDateFormat.format(this.createdAt));
}

return itemJSON;
Expand Down
54 changes: 54 additions & 0 deletions src/main/java/com/box/sdk/BoxAIExtractStructuredResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

import com.eclipsesource.json.Json;
import com.eclipsesource.json.JsonObject;
import com.eclipsesource.json.JsonValue;
import java.text.ParseException;
import java.util.Date;

/**
* AI response to a user request.
*/
public class BoxAIExtractStructuredResponse extends BoxJSONObject {
private final JsonObject sourceJson;
private JsonObject answer;
private String completionReason;
private Date createdAt;

/**
* Constructs a BoxAIResponse object.
Expand Down Expand Up @@ -35,4 +41,52 @@ public BoxAIExtractStructuredResponse(String json) {
public JsonObject getSourceJson() {
return sourceJson;
}

/**
* Gets the answer of the AI.
*
* @return the answer of the AI.
*/
public JsonObject getAnswer() {
return answer;
}

/**
* Gets reason the response finishes.
*
* @return the reason the response finishes.
*/
public String getCompletionReason() {
return completionReason;
}

/**
* Gets the ISO date formatted timestamp of when the answer to the prompt was created.
*
* @return The ISO date formatted timestamp of when the answer to the prompt was created.
*/
public Date getCreatedAt() {
return createdAt;
}

/**
* {@inheritDoc}
*/
@Override
void parseJSONMember(JsonObject.Member member) {
JsonValue value = member.getValue();
String memberName = member.getName();
try {
if (memberName.equals("answer")) {
this.answer = value.asObject();
} else if (memberName.equals("completion_reason")) {
this.completionReason = value.asString();
} else if (memberName.equals("created_at")) {
this.createdAt = BoxDateFormat.parse(value.asString());
}
} catch (ParseException e) {
assert false : "A ParseException indicates a bug in the SDK.";
}
}

}
18 changes: 11 additions & 7 deletions src/test/Fixtures/BoxAI/ExtractMetadataStructured200.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
{
"firstName": "John",
"lastName": "Doe",
"age": 25,
"hobbies": [
"reading",
"travelling"
]
"answer": {
"firstName": "John",
"lastName": "Doe",
"age": 25,
"hobbies": [
"reading",
"travelling"
]
},
"completion_reason": "done",
"created_at": "2012-12-12T10:53:43.123-08:00"
}
Loading

0 comments on commit 29b6519

Please sign in to comment.