-
Notifications
You must be signed in to change notification settings - Fork 146
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
using Microsoft.Extensions.AI; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Runtime.CompilerServices; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
|
||
namespace Microsoft.ML.OnnxRuntimeGenAI; | ||
|
||
/// <summary>Provides an <see cref="IChatClient"/> implementation based on ONNX Runtime GenAI.</summary> | ||
public sealed partial class ChatClient : IChatClient | ||
{ | ||
/// <summary>The options used to configure the instance.</summary> | ||
private readonly ChatClientConfiguration _config; | ||
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux-cpu-arm64-build
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux-cpu-arm64-build
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux-cpu-arm64-build
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux-cpu-arm64-build
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux-cuda-x64-build
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux-cuda-x64-build
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux-cuda-x64-build
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux-cuda-x64-build
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux_cpu_x64
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux_cpu_x64
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux_cpu_x64
Check warning on line 15 in src/csharp/ChatClient.cs GitHub Actions / linux_cpu_x64
|
||
/// <summary>The wrapped <see cref="Model"/>.</summary> | ||
private readonly Model _model; | ||
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary> | ||
private readonly Tokenizer _tokenizer; | ||
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary> | ||
private readonly bool _ownsModel; | ||
|
||
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary> | ||
/// <param name="modelPath">The file path to the model to load.</param> | ||
/// <param name="configuration">Options used to configure the client instance.</param> | ||
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception> | ||
public ChatClient(string modelPath, ChatClientConfiguration configuration) | ||
{ | ||
if (modelPath is null) | ||
{ | ||
throw new ArgumentNullException(nameof(modelPath)); | ||
} | ||
|
||
_ownsModel = true; | ||
_model = new Model(modelPath); | ||
_tokenizer = new Tokenizer(_model); | ||
|
||
Metadata = new(typeof(ChatClient).Namespace, new Uri($"file://{modelPath}"), modelPath); | ||
} | ||
|
||
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary> | ||
/// <param name="model">The model to employ.</param> | ||
/// <param name="ownsModel"> | ||
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should | ||
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>. | ||
/// The default is <see langword="true"/>. | ||
/// </param> | ||
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception> | ||
public ChatClient(Model model, bool ownsModel = true) | ||
{ | ||
if (model is null) | ||
{ | ||
throw new ArgumentNullException(nameof(model)); | ||
} | ||
|
||
_ownsModel = ownsModel; | ||
_model = model; | ||
_tokenizer = new Tokenizer(_model); | ||
|
||
Metadata = new("onnxruntime-genai"); | ||
} | ||
|
||
/// <inheritdoc/> | ||
public ChatClientMetadata Metadata { get; } | ||
|
||
/// <inheritdoc/> | ||
public void Dispose() | ||
{ | ||
_tokenizer.Dispose(); | ||
|
||
if (_ownsModel) | ||
{ | ||
_model.Dispose(); | ||
} | ||
} | ||
|
||
/// <inheritdoc/> | ||
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default) | ||
{ | ||
if (chatMessages is null) | ||
{ | ||
throw new ArgumentNullException(nameof(chatMessages)); | ||
} | ||
|
||
StringBuilder text = new(); | ||
await Task.Run(() => | ||
{ | ||
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages)); | ||
using GeneratorParams generatorParams = new(_model); | ||
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); | ||
|
||
using Generator generator = new(_model, generatorParams); | ||
generator.AppendTokenSequences(tokens); | ||
|
||
using var tokenizerStream = _tokenizer.CreateStream(); | ||
|
||
var completionId = Guid.NewGuid().ToString(); | ||
while (!generator.IsDone()) | ||
{ | ||
cancellationToken.ThrowIfCancellationRequested(); | ||
|
||
generator.GenerateNextToken(); | ||
|
||
ReadOnlySpan<int> outputSequence = generator.GetSequence(0); | ||
string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]); | ||
|
||
if (IsStop(next, options)) | ||
{ | ||
break; | ||
} | ||
|
||
text.Append(next); | ||
} | ||
}, cancellationToken); | ||
|
||
return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString())) | ||
{ | ||
CompletionId = Guid.NewGuid().ToString(), | ||
CreatedAt = DateTimeOffset.UtcNow, | ||
ModelId = Metadata.ModelId, | ||
}; | ||
} | ||
|
||
/// <inheritdoc/> | ||
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync( | ||
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) | ||
{ | ||
if (chatMessages is null) | ||
{ | ||
throw new ArgumentNullException(nameof(chatMessages)); | ||
} | ||
|
||
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages)); | ||
using GeneratorParams generatorParams = new(_model); | ||
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options); | ||
|
||
using Generator generator = new(_model, generatorParams); | ||
generator.AppendTokenSequences(tokens); | ||
|
||
using var tokenizerStream = _tokenizer.CreateStream(); | ||
|
||
var completionId = Guid.NewGuid().ToString(); | ||
while (!generator.IsDone()) | ||
{ | ||
string next = await Task.Run(() => | ||
{ | ||
generator.GenerateNextToken(); | ||
|
||
ReadOnlySpan<int> outputSequence = generator.GetSequence(0); | ||
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]); | ||
}, cancellationToken); | ||
|
||
if (IsStop(next, options)) | ||
{ | ||
break; | ||
} | ||
|
||
yield return new StreamingChatCompletionUpdate | ||
{ | ||
CompletionId = completionId, | ||
CreatedAt = DateTimeOffset.UtcNow, | ||
Role = ChatRole.Assistant, | ||
Text = next, | ||
}; | ||
} | ||
} | ||
|
||
/// <inheritdoc/> | ||
public object GetService(Type serviceType, object key = null) => | ||
key is not null ? null : | ||
serviceType == typeof(Model) ? _model : | ||
serviceType == typeof(Tokenizer) ? _tokenizer : | ||
serviceType?.IsInstanceOfType(this) is true ? this : | ||
null; | ||
|
||
/// <summary>Gets whether the specified token is a stop sequence.</summary> | ||
private bool IsStop(string token, ChatOptions options) => | ||
options?.StopSequences?.Contains(token) is true || | ||
Array.IndexOf(_config.StopSequences, token) >= 0; | ||
|
||
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary> | ||
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options) | ||
{ | ||
if (options is null) | ||
{ | ||
return; | ||
} | ||
|
||
if (options.MaxOutputTokens.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value); | ||
} | ||
|
||
if (options.Temperature.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("temperature", options.Temperature.Value); | ||
} | ||
|
||
if (options.TopP.HasValue || options.TopK.HasValue) | ||
{ | ||
if (options.TopP.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("top_p", options.TopP.Value); | ||
} | ||
|
||
if (options.TopK.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("top_k", options.TopK.Value); | ||
} | ||
} | ||
|
||
if (options.Seed.HasValue) | ||
{ | ||
generatorParams.SetSearchOption("random_seed", options.Seed.Value); | ||
} | ||
|
||
if (options.AdditionalProperties is { } props) | ||
{ | ||
foreach (var entry in props) | ||
{ | ||
switch (entry.Value) | ||
{ | ||
case int i: generatorParams.SetSearchOption(entry.Key, i); break; | ||
case long l: generatorParams.SetSearchOption(entry.Key, l); break; | ||
case float f: generatorParams.SetSearchOption(entry.Key, f); break; | ||
case double d: generatorParams.SetSearchOption(entry.Key, d); break; | ||
case bool b: generatorParams.SetSearchOption(entry.Key, b); break; | ||
} | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
using Microsoft.Extensions.AI; | ||
using System; | ||
using System.Collections.Generic; | ||
|
||
namespace Microsoft.ML.OnnxRuntimeGenAI; | ||
|
||
/// <summary>Provides configuration options used when constructing a <see cref="ChatClient"/>.</summary> | ||
/// <remarks> | ||
/// Every model has different requirements for stop sequences and prompt formatting. For best results, | ||
/// the configuration should be tailored to the exact nature of the model being used. For example, | ||
/// when using a Phi3 model, a configuration like the following may be used: | ||
/// <code> | ||
/// static ChatClientConfiguration CreateForPhi3() => | ||
/// new(["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"], | ||
/// (IEnumerable<ChatMessage> messages) => | ||
/// { | ||
/// StringBuilder prompt = new(); | ||
/// | ||
/// foreach (var message in messages) | ||
/// foreach (var content in message.Contents.OfType<TextContent>()) | ||
/// prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(tc.Text).Append("<|end|>\n"); | ||
/// | ||
/// return prompt.Append("<|assistant|>\n").ToString(); | ||
/// }); | ||
/// </code> | ||
/// </remarks> | ||
public sealed class ChatClientConfiguration | ||
{ | ||
private string[] _stopSequences; | ||
private Func<IEnumerable<ChatMessage>, string> _promptFormatter; | ||
|
||
/// <summary>Initializes a new instance of the <see cref="ChatClientConfiguration"/> class.</summary> | ||
/// <param name="stopSequences">The stop sequences used by the model.</param> | ||
/// <param name="promptFormatter">The function to use to format a list of messages for input into the model.</param> | ||
/// <exception cref="ArgumentNullException"><paramref name="stopSequences"/> is null.</exception> | ||
/// <exception cref="ArgumentNullException"><paramref name="promptFormatter"/> is null.</exception> | ||
public ChatClientConfiguration( | ||
string[] stopSequences, | ||
Func<IEnumerable<ChatMessage>, string> promptFormatter) | ||
{ | ||
if (stopSequences is null) | ||
{ | ||
throw new ArgumentNullException(nameof(stopSequences)); | ||
} | ||
|
||
if (promptFormatter is null) | ||
{ | ||
throw new ArgumentNullException(nameof(promptFormatter)); | ||
} | ||
|
||
StopSequences = stopSequences; | ||
PromptFormatter = promptFormatter; | ||
} | ||
|
||
/// <summary> | ||
/// Gets or sets stop sequences to use during generation. | ||
/// </summary> | ||
/// <remarks> | ||
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>. | ||
/// </remarks> | ||
public string[] StopSequences | ||
{ | ||
get => _stopSequences; | ||
set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value)); | ||
} | ||
|
||
/// <summary>Gets the function that creates a prompt string from the chat history.</summary> | ||
public Func<IEnumerable<ChatMessage>, string> PromptFormatter | ||
{ | ||
get => _promptFormatter; | ||
set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value)); | ||
} | ||
} |