[pbs-devel] [PATCH v3 proxmox 1/3] rest-server: Refactor `AcceptBuilder`, provide support for optional TLS

Max Carrara m.carrara at proxmox.com
Tue Oct 31 19:47:03 CET 2023


The new public function `accept_tls_optional()` is added, which
accepts both plain TCP streams and TCP streams running TLS. Plain TCP
streams are sent along via a separate channel in order to clearly
distinguish between "secure" and "insecure" connections.

Furthermore, instead of `AcceptBuilder` itself holding a reference to
an `SslAcceptor`, its public functions now take the acceptor as an
argument. The public functions' names are changed to distinguish
between their functionality in a more explicit manner:

  * `accept()` --> `accept_tls()`
  *        NEW --> `accept_tls_optional()`

Signed-off-by: Max Carrara <m.carrara at proxmox.com>
---
 Changes v1 --> v2:
  * No more `BiAcceptBuilder`, `AcceptBuilder` is refactored instead
  * `AcceptBuilder` doesn't hold a reference to `SslAcceptor` anymore
  * Avoid unnecessary `#[cfg]`s
  * Avoid unnecessarily duplicated code (already mostly done by getting
    rid of `BiAcceptBuilder`)
  * Some clippy stuff

 Changes v2 --> v3:
  * Incorporate previously applied clippy fixes

 proxmox-rest-server/src/connection.rs | 373 ++++++++++++++++++++------
 1 file changed, 287 insertions(+), 86 deletions(-)

diff --git a/proxmox-rest-server/src/connection.rs b/proxmox-rest-server/src/connection.rs
index 1bec28d..ab8c7db 100644
--- a/proxmox-rest-server/src/connection.rs
+++ b/proxmox-rest-server/src/connection.rs
@@ -8,15 +8,16 @@ use std::pin::Pin;
 use std::sync::{Arc, Mutex};
 use std::time::Duration;
 
-use anyhow::Context as _;
-use anyhow::Error;
+use anyhow::{format_err, Context as _, Error};
 use futures::FutureExt;
+use hyper::server::accept;
 use openssl::ec::{EcGroup, EcKey};
 use openssl::nid::Nid;
 use openssl::pkey::{PKey, Private};
 use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
 use openssl::x509::X509;
 use tokio::net::{TcpListener, TcpStream};
+use tokio::sync::mpsc;
 use tokio_openssl::SslStream;
 use tokio_stream::wrappers::ReceiverStream;
 
@@ -133,10 +134,14 @@ impl TlsAcceptorBuilder {
     }
 }
 
-#[cfg(feature = "rate-limited-stream")]
-type ClientStreamResult = Pin<Box<SslStream<RateLimitedStream<TcpStream>>>>;
 #[cfg(not(feature = "rate-limited-stream"))]
-type ClientStreamResult = Pin<Box<SslStream<TcpStream>>>;
+type InsecureClientStream = TcpStream;
+#[cfg(feature = "rate-limited-stream")]
+type InsecureClientStream = RateLimitedStream<TcpStream>;
+
+type InsecureClientStreamResult = Pin<Box<InsecureClientStream>>;
+
+type ClientStreamResult = Pin<Box<SslStream<InsecureClientStream>>>;
 
 #[cfg(feature = "rate-limited-stream")]
 type LookupRateLimiter = dyn Fn(std::net::SocketAddr) -> (Option<SharedRateLimit>, Option<SharedRateLimit>)
@@ -145,7 +150,6 @@ type LookupRateLimiter = dyn Fn(std::net::SocketAddr) -> (Option<SharedRateLimit
     + 'static;
 
 pub struct AcceptBuilder {
-    acceptor: Arc<Mutex<SslAcceptor>>,
     debug: bool,
     tcp_keepalive_time: u32,
     max_pending_accepts: usize,
@@ -154,16 +158,9 @@ pub struct AcceptBuilder {
     lookup_rate_limiter: Option<Arc<LookupRateLimiter>>,
 }
 
-impl AcceptBuilder {
-    pub fn new() -> Result<Self, Error> {
-        Ok(Self::with_acceptor(Arc::new(Mutex::new(
-            TlsAcceptorBuilder::new().build()?,
-        ))))
-    }
-
-    pub fn with_acceptor(acceptor: Arc<Mutex<SslAcceptor>>) -> Self {
+impl Default for AcceptBuilder {
+    fn default() -> Self {
         Self {
-            acceptor,
             debug: false,
             tcp_keepalive_time: 120,
             max_pending_accepts: 1024,
@@ -172,6 +169,12 @@ impl AcceptBuilder {
             lookup_rate_limiter: None,
         }
     }
+}
+
+impl AcceptBuilder {
+    pub fn new() -> Self {
+        Default::default()
+    }
 
     pub fn debug(mut self, debug: bool) -> Self {
         self.debug = debug;
@@ -193,114 +196,312 @@ impl AcceptBuilder {
         self.lookup_rate_limiter = Some(lookup_rate_limiter);
         self
     }
+}
 
-    pub fn accept(
+impl AcceptBuilder {
+    pub fn accept_tls(
         self,
         listener: TcpListener,
-    ) -> impl hyper::server::accept::Accept<Conn = ClientStreamResult, Error = Error> {
-        let (sender, receiver) = tokio::sync::mpsc::channel(self.max_pending_accepts);
+        acceptor: Arc<Mutex<SslAcceptor>>,
+    ) -> impl accept::Accept<Conn = ClientStreamResult, Error = Error> {
+        let (secure_sender, secure_receiver) = mpsc::channel(self.max_pending_accepts);
+
+        tokio::spawn(self.accept_connections(listener, acceptor, secure_sender.into()));
+
+        accept::from_stream(ReceiverStream::new(secure_receiver))
+    }
+
+    pub fn accept_tls_optional(
+        self,
+        listener: TcpListener,
+        acceptor: Arc<Mutex<SslAcceptor>>,
+    ) -> (
+        impl accept::Accept<Conn = ClientStreamResult, Error = Error>,
+        impl accept::Accept<Conn = InsecureClientStreamResult, Error = Error>,
+    ) {
+        let (secure_sender, secure_receiver) = mpsc::channel(self.max_pending_accepts);
+        let (insecure_sender, insecure_receiver) = mpsc::channel(self.max_pending_accepts);
+
+        tokio::spawn(self.accept_connections(
+            listener,
+            acceptor,
+            (secure_sender, insecure_sender).into(),
+        ));
+
+        (
+            accept::from_stream(ReceiverStream::new(secure_receiver)),
+            accept::from_stream(ReceiverStream::new(insecure_receiver)),
+        )
+    }
+}
+
+type ClientSender = mpsc::Sender<Result<ClientStreamResult, Error>>;
+type InsecureClientSender = mpsc::Sender<Result<InsecureClientStreamResult, Error>>;
 
-        tokio::spawn(self.accept_connections(listener, sender));
+enum Sender {
+    Secure(ClientSender),
+    SecureAndInsecure(ClientSender, InsecureClientSender),
+}
 
-        //receiver
-        hyper::server::accept::from_stream(ReceiverStream::new(receiver))
+impl From<ClientSender> for Sender {
+    fn from(sender: ClientSender) -> Self {
+        Sender::Secure(sender)
     }
+}
+
+impl From<(ClientSender, InsecureClientSender)> for Sender {
+    fn from(senders: (ClientSender, InsecureClientSender)) -> Self {
+        Sender::SecureAndInsecure(senders.0, senders.1)
+    }
+}
 
+impl AcceptBuilder {
     async fn accept_connections(
         self,
         listener: TcpListener,
-        sender: tokio::sync::mpsc::Sender<Result<ClientStreamResult, Error>>,
+        acceptor: Arc<Mutex<SslAcceptor>>,
+        sender: Sender,
     ) {
         let accept_counter = Arc::new(());
         let mut shutdown_future = crate::shutdown_future().fuse();
 
         loop {
-            let (sock, peer) = futures::select! {
-                res = listener.accept().fuse() => match res {
-                    Ok(conn) => conn,
+            let socket = futures::select! {
+                res = self.try_setup_socket(&listener).fuse() => match res {
+                    Ok(socket) => socket,
                     Err(err) => {
-                        eprintln!("error accepting tcp connection: {err}");
+                        log::error!("couldn't set up TCP socket: {err}");
                         continue;
                     }
                 },
-                _ =  shutdown_future => break,
+                _ = shutdown_future => break,
             };
-            #[cfg(not(feature = "rate-limited-stream"))]
-            {
-                let _ = &peer;
-            }
 
-            sock.set_nodelay(true).unwrap();
-            let _ = proxmox_sys::linux::socket::set_tcp_keepalive(
-                sock.as_raw_fd(),
-                self.tcp_keepalive_time,
-            );
+            let acceptor = Arc::clone(&acceptor);
+            let accept_counter = Arc::clone(&accept_counter);
 
-            #[cfg(feature = "rate-limited-stream")]
-            let sock = match self.lookup_rate_limiter.clone() {
-                Some(lookup) => {
-                    RateLimitedStream::with_limiter_update_cb(sock, move || lookup(peer))
+            if Arc::strong_count(&accept_counter) > self.max_pending_accepts {
+                log::error!("connection rejected - too many open connections");
+                continue;
+            }
+
+            match sender {
+                Sender::Secure(ref secure_sender) => {
+                    let accept_future = Self::do_accept_tls(
+                        socket,
+                        acceptor,
+                        accept_counter,
+                        self.debug,
+                        secure_sender.clone(),
+                    );
+
+                    tokio::spawn(accept_future);
+                }
+                Sender::SecureAndInsecure(ref secure_sender, ref insecure_sender) => {
+                    let accept_future = Self::do_accept_tls_optional(
+                        socket,
+                        acceptor,
+                        accept_counter,
+                        self.debug,
+                        secure_sender.clone(),
+                        insecure_sender.clone(),
+                    );
+
+                    tokio::spawn(accept_future);
                 }
-                None => RateLimitedStream::with_limiter(sock, None, None),
             };
+        }
+    }
 
-            let ssl = {
-                // limit acceptor_guard scope
-                // Acceptor can be reloaded using the command socket "reload-certificate" command
-                let acceptor_guard = self.acceptor.lock().unwrap();
+    async fn try_setup_socket(
+        &self,
+        listener: &TcpListener,
+    ) -> Result<InsecureClientStream, Error> {
+        let (socket, peer) = match listener.accept().await {
+            Ok(connection) => connection,
+            Err(error) => {
+                return Err(format_err!(error)).context("error while accepting tcp stream")
+            }
+        };
 
-                match openssl::ssl::Ssl::new(acceptor_guard.context()) {
-                    Ok(ssl) => ssl,
-                    Err(err) => {
-                        eprintln!("failed to create Ssl object from Acceptor context - {err}");
-                        continue;
-                    }
-                }
-            };
+        socket
+            .set_nodelay(true)
+            .context("error while setting TCP_NODELAY on socket")?;
+
+        proxmox_sys::linux::socket::set_tcp_keepalive(socket.as_raw_fd(), self.tcp_keepalive_time)
+            .context("error while setting SO_KEEPALIVE on socket")?;
 
-            let stream = match tokio_openssl::SslStream::new(ssl, sock) {
-                Ok(stream) => stream,
+        #[cfg(feature = "rate-limited-stream")]
+        let socket = match self.lookup_rate_limiter.clone() {
+            Some(lookup) => RateLimitedStream::with_limiter_update_cb(socket, move || lookup(peer)),
+            None => RateLimitedStream::with_limiter(socket, None, None),
+        };
+
+        #[cfg(not(feature = "rate-limited-stream"))]
+        let _peer = peer;
+
+        Ok(socket)
+    }
+
+    async fn do_accept_tls(
+        socket: InsecureClientStream,
+        acceptor: Arc<Mutex<SslAcceptor>>,
+        accept_counter: Arc<()>,
+        debug: bool,
+        secure_sender: ClientSender,
+    ) {
+        let ssl = {
+            // limit acceptor_guard scope
+            // Acceptor can be reloaded using the command socket "reload-certificate" command
+            let acceptor_guard = acceptor.lock().unwrap();
+
+            match openssl::ssl::Ssl::new(acceptor_guard.context()) {
+                Ok(ssl) => ssl,
                 Err(err) => {
-                    eprintln!("failed to create SslStream using ssl and connection socket - {err}");
-                    continue;
+                    log::error!("failed to create Ssl object from Acceptor context - {err}");
+                    return;
                 }
-            };
+            }
+        };
 
-            let mut stream = Box::pin(stream);
-            let sender = sender.clone();
+        let secure_stream = match tokio_openssl::SslStream::new(ssl, socket) {
+            Ok(stream) => stream,
+            Err(err) => {
+                log::error!("failed to create SslStream using ssl and connection socket - {err}");
+                return;
+            }
+        };
 
-            if Arc::strong_count(&accept_counter) > self.max_pending_accepts {
-                eprintln!("connection rejected - too many open connections");
-                continue;
+        let mut secure_stream = Box::pin(secure_stream);
+
+        let accept_future =
+            tokio::time::timeout(Duration::new(10, 0), secure_stream.as_mut().accept());
+
+        let result = accept_future.await;
+
+        match result {
+            Ok(Ok(())) => {
+                if secure_sender.send(Ok(secure_stream)).await.is_err() && debug {
+                    log::error!("detected closed connection channel");
+                }
+            }
+            Ok(Err(err)) => {
+                if debug {
+                    log::error!("https handshake failed - {err}");
+                }
             }
+            Err(_) => {
+                if debug {
+                    log::error!("https handshake timeout");
+                }
+            }
+        }
 
-            let accept_counter = Arc::clone(&accept_counter);
-            tokio::spawn(async move {
-                let accept_future =
-                    tokio::time::timeout(Duration::new(10, 0), stream.as_mut().accept());
+        drop(accept_counter); // decrease reference count
+    }
 
-                let result = accept_future.await;
+    async fn do_accept_tls_optional(
+        socket: InsecureClientStream,
+        acceptor: Arc<Mutex<SslAcceptor>>,
+        accept_counter: Arc<()>,
+        debug: bool,
+        secure_sender: ClientSender,
+        insecure_sender: InsecureClientSender,
+    ) {
+        let client_initiates_handshake = {
+            #[cfg(feature = "rate-limited-stream")]
+            let socket = socket.inner();
 
-                match result {
-                    Ok(Ok(())) => {
-                        if sender.send(Ok(stream)).await.is_err() && self.debug {
-                            log::error!("detect closed connection channel");
-                        }
-                    }
-                    Ok(Err(err)) => {
-                        if self.debug {
-                            log::error!("https handshake failed - {err}");
-                        }
-                    }
-                    Err(_) => {
-                        if self.debug {
-                            log::error!("https handshake timeout");
-                        }
-                    }
+            #[cfg(not(feature = "rate-limited-stream"))]
+            let socket = &socket;
+
+            match Self::wait_for_client_tls_handshake(socket).await {
+                Ok(initiates_handshake) => initiates_handshake,
+                Err(err) => {
+                    log::error!("error checking for TLS handshake: {err}");
+                    return;
                 }
+            }
+        };
+
+        if !client_initiates_handshake {
+            let insecure_stream = Box::pin(socket);
 
-                drop(accept_counter); // decrease reference count
-            });
+            if insecure_sender.send(Ok(insecure_stream)).await.is_err() && debug {
+                log::error!("detected closed connection channel")
+            }
+
+            return;
         }
+
+        Self::do_accept_tls(socket, acceptor, accept_counter, debug, secure_sender).await
     }
+
+    async fn wait_for_client_tls_handshake(incoming_stream: &TcpStream) -> Result<bool, Error> {
+        const MS_TIMEOUT: u64 = 1000;
+        const BYTES_BUF_SIZE: usize = 128;
+
+        let mut buf = [0; BYTES_BUF_SIZE];
+        let mut last_peek_size = 0;
+
+        let future = async {
+            loop {
+                let peek_size = incoming_stream
+                    .peek(&mut buf)
+                    .await
+                    .context("couldn't peek into incoming tcp stream")?;
+
+                if contains_tls_handshake_fragment(&buf) {
+                    return Ok(true);
+                }
+
+                // No more new data came in
+                if peek_size == last_peek_size {
+                    return Ok(false);
+                }
+
+                last_peek_size = peek_size;
+
+                // yields to event loop; this future blocks otherwise ad infinitum
+                tokio::time::sleep(Duration::from_millis(0)).await;
+            }
+        };
+
+        tokio::time::timeout(Duration::from_millis(MS_TIMEOUT), future)
+            .await
+            .unwrap_or(Ok(false))
+    }
+}
+
+/// Checks whether an [SSL 3.0 / TLS plaintext fragment][0] being part of a
+/// SSL / TLS handshake is contained in the given buffer.
+///
+/// Such a fragment might look as follows:
+/// ```ignore
+/// [0x16, 0x3, 0x1, 0x02, 0x00, ...]
+/// //  |    |    |     |_____|
+/// //  |    |    |            \__ content length interpreted as u16
+/// //  |    |    |                must not exceed 0x4000 (2^14) bytes
+/// //  |    |    |
+/// //  |    |     \__ any minor version
+/// //  |    |
+/// //  |     \__ major version 3
+/// //  |
+/// //   \__ content type is handshake(22)
+/// ```
+///
+/// If a slice like this is detected at the beginning of the given buffer,
+/// a TLS handshake is most definitely being made.
+///
+/// [0]: https://datatracker.ietf.org/doc/html/rfc6101#section-5.2
+#[inline]
+fn contains_tls_handshake_fragment(buf: &[u8]) -> bool {
+    const SLICE_LENGTH: usize = 5;
+    const CONTENT_SIZE: u16 = 1 << 14; // max length of a TLS plaintext fragment
+
+    if buf.len() < SLICE_LENGTH {
+        return false;
+    }
+
+    buf[0] == 0x16 && buf[1] == 0x3 && (((buf[3] as u16) << 8) + buf[4] as u16) <= CONTENT_SIZE
 }
-- 
2.39.2






More information about the pbs-devel mailing list