[pbs-devel] applied: [PATCH proxmox v2 1/4] proxmox/tools/websocket: introduce WebSocketError and use it
Wolfgang Bumiller
w.bumiller at proxmox.com
Fri Jul 17 14:38:15 CEST 2020
applied series
On Fri, Jul 17, 2020 at 02:17:07PM +0200, Dominik Csapak wrote:
> this patch introduces a custom WebSocketError, so that we can
> detect errors in the websocket protocol and react to it in the
> right way
>
> the channel now sends a Result<(OpCode, Box<[u8]>), WebSocketError>
> so that we can either react to a control frame, or react to an
> errornous data stream (which means sending a close frame
> with the appropriate error code)
>
> while at it, change FrameHeader::try_from_bytes to return
> Result<Option<FrameHeader>, WebSocketError> instead of a nested
> Result with the *guessed* remaining size. This was neither used by
> us, nor was it really meaningful, since this can change during the
> decode every time new data is read (extensions, mask, payload length, etc.)
> so simply returning an Option is enough information for us
>
> Signed-off-by: Dominik Csapak <d.csapak at proxmox.com>
> ---
> changes from v1:
> * fix one instance of Ok(Err) for try_from_bytes that was leftover
> * omit the pub const slices and implement to_be_bytes for the ErrorKind
> instead
> * (correctly) use extend_from_slice instead of copy_from_slice
> * impl std::error::Error instead of From<WebSocketError> for anyhow::Error
>
> proxmox/src/tools/websocket.rs | 184 ++++++++++++++++++++++++---------
> 1 file changed, 133 insertions(+), 51 deletions(-)
>
> diff --git a/proxmox/src/tools/websocket.rs b/proxmox/src/tools/websocket.rs
> index fc9a0c5..c6775f0 100644
> --- a/proxmox/src/tools/websocket.rs
> +++ b/proxmox/src/tools/websocket.rs
> @@ -7,7 +7,7 @@
> use std::pin::Pin;
> use std::task::{Context, Poll};
> use std::cmp::min;
> -use std::io::{self, ErrorKind};
> +use std::io;
> use std::future::Future;
>
> use futures::select;
> @@ -29,9 +29,65 @@ use hyper::header::{
> use futures::future::FutureExt;
> use futures::ready;
>
> -use crate::io_format_err;
> +use crate::sys::error::io_err_other;
> use crate::tools::byte_buffer::ByteBuffer;
>
> +// see RFC6455 section 7.4.1
> +#[derive(Debug, Clone, Copy)]
> +#[repr(u16)]
> +pub enum WebSocketErrorKind {
> + Normal = 1000,
> + ProtocolError = 1002,
> + InvalidData = 1003,
> + Other = 1008,
> + Unexpected = 1011,
> +}
> +
> +impl WebSocketErrorKind {
> + #[inline]
> + pub fn to_be_bytes(self) -> [u8; 2] {
> + (self as u16).to_be_bytes()
> + }
> +}
> +
> +impl std::fmt::Display for WebSocketErrorKind {
> + fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
> + write!(f, "{}", *self as u16)
> + }
> +}
> +
> +#[derive(Debug, Clone)]
> +pub struct WebSocketError{
> + kind: WebSocketErrorKind,
> + message: String,
> +}
> +
> +impl WebSocketError {
> + pub fn new(kind: WebSocketErrorKind, message: &str) -> Self {
> + Self{
> + kind,
> + message: message.to_string()
> + }
> + }
> +
> + pub fn generate_frame_payload(&self) -> Vec<u8> {
> + let msglen = self.message.len().min(125);
> + let code = self.kind.to_be_bytes();
> + let mut data = Vec::with_capacity(msglen + 2);
> + data.extend_from_slice(&code);
> + data.extend_from_slice(&self.message.as_bytes()[..msglen]);
> + data
> + }
> +}
> +
> +impl std::fmt::Display for WebSocketError {
> + fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
> + write!(f, "{} (Code: {})", self.message, self.kind)
> + }
> +}
> +
> +impl std::error::Error for WebSocketError {}
> +
> #[repr(u8)]
> #[derive(Debug, PartialEq, PartialOrd, Copy, Clone)]
> /// Represents an OpCode of a websocket frame
> @@ -293,25 +349,23 @@ impl FrameHeader {
> /// Tries to parse a FrameHeader from bytes.
> ///
> /// When there are not enough bytes to completely parse the header,
> - /// returns Ok(Err(size)) where size determines how many bytes
> - /// are missing to parse further (this amount can change when more
> - /// information is available)
> + /// returns Ok(None)
> ///
> /// Example:
> /// ```
> /// # use proxmox::tools::websocket::*;
> /// # use std::io;
> - /// # fn main() -> io::Result<()> {
> + /// # fn main() -> Result<(), WebSocketError> {
> /// let frame = create_frame(None, &[0,1,2,3], OpCode::Ping)?;
> /// let header = FrameHeader::try_from_bytes(&frame[..1])?;
> /// match header {
> - /// Ok(_) => unreachable!(),
> - /// Err(x) => assert_eq!(x, 1),
> + /// Some(_) => unreachable!(),
> + /// None => {},
> /// }
> /// let header = FrameHeader::try_from_bytes(&frame[..2])?;
> /// match header {
> - /// Err(x) => unreachable!(),
> - /// Ok(header) => assert_eq!(header, FrameHeader{
> + /// None => unreachable!(),
> + /// Some(header) => assert_eq!(header, FrameHeader{
> /// fin: true,
> /// mask: None,
> /// frametype: OpCode::Ping,
> @@ -322,19 +376,19 @@ impl FrameHeader {
> /// # Ok(())
> /// # }
> /// ```
> - pub fn try_from_bytes(data: &[u8]) -> io::Result<Result<FrameHeader, usize>> {
> + pub fn try_from_bytes(data: &[u8]) -> Result<Option<FrameHeader>, WebSocketError> {
> let len = data.len();
> if len < 2 {
> - return Ok(Err(2 - len));
> + return Ok(None);
> }
>
> let data = data;
>
> // we do not support extensions
> if data[0] & 0b01110000 > 0 {
> - return Err(io::Error::new(
> - ErrorKind::InvalidData,
> - "Extensions not supported",
> + return Err(WebSocketError::new(
> + WebSocketErrorKind::ProtocolError,
> + "Extensions not supported",
> ));
> }
>
> @@ -347,14 +401,17 @@ impl FrameHeader {
> 9 => OpCode::Ping,
> 10 => OpCode::Pong,
> other => {
> - return Err(io::Error::new(ErrorKind::InvalidData, format!("Unknown OpCode {}", other)));
> + return Err(WebSocketError::new(
> + WebSocketErrorKind::ProtocolError,
> + &format!("Unknown OpCode {}", other),
> + ));
> }
> };
>
> if !fin && frametype.is_control() {
> - return Err(io::Error::new(
> - ErrorKind::InvalidData,
> - "Control frames cannot be fragmented",
> + return Err(WebSocketError::new(
> + WebSocketErrorKind::ProtocolError,
> + "Control frames cannot be fragmented",
> ));
> }
>
> @@ -368,14 +425,14 @@ impl FrameHeader {
> let mut payload_len: usize = (data[1] & 0b01111111).into();
> if payload_len == 126 {
> if len < 4 {
> - return Ok(Err(4 - len));
> + return Ok(None);
> }
> payload_len = u16::from_be_bytes([data[2], data[3]]) as usize;
> mask_offset += 2;
> payload_offset += 2;
> } else if payload_len == 127 {
> if len < 10 {
> - return Ok(Err(10 - len));
> + return Ok(None);
> }
> payload_len = u64::from_be_bytes([
> data[2], data[3], data[4], data[5], data[6], data[7], data[8], data[9],
> @@ -385,16 +442,16 @@ impl FrameHeader {
> }
>
> if payload_len > 125 && frametype.is_control() {
> - return Err(io::Error::new(
> - ErrorKind::InvalidData,
> - "Control frames cannot carry more than 125 bytes of data",
> + return Err(WebSocketError::new(
> + WebSocketErrorKind::ProtocolError,
> + "Control frames cannot carry more than 125 bytes of data",
> ));
> }
>
> let mask = match mask_bit {
> true => {
> if len < mask_offset + 4 {
> - return Ok(Err(mask_offset + 4 - len));
> + return Ok(None);
> }
> let mut mask = [0u8; 4];
> mask.copy_from_slice(&data[mask_offset as usize..payload_offset as usize]);
> @@ -403,7 +460,7 @@ impl FrameHeader {
> false => None,
> };
>
> - Ok(Ok(FrameHeader {
> + Ok(Some(FrameHeader {
> fin,
> mask,
> frametype,
> @@ -413,6 +470,8 @@ impl FrameHeader {
> }
> }
>
> +type WebSocketReadResult = Result<(OpCode, Box<[u8]>), WebSocketError>;
> +
> /// Wraps a reader that implements AsyncRead and implements it itself.
> ///
> /// On read, reads the underlying reader and tries to decode the frames and
> @@ -422,7 +481,7 @@ impl FrameHeader {
> /// Has an internal Buffer for storing incomplete headers.
> pub struct WebSocketReader<R: AsyncRead> {
> reader: Option<R>,
> - sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>,
> + sender: mpsc::UnboundedSender<WebSocketReadResult>,
> read_buffer: Option<ByteBuffer>,
> header: Option<FrameHeader>,
> state: ReaderState<R>,
> @@ -431,11 +490,11 @@ pub struct WebSocketReader<R: AsyncRead> {
> impl<R: AsyncReadExt> WebSocketReader<R> {
> /// Creates a new WebSocketReader with the given CallBack for control frames
> /// and a default buffer size of 4096.
> - pub fn new(reader: R, sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>) -> WebSocketReader<R> {
> + pub fn new(reader: R, sender: mpsc::UnboundedSender<WebSocketReadResult>) -> WebSocketReader<R> {
> Self::with_capacity(reader, 4096, sender)
> }
>
> - pub fn with_capacity(reader: R, capacity: usize, sender: mpsc::UnboundedSender<(OpCode, Box<[u8]>)>) -> WebSocketReader<R> {
> + pub fn with_capacity(reader: R, capacity: usize, sender: mpsc::UnboundedSender<WebSocketReadResult>) -> WebSocketReader<R> {
> WebSocketReader {
> reader: Some(reader),
> sender,
> @@ -512,13 +571,19 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
> let mut header = match this.header.take() {
> Some(header) => header,
> None => {
> - let header = match FrameHeader::try_from_bytes(&read_buffer[..])? {
> - Ok(header) => header,
> - Err(_) => {
> + let header = match FrameHeader::try_from_bytes(&read_buffer[..]) {
> + Ok(Some(header)) => header,
> + Ok(None) => {
> this.state = ReaderState::NoData;
> this.read_buffer = Some(read_buffer);
> continue;
> - }
> + },
> + Err(err) => {
> + if let Err(err) = this.sender.send(Err(err.clone())) {
> + return Poll::Ready(Err(io_err_other(err)));
> + }
> + return Poll::Ready(Err(io_err_other(err)));
> + },
> };
>
> read_buffer.consume(header.header_len as usize);
> @@ -531,7 +596,7 @@ impl<R: AsyncReadExt + Unpin + Send + 'static> AsyncRead for WebSocketReader<R>
>
> let mut data = read_buffer.remove_data(header.payload_len);
> mask_bytes(header.mask, &mut data);
> - if let Err(err) = this.sender.send((header.frametype, data)) {
> + if let Err(err) = this.sender.send(Ok((header.frametype, data))) {
> eprintln!("error sending control frame: {}", err);
> }
>
> @@ -639,10 +704,37 @@ impl WebSocket {
> Ok((Self { text }, response))
> }
>
> + async fn handle_channel_message<W>(
> + result: WebSocketReadResult,
> + writer: &mut WebSocketWriter<W>
> + ) -> Result<OpCode, Error>
> + where
> + W: AsyncWrite + Unpin + Send,
> + {
> + match result {
> + Ok((OpCode::Ping, msg)) => {
> + writer.send_control_frame(None, OpCode::Pong, &msg).await?;
> + Ok(OpCode::Pong)
> + }
> + Ok((OpCode::Close, msg)) => {
> + writer.send_control_frame(None, OpCode::Close, &msg).await?;
> + Ok(OpCode::Close)
> + }
> + Ok((opcode, _)) => {
> + // ignore other frames
> + Ok(opcode)
> + },
> + Err(err) => {
> + writer.send_control_frame(None, OpCode::Close, &err.generate_frame_payload()).await?;
> + Err(Error::from(err))
> + }
> + }
> + }
> +
> async fn copy_to_websocket<R, W>(
> mut reader: &mut R,
> - writer: &mut WebSocketWriter<W>,
> - receiver: &mut mpsc::UnboundedReceiver<(OpCode, Box<[u8]>)>) -> Result<bool, Error>
> + mut writer: &mut WebSocketWriter<W>,
> + receiver: &mut mpsc::UnboundedReceiver<WebSocketReadResult>) -> Result<bool, Error>
> where
> R: AsyncRead + Unpin + Send,
> W: AsyncWrite + Unpin + Send,
> @@ -654,20 +746,10 @@ impl WebSocket {
> let bytes = select!{
> res = buf.read_from_async(&mut reader).fuse() => res?,
> res = receiver.recv().fuse() => {
> - let (opcode, msg) = res.ok_or(format_err!("control channel closed"))?;
> - match opcode {
> - OpCode::Ping => {
> - writer.send_control_frame(None, OpCode::Pong, &msg).await?;
> - continue;
> - }
> - OpCode::Close => {
> - writer.send_control_frame(None, OpCode::Close, &msg).await?;
> - return Ok(true);
> - }
> - _ => {
> - // ignore other frames
> - continue;
> - }
> + let res = res.ok_or_else(|| format_err!("control channel closed"))?;
> + match Self::handle_channel_message(res, &mut writer).await? {
> + OpCode::Close => return Ok(true),
> + _ => { continue; },
> }
> }
> };
> @@ -720,7 +802,7 @@ impl WebSocket {
> res = term_future.fuse() => match res {
> Ok(sent_close) if !sent_close => {
> // status code 1000 => 0x03E8
> - wswriter.send_control_frame(None, OpCode::Close, &[0x03, 0xE8]).await?;
> + wswriter.send_control_frame(None, OpCode::Close, &WebSocketErrorKind::Normal.to_be_bytes()).await?;
> Ok(())
> }
> Ok(_) => Ok(()),
> --
> 2.20.1
More information about the pbs-devel
mailing list