diff --git a/Cargo.toml b/Cargo.toml index e8d6b8a..6e27a87 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,18 @@ stderrlog = "0.6.0" structopt = "0.3.26" tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] } tonic = "0.12.2" +opentelemetry_sdk = { version = "*", features = ["rt-tokio"] } +opentelemetry-otlp = { version = "*", features = ["grpc-tonic"] } +opentelemetry = "0.27.1" +tonic-tracing-opentelemetry = "0.24.3" +tower = "0.5.2" +tracing-opentelemetry-instrumentation-sdk = "0.24.1" +http = "1.2.0" +axum-tracing-opentelemetry = "0.25.0" +opentelemetry-stdout = "0.27.0" +tracing-subscriber = { version="0.3.19", features = ["fmt", "env-filter"]} +tracing-opentelemetry = "0.28.0" +tracing = "0.1.41" [build-dependencies] tonic-build = "0.12.2" diff --git a/src/lib.rs b/src/lib.rs index d9a124b..97435ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,17 +15,18 @@ use std::env; use std::sync::Arc; use anyhow::Result; +use log::info; use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; -use tonic::transport::Channel; use tonic::Status; pub mod torchftpb { tonic::include_proto!("torchft"); } +use crate::net::Channel; use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{CheckpointAddressRequest, ManagerQuorumRequest, ShouldCommitRequest}; use pyo3::prelude::*; @@ -301,8 +302,7 @@ impl From for StatusError { } } -#[pymodule] -fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { +fn init_logging() -> PyResult<()> { // setup logging on import let mut log = stderrlog::new(); log.verbosity(2) @@ -316,6 +316,92 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { log.init() .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(()) +} + +fn init_tracing() -> PyResult<()> { + use opentelemetry::trace::Tracer; + use opentelemetry::trace::TracerProvider as OpenTelemetryTracerProvider; + use opentelemetry_otlp::WithExportConfig; + use opentelemetry_sdk::trace::TracerProvider; + use tracing_subscriber::layer::SubscriberExt; + use tracing_subscriber::{filter::EnvFilter, Layer}; + + fn set_tracer_provider(tracer_provider: TracerProvider) -> PyResult<()> { + opentelemetry::global::set_tracer_provider(tracer_provider.clone()); + + let layer = tracing_opentelemetry::layer() + .with_error_records_to_exceptions(true) + .with_tracer(tracer_provider.tracer("")); + + // Create a new tracing::Fmt layer to print the logs to stdout. It has a + // default filter of `info` level and above, and `debug` and above for logs + // from OpenTelemetry crates. The filter levels can be customized as needed. + let filter_fmt = + EnvFilter::new("info").add_directive("opentelemetry=debug".parse().unwrap()); + let fmt_layer = tracing_subscriber::fmt::layer() + .with_thread_names(true) + .with_filter(filter_fmt); + + let subscriber = tracing_subscriber::registry().with(fmt_layer).with(layer); + tracing::subscriber::set_global_default(subscriber) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + info!("OpenTelemetry tracing enabled"); + + Ok(()) + } + + match env::var("TORCHFT_OTEL_OTLP") { + Ok(endpoint) => { + let runtime = Runtime::new()?; + + runtime.block_on(async move { + info!("Enabling OpenTelemetry OTLP with {}", endpoint); + let exporter = opentelemetry_otlp::SpanExporter::builder() + .with_tonic() + .with_endpoint(endpoint) + .with_timeout(Duration::from_secs(10)) + .build() + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + + let tracer_provider = TracerProvider::builder() + .with_batch_exporter(exporter, opentelemetry_sdk::runtime::Tokio) + .build(); + + set_tracer_provider(tracer_provider)?; + + Ok::<(), pyo3::PyErr>(()) + })?; + } + Err(_) => {} + }; + match env::var("TORCHFT_OTEL_STDOUT") { + Ok(_) => { + info!("Enabling OpenTelemetry stdout"); + let exporter = opentelemetry_stdout::SpanExporter::default(); + let tracer_provider = TracerProvider::builder() + .with_simple_exporter(exporter) + .build(); + + set_tracer_provider(tracer_provider)?; + } + Err(_) => {} + } + + let tracer = opentelemetry::global::tracer("my_tracer"); + tracer.in_span("doing_work", |cx| { + // Traced app logic here... + }); + + Ok(()) +} + +#[pymodule] +fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { + init_logging()?; + init_tracing()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/src/lighthouse.rs b/src/lighthouse.rs index f151fbf..1e39079 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -20,6 +20,7 @@ use axum::{ routing::{get, post}, Router, }; +use axum_tracing_opentelemetry::middleware::{OtelAxumLayer, OtelInResponseLayer}; use gethostname::gethostname; use log::{error, info}; use structopt::StructOpt; @@ -31,6 +32,7 @@ use tonic::service::Routes; use tonic::transport::server::TcpIncoming; use tonic::transport::Server; use tonic::{Request, Response, Status}; +use tonic_tracing_opentelemetry::middleware::server::OtelGrpcLayer; use crate::manager::manager_client_new; use crate::torchftpb::{ @@ -345,12 +347,17 @@ impl Lighthouse { let self_clone = self.clone(); move |path| async { self_clone.kill(path).await } }), - ); + ) + // include trace context as header into the response + .layer(OtelInResponseLayer::default()) + //start OpenTelemetry trace on incoming request + .layer(OtelAxumLayer::default()); // register the GRPC service let routes = Routes::from(app).add_service(LighthouseServiceServer::new(self)); Server::builder() + .layer(OtelGrpcLayer::default()) // allow non-GRPC connections .accept_http1(true) .add_routes(routes) @@ -571,9 +578,7 @@ mod tests { use super::*; use std::ops::Sub; - use tonic::transport::Channel; - - use crate::net::connect; + use crate::net::{connect, Channel}; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; async fn lighthouse_client_new(addr: String) -> Result> { diff --git a/src/manager.rs b/src/manager.rs index 982500a..c0716f7 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -16,11 +16,11 @@ use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::sleep; use tonic::transport::server::TcpIncoming; -use tonic::transport::Channel; use tonic::transport::Server; use tonic::{Request, Response, Status}; +use tonic_tracing_opentelemetry::middleware::server::OtelGrpcLayer; -use crate::net::connect; +use crate::net::{connect, Channel}; use crate::timeout::try_parse_grpc_timeout; use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; @@ -64,8 +64,9 @@ pub async fn manager_client_new( connect_timeout: Duration, ) -> Result> { info!("ManagerClient: establishing connection to {}", &addr); - let conn = connect(addr, connect_timeout).await?; - Ok(ManagerServiceClient::new(conn)) + let channel = connect(addr, connect_timeout).await?; + + Ok(ManagerServiceClient::new(channel)) } pub async fn lighthouse_client_new( @@ -73,8 +74,8 @@ pub async fn lighthouse_client_new( connect_timeout: Duration, ) -> Result> { info!("LighthouseClient: establishing connection to {}", &addr); - let conn = connect(addr, connect_timeout).await?; - Ok(LighthouseServiceClient::new(conn)) + let channel = connect(addr, connect_timeout).await?; + Ok(LighthouseServiceClient::new(channel)) } impl Manager { @@ -146,6 +147,7 @@ impl Manager { TcpIncoming::from_listener(listener, true, None).map_err(|e| anyhow::anyhow!(e))?; Server::builder() + .layer(OtelGrpcLayer::default()) .add_service(ManagerServiceServer::new(self)) .serve_with_incoming(incoming) .await diff --git a/src/net.rs b/src/net.rs index e6d9b69..1d19984 100644 --- a/src/net.rs +++ b/src/net.rs @@ -1,12 +1,16 @@ use std::time::Duration; use anyhow::Result; -use tonic::transport::{Channel, Endpoint}; +use tonic::transport::Endpoint; +use tonic_tracing_opentelemetry::middleware::client::{OtelGrpcLayer, OtelGrpcService}; +use tower::ServiceBuilder; use crate::retry::{retry_backoff, ExponentialBackoff}; +pub type Channel = OtelGrpcService; + pub async fn connect_once(addr: String, connect_timeout: Duration) -> Result { - let conn = Endpoint::new(addr)? + let channel = Endpoint::new(addr)? .connect_timeout(connect_timeout) // Enable HTTP2 keep alives .http2_keep_alive_interval(Duration::from_secs(60)) @@ -16,7 +20,9 @@ pub async fn connect_once(addr: String, connect_timeout: Duration) -> Result Result {