[pve-devel] [PATCH proxmox-websocket-tunnel 2/4] add tunnel implementation

Dominik Csapak d.csapak at proxmox.com
Tue Nov 9 13:54:07 CET 2021


looks mostly fine, some comments inline

On 11/5/21 14:03, Fabian Grünbichler wrote:
> the websocket tunnel helper accepts control commands (encoded as
> single-line JSON) on stdin, and prints responses on stdout.
> 
> the following commands are available:
> - "connect" a 'control' tunnel via a websocket
> - "forward" a local unix socket to a remote socket via a websocket
> -- if requested, this will ask for a ticket via the control tunnel after
> accepting a new connection on the unix socket
> - "close" the control tunnel and any forwarded socket
> 
> any other json input (without the 'control' flag set) is forwarded as-is
> to the remote end of the control tunnel.
> 
> internally, the tunnel helper will spawn tokio tasks for
> - handling the control tunnel connection (new commands are passed in via
> an mpsc channel together with a oneshot channel for the response)
> - handling accepting new connections on each forwarded unix socket
> - handling forwarding data over accepted forwarded connections
> 
> Signed-off-by: Fabian Grünbichler <f.gruenbichler at proxmox.com>
> ---
> 
> Notes:
>      requires proxmox-http with changes and bumped version
> 
>   Cargo.toml  |  13 ++
>   src/main.rs | 410 ++++++++++++++++++++++++++++++++++++++++++++++++++++
>   2 files changed, 423 insertions(+)
>   create mode 100644 src/main.rs
> 
> diff --git a/Cargo.toml b/Cargo.toml
> index 939184c..9d2a8c6 100644
> --- a/Cargo.toml
> +++ b/Cargo.toml
> @@ -9,3 +9,16 @@ description = "Proxmox websocket tunneling helper"
>   exclude = ["debian"]
>   
>   [dependencies]
> +anyhow = "1.0"
> +base64 = "0.12"
> +futures = "0.3"
> +futures-util = "0.3"
> +hyper = { version = "0.14" }
> +openssl = "0.10"
> +percent-encoding = "2"
> +proxmox-http = { version = "0.5.2", path = "../proxmox/proxmox-http", features = ["websocket", "client"] }
> +serde = { version = "1.0", features = ["derive"] }
> +serde_json = "1.0"
> +tokio = { version = "1", features = ["io-std", "io-util", "macros", "rt-multi-thread", "sync"] }
> +tokio-stream = { version = "0.1", features = ["io-util"] }
> +tokio-util = "0.6"
> diff --git a/src/main.rs b/src/main.rs
> new file mode 100644
> index 0000000..150c1cf
> --- /dev/null
> +++ b/src/main.rs
> @@ -0,0 +1,410 @@
> +use anyhow::{bail, format_err, Error};
> +
> +use std::collections::VecDeque;
> +use std::sync::{Arc, Mutex};
> +
> +use futures::future::FutureExt;
> +use futures::select;
> +
> +use hyper::client::{Client, HttpConnector};
> +use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE};
> +use hyper::upgrade::Upgraded;
> +use hyper::{Body, Request, StatusCode};
> +
> +use openssl::ssl::{SslConnector, SslMethod};
> +use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
> +
> +use serde::{Deserialize, Serialize};
> +use serde_json::{Map, Value};
> +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
> +use tokio::net::{UnixListener, UnixStream};
> +use tokio::sync::{mpsc, oneshot};
> +use tokio_stream::wrappers::LinesStream;
> +use tokio_stream::StreamExt;
> +
> +use proxmox_http::client::HttpsConnector;
> +use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter};
> +
> +#[derive(Serialize, Deserialize, Debug)]
> +#[serde(rename_all = "kebab-case")]
> +enum CmdType {
> +    Connect,
> +    Forward,
> +    CloseCmd,

this is never used

> +    NonControl,
> +}
> +
> +type CmdData = Map<String, Value>;
> +
> +#[derive(Serialize, Deserialize, Debug)]
> +#[serde(rename_all = "kebab-case")]
> +struct ConnectCmdData {
> +    // target URL for WS connection
> +    url: String,
> +    // fingerprint of TLS certificate
> +    fingerprint: Option<String>,
> +    // addition headers such as authorization
> +    headers: Option<Vec<(String, String)>>,
> +}
> +
> +#[derive(Serialize, Deserialize, Debug, Clone)]
> +#[serde(rename_all = "kebab-case")]
> +struct ForwardCmdData {
> +    // target URL for WS connection
> +    url: String,
> +    // addition headers such as authorization
> +    headers: Option<Vec<(String, String)>>,
> +    // fingerprint of TLS certificate
> +    fingerprint: Option<String>,
> +    // local UNIX socket path for forwarding
> +    unix: String,
> +    // request ticket using these parameters
> +    ticket: Option<Map<String, Value>>,
> +}
> +
> +struct CtrlTunnel {
> +    sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
> +    forwarded: Arc<Mutex<Vec<oneshot::Sender<()>>>>,

for now, this is really not used (see my comments further below)

> +}
> +
> +impl CtrlTunnel {
> +    async fn read_cmd_loop(mut self) -> Result<(), Error> {
> +        let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines());
> +        while let Some(res) = stdin_stream.next().await {
> +            match res {
> +                Ok(line) => {
> +                    let (cmd_type, data) = Self::parse_cmd(&line)?;
> +                    match cmd_type {
> +                        CmdType::Connect => self.handle_connect_cmd(data).await,
> +                        CmdType::Forward => {
> +                            let res = self.handle_forward_cmd(data).await;
> +                            match &res {
> +                                Ok(()) => println!("{}", serde_json::json!({"success": true})),
> +                                Err(msg) => println!(
> +                                    "{}",
> +                                    serde_json::json!({"success": false, "msg": msg.to_string()})
> +                                ),
> +                            };
> +                            res
> +                        }
> +                        CmdType::NonControl => self
> +                            .handle_tunnel_cmd(data)
> +                            .await
> +                            .map(|res| println!("{}", res)),
> +                        _ => unimplemented!(),
> +                    }
> +                }
> +                Err(err) => bail!("Failed to read from STDIN - {}", err),
> +            }?;
> +        }
> +
> +        Ok(())
> +    }
> +
> +    fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> {
> +        let mut json: Map<String, Value> = serde_json::from_str(line)?;
> +        match json.remove("control") {
> +            Some(Value::Bool(true)) => {
> +                match json.remove("cmd").map(serde_json::from_value::<CmdType>) {
> +                    None => bail!("input has 'control' flag, but no control 'cmd' set.."),
> +                    Some(Err(e)) => bail!("failed to parse control cmd - {}", e),
> +                    Some(Ok(cmd_type)) => Ok((cmd_type, json)),
> +                }
> +            }
> +            _ => Ok((CmdType::NonControl, json)),
> +        }
> +    }
> +
> +    async fn websocket_connect(
> +        url: String,
> +        extra_headers: Vec<(String, String)>,
> +        fingerprint: Option<String>,
> +    ) -> Result<Upgraded, Error> {
> +        let ws_key = proxmox::sys::linux::random_data(16)?;
> +        let ws_key = base64::encode(&ws_key);
> +        let mut req = Request::builder()
> +            .uri(url)
> +            .header(UPGRADE, "websocket")
> +            .header(SEC_WEBSOCKET_VERSION, "13")
> +            .header(SEC_WEBSOCKET_KEY, ws_key)
> +            .body(Body::empty())
> +            .unwrap();
> +
> +        let headers = req.headers_mut();
> +        for (name, value) in extra_headers {
> +            let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
> +            let value = hyper::header::HeaderValue::from_str(&value)?;
> +            headers.insert(name, value);
> +        }
> +
> +        let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();

not sure if this unwrap cannot fail though?

> +        if fingerprint.is_some() {
> +            // FIXME actually verify fingerprint via callback!
> +            ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
> +        } else {
> +            ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
> +        }
> +
> +        let mut httpc = HttpConnector::new();
> +        httpc.enforce_http(false); // we want https...
> +        httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0)));
> +        let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120);
> +
> +        let client = Client::builder().build::<_, Body>(https);
> +        let res = client.request(req).await?;
> +        if res.status() != StatusCode::SWITCHING_PROTOCOLS {
> +            bail!("server didn't upgrade: {}", res.status());
> +        }
> +
> +        hyper::upgrade::on(res)
> +            .await
> +            .map_err(|err| format_err!("failed to upgrade - {}", err))
> +    }
> +
> +    async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
> +        let mut data: ConnectCmdData = data
> +            .remove("data")
> +            .ok_or_else(|| format_err!("'connect' command missing 'data'"))
> +            .map(serde_json::from_value)??;
> +
> +        if self.sender.is_some() {
> +            bail!("already connected!");
> +        }
> +
> +        let upgraded = Self::websocket_connect(
> +            data.url.clone(),
> +            data.headers.take().unwrap_or_else(Vec::new),
> +            data.fingerprint.take(),
> +        )
> +        .await?;
> +
> +        let (tx, rx) = mpsc::unbounded_channel();
> +        self.sender = Some(tx);
> +        tokio::spawn(async move {
> +            if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await {
> +                eprintln!("Tunnel to {} failed - {}", data.url, err);
> +            }
> +        });
> +
> +        Ok(())
> +    }
> +
> +    async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
> +        let data: ForwardCmdData = data
> +            .remove("data")
> +            .ok_or_else(|| format_err!("'forward' command missing 'data'"))
> +            .map(serde_json::from_value)??;
> +
> +        if self.sender.is_none() && data.ticket.is_some() {
> +            bail!("dynamically requesting ticket requires cmd tunnel connection!");
> +        }
> +
> +        let unix_listener = UnixListener::bind(data.unix.clone()).unwrap();

it would be better to bubble the error up instead of unwrapping here
(AFAICS, the rest of the unwraps are ok, since they cannot really fail?)

> +        let (tx, rx) = oneshot::channel();
> +        let data = Arc::new(data);
> +
> +        self.forwarded.lock().unwrap().push(tx);

we push the 'tx' here into the forwarded vec, but never use it again
(no other 'lock()' call in the file)

> +        let cmd_sender = self.sender.clone();
> +
> +        tokio::spawn(async move {
> +            let mut rx = rx.fuse();
> +            let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
> +            loop {
> +                let accept = unix_listener.accept().fuse();
> +                tokio::pin!(accept);
> +                let data2 = data.clone();
> +                select! {
> +                    _ = rx => {
> +                        eprintln!("received shutdown signal, closing unix listener stream and forwarding handlers");
> +                        for task in tasks {
> +                            task.abort();
> +                        }
> +                        break;
> +                    },

which makes this branch dead code

so i'd drop the forwarded part and simplify this to

match unix_listener.accept().await {
...
}

> +                    res = accept => match res {
> +                        Ok((unix_stream, _)) => {
> +                            eprintln!("accepted new connection on '{}'", data2.unix);
> +                            let data3: Result<Arc<ForwardCmdData>, Error> = match (&cmd_sender, &data2.ticket) {
> +                                (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data2.clone()).await,\

the get_ticket could probably be inside the 'handle_forward_tunnel' this 
way, another client could connect while the first ticket is checked.
not necessary for now though, since we do not connect in parallel atm

> +                                _ => Ok(data2.clone()),
> +                            };
> +
> +                            match data3 {
> +                                Ok(data3) => {
> +                                    let task = tokio::spawn(async move {
> +                                        if let Err(err) = Self::handle_forward_tunnel(data3.clone(), unix_stream).await {
> +                                            eprintln!("Tunnel for {} failed - {}", data3.unix, err);
> +                                        }
> +                                    });
> +                                    tasks.push(task);
> +                                },
> +                                Err(err) => {
> +                                    eprintln!("Failed to accept unix connection - {}", err);
> +                                },
> +                            };
> +                        },
> +                        Err(err) => eprintln!("Failed to accept unix connection on {} - {}", data2.unix, err),
> +                    },
> +                };
> +            }
> +        });
> +
> +        Ok(())
> +    }
> +
> +    async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> {
> +        match &mut self.sender {
> +            None => bail!("not connected!"),
> +            Some(sender) => {
> +                let data: Value = data.into();
> +                let (tx, rx) = oneshot::channel::<String>();
> +                if let Some(cmd) = data.get("cmd") {
> +                    eprintln!("-> sending command {} to remote", cmd);
> +                } else {
> +                    eprintln!("-> sending data line to remote");
> +                }
> +                sender.send((data, tx))?;
> +                let res = rx.await?;
> +                eprintln!("<- got reply");
> +                Ok(res)
> +            }
> +        }
> +    }
> +
> +    async fn handle_ctrl_tunnel(
> +        websocket: Upgraded,
> +        mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>,
> +    ) -> Result<(), Error> {
> +        let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket);
> +        let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel();
> +        let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx);
> +        let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer);
> +
> +        let mut framed_reader =
> +            tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new());
> +
> +        let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new();
> +        let mut shutting_down = false;
> +
> +        loop {
> +            let mut close_future = ws_close_rx.recv().boxed().fuse();
> +            let mut frame_future = framed_reader.next().boxed().fuse();
> +            let mut cmd_future = cmd_receiver.recv().boxed().fuse();
> +
> +            select! {
> +                res = close_future => {
> +                    let res = res.ok_or_else(|| format_err!("WS control channel closed"))?;
> +                    eprintln!("WS: received control message: '{:?}'", res);
> +                    shutting_down = true;
> +                },
> +                res = frame_future => {
> +                    match res {
> +                        None if shutting_down => {
> +                            eprintln!("WS closed");
> +                            break;
> +                        },
> +                        None => bail!("WS closed unexpectedly"),
> +                        Some(Ok(res)) => {
> +                            resp_tx_queue
> +                                .pop_front()
> +                                .ok_or_else(|| format_err!("no response handler"))?
> +                                .send(res)
> +                                .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?;
> +                        },
> +                        Some(Err(err)) => {
> +                            bail!("reading from control tunnel failed - WS receive failed: {}", err);
> +                        },
> +                    }
> +                },
> +                res = cmd_future => {
> +                    if shutting_down { continue };
> +                    match res {
> +                        None => {
> +                            eprintln!("CMD channel closed, shutting down");
> +                            ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?;
> +                            shutting_down = true;
> +                        },
> +                        Some((msg, resp_tx)) => {
> +                            resp_tx_queue.push_back(resp_tx);
> +
> +                            let line = format!("{}\n", msg);
> +                            ws_writer.write_all(line.as_bytes()).await?;
> +                            ws_writer.flush().await?;
> +                        },
> +                    }
> +                },
> +            };
> +        }
> +
> +        Ok(())
> +    }
> +
> +    async fn handle_forward_tunnel(
> +        data: Arc<ForwardCmdData>,
> +        unix: UnixStream,
> +    ) -> Result<(), Error> {
> +        let upgraded = Self::websocket_connect(
> +            data.url.clone(),
> +            data.headers.clone().unwrap_or_else(Vec::new),
> +            data.fingerprint.clone(),
> +        )
> +        .await?;
> +
> +        let ws = WebSocket {
> +            mask: Some([0, 0, 0, 0]),
> +        };
> +        eprintln!("established new WS for forwarding '{}'", data.unix);
> +        ws.serve_connection(upgraded, unix).await?;
> +
> +        eprintln!("done handling forwarded connection from '{}'", data.unix);
> +
> +        Ok(())
> +    }
> +
> +    async fn get_ticket(
> +        cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>,
> +        cmd_data: Arc<ForwardCmdData>,
> +    ) -> Result<Arc<ForwardCmdData>, Error> {
> +        eprintln!("requesting WS ticket via tunnel");
> +        let ticket_cmd = match cmd_data.ticket.clone() {
> +            Some(mut ticket_cmd) => {
> +                ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket"));
> +                ticket_cmd
> +            }
> +            None => bail!("can't get ticket without ticket parameters"),
> +        };
> +        let (tx, rx) = oneshot::channel::<String>();
> +        cmd_sender.send((serde_json::json!(ticket_cmd), tx))?;
> +        let ticket = rx.await?;
> +        let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?;
> +        let ticket = ticket
> +            .remove("ticket")
> +            .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?;
> +
> +        let ticket = ticket
> +            .as_str()
> +            .ok_or_else(|| format_err!("failed to format received ticket"))?;
> +        let ticket = utf8_percent_encode(&ticket, NON_ALPHANUMERIC).to_string();
> +
> +        let mut data = cmd_data.clone();
> +        let mut url = data.url.clone();
> +        url.push_str("ticket=");
> +        url.push_str(&ticket);
> +        let mut d = Arc::make_mut(&mut data);
> +        d.url = url;
> +        Ok(data)
> +    }
> +}
> +
> +#[tokio::main]
> +async fn main() -> Result<(), Error> {
> +    do_main().await
> +}
> +
> +async fn do_main() -> Result<(), Error> {
> +    let tunnel = CtrlTunnel {
> +        sender: None,
> +        forwarded: Arc::new(Mutex::new(Vec::new())),
> +    };
> +    tunnel.read_cmd_loop().await
> +}
> 






More information about the pve-devel mailing list