[pbs-devel] [PATCH v3 proxmox 1/3] rest-server: Refactor `AcceptBuilder`, provide support for optional TLS

Wolfgang Bumiller w.bumiller at proxmox.com
Thu Nov 16 08:35:31 CET 2023


On Tue, Oct 31, 2023 at 07:47:03PM +0100, Max Carrara wrote:
> The new public function `accept_tls_optional()` is added, which
> accepts both plain TCP streams and TCP streams running TLS. Plain TCP
> streams are sent along via a separate channel in order to clearly
> distinguish between "secure" and "insecure" connections.
> 
> Furthermore, instead of `AcceptBuilder` itself holding a reference to
> an `SslAcceptor`, its public functions now take the acceptor as an
> argument. The public functions' names are changed to distinguish
> between their functionality in a more explicit manner:
> 
>   * `accept()` --> `accept_tls()`
>   *        NEW --> `accept_tls_optional()`
> 
> Signed-off-by: Max Carrara <m.carrara at proxmox.com>
> ---
>  Changes v1 --> v2:
>   * No more `BiAcceptBuilder`, `AcceptBuilder` is refactored instead
>   * `AcceptBuilder` doesn't hold a reference to `SslAcceptor` anymore
>   * Avoid unnecessary `#[cfg]`s
>   * Avoid unnecessarily duplicated code (already mostly done by getting
>     rid of `BiAcceptBuilder`)
>   * Some clippy stuff
> 
>  Changes v2 --> v3:
>   * Incorporate previously applied clippy fixes
> 
>  proxmox-rest-server/src/connection.rs | 373 ++++++++++++++++++++------
>  1 file changed, 287 insertions(+), 86 deletions(-)
> 
> diff --git a/proxmox-rest-server/src/connection.rs b/proxmox-rest-server/src/connection.rs
> index 1bec28d..ab8c7db 100644
> --- a/proxmox-rest-server/src/connection.rs
> +++ b/proxmox-rest-server/src/connection.rs
> @@ -8,15 +8,16 @@ use std::pin::Pin;
>  use std::sync::{Arc, Mutex};
>  use std::time::Duration;
>  
> -use anyhow::Context as _;
> -use anyhow::Error;
> +use anyhow::{format_err, Context as _, Error};
>  use futures::FutureExt;
> +use hyper::server::accept;
>  use openssl::ec::{EcGroup, EcKey};
>  use openssl::nid::Nid;
>  use openssl::pkey::{PKey, Private};
>  use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
>  use openssl::x509::X509;
>  use tokio::net::{TcpListener, TcpStream};
> +use tokio::sync::mpsc;
>  use tokio_openssl::SslStream;
>  use tokio_stream::wrappers::ReceiverStream;
>  
> @@ -133,10 +134,14 @@ impl TlsAcceptorBuilder {
>      }
>  }
>  
> -#[cfg(feature = "rate-limited-stream")]
> -type ClientStreamResult = Pin<Box<SslStream<RateLimitedStream<TcpStream>>>>;
>  #[cfg(not(feature = "rate-limited-stream"))]
> -type ClientStreamResult = Pin<Box<SslStream<TcpStream>>>;
> +type InsecureClientStream = TcpStream;
> +#[cfg(feature = "rate-limited-stream")]
> +type InsecureClientStream = RateLimitedStream<TcpStream>;
> +
> +type InsecureClientStreamResult = Pin<Box<InsecureClientStream>>;
> +
> +type ClientStreamResult = Pin<Box<SslStream<InsecureClientStream>>>;
>  
>  #[cfg(feature = "rate-limited-stream")]
>  type LookupRateLimiter = dyn Fn(std::net::SocketAddr) -> (Option<SharedRateLimit>, Option<SharedRateLimit>)
> @@ -145,7 +150,6 @@ type LookupRateLimiter = dyn Fn(std::net::SocketAddr) -> (Option<SharedRateLimit
>      + 'static;
>  
>  pub struct AcceptBuilder {
> -    acceptor: Arc<Mutex<SslAcceptor>>,
>      debug: bool,
>      tcp_keepalive_time: u32,
>      max_pending_accepts: usize,
> @@ -154,16 +158,9 @@ pub struct AcceptBuilder {
>      lookup_rate_limiter: Option<Arc<LookupRateLimiter>>,
>  }
>  
> -impl AcceptBuilder {
> -    pub fn new() -> Result<Self, Error> {
> -        Ok(Self::with_acceptor(Arc::new(Mutex::new(
> -            TlsAcceptorBuilder::new().build()?,
> -        ))))
> -    }
> -
> -    pub fn with_acceptor(acceptor: Arc<Mutex<SslAcceptor>>) -> Self {
> +impl Default for AcceptBuilder {
> +    fn default() -> Self {
>          Self {
> -            acceptor,
>              debug: false,
>              tcp_keepalive_time: 120,
>              max_pending_accepts: 1024,
> @@ -172,6 +169,12 @@ impl AcceptBuilder {
>              lookup_rate_limiter: None,
>          }
>      }
> +}
> +
> +impl AcceptBuilder {
> +    pub fn new() -> Self {
> +        Default::default()
> +    }
>  
>      pub fn debug(mut self, debug: bool) -> Self {
>          self.debug = debug;
> @@ -193,114 +196,312 @@ impl AcceptBuilder {
>          self.lookup_rate_limiter = Some(lookup_rate_limiter);
>          self
>      }
> +}
>  
> -    pub fn accept(
> +impl AcceptBuilder {
> +    pub fn accept_tls(
>          self,
>          listener: TcpListener,
> -    ) -> impl hyper::server::accept::Accept<Conn = ClientStreamResult, Error = Error> {
> -        let (sender, receiver) = tokio::sync::mpsc::channel(self.max_pending_accepts);
> +        acceptor: Arc<Mutex<SslAcceptor>>,
> +    ) -> impl accept::Accept<Conn = ClientStreamResult, Error = Error> {
> +        let (secure_sender, secure_receiver) = mpsc::channel(self.max_pending_accepts);
> +
> +        tokio::spawn(self.accept_connections(listener, acceptor, secure_sender.into()));
> +
> +        accept::from_stream(ReceiverStream::new(secure_receiver))
> +    }
> +
> +    pub fn accept_tls_optional(
> +        self,
> +        listener: TcpListener,
> +        acceptor: Arc<Mutex<SslAcceptor>>,
> +    ) -> (
> +        impl accept::Accept<Conn = ClientStreamResult, Error = Error>,
> +        impl accept::Accept<Conn = InsecureClientStreamResult, Error = Error>,
> +    ) {
> +        let (secure_sender, secure_receiver) = mpsc::channel(self.max_pending_accepts);
> +        let (insecure_sender, insecure_receiver) = mpsc::channel(self.max_pending_accepts);
> +
> +        tokio::spawn(self.accept_connections(
> +            listener,
> +            acceptor,
> +            (secure_sender, insecure_sender).into(),
> +        ));
> +
> +        (
> +            accept::from_stream(ReceiverStream::new(secure_receiver)),
> +            accept::from_stream(ReceiverStream::new(insecure_receiver)),
> +        )
> +    }
> +}
> +
> +type ClientSender = mpsc::Sender<Result<ClientStreamResult, Error>>;
> +type InsecureClientSender = mpsc::Sender<Result<InsecureClientStreamResult, Error>>;
>  
> -        tokio::spawn(self.accept_connections(listener, sender));
> +enum Sender {
> +    Secure(ClientSender),
> +    SecureAndInsecure(ClientSender, InsecureClientSender),
> +}
>  
> -        //receiver
> -        hyper::server::accept::from_stream(ReceiverStream::new(receiver))
> +impl From<ClientSender> for Sender {
> +    fn from(sender: ClientSender) -> Self {
> +        Sender::Secure(sender)
>      }
> +}
> +
> +impl From<(ClientSender, InsecureClientSender)> for Sender {
> +    fn from(senders: (ClientSender, InsecureClientSender)) -> Self {
> +        Sender::SecureAndInsecure(senders.0, senders.1)
> +    }
> +}
>  
> +impl AcceptBuilder {
>      async fn accept_connections(
>          self,
>          listener: TcpListener,
> -        sender: tokio::sync::mpsc::Sender<Result<ClientStreamResult, Error>>,
> +        acceptor: Arc<Mutex<SslAcceptor>>,
> +        sender: Sender,
>      ) {
>          let accept_counter = Arc::new(());
>          let mut shutdown_future = crate::shutdown_future().fuse();
>  
>          loop {
> -            let (sock, peer) = futures::select! {
> -                res = listener.accept().fuse() => match res {
> -                    Ok(conn) => conn,
> +            let socket = futures::select! {
> +                res = self.try_setup_socket(&listener).fuse() => match res {
> +                    Ok(socket) => socket,
>                      Err(err) => {
> -                        eprintln!("error accepting tcp connection: {err}");
> +                        log::error!("couldn't set up TCP socket: {err}");
>                          continue;
>                      }
>                  },
> -                _ =  shutdown_future => break,
> +                _ = shutdown_future => break,
>              };
> -            #[cfg(not(feature = "rate-limited-stream"))]
> -            {
> -                let _ = &peer;
> -            }
>  
> -            sock.set_nodelay(true).unwrap();
> -            let _ = proxmox_sys::linux::socket::set_tcp_keepalive(
> -                sock.as_raw_fd(),
> -                self.tcp_keepalive_time,
> -            );
> +            let acceptor = Arc::clone(&acceptor);
> +            let accept_counter = Arc::clone(&accept_counter);
>  
> -            #[cfg(feature = "rate-limited-stream")]
> -            let sock = match self.lookup_rate_limiter.clone() {
> -                Some(lookup) => {
> -                    RateLimitedStream::with_limiter_update_cb(sock, move || lookup(peer))
> +            if Arc::strong_count(&accept_counter) > self.max_pending_accepts {
> +                log::error!("connection rejected - too many open connections");
> +                continue;
> +            }
> +
> +            match sender {
> +                Sender::Secure(ref secure_sender) => {
> +                    let accept_future = Self::do_accept_tls(
> +                        socket,
> +                        acceptor,
> +                        accept_counter,
> +                        self.debug,
> +                        secure_sender.clone(),
> +                    );
> +
> +                    tokio::spawn(accept_future);
> +                }
> +                Sender::SecureAndInsecure(ref secure_sender, ref insecure_sender) => {
> +                    let accept_future = Self::do_accept_tls_optional(
> +                        socket,
> +                        acceptor,
> +                        accept_counter,
> +                        self.debug,
> +                        secure_sender.clone(),
> +                        insecure_sender.clone(),
> +                    );
> +
> +                    tokio::spawn(accept_future);
>                  }
> -                None => RateLimitedStream::with_limiter(sock, None, None),
>              };
> +        }
> +    }
>  
> -            let ssl = {
> -                // limit acceptor_guard scope
> -                // Acceptor can be reloaded using the command socket "reload-certificate" command
> -                let acceptor_guard = self.acceptor.lock().unwrap();
> +    async fn try_setup_socket(
> +        &self,
> +        listener: &TcpListener,
> +    ) -> Result<InsecureClientStream, Error> {
> +        let (socket, peer) = match listener.accept().await {
> +            Ok(connection) => connection,
> +            Err(error) => {
> +                return Err(format_err!(error)).context("error while accepting tcp stream")
> +            }
> +        };
>  
> -                match openssl::ssl::Ssl::new(acceptor_guard.context()) {
> -                    Ok(ssl) => ssl,
> -                    Err(err) => {
> -                        eprintln!("failed to create Ssl object from Acceptor context - {err}");
> -                        continue;
> -                    }
> -                }
> -            };
> +        socket
> +            .set_nodelay(true)
> +            .context("error while setting TCP_NODELAY on socket")?;
> +
> +        proxmox_sys::linux::socket::set_tcp_keepalive(socket.as_raw_fd(), self.tcp_keepalive_time)
> +            .context("error while setting SO_KEEPALIVE on socket")?;
>  
> -            let stream = match tokio_openssl::SslStream::new(ssl, sock) {
> -                Ok(stream) => stream,
> +        #[cfg(feature = "rate-limited-stream")]
> +        let socket = match self.lookup_rate_limiter.clone() {
> +            Some(lookup) => RateLimitedStream::with_limiter_update_cb(socket, move || lookup(peer)),
> +            None => RateLimitedStream::with_limiter(socket, None, None),
> +        };
> +
> +        #[cfg(not(feature = "rate-limited-stream"))]
> +        let _peer = peer;
> +
> +        Ok(socket)
> +    }
> +
> +    async fn do_accept_tls(
> +        socket: InsecureClientStream,
> +        acceptor: Arc<Mutex<SslAcceptor>>,
> +        accept_counter: Arc<()>,
> +        debug: bool,
> +        secure_sender: ClientSender,
> +    ) {
> +        let ssl = {
> +            // limit acceptor_guard scope
> +            // Acceptor can be reloaded using the command socket "reload-certificate" command
> +            let acceptor_guard = acceptor.lock().unwrap();
> +
> +            match openssl::ssl::Ssl::new(acceptor_guard.context()) {
> +                Ok(ssl) => ssl,
>                  Err(err) => {
> -                    eprintln!("failed to create SslStream using ssl and connection socket - {err}");
> -                    continue;
> +                    log::error!("failed to create Ssl object from Acceptor context - {err}");
> +                    return;
>                  }
> -            };
> +            }
> +        };
>  
> -            let mut stream = Box::pin(stream);
> -            let sender = sender.clone();
> +        let secure_stream = match tokio_openssl::SslStream::new(ssl, socket) {
> +            Ok(stream) => stream,
> +            Err(err) => {
> +                log::error!("failed to create SslStream using ssl and connection socket - {err}");
> +                return;
> +            }
> +        };
>  
> -            if Arc::strong_count(&accept_counter) > self.max_pending_accepts {
> -                eprintln!("connection rejected - too many open connections");
> -                continue;
> +        let mut secure_stream = Box::pin(secure_stream);
> +
> +        let accept_future =
> +            tokio::time::timeout(Duration::new(10, 0), secure_stream.as_mut().accept());
> +
> +        let result = accept_future.await;
> +
> +        match result {
> +            Ok(Ok(())) => {
> +                if secure_sender.send(Ok(secure_stream)).await.is_err() && debug {
> +                    log::error!("detected closed connection channel");
> +                }
> +            }
> +            Ok(Err(err)) => {
> +                if debug {
> +                    log::error!("https handshake failed - {err}");
> +                }
>              }
> +            Err(_) => {
> +                if debug {
> +                    log::error!("https handshake timeout");
> +                }
> +            }
> +        }
>  
> -            let accept_counter = Arc::clone(&accept_counter);
> -            tokio::spawn(async move {
> -                let accept_future =
> -                    tokio::time::timeout(Duration::new(10, 0), stream.as_mut().accept());
> +        drop(accept_counter); // decrease reference count
> +    }
>  
> -                let result = accept_future.await;
> +    async fn do_accept_tls_optional(
> +        socket: InsecureClientStream,
> +        acceptor: Arc<Mutex<SslAcceptor>>,
> +        accept_counter: Arc<()>,
> +        debug: bool,
> +        secure_sender: ClientSender,
> +        insecure_sender: InsecureClientSender,
> +    ) {
> +        let client_initiates_handshake = {
> +            #[cfg(feature = "rate-limited-stream")]
> +            let socket = socket.inner();
>  
> -                match result {
> -                    Ok(Ok(())) => {
> -                        if sender.send(Ok(stream)).await.is_err() && self.debug {
> -                            log::error!("detect closed connection channel");
> -                        }
> -                    }
> -                    Ok(Err(err)) => {
> -                        if self.debug {
> -                            log::error!("https handshake failed - {err}");
> -                        }
> -                    }
> -                    Err(_) => {
> -                        if self.debug {
> -                            log::error!("https handshake timeout");
> -                        }
> -                    }
> +            #[cfg(not(feature = "rate-limited-stream"))]
> +            let socket = &socket;
> +
> +            match Self::wait_for_client_tls_handshake(socket).await {
> +                Ok(initiates_handshake) => initiates_handshake,
> +                Err(err) => {
> +                    log::error!("error checking for TLS handshake: {err}");
> +                    return;
>                  }
> +            }
> +        };
> +
> +        if !client_initiates_handshake {
> +            let insecure_stream = Box::pin(socket);
>  
> -                drop(accept_counter); // decrease reference count
> -            });
> +            if insecure_sender.send(Ok(insecure_stream)).await.is_err() && debug {
> +                log::error!("detected closed connection channel")
> +            }
> +
> +            return;
>          }
> +
> +        Self::do_accept_tls(socket, acceptor, accept_counter, debug, secure_sender).await
>      }
> +
> +    async fn wait_for_client_tls_handshake(incoming_stream: &TcpStream) -> Result<bool, Error> {
> +        const MS_TIMEOUT: u64 = 1000;
> +        const BYTES_BUF_SIZE: usize = 128;
> +
> +        let mut buf = [0; BYTES_BUF_SIZE];
> +        let mut last_peek_size = 0;
> +
> +        let future = async {
> +            loop {
> +                let peek_size = incoming_stream
> +                    .peek(&mut buf)
> +                    .await
> +                    .context("couldn't peek into incoming tcp stream")?;
> +
> +                if contains_tls_handshake_fragment(&buf) {
> +                    return Ok(true);
> +                }
> +
> +                // No more new data came in
> +                if peek_size == last_peek_size {
> +                    return Ok(false);
> +                }
> +
> +                last_peek_size = peek_size;
> +
> +                // yields to event loop; this future blocks otherwise ad infinitum
> +                tokio::time::sleep(Duration::from_millis(0)).await;

Just noticed this - how about tokio::task::yield_now()?
Also this makes me wish for epoll to have a flag for edge triggering to
also trigger on additional data not just when going up from zero.





More information about the pbs-devel mailing list