[pve-devel] [PATCH v3 proxmox-websocket-tunnel 2/4] add tunnel implementation
Fabian Grünbichler
f.gruenbichler at proxmox.com
Wed Dec 22 14:52:41 CET 2021
the websocket tunnel helper accepts control commands (encoded as
single-line JSON) on stdin, and prints responses on stdout.
the following commands are available:
- "connect" a 'control' tunnel via a websocket
- "forward" a local unix socket to a remote socket via a websocket
-- if requested, this will ask for a ticket via the control tunnel after
accepting a new connection on the unix socket
- "close" the control tunnel and any forwarded socket
any other json input (without the 'control' flag set) is forwarded as-is
to the remote end of the control tunnel.
internally, the tunnel helper will spawn tokio tasks for
- handling the control tunnel connection (new commands are passed in via
an mpsc channel together with a oneshot channel for the response)
- handling accepting new connections on each forwarded unix socket
- handling forwarding data over accepted forwarded connections
Signed-off-by: Fabian Grünbichler <f.gruenbichler at proxmox.com>
---
Notes:
v3:
- rebased, use proxmox-sys instead of proxmox for linux::random_data
v2:
- dropped CloseCmd and related code
- moved call to get_ticket into forward handler task
- bubble up errors instead of unwrap()
Cargo.toml | 15 ++
src/main.rs | 396 ++++++++++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 411 insertions(+)
create mode 100644 src/main.rs
diff --git a/Cargo.toml b/Cargo.toml
index 939184c..7f24602 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -9,3 +9,18 @@ description = "Proxmox websocket tunneling helper"
exclude = ["debian"]
[dependencies]
+anyhow = "1.0"
+base64 = "0.13"
+futures = "0.3"
+futures-util = "0.3"
+hex = "0.4"
+hyper = { version = "0.14" }
+openssl = "0.10"
+percent-encoding = "2"
+proxmox-http = { version = "0.6", path = "../proxmox/proxmox-http", features = ["websocket", "client"] }
+proxmox-sys = { version = "0.2.2", path = "../proxmox/proxmox-sys" }
+serde = { version = "1.0", features = ["derive"] }
+serde_json = "1.0"
+tokio = { version = "1", features = ["io-std", "io-util", "macros", "rt-multi-thread", "sync"] }
+tokio-stream = { version = "0.1", features = ["io-util"] }
+tokio-util = "0.6"
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..582214c
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,396 @@
+use anyhow::{bail, format_err, Error};
+
+use std::collections::VecDeque;
+use std::sync::Arc;
+
+use futures::future::FutureExt;
+use futures::select;
+
+use hyper::client::{Client, HttpConnector};
+use hyper::header::{SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE};
+use hyper::upgrade::Upgraded;
+use hyper::{Body, Request, StatusCode};
+
+use openssl::ssl::{SslConnector, SslMethod};
+use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
+
+use serde::{Deserialize, Serialize};
+use serde_json::{Map, Value};
+use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
+use tokio::net::{UnixListener, UnixStream};
+use tokio::sync::{mpsc, oneshot};
+use tokio_stream::wrappers::LinesStream;
+use tokio_stream::StreamExt;
+
+use proxmox_http::client::HttpsConnector;
+use proxmox_http::websocket::{OpCode, WebSocket, WebSocketReader, WebSocketWriter};
+
+#[derive(Serialize, Deserialize, Debug)]
+#[serde(rename_all = "kebab-case")]
+enum CmdType {
+ Connect,
+ Forward,
+ NonControl,
+}
+
+type CmdData = Map<String, Value>;
+
+#[derive(Serialize, Deserialize, Debug)]
+#[serde(rename_all = "kebab-case")]
+struct ConnectCmdData {
+ // target URL for WS connection
+ url: String,
+ // fingerprint of TLS certificate
+ fingerprint: Option<String>,
+ // addition headers such as authorization
+ headers: Option<Vec<(String, String)>>,
+}
+
+#[derive(Serialize, Deserialize, Debug, Clone)]
+#[serde(rename_all = "kebab-case")]
+struct ForwardCmdData {
+ // target URL for WS connection
+ url: String,
+ // addition headers such as authorization
+ headers: Option<Vec<(String, String)>>,
+ // fingerprint of TLS certificate
+ fingerprint: Option<String>,
+ // local UNIX socket path for forwarding
+ unix: String,
+ // request ticket using these parameters
+ ticket: Option<Map<String, Value>>,
+}
+
+struct CtrlTunnel {
+ sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
+}
+
+impl CtrlTunnel {
+ async fn read_cmd_loop(mut self) -> Result<(), Error> {
+ let mut stdin_stream = LinesStream::new(BufReader::new(tokio::io::stdin()).lines());
+ while let Some(res) = stdin_stream.next().await {
+ match res {
+ Ok(line) => {
+ let (cmd_type, data) = Self::parse_cmd(&line)?;
+ match cmd_type {
+ CmdType::Connect => self.handle_connect_cmd(data).await,
+ CmdType::Forward => {
+ let res = self.handle_forward_cmd(data).await;
+ match &res {
+ Ok(()) => println!("{}", serde_json::json!({"success": true})),
+ Err(msg) => println!(
+ "{}",
+ serde_json::json!({"success": false, "msg": msg.to_string()})
+ ),
+ };
+ res
+ }
+ CmdType::NonControl => self
+ .handle_tunnel_cmd(data)
+ .await
+ .map(|res| println!("{}", res)),
+ }
+ }
+ Err(err) => bail!("Failed to read from STDIN - {}", err),
+ }?;
+ }
+
+ Ok(())
+ }
+
+ fn parse_cmd(line: &str) -> Result<(CmdType, CmdData), Error> {
+ let mut json: Map<String, Value> = serde_json::from_str(line)?;
+ match json.remove("control") {
+ Some(Value::Bool(true)) => {
+ match json.remove("cmd").map(serde_json::from_value::<CmdType>) {
+ None => bail!("input has 'control' flag, but no control 'cmd' set.."),
+ Some(Err(e)) => bail!("failed to parse control cmd - {}", e),
+ Some(Ok(cmd_type)) => Ok((cmd_type, json)),
+ }
+ }
+ _ => Ok((CmdType::NonControl, json)),
+ }
+ }
+
+ async fn websocket_connect(
+ url: String,
+ extra_headers: Vec<(String, String)>,
+ fingerprint: Option<String>,
+ ) -> Result<Upgraded, Error> {
+ let ws_key = proxmox_sys::linux::random_data(16)?;
+ let ws_key = base64::encode(&ws_key);
+ let mut req = Request::builder()
+ .uri(url)
+ .header(UPGRADE, "websocket")
+ .header(SEC_WEBSOCKET_VERSION, "13")
+ .header(SEC_WEBSOCKET_KEY, ws_key)
+ .body(Body::empty())?;
+
+ let headers = req.headers_mut();
+ for (name, value) in extra_headers {
+ let name = hyper::header::HeaderName::from_bytes(name.as_bytes())?;
+ let value = hyper::header::HeaderValue::from_str(&value)?;
+ headers.insert(name, value);
+ }
+
+ let mut ssl_connector_builder = SslConnector::builder(SslMethod::tls())?;
+ if fingerprint.is_some() {
+ // FIXME actually verify fingerprint via callback!
+ ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::NONE);
+ } else {
+ ssl_connector_builder.set_verify(openssl::ssl::SslVerifyMode::PEER);
+ }
+
+ let mut httpc = HttpConnector::new();
+ httpc.enforce_http(false); // we want https...
+ httpc.set_connect_timeout(Some(std::time::Duration::new(10, 0)));
+ let https = HttpsConnector::with_connector(httpc, ssl_connector_builder.build(), 120);
+
+ let client = Client::builder().build::<_, Body>(https);
+ let res = client.request(req).await?;
+ if res.status() != StatusCode::SWITCHING_PROTOCOLS {
+ bail!("server didn't upgrade: {}", res.status());
+ }
+
+ hyper::upgrade::on(res)
+ .await
+ .map_err(|err| format_err!("failed to upgrade - {}", err))
+ }
+
+ async fn handle_connect_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
+ let mut data: ConnectCmdData = data
+ .remove("data")
+ .ok_or_else(|| format_err!("'connect' command missing 'data'"))
+ .map(serde_json::from_value)??;
+
+ if self.sender.is_some() {
+ bail!("already connected!");
+ }
+
+ let upgraded = Self::websocket_connect(
+ data.url.clone(),
+ data.headers.take().unwrap_or_else(Vec::new),
+ data.fingerprint.take(),
+ )
+ .await?;
+
+ let (tx, rx) = mpsc::unbounded_channel();
+ self.sender = Some(tx);
+ tokio::spawn(async move {
+ if let Err(err) = Self::handle_ctrl_tunnel(upgraded, rx).await {
+ eprintln!("Tunnel to {} failed - {}", data.url, err);
+ }
+ });
+
+ Ok(())
+ }
+
+ async fn handle_forward_cmd(&mut self, mut data: CmdData) -> Result<(), Error> {
+ let data: ForwardCmdData = data
+ .remove("data")
+ .ok_or_else(|| format_err!("'forward' command missing 'data'"))
+ .map(serde_json::from_value)??;
+
+ if self.sender.is_none() && data.ticket.is_some() {
+ bail!("dynamically requesting ticket requires cmd tunnel connection!");
+ }
+
+ let unix_listener = UnixListener::bind(data.unix.clone())?;
+ let data = Arc::new(data);
+ let cmd_sender = self.sender.clone();
+
+ tokio::spawn(async move {
+ let mut tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new();
+ let data2 = data.clone();
+
+ loop {
+ let data3 = data2.clone();
+
+ match unix_listener.accept().await {
+ Ok((unix_stream, _)) => {
+ eprintln!("accepted new connection on '{}'", data3.unix);
+ let cmd_sender2 = cmd_sender.clone();
+
+ let task = tokio::spawn(async move {
+ if let Err(err) = Self::handle_forward_tunnel(
+ cmd_sender2.clone(),
+ data3.clone(),
+ unix_stream,
+ )
+ .await
+ {
+ eprintln!("Tunnel for {} failed - {}", data3.unix, err);
+ }
+ });
+ tasks.push(task);
+ }
+ Err(err) => eprintln!(
+ "Failed to accept unix connection on {} - {}",
+ data3.unix, err
+ ),
+ };
+ }
+ });
+
+ Ok(())
+ }
+
+ async fn handle_tunnel_cmd(&mut self, data: CmdData) -> Result<String, Error> {
+ match &mut self.sender {
+ None => bail!("not connected!"),
+ Some(sender) => {
+ let data: Value = data.into();
+ let (tx, rx) = oneshot::channel::<String>();
+ if let Some(cmd) = data.get("cmd") {
+ eprintln!("-> sending command {} to remote", cmd);
+ } else {
+ eprintln!("-> sending data line to remote");
+ }
+ sender.send((data, tx))?;
+ let res = rx.await?;
+ eprintln!("<- got reply");
+ Ok(res)
+ }
+ }
+ }
+
+ async fn handle_ctrl_tunnel(
+ websocket: Upgraded,
+ mut cmd_receiver: mpsc::UnboundedReceiver<(Value, oneshot::Sender<String>)>,
+ ) -> Result<(), Error> {
+ let (tunnel_reader, tunnel_writer) = tokio::io::split(websocket);
+ let (ws_close_tx, mut ws_close_rx) = mpsc::unbounded_channel();
+ let ws_reader = WebSocketReader::new(tunnel_reader, ws_close_tx);
+ let mut ws_writer = WebSocketWriter::new(Some([0, 0, 0, 0]), tunnel_writer);
+
+ let mut framed_reader =
+ tokio_util::codec::FramedRead::new(ws_reader, tokio_util::codec::LinesCodec::new());
+
+ let mut resp_tx_queue: VecDeque<oneshot::Sender<String>> = VecDeque::new();
+ let mut shutting_down = false;
+
+ loop {
+ let mut close_future = ws_close_rx.recv().boxed().fuse();
+ let mut frame_future = framed_reader.next().boxed().fuse();
+ let mut cmd_future = cmd_receiver.recv().boxed().fuse();
+
+ select! {
+ res = close_future => {
+ let res = res.ok_or_else(|| format_err!("WS control channel closed"))?;
+ eprintln!("WS: received control message: '{:?}'", res);
+ shutting_down = true;
+ },
+ res = frame_future => {
+ match res {
+ None if shutting_down => {
+ eprintln!("WS closed");
+ break;
+ },
+ None => bail!("WS closed unexpectedly"),
+ Some(Ok(res)) => {
+ resp_tx_queue
+ .pop_front()
+ .ok_or_else(|| format_err!("no response handler"))?
+ .send(res)
+ .map_err(|msg| format_err!("failed to send tunnel response '{}' back to requester - receiver already closed?", msg))?;
+ },
+ Some(Err(err)) => {
+ bail!("reading from control tunnel failed - WS receive failed: {}", err);
+ },
+ }
+ },
+ res = cmd_future => {
+ if shutting_down { continue };
+ match res {
+ None => {
+ eprintln!("CMD channel closed, shutting down");
+ ws_writer.send_control_frame(Some([1,2,3,4]), OpCode::Close, &[]).await?;
+ shutting_down = true;
+ },
+ Some((msg, resp_tx)) => {
+ resp_tx_queue.push_back(resp_tx);
+
+ let line = format!("{}\n", msg);
+ ws_writer.write_all(line.as_bytes()).await?;
+ ws_writer.flush().await?;
+ },
+ }
+ },
+ };
+ }
+
+ Ok(())
+ }
+
+ async fn handle_forward_tunnel(
+ cmd_sender: Option<mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>>,
+ data: Arc<ForwardCmdData>,
+ unix: UnixStream,
+ ) -> Result<(), Error> {
+ let data = match (&cmd_sender, &data.ticket) {
+ (Some(cmd_sender), Some(_)) => Self::get_ticket(cmd_sender, data.clone()).await,
+ _ => Ok(data.clone()),
+ }?;
+
+ let upgraded = Self::websocket_connect(
+ data.url.clone(),
+ data.headers.clone().unwrap_or_else(Vec::new),
+ data.fingerprint.clone(),
+ )
+ .await?;
+
+ let ws = WebSocket {
+ mask: Some([0, 0, 0, 0]),
+ };
+ eprintln!("established new WS for forwarding '{}'", data.unix);
+ ws.serve_connection(upgraded, unix).await?;
+
+ eprintln!("done handling forwarded connection from '{}'", data.unix);
+
+ Ok(())
+ }
+
+ async fn get_ticket(
+ cmd_sender: &mpsc::UnboundedSender<(Value, oneshot::Sender<String>)>,
+ cmd_data: Arc<ForwardCmdData>,
+ ) -> Result<Arc<ForwardCmdData>, Error> {
+ eprintln!("requesting WS ticket via tunnel");
+ let ticket_cmd = match cmd_data.ticket.clone() {
+ Some(mut ticket_cmd) => {
+ ticket_cmd.insert("cmd".to_string(), serde_json::json!("ticket"));
+ ticket_cmd
+ }
+ None => bail!("can't get ticket without ticket parameters"),
+ };
+ let (tx, rx) = oneshot::channel::<String>();
+ cmd_sender.send((serde_json::json!(ticket_cmd), tx))?;
+ let ticket = rx.await?;
+ let mut ticket: Map<String, Value> = serde_json::from_str(&ticket)?;
+ let ticket = ticket
+ .remove("ticket")
+ .ok_or_else(|| format_err!("failed to retrieve ticket via tunnel"))?;
+
+ let ticket = ticket
+ .as_str()
+ .ok_or_else(|| format_err!("failed to format received ticket"))?;
+ let ticket = utf8_percent_encode(ticket, NON_ALPHANUMERIC).to_string();
+
+ let mut data = cmd_data.clone();
+ let mut url = data.url.clone();
+ url.push_str("ticket=");
+ url.push_str(&ticket);
+ let mut d = Arc::make_mut(&mut data);
+ d.url = url;
+ Ok(data)
+ }
+}
+
+#[tokio::main]
+async fn main() -> Result<(), Error> {
+ do_main().await
+}
+
+async fn do_main() -> Result<(), Error> {
+ let tunnel = CtrlTunnel { sender: None };
+ tunnel.read_cmd_loop().await
+}
--
2.30.2
More information about the pve-devel
mailing list