[pve-devel] [PATCH proxmox-websocket-tunnel 2/2] add tunnel implementation

Fabian Grünbichler f.gruenbichler at proxmox.com
Tue Apr 13 14:16:22 CEST 2021


the websocket tunnel helper accepts control commands (encoded as
single-line JSON) on stdin, and prints responses on stdout.

the following commands are available:
- "connect" a 'control' tunnel via a websocket
- "forward" a local unix socket to a remote socket via a websocket
-- if requested, this will ask for a ticket via the control tunnel after
accepting a new connection on the unix socket
- "close" the control tunnel and any forwarded socket

any other json input (without the 'control' flag set) is forwarded as-is
to the remote end of the control tunnel.

internally, the tunnel helper will spawn tokio tasks for
- handling the control tunnel connection (new commands are passed in via
an mpsc channel together with a oneshot channel for the response)
- handling accepting new connections on each forwarded unix socket
- handling forwarding data over accepted forwarded connections

Signed-off-by: Fabian Grünbichler <f.gruenbichler at proxmox.com>
---

Notes:
    proxmox_backup dependency is just there for HTTP client code that should be extracted somewhere else
    proxmox dependency requires websocket patches exposing a bit more stuff for client usage

    full repo available on my staff git..

 Cargo.toml  |  14 ++
 src/main.rs | 396 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 410 insertions(+)
 create mode 100644 src/main.rs

diff --git a/Cargo.toml b/Cargo.toml
index 939184c..18ba297 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -9,3 +9,17 @@ description = "Proxmox websocket tunneling helper"
 exclude = ["debian"]
 
 [dependencies]
+anyhow = "1.0"
+futures = "0.3"
+futures-util = "0.3"
+hyper = { version = "0.14" }
+openssl = "0.10"
+percent-encoding = "2"
+proxmox = { version = "0.11", path = "../proxmox/proxmox", features = ["websocket"] }
+# just for tools::http and tools::async_io::EitherStream, need to move them somewhere else
+proxmox-backup = { version = "1.0.8", path = "../proxmox-backup" }
+serde = { version = "1.0", features = ["derive"] }
+serde_json = "1.0"
+tokio = { version = "1", features = ["io-std", "io-util", "macros", "rt-multi-thread", "sync"] }
+tokio-stream = { version = "0.1", features = ["io-util"] }
+tokio-util = "0.6"
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..6d21352
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,396 @@
+use anyhow::{bail, format_err, Error};
+
+use std::collections::VecDeque;
+use std::sync::{Arc, Mutex};
+
+use futures::future::FutureExt;
+use futures::select;
+
+use hyper::client::{Client, HttpConnector};
+use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL, SEC_WEBSOCKET_VERSION, UPGRADE};
+use hyper::upgrade::Upgraded;
+use hyper::{Body, Request, StatusCode};
+
+use openssl::ssl::{SslConnector, SslMethod};
+use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
+
+use serde::{Deserialize, Serialize};
+use serde_json::{Map, Value};
+use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
+use tokio::net::{UnixListener, UnixStream};
+use tokio::sync::{mpsc, oneshot};
+use tokio_stream::wrappers::LinesStream;
+use tokio_stream::StreamExt;
+
+use proxmox::tools::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter};
+use proxmox_backup::tools::http::HttpsConnector;
+
+#[derive(Serialize, Deserialize, Debug)]
+#[serde(rename_all = "kebab-case")]
+enum CmdType {
+    Connect,
+    Forward,
+    CloseCmd,
+    NonControl,
+}
+
+type CmdData = Map<String, Value>;
+
+#[derive(Serialize, Deserialize, Debug)]
+#[serde(rename_all = "kebab-case")]
+struct ConnectCmdData {
+    // target URL for WS connection
+    url: String,
+    // fingerprint of TLS certificate
+    fingerprint: Option<String>,
+    // addition headers such as authorization
+    headers: Option<Vec<(String, String)>>,
+}
+
+#[derive(Serialize, Deserialize, Debug, Clone)]
+#[serde(rename_all = "kebab-case")]
+struct ForwardCmdData {
+    // target URL for WS connection
+    url: String,
+    // addition headers such as authorization
+    headers: Option<Vec<(String, String)>>,
+    // fingerprint of TLS certificate
+    fingerprint: Option<String>,
+    // local UNIX socket path for forwarding
+    unix: String,
+    // request ticket using these parameters
+    ticket: Option<Map<String, Value>>,
+}
+
+struct CtrlTunnel {
+    sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
+    forwarded: Arc<Mutex<Vec<oneshot::Sender<()>>>>,
+}
+
+impl CtrlTunnel {
+    async fn read_cmd_loop(mut self) -> Result<(), Error> {
+        let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines());
+        while let Some(res) = stdin_stream.next().await {
+            match res {
+                Ok(line) => {
+                    let (cmd_type, data) = Self::parse_cmd(&line)?;
+                    match cmd_type {
+                        CmdType::Connect => self.handle_connect_cmd(data).await,
+                        CmdType::Forward => {
+                            let res = self.handle_forward_cmd(data).await;
+                            match &res {
+                                Ok(()) => println!("{}", serde_json::json!({"success": true})),
+                                Err(msg) => println!(
+                                    "{}",
+                                    serde_json::json!({"success": false, "msg": msg.to_string()})
+                                ),
+                            };
+                            res
+                        }
+                        CmdType::NonControl => self
+                            .handle_tunnel_cmd(data)
+                            .await
+                            .map(|res| println!("{}", res)),
+                        _ => unimplemented!(),
+                    }
+                }
+                Err(err) => bail!("Failed to read from STDIN - {}", err),
+            }?;
+        }
+
+        Ok(())
+    }
+
+    fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> {
+        let mut json: Map<String, Value> = serde_json::from_str(line)?;
+        match json.remove("control") {
+            Some(Value::Bool(true)) => {
+                match json.remove("cmd").map(serde_json::from_value::<CmdType>) {
+                    None => bail!("input has 'control' flag, but no control 'cmd' set.."),
+                    Some(Err(e)) => bail!("failed to parse control cmd - {}", e),
+                    Some(Ok(cmd_type)) => Ok((cmd_type, json)),
+                }
+            }
+            _ => Ok((CmdType::NonControl, json)),
+        }
+    }
+
+    async fn websocket_connect(
+        url: String,
+        text: bool,
+        extra_headers: Vec<(String, String)>,
+        fingerprint: Option<String>,
+    ) -> Result<Upgraded, Error> {
+        let mut req = Request::builder()
+            .uri(url)
+            .header(UPGRADE, "websocket")
+            .header(SEC_WEBSOCKET_VERSION, "13")
+            .header(SEC_WEBSOCKET_KEY, "foobar") // TODO
+            .header(SEC_WEBSOCKET_PROTOCOL, if text { "text" } else { "binary" })
+            .body(Body::empty())
+            .unwrap();
+
+        let headers = req.headers_mut();
+        for (name, value) in extra_headers {
+            let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
+            let value = hyper::header::HeaderValue::from_str(&value)?;
+            headers.insert(name, value);
+        }
+
+        let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls()).unwrap();
+        if fingerprint.is_some() {
+            // FIXME actually verify fingerprint via callback!
+            ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
+        } else {
+            ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
+        }
+
+        let mut httpc = HttpConnector::new();
+        httpc.enforce_http(false); // we want https...
+        httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0)));
+        let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build());
+
+        let client = Client::builder().build::<_, Body>(https);
+        let res = client.request(req).await?;
+        if res.status() != StatusCode::SWITCHING_PROTOCOLS {
+            bail!("server didn't upgrade: {}", res.status());
+        }
+
+        hyper::upgrade::on(res)
+            .await
+            .map_err(|err| format_err!("failed to upgrade - {}", err))
+    }
+
+    async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
+        let mut data: ConnectCmdData = data
+            .remove("data")
+            .ok_or_else(|| format_err!("'connect' command missing 'data'"))
+            .map(serde_json::from_value)??;
+
+        if self.sender.is_some() {
+            bail!("already connected!");
+        }
+
+        let upgraded = Self::websocket_connect(
+            data.url.clone(),
+            false,
+            data.headers.take().unwrap_or_else(Vec::new),
+            data.fingerprint.take(),
+        )
+        .await?;
+
+        let (tx, rx) = mpsc::unbounded_channel();
+        self.sender = Some(tx);
+        tokio::spawn(async move {
+            if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await {
+                eprintln!("Tunnel to {} failed - {}", data.url, err);
+            }
+        });
+
+        Ok(())
+    }
+
+    async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
+        let data: ForwardCmdData = data
+            .remove("data")
+            .ok_or_else(|| format_err!("'forward' command missing 'data'"))
+            .map(serde_json::from_value)??;
+
+        if self.sender.is_none() && data.ticket.is_some() {
+            bail!("dynamically requesting ticket requires cmd tunnel connection!");
+        }
+
+        let unix_listener = UnixListener::bind(data.unix.clone()).unwrap();
+        let (tx, rx) = oneshot::channel();
+        let data = Arc::new(data);
+
+        self.forwarded.lock().unwrap().push(tx);
+        let cmd_sender = self.sender.clone();
+
+        tokio::spawn(async move {
+            let mut rx = rx.fuse();
+            loop {
+                let accept = unix_listener.accept().fuse();
+                tokio::pin!(accept);
+                let data2 = data.clone();
+                select! {
+                    _ = rx => {
+                        eprintln!("received shutdown signal, closing unix listener stream");
+                    },
+                    res = accept => match res {
+                        Ok((unix_stream, _)) => {
+                            eprintln!("accepted new connection on '{}'", data2.unix);
+                            let data3: Result<Arc<ForwardCmdData>, Error> = match (&cmd_sender, &data2.ticket) {
+                                (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data2.clone()).await,
+                                _ => Ok(data2.clone()),
+                            };
+
+                            match data3 {
+                                Ok(data3) => {
+                                    tokio::spawn(async move {
+                                        if let Err(err) = Self::handle_forward_tunnel(data3.clone(), unix_stream).await {
+                                            eprintln!("Tunnel for {} failed - {}", data3.unix, err);
+                                        }
+                                    });
+                                },
+                                Err(err) => {
+                                    eprintln!("Failed to accept unix connection - {}", err);
+                                },
+                            };
+                        },
+                        Err(err) => eprintln!("Failed to accept unix connection on {} - {}", data2.unix, err),
+                    },
+                };
+            }
+        });
+
+        Ok(())
+    }
+
+    async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> {
+        match &mut self.sender {
+            None => bail!("not connected!"),
+            Some(sender) => {
+                let data: Value = data.into();
+                let (tx, rx) = oneshot::channel::<String>();
+                eprintln!("tunnel request: '{}'", data.to_string());
+                sender.send((data, tx))?;
+                let res = rx.await?;
+                eprintln!("tunnel response: '{}'", res);
+                Ok(res)
+            }
+        }
+    }
+
+    async fn handle_ctrl_tunnel(
+        websocket: Upgraded,
+        mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>,
+    ) -> Result<(), Error> {
+        let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket);
+        let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel();
+        let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx);
+        let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), false, tunnel_writer);
+
+        let mut framed_reader =
+            tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new());
+
+        let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new();
+        let mut shutting_down = false;
+
+        loop {
+            select! {
+                res = ws_close_rx.recv().fuse() => {
+                    let res = res.ok_or_else(|| format_err!("WS control channel closed"))?;
+                    eprintln!("WS: received control message: '{:?}'", res);
+                },
+                res = framed_reader.next().fuse() => {
+                    match res {
+                        None if shutting_down => {
+                            eprintln!("WS closed");
+                            break;
+                        },
+                        None => bail!("WS closed unexpectedly"),
+                        Some(Ok(res)) => {
+                            resp_tx_queue
+                                .pop_front()
+                                .ok_or_else(|| format_err!("no response handler"))?
+                                .send(res)
+                                .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?;
+                        },
+                        Some(Err(err)) => {
+                            eprintln!("WS: received failed - {}", err);
+                            // TODO handle?
+                        },
+                    }
+                },
+                res = cmd_receiver.recv().fuse() => {
+                    if shutting_down { continue };
+                    match res {
+                        None => {
+                            eprintln!("CMD channel closed, shutting down");
+                            ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?;
+                            shutting_down = true;
+                        },
+                        Some((msg, resp_tx)) => {
+                            resp_tx_queue.push_back(resp_tx);
+
+                            let line = format!("{}\n", msg);
+                            ws_writer.write_all(line.as_bytes()).await?;
+                            ws_writer.flush().await?;
+                        },
+                    }
+                },
+            };
+        }
+
+        Ok(())
+    }
+
+    async fn handle_forward_tunnel(
+        data: Arc<ForwardCmdData>,
+        unix: UnixStream,
+    ) -> Result<(), Error> {
+        let upgraded = Self::websocket_connect(
+            data.url.clone(),
+            false,
+            data.headers.clone().unwrap_or_else(Vec::new),
+            data.fingerprint.clone(),
+        )
+        .await?;
+
+        let ws = WebSocket { text: false, mask: Some([0, 0, 0, 0]) };
+        eprintln!("established new WS for forwarding '{}'", data.unix);
+        ws.serve_connection(upgraded, unix).await?;
+
+        eprintln!(
+            "done handling forwarded connection from '{}'",
+            data.unix
+        );
+
+        Ok(())
+    }
+
+    async fn get_ticket(
+        cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>,
+        cmd_data: Arc<ForwardCmdData>,
+    ) -> Result<Arc<ForwardCmdData>, Error> {
+        eprintln!("requesting WS ticket via tunnel");
+        let ticket_cmd = match cmd_data.ticket.clone() {
+            Some(mut ticket_cmd) => {
+                ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket"));
+                ticket_cmd
+            },
+            None => bail!("can't get ticket without ticket parameters"),
+        };
+        let (tx, rx) = oneshot::channel::<String>();
+        cmd_sender.send((serde_json::json!(ticket_cmd), tx))?;
+        let ticket = rx.await?;
+        let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?;
+        let ticket = ticket.remove("ticket")
+            .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?;
+
+        let ticket = ticket.as_str().ok_or_else(|| format_err!("failed to format received ticket"))?;
+        let ticket = utf8_percent_encode(&ticket, NON_ALPHANUMERIC).to_string();
+
+        let mut data = cmd_data.clone();
+        let mut url = data.url.clone();
+        url.push_str("ticket=");
+        url.push_str(&ticket);
+        let mut d = Arc::make_mut(&mut data);
+        d.url = url;
+        Ok(data)
+    }
+}
+
+#[tokio::main]
+async fn main() -> Result<(), Error> {
+    do_main().await
+}
+
+async fn do_main() -> Result<(), Error> {
+    let tunnel = CtrlTunnel {
+        sender: None,
+        forwarded: Arc::new(Mutex::new(Vec::new())),
+    };
+    tunnel.read_cmd_loop().await
+}
-- 
2.20.1






More information about the pve-devel mailing list