tower/steer/mod.rs
1//! This module provides functionality to aid managing routing requests between [`Service`]s.
2//!
3//! # Example
4//!
5//! [`Steer`] can for example be used to create a router, akin to what you might find in web
6//! frameworks.
7//!
8//! Here, `GET /` will be sent to the `root` service, while all other requests go to `not_found`.
9//!
10//! ```rust
11//! # use std::task::{Context, Poll};
12//! # use tower_service::Service;
13//! # use futures_util::future::{ready, Ready, poll_fn};
14//! # use tower::steer::Steer;
15//! # use tower::service_fn;
16//! # use tower::util::BoxService;
17//! # use tower::ServiceExt;
18//! # use std::convert::Infallible;
19//! use http::{Request, Response, StatusCode, Method};
20//!
21//! # #[tokio::main]
22//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
23//! // Service that responds to `GET /`
24//! let root = service_fn(|req: Request<String>| async move {
25//! # assert_eq!(req.uri().path(), "/");
26//! let res = Response::new("Hello, World!".to_string());
27//! Ok::<_, Infallible>(res)
28//! });
29//! // We have to box the service so its type gets erased and we can put it in a `Vec` with other
30//! // services
31//! let root = BoxService::new(root);
32//!
33//! // Service that responds with `404 Not Found` to all requests
34//! let not_found = service_fn(|req: Request<String>| async move {
35//! let res = Response::builder()
36//! .status(StatusCode::NOT_FOUND)
37//! .body(String::new())
38//! .expect("response is valid");
39//! Ok::<_, Infallible>(res)
40//! });
41//! // Box that as well
42//! let not_found = BoxService::new(not_found);
43//!
44//! let mut svc = Steer::new(
45//! // All services we route between
46//! vec![root, not_found],
47//! // How we pick which service to send the request to
48//! |req: &Request<String>, _services: &[_]| {
49//! if req.method() == Method::GET && req.uri().path() == "/" {
50//! 0 // Index of `root`
51//! } else {
52//! 1 // Index of `not_found`
53//! }
54//! },
55//! );
56//!
57//! // This request will get sent to `root`
58//! let req = Request::get("/").body(String::new()).unwrap();
59//! let res = svc.ready().await?.call(req).await?;
60//! assert_eq!(res.into_body(), "Hello, World!");
61//!
62//! // This request will get sent to `not_found`
63//! let req = Request::get("/does/not/exist").body(String::new()).unwrap();
64//! let res = svc.ready().await?.call(req).await?;
65//! assert_eq!(res.status(), StatusCode::NOT_FOUND);
66//! assert_eq!(res.into_body(), "");
67//! #
68//! # Ok(())
69//! # }
70//! ```
71use std::task::{Context, Poll};
72use std::{collections::VecDeque, fmt, marker::PhantomData};
73use tower_service::Service;
74
75/// This is how callers of [`Steer`] tell it which `Service` a `Req` corresponds to.
76pub trait Picker<S, Req> {
77 /// Return an index into the iterator of `Service` passed to [`Steer::new`].
78 fn pick(&mut self, r: &Req, services: &[S]) -> usize;
79}
80
81impl<S, F, Req> Picker<S, Req> for F
82where
83 F: Fn(&Req, &[S]) -> usize,
84{
85 fn pick(&mut self, r: &Req, services: &[S]) -> usize {
86 self(r, services)
87 }
88}
89
90/// [`Steer`] manages a list of [`Service`]s which all handle the same type of request.
91///
92/// An example use case is a sharded service.
93/// It accepts new requests, then:
94/// 1. Determines, via the provided [`Picker`], which [`Service`] the request corresponds to.
95/// 2. Waits (in [`Service::poll_ready`]) for *all* services to be ready.
96/// 3. Calls the correct [`Service`] with the request, and returns a future corresponding to the
97/// call.
98///
99/// Note that [`Steer`] must wait for all services to be ready since it can't know ahead of time
100/// which [`Service`] the next message will arrive for, and is unwilling to buffer items
101/// indefinitely. This will cause head-of-line blocking unless paired with a [`Service`] that does
102/// buffer items indefinitely, and thus always returns [`Poll::Ready`]. For example, wrapping each
103/// component service with a [`Buffer`] with a high enough limit (the maximum number of concurrent
104/// requests) will prevent head-of-line blocking in [`Steer`].
105///
106/// [`Buffer`]: crate::buffer::Buffer
107pub struct Steer<S, F, Req> {
108 router: F,
109 services: Vec<S>,
110 not_ready: VecDeque<usize>,
111 _phantom: PhantomData<Req>,
112}
113
114impl<S, F, Req> Steer<S, F, Req> {
115 /// Make a new [`Steer`] with a list of [`Service`]'s and a [`Picker`].
116 ///
117 /// Note: the order of the [`Service`]'s is significant for [`Picker::pick`]'s return value.
118 pub fn new(services: impl IntoIterator<Item = S>, router: F) -> Self {
119 let services: Vec<_> = services.into_iter().collect();
120 let not_ready: VecDeque<_> = services.iter().enumerate().map(|(i, _)| i).collect();
121 Self {
122 router,
123 services,
124 not_ready,
125 _phantom: PhantomData,
126 }
127 }
128}
129
130impl<S, Req, F> Service<Req> for Steer<S, F, Req>
131where
132 S: Service<Req>,
133 F: Picker<S, Req>,
134{
135 type Response = S::Response;
136 type Error = S::Error;
137 type Future = S::Future;
138
139 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
140 loop {
141 // must wait for *all* services to be ready.
142 // this will cause head-of-line blocking unless the underlying services are always ready.
143 if self.not_ready.is_empty() {
144 return Poll::Ready(Ok(()));
145 } else {
146 if self.services[self.not_ready[0]]
147 .poll_ready(cx)?
148 .is_pending()
149 {
150 return Poll::Pending;
151 }
152
153 self.not_ready.pop_front();
154 }
155 }
156 }
157
158 fn call(&mut self, req: Req) -> Self::Future {
159 assert!(
160 self.not_ready.is_empty(),
161 "Steer must wait for all services to be ready. Did you forget to call poll_ready()?"
162 );
163
164 let idx = self.router.pick(&req, &self.services[..]);
165 let cl = &mut self.services[idx];
166 self.not_ready.push_back(idx);
167 cl.call(req)
168 }
169}
170
171impl<S, F, Req> Clone for Steer<S, F, Req>
172where
173 S: Clone,
174 F: Clone,
175{
176 fn clone(&self) -> Self {
177 Self {
178 router: self.router.clone(),
179 services: self.services.clone(),
180 not_ready: self.not_ready.clone(),
181 _phantom: PhantomData,
182 }
183 }
184}
185
186impl<S, F, Req> fmt::Debug for Steer<S, F, Req>
187where
188 S: fmt::Debug,
189 F: fmt::Debug,
190{
191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192 let Self {
193 router,
194 services,
195 not_ready,
196 _phantom,
197 } = self;
198 f.debug_struct("Steer")
199 .field("router", router)
200 .field("services", services)
201 .field("not_ready", not_ready)
202 .finish()
203 }
204}