[pve-devel] [PATCH proxmox-firewall 03/37] config: firewall: add types for ports

Stefan Hanreich s.hanreich at proxmox.com
Tue Apr 2 19:15:55 CEST 2024


Adds types for all kinds of port-related values in the firewall config
as well as FromStr implementations for parsing them from the config.

Also adds a helper for parsing the named ports from `/etc/services`.

Co-authored-by: Wolfgang Bumiller <w.bumiller at proxmox.com>
Signed-off-by: Stefan Hanreich <s.hanreich at proxmox.com>
---
 proxmox-ve-config/src/firewall/mod.rs        |   1 +
 proxmox-ve-config/src/firewall/ports.rs      |  78 ++++++++
 proxmox-ve-config/src/firewall/types/mod.rs  |   1 +
 proxmox-ve-config/src/firewall/types/port.rs | 181 +++++++++++++++++++
 4 files changed, 261 insertions(+)
 create mode 100644 proxmox-ve-config/src/firewall/ports.rs
 create mode 100644 proxmox-ve-config/src/firewall/types/port.rs

diff --git a/proxmox-ve-config/src/firewall/mod.rs b/proxmox-ve-config/src/firewall/mod.rs
index cd40856..a9f65bf 100644
--- a/proxmox-ve-config/src/firewall/mod.rs
+++ b/proxmox-ve-config/src/firewall/mod.rs
@@ -1 +1,2 @@
+pub mod ports;
 pub mod types;
diff --git a/proxmox-ve-config/src/firewall/ports.rs b/proxmox-ve-config/src/firewall/ports.rs
new file mode 100644
index 0000000..96527f1
--- /dev/null
+++ b/proxmox-ve-config/src/firewall/ports.rs
@@ -0,0 +1,78 @@
+use anyhow::{format_err, Error};
+use std::sync::OnceLock;
+
+#[derive(Default)]
+struct NamedPorts {
+    ports: std::collections::HashMap<String, u16>,
+}
+
+impl NamedPorts {
+    fn new() -> Self {
+        use std::io::BufRead;
+
+        let mut this = Self::default();
+
+        let file = match std::fs::File::open("/etc/services") {
+            Ok(file) => file,
+            Err(_) => return this,
+        };
+
+        for line in std::io::BufReader::new(file).lines() {
+            let line = match line {
+                Ok(line) => line,
+                Err(_) => break,
+            };
+
+            let line = line.trim_start();
+
+            if line.is_empty() || line.starts_with('#') {
+                continue;
+            }
+
+            let mut parts = line.split_ascii_whitespace();
+
+            let name = match parts.next() {
+                None => continue,
+                Some(name) => name.to_string(),
+            };
+
+            let proto: u16 = match parts.next() {
+                None => continue,
+                Some(proto) => match proto.split('/').next() {
+                    None => continue,
+                    Some(num) => match num.parse() {
+                        Ok(num) => num,
+                        Err(_) => continue,
+                    },
+                },
+            };
+
+            this.ports.insert(name, proto);
+            for alias in parts {
+                if alias.starts_with('#') {
+                    break;
+                }
+                this.ports.insert(alias.to_string(), proto);
+            }
+        }
+
+        this
+    }
+
+    fn find(&self, name: &str) -> Option<u16> {
+        self.ports.get(name).copied()
+    }
+}
+
+fn named_ports() -> &'static NamedPorts {
+    static NAMED_PORTS: OnceLock<NamedPorts> = OnceLock::new();
+
+    NAMED_PORTS.get_or_init(NamedPorts::new)
+}
+
+/// Parse a named port with the help of `/etc/services`.
+pub fn parse_named_port(name: &str) -> Result<u16, Error> {
+    named_ports()
+        .find(name)
+        .ok_or_else(|| format_err!("unknown port name {name:?}"))
+}
diff --git a/proxmox-ve-config/src/firewall/types/mod.rs b/proxmox-ve-config/src/firewall/types/mod.rs
index de534b4..b740e5d 100644
--- a/proxmox-ve-config/src/firewall/types/mod.rs
+++ b/proxmox-ve-config/src/firewall/types/mod.rs
@@ -1,3 +1,4 @@
 pub mod address;
+pub mod port;
 
 pub use address::Cidr;
diff --git a/proxmox-ve-config/src/firewall/types/port.rs b/proxmox-ve-config/src/firewall/types/port.rs
new file mode 100644
index 0000000..c1252d9
--- /dev/null
+++ b/proxmox-ve-config/src/firewall/types/port.rs
@@ -0,0 +1,181 @@
+use std::fmt;
+use std::ops::Deref;
+
+use anyhow::{bail, Error};
+use serde_with::DeserializeFromStr;
+
+use crate::firewall::ports::parse_named_port;
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum PortEntry {
+    Port(u16),
+    Range(u16, u16),
+}
+
+impl fmt::Display for PortEntry {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::Port(p) => write!(f, "{p}"),
+            Self::Range(beg, end) => write!(f, "{beg}-{end}"),
+        }
+    }
+}
+
+fn parse_port(port: &str) -> Result<u16, Error> {
+    if let Ok(port) = port.parse::<u16>() {
+        return Ok(port);
+    }
+
+    if let Ok(port) = parse_named_port(port) {
+        return Ok(port);
+    }
+
+    bail!("invalid port specification: {port}")
+}
+
+impl std::str::FromStr for PortEntry {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        Ok(match s.trim().split_once(':') {
+            None => PortEntry::from(parse_port(s)?),
+            Some((first, second)) => {
+                PortEntry::try_from((parse_port(first)?, parse_port(second)?))?
+            }
+        })
+    }
+}
+
+impl From<u16> for PortEntry {
+    fn from(port: u16) -> Self {
+        PortEntry::Port(port)
+    }
+}
+
+impl TryFrom<(u16, u16)> for PortEntry {
+    type Error = Error;
+
+    fn try_from(ports: (u16, u16)) -> Result<Self, Error> {
+        if ports.0 > ports.1 {
+            bail!("start port is greater than end port!");
+        }
+
+        Ok(PortEntry::Range(ports.0, ports.1))
+    }
+}
+
+#[derive(Clone, Debug, DeserializeFromStr)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct PortList(pub(crate) Vec<PortEntry>);
+
+impl FromIterator<PortEntry> for PortList {
+    fn from_iter<T: IntoIterator<Item = PortEntry>>(iter: T) -> Self {
+        Self(iter.into_iter().collect())
+    }
+}
+
+impl<T: Into<PortEntry>> From<T> for PortList {
+    fn from(value: T) -> Self {
+        Self(vec![value.into()])
+    }
+}
+
+impl Deref for PortList {
+    type Target = Vec<PortEntry>;
+
+    fn deref(&self) -> &Self::Target {
+        &self.0
+    }
+}
+
+impl std::str::FromStr for PortList {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Error> {
+        if s.is_empty() {
+            bail!("empty port specification");
+        }
+
+        let mut entries = Vec::new();
+
+        for entry in s.trim().split(',') {
+            entries.push(entry.parse()?);
+        }
+
+        if entries.is_empty() {
+            bail!("invalid empty port list");
+        }
+
+        Ok(Self(entries))
+    }
+}
+
+impl fmt::Display for PortList {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        use fmt::Write;
+        if self.0.len() > 1 {
+            f.write_char('{')?;
+        }
+
+        let mut comma = '\0';
+        for entry in &self.0 {
+            if std::mem::replace(&mut comma, ',') != '\0' {
+                f.write_char(comma)?;
+            }
+            fmt::Display::fmt(entry, f)?;
+        }
+
+        if self.0.len() > 1 {
+            f.write_char('}')?;
+        }
+
+        Ok(())
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_parse_port_entry() {
+        let mut port_entry: PortEntry = "12345".parse().expect("valid port entry");
+        assert_eq!(port_entry, PortEntry::from(12345));
+
+        port_entry = "0:65535".parse().expect("valid port entry");
+        assert_eq!(port_entry, PortEntry::try_from((0, 65535)).unwrap());
+
+        "65536".parse::<PortEntry>().unwrap_err();
+        "100:100000".parse::<PortEntry>().unwrap_err();
+        "qweasd".parse::<PortEntry>().unwrap_err();
+        "".parse::<PortEntry>().unwrap_err();
+    }
+
+    #[test]
+    fn test_parse_port_list() {
+        let mut port_list: PortList = "12345".parse().expect("valid port list");
+        assert_eq!(port_list, PortList::from(12345));
+
+        port_list = "12345,0:65535,1337,ssh:80,https"
+            .parse()
+            .expect("valid port list");
+
+        assert_eq!(
+            port_list,
+            PortList(vec![
+                PortEntry::from(12345),
+                PortEntry::try_from((0, 65535)).unwrap(),
+                PortEntry::from(1337),
+                PortEntry::try_from((22, 80)).unwrap(),
+                PortEntry::from(443),
+            ])
+        );
+
+        "0::1337".parse::<PortList>().unwrap_err();
+        "0:1337,".parse::<PortList>().unwrap_err();
+        "70000".parse::<PortList>().unwrap_err();
+        "qweasd".parse::<PortList>().unwrap_err();
+        "".parse::<PortList>().unwrap_err();
+    }
+}
-- 
2.39.2




More information about the pve-devel mailing list