Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added initial Trimming support. #21

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions OpenAI.sln
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI", "src\OpenAI.csproj
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "OpenAI.Examples", "examples\OpenAI.Examples.csproj", "{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TrimmingHelper", "helpers\TrimmingHelper\TrimmingHelper.csproj", "{4C6C7FB5-DD4E-44A8-9CBF-D739284FDC23}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want the TrimmingHelper project in this repo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will delete this after the review, right now it is still needed for re-checking

EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand All @@ -20,6 +22,10 @@ Global
{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1F1CD1D4-9932-4B73-99D8-C252A67D4B46}.Release|Any CPU.Build.0 = Release|Any CPU
{4C6C7FB5-DD4E-44A8-9CBF-D739284FDC23}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{4C6C7FB5-DD4E-44A8-9CBF-D739284FDC23}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4C6C7FB5-DD4E-44A8-9CBF-D739284FDC23}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4C6C7FB5-DD4E-44A8-9CBF-D739284FDC23}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
1 change: 1 addition & 0 deletions helpers/TrimmingHelper/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Console.WriteLine("Build, rebuild or publish this app to see trimming warnings.");
30 changes: 30 additions & 0 deletions helpers/TrimmingHelper/TrimmingHelper.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net8.0</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>

<PublishTrimmed>true</PublishTrimmed>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\OpenAI.csproj" />
</ItemGroup>

<ItemGroup>
<TrimmerRootAssembly Include="OpenAI" />
</ItemGroup>

<PropertyGroup Label="Publish">
<RuntimeIdentifier Condition="$([MSBuild]::IsOSPlatform('windows'))">win-x64</RuntimeIdentifier>
<RuntimeIdentifier Condition="!$([MSBuild]::IsOSPlatform('windows'))">osx-arm64</RuntimeIdentifier>

<SelfContained>true</SelfContained>
</PropertyGroup>

<Target Name="ProduceTrimmingWarnings" AfterTargets="Build">
<CallTarget Targets="Publish"/>
</Target>

</Project>
31 changes: 22 additions & 9 deletions src/Custom/Common/InternalListHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.ClientModel;
using System.ClientModel.Primitives;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

Expand All @@ -10,28 +11,40 @@ internal static class InternalListHelpers
internal delegate Task<ClientResult> AsyncListResponseFunc(string continuationToken, int? pageSize);
internal delegate ClientResult ListResponseFunc(string continuationToken, int? pageSize);

internal static AsyncPageableCollection<T> CreateAsyncPageable<T, U>(AsyncListResponseFunc listResponseFunc)
where U : IJsonModel<U>, IInternalListResponse<T>
internal static AsyncPageableCollection<T> CreateAsyncPageable<T,
#if NET6_0_OR_GREATER
HavenDV marked this conversation as resolved.
Show resolved Hide resolved
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)]
#endif
TInternalList>(AsyncListResponseFunc listResponseFunc)
where TInternalList : IJsonModel<TInternalList>, IInternalListResponse<T>
{
async Task<ResultPage<T>> pageFunc(string continuationToken, int? pageSize)
=> GetPageFromProtocol<T,U>(await listResponseFunc(continuationToken, pageSize).ConfigureAwait(false));
=> GetPageFromProtocol<T,TInternalList>(await listResponseFunc(continuationToken, pageSize).ConfigureAwait(false));
return PageableResultHelpers.Create((pageSize) => pageFunc(null, pageSize), pageFunc);
}

internal static PageableCollection<T> CreatePageable<T, U>(ListResponseFunc listResponseFunc)
where U : IJsonModel<U>, IInternalListResponse<T>
internal static PageableCollection<T> CreatePageable<T,
#if NET6_0_OR_GREATER
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)]
#endif
TInternalList>(ListResponseFunc listResponseFunc)
where TInternalList : IJsonModel<TInternalList>, IInternalListResponse<T>
{
ResultPage<T> pageFunc(string continuationToken, int? pageSize)
=> GetPageFromProtocol<T, U>(listResponseFunc(continuationToken, pageSize));
=> GetPageFromProtocol<T, TInternalList>(listResponseFunc(continuationToken, pageSize));
return PageableResultHelpers.Create((pageSize) => pageFunc(null, pageSize), pageFunc);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ResultPage<TInstance> GetPageFromProtocol<TInstance, UInternalList>(ClientResult protocolResult)
where UInternalList : IJsonModel<UInternalList>, IInternalListResponse<TInstance>
private static ResultPage<TInstance> GetPageFromProtocol<TInstance,
#if NET6_0_OR_GREATER
[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)]
#endif
TInternalList>(ClientResult protocolResult)
where TInternalList : IJsonModel<TInternalList>, IInternalListResponse<TInstance>
{
PipelineResponse response = protocolResult.GetRawResponse();
IInternalListResponse<TInstance> values = ModelReaderWriter.Read<UInternalList>(response.Content);
IInternalListResponse<TInstance> values = ModelReaderWriter.Read<TInternalList>(response.Content);
return ResultPage<TInstance>.Create(values.Data, values.HasMore ? values.LastId : null, response);
}
}
12 changes: 6 additions & 6 deletions src/Custom/Embeddings/EmbeddingClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public virtual async Task<ClientResult<Embedding>> GenerateEmbeddingAsync(string
Argument.AssertNotNullOrEmpty(input, nameof(input));

options ??= new();
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(input), ref options);
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(input, SourceGenerationContext.Default.String), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await GenerateEmbeddingsAsync(content, (RequestOptions)null).ConfigureAwait(false);
Expand All @@ -99,7 +99,7 @@ public virtual ClientResult<Embedding> GenerateEmbedding(string input, Embedding
Argument.AssertNotNullOrEmpty(input, nameof(input));

options ??= new();
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(input), ref options);
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(input, SourceGenerationContext.Default.String), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = GenerateEmbeddings(content, (RequestOptions)null);
Expand All @@ -117,7 +117,7 @@ public virtual async Task<ClientResult<EmbeddingCollection>> GenerateEmbeddingsA
Argument.AssertNotNullOrEmpty(inputs, nameof(inputs));

options ??= new();
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(inputs), ref options);
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(inputs, SourceGenerationContext.Default.IEnumerableString), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await GenerateEmbeddingsAsync(content, (RequestOptions)null).ConfigureAwait(false);
Expand All @@ -136,7 +136,7 @@ public virtual ClientResult<EmbeddingCollection> GenerateEmbeddings(IEnumerable<
Argument.AssertNotNullOrEmpty(inputs, nameof(inputs));

options ??= new();
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(inputs), ref options);
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(inputs, SourceGenerationContext.Default.IEnumerableString), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = GenerateEmbeddings(content, (RequestOptions)null);
Expand All @@ -154,7 +154,7 @@ public virtual async Task<ClientResult<EmbeddingCollection>> GenerateEmbeddingsA
Argument.AssertNotNullOrEmpty(inputs, nameof(inputs));

options ??= new();
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(inputs), ref options);
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(inputs, SourceGenerationContext.Default.IEnumerableIEnumerableInt32), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await GenerateEmbeddingsAsync(content, (RequestOptions)null).ConfigureAwait(false);
Expand All @@ -172,7 +172,7 @@ public virtual ClientResult<EmbeddingCollection> GenerateEmbeddings(IEnumerable<
Argument.AssertNotNullOrEmpty(inputs, nameof(inputs));

options ??= new();
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(inputs), ref options);
CreateEmbeddingGenerationOptions(BinaryData.FromObjectAsJson(inputs, SourceGenerationContext.Default.IEnumerableIEnumerableInt32), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = GenerateEmbeddings(content, (RequestOptions)null);
Expand Down
8 changes: 4 additions & 4 deletions src/Custom/Moderations/ModerationClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public virtual async Task<ClientResult<ModerationResult>> ClassifyTextInputAsync
Argument.AssertNotNullOrEmpty(input, nameof(input));

ModerationOptions options = new();
CreateModerationOptions(BinaryData.FromObjectAsJson(input), ref options);
CreateModerationOptions(BinaryData.FromObjectAsJson(input, SourceGenerationContext.Default.String), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await ClassifyTextInputsAsync(content, (RequestOptions)null).ConfigureAwait(false);
Expand All @@ -90,7 +90,7 @@ public virtual ClientResult<ModerationResult> ClassifyTextInput(string input)
Argument.AssertNotNullOrEmpty(input, nameof(input));

ModerationOptions options = new();
CreateModerationOptions(BinaryData.FromObjectAsJson(input), ref options);
CreateModerationOptions(BinaryData.FromObjectAsJson(input, SourceGenerationContext.Default.String), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = ClassifyTextInputs(content, (RequestOptions)null);
Expand All @@ -107,7 +107,7 @@ public virtual async Task<ClientResult<ModerationCollection>> ClassifyTextInputs
Argument.AssertNotNullOrEmpty(inputs, nameof(inputs));

ModerationOptions options = new();
CreateModerationOptions(BinaryData.FromObjectAsJson(inputs), ref options);
CreateModerationOptions(BinaryData.FromObjectAsJson(inputs, SourceGenerationContext.Default.IEnumerableString), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = await ClassifyTextInputsAsync(content, (RequestOptions)null).ConfigureAwait(false);
Expand All @@ -123,7 +123,7 @@ public virtual ClientResult<ModerationCollection> ClassifyTextInputs(IEnumerable
Argument.AssertNotNullOrEmpty(inputs, nameof(inputs));

ModerationOptions options = new();
CreateModerationOptions(BinaryData.FromObjectAsJson(inputs), ref options);
CreateModerationOptions(BinaryData.FromObjectAsJson(inputs, SourceGenerationContext.Default.IEnumerableString), ref options);

using BinaryContent content = options.ToBinaryContent();
ClientResult result = ClassifyTextInputs(content, (RequestOptions)null);
Expand Down
2 changes: 1 addition & 1 deletion src/Custom/OpenAIModelFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public static EmbeddingCollection EmbeddingCollection(IEnumerable<Embedding> dat
public static Embedding Embedding(ReadOnlyMemory<float> vector = default, int index = default)
{
// TODO: Vector must be converted to base64-encoded string.
return new Embedding(index, BinaryData.FromObjectAsJson(vector), InternalEmbeddingObject.Embedding, serializedAdditionalRawData: null);
return new Embedding(index, BinaryData.FromObjectAsJson(vector, SourceGenerationContext.Default.ReadOnlyMemorySingle), InternalEmbeddingObject.Embedding, serializedAdditionalRawData: null);
}

}
4 changes: 2 additions & 2 deletions src/OpenAI.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<VersionPrefix>2.0.0</VersionPrefix>
<VersionSuffix>beta.2</VersionSuffix>

<TargetFrameworks>netstandard2.0;net6.0</TargetFrameworks>
<TargetFrameworks>netstandard2.0;net6.0;net8.0</TargetFrameworks>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reason for adding this explicit target?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that attributes only appeared in net7, I'm not sure that net6 issues all the appropriate warnings for trimming, so I always try to add net8 to see actual problems.
Also, adding net8 allows you to avoid using polyfills for it and opens up access to other optimizations, as well as up-to-date diagnostics / in some cases more complete nullability markup
I can remove this, it can be added later if needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stephentoub, what are your thoughts on this? We have been trying to minimize the number of cross-targets, as the matrix can become supper messy after a while, but maybe adding net 8 is not too bad?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll want net8.0 at least by Nov when net6.0 goes out of support. Might as well add it now, and gain the additional benefits.

<LangVersion>latest</LangVersion>

<!-- Sign the assembly with the specified key file. -->
Expand Down Expand Up @@ -46,6 +46,6 @@
<ItemGroup>
<PackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" PrivateAssets="All" />
<PackageReference Include="System.ClientModel" Version="1.1.0-beta.4" />
<PackageReference Include="System.Text.Json" Version="8.0.2" />
HavenDV marked this conversation as resolved.
Show resolved Hide resolved
<PackageReference Include="System.Memory.Data" Version="9.0.0-preview.4.24266.19" />
HavenDV marked this conversation as resolved.
Show resolved Hide resolved
</ItemGroup>
</Project>
12 changes: 12 additions & 0 deletions src/SourceGenerationContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text.Json.Serialization;

namespace OpenAI;

[JsonSourceGenerationOptions]
[JsonSerializable(typeof(string))]
[JsonSerializable(typeof(IEnumerable<string>))]
[JsonSerializable(typeof(IEnumerable<IEnumerable<int>>))]
[JsonSerializable(typeof(ReadOnlyMemory<float>))]
internal sealed partial class SourceGenerationContext : JsonSerializerContext;