From 458501f79c64a4cc85593667e4842f2b4e0cc7f0 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Wed, 24 Jan 2024 20:11:14 -0600 Subject: [PATCH 1/2] Make the generated code less complicated --- crates/twirp-build/src/lib.rs | 23 ++++----------- crates/twirp/src/details.rs | 53 +++++++++++++++++++++++++++++++---- crates/twirp/src/test.rs | 38 ++++++------------------- 3 files changed, 61 insertions(+), 53 deletions(-) diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index c0aa439..985a335 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -43,7 +43,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { where T: {service_name} + Send + Sync + 'static, {{ - twirp::Router::new()"#, + twirp::details::TwirpRouterBuilder::new(api)"#, ) .unwrap(); for m in &service.methods { @@ -51,29 +51,16 @@ where let rust_method_name = &m.name; writeln!( buf, - r#" .route( - "/{uri}", - twirp::details::post( - |twirp::details::State(api): twirp::details::State>, - req: twirp::details::Request| async move {{ - twirp::server::handle_request( - req, - move |req| async move {{ - api.{rust_method_name}(req).await - }}, - ) - .await - }}, - ), - )"#, + r#" .route("/{uri}", |api: std::sync::Arc| move |req| async move {{ + api.{rust_method_name}(req).await + }})"#, ) .unwrap(); } writeln!( buf, r#" - .with_state(api) - .fallback(twirp::server::not_found_handler) + .build() }}"# ) .unwrap(); diff --git a/crates/twirp/src/details.rs b/crates/twirp/src/details.rs index 1ab5989..1af08ef 100644 --- a/crates/twirp/src/details.rs +++ b/crates/twirp/src/details.rs @@ -1,10 +1,51 @@ //! Undocumented features that are public for use in generated code (see `twirp-build`). -#[doc(hidden)] -pub use axum::extract::{Request, State}; +use std::future::Future; -#[doc(hidden)] -pub use axum::routing::post; +use axum::extract::{Request, State}; +use axum::Router; -#[doc(hidden)] -pub use axum::response::Response; +use crate::{server, TwirpErrorResponse}; + +/// Builder object used by generated code to build a router. +pub struct TwirpRouterBuilder { + service: S, + router: Router, +} + +impl TwirpRouterBuilder +where + S: Clone + Send + Sync + 'static, +{ + pub fn new(service: S) -> Self { + TwirpRouterBuilder { + service, + router: Router::new(), + } + } + + pub fn route(self, url: &str, f: F) -> Self + where + F: Fn(S) -> G + Clone + Send + 'static, + G: FnOnce(RequestMessage) -> Fut + Clone + Sync + Send + 'static, + Fut: Future> + Send, + RequestMessage: prost::Message + Default + serde::de::DeserializeOwned, + ResponseMessage: prost::Message + serde::Serialize, + { + TwirpRouterBuilder { + service: self.service, + router: self.router.route( + url, + axum::routing::post(move |State(api): State, req: Request| async move { + server::handle_request(req, f(api)).await + }), + ), + } + } + + pub fn build(self) -> axum::Router { + self.router + .fallback(crate::server::not_found_handler) + .with_state(self.service) + } +} diff --git a/crates/twirp/src/test.rs b/crates/twirp/src/test.rs index d3386ec..c0b35d0 100644 --- a/crates/twirp/src/test.rs +++ b/crates/twirp/src/test.rs @@ -11,6 +11,7 @@ use serde::de::DeserializeOwned; use tokio::task::JoinHandle; use tokio::time::Instant; +use crate::details::TwirpRouterBuilder; use crate::server::Timings; use crate::{error, Client, Result, TwirpErrorResponse}; @@ -28,35 +29,14 @@ pub fn test_api_router() -> Router { let api = Arc::new(TestAPIServer {}); // NB: This part would be generated - let test_router = crate::Router::new() - .route( - "/Ping", - crate::details::post( - |crate::details::State(api): crate::details::State>, - req: crate::details::Request| async move { - crate::server::handle_request( - req, - move |req| async move { api.ping(req).await }, - ) - .await - }, - ), - ) - .route( - "/Boom", - crate::details::post( - |crate::details::State(api): crate::details::State>, - req: crate::details::Request| async move { - crate::server::handle_request( - req, - move |req| async move { api.boom(req).await }, - ) - .await - }, - ), - ) - .fallback(crate::server::not_found_handler) - .with_state(api); + let test_router = TwirpRouterBuilder::new(api) + .route("/Ping", |api: Arc| { + move |req| async move { api.ping(req).await } + }) + .route("/Boom", |api: Arc| { + move |req| async move { api.boom(req).await } + }) + .build(); axum::Router::new() .nest("/twirp/test.TestAPI", test_router) From 538ee51af63bbef41906e0b1b32507ead739f7d8 Mon Sep 17 00:00:00 2001 From: Jason Orendorff Date: Thu, 25 Jan 2024 09:27:29 -0600 Subject: [PATCH 2/2] Simplify TwirpRouterBuilder interface --- README.md | 5 +++++ crates/twirp-build/src/lib.rs | 5 +++-- crates/twirp/src/details.rs | 23 +++++++++++++++-------- crates/twirp/src/lib.rs | 14 ++++++++++---- crates/twirp/src/server.rs | 34 ++++++++++++++++++++++++++-------- crates/twirp/src/test.rs | 14 ++++++++------ 6 files changed, 67 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 16b00fa..085e58f 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ mod haberdash { include!(concat!(env!("OUT_DIR"), "/service.haberdash.v1.rs")); } +use axum::Router; use haberdash::{MakeHatRequest, MakeHatResponse}; #[tokio::main] @@ -87,6 +88,10 @@ impl haberdash::HaberdasherAPI for HaberdasherAPIServer { } ``` +This code creates an `axum::Router`, then hands it off to `axum::serve()` to handle networking. +This use of `axum::serve` is optional. After building `app`, you can instead invoke it from any +`hyper`-based server by importing `twirp::tower::Service` and doing `app.call(request).await`. + ## Usage (client side) On the client side, you also get a generated twirp client (based on the rpc endpoints in your proto). Include the generated code, create a client, and start making rpc calls: diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 985a335..c3800ac 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -19,7 +19,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { let service_fqn = format!("{}.{}", service.package, service_name); writeln!(buf).unwrap(); - writeln!(buf, "pub const SERVICE_FQN: &str = \"{service_fqn}\";").unwrap(); + writeln!(buf, "pub const SERVICE_FQN: &str = \"/{service_fqn}\";").unwrap(); // // generate the twirp server @@ -48,10 +48,11 @@ where .unwrap(); for m in &service.methods { let uri = &m.proto_name; + let req_type = &m.input_type; let rust_method_name = &m.name; writeln!( buf, - r#" .route("/{uri}", |api: std::sync::Arc| move |req| async move {{ + r#" .route("/{uri}", |api: std::sync::Arc, req: {req_type}| async move {{ api.{rust_method_name}(req).await }})"#, ) diff --git a/crates/twirp/src/details.rs b/crates/twirp/src/details.rs index 1af08ef..4d74de7 100644 --- a/crates/twirp/src/details.rs +++ b/crates/twirp/src/details.rs @@ -7,7 +7,10 @@ use axum::Router; use crate::{server, TwirpErrorResponse}; -/// Builder object used by generated code to build a router. +/// Builder object used by generated code to build a Twirp service. +/// +/// The type `S` is something like `Arc`, which can be cheaply cloned for each +/// incoming request, providing access to the Rust value that actually implements the RPCs. pub struct TwirpRouterBuilder { service: S, router: Router, @@ -24,25 +27,29 @@ where } } - pub fn route(self, url: &str, f: F) -> Self + /// Add a handler for an `rpc` to the router. + /// + /// The generated code passes a closure that calls the method, like + /// `|api: Arc, req: MakeHatRequest| async move { api.make_hat(req) }`. + pub fn route(self, url: &str, f: F) -> Self where - F: Fn(S) -> G + Clone + Send + 'static, - G: FnOnce(RequestMessage) -> Fut + Clone + Sync + Send + 'static, - Fut: Future> + Send, - RequestMessage: prost::Message + Default + serde::de::DeserializeOwned, - ResponseMessage: prost::Message + serde::Serialize, + F: Fn(S, Req) -> Fut + Clone + Sync + Send + 'static, + Fut: Future> + Send, + Req: prost::Message + Default + serde::de::DeserializeOwned, + Res: prost::Message + serde::Serialize, { TwirpRouterBuilder { service: self.service, router: self.router.route( url, axum::routing::post(move |State(api): State, req: Request| async move { - server::handle_request(req, f(api)).await + server::handle_request(api, req, f).await }), ), } } + /// Finish building the axum router. pub fn build(self) -> axum::Router { self.router .fallback(crate::server::not_found_handler) diff --git a/crates/twirp/src/lib.rs b/crates/twirp/src/lib.rs index ee05cab..87563c1 100644 --- a/crates/twirp/src/lib.rs +++ b/crates/twirp/src/lib.rs @@ -8,18 +8,24 @@ pub mod server; #[cfg(any(test, feature = "test-support"))] pub mod test; +#[doc(hidden)] pub mod details; pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result}; pub use error::*; // many constructors like `invalid_argument()` -pub use server::{Router, Timings}; -// Re-export `reqwest` so that it's easy to implement middleware. +// Re-export this crate's dependencies that users are likely to code against. These can be used to +// import the exact versions of these libraries `twirp` is built with -- useful if your project is +// so sprawling that it builds multiple versions of some crates. +pub use axum; pub use reqwest; - -// Re-export `url so that the generated code works without additional dependencies beyond just the `twirp` crate. +pub use tower; pub use url; +/// Re-export of `axum::Router`, the type that encapsulates a server-side implementation of a Twirp +/// service. +pub use axum::Router; + pub(crate) fn serialize_proto_message(m: T) -> Vec where T: prost::Message, diff --git a/crates/twirp/src/server.rs b/crates/twirp/src/server.rs index de886e0..d78f27c 100644 --- a/crates/twirp/src/server.rs +++ b/crates/twirp/src/server.rs @@ -1,8 +1,12 @@ +//! Support for serving Twirp APIs. +//! +//! There is not much to see in the documentation here. This API is meant to be used with +//! `twirp-build`. See for details and an example. + use std::fmt::Debug; use axum::body::Body; use axum::response::IntoResponse; -pub use axum::Router; use futures::Future; use http_body_util::BodyExt; use hyper::{header, Request, Response}; @@ -13,9 +17,6 @@ use tokio::time::{Duration, Instant}; use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF}; use crate::{error, serialize_proto_message, GenericError, TwirpErrorResponse}; -/// The canonical twirp path prefix. You don't have to use this, but it's the default. -pub const DEFAULT_TWIRP_PATH_PREFIX: &str = "/twirp"; - // TODO: Properly implement JsonPb (de)serialization as it is slightly different // than standard JSON. #[derive(Debug, Clone, Copy, Default)] @@ -39,9 +40,13 @@ impl BodyFormat { } /// Entry point used in code generated by `twirp-build`. -pub async fn handle_request(req: Request, f: F) -> Response +pub(crate) async fn handle_request( + service: S, + req: Request, + f: F, +) -> Response where - F: FnOnce(Req) -> Fut + Clone + Sync + Send + 'static, + F: FnOnce(S, Req) -> Fut + Clone + Sync + Send + 'static, Fut: Future> + Send, Req: prost::Message + Default + serde::de::DeserializeOwned, Resp: prost::Message + serde::Serialize, @@ -64,7 +69,7 @@ where } }; - let res = f(req).await; + let res = f(service, req).await; timings.set_response_handled(); let mut resp = match write_response(res, resp_fmt) { @@ -111,7 +116,7 @@ where BodyFormat::Pb => Response::builder() .header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF) .body(Body::from(serialize_proto_message(response)))?, - _ => { + BodyFormat::JsonPb => { let data = serde_json::to_string(&response)?; Response::builder() .header(header::CONTENT_TYPE, CONTENT_TYPE_JSON) @@ -126,6 +131,19 @@ where /// Axum handler function that returns 404 Not Found with a Twirp JSON payload. /// /// `axum::Router`'s default fallback handler returns a 404 Not Found with no body content. +/// Use this fallback instead for full Twirp compliance. +/// +/// # Usage +/// +/// ``` +/// use axum::Router; +/// +/// # fn build_app(twirp_routes: Router) -> Router { +/// let app = Router::new() +/// .nest("/twirp", twirp_routes) +/// .fallback(twirp::server::not_found_handler); +/// # app } +/// ``` pub async fn not_found_handler() -> Response { error::bad_route("not found").into_response() } diff --git a/crates/twirp/src/test.rs b/crates/twirp/src/test.rs index c0b35d0..e5407d1 100644 --- a/crates/twirp/src/test.rs +++ b/crates/twirp/src/test.rs @@ -30,12 +30,14 @@ pub fn test_api_router() -> Router { // NB: This part would be generated let test_router = TwirpRouterBuilder::new(api) - .route("/Ping", |api: Arc| { - move |req| async move { api.ping(req).await } - }) - .route("/Boom", |api: Arc| { - move |req| async move { api.boom(req).await } - }) + .route( + "/Ping", + |api: Arc, req: PingRequest| async move { api.ping(req).await }, + ) + .route( + "/Boom", + |api: Arc, req: PingRequest| async move { api.boom(req).await }, + ) .build(); axum::Router::new()