use std::{
convert::Infallible,
fmt::Debug,
future::{poll_fn, Future, IntoFuture},
io,
marker::PhantomData,
sync::Arc,
};
use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::{pin_mut, FutureExt};
use hyper::body::Incoming;
use hyper_util::rt::{TokioExecutor, TokioIo};
#[cfg(any(feature = "http1", feature = "http2"))]
use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService};
use tokio::sync::watch;
use tower::ServiceExt as _;
use tower_service::Service;
mod listener;
pub use self::listener::{Listener, ListenerExt, TapIo};
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub fn serve<L, M, S>(listener: L, make_service: M) -> Serve<L, M, S>
where
L: Listener,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S>,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
Serve {
listener,
make_service,
_marker: PhantomData,
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct Serve<L, M, S> {
listener: L,
make_service: M,
_marker: PhantomData<S>,
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> Serve<L, M, S>
where
L: Listener,
{
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<L, M, S, F>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
listener: self.listener,
make_service: self.make_service,
signal,
_marker: PhantomData,
}
}
pub fn local_addr(&self) -> io::Result<L::Addr> {
self.listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> Debug for Serve<L, M, S>
where
L: Debug + 'static,
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
listener,
make_service,
_marker: _,
} = self;
let mut s = f.debug_struct("Serve");
s.field("listener", listener)
.field("make_service", make_service);
s.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S> IntoFuture for Serve<L, M, S>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;
fn into_future(self) -> Self::IntoFuture {
self.with_graceful_shutdown(std::future::pending())
.into_future()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
#[must_use = "futures must be awaited or polled"]
pub struct WithGracefulShutdown<L, M, S, F> {
listener: L,
make_service: M,
signal: F,
_marker: PhantomData<S>,
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> WithGracefulShutdown<L, M, S, F>
where
L: Listener,
{
pub fn local_addr(&self) -> io::Result<L::Addr> {
self.listener.local_addr()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> Debug for WithGracefulShutdown<L, M, S, F>
where
L: Debug + 'static,
M: Debug,
S: Debug,
F: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
listener,
make_service,
signal,
_marker: _,
} = self;
f.debug_struct("WithGracefulShutdown")
.field("listener", listener)
.field("make_service", make_service)
.field("signal", signal)
.finish()
}
}
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<L, M, S, F> IntoFuture for WithGracefulShutdown<L, M, S, F>
where
L: Listener,
L::Addr: Debug,
M: for<'a> Service<IncomingStream<'a, L>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a, L>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;
fn into_future(self) -> Self::IntoFuture {
let Self {
mut listener,
mut make_service,
signal,
_marker: _,
} = self;
private::ServeFuture(Box::pin(async move {
let (signal_tx, signal_rx) = watch::channel(());
let signal_tx = Arc::new(signal_tx);
tokio::spawn(async move {
signal.await;
trace!("received graceful shutdown signal. Telling tasks to shutdown");
drop(signal_rx);
});
let (close_tx, close_rx) = watch::channel(());
loop {
let (io, remote_addr) = tokio::select! {
conn = listener.accept() => conn,
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
break;
}
};
let io = TokioIo::new(io);
trace!("connection {remote_addr:?} accepted");
poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});
let tower_service = make_service
.call(IncomingStream {
io: &io,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {})
.map_request(|req: Request<Incoming>| req.map(Body::new));
let hyper_service = TowerToHyperService::new(tower_service);
let signal_tx = Arc::clone(&signal_tx);
let close_rx = close_rx.clone();
tokio::spawn(async move {
#[allow(unused_mut)]
let mut builder = Builder::new(TokioExecutor::new());
#[cfg(feature = "http2")]
builder.http2().enable_connect_protocol();
let conn = builder.serve_connection_with_upgrades(io, hyper_service);
pin_mut!(conn);
let signal_closed = signal_tx.closed().fuse();
pin_mut!(signal_closed);
loop {
tokio::select! {
result = conn.as_mut() => {
if let Err(_err) = result {
trace!("failed to serve connection: {_err:#}");
}
break;
}
_ = &mut signal_closed => {
trace!("signal received in task, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}
}
}
drop(close_rx);
});
}
drop(close_rx);
drop(listener);
trace!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;
Ok(())
}))
}
}
#[derive(Debug)]
pub struct IncomingStream<'a, L>
where
L: Listener,
{
io: &'a TokioIo<L::Io>,
remote_addr: L::Addr,
}
impl<L> IncomingStream<'_, L>
where
L: Listener,
{
pub fn io(&self) -> &L::Io {
self.io.inner()
}
pub fn remote_addr(&self) -> &L::Addr {
&self.remote_addr
}
}
mod private {
use std::{
future::Future,
io,
pin::Pin,
task::{Context, Poll},
};
pub struct ServeFuture(pub(super) futures_util::future::BoxFuture<'static, io::Result<()>>);
impl Future for ServeFuture {
type Output = io::Result<()>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
impl std::fmt::Debug for ServeFuture {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServeFuture").finish_non_exhaustive()
}
}
}
#[cfg(test)]
mod tests {
use std::{
future::{pending, IntoFuture as _},
net::{IpAddr, Ipv4Addr},
};
use axum_core::{body::Body, extract::Request};
use http::StatusCode;
use hyper_util::rt::TokioIo;
#[cfg(unix)]
use tokio::net::UnixListener;
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::TcpListener,
};
#[cfg(unix)]
use super::IncomingStream;
use super::{serve, Listener};
#[cfg(unix)]
use crate::extract::connect_info::Connected;
use crate::{
body::to_bytes,
handler::{Handler, HandlerWithoutStateExt},
routing::get,
serve::ListenerExt,
Router,
};
#[allow(dead_code, unused_must_use)]
async fn if_it_compiles_it_works() {
#[derive(Clone, Debug)]
struct UdsConnectInfo;
#[cfg(unix)]
impl Connected<IncomingStream<'_, UnixListener>> for UdsConnectInfo {
fn connect_info(_stream: IncomingStream<'_, UnixListener>) -> Self {
Self
}
}
let router: Router = Router::new();
let addr = "0.0.0.0:0";
let tcp_nodelay_listener = || async {
TcpListener::bind(addr).await.unwrap().tap_io(|tcp_stream| {
if let Err(err) = tcp_stream.set_nodelay(true) {
eprintln!("failed to set TCP_NODELAY on incoming connection: {err:#}");
}
})
};
serve(TcpListener::bind(addr).await.unwrap(), router.clone());
serve(tcp_nodelay_listener().await, router.clone())
.await
.unwrap();
#[cfg(unix)]
serve(UnixListener::bind("").unwrap(), router.clone());
serve(
TcpListener::bind(addr).await.unwrap(),
router.clone().into_make_service(),
);
serve(
tcp_nodelay_listener().await,
router.clone().into_make_service(),
);
#[cfg(unix)]
serve(
UnixListener::bind("").unwrap(),
router.clone().into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
router
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
router
.clone()
.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
#[cfg(unix)]
serve(
UnixListener::bind("").unwrap(),
router.into_make_service_with_connect_info::<UdsConnectInfo>(),
);
serve(TcpListener::bind(addr).await.unwrap(), get(handler));
serve(tcp_nodelay_listener().await, get(handler));
#[cfg(unix)]
serve(UnixListener::bind("").unwrap(), get(handler));
serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service(),
);
serve(
tcp_nodelay_listener().await,
get(handler).into_make_service(),
);
#[cfg(unix)]
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
get(handler).into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
#[cfg(unix)]
serve(
UnixListener::bind("").unwrap(),
get(handler).into_make_service_with_connect_info::<UdsConnectInfo>(),
);
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_service(),
);
serve(tcp_nodelay_listener().await, handler.into_service());
#[cfg(unix)]
serve(UnixListener::bind("").unwrap(), handler.into_service());
serve(
TcpListener::bind(addr).await.unwrap(),
handler.with_state(()),
);
serve(tcp_nodelay_listener().await, handler.with_state(()));
#[cfg(unix)]
serve(UnixListener::bind("").unwrap(), handler.with_state(()));
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service(),
);
serve(tcp_nodelay_listener().await, handler.into_make_service());
#[cfg(unix)]
serve(UnixListener::bind("").unwrap(), handler.into_make_service());
serve(
TcpListener::bind(addr).await.unwrap(),
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
serve(
tcp_nodelay_listener().await,
handler.into_make_service_with_connect_info::<std::net::SocketAddr>(),
);
#[cfg(unix)]
serve(
UnixListener::bind("").unwrap(),
handler.into_make_service_with_connect_info::<UdsConnectInfo>(),
);
}
async fn handler() {}
#[crate::test]
async fn test_serve_local_addr() {
let router: Router = Router::new();
let addr = "0.0.0.0:0";
let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone());
let address = server.local_addr().unwrap();
assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
assert_ne!(address.port(), 0);
}
#[crate::test]
async fn test_with_graceful_shutdown_local_addr() {
let router: Router = Router::new();
let addr = "0.0.0.0:0";
let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone())
.with_graceful_shutdown(pending());
let address = server.local_addr().unwrap();
assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
assert_ne!(address.port(), 0);
}
#[test]
fn into_future_outside_tokio() {
let router: Router = Router::new();
let addr = "0.0.0.0:0";
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap();
_ = serve(listener, router).into_future();
}
#[crate::test]
async fn serving_on_custom_io_type() {
struct ReadyListener<T>(Option<T>);
impl<T> Listener for ReadyListener<T>
where
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Io = T;
type Addr = ();
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
match self.0.take() {
Some(server) => (server, ()),
None => std::future::pending().await,
}
}
fn local_addr(&self) -> io::Result<Self::Addr> {
Ok(())
}
}
let (client, server) = io::duplex(1024);
let listener = ReadyListener(Some(server));
let app = Router::new().route("/", get(|| async { "Hello, World!" }));
tokio::spawn(serve(listener, app).into_future());
let stream = TokioIo::new(client);
let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await.unwrap();
tokio::spawn(conn);
let request = Request::builder().body(Body::empty()).unwrap();
let response = sender.send_request(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = Body::new(response.into_body());
let body = to_bytes(body, usize::MAX).await.unwrap();
let body = String::from_utf8(body.to_vec()).unwrap();
assert_eq!(body, "Hello, World!");
}
}