Skip to content

Commit

Permalink
Simplify TwirpRouterBuilder interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jorendorff committed Jan 25, 2024
1 parent 458501f commit 538ee51
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 28 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ mod haberdash {
include!(concat!(env!("OUT_DIR"), "/service.haberdash.v1.rs"));
}

use axum::Router;
use haberdash::{MakeHatRequest, MakeHatResponse};

#[tokio::main]
Expand Down Expand Up @@ -87,6 +88,10 @@ impl haberdash::HaberdasherAPI for HaberdasherAPIServer {
}
```

This code creates an `axum::Router`, then hands it off to `axum::serve()` to handle networking.
This use of `axum::serve` is optional. After building `app`, you can instead invoke it from any
`hyper`-based server by importing `twirp::tower::Service` and doing `app.call(request).await`.

## Usage (client side)

On the client side, you also get a generated twirp client (based on the rpc endpoints in your proto). Include the generated code, create a client, and start making rpc calls:
Expand Down
5 changes: 3 additions & 2 deletions crates/twirp-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
let service_fqn = format!("{}.{}", service.package, service_name);
writeln!(buf).unwrap();

writeln!(buf, "pub const SERVICE_FQN: &str = \"{service_fqn}\";").unwrap();
writeln!(buf, "pub const SERVICE_FQN: &str = \"/{service_fqn}\";").unwrap();

//
// generate the twirp server
Expand Down Expand Up @@ -48,10 +48,11 @@ where
.unwrap();
for m in &service.methods {
let uri = &m.proto_name;
let req_type = &m.input_type;
let rust_method_name = &m.name;
writeln!(
buf,
r#" .route("/{uri}", |api: std::sync::Arc<T>| move |req| async move {{
r#" .route("/{uri}", |api: std::sync::Arc<T>, req: {req_type}| async move {{
api.{rust_method_name}(req).await
}})"#,
)
Expand Down
23 changes: 15 additions & 8 deletions crates/twirp/src/details.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use axum::Router;

use crate::{server, TwirpErrorResponse};

/// Builder object used by generated code to build a router.
/// Builder object used by generated code to build a Twirp service.
///
/// The type `S` is something like `Arc<MyExampleAPIServer>`, which can be cheaply cloned for each
/// incoming request, providing access to the Rust value that actually implements the RPCs.
pub struct TwirpRouterBuilder<S> {
service: S,
router: Router<S>,
Expand All @@ -24,25 +27,29 @@ where
}
}

pub fn route<F, G, Fut, RequestMessage, ResponseMessage>(self, url: &str, f: F) -> Self
/// Add a handler for an `rpc` to the router.
///
/// The generated code passes a closure that calls the method, like
/// `|api: Arc<HaberdasherAPIServer>, req: MakeHatRequest| async move { api.make_hat(req) }`.
pub fn route<F, Fut, Req, Res>(self, url: &str, f: F) -> Self
where
F: Fn(S) -> G + Clone + Send + 'static,
G: FnOnce(RequestMessage) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<ResponseMessage, TwirpErrorResponse>> + Send,
RequestMessage: prost::Message + Default + serde::de::DeserializeOwned,
ResponseMessage: prost::Message + serde::Serialize,
F: Fn(S, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Res, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Res: prost::Message + serde::Serialize,
{
TwirpRouterBuilder {
service: self.service,
router: self.router.route(
url,
axum::routing::post(move |State(api): State<S>, req: Request| async move {
server::handle_request(req, f(api)).await
server::handle_request(api, req, f).await
}),
),
}
}

/// Finish building the axum router.
pub fn build(self) -> axum::Router {
self.router
.fallback(crate::server::not_found_handler)
Expand Down
14 changes: 10 additions & 4 deletions crates/twirp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,24 @@ pub mod server;
#[cfg(any(test, feature = "test-support"))]
pub mod test;

#[doc(hidden)]
pub mod details;

pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result};
pub use error::*; // many constructors like `invalid_argument()`
pub use server::{Router, Timings};

// Re-export `reqwest` so that it's easy to implement middleware.
// Re-export this crate's dependencies that users are likely to code against. These can be used to
// import the exact versions of these libraries `twirp` is built with -- useful if your project is
// so sprawling that it builds multiple versions of some crates.
pub use axum;
pub use reqwest;

// Re-export `url so that the generated code works without additional dependencies beyond just the `twirp` crate.
pub use tower;
pub use url;

/// Re-export of `axum::Router`, the type that encapsulates a server-side implementation of a Twirp
/// service.
pub use axum::Router;

pub(crate) fn serialize_proto_message<T>(m: T) -> Vec<u8>
where
T: prost::Message,
Expand Down
34 changes: 26 additions & 8 deletions crates/twirp/src/server.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
//! Support for serving Twirp APIs.
//!
//! There is not much to see in the documentation here. This API is meant to be used with
//! `twirp-build`. See <https://github.com/github/twirp-rs#usage> for details and an example.
use std::fmt::Debug;

use axum::body::Body;
use axum::response::IntoResponse;
pub use axum::Router;
use futures::Future;
use http_body_util::BodyExt;
use hyper::{header, Request, Response};
Expand All @@ -13,9 +17,6 @@ use tokio::time::{Duration, Instant};
use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
use crate::{error, serialize_proto_message, GenericError, TwirpErrorResponse};

/// The canonical twirp path prefix. You don't have to use this, but it's the default.
pub const DEFAULT_TWIRP_PATH_PREFIX: &str = "/twirp";

// TODO: Properly implement JsonPb (de)serialization as it is slightly different
// than standard JSON.
#[derive(Debug, Clone, Copy, Default)]
Expand All @@ -39,9 +40,13 @@ impl BodyFormat {
}

/// Entry point used in code generated by `twirp-build`.
pub async fn handle_request<F, Fut, Req, Resp>(req: Request<Body>, f: F) -> Response<Body>
pub(crate) async fn handle_request<S, F, Fut, Req, Resp>(
service: S,
req: Request<Body>,
f: F,
) -> Response<Body>
where
F: FnOnce(Req) -> Fut + Clone + Sync + Send + 'static,
F: FnOnce(S, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Resp, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Resp: prost::Message + serde::Serialize,
Expand All @@ -64,7 +69,7 @@ where
}
};

let res = f(req).await;
let res = f(service, req).await;
timings.set_response_handled();

let mut resp = match write_response(res, resp_fmt) {
Expand Down Expand Up @@ -111,7 +116,7 @@ where
BodyFormat::Pb => Response::builder()
.header(header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
.body(Body::from(serialize_proto_message(response)))?,
_ => {
BodyFormat::JsonPb => {
let data = serde_json::to_string(&response)?;
Response::builder()
.header(header::CONTENT_TYPE, CONTENT_TYPE_JSON)
Expand All @@ -126,6 +131,19 @@ where
/// Axum handler function that returns 404 Not Found with a Twirp JSON payload.
///
/// `axum::Router`'s default fallback handler returns a 404 Not Found with no body content.
/// Use this fallback instead for full Twirp compliance.
///
/// # Usage
///
/// ```
/// use axum::Router;
///
/// # fn build_app(twirp_routes: Router) -> Router {
/// let app = Router::new()
/// .nest("/twirp", twirp_routes)
/// .fallback(twirp::server::not_found_handler);
/// # app }
/// ```
pub async fn not_found_handler() -> Response<Body> {
error::bad_route("not found").into_response()
}
Expand Down
14 changes: 8 additions & 6 deletions crates/twirp/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ pub fn test_api_router() -> Router {

// NB: This part would be generated
let test_router = TwirpRouterBuilder::new(api)
.route("/Ping", |api: Arc<TestAPIServer>| {
move |req| async move { api.ping(req).await }
})
.route("/Boom", |api: Arc<TestAPIServer>| {
move |req| async move { api.boom(req).await }
})
.route(
"/Ping",
|api: Arc<TestAPIServer>, req: PingRequest| async move { api.ping(req).await },
)
.route(
"/Boom",
|api: Arc<TestAPIServer>, req: PingRequest| async move { api.boom(req).await },
)
.build();

axum::Router::new()
Expand Down

0 comments on commit 538ee51

Please sign in to comment.