[pbs-devel] [PATCH proxmox 1/2] add tools/zero: add fast zero comparison code

Dominik Csapak d.csapak at proxmox.com
Fri Dec 11 13:08:57 CET 2020


that can make use of see/avx instructions where available

this is mostly a direct translation of qemu's util/bufferiszero.c

this is originally from Wolfgang Bumiller

Signed-off-by: Dominik Csapak <d.csapak at proxmox.com>
---
 proxmox/src/tools/mod.rs  |   1 +
 proxmox/src/tools/zero.rs | 233 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 234 insertions(+)
 create mode 100644 proxmox/src/tools/zero.rs

diff --git a/proxmox/src/tools/mod.rs b/proxmox/src/tools/mod.rs
index ff3a720..49418a4 100644
--- a/proxmox/src/tools/mod.rs
+++ b/proxmox/src/tools/mod.rs
@@ -20,6 +20,7 @@ pub mod serde;
 pub mod time;
 pub mod uuid;
 pub mod vec;
+pub mod zero;
 
 #[cfg(feature = "websocket")]
 pub mod websocket;
diff --git a/proxmox/src/tools/zero.rs b/proxmox/src/tools/zero.rs
new file mode 100644
index 0000000..7493262
--- /dev/null
+++ b/proxmox/src/tools/zero.rs
@@ -0,0 +1,233 @@
+#[cfg(test)]
+mod test {
+    use std::mem;
+
+    pub(super) fn do_zero_test(func: fn(&[u8]) -> bool, short: bool) {
+        let mut buf: [u8; 512] = unsafe { mem::zeroed() };
+        assert_eq!(func(&buf), true);
+        for i in 0..buf.len() {
+            buf[i] = 1;
+            assert_eq!(func(&buf), false);
+            buf[i] = 0;
+        }
+        if short {
+            for i in 0..8 {
+                assert_eq!(func(&buf[0..i+1]), true);
+                buf[i] = 1;
+                assert_eq!(func(&buf[0..i+1]), false);
+                buf[i] = 0;
+            }
+        }
+    }
+}
+
+//#[cfg(all(target_arch = "x86_64", target_feature = "sse4"))]
+#[cfg(target_arch = "x86_64")]
+mod x86_64 {
+    use std::arch::x86_64::*;
+
+    const BIT_OSXSAVE: u32 = 1<<27;
+    const BIT_SSE2:    u32 = 1<<26;
+    const BIT_SSE4_1:  u32 = 1<<19;
+    const BIT_AVX:     u32 = 1<<28;
+    const BIT_AVX2:    u32 = 1<< 5;
+
+    // Direct translation of buffer_zero_sse2() of qemu's util/bufferiszero.c
+    fn buffer_is_zero_sse2(buf_slice: &[u8]) -> bool {
+        unsafe {
+            let len = buf_slice.len();
+            let buf = buf_slice.as_ptr() as *const u8;
+            let mut t = _mm_loadu_si128(buf as *const __m128i);
+            let mut p = ((buf as usize + 5*16) & !0xf) as *const __m128i;
+            let e = ((buf as usize + len) & !0xf) as *const __m128i;
+            let zero: __m128i = _mm_setzero_si128();
+            while p <= e {
+                _mm_prefetch(p as *const i8, _MM_HINT_T0);
+                t = _mm_cmpeq_epi8(t, zero);
+                if _mm_movemask_epi8(t) != 0xFFFF {
+                    return false;
+                }
+                t = *p.offset(-4);
+                t = _mm_or_si128(t, *p.offset(-3));
+                t = _mm_or_si128(t, *p.offset(-2));
+                t = _mm_or_si128(t, *p.offset(-1));
+                p = p.offset(4);
+            }
+            t = _mm_or_si128(t, *e.offset(-3));
+            t = _mm_or_si128(t, *e.offset(-2));
+            t = _mm_or_si128(t, *e.offset(-1));
+            t = _mm_or_si128(t, _mm_loadu_si128(
+                buf.add(len-16) as *const __m128i));
+            return _mm_movemask_epi8(_mm_cmpeq_epi8(t, zero)) == 0xFFFF;
+        }
+    }
+    #[test]
+    fn test_sse2() {
+        super::test::do_zero_test(buffer_is_zero_sse2, false);
+    }
+
+    // Direct translation of buffer_zero_sse4() of qemu's util/bufferiszero.c
+    fn buffer_is_zero_sse4_1(buf_slice: &[u8]) -> bool {
+        unsafe {
+            let len = buf_slice.len();
+            let buf = buf_slice.as_ptr() as *const u8;
+            let mut t = _mm_loadu_si128(buf as *const __m128i);
+            let mut p = ((buf as usize + 5*16) & !0xf) as *const __m128i;
+            let e = ((buf as usize + len) & !0xf) as *const __m128i;
+            while p <= e {
+                _mm_prefetch(p as *const i8, _MM_HINT_T0);
+                if _mm_testz_si128(t, t) == 0 {
+                    return false;
+                }
+                t = *p.offset(-4);
+                t = _mm_or_si128(t, *p.offset(-3));
+                t = _mm_or_si128(t, *p.offset(-2));
+                t = _mm_or_si128(t, *p.offset(-1));
+                p = p.offset(4);
+            }
+            t = _mm_or_si128(t, *e.offset(-3));
+            t = _mm_or_si128(t, *e.offset(-2));
+            t = _mm_or_si128(t, *e.offset(-1));
+            t = _mm_or_si128(t, _mm_loadu_si128(
+                buf.add(len-16) as *const __m128i));
+            return _mm_testz_si128(t, t) != 0;
+        }
+    }
+    #[test]
+    fn test_sse4_1() {
+        super::test::do_zero_test(buffer_is_zero_sse4_1, false);
+    }
+
+    // Direct translation of buffer_zero_avx2() of qemu's util/bufferiszero.c
+    fn buffer_is_zero_avx2(buf_slice: &[u8]) -> bool {
+        unsafe {
+            let len = buf_slice.len();
+            let buf = buf_slice.as_ptr() as *const u8;
+            let mut t = _mm256_loadu_si256(buf as *const __m256i);
+            let mut p = ((buf as usize + 5*32) & !0x1f) as *const __m256i;
+            let e = ((buf as usize + len) & !0x1f) as *const __m256i;
+            if p <= e {
+                // loop over 32 byte aligned blocks of 128
+                while p <= e {
+                    _mm_prefetch(p as *const i8, _MM_HINT_T0);
+                    if _mm256_testz_si256(t, t) == 0 {
+                        return false;
+                    }
+                    t = *p.offset(-4);
+                    t = _mm256_or_si256(t, *p.offset(-3));
+                    t = _mm256_or_si256(t, *p.offset(-2));
+                    t = _mm256_or_si256(t, *p.offset(-1));
+                    p = p.offset(4);
+                }
+                t = _mm256_or_si256(t, _mm256_loadu_si256(buf.add(len - 4*32) as *const __m256i));
+                t = _mm256_or_si256(t, _mm256_loadu_si256(buf.add(len - 3*32) as *const __m256i));
+            } else {
+                t = _mm256_or_si256(t, _mm256_loadu_si256(
+                    buf.add(32) as *const __m256i));
+                if len > 128 {
+                    t = _mm256_or_si256(t, _mm256_loadu_si256(buf.add(len - 4*32) as *const __m256i));
+                    t = _mm256_or_si256(t, _mm256_loadu_si256(buf.add(len - 3*32) as *const __m256i));
+                }
+            }
+            t = _mm256_or_si256(t, _mm256_loadu_si256(buf.add(len - 2*32) as *const __m256i));
+            t = _mm256_or_si256(t, _mm256_loadu_si256(buf.add(len - 1*32) as *const __m256i));
+            return _mm256_testz_si256(t, t) != 0;
+        }
+    }
+    #[test]
+    fn test_avx2() {
+        super::test::do_zero_test(buffer_is_zero_avx2, false);
+    }
+
+    // From qemu's (util/bufferiszero.c) init_cpuid_cache() + init_accel()
+    pub(super) fn init() {
+        unsafe {
+            let (max, _) = __get_cpuid_max(0);
+            if max >= 1 {
+                let id = __cpuid(1);
+                let avx_bits = BIT_OSXSAVE | BIT_AVX;
+                if (id.ecx & avx_bits) == avx_bits && max >= 7 {
+                    let bv = _xgetbv(0);
+                    let id70 = __cpuid_count(7, 0);
+                    if (bv & 6) == 6 && (id70.ebx & BIT_AVX2) == BIT_AVX2 {
+                        super::BUFFER_IS_ZERO_FUNC = buffer_is_zero_avx2;
+                        return;
+                    }
+                }
+
+                if (id.ecx & BIT_SSE4_1) == BIT_SSE4_1 {
+                    super::BUFFER_IS_ZERO_FUNC = buffer_is_zero_sse4_1;
+                    return;
+                }
+                if (id.edx & BIT_SSE2) == BIT_SSE2 {
+                    super::BUFFER_IS_ZERO_FUNC = buffer_is_zero_sse2;
+                    return
+                }
+            }
+            super::BUFFER_IS_ZERO_FUNC = super::buffer_is_zero_compat;
+        }
+    }
+}
+
+fn buffer_is_zero_compat(buf: &[u8]) -> bool{
+    let len = buf.len();
+    if len < 8 {
+        return buf.iter().fold(0, |a, x| a|x) == 0;
+    }
+
+    unsafe {
+        let mut ptr = buf.as_ptr() as *const u64;
+        let end = ((buf.as_ptr() as usize + len) & !7) as *const u64;
+        let mut t = ptr.read_unaligned();
+        ptr = ((ptr as usize + 8) & !7) as *const u64;
+        while ptr.add(8) <= end {
+            // XXX: add a prefetch_read_data() once it's stable...
+            if t != 0 {
+                return false;
+            }
+
+            t = *ptr | *ptr.add(1) | *ptr.add(2) | *ptr.add(3)
+              | *ptr.add(4) | *ptr.add(5) | *ptr.add(6) | *ptr.add(7);
+            ptr = ptr.add(8);
+        }
+        while ptr < end {
+            t |= *ptr;
+            ptr = ptr.add(1);
+        }
+        t |= end.offset(-1).read_unaligned();
+        return t == 0;
+    }
+}
+
+#[test]
+fn test_zero_compat() {
+    test::do_zero_test(buffer_is_zero_compat, true);
+}
+
+static BUF_IS_ZERO_GUARD: ::std::sync::Once = ::std::sync::Once::new();
+pub(self) static mut BUFFER_IS_ZERO_FUNC: fn(&[u8]) -> bool
+    = first_buffer_is_zero;
+
+fn first_buffer_is_zero(buf: &[u8]) -> bool {
+    BUF_IS_ZERO_GUARD.call_once(|| {
+        if cfg!(target_arch = "x86_64") {
+            x86_64::init();
+        }
+    });
+    return unsafe { BUFFER_IS_ZERO_FUNC(buf) };
+}
+
+pub fn buffer_is_zero(buf: &[u8]) -> bool {
+    if buf.len() == 0 {
+        return false;
+    }
+    if buf.len() < 64 {
+        return buffer_is_zero_compat(buf);
+    }
+    return unsafe { BUFFER_IS_ZERO_FUNC(buf) };
+}
+
+#[test]
+fn test_initialization() {
+    test::do_zero_test(buffer_is_zero, false);
+}
-- 
2.20.1





More information about the pbs-devel mailing list