[pbs-devel] [PATCH v3 proxmox 1/3] rest-server: Refactor `AcceptBuilder`, provide support for optional TLS
Wolfgang Bumiller
w.bumiller at proxmox.com
Thu Nov 16 08:35:31 CET 2023
On Tue, Oct 31, 2023 at 07:47:03PM +0100, Max Carrara wrote:
> The new public function `accept_tls_optional()` is added, which
> accepts both plain TCP streams and TCP streams running TLS. Plain TCP
> streams are sent along via a separate channel in order to clearly
> distinguish between "secure" and "insecure" connections.
>
> Furthermore, instead of `AcceptBuilder` itself holding a reference to
> an `SslAcceptor`, its public functions now take the acceptor as an
> argument. The public functions' names are changed to distinguish
> between their functionality in a more explicit manner:
>
> * `accept()` --> `accept_tls()`
> * NEW --> `accept_tls_optional()`
>
> Signed-off-by: Max Carrara <m.carrara at proxmox.com>
> ---
> Changes v1 --> v2:
> * No more `BiAcceptBuilder`, `AcceptBuilder` is refactored instead
> * `AcceptBuilder` doesn't hold a reference to `SslAcceptor` anymore
> * Avoid unnecessary `#[cfg]`s
> * Avoid unnecessarily duplicated code (already mostly done by getting
> rid of `BiAcceptBuilder`)
> * Some clippy stuff
>
> Changes v2 --> v3:
> * Incorporate previously applied clippy fixes
>
> proxmox-rest-server/src/connection.rs | 373 ++++++++++++++++++++------
> 1 file changed, 287 insertions(+), 86 deletions(-)
>
> diff --git a/proxmox-rest-server/src/connection.rs b/proxmox-rest-server/src/connection.rs
> index 1bec28d..ab8c7db 100644
> --- a/proxmox-rest-server/src/connection.rs
> +++ b/proxmox-rest-server/src/connection.rs
> @@ -8,15 +8,16 @@ use std::pin::Pin;
> use std::sync::{Arc, Mutex};
> use std::time::Duration;
>
> -use anyhow::Context as _;
> -use anyhow::Error;
> +use anyhow::{format_err, Context as _, Error};
> use futures::FutureExt;
> +use hyper::server::accept;
> use openssl::ec::{EcGroup, EcKey};
> use openssl::nid::Nid;
> use openssl::pkey::{PKey, Private};
> use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
> use openssl::x509::X509;
> use tokio::net::{TcpListener, TcpStream};
> +use tokio::sync::mpsc;
> use tokio_openssl::SslStream;
> use tokio_stream::wrappers::ReceiverStream;
>
> @@ -133,10 +134,14 @@ impl TlsAcceptorBuilder {
> }
> }
>
> -#[cfg(feature = "rate-limited-stream")]
> -type ClientStreamResult = Pin<Box<SslStream<RateLimitedStream<TcpStream>>>>;
> #[cfg(not(feature = "rate-limited-stream"))]
> -type ClientStreamResult = Pin<Box<SslStream<TcpStream>>>;
> +type InsecureClientStream = TcpStream;
> +#[cfg(feature = "rate-limited-stream")]
> +type InsecureClientStream = RateLimitedStream<TcpStream>;
> +
> +type InsecureClientStreamResult = Pin<Box<InsecureClientStream>>;
> +
> +type ClientStreamResult = Pin<Box<SslStream<InsecureClientStream>>>;
>
> #[cfg(feature = "rate-limited-stream")]
> type LookupRateLimiter = dyn Fn(std::net::SocketAddr) -> (Option<SharedRateLimit>, Option<SharedRateLimit>)
> @@ -145,7 +150,6 @@ type LookupRateLimiter = dyn Fn(std::net::SocketAddr) -> (Option<SharedRateLimit
> + 'static;
>
> pub struct AcceptBuilder {
> - acceptor: Arc<Mutex<SslAcceptor>>,
> debug: bool,
> tcp_keepalive_time: u32,
> max_pending_accepts: usize,
> @@ -154,16 +158,9 @@ pub struct AcceptBuilder {
> lookup_rate_limiter: Option<Arc<LookupRateLimiter>>,
> }
>
> -impl AcceptBuilder {
> - pub fn new() -> Result<Self, Error> {
> - Ok(Self::with_acceptor(Arc::new(Mutex::new(
> - TlsAcceptorBuilder::new().build()?,
> - ))))
> - }
> -
> - pub fn with_acceptor(acceptor: Arc<Mutex<SslAcceptor>>) -> Self {
> +impl Default for AcceptBuilder {
> + fn default() -> Self {
> Self {
> - acceptor,
> debug: false,
> tcp_keepalive_time: 120,
> max_pending_accepts: 1024,
> @@ -172,6 +169,12 @@ impl AcceptBuilder {
> lookup_rate_limiter: None,
> }
> }
> +}
> +
> +impl AcceptBuilder {
> + pub fn new() -> Self {
> + Default::default()
> + }
>
> pub fn debug(mut self, debug: bool) -> Self {
> self.debug = debug;
> @@ -193,114 +196,312 @@ impl AcceptBuilder {
> self.lookup_rate_limiter = Some(lookup_rate_limiter);
> self
> }
> +}
>
> - pub fn accept(
> +impl AcceptBuilder {
> + pub fn accept_tls(
> self,
> listener: TcpListener,
> - ) -> impl hyper::server::accept::Accept<Conn = ClientStreamResult, Error = Error> {
> - let (sender, receiver) = tokio::sync::mpsc::channel(self.max_pending_accepts);
> + acceptor: Arc<Mutex<SslAcceptor>>,
> + ) -> impl accept::Accept<Conn = ClientStreamResult, Error = Error> {
> + let (secure_sender, secure_receiver) = mpsc::channel(self.max_pending_accepts);
> +
> + tokio::spawn(self.accept_connections(listener, acceptor, secure_sender.into()));
> +
> + accept::from_stream(ReceiverStream::new(secure_receiver))
> + }
> +
> + pub fn accept_tls_optional(
> + self,
> + listener: TcpListener,
> + acceptor: Arc<Mutex<SslAcceptor>>,
> + ) -> (
> + impl accept::Accept<Conn = ClientStreamResult, Error = Error>,
> + impl accept::Accept<Conn = InsecureClientStreamResult, Error = Error>,
> + ) {
> + let (secure_sender, secure_receiver) = mpsc::channel(self.max_pending_accepts);
> + let (insecure_sender, insecure_receiver) = mpsc::channel(self.max_pending_accepts);
> +
> + tokio::spawn(self.accept_connections(
> + listener,
> + acceptor,
> + (secure_sender, insecure_sender).into(),
> + ));
> +
> + (
> + accept::from_stream(ReceiverStream::new(secure_receiver)),
> + accept::from_stream(ReceiverStream::new(insecure_receiver)),
> + )
> + }
> +}
> +
> +type ClientSender = mpsc::Sender<Result<ClientStreamResult, Error>>;
> +type InsecureClientSender = mpsc::Sender<Result<InsecureClientStreamResult, Error>>;
>
> - tokio::spawn(self.accept_connections(listener, sender));
> +enum Sender {
> + Secure(ClientSender),
> + SecureAndInsecure(ClientSender, InsecureClientSender),
> +}
>
> - //receiver
> - hyper::server::accept::from_stream(ReceiverStream::new(receiver))
> +impl From<ClientSender> for Sender {
> + fn from(sender: ClientSender) -> Self {
> + Sender::Secure(sender)
> }
> +}
> +
> +impl From<(ClientSender, InsecureClientSender)> for Sender {
> + fn from(senders: (ClientSender, InsecureClientSender)) -> Self {
> + Sender::SecureAndInsecure(senders.0, senders.1)
> + }
> +}
>
> +impl AcceptBuilder {
> async fn accept_connections(
> self,
> listener: TcpListener,
> - sender: tokio::sync::mpsc::Sender<Result<ClientStreamResult, Error>>,
> + acceptor: Arc<Mutex<SslAcceptor>>,
> + sender: Sender,
> ) {
> let accept_counter = Arc::new(());
> let mut shutdown_future = crate::shutdown_future().fuse();
>
> loop {
> - let (sock, peer) = futures::select! {
> - res = listener.accept().fuse() => match res {
> - Ok(conn) => conn,
> + let socket = futures::select! {
> + res = self.try_setup_socket(&listener).fuse() => match res {
> + Ok(socket) => socket,
> Err(err) => {
> - eprintln!("error accepting tcp connection: {err}");
> + log::error!("couldn't set up TCP socket: {err}");
> continue;
> }
> },
> - _ = shutdown_future => break,
> + _ = shutdown_future => break,
> };
> - #[cfg(not(feature = "rate-limited-stream"))]
> - {
> - let _ = &peer;
> - }
>
> - sock.set_nodelay(true).unwrap();
> - let _ = proxmox_sys::linux::socket::set_tcp_keepalive(
> - sock.as_raw_fd(),
> - self.tcp_keepalive_time,
> - );
> + let acceptor = Arc::clone(&acceptor);
> + let accept_counter = Arc::clone(&accept_counter);
>
> - #[cfg(feature = "rate-limited-stream")]
> - let sock = match self.lookup_rate_limiter.clone() {
> - Some(lookup) => {
> - RateLimitedStream::with_limiter_update_cb(sock, move || lookup(peer))
> + if Arc::strong_count(&accept_counter) > self.max_pending_accepts {
> + log::error!("connection rejected - too many open connections");
> + continue;
> + }
> +
> + match sender {
> + Sender::Secure(ref secure_sender) => {
> + let accept_future = Self::do_accept_tls(
> + socket,
> + acceptor,
> + accept_counter,
> + self.debug,
> + secure_sender.clone(),
> + );
> +
> + tokio::spawn(accept_future);
> + }
> + Sender::SecureAndInsecure(ref secure_sender, ref insecure_sender) => {
> + let accept_future = Self::do_accept_tls_optional(
> + socket,
> + acceptor,
> + accept_counter,
> + self.debug,
> + secure_sender.clone(),
> + insecure_sender.clone(),
> + );
> +
> + tokio::spawn(accept_future);
> }
> - None => RateLimitedStream::with_limiter(sock, None, None),
> };
> + }
> + }
>
> - let ssl = {
> - // limit acceptor_guard scope
> - // Acceptor can be reloaded using the command socket "reload-certificate" command
> - let acceptor_guard = self.acceptor.lock().unwrap();
> + async fn try_setup_socket(
> + &self,
> + listener: &TcpListener,
> + ) -> Result<InsecureClientStream, Error> {
> + let (socket, peer) = match listener.accept().await {
> + Ok(connection) => connection,
> + Err(error) => {
> + return Err(format_err!(error)).context("error while accepting tcp stream")
> + }
> + };
>
> - match openssl::ssl::Ssl::new(acceptor_guard.context()) {
> - Ok(ssl) => ssl,
> - Err(err) => {
> - eprintln!("failed to create Ssl object from Acceptor context - {err}");
> - continue;
> - }
> - }
> - };
> + socket
> + .set_nodelay(true)
> + .context("error while setting TCP_NODELAY on socket")?;
> +
> + proxmox_sys::linux::socket::set_tcp_keepalive(socket.as_raw_fd(), self.tcp_keepalive_time)
> + .context("error while setting SO_KEEPALIVE on socket")?;
>
> - let stream = match tokio_openssl::SslStream::new(ssl, sock) {
> - Ok(stream) => stream,
> + #[cfg(feature = "rate-limited-stream")]
> + let socket = match self.lookup_rate_limiter.clone() {
> + Some(lookup) => RateLimitedStream::with_limiter_update_cb(socket, move || lookup(peer)),
> + None => RateLimitedStream::with_limiter(socket, None, None),
> + };
> +
> + #[cfg(not(feature = "rate-limited-stream"))]
> + let _peer = peer;
> +
> + Ok(socket)
> + }
> +
> + async fn do_accept_tls(
> + socket: InsecureClientStream,
> + acceptor: Arc<Mutex<SslAcceptor>>,
> + accept_counter: Arc<()>,
> + debug: bool,
> + secure_sender: ClientSender,
> + ) {
> + let ssl = {
> + // limit acceptor_guard scope
> + // Acceptor can be reloaded using the command socket "reload-certificate" command
> + let acceptor_guard = acceptor.lock().unwrap();
> +
> + match openssl::ssl::Ssl::new(acceptor_guard.context()) {
> + Ok(ssl) => ssl,
> Err(err) => {
> - eprintln!("failed to create SslStream using ssl and connection socket - {err}");
> - continue;
> + log::error!("failed to create Ssl object from Acceptor context - {err}");
> + return;
> }
> - };
> + }
> + };
>
> - let mut stream = Box::pin(stream);
> - let sender = sender.clone();
> + let secure_stream = match tokio_openssl::SslStream::new(ssl, socket) {
> + Ok(stream) => stream,
> + Err(err) => {
> + log::error!("failed to create SslStream using ssl and connection socket - {err}");
> + return;
> + }
> + };
>
> - if Arc::strong_count(&accept_counter) > self.max_pending_accepts {
> - eprintln!("connection rejected - too many open connections");
> - continue;
> + let mut secure_stream = Box::pin(secure_stream);
> +
> + let accept_future =
> + tokio::time::timeout(Duration::new(10, 0), secure_stream.as_mut().accept());
> +
> + let result = accept_future.await;
> +
> + match result {
> + Ok(Ok(())) => {
> + if secure_sender.send(Ok(secure_stream)).await.is_err() && debug {
> + log::error!("detected closed connection channel");
> + }
> + }
> + Ok(Err(err)) => {
> + if debug {
> + log::error!("https handshake failed - {err}");
> + }
> }
> + Err(_) => {
> + if debug {
> + log::error!("https handshake timeout");
> + }
> + }
> + }
>
> - let accept_counter = Arc::clone(&accept_counter);
> - tokio::spawn(async move {
> - let accept_future =
> - tokio::time::timeout(Duration::new(10, 0), stream.as_mut().accept());
> + drop(accept_counter); // decrease reference count
> + }
>
> - let result = accept_future.await;
> + async fn do_accept_tls_optional(
> + socket: InsecureClientStream,
> + acceptor: Arc<Mutex<SslAcceptor>>,
> + accept_counter: Arc<()>,
> + debug: bool,
> + secure_sender: ClientSender,
> + insecure_sender: InsecureClientSender,
> + ) {
> + let client_initiates_handshake = {
> + #[cfg(feature = "rate-limited-stream")]
> + let socket = socket.inner();
>
> - match result {
> - Ok(Ok(())) => {
> - if sender.send(Ok(stream)).await.is_err() && self.debug {
> - log::error!("detect closed connection channel");
> - }
> - }
> - Ok(Err(err)) => {
> - if self.debug {
> - log::error!("https handshake failed - {err}");
> - }
> - }
> - Err(_) => {
> - if self.debug {
> - log::error!("https handshake timeout");
> - }
> - }
> + #[cfg(not(feature = "rate-limited-stream"))]
> + let socket = &socket;
> +
> + match Self::wait_for_client_tls_handshake(socket).await {
> + Ok(initiates_handshake) => initiates_handshake,
> + Err(err) => {
> + log::error!("error checking for TLS handshake: {err}");
> + return;
> }
> + }
> + };
> +
> + if !client_initiates_handshake {
> + let insecure_stream = Box::pin(socket);
>
> - drop(accept_counter); // decrease reference count
> - });
> + if insecure_sender.send(Ok(insecure_stream)).await.is_err() && debug {
> + log::error!("detected closed connection channel")
> + }
> +
> + return;
> }
> +
> + Self::do_accept_tls(socket, acceptor, accept_counter, debug, secure_sender).await
> }
> +
> + async fn wait_for_client_tls_handshake(incoming_stream: &TcpStream) -> Result<bool, Error> {
> + const MS_TIMEOUT: u64 = 1000;
> + const BYTES_BUF_SIZE: usize = 128;
> +
> + let mut buf = [0; BYTES_BUF_SIZE];
> + let mut last_peek_size = 0;
> +
> + let future = async {
> + loop {
> + let peek_size = incoming_stream
> + .peek(&mut buf)
> + .await
> + .context("couldn't peek into incoming tcp stream")?;
> +
> + if contains_tls_handshake_fragment(&buf) {
> + return Ok(true);
> + }
> +
> + // No more new data came in
> + if peek_size == last_peek_size {
> + return Ok(false);
> + }
> +
> + last_peek_size = peek_size;
> +
> + // yields to event loop; this future blocks otherwise ad infinitum
> + tokio::time::sleep(Duration::from_millis(0)).await;
Just noticed this - how about tokio::task::yield_now()?
Also this makes me wish for epoll to have a flag for edge triggering to
also trigger on additional data not just when going up from zero.
More information about the pbs-devel
mailing list