Skip to content

Commit 9340e5c

Browse files
committed
Auto merge of #103779 - the8472:simd-str-contains, r=thomcc
x86_64 SSE2 fast-path for str.contains(&str) and short needles Based on Wojciech Muła's [SIMD-friendly algorithms for substring searching](http://0x80.pl/articles/simd-strfind.html#sse-avx2) The two-way algorithm is Big-O efficient but it needs to preprocess the needle to find a "critical factorization" of it. This additional work is significant for short needles. Additionally it mostly advances needle.len() bytes at a time. The SIMD-based approach used here on the other hand can advance based on its vector width, which can exceed the needle length. Except for pathological cases, but due to being limited to small needles the worst case blowup is also small. benchmarks taken on a Zen2, compiled with `-Ccodegen-units=1`: ``` OLD: test str::bench_contains_16b_in_long ... bench: 504 ns/iter (+/- 14) = 5061 MB/s test str::bench_contains_2b_repeated_long ... bench: 948 ns/iter (+/- 175) = 2690 MB/s test str::bench_contains_32b_in_long ... bench: 445 ns/iter (+/- 6) = 5732 MB/s test str::bench_contains_bad_naive ... bench: 130 ns/iter (+/- 1) = 569 MB/s test str::bench_contains_bad_simd ... bench: 84 ns/iter (+/- 8) = 880 MB/s test str::bench_contains_equal ... bench: 142 ns/iter (+/- 7) = 394 MB/s test str::bench_contains_short_long ... bench: 677 ns/iter (+/- 25) = 3768 MB/s test str::bench_contains_short_short ... bench: 27 ns/iter (+/- 2) = 2074 MB/s NEW: test str::bench_contains_16b_in_long ... bench: 82 ns/iter (+/- 0) = 31109 MB/s test str::bench_contains_2b_repeated_long ... bench: 73 ns/iter (+/- 0) = 34945 MB/s test str::bench_contains_32b_in_long ... bench: 71 ns/iter (+/- 1) = 35929 MB/s test str::bench_contains_bad_naive ... bench: 7 ns/iter (+/- 0) = 10571 MB/s test str::bench_contains_bad_simd ... bench: 97 ns/iter (+/- 41) = 762 MB/s test str::bench_contains_equal ... bench: 4 ns/iter (+/- 0) = 14000 MB/s test str::bench_contains_short_long ... bench: 73 ns/iter (+/- 0) = 34945 MB/s test str::bench_contains_short_short ... bench: 12 ns/iter (+/- 0) = 4666 MB/s ```
2 parents 251831e + a2b2010 commit 9340e5c

File tree

3 files changed

+311
-12
lines changed

3 files changed

+311
-12
lines changed

Diff for: library/alloc/benches/str.rs

+58-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::iter::Iterator;
12
use test::{black_box, Bencher};
23

34
#[bench]
@@ -122,14 +123,13 @@ fn bench_contains_short_short(b: &mut Bencher) {
122123
let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
123124
let needle = "sit";
124125

126+
b.bytes = haystack.len() as u64;
125127
b.iter(|| {
126-
assert!(haystack.contains(needle));
128+
assert!(black_box(haystack).contains(black_box(needle)));
127129
})
128130
}
129131

130-
#[bench]
131-
fn bench_contains_short_long(b: &mut Bencher) {
132-
let haystack = "\
132+
static LONG_HAYSTACK: &str = "\
133133
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse quis lorem sit amet dolor \
134134
ultricies condimentum. Praesent iaculis purus elit, ac malesuada quam malesuada in. Duis sed orci \
135135
eros. Suspendisse sit amet magna mollis, mollis nunc luctus, imperdiet mi. Integer fringilla non \
@@ -164,10 +164,48 @@ feugiat. Etiam quis mauris vel risus luctus mattis a a nunc. Nullam orci quam, i
164164
vehicula in, porttitor ut nibh. Duis sagittis adipiscing nisl vitae congue. Donec mollis risus eu \
165165
leo suscipit, varius porttitor nulla porta. Pellentesque ut sem nec nisi euismod vehicula. Nulla \
166166
malesuada sollicitudin quam eu fermentum.";
167+
168+
#[bench]
169+
fn bench_contains_2b_repeated_long(b: &mut Bencher) {
170+
let haystack = LONG_HAYSTACK;
171+
let needle = "::";
172+
173+
b.bytes = haystack.len() as u64;
174+
b.iter(|| {
175+
assert!(!black_box(haystack).contains(black_box(needle)));
176+
})
177+
}
178+
179+
#[bench]
180+
fn bench_contains_short_long(b: &mut Bencher) {
181+
let haystack = LONG_HAYSTACK;
167182
let needle = "english";
168183

184+
b.bytes = haystack.len() as u64;
185+
b.iter(|| {
186+
assert!(!black_box(haystack).contains(black_box(needle)));
187+
})
188+
}
189+
190+
#[bench]
191+
fn bench_contains_16b_in_long(b: &mut Bencher) {
192+
let haystack = LONG_HAYSTACK;
193+
let needle = "english language";
194+
195+
b.bytes = haystack.len() as u64;
196+
b.iter(|| {
197+
assert!(!black_box(haystack).contains(black_box(needle)));
198+
})
199+
}
200+
201+
#[bench]
202+
fn bench_contains_32b_in_long(b: &mut Bencher) {
203+
let haystack = LONG_HAYSTACK;
204+
let needle = "the english language sample text";
205+
206+
b.bytes = haystack.len() as u64;
169207
b.iter(|| {
170-
assert!(!haystack.contains(needle));
208+
assert!(!black_box(haystack).contains(black_box(needle)));
171209
})
172210
}
173211

@@ -176,8 +214,20 @@ fn bench_contains_bad_naive(b: &mut Bencher) {
176214
let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
177215
let needle = "aaaaaaaab";
178216

217+
b.bytes = haystack.len() as u64;
218+
b.iter(|| {
219+
assert!(!black_box(haystack).contains(black_box(needle)));
220+
})
221+
}
222+
223+
#[bench]
224+
fn bench_contains_bad_simd(b: &mut Bencher) {
225+
let haystack = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
226+
let needle = "aaabaaaa";
227+
228+
b.bytes = haystack.len() as u64;
179229
b.iter(|| {
180-
assert!(!haystack.contains(needle));
230+
assert!(!black_box(haystack).contains(black_box(needle)));
181231
})
182232
}
183233

@@ -186,8 +236,9 @@ fn bench_contains_equal(b: &mut Bencher) {
186236
let haystack = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
187237
let needle = "Lorem ipsum dolor sit amet, consectetur adipiscing elit.";
188238

239+
b.bytes = haystack.len() as u64;
189240
b.iter(|| {
190-
assert!(haystack.contains(needle));
241+
assert!(black_box(haystack).contains(black_box(needle)));
191242
})
192243
}
193244

Diff for: library/alloc/tests/str.rs

+21-5
Original file line numberDiff line numberDiff line change
@@ -1590,11 +1590,27 @@ fn test_bool_from_str() {
15901590
assert_eq!("not even a boolean".parse::<bool>().ok(), None);
15911591
}
15921592

1593-
fn check_contains_all_substrings(s: &str) {
1594-
assert!(s.contains(""));
1595-
for i in 0..s.len() {
1596-
for j in i + 1..=s.len() {
1597-
assert!(s.contains(&s[i..j]));
1593+
fn check_contains_all_substrings(haystack: &str) {
1594+
let mut modified_needle = String::new();
1595+
1596+
for i in 0..haystack.len() {
1597+
// check different haystack lengths since we special-case short haystacks.
1598+
let haystack = &haystack[0..i];
1599+
assert!(haystack.contains(""));
1600+
for j in 0..haystack.len() {
1601+
for k in j + 1..=haystack.len() {
1602+
let needle = &haystack[j..k];
1603+
assert!(haystack.contains(needle));
1604+
modified_needle.clear();
1605+
modified_needle.push_str(needle);
1606+
modified_needle.replace_range(0..1, "\0");
1607+
assert!(!haystack.contains(&modified_needle));
1608+
1609+
modified_needle.clear();
1610+
modified_needle.push_str(needle);
1611+
modified_needle.replace_range(needle.len() - 1..needle.len(), "\0");
1612+
assert!(!haystack.contains(&modified_needle));
1613+
}
15981614
}
15991615
}
16001616
}

Diff for: library/core/src/str/pattern.rs

+232
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)]
4040

4141
use crate::cmp;
42+
use crate::cmp::Ordering;
4243
use crate::fmt;
4344
use crate::slice::memchr;
4445

@@ -946,6 +947,32 @@ impl<'a, 'b> Pattern<'a> for &'b str {
946947
haystack.as_bytes().starts_with(self.as_bytes())
947948
}
948949

950+
/// Checks whether the pattern matches anywhere in the haystack
951+
#[inline]
952+
fn is_contained_in(self, haystack: &'a str) -> bool {
953+
if self.len() == 0 {
954+
return true;
955+
}
956+
957+
match self.len().cmp(&haystack.len()) {
958+
Ordering::Less => {
959+
if self.len() == 1 {
960+
return haystack.as_bytes().contains(&self.as_bytes()[0]);
961+
}
962+
963+
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
964+
if self.len() <= 32 {
965+
if let Some(result) = simd_contains(self, haystack) {
966+
return result;
967+
}
968+
}
969+
970+
self.into_searcher(haystack).next_match().is_some()
971+
}
972+
_ => self == haystack,
973+
}
974+
}
975+
949976
/// Removes the pattern from the front of haystack, if it matches.
950977
#[inline]
951978
fn strip_prefix_of(self, haystack: &'a str) -> Option<&'a str> {
@@ -1684,3 +1711,208 @@ impl TwoWayStrategy for RejectAndMatch {
16841711
SearchStep::Match(a, b)
16851712
}
16861713
}
1714+
1715+
/// SIMD search for short needles based on
1716+
/// Wojciech Muła's "SIMD-friendly algorithms for substring searching"[0]
1717+
///
1718+
/// It skips ahead by the vector width on each iteration (rather than the needle length as two-way
1719+
/// does) by probing the first and last byte of the needle for the whole vector width
1720+
/// and only doing full needle comparisons when the vectorized probe indicated potential matches.
1721+
///
1722+
/// Since the x86_64 baseline only offers SSE2 we only use u8x16 here.
1723+
/// If we ever ship std with for x86-64-v3 or adapt this for other platforms then wider vectors
1724+
/// should be evaluated.
1725+
///
1726+
/// For haystacks smaller than vector-size + needle length it falls back to
1727+
/// a naive O(n*m) search so this implementation should not be called on larger needles.
1728+
///
1729+
/// [0]: http://0x80.pl/articles/simd-strfind.html#sse-avx2
1730+
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
1731+
#[inline]
1732+
fn simd_contains(needle: &str, haystack: &str) -> Option<bool> {
1733+
let needle = needle.as_bytes();
1734+
let haystack = haystack.as_bytes();
1735+
1736+
debug_assert!(needle.len() > 1);
1737+
1738+
use crate::ops::BitAnd;
1739+
use crate::simd::mask8x16 as Mask;
1740+
use crate::simd::u8x16 as Block;
1741+
use crate::simd::{SimdPartialEq, ToBitMask};
1742+
1743+
let first_probe = needle[0];
1744+
1745+
// the offset used for the 2nd vector
1746+
let second_probe_offset = if needle.len() == 2 {
1747+
// never bail out on len=2 needles because the probes will fully cover them and have
1748+
// no degenerate cases.
1749+
1
1750+
} else {
1751+
// try a few bytes in case first and last byte of the needle are the same
1752+
let Some(second_probe_offset) = (needle.len().saturating_sub(4)..needle.len()).rfind(|&idx| needle[idx] != first_probe) else {
1753+
// fall back to other search methods if we can't find any different bytes
1754+
// since we could otherwise hit some degenerate cases
1755+
return None;
1756+
};
1757+
second_probe_offset
1758+
};
1759+
1760+
// do a naive search if the haystack is too small to fit
1761+
if haystack.len() < Block::LANES + second_probe_offset {
1762+
return Some(haystack.windows(needle.len()).any(|c| c == needle));
1763+
}
1764+
1765+
let first_probe: Block = Block::splat(first_probe);
1766+
let second_probe: Block = Block::splat(needle[second_probe_offset]);
1767+
// first byte are already checked by the outer loop. to verify a match only the
1768+
// remainder has to be compared.
1769+
let trimmed_needle = &needle[1..];
1770+
1771+
// this #[cold] is load-bearing, benchmark before removing it...
1772+
let check_mask = #[cold]
1773+
|idx, mask: u16, skip: bool| -> bool {
1774+
if skip {
1775+
return false;
1776+
}
1777+
1778+
// and so is this. optimizations are weird.
1779+
let mut mask = mask;
1780+
1781+
while mask != 0 {
1782+
let trailing = mask.trailing_zeros();
1783+
let offset = idx + trailing as usize + 1;
1784+
// SAFETY: mask is between 0 and 15 trailing zeroes, we skip one additional byte that was already compared
1785+
// and then take trimmed_needle.len() bytes. This is within the bounds defined by the outer loop
1786+
unsafe {
1787+
let sub = haystack.get_unchecked(offset..).get_unchecked(..trimmed_needle.len());
1788+
if small_slice_eq(sub, trimmed_needle) {
1789+
return true;
1790+
}
1791+
}
1792+
mask &= !(1 << trailing);
1793+
}
1794+
return false;
1795+
};
1796+
1797+
let test_chunk = |idx| -> u16 {
1798+
// SAFETY: this requires at least LANES bytes being readable at idx
1799+
// that is ensured by the loop ranges (see comments below)
1800+
let a: Block = unsafe { haystack.as_ptr().add(idx).cast::<Block>().read_unaligned() };
1801+
// SAFETY: this requires LANES + block_offset bytes being readable at idx
1802+
let b: Block = unsafe {
1803+
haystack.as_ptr().add(idx).add(second_probe_offset).cast::<Block>().read_unaligned()
1804+
};
1805+
let eq_first: Mask = a.simd_eq(first_probe);
1806+
let eq_last: Mask = b.simd_eq(second_probe);
1807+
let both = eq_first.bitand(eq_last);
1808+
let mask = both.to_bitmask();
1809+
1810+
return mask;
1811+
};
1812+
1813+
let mut i = 0;
1814+
let mut result = false;
1815+
// The loop condition must ensure that there's enough headroom to read LANE bytes,
1816+
// and not only at the current index but also at the index shifted by block_offset
1817+
const UNROLL: usize = 4;
1818+
while i + second_probe_offset + UNROLL * Block::LANES < haystack.len() && !result {
1819+
let mut masks = [0u16; UNROLL];
1820+
for j in 0..UNROLL {
1821+
masks[j] = test_chunk(i + j * Block::LANES);
1822+
}
1823+
for j in 0..UNROLL {
1824+
let mask = masks[j];
1825+
if mask != 0 {
1826+
result |= check_mask(i + j * Block::LANES, mask, result);
1827+
}
1828+
}
1829+
i += UNROLL * Block::LANES;
1830+
}
1831+
while i + second_probe_offset + Block::LANES < haystack.len() && !result {
1832+
let mask = test_chunk(i);
1833+
if mask != 0 {
1834+
result |= check_mask(i, mask, result);
1835+
}
1836+
i += Block::LANES;
1837+
}
1838+
1839+
// Process the tail that didn't fit into LANES-sized steps.
1840+
// This simply repeats the same procedure but as right-aligned chunk instead
1841+
// of a left-aligned one. The last byte must be exactly flush with the string end so
1842+
// we don't miss a single byte or read out of bounds.
1843+
let i = haystack.len() - second_probe_offset - Block::LANES;
1844+
let mask = test_chunk(i);
1845+
if mask != 0 {
1846+
result |= check_mask(i, mask, result);
1847+
}
1848+
1849+
Some(result)
1850+
}
1851+
1852+
/// Compares short slices for equality.
1853+
///
1854+
/// It avoids a call to libc's memcmp which is faster on long slices
1855+
/// due to SIMD optimizations but it incurs a function call overhead.
1856+
///
1857+
/// # Safety
1858+
///
1859+
/// Both slices must have the same length.
1860+
#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))] // only called on x86
1861+
#[inline]
1862+
unsafe fn small_slice_eq(x: &[u8], y: &[u8]) -> bool {
1863+
// This function is adapted from
1864+
// https://github.com/BurntSushi/memchr/blob/8037d11b4357b0f07be2bb66dc2659d9cf28ad32/src/memmem/util.rs#L32
1865+
1866+
// If we don't have enough bytes to do 4-byte at a time loads, then
1867+
// fall back to the naive slow version.
1868+
//
1869+
// Potential alternative: We could do a copy_nonoverlapping combined with a mask instead
1870+
// of a loop. Benchmark it.
1871+
if x.len() < 4 {
1872+
for (&b1, &b2) in x.iter().zip(y) {
1873+
if b1 != b2 {
1874+
return false;
1875+
}
1876+
}
1877+
return true;
1878+
}
1879+
// When we have 4 or more bytes to compare, then proceed in chunks of 4 at
1880+
// a time using unaligned loads.
1881+
//
1882+
// Also, why do 4 byte loads instead of, say, 8 byte loads? The reason is
1883+
// that this particular version of memcmp is likely to be called with tiny
1884+
// needles. That means that if we do 8 byte loads, then a higher proportion
1885+
// of memcmp calls will use the slower variant above. With that said, this
1886+
// is a hypothesis and is only loosely supported by benchmarks. There's
1887+
// likely some improvement that could be made here. The main thing here
1888+
// though is to optimize for latency, not throughput.
1889+
1890+
// SAFETY: Via the conditional above, we know that both `px` and `py`
1891+
// have the same length, so `px < pxend` implies that `py < pyend`.
1892+
// Thus, derefencing both `px` and `py` in the loop below is safe.
1893+
//
1894+
// Moreover, we set `pxend` and `pyend` to be 4 bytes before the actual
1895+
// end of of `px` and `py`. Thus, the final dereference outside of the
1896+
// loop is guaranteed to be valid. (The final comparison will overlap with
1897+
// the last comparison done in the loop for lengths that aren't multiples
1898+
// of four.)
1899+
//
1900+
// Finally, we needn't worry about alignment here, since we do unaligned
1901+
// loads.
1902+
unsafe {
1903+
let (mut px, mut py) = (x.as_ptr(), y.as_ptr());
1904+
let (pxend, pyend) = (px.add(x.len() - 4), py.add(y.len() - 4));
1905+
while px < pxend {
1906+
let vx = (px as *const u32).read_unaligned();
1907+
let vy = (py as *const u32).read_unaligned();
1908+
if vx != vy {
1909+
return false;
1910+
}
1911+
px = px.add(4);
1912+
py = py.add(4);
1913+
}
1914+
let vx = (pxend as *const u32).read_unaligned();
1915+
let vy = (pyend as *const u32).read_unaligned();
1916+
vx == vy
1917+
}
1918+
}

0 commit comments

Comments
 (0)