diff --git a/Cargo.toml b/Cargo.toml index 544b85b9..c7557d06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,11 @@ [workspace] +resolver = "2" members = [ "example-service", "tarpc", "plugins", ] + +[profile.dev] +split-debuginfo = "unpacked" diff --git a/README.md b/README.md index ed04e6bf..ebf9223b 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,6 @@ process, and no context switching between different languages. Some other features of tarpc: - Pluggable transport: any type impling `Stream + Sink` can be used as a transport to connect the client and server. -- `Send + 'static` optional: if the transport doesn't require it, neither does tarpc! - Cascading cancellation: dropping a request will send a cancellation message to the server. The server will cease any unfinished work on the request, subsequently cancelling any of its own requests, repeating for the entire chain of transitive dependencies. @@ -51,6 +50,14 @@ Some other features of tarpc: requests sent by the server that use the request context will propagate the request deadline. For example, if a server is handling a request with a 10s deadline, does 2s of work, then sends a request to another server, that server will see an 8s deadline. +- Distributed tracing: tarpc is instrumented with [tracing](https://github.com/tokio-rs/tracing) + primitives extended with [OpenTelemetry](https://opentelemetry.io/) traces. Using a compatible + tracing subscriber like + [Jaeger](https://github.com/open-telemetry/opentelemetry-rust/tree/main/opentelemetry-jaeger), + each RPC can be traced through the client, server, amd other dependencies downstream of the + server. Even for applications not connected to a distributed tracing collector, the + instrumentation can also be ingested by regular loggers like + [env_logger](https://github.com/env-logger-rs/env_logger/). - Serde serialization: enabling the `serde1` Cargo feature will make service requests and responses `Serialize + Deserialize`. It's entirely optional, though: in-memory transports can be used, as well, so the price of serialization doesn't have to be paid when it's not needed. diff --git a/example-service/Cargo.toml b/example-service/Cargo.toml index 9cf715dd..08485740 100644 --- a/example-service/Cargo.toml +++ b/example-service/Cargo.toml @@ -13,13 +13,21 @@ readme = "../README.md" description = "An example server built on tarpc." [dependencies] -clap = "2.33" -env_logger = "0.8" +anyhow = "1.0" +clap = "3.0.0-beta.2" +log = "0.4" futures = "0.3" +opentelemetry = { version = "0.13", features = ["rt-tokio"] } +opentelemetry-jaeger = { version = "0.12", features = ["tokio"] } +rand = "0.8" serde = { version = "1.0" } tarpc = { path = "../tarpc", features = ["full"] } tokio = { version = "1", features = ["macros", "net", "rt-multi-thread"] } tokio-serde = { version = "0.8", features = ["json"] } +tracing = { version = "0.1" } +tracing-appender = "0.1" +tracing-opentelemetry = "0.12" +tracing-subscriber = "0.2" [lib] name = "service" diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 3f313dbc..efaae4fa 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -4,57 +4,49 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use clap::{App, Arg}; -use std::{io, net::SocketAddr}; +use clap::Clap; +use service::{init_tracing, WorldClient}; +use std::{net::SocketAddr, time::Duration}; use tarpc::{client, context, tokio_serde::formats::Json}; +use tokio::time::sleep; +use tracing::Instrument; + +#[derive(Clap)] +struct Flags { + /// Sets the server address to connect to. + #[clap(long)] + server_addr: SocketAddr, + /// Sets the name to say hello to. + #[clap(long)] + name: String, +} #[tokio::main] -async fn main() -> io::Result<()> { - env_logger::init(); - - let flags = App::new("Hello Client") - .version("0.1") - .author("Tim ") - .about("Say hello!") - .arg( - Arg::with_name("server_addr") - .long("server_addr") - .value_name("ADDRESS") - .help("Sets the server address to connect to.") - .required(true) - .takes_value(true), - ) - .arg( - Arg::with_name("name") - .short("n") - .long("name") - .value_name("STRING") - .help("Sets the name to say hello to.") - .required(true) - .takes_value(true), - ) - .get_matches(); - - let server_addr = flags.value_of("server_addr").unwrap(); - let server_addr = server_addr - .parse::() - .unwrap_or_else(|e| panic!(r#"--server_addr value "{}" invalid: {}"#, server_addr, e)); - - let name = flags.value_of("name").unwrap().into(); - - let mut transport = tarpc::serde_transport::tcp::connect(server_addr, Json::default); - transport.config_mut().max_frame_length(usize::MAX); +async fn main() -> anyhow::Result<()> { + let flags = Flags::parse(); + let _uninstall = init_tracing("Tarpc Example Client")?; + + let transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let client = service::WorldClient::new(client::Config::default(), transport.await?).spawn()?; - - // The client has an RPC method for each RPC defined in the annotated trait. It takes the same - // args as defined, with the addition of a Context, which is always the first arg. The Context - // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(context::current(), name).await?; - - println!("{}", hello); + let client = WorldClient::new(client::Config::default(), transport.await?).spawn()?; + + let hello = async move { + // Send the request twice, just to be safe! ;) + tokio::select! { + hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } + hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } + } + } + .instrument(tracing::info_span!("Two Hellos")) + .await; + + tracing::info!("{:?}", hello); + + // Let the background span processor finish. + sleep(Duration::from_micros(1)).await; + opentelemetry::global::shutdown_tracer_provider(); Ok(()) } diff --git a/example-service/src/lib.rs b/example-service/src/lib.rs index ddd75e0a..e6f4639e 100644 --- a/example-service/src/lib.rs +++ b/example-service/src/lib.rs @@ -4,6 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use std::env; +use tracing_subscriber::{fmt::format::FmtSpan, prelude::*}; + /// This is the service definition. It looks a lot like a trait definition. /// It defines one RPC, hello, which takes one arg, name, and returns a String. #[tarpc::service] @@ -11,3 +14,28 @@ pub trait World { /// Returns a greeting for name. async fn hello(name: String) -> String; } + +/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend. +pub fn init_tracing( + service_name: &str, +) -> anyhow::Result { + env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12"); + + let tracer = opentelemetry_jaeger::new_pipeline() + .with_service_name(service_name) + .with_max_packet_size(2usize.pow(13)) + .install_batch(opentelemetry::runtime::Tokio)?; + let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); + + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::from_default_env()) + .with( + tracing_subscriber::fmt::layer() + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .with_writer(non_blocking), + ) + .with(tracing_opentelemetry::layer().with_tracer(tracer)) + .try_init()?; + + Ok(guard) +} diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 0faadc8f..f681690c 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -4,18 +4,30 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. -use clap::{App, Arg}; +use clap::Clap; use futures::{future, prelude::*}; -use service::World; +use rand::{ + distributions::{Distribution, Uniform}, + thread_rng, +}; +use service::{init_tracing, World}; use std::{ - io, net::{IpAddr, SocketAddr}, + time::Duration, }; use tarpc::{ context, server::{self, Channel, Incoming}, tokio_serde::formats::Json, }; +use tokio::time; + +#[derive(Clap)] +struct Flags { + /// Sets the port number to listen on. + #[clap(long)] + port: u16, +} // This is the type that implements the generated World trait. It is the business logic // and is used to start the server. @@ -25,35 +37,19 @@ struct HelloServer(SocketAddr); #[tarpc::server] impl World for HelloServer { async fn hello(self, _: context::Context, name: String) -> String { + let sleep_time = + Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); + time::sleep(sleep_time).await; format!("Hello, {}! You are connected from {:?}.", name, self.0) } } #[tokio::main] -async fn main() -> io::Result<()> { - env_logger::init(); - - let flags = App::new("Hello Server") - .version("0.1") - .author("Tim ") - .about("Say hello!") - .arg( - Arg::with_name("port") - .short("p") - .long("port") - .value_name("NUMBER") - .help("Sets the port number to listen on") - .required(true) - .takes_value(true), - ) - .get_matches(); - - let port = flags.value_of("port").unwrap(); - let port = port - .parse() - .unwrap_or_else(|e| panic!(r#"--port value "{}" invalid: {}"#, port, e)); +async fn main() -> anyhow::Result<()> { + let flags = Flags::parse(); + let _uninstall = init_tracing("Tarpc Example Server")?; - let server_addr = (IpAddr::from([0, 0, 0, 0]), port); + let server_addr = (IpAddr::from([0, 0, 0, 0]), flags.port); // JSON transport is provided by the json_transport tarpc module. It makes it easy // to start up a serde-powered json serialization strategy over TCP. @@ -64,12 +60,12 @@ async fn main() -> io::Result<()> { .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.as_ref().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.as_ref().as_ref().peer_addr().unwrap()); - channel.requests().execute(server.serve()) + let server = HelloServer(channel.transport().peer_addr().unwrap()); + channel.execute(server.serve()) }) // Max 10 channels. .buffer_unordered(10) diff --git a/plugins/Cargo.toml b/plugins/Cargo.toml index 30964cef..314eae51 100644 --- a/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -30,4 +30,4 @@ proc-macro = true assert-type-eq = "0.1.0" futures = "0.3" serde = { version = "1.0", features = ["derive"] } -tarpc = { path = "../tarpc" } +tarpc = { path = "../tarpc", features = ["serde1"] } diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index b6f5b9e5..3ec2844e 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -267,6 +267,12 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { None }; + let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::>(); + let request_names = methods + .iter() + .map(|m| format!("{}.{}", ident, m)) + .collect::>(); + ServiceGenerator { response_fut_name, service_ident: ident, @@ -278,7 +284,8 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { vis, args, method_attrs: &rpcs.iter().map(|rpc| &*rpc.attrs).collect::>(), - method_idents: &rpcs.iter().map(|rpc| &rpc.ident).collect::>(), + method_idents: &methods, + request_names: &*request_names, attrs, rpcs, return_types: &rpcs @@ -441,6 +448,7 @@ struct ServiceGenerator<'a> { camel_case_idents: &'a [Ident], future_types: &'a [Type], method_idents: &'a [&'a Ident], + request_names: &'a [String], method_attrs: &'a [&'a [Attribute]], args: &'a [&'a [PatType]], return_types: &'a [&'a Type], @@ -524,6 +532,7 @@ impl<'a> ServiceGenerator<'a> { camel_case_idents, arg_pats, method_idents, + request_names, .. } = self; @@ -534,6 +543,16 @@ impl<'a> ServiceGenerator<'a> { type Resp = #response_ident; type Fut = #response_fut_ident; + fn method(&self, req: &#request_ident) -> Option<&'static str> { + Some(match req { + #( + #request_ident::#camel_case_idents{..} => { + #request_names + } + )* + }) + } + fn serve(self, ctx: tarpc::context::Context, req: #request_ident) -> Self::Fut { match req { #( @@ -714,6 +733,7 @@ impl<'a> ServiceGenerator<'a> { method_attrs, vis, method_idents, + request_names, args, return_types, arg_pats, @@ -729,7 +749,7 @@ impl<'a> ServiceGenerator<'a> { #vis fn #method_idents(&self, ctx: tarpc::context::Context, #( #args ),*) -> impl std::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; - let resp = self.0.call(ctx, request); + let resp = self.0.call(ctx, #request_names, request); async move { match resp.await? { #response_ident::#camel_case_idents(msg) => std::result::Result::Ok(msg), diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 31f4e0db..904f1ae6 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -30,26 +30,31 @@ anyhow = "1.0" fnv = "1.0" futures = "0.3" humantime = "2.0" -log = "0.4" pin-project = "1.0" -rand = "0.7" +rand = "0.8" serde = { optional = true, version = "1.0", features = ["derive"] } static_assertions = "1.1.0" tarpc-plugins = { path = "../plugins", version = "0.10" } tokio = { version = "1", features = ["time"] } tokio-util = { version = "0.6.3", features = ["time"] } tokio-serde = { optional = true, version = "0.8" } +tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] } +tracing-opentelemetry = { version = "0.12", default-features = false } +opentelemetry = { version = "0.13", default-features = false } + [dev-dependencies] assert_matches = "1.4" bincode = "1.3" bytes = { version = "1", features = ["serde"] } -env_logger = "0.8" flate2 = "1.0" futures-test = "0.3" -log = "0.4" +opentelemetry = { version = "0.13", default-features = false, features = ["rt-tokio"] } +opentelemetry-jaeger = { version = "0.12", features = ["tokio"] } pin-utils = "0.1.0-alpha" serde_bytes = "0.11" +tracing-appender = "0.1" +tracing-subscriber = "0.2" tokio = { version = "1", features = ["full", "test-util"] } tokio-serde = { version = "0.8", features = ["json", "bincode"] } trybuild = "1.0" @@ -63,7 +68,7 @@ name = "compression" required-features = ["serde-transport", "tcp"] [[example]] -name = "server_calling_server" +name = "tracing" required-features = ["full"] [[example]] diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index 3d94c751..8f9ad8c0 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -38,11 +38,10 @@ use futures::{ future::{self, AbortHandle}, prelude::*, }; -use log::info; use publisher::Publisher as _; use std::{ collections::HashMap, - io, + env, io, net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; @@ -54,6 +53,8 @@ use tarpc::{ }; use tokio::net::ToSocketAddrs; use tokio_serde::formats::Json; +use tracing::info; +use tracing_subscriber::prelude::*; pub mod subscriber { #[tarpc::service] @@ -83,10 +84,7 @@ impl subscriber::Subscriber for Subscriber { } async fn receive(self, _: context::Context, topic: String, message: String) { - info!( - "[{}] received message on topic '{}': {}", - self.local_addr, topic, message - ); + info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -120,7 +118,7 @@ impl Subscriber { let (handler, abort_handle) = future::abortable(handler.execute(subscriber.serve())); tokio::spawn(async move { match handler.await { - Ok(()) | Err(future::Aborted) => info!("[{}] subscriber shutdown.", local_addr), + Ok(()) | Err(future::Aborted) => info!(?local_addr, "subscriber shutdown."), } }); Ok(SubscriberHandle(abort_handle)) @@ -153,13 +151,13 @@ impl Publisher { subscriptions: self.clone().start_subscription_manager().await?, }; - info!("[{}] listening for publishers.", publisher_addrs.publisher); + info!(publisher_addr = %publisher_addrs.publisher, "listening for publishers.",); tokio::spawn(async move { // Because this is just an example, we know there will only be one publisher. In more // realistic code, this would be a loop to continually accept new publisher // connections. let publisher = connecting_publishers.next().await.unwrap().unwrap(); - info!("[{}] publisher connected.", publisher.peer_addr().unwrap()); + info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); server::BaseChannel::with_defaults(publisher) .execute(self.serve()) @@ -174,7 +172,7 @@ impl Publisher { .await? .filter_map(|r| future::ready(r.ok())); let new_subscriber_addr = connecting_subscribers.get_ref().local_addr(); - info!("[{}] listening for subscribers.", new_subscriber_addr); + info!(?new_subscriber_addr, "listening for subscribers."); tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { @@ -215,7 +213,7 @@ impl Publisher { }, ); - info!("[{}] subscribed to topics: {:?}", subscriber_addr, topics); + info!(%subscriber_addr, ?topics, "subscribed to new topics"); let mut subscriptions = self.subscriptions.write().unwrap(); for topic in topics { subscriptions @@ -235,9 +233,9 @@ impl Publisher { tokio::spawn(async move { if let Err(e) = client_dispatch.await { info!( - "[{}] subscriber connection broken: {:?}", - subscriber_addr, e - ) + %subscriber_addr, + error = %e, + "subscriber connection broken"); } // Don't clean up the subscriber until initialization is done. let _ = subscriber_ready.await; @@ -281,13 +279,31 @@ impl publisher::Publisher for Publisher { } } +/// Initializes an OpenTelemetry tracing subscriber with a Jaeger backend. +fn init_tracing(service_name: &str) -> anyhow::Result { + env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12"); + let tracer = opentelemetry_jaeger::new_pipeline() + .with_service_name(service_name) + .with_max_packet_size(2usize.pow(13)) + .install_batch(opentelemetry::runtime::Tokio)?; + + let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); + + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::from_default_env()) + .with(tracing_subscriber::fmt::layer().with_writer(non_blocking)) + .with(tracing_opentelemetry::layer().with_tracer(tracer)) + .try_init()?; + + Ok(guard) +} + #[tokio::main] async fn main() -> anyhow::Result<()> { - env_logger::init(); + let _uninstall = init_tracing("Pub/Sub")?; - let clients = Arc::new(Mutex::new(HashMap::new())); let addrs = Publisher { - clients, + clients: Arc::new(Mutex::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())), } .start() @@ -337,6 +353,7 @@ async fn main() -> anyhow::Result<()> { ) .await?; + opentelemetry::global::shutdown_tracer_provider(); info!("done."); Ok(()) diff --git a/tarpc/examples/server_calling_server.rs b/tarpc/examples/tracing.rs similarity index 69% rename from tarpc/examples/server_calling_server.rs rename to tarpc/examples/tracing.rs index 64fbd9ef..1b766448 100644 --- a/tarpc/examples/server_calling_server.rs +++ b/tarpc/examples/tracing.rs @@ -6,12 +6,13 @@ use crate::{add::Add as AddService, double::Double as DoubleService}; use futures::{future, prelude::*}; -use std::io; +use std::env; use tarpc::{ client, context, server::{BaseChannel, Incoming}, }; use tokio_serde::formats::Json; +use tracing_subscriber::prelude::*; pub mod add { #[tarpc::service] @@ -35,7 +36,6 @@ struct AddServer; #[tarpc::server] impl AddService for AddServer { async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { - log::info!("AddService {:#?}", context::current()); x + y } } @@ -48,18 +48,34 @@ struct DoubleServer { #[tarpc::server] impl DoubleService for DoubleServer { async fn double(self, _: context::Context, x: i32) -> Result { - let ctx = context::current(); - log::info!("DoubleService {:#?}", ctx); self.add_client - .add(ctx, x, x) + .add(context::current(), x, x) .await .map_err(|e| e.to_string()) } } +fn init_tracing(service_name: &str) -> anyhow::Result { + env::set_var("OTEL_BSP_MAX_EXPORT_BATCH_SIZE", "12"); + let tracer = opentelemetry_jaeger::new_pipeline() + .with_service_name(service_name) + .with_max_packet_size(2usize.pow(13)) + .install_batch(opentelemetry::runtime::Tokio)?; + + let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); + + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::from_default_env()) + .with(tracing_subscriber::fmt::layer().with_writer(non_blocking)) + .with(tracing_opentelemetry::layer().with_tracer(tracer)) + .try_init()?; + + Ok(guard) +} + #[tokio::main] -async fn main() -> io::Result<()> { - env_logger::init(); +async fn main() -> anyhow::Result<()> { + let _uninstall = init_tracing("tarpc_tracing_example")?; let add_listener = tarpc::serde_transport::tcp::listen("localhost:0", Json::default) .await? @@ -89,9 +105,11 @@ async fn main() -> io::Result<()> { double::DoubleClient::new(client::Config::default(), to_double_server).spawn()?; let ctx = context::current(); - log::info!("Client {:#?}", ctx); - for i in 1..=5 { - eprintln!("{:?}", double_client.double(ctx, i).await?); + for _ in 1..=5 { + tracing::info!("{:?}", double_client.double(ctx, 1).await?); } + + opentelemetry::global::shutdown_tracer_provider(); + Ok(()) } diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index f068bf2a..93214ba2 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -8,16 +8,13 @@ mod in_flight_requests; -use crate::{ - context, trace::SpanId, ClientMessage, PollContext, PollIo, Request, Response, Transport, -}; +use crate::{context, trace, ClientMessage, PollContext, PollIo, Request, Response, Transport}; use futures::{prelude::*, ready, stream::Fuse, task::*}; use in_flight_requests::InFlightRequests; -use log::{info, trace}; use pin_project::pin_project; use std::{ convert::TryFrom, - fmt, io, + fmt, io, mem, pin::Pin, sync::{ atomic::{AtomicUsize, Ordering}, @@ -25,6 +22,7 @@ use std::{ }, }; use tokio::sync::{mpsc, oneshot}; +use tracing::Span; /// Settings that control the behavior of the client. #[derive(Clone, Debug)] @@ -67,11 +65,9 @@ where #[cfg(feature = "tokio1")] #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))] pub fn spawn(self) -> io::Result { - use log::warn; - let dispatch = self .dispatch - .unwrap_or_else(move |e| warn!("Connection broken: {}", e)); + .unwrap_or_else(move |e| tracing::warn!("Connection broken: {}", e)); tokio::spawn(dispatch); Ok(self.client) } @@ -114,72 +110,70 @@ impl Clone for Channel { impl Channel { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that - /// resolves when the request is sent (not when the response is received). - fn send( + /// resolves to the response. + #[tracing::instrument( + name = "RPC", + skip(self, ctx, request_name, request), + fields( + rpc.trace_id = tracing::field::Empty, + otel.kind = "client", + otel.name = request_name) + )] + pub async fn call( &self, mut ctx: context::Context, + request_name: &str, request: Req, - ) -> impl Future>> + '_ { - // Convert the context to the call context. - ctx.trace_context.parent_id = Some(ctx.trace_context.span_id); - ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng()); - - let (response_completion, response) = oneshot::channel(); - let cancellation = self.cancellation.clone(); + ) -> io::Result { + let span = Span::current(); + ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + tracing::warn!( + "OpenTelemetry subscriber not installed; making unsampled child context." + ); + ctx.trace_context.new_child() + }); + span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id())); + let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - // DispatchResponse impls Drop to cancel in-flight requests. It should be created before + // ResponseGuard impls Drop to cancel in-flight requests. It should be created before // sending out the request; otherwise, the response future could be dropped after the - // request is sent out but before DispatchResponse is created, rendering the cancellation + // request is sent out but before ResponseGuard is created, rendering the cancellation // logic inactive. - let response = DispatchResponse { - response, + let response_guard = ResponseGuard { + response: &mut response, request_id, - cancellation: Some(cancellation), - ctx, + cancellation: &self.cancellation, }; - async move { - self.to_dispatch - .send(DispatchRequest { - ctx, - request_id, - request, - response_completion, - }) - .await - .map_err(|mpsc::error::SendError(_)| { - io::Error::from(io::ErrorKind::ConnectionReset) - })?; - Ok(response) - } - } - - /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that - /// resolves to the response. - pub async fn call(&self, ctx: context::Context, request: Req) -> io::Result { - let dispatch_response = self.send(ctx, request).await?; - dispatch_response.await + self.to_dispatch + .send(DispatchRequest { + ctx, + span, + request_id, + request, + response_completion, + }) + .await + .map_err(|mpsc::error::SendError(_)| io::Error::from(io::ErrorKind::ConnectionReset))?; + response_guard.response().await } } /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. -#[derive(Debug)] -struct DispatchResponse { - response: oneshot::Receiver>, - ctx: context::Context, - cancellation: Option, +struct ResponseGuard<'a, Resp> { + response: &'a mut oneshot::Receiver>, + cancellation: &'a RequestCancellation, request_id: u64, } -impl Future for DispatchResponse { - type Output = io::Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let resp = ready!(self.response.poll_unpin(cx)); - self.cancellation.take(); - Poll::Ready(match resp { +impl ResponseGuard<'_, Resp> { + async fn response(mut self) -> io::Result { + let response = (&mut self.response).await; + // Cancel drop logic once a response has been received. + mem::forget(self); + match response { Ok(resp) => Ok(resp.message?), Err(oneshot::error::RecvError { .. }) => { // The oneshot is Canceled when the dispatch task ends. In that case, @@ -187,27 +181,25 @@ impl Future for DispatchResponse { // propagating cancellation. Err(io::Error::from(io::ErrorKind::ConnectionReset)) } - }) + } } } // Cancels the request when dropped, if not already complete. -impl Drop for DispatchResponse { +impl Drop for ResponseGuard<'_, Resp> { fn drop(&mut self) { - if let Some(cancellation) = &mut self.cancellation { - // The receiver needs to be closed to handle the edge case that the request has not - // yet been received by the dispatch task. It is possible for the cancel message to - // arrive before the request itself, in which case the request could get stuck in the - // dispatch map forever if the server never responds (e.g. if the server dies while - // responding). Even if the server does respond, it will have unnecessarily done work - // for a client no longer waiting for a response. To avoid this, the dispatch task - // checks if the receiver is closed before inserting the request in the map. By - // closing the receiver before sending the cancel message, it is guaranteed that if the - // dispatch task misses an early-arriving cancellation message, then it will see the - // receiver as closed. - self.response.close(); - cancellation.cancel(self.request_id); - } + // The receiver needs to be closed to handle the edge case that the request has not + // yet been received by the dispatch task. It is possible for the cancel message to + // arrive before the request itself, in which case the request could get stuck in the + // dispatch map forever if the server never responds (e.g. if the server dies while + // responding). Even if the server does respond, it will have unnecessarily done work + // for a client no longer waiting for a response. To avoid this, the dispatch task + // checks if the receiver is closed before inserting the request in the map. By + // closing the receiver before sending the cancel message, it is guaranteed that if the + // dispatch task misses an early-arriving cancellation message, then it will see the + // receiver as closed. + self.response.close(); + self.cancellation.cancel(self.request_id); } } @@ -343,7 +335,7 @@ where cx: &mut Context<'_>, ) -> PollIo> { if self.in_flight_requests().len() >= self.config.max_in_flight_requests { - info!( + tracing::info!( "At in-flight request capacity ({}/{}).", self.in_flight_requests().len(), self.config.max_in_flight_requests @@ -360,10 +352,8 @@ where match ready!(self.pending_requests_mut().poll_recv(cx)) { Some(request) => { if request.response_completion.is_closed() { - trace!( - "[{}] Request canceled before being sent.", - request.ctx.trace_id() - ); + let _entered = request.span.enter(); + tracing::info!("AbortRequest"); continue; } @@ -381,14 +371,15 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> PollIo<(context::Context, u64)> { + ) -> PollIo<(context::Context, Span, u64)> { ready!(self.ensure_writeable(cx)?); loop { match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) { Some(request_id) => { - if let Some(ctx) = self.in_flight_requests().cancel_request(request_id) { - return Poll::Ready(Some(Ok((ctx, request_id)))); + if let Some((ctx, span)) = self.in_flight_requests().cancel_request(request_id) + { + return Poll::Ready(Some(Ok((ctx, span, request_id)))); } } None => return Poll::Ready(None), @@ -407,46 +398,56 @@ where } fn poll_write_request<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { - let dispatch_request = match ready!(self.as_mut().poll_next_request(cx)?) { + let DispatchRequest { + ctx, + span, + request_id, + request, + response_completion, + } = match ready!(self.as_mut().poll_next_request(cx)?) { Some(dispatch_request) => dispatch_request, None => return Poll::Ready(None), }; + let entered = span.enter(); // poll_next_request only returns Ready if there is room to buffer another request. // Therefore, we can call write_request without fear of erroring due to a full // buffer. - let request_id = dispatch_request.request_id; + let request_id = request_id; let request = ClientMessage::Request(Request { id: request_id, - message: dispatch_request.request, + message: request, context: context::Context { - deadline: dispatch_request.ctx.deadline, - trace_context: dispatch_request.ctx.trace_context, + deadline: ctx.deadline, + trace_context: ctx.trace_context, }, }); self.transport_pin_mut().start_send(request)?; + let deadline = ctx.deadline; + tracing::info!( + tarpc.deadline = %humantime::format_rfc3339(deadline), + "SendRequest" + ); + drop(entered); + self.in_flight_requests() - .insert_request( - request_id, - dispatch_request.ctx, - dispatch_request.response_completion, - ) + .insert_request(request_id, ctx, span, response_completion) .expect("Request IDs should be unique"); Poll::Ready(Some(Ok(()))) } fn poll_write_cancel<'a>(self: &'a mut Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> { - let (context, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { - Some((context, request_id)) => (context, request_id), + let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some(triple) => triple, None => return Poll::Ready(None), }; + let _entered = span.enter(); - let trace_id = *context.trace_id(); let cancel = ClientMessage::Cancel { trace_context: context.trace_context, request_id, }; self.transport_pin_mut().start_send(cancel)?; - trace!("[{}] Cancel message sent.", trace_id); + tracing::info!("CancelRequest"); Poll::Ready(Some(Ok(()))) } @@ -473,15 +474,15 @@ where .context("failed to write to transport")?, ) { (Poll::Ready(None), _) => { - info!("Shutdown: read half closed, so shutting down."); + tracing::info!("Shutdown: read half closed, so shutting down."); return Poll::Ready(Ok(())); } (read, Poll::Ready(None)) => { - if self.in_flight_requests().is_empty() { - info!("Shutdown: write half closed, and no requests in flight."); + if self.in_flight_requests.is_empty() { + tracing::info!("Shutdown: write half closed, and no requests in flight."); return Poll::Ready(Ok(())); } - info!( + tracing::info!( "Shutdown: write half closed, and {} requests in flight.", self.in_flight_requests().len() ); @@ -502,6 +503,7 @@ where #[derive(Debug)] struct DispatchRequest { pub ctx: context::Context, + pub span: Span, pub request_id: u64, pub request: Req, pub response_completion: oneshot::Sender>, @@ -518,16 +520,14 @@ struct CanceledRequests(mpsc::UnboundedReceiver); /// Returns a channel to send request cancellation messages. fn cancellations() -> (RequestCancellation, CanceledRequests) { // Unbounded because messages are sent in the drop fn. This is fine, because it's still - // bounded by the number of in-flight requests. Additionally, each request has a clone - // of the sender, so the bounded channel would have the same behavior, - // since it guarantees a slot. + // bounded by the number of in-flight requests. let (tx, rx) = mpsc::unbounded_channel(); (RequestCancellation(tx), CanceledRequests(rx)) } impl RequestCancellation { /// Cancels the request with ID `request_id`. - fn cancel(&mut self, request_id: u64) { + fn cancel(&self, request_id: u64) { let _ = self.0.send(request_id); } } @@ -549,8 +549,8 @@ impl Stream for CanceledRequests { #[cfg(test)] mod tests { use super::{ - cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation, - RequestDispatch, + cancellations, CanceledRequests, Channel, DispatchRequest, RequestCancellation, + RequestDispatch, ResponseGuard, }; use crate::{ client::{in_flight_requests::InFlightRequests, Config}, @@ -558,19 +558,46 @@ mod tests { transport::{self, channel::UnboundedChannel}, ClientMessage, Response, }; + use assert_matches::assert_matches; use futures::{prelude::*, task::*}; - use std::{pin::Pin, sync::atomic::AtomicUsize, sync::Arc}; + use std::{ + convert::TryFrom, + pin::Pin, + sync::atomic::{AtomicUsize, Ordering}, + sync::Arc, + }; use tokio::sync::{mpsc, oneshot}; + use tracing::Span; + + #[tokio::test] + async fn response_completes_request_future() { + let (mut dispatch, mut _channel, mut server_channel) = set_up(); + let cx = &mut Context::from_waker(&noop_waker_ref()); + let (tx, mut rx) = oneshot::channel(); + + dispatch + .in_flight_requests + .insert_request(0, context::current(), Span::current(), tx) + .unwrap(); + server_channel + .send(Response { + request_id: 0, + message: Ok("Resp".into()), + }) + .await + .unwrap(); + assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); + assert_matches!(rx.try_recv(), Ok(Response { request_id: 0, message: Ok(resp) }) if resp == "Resp"); + } #[tokio::test] async fn dispatch_response_cancels_on_drop() { let (cancellation, mut canceled_requests) = cancellations(); - let (_, response) = oneshot::channel(); - drop(DispatchResponse:: { - response, - cancellation: Some(cancellation), + let (_, mut response) = oneshot::channel(); + drop(ResponseGuard:: { + response: &mut response, + cancellation: &cancellation, request_id: 3, - ctx: context::current(), }); // resp's drop() is run, which should send a cancel message. let cx = &mut Context::from_waker(&noop_waker_ref()); @@ -580,23 +607,22 @@ mod tests { #[tokio::test] async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); - let (tx, response) = oneshot::channel(); + let (tx, mut response) = oneshot::channel(); tx.send(Response { request_id: 0, message: Ok("well done"), }) .unwrap(); - { - DispatchResponse { - response, - cancellation: Some(cancellation), - request_id: 3, - ctx: context::current(), - } - .await - .unwrap(); - // resp's drop() is run, but should not send a cancel message. + // resp's drop() is run, but should not send a cancel message. + ResponseGuard { + response: &mut response, + cancellation: &cancellation, + request_id: 3, } + .response() + .await + .unwrap(); + drop(cancellation); let cx = &mut Context::from_waker(&noop_waker_ref()); assert_eq!(canceled_requests.0.poll_recv(cx), Poll::Ready(None)); } @@ -604,12 +630,12 @@ mod tests { #[tokio::test] async fn stage_request() { let (mut dispatch, mut channel, _server_channel) = set_up(); - let dispatch = Pin::new(&mut dispatch); let cx = &mut Context::from_waker(&noop_waker_ref()); + let (tx, mut rx) = oneshot::channel(); - let _resp = send_request(&mut channel, "hi").await; + let _resp = send_request(&mut channel, "hi", tx, &mut rx).await; - let req = dispatch.poll_next_request(cx).ready(); + let req = dispatch.as_mut().poll_next_request(cx).ready(); assert!(req.is_some()); let req = req.unwrap(); @@ -621,10 +647,10 @@ mod tests { #[tokio::test] async fn stage_request_channel_dropped_doesnt_panic() { let (mut dispatch, mut channel, mut server_channel) = set_up(); - let mut dispatch = Pin::new(&mut dispatch); let cx = &mut Context::from_waker(&noop_waker_ref()); + let (tx, mut rx) = oneshot::channel(); - let _ = send_request(&mut channel, "hi").await; + let _ = send_request(&mut channel, "hi", tx, &mut rx).await; drop(channel); assert!(dispatch.as_mut().poll(cx).is_ready()); @@ -642,61 +668,68 @@ mod tests { #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { let (mut dispatch, mut channel, _server_channel) = set_up(); - let dispatch = Pin::new(&mut dispatch); let cx = &mut Context::from_waker(&noop_waker_ref()); + let (tx, mut rx) = oneshot::channel(); - let _ = send_request(&mut channel, "hi").await; + let _ = send_request(&mut channel, "hi", tx, &mut rx).await; // Drop the channel so polling returns none if no requests are currently ready. drop(channel); // Test that a request future dropped before it's processed by dispatch will cause the request // to not be added to the in-flight request map. - assert!(dispatch.poll_next_request(cx).ready().is_none()); + assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none()); } #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { let (mut dispatch, mut channel, _server_channel) = set_up(); let cx = &mut Context::from_waker(&noop_waker_ref()); - let mut dispatch = Pin::new(&mut dispatch); + let (tx, mut rx) = oneshot::channel(); - let req = send_request(&mut channel, "hi").await; + let req = send_request(&mut channel, "hi", tx, &mut rx).await; assert!(dispatch.as_mut().pump_write(cx).ready().is_some()); - assert!(!dispatch.in_flight_requests().is_empty()); + assert!(!dispatch.in_flight_requests.is_empty()); // Test that a request future dropped after it's processed by dispatch will cause the request // to be removed from the in-flight request map. drop(req); - if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() { - // ok - } else { - panic!("Expected request to be cancelled") - }; - assert!(dispatch.in_flight_requests().is_empty()); + assert_matches!( + dispatch.as_mut().poll_next_cancellation(cx), + Poll::Ready(Some(Ok(_))) + ); + assert!(dispatch.in_flight_requests.is_empty()); } #[tokio::test] async fn stage_request_response_closed_skipped() { let (mut dispatch, mut channel, _server_channel) = set_up(); - let dispatch = Pin::new(&mut dispatch); let cx = &mut Context::from_waker(&noop_waker_ref()); + let (tx, mut rx) = oneshot::channel(); // Test that a request future that's closed its receiver but not yet canceled its request -- // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request // map. - let mut resp = send_request(&mut channel, "hi").await; + let resp = send_request(&mut channel, "hi", tx, &mut rx).await; resp.response.close(); - assert!(dispatch.poll_next_request(cx).is_pending()); + assert!(dispatch.as_mut().poll_next_request(cx).is_pending()); } fn set_up() -> ( - RequestDispatch, ClientMessage>>, + Pin< + Box< + RequestDispatch< + String, + String, + UnboundedChannel, ClientMessage>, + >, + >, + >, Channel, UnboundedChannel, Response>, ) { - let _ = env_logger::try_init(); + let _ = tracing_subscriber::fmt().with_test_writer().try_init(); let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancel_tx, canceled_requests) = mpsc::unbounded_channel(); @@ -717,17 +750,31 @@ mod tests { next_request_id: Arc::new(AtomicUsize::new(0)), }; - (dispatch, channel, server_channel) + (Box::pin(dispatch), channel, server_channel) } - async fn send_request( - channel: &mut Channel, + async fn send_request<'a>( + channel: &'a mut Channel, request: &str, - ) -> DispatchResponse { - channel - .send(context::current(), request.to_string()) - .await - .unwrap() + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> ResponseGuard<'a, String> { + let request_id = + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + let request = DispatchRequest { + ctx: context::current(), + span: Span::current(), + request_id, + request: request.to_string(), + response_completion, + }; + channel.to_dispatch.send(request).await.unwrap(); + + ResponseGuard { + response, + cancellation: &channel.cancellation, + request_id, + } } async fn send_response( diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index cdfa63df..d4d49e12 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -5,7 +5,6 @@ use crate::{ }; use fnv::FnvHashMap; use futures::ready; -use log::{debug, trace}; use std::{ collections::hash_map, io, @@ -13,6 +12,7 @@ use std::{ }; use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; +use tracing::Span; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -33,6 +33,7 @@ impl Default for InFlightRequests { #[derive(Debug)] struct RequestData { ctx: context::Context, + span: Span, response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, @@ -59,20 +60,16 @@ impl InFlightRequests { &mut self, request_id: u64, ctx: context::Context, + span: Span, response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { let timeout = ctx.deadline.time_until(); - trace!( - "[{}] Queuing request with timeout {:?}.", - ctx.trace_id(), - timeout, - ); - let deadline_key = self.deadlines.insert(request_id, timeout); vacant.insert(RequestData { ctx, + span, response_completion, deadline_key, }); @@ -85,15 +82,15 @@ impl InFlightRequests { /// Removes a request without aborting. Returns true iff the request was found. pub fn complete_request(&mut self, response: Response) -> bool { if let Some(request_data) = self.request_data.remove(&response.request_id) { + let _entered = request_data.span.enter(); + tracing::info!("ReceiveResponse"); self.request_data.compact(0.1); - - trace!("[{}] Received response.", request_data.ctx.trace_id()); self.deadlines.remove(&request_data.deadline_key); - request_data.complete(response); + let _ = request_data.response_completion.send(response); return true; } - debug!( + tracing::debug!( "No in-flight request found for request_id = {}.", response.request_id ); @@ -104,12 +101,11 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); - trace!("[{}] Cancelling request.", request_data.ctx.trace_id()); self.deadlines.remove(&request_data.deadline_key); - Some(request_data.ctx) + Some((request_data.ctx, request_data.span)) } else { None } @@ -122,8 +118,12 @@ impl InFlightRequests { Some(Ok(expired)) => { let request_id = expired.into_inner(); if let Some(request_data) = self.request_data.remove(&request_id) { + let _entered = request_data.span.enter(); + tracing::error!("DeadlineExceeded"); self.request_data.compact(0.1); - request_data.complete(Self::deadline_exceeded_error(request_id)); + let _ = request_data + .response_completion + .send(Self::deadline_exceeded_error(request_id)); } Some(Ok(request_id)) } @@ -142,21 +142,3 @@ impl InFlightRequests { } } } - -/// When InFlightRequests is dropped, any outstanding requests are completed with a -/// deadline-exceeded error. -impl Drop for InFlightRequests { - fn drop(&mut self) { - let deadlines = &mut self.deadlines; - for (_, request_data) in self.request_data.drain() { - let expired = deadlines.remove(&request_data.deadline_key); - request_data.complete(Self::deadline_exceeded_error(expired.into_inner())); - } - } -} - -impl RequestData { - fn complete(self, response: Response) { - let _ = self.response_completion.send(response); - } -} diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 30ef1c98..b6e89d77 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -8,8 +8,13 @@ //! client to server and is used by the server to enforce response deadlines. use crate::trace::{self, TraceId}; +use opentelemetry::trace::TraceContextExt; use static_assertions::assert_impl_all; -use std::time::{Duration, SystemTime}; +use std::{ + convert::TryFrom, + time::{Duration, SystemTime}, +}; +use tracing_opentelemetry::OpenTelemetrySpanExt; /// A request context that carries request-scoped information like deadlines and trace information. /// It is sent from client to server and is used by the server to enforce response deadlines. @@ -33,10 +38,6 @@ pub struct Context { assert_impl_all!(Context: Send, Sync); -tokio::task_local! { - static CURRENT_CONTEXT: Context; -} - fn ten_seconds_from_now() -> SystemTime { SystemTime::now() + Duration::from_secs(10) } @@ -46,64 +47,56 @@ pub fn current() -> Context { Context::current() } -impl Context { - /// Returns a Context containing a new root trace context and a default deadline. - pub fn new_root() -> Self { - Self { - deadline: ten_seconds_from_now(), - trace_context: trace::Context::new_root(), - } +#[derive(Clone)] +struct Deadline(SystemTime); + +impl Default for Deadline { + fn default() -> Self { + Self(ten_seconds_from_now()) } +} +impl Context { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { - CURRENT_CONTEXT - .try_with(Self::clone) - .unwrap_or_else(|_| Self::new_root()) + let span = tracing::Span::current(); + Self { + trace_context: trace::Context::try_from(&span) + .unwrap_or_else(|_| trace::Context::default()), + deadline: span + .context() + .get::() + .cloned() + .unwrap_or_default() + .0, + } } /// Returns the ID of the request-scoped trace. pub fn trace_id(&self) -> &TraceId { &self.trace_context.trace_id } - - /// Run a future with this context as the current context. - pub async fn scope(self, f: F) -> F::Output - where - F: std::future::Future, - { - CURRENT_CONTEXT.scope(self, f).await - } } -#[cfg(test)] -use { - assert_matches::assert_matches, futures::prelude::*, futures_test::task::noop_context, - std::task::Poll, -}; - -#[test] -fn context_current_has_no_parent() { - let ctx = current(); - assert_matches!(ctx.trace_context.parent_id, None); -} - -#[test] -fn context_root_has_no_parent() { - let ctx = Context::new_root(); - assert_matches!(ctx.trace_context.parent_id, None); +/// An extension trait for [`tracing::Span`] for propagating tarpc Contexts. +pub(crate) trait SpanExt { + /// Sets the given context on this span. Newly-created spans will be children of the given + /// context's trace context. + fn set_context(&self, context: &Context); } -#[test] -fn context_scope() { - let ctx = Context::new_root(); - let mut ctx_copy = Box::pin(ctx.scope(async { current() })); - assert_matches!(ctx_copy.poll_unpin(&mut noop_context()), - Poll::Ready(Context { - deadline, - trace_context: trace::Context { trace_id, span_id, parent_id }, - }) if deadline == ctx.deadline - && trace_id == ctx.trace_context.trace_id - && span_id == ctx.trace_context.span_id - && parent_id == ctx.trace_context.parent_id); +impl SpanExt for tracing::Span { + fn set_context(&self, context: &Context) { + self.set_parent( + opentelemetry::Context::new() + .with_remote_span_context(opentelemetry::trace::SpanContext::new( + opentelemetry::trace::TraceId::from(context.trace_context.trace_id), + opentelemetry::trace::SpanId::from(context.trace_context.span_id), + context.trace_context.sampling_decision as u8, + true, + opentelemetry::trace::TraceState::default(), + )) + .with_value(Deadline(context.deadline)), + ); + } } diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 693a0ba8..e1e9998d 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -199,6 +199,7 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #[cfg(feature = "serde1")] +#[doc(hidden)] pub use serde; #[cfg(feature = "serde-transport")] diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index b6dce520..e2a357ad 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -6,7 +6,10 @@ //! Provides a server that concurrently handles many connections sending multiplexed requests. -use crate::{context, ClientMessage, PollIo, Request, Response, Transport}; +use crate::{ + context::{self, SpanExt}, + trace, ClientMessage, PollIo, Request, Response, Transport, +}; use futures::{ future::{AbortRegistration, Abortable}, prelude::*, @@ -14,12 +17,11 @@ use futures::{ stream::Fuse, task::*, }; -use humantime::format_rfc3339; use in_flight_requests::{AlreadyExistsError, InFlightRequests}; -use log::{debug, info, trace}; use pin_project::pin_project; -use std::{fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; +use std::{convert::TryFrom, fmt, hash::Hash, io, marker::PhantomData, pin::Pin, time::SystemTime}; use tokio::sync::mpsc; +use tracing::{info_span, instrument::Instrument, Span}; mod filter; mod in_flight_requests; @@ -67,6 +69,11 @@ pub trait Serve { /// Type of response future. type Fut: Future; + /// Extracts a method name from the request. + fn method(&self, _request: &Req) -> Option<&'static str> { + None + } + /// Responds to a single request. fn serve(self, ctx: context::Context, req: Req) -> Self::Fut; } @@ -219,12 +226,18 @@ where /// Type of response sink item. type Resp; + /// The wrapped transport. + type Transport; + /// Configuration of the channel. fn config(&self) -> &Config; /// Returns the number of in-flight requests over this channel. fn in_flight_requests(&self) -> usize; + /// Returns the transport underlying the channel. + fn transport(&self) -> &Self::Transport; + /// Caps the number of concurrent requests to `limit`. fn max_concurrent_requests(self, limit: usize) -> Throttler where @@ -240,6 +253,7 @@ where self: Pin<&mut Self>, id: u64, deadline: SystemTime, + span: Span, ) -> Result; /// Returns a stream of requests that automatically handle request cancellation and response @@ -294,12 +308,9 @@ where loop { let expiration_status = match self.in_flight_requests_mut().poll_expired(cx)? { - Poll::Ready(Some(request_id)) => { - // No need to send a response, since the client wouldn't be waiting for one - // anymore. - debug!("Request {} did not complete before deadline", request_id); - Ready - } + // No need to send a response, since the client wouldn't be waiting for one + // anymore. + Poll::Ready(Some(_)) => Ready, Poll::Ready(None) => Closed, Poll::Pending => Pending, }; @@ -313,18 +324,10 @@ where trace_context, request_id, } => { - if self.in_flight_requests_mut().cancel_request(request_id) { - let remaining = self.in_flight_requests.len(); - trace!( - "[{}] Request canceled. In-flight requests = {}", - trace_context.trace_id, - remaining, - ); - } else { - trace!( - "[{}] Received cancellation, but response handler \ - is already complete.", - trace_context.trace_id, + if !self.in_flight_requests_mut().cancel_request(request_id) { + tracing::trace!( + rpc.trace_id = %trace_context.trace_id, + "Received cancellation, but response handler is already complete.", ); } Ready @@ -354,11 +357,19 @@ where } fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { - self.as_mut() + if let Some(span) = self + .as_mut() .project() .in_flight_requests - .remove_request(response.request_id); - self.project().transport.start_send(response) + .remove_request(response.request_id) + { + let _entered = span.enter(); + tracing::info!("SendResponse"); + self.project().transport.start_send(response) + } else { + // If the request isn't tracked anymore, there's no need to send the response. + Ok(()) + } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -382,6 +393,7 @@ where { type Req = Req; type Resp = Resp; + type Transport = T; fn config(&self) -> &Config { &self.config @@ -391,14 +403,19 @@ where self.in_flight_requests.len() } + fn transport(&self) -> &Self::Transport { + self.get_ref() + } + fn start_request( self: Pin<&mut Self>, id: u64, deadline: SystemTime, + span: Span, ) -> Result { self.project() .in_flight_requests - .start_request(id, deadline) + .start_request(id, deadline, span) } } @@ -412,9 +429,9 @@ where #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver<(context::Context, Response)>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender<(context::Context, Response)>, + responses_tx: mpsc::Sender>, } impl Requests @@ -429,7 +446,7 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver<(context::Context, Response)> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } @@ -439,33 +456,45 @@ where ) -> PollIo> { loop { match ready!(self.channel_pin_mut().poll_next(cx)?) { - Some(request) => { - trace!( - "[{}] Handling request with deadline {}.", - request.context.trace_id(), - format_rfc3339(request.context.deadline), + Some(mut request) => { + let span = info_span!( + "RPC", + rpc.trace_id = %request.context.trace_id(), + otel.kind = "server", + otel.name = tracing::field::Empty, ); - - match self - .channel_pin_mut() - .start_request(request.id, request.context.deadline) - { + span.set_context(&request.context); + request.context.trace_context = + trace::Context::try_from(&span).unwrap_or_else(|_| { + tracing::trace!( + "OpenTelemetry subscriber not installed; making unsampled + child context." + ); + request.context.trace_context.new_child() + }); + let entered = span.enter(); + tracing::info!("ReceiveRequest"); + let start = self.channel_pin_mut().start_request( + request.id, + request.context.deadline, + span.clone(), + ); + match start { Ok(abort_registration) => { + let response_tx = self.responses_tx.clone(); + drop(entered); return Poll::Ready(Some(Ok(InFlightRequest { request, - response_tx: self.responses_tx.clone(), + response_tx, abort_registration, - }))) + span, + }))); } // Instead of closing the channel if a duplicate request is sent, just // ignore it, since it's already being processed. Note that we cannot // return Poll::Pending here, since nothing has scheduled a wakeup yet. Err(AlreadyExistsError) => { - info!( - "[{}] Request ID {} delivered more than once.", - request.context.trace_id(), - request.id - ); + tracing::trace!("DuplicateRequest"); continue; } } @@ -481,12 +510,7 @@ where read_half_closed: bool, ) -> PollIo<()> { match self.as_mut().poll_next_response(cx)? { - Poll::Ready(Some((context, response))) => { - trace!( - "[{}] Staging response. In-flight requests = {}.", - context.trace_id(), - self.channel.in_flight_requests(), - ); + Poll::Ready(Some(response)) => { // A Ready result from poll_next_response means the Channel is ready to be written // to. Therefore, we can call start_send without worry of a full buffer. self.channel_pin_mut().start_send(response)?; @@ -520,7 +544,7 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> PollIo<(context::Context, Response)> { + ) -> PollIo> { ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { @@ -555,8 +579,9 @@ where #[derive(Debug)] pub struct InFlightRequest { request: Request, - response_tx: mpsc::Sender<(context::Context, Response)>, + response_tx: mpsc::Sender>, abort_registration: AbortRegistration, + span: Span, } impl InFlightRequest { @@ -576,36 +601,38 @@ impl InFlightRequest { /// message](ClientMessage::Cancel) for this request. /// 2. The request [deadline](crate::context::Context::deadline) is reached. /// 3. The service function completes. - pub fn execute(self, serve: S) -> impl Future + pub async fn execute(self, serve: S) where S: Serve, { let Self { abort_registration, - request, - response_tx, - } = self; - Abortable::new( - async move { - let Request { + request: + Request { context, message, id: request_id, - } = request; - context - .scope(async { - let response = serve.serve(context, message).await; - let response = Response { - request_id, - message: Ok(response), - }; - let _ = response_tx.send((context, response)).await; - }) - .await; + }, + response_tx, + span, + } = self; + let method = serve.method(&message); + span.record("otel.name", &method.unwrap_or("")); + let _ = Abortable::new( + async move { + let response = serve.serve(context, message).await; + tracing::info!("CompleteRequest"); + let response = Response { + request_id, + message: Ok(response), + }; + let _ = response_tx.send(response).await; + tracing::info!("BufferResponse"); }, abort_registration, ) - .unwrap_or_else(|_| {}) + .instrument(span) + .await; } } @@ -711,7 +738,7 @@ where while let Some(channel) = ready!(self.inner_pin_mut().poll_next(cx)) { tokio::spawn(channel.execute(self.serve.clone())); } - info!("Server shutting down."); + tracing::info!("Server shutting down."); Poll::Ready(()) } } @@ -737,7 +764,7 @@ where }); } Err(e) => { - info!("Requests stream errored out: {}", e); + tracing::info!("Requests stream errored out: {}", e); break; } } @@ -823,10 +850,12 @@ mod tests { channel .as_mut() - .start_request(0, SystemTime::now()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); assert_matches!( - channel.as_mut().start_request(0, SystemTime::now()), + channel + .as_mut() + .start_request(0, SystemTime::now(), Span::current()), Err(AlreadyExistsError) ); } @@ -838,11 +867,11 @@ mod tests { tokio::time::pause(); let abort_registration0 = channel .as_mut() - .start_request(0, SystemTime::now()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); let abort_registration1 = channel .as_mut() - .start_request(1, SystemTime::now()) + .start_request(1, SystemTime::now(), Span::current()) .unwrap(); tokio::time::advance(std::time::Duration::from_secs(1000)).await; @@ -861,7 +890,11 @@ mod tests { tokio::time::pause(); let abort_registration = channel .as_mut() - .start_request(0, SystemTime::now() + Duration::from_millis(100)) + .start_request( + 0, + SystemTime::now() + Duration::from_millis(100), + Span::current(), + ) .unwrap(); tx.send(ClientMessage::Cancel { @@ -886,7 +919,11 @@ mod tests { tokio::time::pause(); let _abort_registration = channel .as_mut() - .start_request(0, SystemTime::now() + Duration::from_millis(100)) + .start_request( + 0, + SystemTime::now() + Duration::from_millis(100), + Span::current(), + ) .unwrap(); drop(tx); @@ -924,7 +961,7 @@ mod tests { tokio::time::pause(); let abort_registration = channel .as_mut() - .start_request(0, SystemTime::now()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); tokio::time::advance(std::time::Duration::from_secs(1000)).await; @@ -943,7 +980,7 @@ mod tests { channel .as_mut() - .start_request(0, SystemTime::now()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); assert_eq!(channel.in_flight_requests(), 1); channel @@ -961,6 +998,11 @@ mod tests { let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); // Response written to the transport. + requests + .as_mut() + .channel_pin_mut() + .start_request(0, SystemTime::now(), Span::current()) + .unwrap(); requests .as_mut() .channel_pin_mut() @@ -975,20 +1017,17 @@ mod tests { .as_mut() .project() .responses_tx - .send(( - context::current(), - Response { - request_id: 1, - message: Ok(()), - }, - )) + .send(Response { + request_id: 1, + message: Ok(()), + }) .await .unwrap(); requests .as_mut() .channel_pin_mut() - .start_request(1, SystemTime::now()) + .start_request(1, SystemTime::now(), Span::current()) .unwrap(); assert_matches!( @@ -1002,6 +1041,11 @@ mod tests { let (mut requests, _tx) = test_bounded_requests::<(), ()>(0); // Response written to the transport. + requests + .as_mut() + .channel_pin_mut() + .start_request(0, SystemTime::now(), Span::current()) + .unwrap(); requests .as_mut() .channel_pin_mut() @@ -1014,22 +1058,18 @@ mod tests { // Response waiting to be written. requests .as_mut() - .project() - .responses_tx - .send(( - context::current(), - Response { - request_id: 1, - message: Ok(()), - }, - )) - .await + .channel_pin_mut() + .start_request(1, SystemTime::now(), Span::current()) .unwrap(); - requests .as_mut() - .channel_pin_mut() - .start_request(1, SystemTime::now()) + .project() + .responses_tx + .send(Response { + request_id: 1, + message: Ok(()), + }) + .await .unwrap(); assert_matches!( diff --git a/tarpc/src/server/filter.rs b/tarpc/src/server/filter.rs index 18935e26..b5b5a958 100644 --- a/tarpc/src/server/filter.rs +++ b/tarpc/src/server/filter.rs @@ -10,7 +10,6 @@ use crate::{ }; use fnv::FnvHashMap; use futures::{future::AbortRegistration, prelude::*, ready, stream::Fuse, task::*}; -use log::{debug, info, trace}; use pin_project::pin_project; use std::sync::{Arc, Weak}; use std::{ @@ -18,6 +17,7 @@ use std::{ time::SystemTime, }; use tokio::sync::mpsc; +use tracing::{debug, info, trace}; /// A single-threaded filter that drops channels based on per-key limits. #[pin_project] @@ -103,6 +103,7 @@ where { type Req = C::Req; type Resp = C::Resp; + type Transport = C::Transport; fn config(&self) -> &server::Config { self.inner.config() @@ -112,12 +113,17 @@ where self.inner.in_flight_requests() } + fn transport(&self) -> &Self::Transport { + self.inner.transport() + } + fn start_request( mut self: Pin<&mut Self>, id: u64, deadline: SystemTime, + span: tracing::Span, ) -> Result { - self.inner_pin_mut().start_request(id, deadline) + self.inner_pin_mut().start_request(id, deadline, span) } } @@ -171,11 +177,10 @@ where let tracker = self.as_mut().increment_channels_for_key(key.clone())?; trace!( - "[{}] Opening channel ({}/{}) channels for key.", - key, - Arc::strong_count(&tracker), - self.channels_per_key - ); + channel_filter_key = %key, + open_channels = Arc::strong_count(&tracker), + max_open_channels = self.channels_per_key, + "Opening channel"); Ok(TrackedChannel { tracker, @@ -200,9 +205,10 @@ where let count = o.get().strong_count(); if count >= TryFrom::try_from(*self_.channels_per_key).unwrap() { info!( - "[{}] Opened max channels from key ({}/{}).", - key, count, self_.channels_per_key - ); + channel_filter_key = %key, + open_channels = count, + max_open_channels = *self_.channels_per_key, + "At open channel limit"); Err(key) } else { Ok(o.get().upgrade().unwrap_or_else(|| { @@ -233,7 +239,9 @@ where let self_ = self.project(); match ready!(self_.dropped_keys.poll_recv(cx)) { Some(key) => { - debug!("All channels dropped for key [{}]", key); + debug!( + channel_filter_key = %key, + "All channels dropped"); self_.key_counts.remove(&key); self_.key_counts.compact(0.1); Poll::Ready(()) diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index 3399e262..5c251621 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -14,6 +14,7 @@ use std::{ time::SystemTime, }; use tokio_util::time::delay_queue::{self, DelayQueue}; +use tracing::Span; /// A data structure that tracks in-flight requests. It aborts requests, /// either on demand or when a request deadline expires. @@ -23,13 +24,15 @@ pub struct InFlightRequests { deadlines: DelayQueue, } -#[derive(Debug)] /// Data needed to clean up a single in-flight request. +#[derive(Debug)] struct RequestData { /// Aborts the response handler for the associated request. abort_handle: AbortHandle, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, + /// The client span. + span: Span, } /// An error returned when a request attempted to start with the same ID as a request already @@ -48,6 +51,7 @@ impl InFlightRequests { &mut self, request_id: u64, deadline: SystemTime, + span: Span, ) -> Result { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { @@ -57,6 +61,7 @@ impl InFlightRequests { vacant.insert(RequestData { abort_handle, deadline_key, + span, }); Ok(abort_registration) } @@ -66,12 +71,17 @@ impl InFlightRequests { /// Cancels an in-flight request. Returns true iff the request was found. pub fn cancel_request(&mut self, request_id: u64) -> bool { - if let Some(request_data) = self.request_data.remove(&request_id) { + if let Some(RequestData { + span, + abort_handle, + deadline_key, + }) = self.request_data.remove(&request_id) + { + let _entered = span.enter(); self.request_data.compact(0.1); - - request_data.abort_handle.abort(); - self.deadlines.remove(&request_data.deadline_key); - + abort_handle.abort(); + self.deadlines.remove(&deadline_key); + tracing::info!("ReceiveCancel"); true } else { false @@ -80,15 +90,13 @@ impl InFlightRequests { /// Removes a request without aborting. Returns true iff the request was found. /// This method should be used when a response is being sent. - pub fn remove_request(&mut self, request_id: u64) -> bool { + pub fn remove_request(&mut self, request_id: u64) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); - self.deadlines.remove(&request_data.deadline_key); - - true + Some(request_data.span) } else { - false + None } } @@ -96,9 +104,14 @@ impl InFlightRequests { pub fn poll_expired(&mut self, cx: &mut Context) -> PollIo { Poll::Ready(match ready!(self.deadlines.poll_expired(cx)) { Some(Ok(expired)) => { - if let Some(request_data) = self.request_data.remove(expired.get_ref()) { + if let Some(RequestData { + abort_handle, span, .. + }) = self.request_data.remove(expired.get_ref()) + { + let _entered = span.enter(); self.request_data.compact(0.1); - request_data.abort_handle.abort(); + abort_handle.abort(); + tracing::error!("DeadlineExceeded"); } Some(Ok(expired.into_inner())) } @@ -133,7 +146,7 @@ mod tests { let mut in_flight_requests = InFlightRequests::default(); assert_eq!(in_flight_requests.len(), 0); in_flight_requests - .start_request(0, SystemTime::now()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); assert_eq!(in_flight_requests.len(), 1); } @@ -142,7 +155,7 @@ mod tests { async fn polling_expired_aborts() { let mut in_flight_requests = InFlightRequests::default(); let abort_registration = in_flight_requests - .start_request(0, SystemTime::now()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); @@ -164,7 +177,7 @@ mod tests { async fn cancel_request_aborts() { let mut in_flight_requests = InFlightRequests::default(); let abort_registration = in_flight_requests - .start_request(0, SystemTime::now()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); @@ -180,11 +193,11 @@ mod tests { async fn remove_request_doesnt_abort() { let mut in_flight_requests = InFlightRequests::default(); let abort_registration = in_flight_requests - .start_request(0, SystemTime::now()) + .start_request(0, SystemTime::now(), Span::current()) .unwrap(); let mut abortable_future = Box::new(Abortable::new(pending::<()>(), abort_registration)); - assert_eq!(in_flight_requests.remove_request(0), true); + assert_matches!(in_flight_requests.remove_request(0), Some(_)); assert_matches!( abortable_future.poll_unpin(&mut noop_context()), Poll::Pending diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index 97ba49f3..94ffce38 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -11,10 +11,8 @@ use crate::{ }; use futures::{future::AbortRegistration, task::*, Sink, Stream}; use pin_project::pin_project; -use std::collections::VecDeque; -use std::io; -use std::pin::Pin; -use std::time::SystemTime; +use std::{collections::VecDeque, io, pin::Pin, time::SystemTime}; +use tracing::Span; #[pin_project] pub(crate) struct FakeChannel { @@ -70,6 +68,7 @@ where { type Req = Req; type Resp = Resp; + type Transport = (); fn config(&self) -> &Config { &self.config @@ -79,14 +78,19 @@ where self.in_flight_requests.len() } + fn transport(&self) -> &() { + &() + } + fn start_request( self: Pin<&mut Self>, id: u64, deadline: SystemTime, + span: Span, ) -> Result { self.project() .in_flight_requests - .start_request(id, deadline) + .start_request(id, deadline, span) } } diff --git a/tarpc/src/server/throttle.rs b/tarpc/src/server/throttle.rs index 04181b21..0b68ba12 100644 --- a/tarpc/src/server/throttle.rs +++ b/tarpc/src/server/throttle.rs @@ -7,9 +7,9 @@ use super::{Channel, Config}; use crate::{Response, ServerError}; use futures::{future::AbortRegistration, prelude::*, ready, task::*}; -use log::debug; use pin_project::pin_project; use std::{io, pin::Pin, time::SystemTime}; +use tracing::Span; /// A [`Channel`] that limits the number of concurrent /// requests by throttling. @@ -55,11 +55,11 @@ where match ready!(self.as_mut().project().inner.poll_next(cx)?) { Some(request) => { - debug!( - "[{}] Client has reached in-flight request limit ({}/{}).", - request.context.trace_id(), - self.as_mut().in_flight_requests(), - self.as_mut().project().max_in_flight_requests, + tracing::debug!( + rpc.trace_id = %request.context.trace_id(), + in_flight_requests = self.as_mut().in_flight_requests(), + max_in_flight_requests = *self.as_mut().project().max_in_flight_requests, + "At in-flight request limit", ); self.as_mut().start_send(Response { @@ -112,6 +112,7 @@ where { type Req = ::Req; type Resp = ::Resp; + type Transport = ::Transport; fn in_flight_requests(&self) -> usize { self.inner.in_flight_requests() @@ -121,12 +122,17 @@ where self.inner.config() } + fn transport(&self) -> &Self::Transport { + self.inner.transport() + } + fn start_request( self: Pin<&mut Self>, id: u64, deadline: SystemTime, + span: Span, ) -> Result { - self.project().inner.start_request(id, deadline) + self.project().inner.start_request(id, deadline, span) } } @@ -196,7 +202,11 @@ mod tests { throttler .inner .in_flight_requests - .start_request(i, SystemTime::now() + Duration::from_secs(1)) + .start_request( + i, + SystemTime::now() + Duration::from_secs(1), + Span::current(), + ) .unwrap(); } assert_eq!(throttler.as_mut().in_flight_requests(), 5); @@ -212,7 +222,11 @@ mod tests { pin_mut!(throttler); throttler .as_mut() - .start_request(1, SystemTime::now() + Duration::from_secs(1)) + .start_request( + 1, + SystemTime::now() + Duration::from_secs(1), + Span::current(), + ) .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 1); } @@ -305,16 +319,21 @@ mod tests { impl Channel for PendingSink>, Response> { type Req = Req; type Resp = Resp; + type Transport = (); fn config(&self) -> &Config { unimplemented!() } fn in_flight_requests(&self) -> usize { 0 } + fn transport(&self) -> &() { + &() + } fn start_request( self: Pin<&mut Self>, _id: u64, _deadline: SystemTime, + _span: tracing::Span, ) -> Result { unimplemented!() } @@ -332,7 +351,11 @@ mod tests { throttler .inner .in_flight_requests - .start_request(0, SystemTime::now() + Duration::from_secs(1)) + .start_request( + 0, + SystemTime::now() + Duration::from_secs(1), + Span::current(), + ) .unwrap(); throttler .as_mut() diff --git a/tarpc/src/trace.rs b/tarpc/src/trace.rs index d889b283..b60326c0 100644 --- a/tarpc/src/trace.rs +++ b/tarpc/src/trace.rs @@ -16,11 +16,14 @@ //! This crate's design is based on [opencensus //! tracing](https://opencensus.io/core-concepts/tracing/). +use opentelemetry::trace::TraceContextExt; use rand::Rng; use std::{ + convert::TryFrom, fmt::{self, Formatter}, - mem, + num::{NonZeroU128, NonZeroU64}, }; +use tracing_opentelemetry::OpenTelemetrySpanExt; /// A context for tracing the execution of processes, distributed or otherwise. /// @@ -36,33 +39,50 @@ pub struct Context { /// before making an RPC, and the span ID is sent to the server. The server is free to create /// its own spans, for which it sets the client's span as the parent span. pub span_id: SpanId, - /// An identifier of the span that originated the current span. For example, if a server sends - /// an RPC in response to a client request that included a span, the server would create a span - /// for the RPC and set its parent to the span_id in the incoming request's context. - /// - /// If `parent_id` is `None`, then this is a root context. - pub parent_id: Option, + /// Indicates whether a sampler has already decided whether or not to sample the trace + /// associated with the Context. If `sampling_decision` is None, then a decision has not yet + /// been made. Downstream samplers do not need to abide by "no sample" decisions--for example, + /// an upstream client may choose to never sample, which may not make sense for the client's + /// dependencies. On the other hand, if an upstream process has chosen to sample this trace, + /// then the downstream samplers are expected to respect that decision and also sample the + /// trace. Otherwise, the full trace would not be able to be reconstructed. + pub sampling_decision: SamplingDecision, } /// A 128-bit UUID identifying a trace. All spans caused by the same originating span share the /// same trace ID. -#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)] +#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct TraceId(#[cfg_attr(feature = "serde1", serde(with = "u128_serde"))] u128); /// A 64-bit identifier of a span within a trace. The identifier is unique within the span's trace. -#[derive(Debug, Default, PartialEq, Eq, Hash, Clone, Copy)] +#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct SpanId(u64); +/// Indicates whether a sampler has decided whether or not to sample the trace associated with the +/// Context. Downstream samplers do not need to abide by "no sample" decisions--for example, an +/// upstream client may choose to never sample, which may not make sense for the client's +/// dependencies. On the other hand, if an upstream process has chosen to sample this trace, then +/// the downstream samplers are expected to respect that decision and also sample the trace. +/// Otherwise, the full trace would not be able to be reconstructed reliably. +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] +#[repr(u8)] +pub enum SamplingDecision { + /// The associated span was sampled by its creating process. Child spans must also be sampled. + Sampled = opentelemetry::trace::TRACE_FLAG_SAMPLED, + /// The associated span was not sampled by its creating process. + Unsampled = opentelemetry::trace::TRACE_FLAG_NOT_SAMPLED, +} + impl Context { - /// Constructs a new root context. A root context is one with no parent span. - pub fn new_root() -> Self { - let rng = &mut rand::thread_rng(); - Context { - trace_id: TraceId::random(rng), - span_id: SpanId::random(rng), - parent_id: None, + /// Constructs a new context with the trace ID and sampling decision inherited from the parent. + pub(crate) fn new_child(&self) -> Self { + Self { + trace_id: self.trace_id, + span_id: SpanId::random(&mut rand::thread_rng()), + sampling_decision: self.sampling_decision, } } } @@ -71,17 +91,119 @@ impl TraceId { /// Returns a random trace ID that can be assumed to be globally unique if `rng` generates /// actually-random numbers. pub fn random(rng: &mut R) -> Self { - TraceId(u128::from(rng.next_u64()) << mem::size_of::() | u128::from(rng.next_u64())) + TraceId(rng.gen::().get()) + } + + /// Returns true iff the trace ID is 0. + pub fn is_none(&self) -> bool { + self.0 == 0 } } impl SpanId { /// Returns a random span ID that can be assumed to be unique within a single trace. pub fn random(rng: &mut R) -> Self { - SpanId(rng.next_u64()) + SpanId(rng.gen::().get()) + } + + /// Returns true iff the span ID is 0. + pub fn is_none(&self) -> bool { + self.0 == 0 + } +} + +impl From for u128 { + fn from(trace_id: TraceId) -> Self { + trace_id.0 + } +} + +impl From for TraceId { + fn from(trace_id: u128) -> Self { + Self(trace_id) + } +} + +impl From for u64 { + fn from(span_id: SpanId) -> Self { + span_id.0 + } +} + +impl From for SpanId { + fn from(span_id: u64) -> Self { + Self(span_id) + } +} + +impl From for TraceId { + fn from(trace_id: opentelemetry::trace::TraceId) -> Self { + Self::from(trace_id.to_u128()) + } +} + +impl From for opentelemetry::trace::TraceId { + fn from(trace_id: TraceId) -> Self { + Self::from_u128(trace_id.into()) + } +} + +impl From for SpanId { + fn from(span_id: opentelemetry::trace::SpanId) -> Self { + Self::from(span_id.to_u64()) + } +} + +impl From for opentelemetry::trace::SpanId { + fn from(span_id: SpanId) -> Self { + Self::from_u64(span_id.0) + } +} + +impl TryFrom<&tracing::Span> for Context { + type Error = NoActiveSpan; + + fn try_from(span: &tracing::Span) -> Result { + let context = span.context(); + if context.has_active_span() { + Ok(Self::from(context.span())) + } else { + Err(NoActiveSpan) + } + } +} + +impl From<&dyn opentelemetry::trace::Span> for Context { + fn from(span: &dyn opentelemetry::trace::Span) -> Self { + let otel_ctx = span.span_context(); + Self { + trace_id: TraceId::from(otel_ctx.trace_id()), + span_id: SpanId::from(otel_ctx.span_id()), + sampling_decision: SamplingDecision::from(otel_ctx), + } + } +} + +impl From<&opentelemetry::trace::SpanContext> for SamplingDecision { + fn from(context: &opentelemetry::trace::SpanContext) -> Self { + if context.is_sampled() { + SamplingDecision::Sampled + } else { + SamplingDecision::Unsampled + } } } +impl Default for SamplingDecision { + fn default() -> Self { + Self::Unsampled + } +} + +/// Returned when a [`Context`] cannot be constructed from a [`Span`](tracing::Span). +#[derive(Debug)] +pub struct NoActiveSpan; + impl fmt::Display for TraceId { fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { write!(f, "{:02x}", self.0)?; @@ -89,6 +211,13 @@ impl fmt::Display for TraceId { } } +impl fmt::Debug for TraceId { + fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { + write!(f, "{:02x}", self.0)?; + Ok(()) + } +} + impl fmt::Display for SpanId { fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { write!(f, "{:02x}", self.0)?; @@ -96,6 +225,13 @@ impl fmt::Display for SpanId { } } +impl fmt::Debug for SpanId { + fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { + write!(f, "{:02x}", self.0)?; + Ok(()) + } +} + #[cfg(feature = "serde1")] mod u128_serde { pub fn serialize(u: &u128, serializer: S) -> Result diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 8456b884..2494a40f 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -152,12 +152,12 @@ mod tests { }; use assert_matches::assert_matches; use futures::{prelude::*, stream}; - use log::trace; use std::io; + use tracing::trace; #[tokio::test] async fn integration() -> io::Result<()> { - let _ = env_logger::try_init(); + let _ = tracing_subscriber::fmt::try_init(); let (client_channel, server_channel) = transport::channel::unbounded(); tokio::spawn( @@ -175,8 +175,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn()?; - let response1 = client.call(context::current(), "123".into()).await?; - let response2 = client.call(context::current(), "abc".into()).await?; + let response1 = client.call(context::current(), "", "123".into()).await?; + let response2 = client.call(context::current(), "", "abc".into()).await?; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 0fc16eb4..eafa0c0d 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -40,7 +40,7 @@ impl Service for Server { #[tokio::test] async fn sequential() -> io::Result<()> { - let _ = env_logger::try_init(); + let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); @@ -82,7 +82,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> io::Result<()> { } } - let _ = env_logger::try_init(); + let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); @@ -117,7 +117,7 @@ async fn serde() -> io::Result<()> { use tarpc::serde_transport; use tokio_serde::formats::Json; - let _ = env_logger::try_init(); + let _ = tracing_subscriber::fmt::try_init(); let transport = tarpc::serde_transport::tcp::listen("localhost:56789", Json::default).await?; let addr = transport.local_addr(); @@ -143,7 +143,7 @@ async fn serde() -> io::Result<()> { #[tokio::test] async fn concurrent() -> io::Result<()> { - let _ = env_logger::try_init(); + let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn( @@ -167,7 +167,7 @@ async fn concurrent() -> io::Result<()> { #[tokio::test] async fn concurrent_join() -> io::Result<()> { - let _ = env_logger::try_init(); + let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn( @@ -192,7 +192,7 @@ async fn concurrent_join() -> io::Result<()> { #[tokio::test] async fn concurrent_join_all() -> io::Result<()> { - let _ = env_logger::try_init(); + let _ = tracing_subscriber::fmt::try_init(); let (tx, rx) = channel::unbounded(); tokio::spawn(