diff --git a/crates/twirp-build/README.md b/crates/twirp-build/README.md index ada959f..efbe879 100644 --- a/crates/twirp-build/README.md +++ b/crates/twirp-build/README.md @@ -104,7 +104,7 @@ mod haberdash { include!(concat!(env!("OUT_DIR"), "/service.haberdash.v1.rs")); } -use haberdash::{HaberdasherApiClient, MakeHatRequest, MakeHatResponse}; +use haberdash::{HaberdasherApi, MakeHatRequest, MakeHatResponse}; #[tokio::main] pub async fn main() { diff --git a/crates/twirp-build/src/lib.rs b/crates/twirp-build/src/lib.rs index 0540c67..d2e8e18 100644 --- a/crates/twirp-build/src/lib.rs +++ b/crates/twirp-build/src/lib.rs @@ -15,9 +15,6 @@ struct Service { /// The name of the server trait, as parsed into a Rust identifier. server_name: syn::Ident, - /// The name of the client trait, as parsed into a Rust identifier. - client_name: syn::Ident, - /// The fully qualified protobuf name of this Service. fqn: String, @@ -43,7 +40,6 @@ impl Service { fn from_prost(s: prost_build::Service) -> Self { let fqn = format!("{}.{}", s.package, s.proto_name); let server_name = format_ident!("{}", &s.name); - let client_name = format_ident!("{}Client", &s.name); let methods = s .methods .into_iter() @@ -52,7 +48,6 @@ impl Service { Self { server_name, - client_name, fqn, methods, } @@ -161,8 +156,6 @@ impl prost_build::ServiceGenerator for ServiceGenerator { // generate the twirp client // - let client_name = service.client_name; - let mut client_trait_methods = Vec::with_capacity(service.methods.len()); let mut client_methods = Vec::with_capacity(service.methods.len()); for m in &service.methods { let name = &m.name; @@ -170,24 +163,17 @@ impl prost_build::ServiceGenerator for ServiceGenerator { let output_type = &m.output_type; let request_path = format!("{}/{}", service.fqn, m.proto_name); - client_trait_methods.push(quote! { - async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError>; - }); - client_methods.push(quote! { - async fn #name(&self, req: #input_type) -> Result<#output_type, twirp::ClientError> { + async fn #name(&self, _ctx: twirp::Context, req: #input_type) -> Result<#output_type, Self::Error> { self.request(#request_path, req).await } - }) + }); } let client_trait = quote! { #[twirp::async_trait::async_trait] - pub trait #client_name: Send + Sync { - #(#client_trait_methods)* - } + impl #server_name for twirp::client::Client { + type Error = twirp::ClientError; - #[twirp::async_trait::async_trait] - impl #client_name for twirp::client::Client { #(#client_methods)* } }; diff --git a/crates/twirp/src/context.rs b/crates/twirp/src/context.rs index 9e5cd0b..465c6fb 100644 --- a/crates/twirp/src/context.rs +++ b/crates/twirp/src/context.rs @@ -1,7 +1,5 @@ use std::sync::{Arc, Mutex}; -use http::Extensions; - /// Context allows passing information between twirp rpc handlers and http middleware by providing /// access to extensions on the `http::Request` and `http::Response`. /// @@ -9,24 +7,31 @@ use http::Extensions; /// handler code. #[derive(Default)] pub struct Context { - extensions: Extensions, - resp_extensions: Arc>, + req_extensions: http::Extensions, + resp_extensions: Arc>, } impl Context { - pub fn new(extensions: Extensions, resp_extensions: Arc>) -> Self { + pub fn new( + req_extensions: http::Extensions, + resp_extensions: Arc>, + ) -> Self { Self { - extensions, + req_extensions, resp_extensions, } } + pub fn extensions_mut(&mut self) -> &mut http::Extensions { + &mut self.req_extensions + } + /// Get a request extension. pub fn get(&self) -> Option<&T> where T: Clone + Send + Sync + 'static, { - self.extensions.get::() + self.req_extensions.get::() } /// Insert a response extension. diff --git a/example/src/bin/advanced-server.rs b/example/src/bin/advanced-server.rs index cd24fa3..4927daa 100644 --- a/example/src/bin/advanced-server.rs +++ b/example/src/bin/advanced-server.rs @@ -144,7 +144,6 @@ async fn request_id_middleware( #[cfg(test)] mod test { - use service::haberdash::v1::HaberdasherApiClient; use twirp::client::Client; use twirp::url::Url; @@ -228,7 +227,9 @@ mod test { let url = Url::parse(&format!("http://localhost:{}/twirp/", server.port)).unwrap(); let client = Client::from_base_url(url).unwrap(); - let resp = client.make_hat(MakeHatRequest { inches: 1 }).await; + let resp = client + .make_hat(Context::default(), MakeHatRequest { inches: 1 }) + .await; println!("{:?}", resp); assert_eq!(resp.unwrap().size, 1); diff --git a/example/src/bin/client.rs b/example/src/bin/client.rs index 89c6e71..efbcc41 100644 --- a/example/src/bin/client.rs +++ b/example/src/bin/client.rs @@ -1,8 +1,9 @@ +use service::haberdash::v1::HaberdasherApi; use twirp::async_trait::async_trait; use twirp::client::{Client, ClientBuilder, Middleware, Next}; use twirp::reqwest::{Request, Response}; use twirp::url::Url; -use twirp::GenericError; +use twirp::{Context, GenericError}; pub mod service { pub mod haberdash { @@ -13,15 +14,16 @@ pub mod service { } use service::haberdash::v1::{ - GetStatusRequest, GetStatusResponse, HaberdasherApiClient, MakeHatRequest, MakeHatResponse, + GetStatusRequest, GetStatusResponse, MakeHatRequest, MakeHatResponse, }; #[tokio::main] pub async fn main() -> Result<(), GenericError> { // basic client - use service::haberdash::v1::HaberdasherApiClient; let client = Client::from_base_url(Url::parse("http://localhost:3000/twirp/")?)?; - let resp = client.make_hat(MakeHatRequest { inches: 1 }).await; + let resp = client + .make_hat(Context::default(), MakeHatRequest { inches: 1 }) + .await; eprintln!("{:?}", resp); // customize the client with middleware @@ -34,7 +36,7 @@ pub async fn main() -> Result<(), GenericError> { .build()?; let resp = client .with_host("localhost") - .make_hat(MakeHatRequest { inches: 1 }) + .make_hat(Context::default(), MakeHatRequest { inches: 1 }) .await; eprintln!("{:?}", resp); @@ -74,18 +76,21 @@ impl Middleware for PrintResponseHeaders { struct MockHaberdasherApiClient; #[async_trait] -impl HaberdasherApiClient for MockHaberdasherApiClient { +impl HaberdasherApi for MockHaberdasherApiClient { + type Error = twirp::client::ClientError; async fn make_hat( &self, + _ctx: Context, _req: MakeHatRequest, - ) -> Result { + ) -> Result { todo!() } async fn get_status( &self, + _ctx: Context, _req: GetStatusRequest, - ) -> Result { + ) -> Result { todo!() } } diff --git a/example/src/bin/simple-server.rs b/example/src/bin/simple-server.rs index 12eb18b..33d5120 100644 --- a/example/src/bin/simple-server.rs +++ b/example/src/bin/simple-server.rs @@ -89,7 +89,6 @@ struct ResponseInfo(u16); #[cfg(test)] mod test { - use service::haberdash::v1::HaberdasherApiClient; use twirp::client::Client; use twirp::url::Url; use twirp::TwirpErrorCode; @@ -174,7 +173,9 @@ mod test { let url = Url::parse(&format!("http://localhost:{}/twirp/", server.port)).unwrap(); let client = Client::from_base_url(url).unwrap(); - let resp = client.make_hat(MakeHatRequest { inches: 1 }).await; + let resp = client + .make_hat(Context::default(), MakeHatRequest { inches: 1 }) + .await; println!("{:?}", resp); assert_eq!(resp.unwrap().size, 1);