[pbs-devel] applied: [PATCH proxmox v2 1/4] proxmox/tools/websocket: introduce WebSocketError and use it

Wolfgang Bumiller w.bumiller at proxmox.com
Fri Jul 17 14:38:15 CEST 2020


applied series

On Fri, Jul 17, 2020 at 02:17:07PM +0200, Dominik Csapak wrote:
> this patch introduces a custom WebSocketError, so that we can
> detect errors in the websocket protocol and react to it in the
> right way
> 
> the channel now sends a Result<(OpCode, Box<[u8]>), WebSocketError>
> so that we can either react to a control frame, or react to an
> errornous data stream (which means sending a close frame
> with the appropriate error code)
> 
> while at it, change FrameHeader::try_from_bytes to return
> Result<Option<FrameHeader>, WebSocketError> instead of a nested
> Result with the *guessed* remaining size. This was neither used by
> us, nor was it really meaningful, since this can change during the
> decode every time new data is read (extensions, mask, payload length, etc.)
> so simply returning an Option is enough information for us
> 
> Signed-off-by: Dominik Csapak <d.csapak at proxmox.com>
> ---
> changes from v1:
> * fix one instance of Ok(Err) for try_from_bytes that was leftover
> * omit the pub const slices and implement to_be_bytes for the ErrorKind
>   instead
> * (correctly) use extend_from_slice instead of copy_from_slice
> * impl std::error::Error instead of From<WebSocketError> for anyhow::Error
> 
>  proxmox/src/tools/websocket.rs | 184 ++++++++++++++++++++++++---------
>  1 file changed, 133 insertions(+), 51 deletions(-)
> 
> diff --git a/proxmox/src/tools/websocket.rs b/proxmox/src/tools/websocket.rs
> index fc9a0c5..c6775f0 100644
> --- a/proxmox/src/tools/websocket.rs
> +++ b/proxmox/src/tools/websocket.rs
> @@ -7,7 +7,7 @@
>  use std::pin::Pin;
>  use std::task::{Context, Poll};
>  use std::cmp::min;
> -use std::io::{self, ErrorKind};
> +use std::io;
>  use std::future::Future;
>  
>  use futures::select;
> @@ -29,9 +29,65 @@ use hyper::header::{
>  use futures::future::FutureExt;
>  use futures::ready;
>  
> -use crate::io_format_err;
> +use crate::sys::error::io_err_other;
>  use crate::tools::byte_buffer::ByteBuffer;
>  
> +// see RFC6455 section 7.4.1
> +#[derive(Debug, Clone, Copy)]
> +#[repr(u16)]
> +pub enum WebSocketErrorKind {
> +    Normal = 1000,
> +    ProtocolError = 1002,
> +    InvalidData = 1003,
> +    Other = 1008,
> +    Unexpected = 1011,
> +}
> +
> +impl WebSocketErrorKind {
> +    #[inline]
> +    pub fn to_be_bytes(self) -> [u8; 2] {
> +        (self as u16).to_be_bytes()
> +    }
> +}
> +
> +impl std::fmt::Display for WebSocketErrorKind {
> +    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
> +        write!(f, "{}", *self as u16)
> +    }
> +}
> +
> +#[derive(Debug, Clone)]
> +pub struct WebSocketError{
> +    kind: WebSocketErrorKind,
> +    message: String,
> +}
> +
> +impl WebSocketError {
> +    pub fn new(kind: WebSocketErrorKind, message: &str) -> Self {
> +        Self{
> +            kind,
> +            message: message.to_string()
> +        }
> +    }
> +
> +    pub fn generate_frame_payload(&self) -> Vec<u8> {
> +        let msglen = self.message.len().min(125);
> +        let code = self.kind.to_be_bytes();
> +        let mut data = Vec::with_capacity(msglen + 2);
> +        data.extend_from_slice(&code);
> +        data.extend_from_slice(&self.message.as_bytes()[..msglen]);
> +        data
> +    }
> +}
> +
> +impl std::fmt::Display for WebSocketError {
> +    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
> +        write!(f, "{} (Code: {})", self.message, self.kind)
> +    }
> +}
> +
> +impl std::error::Error for WebSocketError {}
> +
>  #[repr(u8)]
>  #[derive(Debug, PartialEq, PartialOrd, Copy, Clone)]
>  /// Represents an OpCode of a websocket frame
> @@ -293,25 +349,23 @@ impl FrameHeader {
>      /// Tries to parse a FrameHeader from bytes.
>      ///
>      /// When there are not enough bytes to completely parse the header,
> -    /// returns Ok(Err(size)) where size determines how many bytes
> -    /// are missing to parse further (this amount can change when more
> -    /// information is available)
> +    /// returns Ok(None)
>      ///
>      /// Example:
>      /// ```
>      /// # use proxmox::tools::websocket::*;
>      /// # use std::io;
> -    /// # fn main() -> io::Result<()> {
> +    /// # fn main() -> Result<(), WebSocketError> {
>      /// let frame = create_frame(None, &[0,1,2,3], OpCode::Ping)?;
>      /// let header = FrameHeader::try_from_bytes(&frame[..1])?;
>      /// match header {
> -    ///     Ok(_) => unreachable!(),
> -    ///     Err(x) => assert_eq!(x, 1),
> +    ///     Some(_) => unreachable!(),
> +    ///     None => {},
>      /// }
>      /// let header = FrameHeader::try_from_bytes(&frame[..2])?;
>      /// match header {
> -    ///     Err(x) => unreachable!(),
> -    ///     Ok(header) => assert_eq!(header, FrameHeader{
> +    ///     None => unreachable!(),
> +    ///     Some(header) => assert_eq!(header, FrameHeader{
>      ///         fin: true,
>      ///         mask: None,
>      ///         frametype: OpCode::Ping,
> @@ -322,19 +376,19 @@ impl FrameHeader {
>      /// # Ok(())
>      /// # }
>      /// ```
> -    pub fn try_from_bytes(data: &[u8]) -> io::Result<Result<FrameHeader, usize>> {
> +    pub fn try_from_bytes(data: &[u8]) -> Result<Option<FrameHeader>, WebSocketError> {
>          let len = data.len();
>          if len < 2 {
> -            return Ok(Err(2 - len));
> +            return Ok(None);
>          }
>  
>          let data = data;
>  
>          // we do not support extensions
>          if data[0] & 0b01110000 > 0 {
> -            return Err(io::Error::new(
> -                ErrorKind::InvalidData,
> -                "Extensions not supported",
> +            return Err(WebSocketError::new(
> +                    WebSocketErrorKind::ProtocolError,
> +                    "Extensions not supported",
>              ));
>          }
>  
> @@ -347,14 +401,17 @@ impl FrameHeader {
>              9 => OpCode::Ping,
>              10 => OpCode::Pong,
>              other => {
> -                return Err(io::Error::new(ErrorKind::InvalidData, format!("Unknown OpCode {}", other)));
> +                return Err(WebSocketError::new(
> +                        WebSocketErrorKind::ProtocolError,
> +                        &format!("Unknown OpCode {}", other),
> +                ));
>              }
>          };
>  
>          if !fin && frametype.is_control() {
> -            return Err(io::Error::new(
> -                ErrorKind::InvalidData,
> -                "Control frames cannot be fragmented",
> +            return Err(WebSocketError::new(
> +                    WebSocketErrorKind::ProtocolError,
> +                    "Control frames cannot be fragmented",
>              ));
>          }
>  
> @@ -368,14 +425,14 @@ impl FrameHeader {
>          let mut payload_len: usize = (data[1] & 0b01111111).into();
>          if payload_len == 126 {
>              if len < 4 {
> -                return Ok(Err(4 - len));
> +                return Ok(None);
>              }
>              payload_len = u16::from_be_bytes([data[2], data[3]]) as usize;
>              mask_offset += 2;
>              payload_offset += 2;
>          } else if payload_len == 127 {
>              if len < 10 {
> -                return Ok(Err(10 - len));
> +                return Ok(None);
>              }
>              payload_len = u64::from_be_bytes([
>                  data[2], data[3], data[4], data[5], data[6], data[7], data[8], data[9],
> @@ -385,16 +442,16 @@ impl FrameHeader {
>          }
>  
>          if payload_len > 125 && frametype.is_control() {
> -            return Err(io::Error::new(
> -                ErrorKind::InvalidData,
> -                "Control frames cannot carry more than 125 bytes of data",
> +            return Err(WebSocketError::new(
> +                    WebSocketErrorKind::ProtocolError,
> +                    "Control frames cannot carry more than 125 bytes of data",
>              ));
>          }
>  
>          let mask = match mask_bit {
>              true => {
>                  if len < mask_offset + 4 {
> -                    return Ok(Err(mask_offset + 4 - len));
> +                    return Ok(None);
>                  }
>                  let mut mask = [0u8; 4];
>                  mask.copy_from_slice(&data[mask_offset as usize..payload_offset as usize]);
> @@ -403,7 +460,7 @@ impl FrameHeader {
>              false => None,
>          };
>  
> -        Ok(Ok(FrameHeader {
> +        Ok(Some(FrameHeader {
>              fin,
>              mask,
>              frametype,
> @@ -413,6 +470,8 @@ impl FrameHeader {
>      }
>  }
>  
> +type WebSocketReadResult = Result<(OpCode, Box<[u8]>), WebSocketError>;
> +
>  /// Wraps a reader that implements AsyncRead and implements it itself.
>  ///
>  /// On read, reads the underlying reader and tries to decode the frames and
> @@ -422,7 +481,7 @@ impl FrameHeader {
>  /// Has an internal Buffer for storing incomplete headers.
>  pub struct WebSocketReader<R: AsyncRead> {
>      reader: Option<R>,
> -    sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>,
> +    sender: mpsc::UnboundedSender<WebSocketReadResult>,
>      read_buffer: Option<ByteBuffer>,
>      header: Option<FrameHeader>,
>      state: ReaderState<R>,
> @@ -431,11 +490,11 @@ pub struct WebSocketReader<R: AsyncRead> {
>  impl<R: AsyncReadExt> WebSocketReader<R> {
>      /// Creates a new WebSocketReader with the given CallBack for control frames
>      /// and a default buffer size of 4096.
> -    pub fn new(reader: R, sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>) -> WebSocketReader<R> {
> +    pub fn new(reader: R, sender: mpsc::UnboundedSender<WebSocketReadResult>) -> WebSocketReader<R> {
>          Self::with_capacity(reader, 4096, sender)
>      }
>  
> -    pub fn with_capacity(reader: R, capacity: usize, sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>) -> WebSocketReader<R> {
> +    pub fn with_capacity(reader: R, capacity: usize, sender: mpsc::UnboundedSender<WebSocketReadResult>) -> WebSocketReader<R> {
>          WebSocketReader {
>              reader: Some(reader),
>              sender,
> @@ -512,13 +571,19 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
>                      let mut header = match this.header.take() {
>                          Some(header) => header,
>                          None => {
> -                            let header = match FrameHeader::try_from_bytes(&read_buffer[..])? {
> -                                Ok(header) => header,
> -                                Err(_) => {
> +                            let header = match FrameHeader::try_from_bytes(&read_buffer[..]) {
> +                                Ok(Some(header)) => header,
> +                                Ok(None) => {
>                                      this.state = ReaderState::NoData;
>                                      this.read_buffer = Some(read_buffer);
>                                      continue;
> -                                }
> +                                },
> +                                Err(err) => {
> +                                    if let Err(err) = this.sender.send(Err(err.clone())) {
> +                                        return Poll::Ready(Err(io_err_other(err)));
> +                                    }
> +                                    return Poll::Ready(Err(io_err_other(err)));
> +                                },
>                              };
>  
>                              read_buffer.consume(header.header_len as usize);
> @@ -531,7 +596,7 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
>  
>                              let mut data = read_buffer.remove_data(header.payload_len);
>                              mask_bytes(header.mask, &mut data);
> -                            if let Err(err) = this.sender.send((header.frametype, data)) {
> +                            if let Err(err) = this.sender.send(Ok((header.frametype, data))) {
>                                  eprintln!("error sending control frame: {}", err);
>                              }
>  
> @@ -639,10 +704,37 @@ impl WebSocket {
>          Ok((Self { text }, response))
>      }
>  
> +    async fn handle_channel_message<W>(
> +        result: WebSocketReadResult,
> +        writer: &mut WebSocketWriter<W>
> +    ) -> Result<OpCode, Error>
> +    where
> +        W: AsyncWrite + Unpin + Send,
> +    {
> +        match result {
> +            Ok((OpCode::Ping, msg)) => {
> +                writer.send_control_frame(None, OpCode::Pong, &msg).await?;
> +                Ok(OpCode::Pong)
> +            }
> +            Ok((OpCode::Close, msg)) => {
> +                writer.send_control_frame(None, OpCode::Close, &msg).await?;
> +                Ok(OpCode::Close)
> +            }
> +            Ok((opcode, _)) => {
> +                // ignore other frames
> +                Ok(opcode)
> +            },
> +            Err(err) => {
> +                writer.send_control_frame(None, OpCode::Close, &err.generate_frame_payload()).await?;
> +                Err(Error::from(err))
> +            }
> +        }
> +    }
> +
>      async fn copy_to_websocket<R, W>(
>          mut reader: &mut R,
> -        writer: &mut WebSocketWriter<W>,
> -        receiver: &mut mpsc::UnboundedReceiver<(OpCode, Box<[u8]>)>) -> Result<bool, Error>
> +        mut writer: &mut WebSocketWriter<W>,
> +        receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>) -> Result<bool, Error>
>      where
>          R: AsyncRead + Unpin + Send,
>          W: AsyncWrite + Unpin + Send,
> @@ -654,20 +746,10 @@ impl WebSocket {
>                  let bytes = select!{
>                      res = buf.read_from_async(&mut reader).fuse() => res?,
>                      res = receiver.recv().fuse() => {
> -                        let (opcode, msg) = res.ok_or(format_err!("control channel closed"))?;
> -                        match opcode {
> -                            OpCode::Ping => {
> -                                writer.send_control_frame(None, OpCode::Pong, &msg).await?;
> -                                continue;
> -                            }
> -                            OpCode::Close => {
> -                                writer.send_control_frame(None, OpCode::Close, &msg).await?;
> -                                return Ok(true);
> -                            }
> -                            _ => {
> -                                // ignore other frames
> -                                continue;
> -                            }
> +                        let res = res.ok_or_else(|| format_err!("control channel closed"))?;
> +                        match Self::handle_channel_message(res, &mut writer).await? {
> +                            OpCode::Close => return Ok(true),
> +                            _ => { continue; },
>                          }
>                      }
>                  };
> @@ -720,7 +802,7 @@ impl WebSocket {
>              res = term_future.fuse() => match res {
>                  Ok(sent_close) if !sent_close => {
>                      // status code 1000 => 0x03E8
> -                    wswriter.send_control_frame(None, OpCode::Close, &[0x03, 0xE8]).await?;
> +                    wswriter.send_control_frame(None, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?;
>                      Ok(())
>                  }
>                  Ok(_) => Ok(()),
> -- 
> 2.20.1





More information about the pbs-devel mailing list