axum/serve/listener.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
use std::{fmt, future::Future, time::Duration};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
};
/// Types that can listen for connections.
pub trait Listener: Send + 'static {
/// The listener's IO type.
type Io: AsyncRead + AsyncWrite + Unpin + Send + 'static;
/// The listener's address type.
type Addr: Send;
/// Accept a new incoming connection to this listener.
///
/// If the underlying accept call can return an error, this function must
/// take care of logging and retrying.
fn accept(&mut self) -> impl Future<Output = (Self::Io, Self::Addr)> + Send;
/// Returns the local address that this listener is bound to.
fn local_addr(&self) -> io::Result<Self::Addr>;
}
impl Listener for TcpListener {
type Io = TcpStream;
type Addr = std::net::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
match Self::accept(self).await {
Ok(tup) => return tup,
Err(e) => handle_accept_error(e).await,
}
}
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
#[cfg(unix)]
impl Listener for tokio::net::UnixListener {
type Io = tokio::net::UnixStream;
type Addr = tokio::net::unix::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
match Self::accept(self).await {
Ok(tup) => return tup,
Err(e) => handle_accept_error(e).await,
}
}
}
#[inline]
fn local_addr(&self) -> io::Result<Self::Addr> {
Self::local_addr(self)
}
}
/// Extensions to [`Listener`].
pub trait ListenerExt: Listener + Sized {
/// Run a mutable closure on every accepted `Io`.
///
/// # Example
///
/// ```
/// use axum::{Router, routing::get, serve::ListenerExt};
/// use tracing::trace;
///
/// # async {
/// let router = Router::new().route("/", get(|| async { "Hello, World!" }));
///
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
/// .await
/// .unwrap()
/// .tap_io(|tcp_stream| {
/// if let Err(err) = tcp_stream.set_nodelay(true) {
/// trace!("failed to set TCP_NODELAY on incoming connection: {err:#}");
/// }
/// });
/// axum::serve(listener, router).await.unwrap();
/// # };
/// ```
fn tap_io<F>(self, tap_fn: F) -> TapIo<Self, F>
where
F: FnMut(&mut Self::Io) + Send + 'static,
{
TapIo {
listener: self,
tap_fn,
}
}
}
impl<L: Listener> ListenerExt for L {}
/// Return type of [`ListenerExt::tap_io`].
///
/// See that method for details.
pub struct TapIo<L, F> {
listener: L,
tap_fn: F,
}
impl<L, F> fmt::Debug for TapIo<L, F>
where
L: Listener + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TapIo")
.field("listener", &self.listener)
.finish_non_exhaustive()
}
}
impl<L, F> Listener for TapIo<L, F>
where
L: Listener,
F: FnMut(&mut L::Io) + Send + 'static,
{
type Io = L::Io;
type Addr = L::Addr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
let (mut io, addr) = self.listener.accept().await;
(self.tap_fn)(&mut io);
(io, addr)
}
fn local_addr(&self) -> io::Result<Self::Addr> {
self.listener.local_addr()
}
}
async fn handle_accept_error(e: io::Error) {
if is_connection_error(&e) {
return;
}
// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
}
fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset
)
}