[pve-devel] [PATCH proxmox-firewall 09/37] config: firewall: add types for rules

Stefan Hanreich s.hanreich at proxmox.com
Tue Apr 2 19:16:01 CEST 2024


Additionally we implement FromStr for all rule types and parts, which
can be used for parsing firewall config rules. Initial rule parsing
works by parsing the different options into a HashMap and only then
de-serializing a struct from the parsed options.

This intermediate step makes rule parsing a lot easier, since we can
reuse the deserialization logic from serde. Also, we can split the
parsing/deserialization logic from the validation logic.

Co-authored-by: Wolfgang Bumiller <w.bumiller at proxmox.com>
Signed-off-by: Stefan Hanreich <s.hanreich at proxmox.com>
---
 proxmox-ve-config/src/firewall/parse.rs       | 185 ++++
 proxmox-ve-config/src/firewall/types/mod.rs   |   3 +
 proxmox-ve-config/src/firewall/types/rule.rs  | 412 ++++++++
 .../src/firewall/types/rule_match.rs          | 953 ++++++++++++++++++
 4 files changed, 1553 insertions(+)
 create mode 100644 proxmox-ve-config/src/firewall/types/rule.rs
 create mode 100644 proxmox-ve-config/src/firewall/types/rule_match.rs

diff --git a/proxmox-ve-config/src/firewall/parse.rs b/proxmox-ve-config/src/firewall/parse.rs
index 669623b..227e045 100644
--- a/proxmox-ve-config/src/firewall/parse.rs
+++ b/proxmox-ve-config/src/firewall/parse.rs
@@ -1,3 +1,5 @@
+use std::fmt;
+
 use anyhow::{bail, format_err, Error};
 
 /// Parses out a "name" which can be alphanumeric and include dashes.
@@ -78,3 +80,186 @@ pub fn parse_bool(value: &str) -> Result<bool, Error> {
         },
     )
 }
+
+/// `&str` deserializer which also accepts an `Option`.
+///
+/// Serde's `StringDeserializer` does not.
+#[derive(Clone, Copy, Debug)]
+pub struct SomeStrDeserializer<'a, E>(serde::de::value::StrDeserializer<'a, E>);
+
+impl<'de, 'a, E> serde::de::Deserializer<'de> for SomeStrDeserializer<'a, E>
+where
+    E: serde::de::Error,
+{
+    type Error = E;
+
+    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        self.0.deserialize_any(visitor)
+    }
+
+    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_some(self.0)
+    }
+
+    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        self.0.deserialize_str(visitor)
+    }
+
+    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        self.0.deserialize_string(visitor)
+    }
+
+    fn deserialize_enum<V>(
+        self,
+        _name: &str,
+        _variants: &'static [&'static str],
+        visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_enum(self.0)
+    }
+
+    serde::forward_to_deserialize_any! {
+        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char
+        bytes byte_buf unit unit_struct newtype_struct seq tuple
+        tuple_struct map struct identifier ignored_any
+    }
+}
+
+/// `&str` wrapper which implements `IntoDeserializer` via `SomeStrDeserializer`.
+#[derive(Clone, Debug)]
+pub struct SomeStr<'a>(pub &'a str);
+
+impl<'a> From<&'a str> for SomeStr<'a> {
+    fn from(s: &'a str) -> Self {
+        Self(s)
+    }
+}
+
+impl<'de, 'a, E> serde::de::IntoDeserializer<'de, E> for SomeStr<'a>
+where
+    E: serde::de::Error,
+{
+    type Deserializer = SomeStrDeserializer<'a, E>;
+
+    fn into_deserializer(self) -> Self::Deserializer {
+        SomeStrDeserializer(self.0.into_deserializer())
+    }
+}
+
+/// `String` deserializer which also accepts an `Option`.
+///
+/// Serde's `StringDeserializer` does not.
+#[derive(Clone, Debug)]
+pub struct SomeStringDeserializer<E>(serde::de::value::StringDeserializer<E>);
+
+impl<'de, E> serde::de::Deserializer<'de> for SomeStringDeserializer<E>
+where
+    E: serde::de::Error,
+{
+    type Error = E;
+
+    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        self.0.deserialize_any(visitor)
+    }
+
+    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_some(self.0)
+    }
+
+    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        self.0.deserialize_str(visitor)
+    }
+
+    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        self.0.deserialize_string(visitor)
+    }
+
+    fn deserialize_enum<V>(
+        self,
+        _name: &str,
+        _variants: &'static [&'static str],
+        visitor: V,
+    ) -> Result<V::Value, Self::Error>
+    where
+        V: serde::de::Visitor<'de>,
+    {
+        visitor.visit_enum(self.0)
+    }
+
+    serde::forward_to_deserialize_any! {
+        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char
+        bytes byte_buf unit unit_struct newtype_struct seq tuple
+        tuple_struct map struct identifier ignored_any
+    }
+}
+
+/// `&str` wrapper which implements `IntoDeserializer` via `SomeStringDeserializer`.
+#[derive(Clone, Debug)]
+pub struct SomeString(pub String);
+
+impl From<&str> for SomeString {
+    fn from(s: &str) -> Self {
+        Self::from(s.to_string())
+    }
+}
+
+impl From<String> for SomeString {
+    fn from(s: String) -> Self {
+        Self(s)
+    }
+}
+
+impl<'de, E> serde::de::IntoDeserializer<'de, E> for SomeString
+where
+    E: serde::de::Error,
+{
+    type Deserializer = SomeStringDeserializer<E>;
+
+    fn into_deserializer(self) -> Self::Deserializer {
+        SomeStringDeserializer(self.0.into_deserializer())
+    }
+}
+
+#[derive(Debug)]
+pub struct SerdeStringError(String);
+
+impl std::error::Error for SerdeStringError {}
+
+impl fmt::Display for SerdeStringError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        f.write_str(&self.0)
+    }
+}
+
+impl serde::de::Error for SerdeStringError {
+    fn custom<T: fmt::Display>(msg: T) -> Self {
+        Self(msg.to_string())
+    }
+}
diff --git a/proxmox-ve-config/src/firewall/types/mod.rs b/proxmox-ve-config/src/firewall/types/mod.rs
index 5833787..b4a6b12 100644
--- a/proxmox-ve-config/src/firewall/types/mod.rs
+++ b/proxmox-ve-config/src/firewall/types/mod.rs
@@ -3,7 +3,10 @@ pub mod alias;
 pub mod ipset;
 pub mod log;
 pub mod port;
+pub mod rule;
+pub mod rule_match;
 
 pub use address::Cidr;
 pub use alias::Alias;
 pub use ipset::Ipset;
+pub use rule::Rule;
diff --git a/proxmox-ve-config/src/firewall/types/rule.rs b/proxmox-ve-config/src/firewall/types/rule.rs
new file mode 100644
index 0000000..20deb3a
--- /dev/null
+++ b/proxmox-ve-config/src/firewall/types/rule.rs
@@ -0,0 +1,412 @@
+use core::fmt::Display;
+use std::fmt;
+use std::str::FromStr;
+
+use anyhow::{bail, ensure, format_err, Error};
+
+use crate::firewall::parse::match_name;
+use crate::firewall::types::rule_match::RuleMatch;
+use crate::firewall::types::rule_match::RuleOptions;
+
+#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
+pub enum Direction {
+    #[default]
+    In,
+    Out,
+}
+
+impl std::str::FromStr for Direction {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Error> {
+        for (name, dir) in [("IN", Direction::In), ("OUT", Direction::Out)] {
+            if s.eq_ignore_ascii_case(name) {
+                return Ok(dir);
+            }
+        }
+
+        bail!("invalid direction: {s:?}, expect 'IN' or 'OUT'");
+    }
+}
+
+serde_plain::derive_deserialize_from_fromstr!(Direction, "valid packet direction");
+
+impl fmt::Display for Direction {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Direction::In => f.write_str("in"),
+            Direction::Out => f.write_str("out"),
+        }
+    }
+}
+
+#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
+pub enum Verdict {
+    Accept,
+    Reject,
+    #[default]
+    Drop,
+}
+
+impl std::str::FromStr for Verdict {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Error> {
+        for (name, verdict) in [
+            ("ACCEPT", Verdict::Accept),
+            ("REJECT", Verdict::Reject),
+            ("DROP", Verdict::Drop),
+        ] {
+            if s.eq_ignore_ascii_case(name) {
+                return Ok(verdict);
+            }
+        }
+        bail!("invalid verdict {s:?}, expected one of 'ACCEPT', 'REJECT' or 'DROP'");
+    }
+}
+
+impl Display for Verdict {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        let string = match self {
+            Verdict::Accept => "ACCEPT",
+            Verdict::Drop => "DROP",
+            Verdict::Reject => "REJECT",
+        };
+
+        write!(f, "{string}")
+    }
+}
+
+serde_plain::derive_deserialize_from_fromstr!(Verdict, "valid verdict");
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Rule {
+    pub(crate) disabled: bool,
+    pub(crate) kind: Kind,
+    pub(crate) comment: Option<String>,
+}
+
+impl std::ops::Deref for Rule {
+    type Target = Kind;
+
+    fn deref(&self) -> &Self::Target {
+        &self.kind
+    }
+}
+
+impl std::ops::DerefMut for Rule {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        &mut self.kind
+    }
+}
+
+impl FromStr for Rule {
+    type Err = Error;
+
+    fn from_str(input: &str) -> Result<Self, Self::Err> {
+        if input.contains(['\n', '\r']) {
+            bail!("rule must not contain any newlines!");
+        }
+
+        let (line, comment) = match input.rsplit_once('#') {
+            Some((line, comment)) if !comment.is_empty() => (line.trim(), Some(comment.trim())),
+            _ => (input.trim(), None),
+        };
+
+        let (disabled, line) = match line.strip_prefix('|') {
+            Some(line) => (true, line.trim_start()),
+            None => (false, line),
+        };
+
+        // todo: case insensitive?
+        let kind = if line.starts_with("GROUP") {
+            Kind::from(line.parse::<RuleGroup>()?)
+        } else {
+            Kind::from(line.parse::<RuleMatch>()?)
+        };
+
+        Ok(Self {
+            disabled,
+            comment: comment.map(str::to_string),
+            kind,
+        })
+    }
+}
+
+impl Rule {
+    pub fn iface(&self) -> Option<&str> {
+        match &self.kind {
+            Kind::Group(group) => group.iface(),
+            Kind::Match(rule) => rule.iface(),
+        }
+    }
+
+    pub fn disabled(&self) -> bool {
+        self.disabled
+    }
+
+    pub fn kind(&self) -> &Kind {
+        &self.kind
+    }
+
+    pub fn comment(&self) -> Option<&str> {
+        self.comment.as_deref()
+    }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum Kind {
+    Group(RuleGroup),
+    Match(RuleMatch),
+}
+
+impl Kind {
+    pub fn is_group(&self) -> bool {
+        matches!(self, Kind::Group(_))
+    }
+
+    pub fn is_match(&self) -> bool {
+        matches!(self, Kind::Match(_))
+    }
+}
+
+impl From<RuleGroup> for Kind {
+    fn from(value: RuleGroup) -> Self {
+        Kind::Group(value)
+    }
+}
+
+impl From<RuleMatch> for Kind {
+    fn from(value: RuleMatch) -> Self {
+        Kind::Match(value)
+    }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct RuleGroup {
+    pub(crate) group: String,
+    pub(crate) iface: Option<String>,
+}
+
+impl RuleGroup {
+    pub(crate) fn from_options(group: String, options: RuleOptions) -> Result<Self, Error> {
+        ensure!(
+            options.proto.is_none()
+                && options.dport.is_none()
+                && options.sport.is_none()
+                && options.dest.is_none()
+                && options.source.is_none()
+                && options.log.is_none()
+                && options.icmp_type.is_none(),
+            "only interface parameter is permitted for group rules"
+        );
+
+        Ok(Self {
+            group,
+            iface: options.iface,
+        })
+    }
+
+    pub fn group(&self) -> &str {
+        &self.group
+    }
+
+    pub fn iface(&self) -> Option<&str> {
+        self.iface.as_deref()
+    }
+}
+
+impl FromStr for RuleGroup {
+    type Err = Error;
+
+    fn from_str(input: &str) -> Result<Self, Self::Err> {
+        let (keyword, rest) = match_name(input)
+            .ok_or_else(|| format_err!("expected a leading keyword in rule group"))?;
+
+        if !keyword.eq_ignore_ascii_case("group") {
+            bail!("Expected keyword GROUP")
+        }
+
+        let (name, rest) =
+            match_name(rest.trim()).ok_or_else(|| format_err!("expected a name for rule group"))?;
+
+        let options = rest.trim_start().parse()?;
+
+        Self::from_options(name.to_string(), options)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::firewall::types::{
+        address::{IpEntry, IpList},
+        alias::{AliasName, AliasScope},
+        ipset::{IpsetName, IpsetScope},
+        log::LogLevel,
+        rule_match::{Icmp, IcmpCode, IpAddrMatch, IpMatch, Ports, Protocol, Udp},
+        Cidr,
+    };
+
+    use super::*;
+
+    #[test]
+    fn test_parse_rule() {
+        let mut rule: Rule = "|GROUP tgr -i eth0 # acomm".parse().expect("valid rule");
+
+        assert_eq!(
+            rule,
+            Rule {
+                disabled: true,
+                comment: Some("acomm".to_string()),
+                kind: Kind::Group(RuleGroup {
+                    group: "tgr".to_string(),
+                    iface: Some("eth0".to_string()),
+                }),
+            },
+        );
+
+        rule = "IN ACCEPT -p udp -dport 33 -sport 22 -log warning"
+            .parse()
+            .expect("valid rule");
+
+        assert_eq!(
+            rule,
+            Rule {
+                disabled: false,
+                comment: None,
+                kind: Kind::Match(RuleMatch {
+                    dir: Direction::In,
+                    verdict: Verdict::Accept,
+                    proto: Some(Udp::new(Ports::from_u16(22, 33)).into()),
+                    log: Some(LogLevel::Warning),
+                    ..Default::default()
+                }),
+            }
+        );
+
+        rule = "IN ACCEPT --proto udp -i eth0".parse().expect("valid rule");
+
+        assert_eq!(
+            rule,
+            Rule {
+                disabled: false,
+                comment: None,
+                kind: Kind::Match(RuleMatch {
+                    dir: Direction::In,
+                    verdict: Verdict::Accept,
+                    proto: Some(Udp::new(Ports::new(None, None)).into()),
+                    iface: Some("eth0".to_string()),
+                    ..Default::default()
+                }),
+            }
+        );
+
+        rule = " OUT DROP \
+          -source 10.0.0.0/24 -dest 20.0.0.0-20.255.255.255,192.168.0.0/16 \
+          -p icmp -log nolog -icmp-type port-unreachable "
+            .parse()
+            .expect("valid rule");
+
+        assert_eq!(
+            rule,
+            Rule {
+                disabled: false,
+                comment: None,
+                kind: Kind::Match(RuleMatch {
+                    dir: Direction::Out,
+                    verdict: Verdict::Drop,
+                    ip: IpMatch::new(
+                        IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 0], 24).unwrap())),
+                        IpAddrMatch::Ip(
+                            IpList::new(vec![
+                                IpEntry::Range([20, 0, 0, 0].into(), [20, 255, 255, 255].into()),
+                                IpEntry::Cidr(Cidr::new_v4([192, 168, 0, 0], 16).unwrap()),
+                            ])
+                            .unwrap()
+                        ),
+                    )
+                    .ok(),
+                    proto: Some(Protocol::Icmp(Icmp::new_code(IcmpCode::Named(
+                        "port-unreachable"
+                    )))),
+                    log: Some(LogLevel::Nolog),
+                    ..Default::default()
+                }),
+            }
+        );
+
+        rule = "IN BGP(ACCEPT) --log crit --iface eth0"
+            .parse()
+            .expect("valid rule");
+
+        assert_eq!(
+            rule,
+            Rule {
+                disabled: false,
+                comment: None,
+                kind: Kind::Match(RuleMatch {
+                    dir: Direction::In,
+                    verdict: Verdict::Accept,
+                    log: Some(LogLevel::Critical),
+                    fw_macro: Some("BGP".to_string()),
+                    iface: Some("eth0".to_string()),
+                    ..Default::default()
+                }),
+            }
+        );
+
+        rule = "IN ACCEPT --source dc/test --dest +dc/test"
+            .parse()
+            .expect("valid rule");
+
+        assert_eq!(
+            rule,
+            Rule {
+                disabled: false,
+                comment: None,
+                kind: Kind::Match(RuleMatch {
+                    dir: Direction::In,
+                    verdict: Verdict::Accept,
+                    ip: Some(
+                        IpMatch::new(
+                            IpAddrMatch::Alias(AliasName::new(AliasScope::Datacenter, "test")),
+                            IpAddrMatch::Set(IpsetName::new(IpsetScope::Datacenter, "test"),),
+                        )
+                        .unwrap()
+                    ),
+                    ..Default::default()
+                }),
+            }
+        );
+
+        rule = "IN REJECT".parse().expect("valid rule");
+
+        assert_eq!(
+            rule,
+            Rule {
+                disabled: false,
+                comment: None,
+                kind: Kind::Match(RuleMatch {
+                    dir: Direction::In,
+                    verdict: Verdict::Reject,
+                    ..Default::default()
+                }),
+            }
+        );
+
+        "IN DROP ---log crit"
+            .parse::<Rule>()
+            .expect_err("too many dashes in option");
+
+        "IN DROP --log --iface eth0"
+            .parse::<Rule>()
+            .expect_err("no value for option");
+
+        "IN DROP --log crit --iface"
+            .parse::<Rule>()
+            .expect_err("no value for option");
+    }
+}
diff --git a/proxmox-ve-config/src/firewall/types/rule_match.rs b/proxmox-ve-config/src/firewall/types/rule_match.rs
new file mode 100644
index 0000000..ae5345c
--- /dev/null
+++ b/proxmox-ve-config/src/firewall/types/rule_match.rs
@@ -0,0 +1,953 @@
+use std::collections::HashMap;
+use std::fmt;
+use std::str::FromStr;
+
+use serde::Deserialize;
+
+use anyhow::{bail, format_err, Error};
+use serde::de::IntoDeserializer;
+
+use crate::firewall::parse::{match_name, match_non_whitespace, SomeStr};
+use crate::firewall::types::address::{Family, IpList};
+use crate::firewall::types::alias::AliasName;
+use crate::firewall::types::ipset::IpsetName;
+use crate::firewall::types::log::LogLevel;
+use crate::firewall::types::port::PortList;
+use crate::firewall::types::rule::{Direction, Verdict};
+
+#[derive(Clone, Debug, Default, Deserialize)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+#[serde(deny_unknown_fields, rename_all = "kebab-case")]
+pub(crate) struct RuleOptions {
+    #[serde(alias = "p")]
+    pub(crate) proto: Option<String>,
+
+    pub(crate) dport: Option<String>,
+    pub(crate) sport: Option<String>,
+
+    pub(crate) dest: Option<String>,
+    pub(crate) source: Option<String>,
+
+    #[serde(alias = "i")]
+    pub(crate) iface: Option<String>,
+
+    pub(crate) log: Option<LogLevel>,
+    pub(crate) icmp_type: Option<String>,
+}
+
+impl FromStr for RuleOptions {
+    type Err = Error;
+
+    fn from_str(mut line: &str) -> Result<Self, Self::Err> {
+        let mut options = HashMap::new();
+
+        loop {
+            line = line.trim_start();
+
+            if line.is_empty() {
+                break;
+            }
+
+            line = line
+                .strip_prefix('-')
+                .ok_or_else(|| format_err!("expected an option starting with '-'"))?;
+
+            // second dash is optional
+            line = line.strip_prefix('-').unwrap_or(line);
+
+            let param;
+            (param, line) = match_name(line)
+                .ok_or_else(|| format_err!("expected a parameter name after '-'"))?;
+
+            let value;
+            (value, line) = match_non_whitespace(line.trim_start())
+                .ok_or_else(|| format_err!("expected a value for {param:?}"))?;
+
+            if options.insert(param, SomeStr(value)).is_some() {
+                bail!("duplicate option in rule: {param}")
+            }
+        }
+
+        Ok(RuleOptions::deserialize(IntoDeserializer::<
+            '_,
+            crate::firewall::parse::SerdeStringError,
+        >::into_deserializer(
+            options
+        ))?)
+    }
+}
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct RuleMatch {
+    pub(crate) dir: Direction,
+    pub(crate) verdict: Verdict,
+    pub(crate) fw_macro: Option<String>,
+
+    pub(crate) iface: Option<String>,
+    pub(crate) log: Option<LogLevel>,
+    pub(crate) ip: Option<IpMatch>,
+    pub(crate) proto: Option<Protocol>,
+}
+
+impl RuleMatch {
+    pub(crate) fn from_options(
+        dir: Direction,
+        verdict: Verdict,
+        fw_macro: impl Into<Option<String>>,
+        options: RuleOptions,
+    ) -> Result<Self, Error> {
+        if options.dport.is_some() && options.icmp_type.is_some() {
+            bail!("dport and icmp-type are mutually exclusive");
+        }
+
+        let ip = IpMatch::from_options(&options)?;
+        let proto = Protocol::from_options(&options)?;
+
+        // todo: check protocol & IP Version compatibility
+
+        Ok(Self {
+            dir,
+            verdict,
+            fw_macro: fw_macro.into(),
+            iface: options.iface,
+            log: options.log,
+            ip,
+            proto,
+        })
+    }
+
+    pub fn direction(&self) -> Direction {
+        self.dir
+    }
+
+    pub fn iface(&self) -> Option<&str> {
+        self.iface.as_deref()
+    }
+
+    pub fn verdict(&self) -> Verdict {
+        self.verdict
+    }
+
+    pub fn fw_macro(&self) -> Option<&str> {
+        self.fw_macro.as_deref()
+    }
+
+    pub fn log(&self) -> Option<LogLevel> {
+        self.log
+    }
+
+    pub fn ip(&self) -> Option<&IpMatch> {
+        self.ip.as_ref()
+    }
+
+    pub fn proto(&self) -> Option<&Protocol> {
+        self.proto.as_ref()
+    }
+}
+
+/// Returns `(Macro name, Verdict, RestOfTheLine)`.
+fn parse_action(line: &str) -> Result<(Option<&str>, Verdict, &str), Error> {
+    let (verdict, line) =
+        match_name(line).ok_or_else(|| format_err!("expected a verdict or macro name"))?;
+
+    Ok(if let Some(line) = line.strip_prefix('(') {
+        // <macro>(<verdict>)
+
+        let macro_name = verdict;
+        let (verdict, line) = match_name(line).ok_or_else(|| format_err!("expected a verdict"))?;
+        let line = line
+            .strip_prefix(')')
+            .ok_or_else(|| format_err!("expected closing ')' after verdict"))?;
+
+        let verdict: Verdict = verdict.parse()?;
+
+        (Some(macro_name), verdict, line.trim_start())
+    } else {
+        (None, verdict.parse()?, line.trim_start())
+    })
+}
+
+impl FromStr for RuleMatch {
+    type Err = Error;
+
+    fn from_str(line: &str) -> Result<Self, Self::Err> {
+        let (dir, rest) = match_name(line).ok_or_else(|| format_err!("expected a direction"))?;
+
+        let direction: Direction = dir.parse()?;
+
+        let (fw_macro, verdict, rest) = parse_action(rest.trim_start())?;
+
+        let options: RuleOptions = rest.trim_start().parse()?;
+
+        Self::from_options(direction, verdict, fw_macro.map(str::to_string), options)
+    }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct IpMatch {
+    pub(crate) src: Option<IpAddrMatch>,
+    pub(crate) dst: Option<IpAddrMatch>,
+}
+
+impl IpMatch {
+    pub fn new(
+        src: impl Into<Option<IpAddrMatch>>,
+        dst: impl Into<Option<IpAddrMatch>>,
+    ) -> Result<Self, Error> {
+        let source = src.into();
+        let dest = dst.into();
+
+        if source.is_none() && dest.is_none() {
+            bail!("either src or dst must be set")
+        }
+
+        if let (Some(src), Some(dst)) = (&source, &dest) {
+            if src.family() != dst.family() {
+                bail!("src and dst family must be equal")
+            }
+        }
+
+        let ip_match = Self {
+            src: source,
+            dst: dest,
+        };
+
+        Ok(ip_match)
+    }
+
+    fn from_options(options: &RuleOptions) -> Result<Option<Self>, Error> {
+        let src = options
+            .source
+            .as_ref()
+            .map(|elem| elem.parse::<IpAddrMatch>())
+            .transpose()?;
+
+        let dst = options
+            .dest
+            .as_ref()
+            .map(|elem| elem.parse::<IpAddrMatch>())
+            .transpose()?;
+
+        Ok(IpMatch::new(src, dst).ok())
+    }
+
+    pub fn src(&self) -> Option<&IpAddrMatch> {
+        self.src.as_ref()
+    }
+
+    pub fn dst(&self) -> Option<&IpAddrMatch> {
+        self.dst.as_ref()
+    }
+}
+
+#[derive(Clone, Debug, Deserialize)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum IpAddrMatch {
+    Ip(IpList),
+    Set(IpsetName),
+    Alias(AliasName),
+}
+
+impl IpAddrMatch {
+    pub fn family(&self) -> Option<Family> {
+        if let IpAddrMatch::Ip(list) = self {
+            return Some(list.family());
+        }
+
+        None
+    }
+}
+
+impl FromStr for IpAddrMatch {
+    type Err = Error;
+
+    fn from_str(value: &str) -> Result<Self, Error> {
+        if value.is_empty() {
+            bail!("empty IP specification");
+        }
+
+        if let Ok(ip_list) = value.parse() {
+            return Ok(IpAddrMatch::Ip(ip_list));
+        }
+
+        if let Ok(ipset) = value.parse() {
+            return Ok(IpAddrMatch::Set(ipset));
+        }
+
+        if let Ok(name) = value.parse() {
+            return Ok(IpAddrMatch::Alias(name));
+        }
+
+        bail!("invalid IP specification: {value}")
+    }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum Protocol {
+    Dccp(Ports),
+    Sctp(Sctp),
+    Tcp(Tcp),
+    Udp(Udp),
+    UdpLite(Ports),
+    Icmp(Icmp),
+    Icmpv6(Icmpv6),
+    Named(String),
+    Numeric(u8),
+}
+
+impl Protocol {
+    pub(crate) fn from_options(options: &RuleOptions) -> Result<Option<Self>, Error> {
+        let proto = match options.proto.as_deref() {
+            Some(p) => p,
+            None => return Ok(None),
+        };
+
+        Ok(Some(match proto {
+            "dccp" | "33" => Protocol::Dccp(Ports::from_options(options)?),
+            "sctp" | "132" => Protocol::Sctp(Sctp::from_options(options)?),
+            "tcp" | "6" => Protocol::Tcp(Tcp::from_options(options)?),
+            "udp" | "17" => Protocol::Udp(Udp::from_options(options)?),
+            "udplite" | "136" => Protocol::UdpLite(Ports::from_options(options)?),
+            "icmp" | "1" => Protocol::Icmp(Icmp::from_options(options)?),
+            "ipv6-icmp" | "icmpv6" | "58" => Protocol::Icmpv6(Icmpv6::from_options(options)?),
+            other => match other.parse::<u8>() {
+                Ok(num) => Protocol::Numeric(num),
+                Err(_) => Protocol::Named(other.to_string()),
+            },
+        }))
+    }
+
+    pub fn family(&self) -> Option<Family> {
+        match self {
+            Self::Icmp(_) => Some(Family::V4),
+            Self::Icmpv6(_) => Some(Family::V6),
+            _ => None,
+        }
+    }
+}
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Udp {
+    ports: Ports,
+}
+
+impl Udp {
+    fn from_options(options: &RuleOptions) -> Result<Self, Error> {
+        Ok(Self {
+            ports: Ports::from_options(options)?,
+        })
+    }
+
+    pub fn new(ports: Ports) -> Self {
+        Self { ports }
+    }
+
+    pub fn ports(&self) -> &Ports {
+        &self.ports
+    }
+}
+
+impl From<Udp> for Protocol {
+    fn from(value: Udp) -> Self {
+        Protocol::Udp(value)
+    }
+}
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Ports {
+    sport: Option<PortList>,
+    dport: Option<PortList>,
+}
+
+impl Ports {
+    pub fn new(sport: impl Into<Option<PortList>>, dport: impl Into<Option<PortList>>) -> Self {
+        Self {
+            sport: sport.into(),
+            dport: dport.into(),
+        }
+    }
+
+    fn from_options(options: &RuleOptions) -> Result<Self, Error> {
+        Ok(Self {
+            sport: options.sport.as_deref().map(|s| s.parse()).transpose()?,
+            dport: options.dport.as_deref().map(|s| s.parse()).transpose()?,
+        })
+    }
+
+    pub fn from_u16(sport: impl Into<Option<u16>>, dport: impl Into<Option<u16>>) -> Self {
+        Self::new(
+            sport.into().map(PortList::from),
+            dport.into().map(PortList::from),
+        )
+    }
+
+    pub fn sport(&self) -> Option<&PortList> {
+        self.sport.as_ref()
+    }
+
+    pub fn dport(&self) -> Option<&PortList> {
+        self.dport.as_ref()
+    }
+}
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Tcp {
+    ports: Ports,
+}
+
+impl Tcp {
+    pub fn new(ports: Ports) -> Self {
+        Self { ports }
+    }
+
+    fn from_options(options: &RuleOptions) -> Result<Self, Error> {
+        Ok(Self {
+            ports: Ports::from_options(options)?,
+        })
+    }
+
+    pub fn ports(&self) -> &Ports {
+        &self.ports
+    }
+}
+
+impl From<Tcp> for Protocol {
+    fn from(value: Tcp) -> Self {
+        Protocol::Tcp(value)
+    }
+}
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Sctp {
+    ports: Ports,
+}
+
+impl Sctp {
+    fn from_options(options: &RuleOptions) -> Result<Self, Error> {
+        Ok(Self {
+            ports: Ports::from_options(options)?,
+        })
+    }
+
+    pub fn ports(&self) -> &Ports {
+        &self.ports
+    }
+}
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Icmp {
+    ty: Option<IcmpType>,
+    code: Option<IcmpCode>,
+}
+
+impl Icmp {
+    pub fn new_ty(ty: IcmpType) -> Self {
+        Self {
+            ty: Some(ty),
+            ..Default::default()
+        }
+    }
+
+    pub fn new_code(code: IcmpCode) -> Self {
+        Self {
+            code: Some(code),
+            ..Default::default()
+        }
+    }
+
+    fn from_options(options: &RuleOptions) -> Result<Self, Error> {
+        if let Some(ty) = &options.icmp_type {
+            return ty.parse();
+        }
+
+        Ok(Self::default())
+    }
+
+    pub fn ty(&self) -> Option<&IcmpType> {
+        self.ty.as_ref()
+    }
+
+    pub fn code(&self) -> Option<&IcmpCode> {
+        self.code.as_ref()
+    }
+}
+
+impl FromStr for Icmp {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        let mut this = Self::default();
+
+        if let Ok(ty) = s.parse() {
+            this.ty = Some(ty);
+            return Ok(this);
+        }
+
+        if let Ok(code) = s.parse() {
+            this.code = Some(code);
+            return Ok(this);
+        }
+
+        bail!("supplied string is neither a valid icmp type nor code");
+    }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum IcmpType {
+    Numeric(u8),
+    Named(&'static str),
+}
+
+// MUST BE SORTED!
+const ICMP_TYPES: &[(&str, u8)] = &[
+    ("address-mask-reply", 18),
+    ("address-mask-request", 17),
+    ("destination-unreachable", 3),
+    ("echo-reply", 0),
+    ("echo-request", 8),
+    ("info-reply", 16),
+    ("info-request", 15),
+    ("parameter-problem", 12),
+    ("redirect", 5),
+    ("router-advertisement", 9),
+    ("router-solicitation", 10),
+    ("source-quench", 4),
+    ("time-exceeded", 11),
+    ("timestamp-reply", 14),
+    ("timestamp-request", 13),
+];
+
+impl std::str::FromStr for IcmpType {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Error> {
+        if let Ok(ty) = s.trim().parse::<u8>() {
+            return Ok(Self::Numeric(ty));
+        }
+
+        if let Ok(index) = ICMP_TYPES.binary_search_by(|v| v.0.cmp(s)) {
+            return Ok(Self::Named(ICMP_TYPES[index].0));
+        }
+
+        bail!("{s:?} is not a valid icmp type");
+    }
+}
+
+impl fmt::Display for IcmpType {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            IcmpType::Numeric(ty) => write!(f, "{ty}"),
+            IcmpType::Named(ty) => write!(f, "{ty}"),
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum IcmpCode {
+    Numeric(u8),
+    Named(&'static str),
+}
+
+// MUST BE SORTED!
+const ICMP_CODES: &[(&str, u8)] = &[
+    ("admin-prohibited", 13),
+    ("host-prohibited", 10),
+    ("host-unreachable", 1),
+    ("net-prohibited", 9),
+    ("net-unreachable", 0),
+    ("port-unreachable", 3),
+    ("prot-unreachable", 2),
+];
+
+impl std::str::FromStr for IcmpCode {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Error> {
+        if let Ok(code) = s.trim().parse::<u8>() {
+            return Ok(Self::Numeric(code));
+        }
+
+        if let Ok(index) = ICMP_CODES.binary_search_by(|v| v.0.cmp(s)) {
+            return Ok(Self::Named(ICMP_CODES[index].0));
+        }
+
+        bail!("{s:?} is not a valid icmp code");
+    }
+}
+
+impl fmt::Display for IcmpCode {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            IcmpCode::Numeric(code) => write!(f, "{code}"),
+            IcmpCode::Named(code) => write!(f, "{code}"),
+        }
+    }
+}
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Icmpv6 {
+    pub ty: Option<Icmpv6Type>,
+    pub code: Option<Icmpv6Code>,
+}
+
+impl Icmpv6 {
+    pub fn new_ty(ty: Icmpv6Type) -> Self {
+        Self {
+            ty: Some(ty),
+            ..Default::default()
+        }
+    }
+
+    pub fn new_code(code: Icmpv6Code) -> Self {
+        Self {
+            code: Some(code),
+            ..Default::default()
+        }
+    }
+
+    fn from_options(options: &RuleOptions) -> Result<Self, Error> {
+        if let Some(ty) = &options.icmp_type {
+            return ty.parse();
+        }
+
+        Ok(Self::default())
+    }
+
+    pub fn ty(&self) -> Option<&Icmpv6Type> {
+        self.ty.as_ref()
+    }
+
+    pub fn code(&self) -> Option<&Icmpv6Code> {
+        self.code.as_ref()
+    }
+}
+
+impl FromStr for Icmpv6 {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        let mut this = Self::default();
+
+        if let Ok(ty) = s.parse() {
+            this.ty = Some(ty);
+            return Ok(this);
+        }
+
+        if let Ok(code) = s.parse() {
+            this.code = Some(code);
+            return Ok(this);
+        }
+
+        bail!("supplied string is neither a valid icmpv6 type nor code");
+    }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum Icmpv6Type {
+    Numeric(u8),
+    Named(&'static str),
+}
+
+// MUST BE SORTED!
+const ICMPV6_TYPES: &[(&str, u8)] = &[
+    ("destination-unreachable", 1),
+    ("echo-reply", 129),
+    ("echo-request", 128),
+    ("ind-neighbor-advert", 142),
+    ("ind-neighbor-solicit", 141),
+    ("mld-listener-done", 132),
+    ("mld-listener-query", 130),
+    ("mld-listener-reduction", 132),
+    ("mld-listener-report", 131),
+    ("mld2-listener-report", 143),
+    ("nd-neighbor-advert", 136),
+    ("nd-neighbor-solicit", 135),
+    ("nd-redirect", 137),
+    ("nd-router-advert", 134),
+    ("nd-router-solicit", 133),
+    ("packet-too-big", 2),
+    ("parameter-problem", 4),
+    ("router-renumbering", 138),
+    ("time-exceeded", 3),
+];
+
+impl std::str::FromStr for Icmpv6Type {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Error> {
+        if let Ok(ty) = s.trim().parse::<u8>() {
+            return Ok(Self::Numeric(ty));
+        }
+
+        if let Ok(index) = ICMPV6_TYPES.binary_search_by(|v| v.0.cmp(s)) {
+            return Ok(Self::Named(ICMPV6_TYPES[index].0));
+        }
+
+        bail!("{s:?} is not a valid icmpv6 type");
+    }
+}
+
+impl fmt::Display for Icmpv6Type {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Icmpv6Type::Numeric(ty) => write!(f, "{ty}"),
+            Icmpv6Type::Named(ty) => write!(f, "{ty}"),
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum Icmpv6Code {
+    Numeric(u8),
+    Named(&'static str),
+}
+
+// MUST BE SORTED!
+const ICMPV6_CODES: &[(&str, u8)] = &[
+    ("addr-unreachable", 3),
+    ("admin-prohibited", 1),
+    ("no-route", 0),
+    ("policy-fail", 5),
+    ("port-unreachable", 4),
+    ("reject-route", 6),
+];
+
+impl std::str::FromStr for Icmpv6Code {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Error> {
+        if let Ok(code) = s.trim().parse::<u8>() {
+            return Ok(Self::Numeric(code));
+        }
+
+        if let Ok(index) = ICMPV6_CODES.binary_search_by(|v| v.0.cmp(s)) {
+            return Ok(Self::Named(ICMPV6_CODES[index].0));
+        }
+
+        bail!("{s:?} is not a valid icmpv6 code");
+    }
+}
+
+impl fmt::Display for Icmpv6Code {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Icmpv6Code::Numeric(code) => write!(f, "{code}"),
+            Icmpv6Code::Named(code) => write!(f, "{code}"),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::firewall::types::Cidr;
+
+    use super::*;
+
+    #[test]
+    fn test_parse_action() {
+        assert_eq!(parse_action("REJECT").unwrap(), (None, Verdict::Reject, ""));
+
+        assert_eq!(
+            parse_action("SSH(ACCEPT) qweasd").unwrap(),
+            (Some("SSH"), Verdict::Accept, "qweasd")
+        );
+    }
+
+    #[test]
+    fn test_parse_ip_addr_match() {
+        for input in [
+            "10.0.0.0/8",
+            "10.0.0.0/8,192.168.0.0-192.168.255.255,172.16.0.1",
+            "dc/test",
+            "+guest/proxmox",
+        ] {
+            input.parse::<IpAddrMatch>().expect("valid ip match");
+        }
+
+        for input in [
+            "10.0.0.0/",
+            "10.0.0.0/8,192.168.256.0-192.168.255.255,172.16.0.1",
+            "dcc/test",
+            "+guest/",
+            "",
+        ] {
+            input.parse::<IpAddrMatch>().expect_err("invalid ip match");
+        }
+    }
+
+    #[test]
+    fn test_parse_options() {
+        let mut options: RuleOptions =
+            "-p udp --sport 123 --dport 234 -source 127.0.0.1 --dest 127.0.0.1 -i ens1 --log crit"
+                .parse()
+                .expect("valid option string");
+
+        assert_eq!(
+            options,
+            RuleOptions {
+                proto: Some("udp".to_string()),
+                sport: Some("123".to_string()),
+                dport: Some("234".to_string()),
+                source: Some("127.0.0.1".to_string()),
+                dest: Some("127.0.0.1".to_string()),
+                iface: Some("ens1".to_string()),
+                log: Some(LogLevel::Critical),
+                icmp_type: None,
+            }
+        );
+
+        options = "".parse().expect("valid option string");
+
+        assert_eq!(options, RuleOptions::default(),);
+    }
+
+    #[test]
+    fn test_construct_ip_match() {
+        IpMatch::new(
+            IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 0], 8).unwrap())),
+            IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 0], 8).unwrap())),
+        )
+        .expect("valid ip match");
+
+        IpMatch::new(
+            IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 0], 8).unwrap())),
+            IpAddrMatch::Ip(IpList::from(Cidr::new_v6([0x0000; 8], 8).unwrap())),
+        )
+        .expect_err("cannot mix ip families");
+
+        IpMatch::new(None, None).expect_err("at least one ip must be set");
+    }
+
+    #[test]
+    fn test_from_options() {
+        let mut options = RuleOptions {
+            proto: Some("tcp".to_string()),
+            sport: Some("123".to_string()),
+            dport: Some("234".to_string()),
+            source: Some("192.168.0.1".to_string()),
+            dest: Some("10.0.0.1".to_string()),
+            iface: Some("eth123".to_string()),
+            log: Some(LogLevel::Error),
+            ..Default::default()
+        };
+
+        assert_eq!(
+            Protocol::from_options(&options).unwrap().unwrap(),
+            Protocol::Tcp(Tcp::new(Ports::from_u16(123, 234))),
+        );
+
+        assert_eq!(
+            IpMatch::from_options(&options).unwrap().unwrap(),
+            IpMatch::new(
+                IpAddrMatch::Ip(IpList::from(Cidr::new_v4([192, 168, 0, 1], 32).unwrap()),),
+                IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 1], 32).unwrap()),)
+            )
+            .unwrap(),
+        );
+
+        options = RuleOptions::default();
+
+        assert_eq!(Protocol::from_options(&options).unwrap(), None,);
+
+        assert_eq!(IpMatch::from_options(&options).unwrap(), None,);
+
+        options = RuleOptions {
+            proto: Some("tcp".to_string()),
+            sport: Some("qwe".to_string()),
+            source: Some("qwe".to_string()),
+            ..Default::default()
+        };
+
+        Protocol::from_options(&options).expect_err("invalid source port");
+
+        IpMatch::from_options(&options).expect_err("invalid source address");
+
+        options = RuleOptions {
+            icmp_type: Some("port-unreachable".to_string()),
+            dport: Some("123".to_string()),
+            ..Default::default()
+        };
+
+        RuleMatch::from_options(Direction::In, Verdict::Drop, None, options)
+            .expect_err("cannot mix dport and icmp-type");
+    }
+
+    #[test]
+    fn test_parse_icmp() {
+        let mut icmp: Icmp = "info-request".parse().expect("valid icmp type");
+
+        assert_eq!(
+            icmp,
+            Icmp {
+                ty: Some(IcmpType::Named("info-request")),
+                code: None
+            }
+        );
+
+        icmp = "12".parse().expect("valid icmp type");
+
+        assert_eq!(
+            icmp,
+            Icmp {
+                ty: Some(IcmpType::Numeric(12)),
+                code: None
+            }
+        );
+
+        icmp = "port-unreachable".parse().expect("valid icmp code");
+
+        assert_eq!(
+            icmp,
+            Icmp {
+                ty: None,
+                code: Some(IcmpCode::Named("port-unreachable"))
+            }
+        );
+    }
+
+    #[test]
+    fn test_parse_icmp6() {
+        let mut icmp: Icmpv6 = "echo-reply".parse().expect("valid icmpv6 type");
+
+        assert_eq!(
+            icmp,
+            Icmpv6 {
+                ty: Some(Icmpv6Type::Named("echo-reply")),
+                code: None
+            }
+        );
+
+        icmp = "12".parse().expect("valid icmpv6 type");
+
+        assert_eq!(
+            icmp,
+            Icmpv6 {
+                ty: Some(Icmpv6Type::Numeric(12)),
+                code: None
+            }
+        );
+
+        icmp = "admin-prohibited".parse().expect("valid icmpv6 code");
+
+        assert_eq!(
+            icmp,
+            Icmpv6 {
+                ty: None,
+                code: Some(Icmpv6Code::Named("admin-prohibited"))
+            }
+        );
+    }
+}
-- 
2.39.2




More information about the pve-devel mailing list