diff --git a/build.gradle b/build.gradle index 301d49e..683a171 100644 --- a/build.gradle +++ b/build.gradle @@ -50,7 +50,7 @@ publishing { maven(MavenPublication) { groupId = 'com.cohere' artifactId = 'cohere-java' - version = '1.3.1' + version = '1.3.2' from components.java pom { name = 'cohere' diff --git a/src/main/java/com/cohere/api/Cohere.java b/src/main/java/com/cohere/api/Cohere.java index 362a79c..df115f5 100644 --- a/src/main/java/com/cohere/api/Cohere.java +++ b/src/main/java/com/cohere/api/Cohere.java @@ -55,6 +55,8 @@ import com.cohere.api.types.UnprocessableEntityErrorBody; import com.fasterxml.jackson.core.JsonProcessingException; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.function.Supplier; import okhttp3.Headers; import okhttp3.HttpUrl; @@ -103,19 +105,100 @@ public Iterable chatStream(ChatStreamRequest request, Requ .newBuilder() .addPathSegments("v1/chat") .build(); + Map properties = new HashMap<>(); + properties.put("message", request.getMessage()); + if (request.getModel().isPresent()) { + properties.put("model", request.getModel()); + } + properties.put("stream", request.getStream()); + if (request.getPreamble().isPresent()) { + properties.put("preamble", request.getPreamble()); + } + if (request.getChatHistory().isPresent()) { + properties.put("chat_history", request.getChatHistory()); + } + if (request.getConversationId().isPresent()) { + properties.put("conversation_id", request.getConversationId()); + } + if (request.getPromptTruncation().isPresent()) { + properties.put("prompt_truncation", request.getPromptTruncation()); + } + if (request.getConnectors().isPresent()) { + properties.put("connectors", request.getConnectors()); + } + if (request.getSearchQueriesOnly().isPresent()) { + properties.put("search_queries_only", request.getSearchQueriesOnly()); + } + if (request.getDocuments().isPresent()) { + properties.put("documents", request.getDocuments()); + } + if (request.getCitationQuality().isPresent()) { + properties.put("citation_quality", request.getCitationQuality()); + } + if (request.getTemperature().isPresent()) { + properties.put("temperature", request.getTemperature()); + } + if (request.getMaxTokens().isPresent()) { + properties.put("max_tokens", request.getMaxTokens()); + } + if (request.getMaxInputTokens().isPresent()) { + properties.put("max_input_tokens", request.getMaxInputTokens()); + } + if (request.getK().isPresent()) { + properties.put("k", request.getK()); + } + if (request.getP().isPresent()) { + properties.put("p", request.getP()); + } + if (request.getSeed().isPresent()) { + properties.put("seed", request.getSeed()); + } + if (request.getStopSequences().isPresent()) { + properties.put("stop_sequences", request.getStopSequences()); + } + if (request.getFrequencyPenalty().isPresent()) { + properties.put("frequency_penalty", request.getFrequencyPenalty()); + } + if (request.getPresencePenalty().isPresent()) { + properties.put("presence_penalty", request.getPresencePenalty()); + } + if (request.getRawPrompting().isPresent()) { + properties.put("raw_prompting", request.getRawPrompting()); + } + if (request.getReturnPrompt().isPresent()) { + properties.put("return_prompt", request.getReturnPrompt()); + } + if (request.getTools().isPresent()) { + properties.put("tools", request.getTools()); + } + if (request.getToolResults().isPresent()) { + properties.put("tool_results", request.getToolResults()); + } + if (request.getForceSingleStep().isPresent()) { + properties.put("force_single_step", request.getForceSingleStep()); + } + if (request.getResponseFormat().isPresent()) { + properties.put("response_format", request.getResponseFormat()); + } + if (request.getSafetyMode().isPresent()) { + properties.put("safety_mode", request.getSafetyMode()); + } RequestBody body; try { body = RequestBody.create( - ObjectMappers.JSON_MAPPER.writeValueAsBytes(request), MediaTypes.APPLICATION_JSON); - } catch (JsonProcessingException e) { - throw new CohereApiError("Failed to serialize request", e); + ObjectMappers.JSON_MAPPER.writeValueAsBytes(properties), MediaTypes.APPLICATION_JSON); + } catch (Exception e) { + throw new RuntimeException(e); } - Request okhttpRequest = new Request.Builder() + Request.Builder _requestBuilder = new Request.Builder() .url(httpUrl) .method("POST", body) .headers(Headers.of(clientOptions.headers(requestOptions))) - .addHeader("Content-Type", "application/json") - .build(); + .addHeader("Content-Type", "application/json"); + if (request.getAccepts().isPresent()) { + _requestBuilder.addHeader("Accepts", request.getAccepts().get()); + } + Request okhttpRequest = _requestBuilder.build(); OkHttpClient client = clientOptions.httpClient(); if (requestOptions != null && requestOptions.getTimeout().isPresent()) { client = clientOptions.httpClientWithTimeout(requestOptions); @@ -191,19 +274,100 @@ public NonStreamedChatResponse chat(ChatRequest request, RequestOptions requestO .newBuilder() .addPathSegments("v1/chat") .build(); + Map properties = new HashMap<>(); + properties.put("message", request.getMessage()); + if (request.getModel().isPresent()) { + properties.put("model", request.getModel()); + } + properties.put("stream", request.getStream()); + if (request.getPreamble().isPresent()) { + properties.put("preamble", request.getPreamble()); + } + if (request.getChatHistory().isPresent()) { + properties.put("chat_history", request.getChatHistory()); + } + if (request.getConversationId().isPresent()) { + properties.put("conversation_id", request.getConversationId()); + } + if (request.getPromptTruncation().isPresent()) { + properties.put("prompt_truncation", request.getPromptTruncation()); + } + if (request.getConnectors().isPresent()) { + properties.put("connectors", request.getConnectors()); + } + if (request.getSearchQueriesOnly().isPresent()) { + properties.put("search_queries_only", request.getSearchQueriesOnly()); + } + if (request.getDocuments().isPresent()) { + properties.put("documents", request.getDocuments()); + } + if (request.getCitationQuality().isPresent()) { + properties.put("citation_quality", request.getCitationQuality()); + } + if (request.getTemperature().isPresent()) { + properties.put("temperature", request.getTemperature()); + } + if (request.getMaxTokens().isPresent()) { + properties.put("max_tokens", request.getMaxTokens()); + } + if (request.getMaxInputTokens().isPresent()) { + properties.put("max_input_tokens", request.getMaxInputTokens()); + } + if (request.getK().isPresent()) { + properties.put("k", request.getK()); + } + if (request.getP().isPresent()) { + properties.put("p", request.getP()); + } + if (request.getSeed().isPresent()) { + properties.put("seed", request.getSeed()); + } + if (request.getStopSequences().isPresent()) { + properties.put("stop_sequences", request.getStopSequences()); + } + if (request.getFrequencyPenalty().isPresent()) { + properties.put("frequency_penalty", request.getFrequencyPenalty()); + } + if (request.getPresencePenalty().isPresent()) { + properties.put("presence_penalty", request.getPresencePenalty()); + } + if (request.getRawPrompting().isPresent()) { + properties.put("raw_prompting", request.getRawPrompting()); + } + if (request.getReturnPrompt().isPresent()) { + properties.put("return_prompt", request.getReturnPrompt()); + } + if (request.getTools().isPresent()) { + properties.put("tools", request.getTools()); + } + if (request.getToolResults().isPresent()) { + properties.put("tool_results", request.getToolResults()); + } + if (request.getForceSingleStep().isPresent()) { + properties.put("force_single_step", request.getForceSingleStep()); + } + if (request.getResponseFormat().isPresent()) { + properties.put("response_format", request.getResponseFormat()); + } + if (request.getSafetyMode().isPresent()) { + properties.put("safety_mode", request.getSafetyMode()); + } RequestBody body; try { body = RequestBody.create( - ObjectMappers.JSON_MAPPER.writeValueAsBytes(request), MediaTypes.APPLICATION_JSON); - } catch (JsonProcessingException e) { - throw new CohereApiError("Failed to serialize request", e); + ObjectMappers.JSON_MAPPER.writeValueAsBytes(properties), MediaTypes.APPLICATION_JSON); + } catch (Exception e) { + throw new RuntimeException(e); } - Request okhttpRequest = new Request.Builder() + Request.Builder _requestBuilder = new Request.Builder() .url(httpUrl) .method("POST", body) .headers(Headers.of(clientOptions.headers(requestOptions))) - .addHeader("Content-Type", "application/json") - .build(); + .addHeader("Content-Type", "application/json"); + if (request.getAccepts().isPresent()) { + _requestBuilder.addHeader("Accepts", request.getAccepts().get()); + } + Request okhttpRequest = _requestBuilder.build(); OkHttpClient client = clientOptions.httpClient(); if (requestOptions != null && requestOptions.getTimeout().isPresent()) { client = clientOptions.httpClientWithTimeout(requestOptions); @@ -448,6 +612,15 @@ public Generation generate(GenerateRequest request, RequestOptions requestOption } } + /** + * This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents. + *

Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.

+ *

If you want to learn more how to use the embedding model, have a look at the Semantic Search Guide.

+ */ + public EmbedResponse embed() { + return embed(EmbedRequest.builder().build()); + } + /** * This endpoint returns text embeddings. An embedding is a list of floating point numbers that captures semantic information about the text that it represents. *

Embeddings can be used to create text classifiers as well as empower semantic search. To learn more about embeddings, see the embedding page.

diff --git a/src/main/java/com/cohere/api/core/ClientOptions.java b/src/main/java/com/cohere/api/core/ClientOptions.java index f51ecd7..943764d 100644 --- a/src/main/java/com/cohere/api/core/ClientOptions.java +++ b/src/main/java/com/cohere/api/core/ClientOptions.java @@ -30,7 +30,7 @@ private ClientOptions( { put("X-Fern-Language", "JAVA"); put("X-Fern-SDK-Name", "com.cohere.fern:api-sdk"); - put("X-Fern-SDK-Version", "1.3.1"); + put("X-Fern-SDK-Version", "1.3.2"); } }); this.headerSuppliers = headerSuppliers; diff --git a/src/main/java/com/cohere/api/requests/ChatRequest.java b/src/main/java/com/cohere/api/requests/ChatRequest.java index ccae20d..09b1ec5 100644 --- a/src/main/java/com/cohere/api/requests/ChatRequest.java +++ b/src/main/java/com/cohere/api/requests/ChatRequest.java @@ -29,6 +29,8 @@ @JsonInclude(JsonInclude.Include.NON_EMPTY) @JsonDeserialize(builder = ChatRequest.Builder.class) public final class ChatRequest { + private final Optional accepts; + private final String message; private final Optional model; @@ -84,6 +86,7 @@ public final class ChatRequest { private final Map additionalProperties; private ChatRequest( + Optional accepts, String message, Optional model, Optional preamble, @@ -111,6 +114,7 @@ private ChatRequest( Optional responseFormat, Optional safetyMode, Map additionalProperties) { + this.accepts = accepts; this.message = message; this.model = model; this.preamble = preamble; @@ -140,6 +144,14 @@ private ChatRequest( this.additionalProperties = additionalProperties; } + /** + * @return Pass text/event-stream to receive the streamed response as server-sent events. The default is \n delimited events. + */ + @JsonProperty("Accepts") + public Optional getAccepts() { + return accepts; + } + /** * @return Text input for the model to respond to. *

Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments

@@ -451,7 +463,8 @@ public Map getAdditionalProperties() { } private boolean equalTo(ChatRequest other) { - return message.equals(other.message) + return accepts.equals(other.accepts) + && message.equals(other.message) && model.equals(other.model) && preamble.equals(other.preamble) && chatHistory.equals(other.chatHistory) @@ -482,6 +495,7 @@ private boolean equalTo(ChatRequest other) { @java.lang.Override public int hashCode() { return Objects.hash( + this.accepts, this.message, this.model, this.preamble, @@ -528,6 +542,10 @@ public interface MessageStage { public interface _FinalStage { ChatRequest build(); + _FinalStage accepts(Optional accepts); + + _FinalStage accepts(String accepts); + _FinalStage model(Optional model); _FinalStage model(String model); @@ -683,6 +701,8 @@ public static final class Builder implements MessageStage, _FinalStage { private Optional model = Optional.empty(); + private Optional accepts = Optional.empty(); + @JsonAnySetter private Map additionalProperties = new HashMap<>(); @@ -690,6 +710,7 @@ private Builder() {} @java.lang.Override public Builder from(ChatRequest other) { + accepts(other.getAccepts()); message(other.getMessage()); model(other.getModel()); preamble(other.getPreamble()); @@ -1234,9 +1255,27 @@ public _FinalStage model(Optional model) { return this; } + /** + *

Pass text/event-stream to receive the streamed response as server-sent events. The default is \n delimited events.

+ * @return Reference to {@code this} so that method calls can be chained together. + */ + @java.lang.Override + public _FinalStage accepts(String accepts) { + this.accepts = Optional.of(accepts); + return this; + } + + @java.lang.Override + @JsonSetter(value = "Accepts", nulls = Nulls.SKIP) + public _FinalStage accepts(Optional accepts) { + this.accepts = accepts; + return this; + } + @java.lang.Override public ChatRequest build() { return new ChatRequest( + accepts, message, model, preamble, diff --git a/src/main/java/com/cohere/api/requests/ChatStreamRequest.java b/src/main/java/com/cohere/api/requests/ChatStreamRequest.java index c1d1eb3..54ae350 100644 --- a/src/main/java/com/cohere/api/requests/ChatStreamRequest.java +++ b/src/main/java/com/cohere/api/requests/ChatStreamRequest.java @@ -29,6 +29,8 @@ @JsonInclude(JsonInclude.Include.NON_EMPTY) @JsonDeserialize(builder = ChatStreamRequest.Builder.class) public final class ChatStreamRequest { + private final Optional accepts; + private final String message; private final Optional model; @@ -84,6 +86,7 @@ public final class ChatStreamRequest { private final Map additionalProperties; private ChatStreamRequest( + Optional accepts, String message, Optional model, Optional preamble, @@ -111,6 +114,7 @@ private ChatStreamRequest( Optional responseFormat, Optional safetyMode, Map additionalProperties) { + this.accepts = accepts; this.message = message; this.model = model; this.preamble = preamble; @@ -140,6 +144,14 @@ private ChatStreamRequest( this.additionalProperties = additionalProperties; } + /** + * @return Pass text/event-stream to receive the streamed response as server-sent events. The default is \n delimited events. + */ + @JsonProperty("Accepts") + public Optional getAccepts() { + return accepts; + } + /** * @return Text input for the model to respond to. *

Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private Deployments

@@ -451,7 +463,8 @@ public Map getAdditionalProperties() { } private boolean equalTo(ChatStreamRequest other) { - return message.equals(other.message) + return accepts.equals(other.accepts) + && message.equals(other.message) && model.equals(other.model) && preamble.equals(other.preamble) && chatHistory.equals(other.chatHistory) @@ -482,6 +495,7 @@ private boolean equalTo(ChatStreamRequest other) { @java.lang.Override public int hashCode() { return Objects.hash( + this.accepts, this.message, this.model, this.preamble, @@ -528,6 +542,10 @@ public interface MessageStage { public interface _FinalStage { ChatStreamRequest build(); + _FinalStage accepts(Optional accepts); + + _FinalStage accepts(String accepts); + _FinalStage model(Optional model); _FinalStage model(String model); @@ -683,6 +701,8 @@ public static final class Builder implements MessageStage, _FinalStage { private Optional model = Optional.empty(); + private Optional accepts = Optional.empty(); + @JsonAnySetter private Map additionalProperties = new HashMap<>(); @@ -690,6 +710,7 @@ private Builder() {} @java.lang.Override public Builder from(ChatStreamRequest other) { + accepts(other.getAccepts()); message(other.getMessage()); model(other.getModel()); preamble(other.getPreamble()); @@ -1234,9 +1255,27 @@ public _FinalStage model(Optional model) { return this; } + /** + *

Pass text/event-stream to receive the streamed response as server-sent events. The default is \n delimited events.

+ * @return Reference to {@code this} so that method calls can be chained together. + */ + @java.lang.Override + public _FinalStage accepts(String accepts) { + this.accepts = Optional.of(accepts); + return this; + } + + @java.lang.Override + @JsonSetter(value = "Accepts", nulls = Nulls.SKIP) + public _FinalStage accepts(Optional accepts) { + this.accepts = accepts; + return this; + } + @java.lang.Override public ChatStreamRequest build() { return new ChatStreamRequest( + accepts, message, model, preamble, diff --git a/src/main/java/com/cohere/api/requests/EmbedRequest.java b/src/main/java/com/cohere/api/requests/EmbedRequest.java index 3a90448..c01ffd7 100644 --- a/src/main/java/com/cohere/api/requests/EmbedRequest.java +++ b/src/main/java/com/cohere/api/requests/EmbedRequest.java @@ -15,7 +15,6 @@ import com.fasterxml.jackson.annotation.JsonSetter; import com.fasterxml.jackson.annotation.Nulls; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -25,7 +24,9 @@ @JsonInclude(JsonInclude.Include.NON_EMPTY) @JsonDeserialize(builder = EmbedRequest.Builder.class) public final class EmbedRequest { - private final List texts; + private final Optional> texts; + + private final Optional> images; private final Optional model; @@ -38,13 +39,15 @@ public final class EmbedRequest { private final Map additionalProperties; private EmbedRequest( - List texts, + Optional> texts, + Optional> images, Optional model, Optional inputType, Optional> embeddingTypes, Optional truncate, Map additionalProperties) { this.texts = texts; + this.images = images; this.model = model; this.inputType = inputType; this.embeddingTypes = embeddingTypes; @@ -56,10 +59,19 @@ private EmbedRequest( * @return An array of strings for the model to embed. Maximum number of texts per call is 96. We recommend reducing the length of each text to be under 512 tokens for optimal quality. */ @JsonProperty("texts") - public List getTexts() { + public Optional> getTexts() { return texts; } + /** + * @return An array of image data URIs for the model to embed. Maximum number of images per call is 1. + *

The image must be a valid data URI. The image must be in either image/jpeg or image/png format and has a maximum size of 5MB.

+ */ + @JsonProperty("images") + public Optional> getImages() { + return images; + } + /** * @return Defaults to embed-english-v2.0 *

The identifier of the model. Smaller "light" models are faster, while larger models will perform better. Custom models can also be supplied with their full ID.

@@ -136,6 +148,7 @@ public Map getAdditionalProperties() { private boolean equalTo(EmbedRequest other) { return texts.equals(other.texts) + && images.equals(other.images) && model.equals(other.model) && inputType.equals(other.inputType) && embeddingTypes.equals(other.embeddingTypes) @@ -144,7 +157,7 @@ private boolean equalTo(EmbedRequest other) { @java.lang.Override public int hashCode() { - return Objects.hash(this.texts, this.model, this.inputType, this.embeddingTypes, this.truncate); + return Objects.hash(this.texts, this.images, this.model, this.inputType, this.embeddingTypes, this.truncate); } @java.lang.Override @@ -158,7 +171,9 @@ public static Builder builder() { @JsonIgnoreProperties(ignoreUnknown = true) public static final class Builder { - private List texts = new ArrayList<>(); + private Optional> texts = Optional.empty(); + + private Optional> images = Optional.empty(); private Optional model = Optional.empty(); @@ -175,6 +190,7 @@ private Builder() {} public Builder from(EmbedRequest other) { texts(other.getTexts()); + images(other.getImages()); model(other.getModel()); inputType(other.getInputType()); embeddingTypes(other.getEmbeddingTypes()); @@ -183,19 +199,24 @@ public Builder from(EmbedRequest other) { } @JsonSetter(value = "texts", nulls = Nulls.SKIP) + public Builder texts(Optional> texts) { + this.texts = texts; + return this; + } + public Builder texts(List texts) { - this.texts.clear(); - this.texts.addAll(texts); + this.texts = Optional.of(texts); return this; } - public Builder addTexts(String texts) { - this.texts.add(texts); + @JsonSetter(value = "images", nulls = Nulls.SKIP) + public Builder images(Optional> images) { + this.images = images; return this; } - public Builder addAllTexts(List texts) { - this.texts.addAll(texts); + public Builder images(List images) { + this.images = Optional.of(images); return this; } @@ -244,7 +265,7 @@ public Builder truncate(EmbedRequestTruncate truncate) { } public EmbedRequest build() { - return new EmbedRequest(texts, model, inputType, embeddingTypes, truncate, additionalProperties); + return new EmbedRequest(texts, images, model, inputType, embeddingTypes, truncate, additionalProperties); } } } diff --git a/src/main/java/com/cohere/api/resources/finetuning/finetuning/types/Settings.java b/src/main/java/com/cohere/api/resources/finetuning/finetuning/types/Settings.java index a63112a..c751d56 100644 --- a/src/main/java/com/cohere/api/resources/finetuning/finetuning/types/Settings.java +++ b/src/main/java/com/cohere/api/resources/finetuning/finetuning/types/Settings.java @@ -80,7 +80,7 @@ public Optional getMultiLabel() { } /** - * @return The Weights & Biases configuration. + * @return The Weights & Biases configuration (Chat fine-tuning only). */ @JsonProperty("wandb") public Optional getWandb() { @@ -196,7 +196,7 @@ public _FinalStage datasetId(String datasetId) { } /** - *

The Weights & Biases configuration.

+ *

The Weights & Biases configuration (Chat fine-tuning only).

* @return Reference to {@code this} so that method calls can be chained together. */ @java.lang.Override