[pbs-devel] [PATCH proxmox 1/3] rest-server: Add `BiAcceptBuilder`

Max Carrara m.carrara at proxmox.com
Thu Jun 22 11:15:24 CEST 2023


This builder is similar to `AcceptBuilder`, but is also able to differ
between plain TCP streams and TCP streams running TLS.

It does so by peeking into the stream's buffer and checking whether
the client is initiating a TLS handshake.

Newly accepted plain TCP streams are sent along via a separate channel
in order to clearly distinguish between "secure" and "insecure"
connections.

Signed-off-by: Max Carrara <m.carrara at proxmox.com>
---
 proxmox-rest-server/src/connection.rs | 327 ++++++++++++++++++++++++++
 1 file changed, 327 insertions(+)

diff --git a/proxmox-rest-server/src/connection.rs b/proxmox-rest-server/src/connection.rs
index 7681f00..937b5d7 100644
--- a/proxmox-rest-server/src/connection.rs
+++ b/proxmox-rest-server/src/connection.rs
@@ -302,3 +302,330 @@ impl AcceptBuilder {
         }
     }
 }
+
+#[cfg(feature = "rate-limited-stream")]
+type InsecureClientStreamResult = Pin<Box<RateLimitedStream<TcpStream>>>;
+#[cfg(not(feature = "rate-limited-stream"))]
+type InsecureClientStreamResult = Pin<Box<TcpStream>>;
+
+#[cfg(feature = "rate-limited-stream")]
+type ClientStream = RateLimitedStream<TcpStream>;
+
+#[cfg(not(feature = "rate-limited-stream"))]
+type ClientStream = TcpStream;
+
+pub struct BiAcceptBuilder {
+    acceptor: Option<Arc<Mutex<SslAcceptor>>>,
+    debug: bool,
+    tcp_keepalive_time: u32,
+    max_pending_accepts: usize,
+
+    #[cfg(feature = "rate-limited-stream")]
+    lookup_rate_limiter: Option<Arc<LookupRateLimiter>>,
+}
+
+impl Default for BiAcceptBuilder {
+    fn default() -> Self {
+        Self {
+            acceptor: None,
+            debug: false,
+            tcp_keepalive_time: 120,
+            max_pending_accepts: 1024,
+            #[cfg(feature = "rate-limited-stream")]
+            lookup_rate_limiter: None,
+        }
+    }
+}
+
+impl BiAcceptBuilder {
+    pub fn new() -> Self {
+        Default::default()
+    }
+
+    pub fn with_acceptor(acceptor: Arc<Mutex<SslAcceptor>>) -> Self {
+        Self {
+            acceptor: Some(acceptor),
+            ..Default::default()
+        }
+    }
+
+    pub fn debug(mut self, debug: bool) -> Self {
+        self.debug = debug;
+        self
+    }
+
+    pub fn tcp_keepalive_time(mut self, time: u32) -> Self {
+        self.tcp_keepalive_time = time;
+        self
+    }
+
+    pub fn max_pending_accepts(mut self, count: usize) -> Self {
+        self.max_pending_accepts = count;
+        self
+    }
+
+    #[cfg(feature = "rate-limited-stream")]
+    pub fn rate_limiter_lookup(mut self, lookup_rate_limiter: Arc<LookupRateLimiter>) -> Self {
+        self.lookup_rate_limiter = Some(lookup_rate_limiter);
+        self
+    }
+
+    pub fn accept(
+        self,
+        listener: TcpListener,
+    ) -> (
+        impl hyper::server::accept::Accept<Conn = ClientStreamResult, Error = Error>,
+        impl hyper::server::accept::Accept<Conn = InsecureClientStreamResult, Error = Error>,
+    ) {
+        use hyper::server::accept;
+        use tokio::sync::mpsc::channel;
+
+        let (secure_sender, secure_receiver) = channel(self.max_pending_accepts);
+        let (insecure_sender, insecure_receiver) = channel(self.max_pending_accepts);
+
+        tokio::spawn(self.accept_connections(listener, secure_sender, insecure_sender));
+
+        (
+            accept::from_stream(ReceiverStream::new(secure_receiver)),
+            accept::from_stream(ReceiverStream::new(insecure_receiver)),
+        )
+    }
+
+    async fn accept_connections(
+        self,
+        listener: TcpListener,
+        secure_sender: tokio::sync::mpsc::Sender<Result<ClientStreamResult, Error>>,
+        insecure_sender: tokio::sync::mpsc::Sender<Result<InsecureClientStreamResult, Error>>,
+    ) {
+        let accept_counter = Arc::new(());
+        let mut shutdown_future = crate::shutdown_future().fuse();
+
+        loop {
+            let (sock, peer) = futures::select! {
+                res = listener.accept().fuse() => match res {
+                    Ok(conn) => conn,
+                    Err(err) => {
+                        log::error!("error accepting tcp connection: {err}");
+                        continue;
+                    }
+                },
+                _ = shutdown_future => break,
+            };
+
+            #[cfg(not(feature = "rate-limited-stream"))]
+            drop(peer);
+
+            sock.set_nodelay(true).unwrap();
+
+            let _ = proxmox_sys::linux::socket::set_tcp_keepalive(
+                sock.as_raw_fd(),
+                self.tcp_keepalive_time,
+            );
+
+            #[cfg(feature = "rate-limited-stream")]
+            let sock = match self.lookup_rate_limiter.clone() {
+                Some(lookup) => {
+                    RateLimitedStream::with_limiter_update_cb(sock, move || lookup(peer))
+                }
+                None => RateLimitedStream::with_limiter(sock, None, None),
+            };
+
+            let accept_counter = Arc::clone(&accept_counter);
+
+            if Arc::strong_count(&accept_counter) > self.max_pending_accepts {
+                log::error!("connection rejected - too many open connections");
+                continue;
+            }
+
+            let acceptor = self.acceptor.clone();
+
+            // prevents the loop from being blocked if many plain TCP
+            // connections are being opened
+            let accept_future = Self::do_accept(
+                sock,
+                acceptor,
+                accept_counter,
+                secure_sender.clone(),
+                insecure_sender.clone(),
+                self.debug,
+            );
+
+            tokio::spawn(accept_future);
+        }
+    }
+
+    async fn do_accept(
+        sock: ClientStream,
+        acceptor: Option<Arc<Mutex<SslAcceptor>>>,
+        accept_counter: Arc<()>,
+        secure_sender: tokio::sync::mpsc::Sender<Result<ClientStreamResult, Error>>,
+        insecure_sender: tokio::sync::mpsc::Sender<Result<InsecureClientStreamResult, Error>>,
+        debug: bool,
+    ) {
+        #[inline(always)]
+        fn send_insecure(
+            sock: ClientStream,
+            insecure_sender: tokio::sync::mpsc::Sender<Result<InsecureClientStreamResult, Error>>,
+            accept_counter: Arc<()>,
+            debug: bool,
+        ) {
+            let insecure_stream = Box::pin(sock);
+
+            tokio::spawn(async move {
+                if insecure_sender.send(Ok(insecure_stream)).await.is_err() && debug {
+                    log::error!("detected closed connection channel")
+                };
+
+                drop(accept_counter);
+            });
+        }
+
+        if acceptor.is_none() {
+            send_insecure(sock, insecure_sender, accept_counter, debug);
+            return;
+        }
+
+        let client_initiates_handshake = {
+            #[cfg(feature = "rate-limited-stream")]
+            let sock = sock.inner();
+
+            #[cfg(not(feature = "rate-limited-stream"))]
+            let sock = &sock;
+
+            match Self::wait_for_client_tls_handshake(sock).await {
+                Ok(initiates_handshake) => initiates_handshake,
+                Err(err) => {
+                    log::error!("error checking for TLS handshake: {err}");
+                    return;
+                }
+            }
+        };
+
+        if !client_initiates_handshake {
+            send_insecure(sock, insecure_sender, accept_counter, debug);
+            return;
+        }
+
+        let ssl = {
+            // limit acceptor_guard scope
+            // Acceptor can be reloaded using the command socket "reload-certificate" command
+            let acceptor_guard = acceptor.as_ref().unwrap().lock().unwrap();
+
+            match openssl::ssl::Ssl::new(acceptor_guard.context()) {
+                Ok(ssl) => ssl,
+                Err(err) => {
+                    log::error!("failed to create Ssl object from Acceptor context - {err}");
+                    return;
+                }
+            }
+        };
+
+        let secure_stream = match tokio_openssl::SslStream::new(ssl, sock) {
+            Ok(stream) => stream,
+            Err(err) => {
+                log::error!("failed to create SslStream using ssl and connection socket - {err}");
+                return;
+            }
+        };
+
+        let mut secure_stream = Box::pin(secure_stream);
+        let secure_sender = secure_sender.clone();
+
+        tokio::spawn(async move {
+            let accept_future =
+                tokio::time::timeout(Duration::new(10, 0), secure_stream.as_mut().accept());
+
+            let result = accept_future.await;
+
+            match result {
+                Ok(Ok(())) => {
+                    if secure_sender.send(Ok(secure_stream)).await.is_err() && debug {
+                        log::error!("detected closed connection channel");
+                    }
+                }
+                Ok(Err(err)) => {
+                    if debug {
+                        log::error!("https handshake failed - {err}");
+                    }
+                }
+                Err(_) => {
+                    if debug {
+                        log::error!("https handshake timeout");
+                    }
+                }
+            }
+
+            drop(accept_counter); // decrease reference count
+        });
+    }
+
+    async fn wait_for_client_tls_handshake(incoming_stream: &TcpStream) -> Result<bool, Error> {
+        const MS_TIMEOUT: u64 = 1000;
+        const BYTES_BUF_SIZE: usize = 128;
+
+        let mut buf = [0; BYTES_BUF_SIZE];
+        let mut last_peek_size = 0;
+
+        let future = async {
+            loop {
+                let peek_size = incoming_stream
+                    .peek(&mut buf)
+                    .await
+                    .context("couldn't peek into incoming tcp stream")?;
+
+                if contains_tls_handshake_fragment(&buf) {
+                    return Ok(true);
+                }
+
+                // No more new data came in
+                if peek_size == last_peek_size {
+                    return Ok(false);
+                }
+
+                last_peek_size = peek_size;
+
+                // yields to event loop; this future blocks otherwise ad infinitum
+                tokio::time::sleep(Duration::from_millis(0)).await;
+            }
+        };
+
+        tokio::time::timeout(Duration::from_millis(MS_TIMEOUT), future)
+            .await
+            .unwrap_or(Ok(false))
+    }
+}
+
+/// Checks whether an [SSL 3.0 / TLS plaintext fragment][0] being part of a
+/// SSL / TLS handshake is contained in the given buffer.
+///
+/// Such a fragment might look as follows:
+/// ```ignore
+/// [0x16, 0x3, 0x1, 0x02, 0x00, ...]
+/// //  |    |    |     |_____|
+/// //  |    |    |            \__ content length interpreted as u16
+/// //  |    |    |                must not exceed 0x4000 (2^14) bytes
+/// //  |    |    |
+/// //  |    |     \__ any minor version
+/// //  |    |
+/// //  |     \__ major version 3
+/// //  |
+/// //   \__ content type is handshake(22)
+/// ```
+///
+/// If a slice like this is detected at the beginning of the given buffer,
+/// a TLS handshake is most definitely being made.
+///
+/// [0]: https://datatracker.ietf.org/doc/html/rfc6101#section-5.2
+#[inline]
+fn contains_tls_handshake_fragment(buf: &[u8]) -> bool {
+    const SLICE_LENGTH: usize = 5;
+    const CONTENT_SIZE: u16 = 1 << 14; // max length of a TLS plaintext fragment
+
+    if buf.len() < SLICE_LENGTH {
+        return false;
+    }
+
+    return buf[0] == 0x16
+        && buf[1] == 0x3
+        && (((buf[3] as u16) << 8) + buf[4] as u16) <= CONTENT_SIZE;
+}
-- 
2.30.2






More information about the pbs-devel mailing list