Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit a556bd7

Browse filesBrowse files
committed
Optimize Wtf8Codepoints::count
1 parent f239de0 commit a556bd7
Copy full SHA for a556bd7

File tree

2 files changed

+173
-2
lines changed
Filter options

2 files changed

+173
-2
lines changed

‎common/src/wtf8/core_str_count.rs

Copy file name to clipboard
+161Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//! Modified from core::str::count
2+
3+
use super::Wtf8;
4+
5+
const USIZE_SIZE: usize = core::mem::size_of::<usize>();
6+
const UNROLL_INNER: usize = 4;
7+
8+
#[inline]
9+
pub(super) fn count_chars(s: &Wtf8) -> usize {
10+
if s.len() < USIZE_SIZE * UNROLL_INNER {
11+
// Avoid entering the optimized implementation for strings where the
12+
// difference is not likely to matter, or where it might even be slower.
13+
// That said, a ton of thought was not spent on the particular threshold
14+
// here, beyond "this value seems to make sense".
15+
char_count_general_case(s.as_bytes())
16+
} else {
17+
do_count_chars(s)
18+
}
19+
}
20+
21+
fn do_count_chars(s: &Wtf8) -> usize {
22+
// For correctness, `CHUNK_SIZE` must be:
23+
//
24+
// - Less than or equal to 255, otherwise we'll overflow bytes in `counts`.
25+
// - A multiple of `UNROLL_INNER`, otherwise our `break` inside the
26+
// `body.chunks(CHUNK_SIZE)` loop is incorrect.
27+
//
28+
// For performance, `CHUNK_SIZE` should be:
29+
// - Relatively cheap to `/` against (so some simple sum of powers of two).
30+
// - Large enough to avoid paying for the cost of the `sum_bytes_in_usize`
31+
// too often.
32+
const CHUNK_SIZE: usize = 192;
33+
34+
// Check the properties of `CHUNK_SIZE` and `UNROLL_INNER` that are required
35+
// for correctness.
36+
const _: () = assert!(CHUNK_SIZE < 256);
37+
const _: () = assert!(CHUNK_SIZE % UNROLL_INNER == 0);
38+
39+
// SAFETY: transmuting `[u8]` to `[usize]` is safe except for size
40+
// differences which are handled by `align_to`.
41+
let (head, body, tail) = unsafe { s.as_bytes().align_to::<usize>() };
42+
43+
// This should be quite rare, and basically exists to handle the degenerate
44+
// cases where align_to fails (as well as miri under symbolic alignment
45+
// mode).
46+
//
47+
// The `unlikely` helps discourage LLVM from inlining the body, which is
48+
// nice, as we would rather not mark the `char_count_general_case` function
49+
// as cold.
50+
if unlikely(body.is_empty() || head.len() > USIZE_SIZE || tail.len() > USIZE_SIZE) {
51+
return char_count_general_case(s.as_bytes());
52+
}
53+
54+
let mut total = char_count_general_case(head) + char_count_general_case(tail);
55+
// Split `body` into `CHUNK_SIZE` chunks to reduce the frequency with which
56+
// we call `sum_bytes_in_usize`.
57+
for chunk in body.chunks(CHUNK_SIZE) {
58+
// We accumulate intermediate sums in `counts`, where each byte contains
59+
// a subset of the sum of this chunk, like a `[u8; size_of::<usize>()]`.
60+
let mut counts = 0;
61+
62+
let (unrolled_chunks, remainder) = slice_as_chunks::<_, UNROLL_INNER>(chunk);
63+
for unrolled in unrolled_chunks {
64+
for &word in unrolled {
65+
// Because `CHUNK_SIZE` is < 256, this addition can't cause the
66+
// count in any of the bytes to overflow into a subsequent byte.
67+
counts += contains_non_continuation_byte(word);
68+
}
69+
}
70+
71+
// Sum the values in `counts` (which, again, is conceptually a `[u8;
72+
// size_of::<usize>()]`), and accumulate the result into `total`.
73+
total += sum_bytes_in_usize(counts);
74+
75+
// If there's any data in `remainder`, then handle it. This will only
76+
// happen for the last `chunk` in `body.chunks()` (because `CHUNK_SIZE`
77+
// is divisible by `UNROLL_INNER`), so we explicitly break at the end
78+
// (which seems to help LLVM out).
79+
if !remainder.is_empty() {
80+
// Accumulate all the data in the remainder.
81+
let mut counts = 0;
82+
for &word in remainder {
83+
counts += contains_non_continuation_byte(word);
84+
}
85+
total += sum_bytes_in_usize(counts);
86+
break;
87+
}
88+
}
89+
total
90+
}
91+
92+
// Checks each byte of `w` to see if it contains the first byte in a UTF-8
93+
// sequence. Bytes in `w` which are continuation bytes are left as `0x00` (e.g.
94+
// false), and bytes which are non-continuation bytes are left as `0x01` (e.g.
95+
// true)
96+
#[inline]
97+
fn contains_non_continuation_byte(w: usize) -> usize {
98+
const LSB: usize = usize_repeat_u8(0x01);
99+
((!w >> 7) | (w >> 6)) & LSB
100+
}
101+
102+
// Morally equivalent to `values.to_ne_bytes().into_iter().sum::<usize>()`, but
103+
// more efficient.
104+
#[inline]
105+
fn sum_bytes_in_usize(values: usize) -> usize {
106+
const LSB_SHORTS: usize = usize_repeat_u16(0x0001);
107+
const SKIP_BYTES: usize = usize_repeat_u16(0x00ff);
108+
109+
let pair_sum: usize = (values & SKIP_BYTES) + ((values >> 8) & SKIP_BYTES);
110+
pair_sum.wrapping_mul(LSB_SHORTS) >> ((USIZE_SIZE - 2) * 8)
111+
}
112+
113+
// This is the most direct implementation of the concept of "count the number of
114+
// bytes in the string which are not continuation bytes", and is used for the
115+
// head and tail of the input string (the first and last item in the tuple
116+
// returned by `slice::align_to`).
117+
fn char_count_general_case(s: &[u8]) -> usize {
118+
s.iter()
119+
.filter(|&&byte| !super::core_str::utf8_is_cont_byte(byte))
120+
.count()
121+
}
122+
123+
// polyfills of unstable library features
124+
125+
const fn usize_repeat_u8(x: u8) -> usize {
126+
usize::from_ne_bytes([x; size_of::<usize>()])
127+
}
128+
129+
const fn usize_repeat_u16(x: u16) -> usize {
130+
let mut r = 0usize;
131+
let mut i = 0;
132+
while i < size_of::<usize>() {
133+
// Use `wrapping_shl` to make it work on targets with 16-bit `usize`
134+
r = r.wrapping_shl(16) | (x as usize);
135+
i += 2;
136+
}
137+
r
138+
}
139+
140+
fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
141+
assert!(N != 0, "chunk size must be non-zero");
142+
let len_rounded_down = slice.len() / N * N;
143+
// SAFETY: The rounded-down value is always the same or smaller than the
144+
// original length, and thus must be in-bounds of the slice.
145+
let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
146+
// SAFETY: We already panicked for zero, and ensured by construction
147+
// that the length of the subslice is a multiple of N.
148+
let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
149+
(array_slice, remainder)
150+
}
151+
152+
unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
153+
let new_len = slice.len() / N;
154+
// SAFETY: We cast a slice of `new_len * N` elements into
155+
// a slice of `new_len` many `N` elements chunks.
156+
unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
157+
}
158+
159+
fn unlikely(x: bool) -> bool {
160+
x
161+
}

‎common/src/wtf8/mod.rs

Copy file name to clipboardExpand all lines: common/src/wtf8/mod.rs
+12-2Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ use bstr::{ByteSlice, ByteVec};
5353

5454
mod core_char;
5555
mod core_str;
56+
mod core_str_count;
5657

5758
const UTF8_REPLACEMENT_CHARACTER: &str = "\u{FFFD}";
5859

@@ -1256,6 +1257,10 @@ impl Iterator for Wtf8CodePoints<'_> {
12561257
fn last(mut self) -> Option<Self::Item> {
12571258
self.next_back()
12581259
}
1260+
1261+
fn count(self) -> usize {
1262+
core_str_count::count_chars(self.as_wtf8())
1263+
}
12591264
}
12601265

12611266
impl DoubleEndedIterator for Wtf8CodePoints<'_> {
@@ -1277,8 +1282,8 @@ impl<'a> Wtf8CodePoints<'a> {
12771282

12781283
#[derive(Clone)]
12791284
pub struct Wtf8CodePointIndices<'a> {
1280-
pub(super) front_offset: usize,
1281-
pub(super) iter: Wtf8CodePoints<'a>,
1285+
front_offset: usize,
1286+
iter: Wtf8CodePoints<'a>,
12821287
}
12831288

12841289
impl Iterator for Wtf8CodePointIndices<'_> {
@@ -1308,6 +1313,11 @@ impl Iterator for Wtf8CodePointIndices<'_> {
13081313
// No need to go through the entire string.
13091314
self.next_back()
13101315
}
1316+
1317+
#[inline]
1318+
fn count(self) -> usize {
1319+
self.iter.count()
1320+
}
13111321
}
13121322

13131323
impl DoubleEndedIterator for Wtf8CodePointIndices<'_> {

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.