Skip to content

Commit 7339e5b

Browse files
author
Aandreba
committed
adding stuff to stream
1 parent 58f512b commit 7339e5b

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

src/completion.rs

+28-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use super::{
55
};
66
use crate::{error::FallibleResponse, Client, OpenAiStream};
77
use chrono::{DateTime, Utc};
8-
use futures::{future::ready, Stream, TryStreamExt};
8+
use futures::{channel::mpsc::unbounded, StreamExt};
9+
use futures::{channel::mpsc::UnboundedSender, future::ready, Stream, TryStreamExt};
910
use reqwest::Response;
1011
use serde::{Deserialize, Serialize};
1112
use std::{borrow::Cow, collections::HashMap, marker::PhantomData, ops::RangeInclusive};
@@ -376,6 +377,32 @@ impl CompletionStream {
376377
}
377378

378379
impl CompletionStream {
380+
pub fn completions(self) -> impl Stream<Item = Result<impl Stream<Item = Choice>>> {
381+
let mut this = self.into_choice_stream();
382+
tokio::spawn(async move {
383+
let mut choices = HashMap::<u64, UnboundedSender<Choice>>::new();
384+
while let Some(choice) = this.next().await {
385+
if let Ok(choice) = choice {
386+
match choices.entry(choice.index) {
387+
std::collections::hash_map::Entry::Occupied(mut entry) => {
388+
let _ = entry.get_mut().unbounded_send(choice);
389+
return ready(Ok(None));
390+
}
391+
std::collections::hash_map::Entry::Vacant(entry) => {
392+
let (send, recv) = unbounded();
393+
let _ = send.unbounded_send(choice);
394+
entry.insert(send);
395+
return ready(Ok(Some(recv)));
396+
}
397+
}
398+
}
399+
}
400+
todo!()
401+
});
402+
403+
return this;
404+
}
405+
379406
/// Converts [`Stream<Item = Result<Completion>>`] into [`Stream<Item = Result<Choice>>`]
380407
pub fn into_choice_stream(self) -> impl Stream<Item = Result<Choice>> {
381408
return self.try_filter_map(|x| ready(Ok(x.choices.into_iter().next())));

src/lib.rs

-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use serde::{
1010
Deserialize, Deserializer,
1111
};
1212
use std::{
13-
backtrace::Backtrace,
1413
borrow::Cow,
1514
marker::PhantomData,
1615
ops::{Deref, DerefMut},

tests/completion.rs

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use futures::TryStreamExt;
1+
use futures::{StreamExt, TryStreamExt};
22
use libopenai::{
33
error::Result,
44
file::TemporaryFile,
@@ -24,24 +24,29 @@ async fn basic() -> Result<()> {
2424
return Ok(());
2525
}
2626

27-
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
27+
#[tokio::test]
2828
async fn stream() -> Result<()> {
2929
dotenv::dotenv().unwrap();
3030
tracing_subscriber::fmt::init();
3131
let client = Client::new(None, None)?;
3232

3333
let mut stream = Completion::builder(
34-
"text-ada-001",
34+
"text-davinci-003",
3535
"Whats' the best way to calculate a factorial?",
3636
)
37+
.n(2)
3738
.max_tokens(256)
3839
.build_stream(&client)
40+
.await?
41+
.completions()
42+
.try_for_each(|mut entry| async move {
43+
while let Some(entry) = entry.next().await {
44+
println!("{:?}", entry.text);
45+
}
46+
return Ok(());
47+
})
3948
.await?;
4049

41-
while let Some(entry) = stream.try_next().await? {
42-
println!("{entry:#?}");
43-
}
44-
4550
return Ok(());
4651
}
4752

0 commit comments

Comments
 (0)