Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Allow surrogates in str #5587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Optimize Wtf8Codepoints::count
  • Loading branch information
coolreader18 committed Mar 26, 2025
commit bd55baefa6f39aff3c7e3a6e365c0ce1a187bccf
161 changes: 161 additions & 0 deletions 161 common/src/wtf8/core_str_count.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
//! Modified from core::str::count

use super::Wtf8;

const USIZE_SIZE: usize = core::mem::size_of::<usize>();
const UNROLL_INNER: usize = 4;

#[inline]
pub(super) fn count_chars(s: &Wtf8) -> usize {
if s.len() < USIZE_SIZE * UNROLL_INNER {
// Avoid entering the optimized implementation for strings where the
// difference is not likely to matter, or where it might even be slower.
// That said, a ton of thought was not spent on the particular threshold
// here, beyond "this value seems to make sense".
char_count_general_case(s.as_bytes())
} else {
do_count_chars(s)
}
}

fn do_count_chars(s: &Wtf8) -> usize {
// For correctness, `CHUNK_SIZE` must be:
//
// - Less than or equal to 255, otherwise we'll overflow bytes in `counts`.
// - A multiple of `UNROLL_INNER`, otherwise our `break` inside the
// `body.chunks(CHUNK_SIZE)` loop is incorrect.
//
// For performance, `CHUNK_SIZE` should be:
// - Relatively cheap to `/` against (so some simple sum of powers of two).
// - Large enough to avoid paying for the cost of the `sum_bytes_in_usize`
// too often.
const CHUNK_SIZE: usize = 192;

// Check the properties of `CHUNK_SIZE` and `UNROLL_INNER` that are required
// for correctness.
const _: () = assert!(CHUNK_SIZE < 256);
const _: () = assert!(CHUNK_SIZE % UNROLL_INNER == 0);

// SAFETY: transmuting `[u8]` to `[usize]` is safe except for size
// differences which are handled by `align_to`.
let (head, body, tail) = unsafe { s.as_bytes().align_to::<usize>() };

// This should be quite rare, and basically exists to handle the degenerate
// cases where align_to fails (as well as miri under symbolic alignment
// mode).
//
// The `unlikely` helps discourage LLVM from inlining the body, which is
// nice, as we would rather not mark the `char_count_general_case` function
// as cold.
if unlikely(body.is_empty() || head.len() > USIZE_SIZE || tail.len() > USIZE_SIZE) {
return char_count_general_case(s.as_bytes());
}

let mut total = char_count_general_case(head) + char_count_general_case(tail);
// Split `body` into `CHUNK_SIZE` chunks to reduce the frequency with which
// we call `sum_bytes_in_usize`.
for chunk in body.chunks(CHUNK_SIZE) {
// We accumulate intermediate sums in `counts`, where each byte contains
// a subset of the sum of this chunk, like a `[u8; size_of::<usize>()]`.
let mut counts = 0;

let (unrolled_chunks, remainder) = slice_as_chunks::<_, UNROLL_INNER>(chunk);
for unrolled in unrolled_chunks {
for &word in unrolled {
// Because `CHUNK_SIZE` is < 256, this addition can't cause the
// count in any of the bytes to overflow into a subsequent byte.
counts += contains_non_continuation_byte(word);
}
}

// Sum the values in `counts` (which, again, is conceptually a `[u8;
// size_of::<usize>()]`), and accumulate the result into `total`.
total += sum_bytes_in_usize(counts);

// If there's any data in `remainder`, then handle it. This will only
// happen for the last `chunk` in `body.chunks()` (because `CHUNK_SIZE`
// is divisible by `UNROLL_INNER`), so we explicitly break at the end
// (which seems to help LLVM out).
if !remainder.is_empty() {
// Accumulate all the data in the remainder.
let mut counts = 0;
for &word in remainder {
counts += contains_non_continuation_byte(word);
}
total += sum_bytes_in_usize(counts);
break;
}
}
total
}

// Checks each byte of `w` to see if it contains the first byte in a UTF-8
// sequence. Bytes in `w` which are continuation bytes are left as `0x00` (e.g.
// false), and bytes which are non-continuation bytes are left as `0x01` (e.g.
// true)
#[inline]
fn contains_non_continuation_byte(w: usize) -> usize {
const LSB: usize = usize_repeat_u8(0x01);
((!w >> 7) | (w >> 6)) & LSB
}

// Morally equivalent to `values.to_ne_bytes().into_iter().sum::<usize>()`, but
// more efficient.
#[inline]
fn sum_bytes_in_usize(values: usize) -> usize {
const LSB_SHORTS: usize = usize_repeat_u16(0x0001);
const SKIP_BYTES: usize = usize_repeat_u16(0x00ff);

let pair_sum: usize = (values & SKIP_BYTES) + ((values >> 8) & SKIP_BYTES);
pair_sum.wrapping_mul(LSB_SHORTS) >> ((USIZE_SIZE - 2) * 8)
}

// This is the most direct implementation of the concept of "count the number of
// bytes in the string which are not continuation bytes", and is used for the
// head and tail of the input string (the first and last item in the tuple
// returned by `slice::align_to`).
fn char_count_general_case(s: &[u8]) -> usize {
s.iter()
.filter(|&&byte| !super::core_str::utf8_is_cont_byte(byte))
.count()
}

// polyfills of unstable library features

const fn usize_repeat_u8(x: u8) -> usize {
usize::from_ne_bytes([x; size_of::<usize>()])
}

const fn usize_repeat_u16(x: u16) -> usize {
let mut r = 0usize;
let mut i = 0;
while i < size_of::<usize>() {
// Use `wrapping_shl` to make it work on targets with 16-bit `usize`
r = r.wrapping_shl(16) | (x as usize);
i += 2;
}
r
}

fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
assert!(N != 0, "chunk size must be non-zero");
let len_rounded_down = slice.len() / N * N;
// SAFETY: The rounded-down value is always the same or smaller than the
// original length, and thus must be in-bounds of the slice.
let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
// SAFETY: We already panicked for zero, and ensured by construction
// that the length of the subslice is a multiple of N.
let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
(array_slice, remainder)
}

unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
let new_len = slice.len() / N;
// SAFETY: We cast a slice of `new_len * N` elements into
// a slice of `new_len` many `N` elements chunks.
unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
}

fn unlikely(x: bool) -> bool {
x
}
14 changes: 12 additions & 2 deletions 14 common/src/wtf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ use bstr::{ByteSlice, ByteVec};

mod core_char;
mod core_str;
mod core_str_count;

const UTF8_REPLACEMENT_CHARACTER: &str = "\u{FFFD}";

Expand Down Expand Up @@ -1256,6 +1257,10 @@ impl Iterator for Wtf8CodePoints<'_> {
fn last(mut self) -> Option<Self::Item> {
self.next_back()
}

fn count(self) -> usize {
core_str_count::count_chars(self.as_wtf8())
}
}

impl DoubleEndedIterator for Wtf8CodePoints<'_> {
Expand All @@ -1277,8 +1282,8 @@ impl<'a> Wtf8CodePoints<'a> {

#[derive(Clone)]
pub struct Wtf8CodePointIndices<'a> {
pub(super) front_offset: usize,
pub(super) iter: Wtf8CodePoints<'a>,
front_offset: usize,
iter: Wtf8CodePoints<'a>,
}

impl Iterator for Wtf8CodePointIndices<'_> {
Expand Down Expand Up @@ -1308,6 +1313,11 @@ impl Iterator for Wtf8CodePointIndices<'_> {
// No need to go through the entire string.
self.next_back()
}

#[inline]
fn count(self) -> usize {
self.iter.count()
}
}

impl DoubleEndedIterator for Wtf8CodePointIndices<'_> {
Expand Down
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.