[pve-devel] [PATCH v3 conntrack-tool 1/4] initial commit

Mira Limbeck m.limbeck at proxmox.com
Tue Feb 16 17:56:39 CET 2021


Dumping conntrack information and importing conntrack information works
for IPv4 and IPv6. No filtering is supported for now. pve-conntrack-tool
will always return both IPv4 and IPv6 conntracks together.

Conntracks are serialized as JSON and printed on STDOUT line by line
with one line containing one conntrack. When inserting data is read
from STDIN line by line and expected to be one JSON object per line
representing the conntrack.

Currently some conntrack attributes are not supported. These are
HELPER_INFO, CONNLABELS and CONNLABELS_MASK. The reason for this is that
handling of variable length attributes does not seem to be correctly
implemented in libnetfilter_conntrack. To fix this we would probably have
to use libmnl directly.

Conntracks containing protonum 2 (IGMP) are ignored in the dump as
they can't be inserted using libnetfilter_conntrack (conntrack-tools'
conntrack also exhibits the same behavior).

Signed-off-by: Mira Limbeck <m.limbeck at proxmox.com>
---
v3:
 - split the functionality from Socket and into their own files
 - fixed MNL_SOCKET_BUFFER_SIZE to check configured page size at runtime
   (added lazy_static dependency)
 - changed is_ipv6 function to check the attribute keys instead of values
 - merged v2 patch 5 into patch 1 and 3 and fixed calling the closure in
   the callback
v2:
 - changed Conntracks to Socket
 - reworked a lot of the code for less code duplication
 - reduced usage of 'unsafe'
 - added/changed things based on @Wobu's suggestions (off-list)

 Cargo.toml                 |  15 ++
 src/conntrack.rs           | 338 +++++++++++++++++++++++++++++++++++++
 src/main.rs                |  67 ++++++++
 src/mnl.rs                 | 142 ++++++++++++++++
 src/netfilter_conntrack.rs | 168 ++++++++++++++++++
 src/socket.rs              | 104 ++++++++++++
 src/utils.rs               |  26 +++
 7 files changed, 860 insertions(+)
 create mode 100644 Cargo.toml
 create mode 100644 src/conntrack.rs
 create mode 100644 src/main.rs
 create mode 100644 src/mnl.rs
 create mode 100644 src/netfilter_conntrack.rs
 create mode 100644 src/socket.rs
 create mode 100644 src/utils.rs

diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..4936ec7
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,15 @@
+[package]
+name = "pve-conntrack-tool"
+version = "1.0.0"
+authors = ["Mira Limbeck <m.limbeck at proxmox.com>"]
+edition = "2018"
+license = "AGPL-3"
+
+exclude = [ "build", "debian" ]
+
+[dependencies]
+anyhow = "1.0.26"
+lazy_static = "1.4.0"
+libc = "0.2.79"
+serde = { version = "1.0.106", features = ["derive"] }
+serde_json = "1.0.41"
diff --git a/src/conntrack.rs b/src/conntrack.rs
new file mode 100644
index 0000000..6abd4a5
--- /dev/null
+++ b/src/conntrack.rs
@@ -0,0 +1,338 @@
+use crate::mnl::{IPCTNL_MSG_CT_GET, IPCTNL_MSG_CT_NEW, MNL_SOCKET_BUFFER_SIZE};
+use crate::netfilter_conntrack::{
+    nf_conntrack, nfct_attr_is_set, nfct_destroy, nfct_get_attr, nfct_get_attr_u16,
+    nfct_get_attr_u32, nfct_get_attr_u64, nfct_get_attr_u8, nfct_new, nfct_nlmsg_build,
+    nfct_nlmsg_parse, nfct_set_attr, nfct_set_attr_l, nfct_set_attr_u16, nfct_set_attr_u32,
+    nfct_set_attr_u64, nfct_set_attr_u8, CTAttr,
+};
+use crate::socket::Socket;
+use crate::utils::build_msg_header;
+
+use anyhow::{bail, Result};
+use serde::{Deserialize, Serialize};
+
+use std::convert::TryInto;
+use std::ffi::CString;
+
+const CONNTRACK_QUERY_MSG_TYPE: u16 =
+    ((libc::NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_GET) as u16;
+const CONNTRACK_QUERY_FLAGS: u16 =
+    (libc::NLM_F_ACK | libc::NLM_F_REQUEST | libc::NLM_F_DUMP) as u16;
+const CONNTRACK_INSERT_MSG_TYPE: u16 =
+    ((libc::NFNL_SUBSYS_CTNETLINK << 8) | IPCTNL_MSG_CT_NEW) as u16;
+const CONNTRACK_INSERT_FLAGS: u16 =
+    (libc::NLM_F_ACK | libc::NLM_F_REQUEST | libc::NLM_F_CREATE) as u16;
+
+enum AttrType {
+    U8,
+    U16,
+    U32,
+    U64,
+    U128,
+    String(Option<u32>),
+    VarLen,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+enum AttrValue {
+    U8(u8),
+    U16(u16),
+    U32(u32),
+    U64(u64),
+    U128([u32; 4]),
+    String(CString),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct Attr {
+    #[serde(rename = "type")]
+    key: CTAttr,
+    #[serde(flatten)]
+    value: AttrValue,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct Conntrack {
+    attributes: Vec<Attr>,
+}
+
+impl Conntrack {
+    fn is_ipv6(&self) -> bool {
+        for attr in self.attributes.iter() {
+            if IPV6_ATTRIBUTES.contains(&attr.key) {
+                return true;
+            }
+        }
+        false
+    }
+}
+
+fn build_nf_conntrack(ct: Conntrack) -> Result<(*mut nf_conntrack, Vec<CString>)> {
+    let cth = unsafe { nfct_new() };
+    if cth.is_null() {
+        bail!("Failed to create new conntrack object");
+    }
+
+    let mut strings = Vec::new();
+    for attr in ct.attributes {
+        match attr.value {
+            AttrValue::U8(v) => unsafe {
+                nfct_set_attr_u8(cth, attr.key, v);
+            },
+            AttrValue::U16(v) => unsafe {
+                nfct_set_attr_u16(cth, attr.key, v);
+            },
+            AttrValue::U32(v) => unsafe {
+                nfct_set_attr_u32(cth, attr.key, v);
+            },
+            AttrValue::U64(v) => unsafe {
+                nfct_set_attr_u64(cth, attr.key, v);
+            },
+            AttrValue::U128(v) => unsafe {
+                nfct_set_attr_l(cth, attr.key, v.as_ptr() as _, std::mem::size_of_val(&v));
+            },
+            AttrValue::String(v) => unsafe {
+                nfct_set_attr(cth, attr.key, v.as_ptr() as _);
+                strings.push(v);
+            },
+        }
+    }
+    Ok((cth, strings))
+}
+
+fn parse_nf_conntrack(ct: *const nf_conntrack) -> Option<Conntrack> {
+    let mut attributes = Vec::new();
+    for (attr, ty) in ALL_ATTRIBUTES {
+        if *attr == CTAttr::ID {
+            continue;
+        }
+        if unsafe { nfct_attr_is_set(ct, *attr) } == 0 {
+            continue;
+        }
+        // check for IGMP and skip it as we can't insert it again
+        if unsafe { nfct_get_attr_u8(ct, CTAttr::ORIG_L4PROTO) } == 2 {
+            return None;
+        }
+        match ty {
+            AttrType::U8 => {
+                let val = unsafe { nfct_get_attr_u8(ct, *attr) };
+                attributes.push(Attr {
+                    key: *attr,
+                    value: AttrValue::U8(val),
+                });
+            }
+            AttrType::U16 => {
+                let val = unsafe { nfct_get_attr_u16(ct, *attr) };
+                attributes.push(Attr {
+                    key: *attr,
+                    value: AttrValue::U16(val),
+                });
+            }
+            AttrType::U32 => {
+                let val = unsafe { nfct_get_attr_u32(ct, *attr) };
+                attributes.push(Attr {
+                    key: *attr,
+                    value: AttrValue::U32(val),
+                });
+            }
+            AttrType::U64 => {
+                let val = unsafe { nfct_get_attr_u64(ct, *attr) };
+                attributes.push(Attr {
+                    key: *attr,
+                    value: AttrValue::U64(val),
+                });
+            }
+            AttrType::U128 => {
+                let val = unsafe { nfct_get_attr(ct, *attr) } as *const u32;
+                let val = unsafe { std::slice::from_raw_parts(val, 4) }
+                .try_into()
+                    .unwrap();
+                attributes.push(Attr {
+                    key: *attr,
+                    value: AttrValue::U128(val),
+                });
+            }
+            AttrType::String(Some(len)) => {
+                let ptr = unsafe { nfct_get_attr(ct, *attr) };
+                let cstr = unsafe { std::ffi::CStr::from_ptr(ptr as _) };
+                let s = cstr.to_bytes();
+                let s = unsafe {
+                    CString::from_vec_unchecked(s[0..s.len().min(*len as _)].to_vec())
+                };
+                attributes.push(Attr {
+                    key: *attr,
+                    value: AttrValue::String(s),
+                });
+            }
+            AttrType::String(None) => {
+                let ptr = unsafe { nfct_get_attr(ct, *attr) };
+                let cstr = unsafe { std::ffi::CStr::from_ptr(ptr as _) };
+                let s = cstr.to_bytes();
+                let s = unsafe { CString::from_vec_unchecked(s.to_vec()) };
+                attributes.push(Attr {
+                    key: *attr,
+                    value: AttrValue::String(s),
+                });
+            }
+            // ignore VarLen case for now
+            AttrType::VarLen => {}
+        }
+    }
+
+    Some(Conntrack { attributes })
+}
+
+pub fn query_all(socket: &mut Socket) -> Result<Vec<Conntrack>> {
+    let mut cts = Vec::new();
+    let seq = socket.seq();
+    query_impl(socket, &mut cts, seq, libc::AF_INET as _)?;
+    let seq = socket.seq();
+    query_impl(socket, &mut cts, seq, libc::AF_INET6 as _)?;
+    Ok(cts)
+}
+
+fn query_impl(
+    socket: &mut Socket,
+    cts: &mut Vec<Conntrack>,
+    seq: u32,
+    proto: u8,
+) -> Result<()> {
+    let mut buf = vec![0u8; *MNL_SOCKET_BUFFER_SIZE as _];
+    let hdr = build_msg_header(
+        buf.as_mut_ptr() as _,
+        CONNTRACK_QUERY_MSG_TYPE,
+        CONNTRACK_QUERY_FLAGS,
+        seq,
+        proto,
+    );
+    let mut cb = |nlh| {
+        let ct = unsafe { nfct_new() };
+        unsafe {
+            nfct_nlmsg_parse(nlh, ct);
+        }
+
+        if let Some(conntrack) = parse_nf_conntrack(ct) {
+            cts.push(conntrack);
+        }
+
+        unsafe {
+            nfct_destroy(ct);
+        }
+    };
+    socket.send_and_receive(hdr, 0, &mut cb)
+}
+
+pub fn insert(socket: &mut Socket, ct: Conntrack) -> Result<()> {
+    let proto = if ct.is_ipv6() {
+        libc::AF_INET6 as u8
+    } else {
+        libc::AF_INET as u8
+    };
+
+    let mut buf = vec![0u8; *MNL_SOCKET_BUFFER_SIZE as _];
+    let hdr = build_msg_header(
+        buf.as_mut_ptr() as _,
+        CONNTRACK_INSERT_MSG_TYPE,
+        CONNTRACK_INSERT_FLAGS,
+        socket.seq(),
+        proto,
+    );
+
+    let (cth, _strings) = build_nf_conntrack(ct)?;
+
+    unsafe {
+        nfct_nlmsg_build(hdr, cth);
+        nfct_destroy(cth);
+    }
+
+    socket.send_and_receive(hdr, 0, &mut |_| {})
+}
+
+const ALL_ATTRIBUTES: &[(CTAttr, AttrType)] = &[
+    (CTAttr::ORIG_IPV4_SRC, AttrType::U32),        /* u32 bits */
+    (CTAttr::ORIG_IPV4_DST, AttrType::U32),        /* u32 bits */
+    (CTAttr::REPL_IPV4_SRC, AttrType::U32),        /* u32 bits */
+    (CTAttr::REPL_IPV4_DST, AttrType::U32),        /* u32 bits */
+    (CTAttr::ORIG_IPV6_SRC, AttrType::U128),       /* u128 bits */
+    (CTAttr::ORIG_IPV6_DST, AttrType::U128),       /* u128 bits */
+    (CTAttr::REPL_IPV6_SRC, AttrType::U128),       /* u128 bits */
+    (CTAttr::REPL_IPV6_DST, AttrType::U128),       /* u128 bits */
+    (CTAttr::ORIG_PORT_SRC, AttrType::U16),        /* u16 bits */
+    (CTAttr::ORIG_PORT_DST, AttrType::U16),        /* u16 bits */
+    (CTAttr::REPL_PORT_SRC, AttrType::U16),        /* u16 bits */
+    (CTAttr::REPL_PORT_DST, AttrType::U16),        /* u16 bits */
+    (CTAttr::ICMP_TYPE, AttrType::U8),             /* u8 bits */
+    (CTAttr::ICMP_CODE, AttrType::U8),             /* u8 bits */
+    (CTAttr::ICMP_ID, AttrType::U16),              /* u16 bits */
+    (CTAttr::ORIG_L3PROTO, AttrType::U8),          /* u8 bits */
+    (CTAttr::REPL_L3PROTO, AttrType::U8),          /* u8 bits */
+    (CTAttr::ORIG_L4PROTO, AttrType::U8),          /* u8 bits */
+    (CTAttr::REPL_L4PROTO, AttrType::U8),          /* u8 bits */
+    (CTAttr::TCP_STATE, AttrType::U8),             /* u8 bits */
+    (CTAttr::SNAT_IPV4, AttrType::U32),            /* u32 bits */
+    (CTAttr::DNAT_IPV4, AttrType::U32),            /* u32 bits */
+    (CTAttr::SNAT_PORT, AttrType::U16),            /* u16 bits */
+    (CTAttr::DNAT_PORT, AttrType::U16),            /* u16 bits */
+    (CTAttr::TIMEOUT, AttrType::U32),              /* u32 bits */
+    (CTAttr::MARK, AttrType::U32),                 /* u32 bits */
+    (CTAttr::ORIG_COUNTER_PACKETS, AttrType::U64), /* u64 bits */
+    (CTAttr::REPL_COUNTER_PACKETS, AttrType::U64), /* u64 bits */
+    (CTAttr::ORIG_COUNTER_BYTES, AttrType::U64),   /* u64 bits */
+    (CTAttr::REPL_COUNTER_BYTES, AttrType::U64),   /* u64 bits */
+    (CTAttr::USE, AttrType::U32),                  /* u32 bits */
+    (CTAttr::ID, AttrType::U32),                   /* u32 bits */
+    (CTAttr::STATUS, AttrType::U32),               /* u32 bits  */
+    (CTAttr::TCP_FLAGS_ORIG, AttrType::U8),        /* u8 bits */
+    (CTAttr::TCP_FLAGS_REPL, AttrType::U8),        /* u8 bits */
+    (CTAttr::TCP_MASK_ORIG, AttrType::U8),         /* u8 bits */
+    (CTAttr::TCP_MASK_REPL, AttrType::U8),         /* u8 bits */
+    (CTAttr::MASTER_IPV4_SRC, AttrType::U32),      /* u32 bits */
+    (CTAttr::MASTER_IPV4_DST, AttrType::U32),      /* u32 bits */
+    (CTAttr::MASTER_IPV6_SRC, AttrType::U128),     /* u128 bits */
+    (CTAttr::MASTER_IPV6_DST, AttrType::U128),     /* u128 bits */
+    (CTAttr::MASTER_PORT_SRC, AttrType::U16),      /* u16 bits */
+    (CTAttr::MASTER_PORT_DST, AttrType::U16),      /* u16 bits */
+    (CTAttr::MASTER_L3PROTO, AttrType::U8),        /* u8 bits */
+    (CTAttr::MASTER_L4PROTO, AttrType::U8),        /* u8 bits */
+    (CTAttr::SECMARK, AttrType::U32),              /* u32 bits */
+    (CTAttr::ORIG_NAT_SEQ_CORRECTION_POS, AttrType::U32), /* u32 bits */
+    (CTAttr::ORIG_NAT_SEQ_OFFSET_BEFORE, AttrType::U32), /* u32 bits */
+    (CTAttr::ORIG_NAT_SEQ_OFFSET_AFTER, AttrType::U32), /* u32 bits */
+    (CTAttr::REPL_NAT_SEQ_CORRECTION_POS, AttrType::U32), /* u32 bits */
+    (CTAttr::REPL_NAT_SEQ_OFFSET_BEFORE, AttrType::U32), /* u32 bits */
+    (CTAttr::REPL_NAT_SEQ_OFFSET_AFTER, AttrType::U32), /* u32 bits */
+    (CTAttr::SCTP_STATE, AttrType::U8),            /* u8 bits */
+    (CTAttr::SCTP_VTAG_ORIG, AttrType::U32),       /* u32 bits */
+    (CTAttr::SCTP_VTAG_REPL, AttrType::U32),       /* u32 bits */
+    (CTAttr::HELPER_NAME, AttrType::String(Some(30))), /* string (30 bytes max) */
+    (CTAttr::DCCP_STATE, AttrType::U8),            /* u8 bits */
+    (CTAttr::DCCP_ROLE, AttrType::U8),             /* u8 bits */
+    (CTAttr::DCCP_HANDSHAKE_SEQ, AttrType::U64),   /* u64 bits */
+    (CTAttr::TCP_WSCALE_ORIG, AttrType::U8),       /* u8 bits */
+    (CTAttr::TCP_WSCALE_REPL, AttrType::U8),       /* u8 bits */
+    (CTAttr::ZONE, AttrType::U16),                 /* u16 bits */
+    (CTAttr::SECCTX, AttrType::String(None)),      /* string */
+    (CTAttr::TIMESTAMP_START, AttrType::U64),      /* u64 bits; linux >= 2.6.38 */
+    (CTAttr::TIMESTAMP_STOP, AttrType::U64),       /* u64 bits; linux >= 2.6.38 */
+    (CTAttr::HELPER_INFO, AttrType::VarLen),       /* variable length */
+    (CTAttr::CONNLABELS, AttrType::VarLen),        /* variable length */
+    (CTAttr::CONNLABELS_MASK, AttrType::VarLen),   /* variable length */
+    (CTAttr::ORIG_ZONE, AttrType::U16),            /* u16 bits */
+    (CTAttr::REPL_ZONE, AttrType::U16),            /* u16 bits */
+    (CTAttr::SNAT_IPV6, AttrType::U128),           /* u128 bits */
+    (CTAttr::DNAT_IPV6, AttrType::U128),           /* u128 bits */
+    (CTAttr::SYNPROXY_ISN, AttrType::U32),         /* u32 bits */
+    (CTAttr::SYNPROXY_ITS, AttrType::U32),         /* u32 bits */
+    (CTAttr::SYNPROXY_TSOFF, AttrType::U32),       /* u32 bits */
+];
+
+const IPV6_ATTRIBUTES: &[CTAttr] = &[
+    CTAttr::ORIG_IPV6_SRC,
+    CTAttr::ORIG_IPV6_DST,
+    CTAttr::REPL_IPV6_SRC,
+    CTAttr::REPL_IPV6_DST,
+    CTAttr::MASTER_IPV6_SRC,
+    CTAttr::MASTER_IPV6_DST,
+    CTAttr::SNAT_IPV6,
+    CTAttr::DNAT_IPV6,
+];
diff --git a/src/main.rs b/src/main.rs
new file mode 100644
index 0000000..792d487
--- /dev/null
+++ b/src/main.rs
@@ -0,0 +1,67 @@
+mod mnl;
+mod netfilter_conntrack;
+
+use std::io::{stdin, BufRead, BufReader};
+use std::os::unix::ffi::OsStringExt;
+
+use anyhow::{bail, format_err, Result};
+
+mod socket;
+mod conntrack;
+mod utils;
+
+use socket::Socket;
+use conntrack::Conntrack;
+
+fn main() -> Result<()> {
+    let args = std::env::args_os()
+        .map(|os| String::from_utf8(os.into_vec()))
+        .try_fold(Vec::new(), |mut args, s| match s {
+            Ok(s) => {
+                args.push(s);
+                Ok(args)
+            }
+            Err(err) => bail!("Invalid UTF8 argument: {}", err),
+        })?;
+    if args.len() != 2 {
+        bail!("Either 'dump' or 'insert' command required.");
+    }
+
+    let mut socket = Socket::open()?;
+
+    if args[1] == "dump" {
+        let cts = conntrack::query_all(&mut socket)
+            .map_err(|err| format_err!("Error querying conntracks: {}", err))?;
+
+        for ct in cts.iter() {
+            match serde_json::to_string(ct) {
+                Ok(s) => println!("{}", s),
+                Err(err) => {
+                    eprintln!("Failed to serialize conntrack: {}", err);
+                    break;
+                }
+            }
+        }
+    } else if args[1] == "insert" {
+        for line in BufReader::new(stdin())
+            .lines()
+            .map(|line| line.unwrap_or_else(|_| "".to_string()))
+        {
+            let ct: Conntrack = match serde_json::from_str(&line) {
+                Ok(ct) => ct,
+                Err(err) => {
+                    eprintln!("Failed to deserialize conntrack: {}", err);
+                    break;
+                }
+            };
+            if let Err(err) = conntrack::insert(&mut socket, ct) {
+                eprintln!("Error inserting conntrack: {}", err);
+            }
+        }
+    } else {
+        bail!("Unknown command: {}", args[1]);
+    }
+
+    Ok(())
+}
+
diff --git a/src/mnl.rs b/src/mnl.rs
new file mode 100644
index 0000000..737f59a
--- /dev/null
+++ b/src/mnl.rs
@@ -0,0 +1,142 @@
+#![allow(dead_code)]
+use lazy_static::lazy_static;
+
+pub const MNL_SOCKET_AUTOPID: libc::c_int = 0;
+pub const MNL_SOCKET_DUMP_SIZE: libc::c_int = 32768;
+lazy_static! {
+    pub static ref MNL_SOCKET_BUFFER_SIZE: libc::c_long = {
+        let pagesize = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
+        if pagesize == -1 {
+            eprintln!("Error retrieving page size: {}", std::io::Error::last_os_error());
+            std::process::exit(1);
+        }
+        std::cmp::min(pagesize, 8192)
+    };
+}
+
+pub const MNL_CB_ERROR: libc::c_int = -1;
+pub const MNL_CB_STOP: libc::c_int = 0;
+pub const MNL_CB_OK: libc::c_int = 1;
+
+#[repr(C)]
+pub struct mnl_socket {
+    _private: [u8; 0],
+}
+
+pub const IPCTNL_MSG_CT_NEW: libc::c_int = 0;
+pub const IPCTNL_MSG_CT_GET: libc::c_int = 1;
+pub const IPCTNL_MSG_CT_DELETE: libc::c_int = 2;
+pub const IPCTNL_MSG_CT_GET_CTRZERO: libc::c_int = 3;
+pub const IPCTNL_MSG_CT_GET_STATS_CPU: libc::c_int = 4;
+pub const IPCTNL_MSG_CT_GET_STATS: libc::c_int = 5;
+pub const IPCTNL_MSG_CT_GET_DYING: libc::c_int = 6;
+pub const IPCTNL_MSG_CT_GET_UNCONFIRMED: libc::c_int = 7;
+pub const IPCTNL_MSG_MAX: libc::c_int = 8;
+
+pub const IPCTNL_MSG_EXP_NEW: libc::c_int = 0;
+pub const IPCTNL_MSG_EXP_GET: libc::c_int = 1;
+pub const IPCTNL_MSG_EXP_DELETE: libc::c_int = 2;
+pub const IPCTNL_MSG_EXP_GET_STATS_CPU: libc::c_int = 3;
+pub const IPCTNL_MSG_EXP_MAX: libc::c_int = 4;
+
+#[link(name = "mnl")]
+extern "C" {
+    pub fn mnl_socket_open(bus: libc::c_int) -> *mut mnl_socket;
+    pub fn mnl_socket_bind(
+        nl: *const mnl_socket,
+        groups: libc::c_uint,
+        pid: libc::pid_t,
+    ) -> libc::c_int;
+    pub fn mnl_socket_close(nl: *mut mnl_socket) -> libc::c_int;
+    pub fn mnl_socket_get_portid(nl: *mut mnl_socket) -> libc::c_uint;
+    pub fn mnl_socket_sendto(
+        nl: *mut mnl_socket,
+        buf: *const libc::c_void,
+        len: libc::size_t,
+    ) -> libc::ssize_t;
+    pub fn mnl_socket_recvfrom(
+        nl: *mut mnl_socket,
+        buf: *mut libc::c_void,
+        bufsiz: libc::size_t,
+    ) -> libc::ssize_t;
+
+    pub fn mnl_nlmsg_put_header(buf: *mut libc::c_void) -> *mut libc::nlmsghdr;
+    pub fn mnl_nlmsg_put_extra_header(
+        nlh: *mut libc::nlmsghdr,
+        size: libc::size_t,
+    ) -> *mut libc::c_void;
+    pub fn mnl_nlmsg_get_payload(nlh: *const libc::nlmsghdr) -> *mut libc::c_void;
+
+    pub fn mnl_cb_run(
+        buf: *const libc::c_void,
+        numbytes: libc::size_t,
+        seq: libc::c_uint,
+        portid: libc::c_uint,
+        cb_data: Option<mnl_cb_t>,
+        data: *mut libc::c_void,
+    ) -> libc::c_int;
+
+    pub fn mnl_attr_parse(
+        nlh: *const libc::nlmsghdr,
+        offset: libc::c_uint,
+        cb: mnl_attr_cb_t,
+        data: *mut libc::c_void,
+    ) -> libc::c_int;
+    pub fn mnl_attr_get_type(attr: *const libc::nlattr) -> u16;
+    pub fn mnl_attr_type_valid(attr: *const libc::nlattr, maxtype: u16) -> libc::c_int;
+
+    pub fn mnl_nlmsg_batch_start(buf: *mut libc::c_void, limit: libc::size_t) -> *mut mnl_nlmsg_batch;
+    pub fn mnl_nlmsg_batch_stop(b: *mut mnl_nlmsg_batch);
+    pub fn mnl_nlmsg_batch_next(b: *mut mnl_nlmsg_batch) -> bool;
+    pub fn mnl_nlmsg_batch_current(b: *mut mnl_nlmsg_batch) -> *mut libc::c_void;
+    pub fn mnl_nlmsg_batch_head(b: *mut mnl_nlmsg_batch) -> *mut libc::c_void;
+    pub fn mnl_nlmsg_batch_size(b: *mut mnl_nlmsg_batch) -> libc::size_t;
+}
+
+#[allow(non_camel_case_types)]
+pub type mnl_cb_t = extern "C" fn(nlh: *const libc::nlmsghdr, data: *mut libc::c_void) -> libc::c_int;
+#[allow(non_camel_case_types)]
+pub type mnl_attr_cb_t =
+    extern "C" fn(attr: *const libc::nlattr, data: *mut libc::c_void) -> libc::c_int;
+
+#[repr(C)]
+pub struct nfgenmsg {
+    pub nfgen_family: u8,
+    pub version: u8,
+    pub res_id: u16, // TODO any better solution for __be16?
+}
+
+pub const CTA_UNSPEC: u16 = 0;
+pub const CTA_TUPLE_ORIG: u16 = 1;
+pub const CTA_TUPLE_REPLY: u16 = 2;
+pub const CTA_STATUS: u16 = 3;
+pub const CTA_PROTOINFO: u16 = 4;
+pub const CTA_HELP: u16 = 5;
+pub const CTA_NAT_SRC: u16 = 6;
+pub const CTA_NAT: u16 = CTA_NAT_SRC; // backwards compatibility
+pub const CTA_TIMEOUT: u16 = 7;
+pub const CTA_MARK: u16 = 8;
+pub const CTA_COUNTERS_ORIG: u16 = 9;
+pub const CTA_COUNTERS_REPLY: u16 = 10;
+pub const CTA_USE: u16 = 11;
+pub const CTA_ID: u16 = 12;
+pub const CTA_NAT_DST: u16 = 13;
+pub const CTA_TUPLE_MASTER: u16 = 14;
+pub const CTA_SEQ_ADJ_ORIG: u16 = 15;
+pub const CTA_NAT_SEQ_ADJ_ORIG: u16 = CTA_SEQ_ADJ_ORIG;
+pub const CTA_SEQ_ADJ_REPLY: u16 = 16;
+pub const CTA_NAT_SEQ_ADJ_REPLY: u16 = CTA_SEQ_ADJ_REPLY;
+pub const CTA_SECMARK: u16 = 17; // obsolete
+pub const CTA_ZONE: u16 = 18;
+pub const CTA_SECCTX: u16 = 19;
+pub const CTA_TIMESTAMP: u16 = 20;
+pub const CTA_MARK_MASK: u16 = 21;
+pub const CTA_LABELS: u16 = 22;
+pub const CTA_LABELS_MASK: u16 = 23;
+pub const CTA_SYNPROXY: u16 = 24;
+pub const CTA_MAX: u16 = CTA_SYNPROXY;
+
+#[repr(C)]
+pub struct mnl_nlmsg_batch {
+    _private: [u8; 0],
+}
diff --git a/src/netfilter_conntrack.rs b/src/netfilter_conntrack.rs
new file mode 100644
index 0000000..a9e67e4
--- /dev/null
+++ b/src/netfilter_conntrack.rs
@@ -0,0 +1,168 @@
+#![allow(dead_code)]
+
+#[repr(C)]
+pub struct nf_conntrack {
+    _private: [u8; 0],
+}
+
+#[link(name = "netfilter_conntrack")]
+extern "C" {
+    pub fn nfct_new() -> *mut nf_conntrack;
+    pub fn nfct_destroy(ct: *mut nf_conntrack);
+
+    pub fn nfct_nlmsg_build(nlh: *mut libc::nlmsghdr, ct: *const nf_conntrack) -> libc::c_int;
+    pub fn nfct_nlmsg_parse(nlh: *const libc::nlmsghdr, ct: *mut nf_conntrack) -> libc::c_int;
+
+    pub fn nfct_snprintf(
+        buf: *mut libc::c_char,
+        size: libc::c_uint,
+        ct: *const nf_conntrack,
+        msg_type: libc::c_uint,
+        out_type: libc::c_uint,
+        out_flags: libc::c_uint,
+    ) -> libc::c_int;
+
+    pub fn nfct_setobjopt(ct: *mut nf_conntrack, option: libc::c_uint) -> libc::c_int;
+    pub fn nfct_getobjopt(ct: *const nf_conntrack, option: libc::c_uint) -> libc::c_int;
+
+    pub fn nfct_set_attr(ct: *mut nf_conntrack, type_: CTAttr, value: *const libc::c_void);
+    pub fn nfct_set_attr_u8(ct: *mut nf_conntrack, type_: CTAttr, value: u8);
+    pub fn nfct_set_attr_u16(ct: *mut nf_conntrack, type_: CTAttr, value: u16);
+    pub fn nfct_set_attr_u32(ct: *mut nf_conntrack, type_: CTAttr, value: u32);
+    pub fn nfct_set_attr_u64(ct: *mut nf_conntrack, type_: CTAttr, value: u64);
+    pub fn nfct_set_attr_l(ct: *mut nf_conntrack, type_: CTAttr, value: *const libc::c_void, len: libc::size_t);
+
+    pub fn nfct_get_attr(ct: *const nf_conntrack, type_: CTAttr) -> *const libc::c_void;
+    pub fn nfct_get_attr_u8(ct: *const nf_conntrack, type_: CTAttr) -> u8;
+    pub fn nfct_get_attr_u16(ct: *const nf_conntrack, type_: CTAttr) -> u16;
+    pub fn nfct_get_attr_u32(ct: *const nf_conntrack, type_: CTAttr) -> u32;
+    pub fn nfct_get_attr_u64(ct: *const nf_conntrack, type_: CTAttr) -> u64;
+
+    pub fn nfct_attr_is_set(ct: *const nf_conntrack, type_: CTAttr) -> libc::c_int;
+}
+
+// set option
+pub const NFCT_SOPT_UNDO_SNAT: u32 = 0;
+pub const NFCT_SOPT_UNDO_DNAT: u32 = 1;
+pub const NFCT_SOPT_UNDO_SPAT: u32 = 2;
+pub const NFCT_SOPT_UNDO_DPAT: u32 = 3;
+pub const NFCT_SOPT_SETUP_ORIGINAL: u32 = 4;
+pub const NFCT_SOPT_SETUP_REPLY: u32 = 5;
+
+// get option
+pub const NFCT_GOPT_IS_SNAT: u32 = 0;
+pub const NFCT_GOPT_IS_DNAT: u32 = 1;
+pub const NFCT_GOPT_IS_SPAT: u32 = 2;
+pub const NFCT_GOPT_IS_DPAT: u32 = 3;
+
+// output type
+pub const NFCT_O_PLAIN: u32 = 0;
+pub const NFCT_O_DEFAULT: u32 = NFCT_O_PLAIN;
+pub const NFCT_O_XML: u32 = 1;
+pub const NFCT_O_MAX: u32 = 2;
+
+// output flags
+pub const NFCT_OF_SHOW_LAYER3_BIT: u32 = 0;
+pub const NFCT_OF_SHOW_LAYER3: u32 = 1 << NFCT_OF_SHOW_LAYER3_BIT;
+pub const NFCT_OF_TIME_BIT: u32 = 1;
+pub const NFCT_OF_TIME: u32 = 1 << NFCT_OF_TIME_BIT;
+pub const NFCT_OF_ID_BIT: u32 = 2;
+pub const NFCT_OF_ID: u32 = 1 << NFCT_OF_ID_BIT;
+pub const NFCT_OF_TIMESTAMP_BIT: u32 = 3;
+pub const NFCT_OF_TIMESTAMP: u32 = 1 << NFCT_OF_TIMESTAMP_BIT;
+
+// message type
+pub const NFCT_T_UNKNOWN: u32 = 0;
+pub const NFCT_T_NEW_BIT: u32 = 0;
+pub const NFCT_T_NEW: u32 = 1 << NFCT_T_NEW_BIT;
+pub const NFCT_T_UPDATE_BIT: u32 = 1;
+pub const NFCT_T_UPDATE: u32 = 1 << NFCT_T_UPDATE_BIT;
+pub const NFCT_T_DESTROY_BIT: u32 = 2;
+pub const NFCT_T_DESTROY: u32 = 1 << NFCT_T_DESTROY_BIT;
+pub const NFCT_T_ALL: u32 = NFCT_T_NEW | NFCT_T_UPDATE | NFCT_T_DESTROY;
+pub const NFCT_T_ERROR_BIT: u32 = 31;
+pub const NFCT_T_ERROR: u32 = 1 << NFCT_T_ERROR_BIT;
+
+#[repr(u32)]
+#[non_exhaustive]
+#[derive(Debug, Copy, Clone, PartialEq)]
+#[derive(serde::Deserialize, serde::Serialize)]
+#[allow(non_camel_case_types)]
+pub enum CTAttr {
+    ORIG_IPV4_SRC = 0,			/* u32 bits */
+    ORIG_IPV4_DST = 1,			/* u32 bits */
+    REPL_IPV4_SRC = 2,		        /* u32 bits */
+    REPL_IPV4_DST = 3,			/* u32 bits */
+    ORIG_IPV6_SRC = 4,			/* u128 bits */
+    ORIG_IPV6_DST = 5,			/* u128 bits */
+    REPL_IPV6_SRC = 6,			/* u128 bits */
+    REPL_IPV6_DST = 7,			/* u128 bits */
+    ORIG_PORT_SRC = 8,			/* u16 bits */
+    ORIG_PORT_DST = 9,			/* u16 bits */
+    REPL_PORT_SRC = 10,			/* u16 bits */
+    REPL_PORT_DST = 11,			/* u16 bits */
+    ICMP_TYPE = 12,			/* u8 bits */
+    ICMP_CODE = 13,			/* u8 bits */
+    ICMP_ID = 14,			/* u16 bits */
+    ORIG_L3PROTO = 15,			/* u8 bits */
+    REPL_L3PROTO = 16,			/* u8 bits */
+    ORIG_L4PROTO = 17,			/* u8 bits */
+    REPL_L4PROTO = 18,			/* u8 bits */
+    TCP_STATE = 19,			/* u8 bits */
+    SNAT_IPV4 = 20,			/* u32 bits */
+    DNAT_IPV4 = 21,			/* u32 bits */
+    SNAT_PORT = 22,			/* u16 bits */
+    DNAT_PORT = 23,			/* u16 bits */
+    TIMEOUT = 24,			/* u32 bits */
+    MARK = 25,				/* u32 bits */
+    ORIG_COUNTER_PACKETS = 26,		/* u64 bits */
+    REPL_COUNTER_PACKETS = 27,		/* u64 bits */
+    ORIG_COUNTER_BYTES = 28,		/* u64 bits */
+    REPL_COUNTER_BYTES = 29,		/* u64 bits */
+    USE = 30,				/* u32 bits */
+    ID = 31,				/* u32 bits */
+    STATUS = 32,			/* u32 bits  */
+    TCP_FLAGS_ORIG = 33,		/* u8 bits */
+    TCP_FLAGS_REPL = 34,		/* u8 bits */
+    TCP_MASK_ORIG = 35,			/* u8 bits */
+    TCP_MASK_REPL = 36,		        /* u8 bits */
+    MASTER_IPV4_SRC = 37,		/* u32 bits */
+    MASTER_IPV4_DST = 38,		/* u32 bits */
+    MASTER_IPV6_SRC = 39,		/* u128 bits */
+    MASTER_IPV6_DST = 40,		/* u128 bits */
+    MASTER_PORT_SRC = 41,		/* u16 bits */
+    MASTER_PORT_DST = 42,		/* u16 bits */
+    MASTER_L3PROTO = 43,		/* u8 bits */
+    MASTER_L4PROTO = 44,		/* u8 bits */
+    SECMARK = 45,			/* u32 bits */
+    ORIG_NAT_SEQ_CORRECTION_POS = 46,	/* u32 bits */
+    ORIG_NAT_SEQ_OFFSET_BEFORE = 47,	/* u32 bits */
+    ORIG_NAT_SEQ_OFFSET_AFTER = 48,	/* u32 bits */
+    REPL_NAT_SEQ_CORRECTION_POS = 49,	/* u32 bits */
+    REPL_NAT_SEQ_OFFSET_BEFORE = 50,	/* u32 bits */
+    REPL_NAT_SEQ_OFFSET_AFTER = 51,	/* u32 bits */
+    SCTP_STATE = 52,			/* u8 bits */
+    SCTP_VTAG_ORIG = 53,		/* u32 bits */
+    SCTP_VTAG_REPL = 54,		/* u32 bits */
+    HELPER_NAME = 55,			/* string (30 bytes max) */
+    DCCP_STATE = 56,			/* u8 bits */
+    DCCP_ROLE = 57,			/* u8 bits */
+    DCCP_HANDSHAKE_SEQ = 58,		/* u64 bits */
+    TCP_WSCALE_ORIG = 59,		/* u8 bits */
+    TCP_WSCALE_REPL = 60,		/* u8 bits */
+    ZONE = 61,				/* u16 bits */
+    SECCTX = 62,			/* string */
+    TIMESTAMP_START = 63,		/* u64 bits, linux >= 2.6.38 */
+    TIMESTAMP_STOP = 64,		/* u64 bits, linux >= 2.6.38 */
+    HELPER_INFO = 65,			/* variable length */
+    CONNLABELS = 66,			/* variable length */
+    CONNLABELS_MASK = 67,		/* variable length */
+    ORIG_ZONE = 68,			/* u16 bits */
+    REPL_ZONE = 69,			/* u16 bits */
+    SNAT_IPV6 = 70,			/* u128 bits */
+    DNAT_IPV6 = 71,			/* u128 bits */
+    SYNPROXY_ISN = 72,			/* u32 bits */
+    SYNPROXY_ITS = 73,			/* u32 bits */
+    SYNPROXY_TSOFF = 74,		/* u32 bits */
+    MAX = 75,
+}
diff --git a/src/socket.rs b/src/socket.rs
new file mode 100644
index 0000000..0167c1e
--- /dev/null
+++ b/src/socket.rs
@@ -0,0 +1,104 @@
+use crate::mnl::{
+    mnl_cb_run, mnl_socket, mnl_socket_bind, mnl_socket_close, mnl_socket_get_portid,
+    mnl_socket_open, mnl_socket_recvfrom, mnl_socket_sendto, MNL_CB_OK, MNL_CB_STOP,
+    MNL_SOCKET_AUTOPID, MNL_SOCKET_DUMP_SIZE,
+};
+
+use anyhow::{bail, Result};
+
+use std::ptr::NonNull;
+
+pub struct Socket {
+    socket: NonNull<mnl_socket>,
+    seq: u32,
+}
+
+impl Socket {
+    pub fn open() -> Result<Self> {
+        let socket = unsafe { mnl_socket_open(libc::NETLINK_NETFILTER) };
+        let socket = match NonNull::new(socket) {
+            Some(s) => s,
+            None => {
+                let err = std::io::Error::last_os_error();
+                bail!("Failed to open MNL socket: {}", err)
+            }
+        };
+
+        let res = unsafe { mnl_socket_bind(socket.as_ptr(), 0, MNL_SOCKET_AUTOPID) };
+        if res < 0 {
+            let err = std::io::Error::last_os_error();
+            bail!("Failed to bind MNL socket: {}", err);
+        }
+
+        Ok(Self { socket, seq: 0 })
+    }
+
+    pub fn seq(&mut self) -> u32 {
+        self.seq += 1;
+        self.seq
+    }
+
+    pub fn send_and_receive(
+        &mut self,
+        msg: *const libc::nlmsghdr,
+        seq: u32,
+        mut cb: &mut dyn FnMut(*const libc::nlmsghdr),
+    ) -> Result<()> {
+        let res =
+            unsafe { mnl_socket_sendto(self.socket.as_ptr(), msg as _, (*msg).nlmsg_len as _) };
+        if res == -1 {
+            let err = std::io::Error::last_os_error();
+            bail!("Failed to send message: {}", err);
+        }
+
+        let portid = unsafe { mnl_socket_get_portid(self.socket.as_ptr()) };
+        let mut buffer = [0u8; MNL_SOCKET_DUMP_SIZE as _];
+
+        loop {
+            let res = unsafe {
+                mnl_socket_recvfrom(
+                    self.socket.as_ptr(),
+                    buffer.as_mut_ptr() as _,
+                    MNL_SOCKET_DUMP_SIZE as _,
+                )
+            };
+            if res == -1 {
+                let err = std::io::Error::last_os_error();
+                bail!("Failed to read message: {}", err);
+            }
+
+            let res = unsafe {
+                mnl_cb_run(
+                    buffer.as_ptr() as _,
+                    res as _,
+                    seq,
+                    portid,
+                    Some(callback),
+                    &mut cb as *mut _ as _,
+                )
+            };
+            if res == -1 {
+                let err = std::io::Error::last_os_error();
+                bail!("Failed to run callback: {}", err);
+            } else if res <= MNL_CB_STOP {
+                break;
+            }
+        }
+        Ok(())
+    }
+}
+
+impl Drop for Socket {
+    fn drop(&mut self) {
+        let res = unsafe { mnl_socket_close(self.socket.as_ptr()) };
+        if res < 0 {
+            eprintln!("Error closing socket");
+        }
+    }
+}
+
+extern "C" fn callback(nlh: *const libc::nlmsghdr, data_ptr: *mut libc::c_void) -> libc::c_int {
+    let cb = unsafe { &mut *(data_ptr as *mut &mut dyn FnMut(*const libc::nlmsghdr)) };
+    cb(nlh);
+    MNL_CB_OK
+}
diff --git a/src/utils.rs b/src/utils.rs
new file mode 100644
index 0000000..f4d0853
--- /dev/null
+++ b/src/utils.rs
@@ -0,0 +1,26 @@
+use crate::mnl::{mnl_nlmsg_put_extra_header, mnl_nlmsg_put_header, nfgenmsg};
+
+pub fn build_msg_header(
+    buf: *mut libc::c_void,
+    ty: u16,
+    flags: u16,
+    seq: u32,
+    proto: u8,
+) -> *mut libc::nlmsghdr {
+    let nlh = unsafe { mnl_nlmsg_put_header(buf) };
+    unsafe {
+        (*nlh).nlmsg_type = ty;
+        (*nlh).nlmsg_flags = flags;
+        (*nlh).nlmsg_seq = seq;
+    }
+
+    let nfh = unsafe {
+        mnl_nlmsg_put_extra_header(nlh, std::mem::size_of::<nfgenmsg>()) as *mut nfgenmsg
+    };
+    unsafe {
+        (*nfh).nfgen_family = proto;
+        (*nfh).version = libc::NFNETLINK_V0 as _;
+        (*nfh).res_id = 0;
+    }
+    nlh
+}
-- 
2.20.1






More information about the pve-devel mailing list