[pve-devel] [PATCH proxmox-firewall 02/37] config: firewall: add types for ip addresses
Stefan Hanreich
s.hanreich at proxmox.com
Tue Apr 2 19:15:54 CEST 2024
Includes types for all kinds of IP values that can occur in the
firewall config. Additionally, FromStr implementations are available
for parsing from the config files.
Co-authored-by: Wolfgang Bumiller <w.bumiller at proxmox.com>
Signed-off-by: Stefan Hanreich <s.hanreich at proxmox.com>
---
proxmox-ve-config/src/firewall/mod.rs | 1 +
.../src/firewall/types/address.rs | 624 ++++++++++++++++++
proxmox-ve-config/src/firewall/types/mod.rs | 3 +
proxmox-ve-config/src/lib.rs | 1 +
4 files changed, 629 insertions(+)
create mode 100644 proxmox-ve-config/src/firewall/mod.rs
create mode 100644 proxmox-ve-config/src/firewall/types/address.rs
create mode 100644 proxmox-ve-config/src/firewall/types/mod.rs
diff --git a/proxmox-ve-config/src/firewall/mod.rs b/proxmox-ve-config/src/firewall/mod.rs
new file mode 100644
index 0000000..cd40856
--- /dev/null
+++ b/proxmox-ve-config/src/firewall/mod.rs
@@ -0,0 +1 @@
+pub mod types;
diff --git a/proxmox-ve-config/src/firewall/types/address.rs b/proxmox-ve-config/src/firewall/types/address.rs
new file mode 100644
index 0000000..ce2f1cd
--- /dev/null
+++ b/proxmox-ve-config/src/firewall/types/address.rs
@@ -0,0 +1,624 @@
+use std::fmt;
+use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
+use std::ops::Deref;
+
+use anyhow::{bail, format_err, Error};
+use serde_with::DeserializeFromStr;
+
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub enum Family {
+ V4,
+ V6,
+}
+
+impl fmt::Display for Family {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ Family::V4 => f.write_str("Ipv4"),
+ Family::V6 => f.write_str("Ipv6"),
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum Cidr {
+ Ipv4(Ipv4Cidr),
+ Ipv6(Ipv6Cidr),
+}
+
+impl Cidr {
+ pub fn new_v4(addr: impl Into<Ipv4Addr>, mask: u8) -> Result<Self, Error> {
+ Ok(Cidr::Ipv4(Ipv4Cidr::new(addr, mask)?))
+ }
+
+ pub fn new_v6(addr: impl Into<Ipv6Addr>, mask: u8) -> Result<Self, Error> {
+ Ok(Cidr::Ipv6(Ipv6Cidr::new(addr, mask)?))
+ }
+
+ pub const fn family(&self) -> Family {
+ match self {
+ Cidr::Ipv4(_) => Family::V4,
+ Cidr::Ipv6(_) => Family::V6,
+ }
+ }
+
+ pub fn is_ipv4(&self) -> bool {
+ matches!(self, Cidr::Ipv4(_))
+ }
+
+ pub fn is_ipv6(&self) -> bool {
+ matches!(self, Cidr::Ipv6(_))
+ }
+}
+
+impl fmt::Display for Cidr {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ Self::Ipv4(ip) => f.write_str(ip.to_string().as_str()),
+ Self::Ipv6(ip) => f.write_str(ip.to_string().as_str()),
+ }
+ }
+}
+
+impl std::str::FromStr for Cidr {
+ type Err = Error;
+
+ fn from_str(s: &str) -> Result<Self, Error> {
+ if let Ok(ip) = s.parse::<Ipv4Cidr>() {
+ return Ok(Cidr::Ipv4(ip));
+ }
+
+ if let Ok(ip) = s.parse::<Ipv6Cidr>() {
+ return Ok(Cidr::Ipv6(ip));
+ }
+
+ bail!("invalid ip address or CIDR: {s:?}");
+ }
+}
+
+impl From<Ipv4Cidr> for Cidr {
+ fn from(cidr: Ipv4Cidr) -> Self {
+ Cidr::Ipv4(cidr)
+ }
+}
+
+impl From<Ipv6Cidr> for Cidr {
+ fn from(cidr: Ipv6Cidr) -> Self {
+ Cidr::Ipv6(cidr)
+ }
+}
+
+const IPV4_LENGTH: u8 = 32;
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Ipv4Cidr {
+ addr: Ipv4Addr,
+ mask: u8,
+}
+
+impl Ipv4Cidr {
+ pub fn new(addr: impl Into<Ipv4Addr>, mask: u8) -> Result<Self, Error> {
+ if mask > 32 {
+ bail!("mask out of range for ipv4 cidr ({mask})");
+ }
+
+ Ok(Self {
+ addr: addr.into(),
+ mask,
+ })
+ }
+
+ pub fn contains_address(&self, other: &Ipv4Addr) -> bool {
+ let bits = u32::from_be_bytes(self.addr.octets());
+ let other_bits = u32::from_be_bytes(other.octets());
+
+ let shift_amount: u32 = IPV4_LENGTH.saturating_sub(self.mask).into();
+
+ bits.checked_shr(shift_amount).unwrap_or(0)
+ == other_bits.checked_shr(shift_amount).unwrap_or(0)
+ }
+
+ pub fn address(&self) -> &Ipv4Addr {
+ &self.addr
+ }
+
+ pub fn mask(&self) -> u8 {
+ self.mask
+ }
+}
+
+impl<T: Into<Ipv4Addr>> From<T> for Ipv4Cidr {
+ fn from(value: T) -> Self {
+ Self {
+ addr: value.into(),
+ mask: 32,
+ }
+ }
+}
+
+impl std::str::FromStr for Ipv4Cidr {
+ type Err = Error;
+
+ fn from_str(s: &str) -> Result<Self, Error> {
+ Ok(match s.find('/') {
+ None => Self {
+ addr: s.parse()?,
+ mask: 32,
+ },
+ Some(pos) => {
+ let mask: u8 = s[(pos + 1)..]
+ .parse()
+ .map_err(|_| format_err!("invalid mask in ipv4 cidr: {s:?}"))?;
+
+ Self::new(s[..pos].parse::<Ipv4Addr>()?, mask)?
+ }
+ })
+ }
+}
+
+impl fmt::Display for Ipv4Cidr {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}/{}", &self.addr, self.mask)
+ }
+}
+
+const IPV6_LENGTH: u8 = 128;
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct Ipv6Cidr {
+ addr: Ipv6Addr,
+ mask: u8,
+}
+
+impl Ipv6Cidr {
+ pub fn new(addr: impl Into<Ipv6Addr>, mask: u8) -> Result<Self, Error> {
+ if mask > IPV6_LENGTH {
+ bail!("mask out of range for ipv6 cidr");
+ }
+
+ Ok(Self {
+ addr: addr.into(),
+ mask,
+ })
+ }
+
+ pub fn contains_address(&self, other: &Ipv6Addr) -> bool {
+ let bits = u128::from_be_bytes(self.addr.octets());
+ let other_bits = u128::from_be_bytes(other.octets());
+
+ let shift_amount: u32 = IPV6_LENGTH.saturating_sub(self.mask).into();
+
+ bits.checked_shr(shift_amount).unwrap_or(0)
+ == other_bits.checked_shr(shift_amount).unwrap_or(0)
+ }
+
+ pub fn address(&self) -> &Ipv6Addr {
+ &self.addr
+ }
+
+ pub fn mask(&self) -> u8 {
+ self.mask
+ }
+}
+
+impl std::str::FromStr for Ipv6Cidr {
+ type Err = Error;
+
+ fn from_str(s: &str) -> Result<Self, Error> {
+ Ok(match s.find('/') {
+ None => Self {
+ addr: s.parse()?,
+ mask: 128,
+ },
+ Some(pos) => {
+ let mask: u8 = s[(pos + 1)..]
+ .parse()
+ .map_err(|_| format_err!("invalid mask in ipv6 cidr: {s:?}"))?;
+
+ Self::new(s[..pos].parse::<Ipv6Addr>()?, mask)?
+ }
+ })
+ }
+}
+
+impl fmt::Display for Ipv6Cidr {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "{}/{}", &self.addr, self.mask)
+ }
+}
+
+impl<T: Into<Ipv6Addr>> From<T> for Ipv6Cidr {
+ fn from(addr: T) -> Self {
+ Self {
+ addr: addr.into(),
+ mask: 128,
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub enum IpEntry {
+ Cidr(Cidr),
+ Range(IpAddr, IpAddr),
+}
+
+impl std::str::FromStr for IpEntry {
+ type Err = Error;
+
+ fn from_str(s: &str) -> Result<Self, Error> {
+ if s.is_empty() {
+ bail!("Empty IP specification!")
+ }
+
+ let entries: Vec<&str> = s
+ .split('-')
+ .take(3) // so we can check whether there are too many
+ .collect();
+
+ match entries.len() {
+ 1 => {
+ let cidr = entries.first().expect("Vec contains an element");
+
+ Ok(IpEntry::Cidr(cidr.parse()?))
+ }
+ 2 => {
+ let (beg, end) = (
+ entries.first().expect("Vec contains two elements"),
+ entries.get(1).expect("Vec contains two elements"),
+ );
+
+ if let Ok(beg) = beg.parse::<Ipv4Addr>() {
+ if let Ok(end) = end.parse::<Ipv4Addr>() {
+ if beg < end {
+ return Ok(IpEntry::Range(beg.into(), end.into()));
+ }
+
+ bail!("start address is greater than end address!");
+ }
+ }
+
+ if let Ok(beg) = beg.parse::<Ipv6Addr>() {
+ if let Ok(end) = end.parse::<Ipv6Addr>() {
+ if beg < end {
+ return Ok(IpEntry::Range(beg.into(), end.into()));
+ }
+
+ bail!("start address is greater than end address!");
+ }
+ }
+
+ bail!("start and end are not valid IP addresses of the same type!")
+ }
+ _ => bail!("Invalid amount of elements in IpEntry!"),
+ }
+ }
+}
+
+impl fmt::Display for IpEntry {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ Self::Cidr(ip) => write!(f, "{ip}"),
+ Self::Range(beg, end) => write!(f, "{beg}-{end}"),
+ }
+ }
+}
+
+impl IpEntry {
+ fn family(&self) -> Family {
+ match self {
+ Self::Cidr(cidr) => cidr.family(),
+ Self::Range(start, end) => {
+ if start.is_ipv4() && end.is_ipv4() {
+ return Family::V4;
+ }
+
+ if start.is_ipv6() && end.is_ipv6() {
+ return Family::V6;
+ }
+
+ // should never be reached due to constructors validating that
+ // start type == end type
+ unreachable!("invalid IP entry")
+ }
+ }
+ }
+}
+
+impl From<Cidr> for IpEntry {
+ fn from(value: Cidr) -> Self {
+ IpEntry::Cidr(value)
+ }
+}
+
+#[derive(Clone, Debug, DeserializeFromStr)]
+#[cfg_attr(test, derive(Eq, PartialEq))]
+pub struct IpList {
+ // guaranteed to have the same family
+ entries: Vec<IpEntry>,
+ family: Family,
+}
+
+impl Deref for IpList {
+ type Target = Vec<IpEntry>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.entries
+ }
+}
+
+impl<T: Into<IpEntry>> From<T> for IpList {
+ fn from(value: T) -> Self {
+ let entry = value.into();
+
+ Self {
+ family: entry.family(),
+ entries: vec![entry],
+ }
+ }
+}
+
+impl std::str::FromStr for IpList {
+ type Err = Error;
+
+ fn from_str(s: &str) -> Result<Self, Error> {
+ if s.is_empty() {
+ bail!("Empty IP specification!")
+ }
+
+ let mut entries = Vec::new();
+ let mut current_family = None;
+
+ for element in s.split(',') {
+ let entry: IpEntry = element.parse()?;
+
+ if let Some(family) = current_family {
+ if family != entry.family() {
+ bail!("Incompatible families in IPList!")
+ }
+ } else {
+ current_family = Some(entry.family());
+ }
+
+ entries.push(entry);
+ }
+
+ if entries.is_empty() {
+ bail!("empty ip list")
+ }
+
+ Ok(IpList {
+ entries,
+ family: current_family.unwrap(), // must be set due to length check above
+ })
+ }
+}
+
+impl IpList {
+ pub fn new(entries: Vec<IpEntry>) -> Result<Self, Error> {
+ let family = entries.iter().try_fold(None, |result, entry| {
+ if let Some(family) = result {
+ if entry.family() != family {
+ bail!("non-matching families in entries list");
+ }
+
+ Ok(Some(family))
+ } else {
+ Ok(Some(entry.family()))
+ }
+ })?;
+
+ if let Some(family) = family {
+ return Ok(Self { entries, family });
+ }
+
+ bail!("no elements in ip list entries");
+ }
+
+ pub fn family(&self) -> Family {
+ self.family
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::net::{Ipv4Addr, Ipv6Addr};
+
+ #[test]
+ fn test_v4_cidr() {
+ let mut cidr: Ipv4Cidr = "0.0.0.0/0".parse().expect("valid IPv4 CIDR");
+
+ assert_eq!(cidr.addr, Ipv4Addr::new(0, 0, 0, 0));
+ assert_eq!(cidr.mask, 0);
+
+ assert!(cidr.contains_address(&Ipv4Addr::new(0, 0, 0, 0)));
+ assert!(cidr.contains_address(&Ipv4Addr::new(255, 255, 255, 255)));
+
+ cidr = "192.168.100.1".parse().expect("valid IPv4 CIDR");
+
+ assert_eq!(cidr.addr, Ipv4Addr::new(192, 168, 100, 1));
+ assert_eq!(cidr.mask, 32);
+
+ assert!(cidr.contains_address(&Ipv4Addr::new(192, 168, 100, 1)));
+ assert!(!cidr.contains_address(&Ipv4Addr::new(192, 168, 100, 2)));
+ assert!(!cidr.contains_address(&Ipv4Addr::new(192, 168, 100, 0)));
+
+ cidr = "10.100.5.0/24".parse().expect("valid IPv4 CIDR");
+
+ assert_eq!(cidr.mask, 24);
+
+ assert!(cidr.contains_address(&Ipv4Addr::new(10, 100, 5, 0)));
+ assert!(cidr.contains_address(&Ipv4Addr::new(10, 100, 5, 1)));
+ assert!(cidr.contains_address(&Ipv4Addr::new(10, 100, 5, 100)));
+ assert!(cidr.contains_address(&Ipv4Addr::new(10, 100, 5, 255)));
+ assert!(!cidr.contains_address(&Ipv4Addr::new(10, 100, 4, 255)));
+ assert!(!cidr.contains_address(&Ipv4Addr::new(10, 100, 6, 0)));
+
+ "0.0.0.0/-1".parse::<Ipv4Cidr>().unwrap_err();
+ "0.0.0.0/33".parse::<Ipv4Cidr>().unwrap_err();
+ "256.256.256.256/10".parse::<Ipv4Cidr>().unwrap_err();
+
+ "fe80::1/64".parse::<Ipv4Cidr>().unwrap_err();
+ "qweasd".parse::<Ipv4Cidr>().unwrap_err();
+ "".parse::<Ipv4Cidr>().unwrap_err();
+ }
+
+ #[test]
+ fn test_v6_cidr() {
+ let mut cidr: Ipv6Cidr = "abab::1/64".parse().expect("valid IPv6 CIDR");
+
+ assert_eq!(cidr.addr, Ipv6Addr::new(0xABAB, 0, 0, 0, 0, 0, 0, 1));
+ assert_eq!(cidr.mask, 64);
+
+ assert!(cidr.contains_address(&Ipv6Addr::new(0xABAB, 0, 0, 0, 0, 0, 0, 0)));
+ assert!(cidr.contains_address(&Ipv6Addr::new(
+ 0xABAB, 0, 0, 0, 0xAAAA, 0xAAAA, 0xAAAA, 0xAAAA
+ )));
+ assert!(cidr.contains_address(&Ipv6Addr::new(
+ 0xABAB, 0, 0, 0, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF
+ )));
+ assert!(!cidr.contains_address(&Ipv6Addr::new(0xABAB, 0, 0, 1, 0, 0, 0, 0)));
+ assert!(!cidr.contains_address(&Ipv6Addr::new(
+ 0xABAA, 0, 0, 0, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF
+ )));
+
+ cidr = "eeee::1".parse().expect("valid IPv6 CIDR");
+
+ assert_eq!(cidr.mask, 128);
+
+ assert!(cidr.contains_address(&Ipv6Addr::new(0xEEEE, 0, 0, 0, 0, 0, 0, 1)));
+ assert!(!cidr.contains_address(&Ipv6Addr::new(
+ 0xEEED, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF, 0xFFFF
+ )));
+ assert!(!cidr.contains_address(&Ipv6Addr::new(0xEEEE, 0, 0, 0, 0, 0, 0, 0)));
+
+ "eeee::1/-1".parse::<Ipv6Cidr>().unwrap_err();
+ "eeee::1/129".parse::<Ipv6Cidr>().unwrap_err();
+ "gggg::1/64".parse::<Ipv6Cidr>().unwrap_err();
+
+ "192.168.0.1".parse::<Ipv6Cidr>().unwrap_err();
+ "qweasd".parse::<Ipv6Cidr>().unwrap_err();
+ "".parse::<Ipv6Cidr>().unwrap_err();
+ }
+
+ #[test]
+ fn test_parse_ip_entry() {
+ let mut entry: IpEntry = "10.0.0.1".parse().expect("valid IP entry");
+
+ assert_eq!(entry, Cidr::new_v4([10, 0, 0, 1], 32).unwrap().into());
+
+ entry = "10.0.0.0/16".parse().expect("valid IP entry");
+
+ assert_eq!(entry, Cidr::new_v4([10, 0, 0, 0], 16).unwrap().into());
+
+ entry = "192.168.0.1-192.168.99.255"
+ .parse()
+ .expect("valid IP entry");
+
+ assert_eq!(
+ entry,
+ IpEntry::Range([192, 168, 0, 1].into(), [192, 168, 99, 255].into())
+ );
+
+ entry = "fe80::1".parse().expect("valid IP entry");
+
+ assert_eq!(
+ entry,
+ Cidr::new_v6([0xFE80, 0, 0, 0, 0, 0, 0, 1], 128)
+ .unwrap()
+ .into()
+ );
+
+ entry = "fe80::1/48".parse().expect("valid IP entry");
+
+ assert_eq!(
+ entry,
+ Cidr::new_v6([0xFE80, 0, 0, 0, 0, 0, 0, 1], 48)
+ .unwrap()
+ .into()
+ );
+
+ entry = "fd80::1-fd80::ffff".parse().expect("valid IP entry");
+
+ assert_eq!(
+ entry,
+ IpEntry::Range(
+ [0xFD80, 0, 0, 0, 0, 0, 0, 1].into(),
+ [0xFD80, 0, 0, 0, 0, 0, 0, 0xFFFF].into(),
+ )
+ );
+
+ "192.168.100.0-192.168.99.255"
+ .parse::<IpEntry>()
+ .unwrap_err();
+ "192.168.100.0-fe80::1".parse::<IpEntry>().unwrap_err();
+ "192.168.100.0-192.168.200.0/16"
+ .parse::<IpEntry>()
+ .unwrap_err();
+ "192.168.100.0-192.168.200.0-192.168.250.0"
+ .parse::<IpEntry>()
+ .unwrap_err();
+ "qweasd".parse::<IpEntry>().unwrap_err();
+ }
+
+ #[test]
+ fn test_parse_ip_list() {
+ let mut ip_list: IpList = "192.168.0.1,192.168.100.0/24,172.16.0.0-172.32.255.255"
+ .parse()
+ .expect("valid IP list");
+
+ assert_eq!(
+ ip_list,
+ IpList {
+ entries: vec![
+ IpEntry::Cidr(Cidr::new_v4([192, 168, 0, 1], 32).unwrap()),
+ IpEntry::Cidr(Cidr::new_v4([192, 168, 100, 0], 24).unwrap()),
+ IpEntry::Range([172, 16, 0, 0].into(), [172, 32, 255, 255].into()),
+ ],
+ family: Family::V4,
+ }
+ );
+
+ ip_list = "fe80::1/64".parse().expect("valid IP list");
+
+ assert_eq!(
+ ip_list,
+ IpList {
+ entries: vec![IpEntry::Cidr(
+ Cidr::new_v6([0xFE80, 0, 0, 0, 0, 0, 0, 1], 64).unwrap()
+ ),],
+ family: Family::V6,
+ }
+ );
+
+ "192.168.0.1,fe80::1".parse::<IpList>().unwrap_err();
+
+ "".parse::<IpList>().unwrap_err();
+ "proxmox".parse::<IpList>().unwrap_err();
+ }
+
+ #[test]
+ fn test_construct_ip_list() {
+ let mut ip_list = IpList::new(vec![Cidr::new_v4([10, 0, 0, 0], 8).unwrap().into()])
+ .expect("valid ip list");
+
+ assert_eq!(ip_list.family(), Family::V4);
+
+ ip_list =
+ IpList::new(vec![Cidr::new_v6([0x000; 8], 8).unwrap().into()]).expect("valid ip list");
+
+ assert_eq!(ip_list.family(), Family::V6);
+
+ IpList::new(vec![]).expect_err("empty ip list is invalid");
+
+ IpList::new(vec![
+ Cidr::new_v4([10, 0, 0, 0], 8).unwrap().into(),
+ Cidr::new_v6([0x0000; 8], 8).unwrap().into(),
+ ])
+ .expect_err("cannot mix ip families in ip list");
+ }
+}
diff --git a/proxmox-ve-config/src/firewall/types/mod.rs b/proxmox-ve-config/src/firewall/types/mod.rs
new file mode 100644
index 0000000..de534b4
--- /dev/null
+++ b/proxmox-ve-config/src/firewall/types/mod.rs
@@ -0,0 +1,3 @@
+pub mod address;
+
+pub use address::Cidr;
diff --git a/proxmox-ve-config/src/lib.rs b/proxmox-ve-config/src/lib.rs
index e69de29..a0734b8 100644
--- a/proxmox-ve-config/src/lib.rs
+++ b/proxmox-ve-config/src/lib.rs
@@ -0,0 +1 @@
+pub mod firewall;
--
2.39.2
More information about the pve-devel
mailing list