diff --git a/Cargo.lock b/Cargo.lock index b214780cb4..1fa60eac29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2338,6 +2338,7 @@ version = "0.4.0" dependencies = [ "ascii", "bitflags 2.8.0", + "bstr", "cfg-if", "itertools 0.14.0", "libc", @@ -2345,6 +2346,7 @@ dependencies = [ "malachite-base", "malachite-bigint", "malachite-q", + "memchr", "num-complex", "num-traits", "once_cell", @@ -2472,6 +2474,7 @@ dependencies = [ "criterion", "num_enum", "optional", + "rustpython-common", ] [[package]] diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index ea82b0166c..6f402513fd 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -1066,8 +1066,6 @@ def test_hash(self): hash(b) self.assertEqual(hash(a), hash(b)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_capitalize_nonascii(self): # check that titlecased chars are lowered correctly # \u1ffc is the titlecased char diff --git a/Lib/test/test_cmd_line_script.py b/Lib/test/test_cmd_line_script.py index e40069d780..833dc6b15d 100644 --- a/Lib/test/test_cmd_line_script.py +++ b/Lib/test/test_cmd_line_script.py @@ -574,6 +574,7 @@ def test_pep_409_verbiage(self): self.assertTrue(text[1].startswith(' File ')) self.assertTrue(text[3].startswith('NameError')) + @unittest.expectedFailureIf(sys.platform == "linux", "TODO: RUSTPYTHON") def test_non_ascii(self): # Mac OS X denies the creation of a file with an invalid UTF-8 name. # Windows allows creating a name with an arbitrary bytes name, but diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index f29e91e088..df04653c66 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -1698,8 +1698,6 @@ def test_decode_invalid(self): class NameprepTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_nameprep(self): from encodings.idna import nameprep for pos, (orig, prepped) in enumerate(nameprep_tests): diff --git a/Lib/test/test_difflib.py b/Lib/test/test_difflib.py index 5592a2d5a3..0d669afe61 100644 --- a/Lib/test/test_difflib.py +++ b/Lib/test/test_difflib.py @@ -373,8 +373,6 @@ def test_byte_content(self): check(difflib.diff_bytes(context, a, a, b'a', b'a', b'2005', b'2013')) check(difflib.diff_bytes(context, a, b, b'a', b'b', b'2005', b'2013')) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_byte_filenames(self): # somebody renamed a file from ISO-8859-2 to UTF-8 fna = b'\xb3odz.txt' # "łodz.txt" diff --git a/Lib/test/test_import/__init__.py b/Lib/test/test_import/__init__.py index c2f181cc86..89e5ec1534 100644 --- a/Lib/test/test_import/__init__.py +++ b/Lib/test/test_import/__init__.py @@ -1305,6 +1305,8 @@ def exec_module(*args): else: importlib.SourceLoader.exec_module = old_exec_module + # TODO: RUSTPYTHON + @unittest.expectedFailure @unittest.skipUnless(TESTFN_UNENCODABLE, 'need TESTFN_UNENCODABLE') def test_unencodable_filename(self): # Issue #11619: The Python parser and the import machinery must not diff --git a/Lib/test/test_json/test_scanstring.py b/Lib/test/test_json/test_scanstring.py index 140a7c12a2..682dc74999 100644 --- a/Lib/test/test_json/test_scanstring.py +++ b/Lib/test/test_json/test_scanstring.py @@ -143,10 +143,4 @@ def test_overflow(self): class TestPyScanstring(TestScanstring, PyTest): pass -# TODO: RUSTPYTHON -class TestPyScanstring(TestScanstring, PyTest): - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_bad_escapes(self): - super().test_bad_escapes() class TestCScanstring(TestScanstring, CTest): pass diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py index 8e23b88676..7609ecea79 100644 --- a/Lib/test/test_ntpath.py +++ b/Lib/test/test_ntpath.py @@ -1032,12 +1032,6 @@ class NtCommonTest(test_genericpath.CommonTest, unittest.TestCase): pathmodule = ntpath attributes = ['relpath'] - # TODO: RUSTPYTHON - if sys.platform == "linux": - @unittest.expectedFailure - def test_nonascii_abspath(self): - super().test_nonascii_abspath() - # TODO: RUSTPYTHON if sys.platform == "win32": # TODO: RUSTPYTHON, ValueError: illegal environment variable name diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py index a060a3deef..3c0b6af3c6 100644 --- a/Lib/test/test_re.py +++ b/Lib/test/test_re.py @@ -854,8 +854,6 @@ def test_string_boundaries(self): # Can match around the whitespace. self.assertEqual(len(re.findall(r"\B", " ")), 2) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bigcharset(self): self.assertEqual(re.match("([\u2222\u2223])", "\u2222").group(1), "\u2222") @@ -2233,6 +2231,7 @@ def test_bug_40736(self): with self.assertRaisesRegex(TypeError, "got 'type'"): re.search("x*", type) + @unittest.skip("TODO: RUSTPYTHON: flaky, improve perf") @requires_resource('cpu') def test_search_anchor_at_beginning(self): s = 'x'*10**7 diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py index a36d7bbe2a..9b787950fc 100644 --- a/Lib/test/test_smtplib.py +++ b/Lib/test/test_smtplib.py @@ -1459,8 +1459,6 @@ def test_send_unicode_with_SMTPUTF8_via_low_level_API(self): self.assertIn('SMTPUTF8', self.serv.last_mail_options) self.assertEqual(self.serv.last_rcpt_options, []) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_send_message_uses_smtputf8_if_addrs_non_ascii(self): msg = EmailMessage() msg['From'] = "Páolo " diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 5b82853102..0e3eb08b82 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -1578,7 +1578,7 @@ def test_getnameinfo(self): # only IP addresses are allowed self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0) - @unittest.expectedFailureIf(sys.platform != "darwin", "TODO: RUSTPYTHON; socket.gethostbyname_ex") + @unittest.skip("TODO: RUSTPYTHON: flaky on CI?") @unittest.skipUnless(support.is_resource_enabled('network'), 'network is not enabled') def test_idna(self): @@ -5519,8 +5519,6 @@ def testBytesAddr(self): self.addCleanup(os_helper.unlink, path) self.assertEqual(self.sock.getsockname(), path) - # TODO: RUSTPYTHON, surrogateescape - @unittest.expectedFailure def testSurrogateescapeBind(self): # Test binding to a valid non-ASCII pathname, with the # non-ASCII bytes supplied using surrogateescape encoding. diff --git a/Lib/test/test_sqlite3/test_types.py b/Lib/test/test_sqlite3/test_types.py index d7631ec938..45f30824d0 100644 --- a/Lib/test/test_sqlite3/test_types.py +++ b/Lib/test/test_sqlite3/test_types.py @@ -95,8 +95,6 @@ def test_too_large_int(self): row = self.cur.fetchone() self.assertIsNone(row) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_string_with_surrogates(self): for value in 0xd8ff, 0xdcff: with self.assertRaises(UnicodeEncodeError): diff --git a/Lib/test/test_ucn.py b/Lib/test/test_ucn.py index 6d082a0942..f6d69540b9 100644 --- a/Lib/test/test_ucn.py +++ b/Lib/test/test_ucn.py @@ -102,8 +102,6 @@ def test_cjk_unified_ideographs(self): self.checkletter("CJK UNIFIED IDEOGRAPH-2B81D", "\U0002B81D") self.checkletter("CJK UNIFIED IDEOGRAPH-3134A", "\U0003134A") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bmp_characters(self): for code in range(0x10000): char = chr(code) diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 17c9f01cd8..5c2c6c29b1 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -721,8 +721,6 @@ def test_isspace(self): '\U0001F40D', '\U0001F46F']: self.assertFalse(ch.isspace(), '{!a} is not space.'.format(ch)) - # TODO: RUSTPYTHON - @unittest.expectedFailure @support.requires_resource('cpu') def test_isspace_invariant(self): for codepoint in range(sys.maxunicode + 1): diff --git a/Lib/test/test_unicodedata.py b/Lib/test/test_unicodedata.py index f5b4b6218e..c9e0b234ef 100644 --- a/Lib/test/test_unicodedata.py +++ b/Lib/test/test_unicodedata.py @@ -99,8 +99,6 @@ def test_function_checksum(self): result = h.hexdigest() self.assertEqual(result, self.expectedchecksum) - # TODO: RUSTPYTHON - @unittest.expectedFailure @requires_resource('cpu') def test_name_inverse_lookup(self): for i in range(sys.maxunicode + 1): @@ -326,8 +324,6 @@ def test_ucd_510(self): self.assertTrue("\u1d79".upper()=='\ua77d') self.assertTrue(".".upper()=='.') - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bug_5828(self): self.assertEqual("\u1d79".lower(), "\u1d79") # Only U+0000 should have U+0000 as its upper/lower/titlecase variant @@ -347,8 +343,6 @@ def test_bug_4971(self): self.assertEqual("\u01c5".title(), "\u01c5") self.assertEqual("\u01c6".title(), "\u01c5") - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_linebreak_7643(self): for i in range(0x10000): lines = (chr(i) + 'A').splitlines() diff --git a/common/Cargo.toml b/common/Cargo.toml index 589170064c..e9aeba7459 100644 --- a/common/Cargo.toml +++ b/common/Cargo.toml @@ -16,12 +16,14 @@ rustpython-literal = { workspace = true } ascii = { workspace = true } bitflags = { workspace = true } +bstr = { workspace = true } cfg-if = { workspace = true } itertools = { workspace = true } libc = { workspace = true } malachite-bigint = { workspace = true } malachite-q = { workspace = true } malachite-base = { workspace = true } +memchr = { workspace = true } num-complex = { workspace = true } num-traits = { workspace = true } once_cell = { workspace = true } diff --git a/common/src/cformat.rs b/common/src/cformat.rs index 94762aceab..e62ffca65e 100644 --- a/common/src/cformat.rs +++ b/common/src/cformat.rs @@ -11,11 +11,13 @@ use std::{ str::FromStr, }; +use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf}; + #[derive(Debug, PartialEq)] pub enum CFormatErrorType { UnmatchedKeyParentheses, MissingModuloSign, - UnsupportedFormatChar(char), + UnsupportedFormatChar(CodePoint), IncompleteFormat, IntTooBig, // Unimplemented, @@ -39,7 +41,9 @@ impl fmt::Display for CFormatError { UnsupportedFormatChar(c) => write!( f, "unsupported format character '{}' ({:#x}) at index {}", - c, c as u32, self.index + c, + c.to_u32(), + self.index ), IntTooBig => write!(f, "width/precision too big"), _ => write!(f, "unexpected error parsing format string"), @@ -160,7 +164,7 @@ pub trait FormatBuf: fn concat(self, other: Self) -> Self; } -pub trait FormatChar: Copy + Into + From { +pub trait FormatChar: Copy + Into + From { fn to_char_lossy(self) -> char; fn eq_char(self, c: char) -> bool; } @@ -188,6 +192,29 @@ impl FormatChar for char { } } +impl FormatBuf for Wtf8Buf { + type Char = CodePoint; + fn chars(&self) -> impl Iterator { + self.code_points() + } + fn len(&self) -> usize { + (**self).len() + } + fn concat(mut self, other: Self) -> Self { + self.extend([other]); + self + } +} + +impl FormatChar for CodePoint { + fn to_char_lossy(self) -> char { + self.to_char_lossy() + } + fn eq_char(self, c: char) -> bool { + self == c + } +} + impl FormatBuf for Vec { type Char = u8; fn chars(&self) -> impl Iterator { @@ -801,6 +828,15 @@ impl FromStr for CFormatString { } } +pub type CFormatWtf8 = CFormatStrOrBytes; + +impl CFormatWtf8 { + pub fn parse_from_wtf8(s: &Wtf8) -> Result { + let mut iter = s.code_points().enumerate().peekable(); + Self::parse(&mut iter) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/common/src/encodings.rs b/common/src/encodings.rs index 858d3b8c6b..7d99646c31 100644 --- a/common/src/encodings.rs +++ b/common/src/encodings.rs @@ -1,12 +1,22 @@ use std::ops::Range; +use num_traits::ToPrimitive; + +use crate::str::StrKind; +use crate::wtf8::{Wtf8, Wtf8Buf}; + pub type EncodeErrorResult = Result<(EncodeReplace, usize), E>; pub type DecodeErrorResult = Result<(S, Option, usize), E>; -pub trait StrBuffer: AsRef { - fn is_ascii(&self) -> bool { - self.as_ref().is_ascii() +pub trait StrBuffer: AsRef { + fn is_compatible_with(&self, kind: StrKind) -> bool { + let s = self.as_ref(); + match kind { + StrKind::Ascii => s.is_ascii(), + StrKind::Utf8 => s.is_utf8(), + StrKind::Wtf8 => true, + } } } @@ -16,7 +26,7 @@ pub trait ErrorHandler { type BytesBuf: AsRef<[u8]>; fn handle_encode_error( &self, - data: &str, + data: &Wtf8, char_range: Range, reason: &str, ) -> EncodeErrorResult; @@ -27,7 +37,7 @@ pub trait ErrorHandler { reason: &str, ) -> DecodeErrorResult; fn error_oob_restart(&self, i: usize) -> Self::Error; - fn error_encoding(&self, data: &str, char_range: Range, reason: &str) -> Self::Error; + fn error_encoding(&self, data: &Wtf8, char_range: Range, reason: &str) -> Self::Error; } pub enum EncodeReplace { Str(S), @@ -63,19 +73,19 @@ fn decode_utf8_compatible( errors: &E, decode: DecodeF, handle_error: ErrF, -) -> Result<(String, usize), E::Error> +) -> Result<(Wtf8Buf, usize), E::Error> where DecodeF: Fn(&[u8]) -> Result<&str, DecodeError<'_>>, ErrF: Fn(&[u8], Option) -> HandleResult<'_>, { if data.is_empty() { - return Ok((String::new(), 0)); + return Ok((Wtf8Buf::new(), 0)); } // we need to coerce the lifetime to that of the function body rather than the // anonymous input lifetime, so that we can assign it data borrowed from data_from_err let mut data = data; let mut data_from_err: E::BytesBuf; - let mut out = String::with_capacity(data.len()); + let mut out = Wtf8Buf::with_capacity(data.len()); let mut remaining_index = 0; let mut remaining_data = data; loop { @@ -98,7 +108,7 @@ where err_idx..err_len.map_or_else(|| data.len(), |len| err_idx + len); let (replace, new_data, restart) = errors.handle_decode_error(data, err_range, reason)?; - out.push_str(replace.as_ref()); + out.push_wtf8(replace.as_ref()); if let Some(new_data) = new_data { data_from_err = new_data; data = data_from_err.as_ref(); @@ -116,21 +126,68 @@ where Ok((out, remaining_index)) } +#[inline] +fn encode_utf8_compatible( + s: &Wtf8, + errors: &E, + err_reason: &str, + target_kind: StrKind, +) -> Result, E::Error> { + let full_data = s; + let mut data = s; + let mut char_data_index = 0; + let mut out = Vec::::new(); + while let Some((char_i, (byte_i, _))) = data + .code_point_indices() + .enumerate() + .find(|(_, (_, c))| !target_kind.can_encode(*c)) + { + out.extend_from_slice(&data.as_bytes()[..byte_i]); + let char_start = char_data_index + char_i; + + // number of non-compatible chars between the first non-compatible char and the next compatible char + let non_compat_run_length = data[byte_i..] + .code_points() + .take_while(|c| !target_kind.can_encode(*c)) + .count(); + let char_range = char_start..char_start + non_compat_run_length; + let (replace, char_restart) = + errors.handle_encode_error(full_data, char_range.clone(), err_reason)?; + match replace { + EncodeReplace::Str(s) => { + if s.is_compatible_with(target_kind) { + out.extend_from_slice(s.as_ref().as_bytes()); + } else { + return Err(errors.error_encoding(full_data, char_range, err_reason)); + } + } + EncodeReplace::Bytes(b) => { + out.extend_from_slice(b.as_ref()); + } + } + data = crate::str::try_get_codepoints(full_data, char_restart..) + .ok_or_else(|| errors.error_oob_restart(char_restart))?; + char_data_index = char_restart; + } + out.extend_from_slice(data.as_bytes()); + Ok(out) +} + pub mod utf8 { use super::*; pub const ENCODING_NAME: &str = "utf-8"; #[inline] - pub fn encode(s: &str, _errors: &E) -> Result, E::Error> { - Ok(s.as_bytes().to_vec()) + pub fn encode(s: &Wtf8, errors: &E) -> Result, E::Error> { + encode_utf8_compatible(s, errors, "surrogates not allowed", StrKind::Utf8) } pub fn decode( data: &[u8], errors: &E, final_decode: bool, - ) -> Result<(String, usize), E::Error> { + ) -> Result<(Wtf8Buf, usize), E::Error> { decode_utf8_compatible( data, errors, @@ -180,14 +237,14 @@ pub mod latin_1 { const ERR_REASON: &str = "ordinal not in range(256)"; #[inline] - pub fn encode(s: &str, errors: &E) -> Result, E::Error> { + pub fn encode(s: &Wtf8, errors: &E) -> Result, E::Error> { let full_data = s; let mut data = s; let mut char_data_index = 0; let mut out = Vec::::new(); loop { match data - .char_indices() + .code_point_indices() .enumerate() .find(|(_, (_, c))| !c.is_ascii()) { @@ -198,17 +255,16 @@ pub mod latin_1 { Some((char_i, (byte_i, ch))) => { out.extend_from_slice(&data.as_bytes()[..byte_i]); let char_start = char_data_index + char_i; - if (ch as u32) <= 255 { - out.push(ch as u8); - let char_restart = char_start + 1; - data = crate::str::try_get_chars(full_data, char_restart..) - .ok_or_else(|| errors.error_oob_restart(char_restart))?; - char_data_index = char_restart; + if let Some(byte) = ch.to_u32().to_u8() { + out.push(byte); + // if the codepoint is between 128..=255, it's utf8-length is 2 + data = &data[byte_i + 2..]; + char_data_index = char_start + 1; } else { // number of non-latin_1 chars between the first non-latin_1 char and the next latin_1 char let non_latin_1_run_length = data[byte_i..] - .chars() - .take_while(|c| (*c as u32) > 255) + .code_points() + .take_while(|c| c.to_u32() > 255) .count(); let char_range = char_start..char_start + non_latin_1_run_length; let (replace, char_restart) = errors.handle_encode_error( @@ -218,7 +274,7 @@ pub mod latin_1 { )?; match replace { EncodeReplace::Str(s) => { - if s.as_ref().chars().any(|c| (c as u32) > 255) { + if s.as_ref().code_points().any(|c| c.to_u32() > 255) { return Err( errors.error_encoding(full_data, char_range, ERR_REASON) ); @@ -229,7 +285,7 @@ pub mod latin_1 { out.extend_from_slice(b.as_ref()); } } - data = crate::str::try_get_chars(full_data, char_restart..) + data = crate::str::try_get_codepoints(full_data, char_restart..) .ok_or_else(|| errors.error_oob_restart(char_restart))?; char_data_index = char_restart; } @@ -240,10 +296,10 @@ pub mod latin_1 { Ok(out) } - pub fn decode(data: &[u8], _errors: &E) -> Result<(String, usize), E::Error> { + pub fn decode(data: &[u8], _errors: &E) -> Result<(Wtf8Buf, usize), E::Error> { let out: String = data.iter().map(|c| *c as char).collect(); let out_len = out.len(); - Ok((out, out_len)) + Ok((out.into(), out_len)) } } @@ -256,54 +312,11 @@ pub mod ascii { const ERR_REASON: &str = "ordinal not in range(128)"; #[inline] - pub fn encode(s: &str, errors: &E) -> Result, E::Error> { - let full_data = s; - let mut data = s; - let mut char_data_index = 0; - let mut out = Vec::::new(); - loop { - match data - .char_indices() - .enumerate() - .find(|(_, (_, c))| !c.is_ascii()) - { - None => { - out.extend_from_slice(data.as_bytes()); - break; - } - Some((char_i, (byte_i, _))) => { - out.extend_from_slice(&data.as_bytes()[..byte_i]); - let char_start = char_data_index + char_i; - // number of non-ascii chars between the first non-ascii char and the next ascii char - let non_ascii_run_length = - data[byte_i..].chars().take_while(|c| !c.is_ascii()).count(); - let char_range = char_start..char_start + non_ascii_run_length; - let (replace, char_restart) = - errors.handle_encode_error(full_data, char_range.clone(), ERR_REASON)?; - match replace { - EncodeReplace::Str(s) => { - if !s.is_ascii() { - return Err( - errors.error_encoding(full_data, char_range, ERR_REASON) - ); - } - out.extend_from_slice(s.as_ref().as_bytes()); - } - EncodeReplace::Bytes(b) => { - out.extend_from_slice(b.as_ref()); - } - } - data = crate::str::try_get_chars(full_data, char_restart..) - .ok_or_else(|| errors.error_oob_restart(char_restart))?; - char_data_index = char_restart; - continue; - } - } - } - Ok(out) + pub fn encode(s: &Wtf8, errors: &E) -> Result, E::Error> { + encode_utf8_compatible(s, errors, ERR_REASON, StrKind::Ascii) } - pub fn decode(data: &[u8], errors: &E) -> Result<(String, usize), E::Error> { + pub fn decode(data: &[u8], errors: &E) -> Result<(Wtf8Buf, usize), E::Error> { decode_utf8_compatible( data, errors, diff --git a/common/src/format.rs b/common/src/format.rs index d9f821658b..75d0996796 100644 --- a/common/src/format.rs +++ b/common/src/format.rs @@ -7,8 +7,10 @@ use rustpython_literal::format::Case; use std::ops::Deref; use std::{cmp, str::FromStr}; +use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf}; + trait FormatParse { - fn parse(text: &str) -> (Option, &str) + fn parse(text: &Wtf8) -> (Option, &Wtf8) where Self: Sized; } @@ -23,20 +25,20 @@ pub enum FormatConversion { } impl FormatParse for FormatConversion { - fn parse(text: &str) -> (Option, &str) { + fn parse(text: &Wtf8) -> (Option, &Wtf8) { let Some(conversion) = Self::from_string(text) else { return (None, text); }; - let mut chars = text.chars(); + let mut chars = text.code_points(); chars.next(); // Consume the bang chars.next(); // Consume one r,s,a char - (Some(conversion), chars.as_str()) + (Some(conversion), chars.as_wtf8()) } } impl FormatConversion { - pub fn from_char(c: char) -> Option { - match c { + pub fn from_char(c: CodePoint) -> Option { + match c.to_char_lossy() { 's' => Some(FormatConversion::Str), 'r' => Some(FormatConversion::Repr), 'a' => Some(FormatConversion::Ascii), @@ -45,9 +47,9 @@ impl FormatConversion { } } - fn from_string(text: &str) -> Option { - let mut chars = text.chars(); - if chars.next() != Some('!') { + fn from_string(text: &Wtf8) -> Option { + let mut chars = text.code_points(); + if chars.next()? != '!' { return None; } @@ -64,8 +66,8 @@ pub enum FormatAlign { } impl FormatAlign { - fn from_char(c: char) -> Option { - match c { + fn from_char(c: CodePoint) -> Option { + match c.to_char_lossy() { '<' => Some(FormatAlign::Left), '>' => Some(FormatAlign::Right), '=' => Some(FormatAlign::AfterSign), @@ -76,10 +78,10 @@ impl FormatAlign { } impl FormatParse for FormatAlign { - fn parse(text: &str) -> (Option, &str) { - let mut chars = text.chars(); + fn parse(text: &Wtf8) -> (Option, &Wtf8) { + let mut chars = text.code_points(); if let Some(maybe_align) = chars.next().and_then(Self::from_char) { - (Some(maybe_align), chars.as_str()) + (Some(maybe_align), chars.as_wtf8()) } else { (None, text) } @@ -94,12 +96,12 @@ pub enum FormatSign { } impl FormatParse for FormatSign { - fn parse(text: &str) -> (Option, &str) { - let mut chars = text.chars(); - match chars.next() { - Some('-') => (Some(Self::Minus), chars.as_str()), - Some('+') => (Some(Self::Plus), chars.as_str()), - Some(' ') => (Some(Self::MinusOrSpace), chars.as_str()), + fn parse(text: &Wtf8) -> (Option, &Wtf8) { + let mut chars = text.code_points(); + match chars.next().and_then(CodePoint::to_char) { + Some('-') => (Some(Self::Minus), chars.as_wtf8()), + Some('+') => (Some(Self::Plus), chars.as_wtf8()), + Some(' ') => (Some(Self::MinusOrSpace), chars.as_wtf8()), _ => (None, text), } } @@ -112,11 +114,11 @@ pub enum FormatGrouping { } impl FormatParse for FormatGrouping { - fn parse(text: &str) -> (Option, &str) { - let mut chars = text.chars(); - match chars.next() { - Some('_') => (Some(Self::Underscore), chars.as_str()), - Some(',') => (Some(Self::Comma), chars.as_str()), + fn parse(text: &Wtf8) -> (Option, &Wtf8) { + let mut chars = text.code_points(); + match chars.next().and_then(CodePoint::to_char) { + Some('_') => (Some(Self::Underscore), chars.as_wtf8()), + Some(',') => (Some(Self::Comma), chars.as_wtf8()), _ => (None, text), } } @@ -161,25 +163,25 @@ impl From<&FormatType> for char { } impl FormatParse for FormatType { - fn parse(text: &str) -> (Option, &str) { - let mut chars = text.chars(); - match chars.next() { - Some('s') => (Some(Self::String), chars.as_str()), - Some('b') => (Some(Self::Binary), chars.as_str()), - Some('c') => (Some(Self::Character), chars.as_str()), - Some('d') => (Some(Self::Decimal), chars.as_str()), - Some('o') => (Some(Self::Octal), chars.as_str()), - Some('n') => (Some(Self::Number(Case::Lower)), chars.as_str()), - Some('N') => (Some(Self::Number(Case::Upper)), chars.as_str()), - Some('x') => (Some(Self::Hex(Case::Lower)), chars.as_str()), - Some('X') => (Some(Self::Hex(Case::Upper)), chars.as_str()), - Some('e') => (Some(Self::Exponent(Case::Lower)), chars.as_str()), - Some('E') => (Some(Self::Exponent(Case::Upper)), chars.as_str()), - Some('f') => (Some(Self::FixedPoint(Case::Lower)), chars.as_str()), - Some('F') => (Some(Self::FixedPoint(Case::Upper)), chars.as_str()), - Some('g') => (Some(Self::GeneralFormat(Case::Lower)), chars.as_str()), - Some('G') => (Some(Self::GeneralFormat(Case::Upper)), chars.as_str()), - Some('%') => (Some(Self::Percentage), chars.as_str()), + fn parse(text: &Wtf8) -> (Option, &Wtf8) { + let mut chars = text.code_points(); + match chars.next().and_then(CodePoint::to_char) { + Some('s') => (Some(Self::String), chars.as_wtf8()), + Some('b') => (Some(Self::Binary), chars.as_wtf8()), + Some('c') => (Some(Self::Character), chars.as_wtf8()), + Some('d') => (Some(Self::Decimal), chars.as_wtf8()), + Some('o') => (Some(Self::Octal), chars.as_wtf8()), + Some('n') => (Some(Self::Number(Case::Lower)), chars.as_wtf8()), + Some('N') => (Some(Self::Number(Case::Upper)), chars.as_wtf8()), + Some('x') => (Some(Self::Hex(Case::Lower)), chars.as_wtf8()), + Some('X') => (Some(Self::Hex(Case::Upper)), chars.as_wtf8()), + Some('e') => (Some(Self::Exponent(Case::Lower)), chars.as_wtf8()), + Some('E') => (Some(Self::Exponent(Case::Upper)), chars.as_wtf8()), + Some('f') => (Some(Self::FixedPoint(Case::Lower)), chars.as_wtf8()), + Some('F') => (Some(Self::FixedPoint(Case::Upper)), chars.as_wtf8()), + Some('g') => (Some(Self::GeneralFormat(Case::Lower)), chars.as_wtf8()), + Some('G') => (Some(Self::GeneralFormat(Case::Upper)), chars.as_wtf8()), + Some('%') => (Some(Self::Percentage), chars.as_wtf8()), _ => (None, text), } } @@ -188,7 +190,7 @@ impl FormatParse for FormatType { #[derive(Debug, PartialEq)] pub struct FormatSpec { conversion: Option, - fill: Option, + fill: Option, align: Option, sign: Option, alternate_form: bool, @@ -198,17 +200,17 @@ pub struct FormatSpec { format_type: Option, } -fn get_num_digits(text: &str) -> usize { - for (index, character) in text.char_indices() { - if !character.is_ascii_digit() { +fn get_num_digits(text: &Wtf8) -> usize { + for (index, character) in text.code_point_indices() { + if !character.is_char_and(|c| c.is_ascii_digit()) { return index; } } text.len() } -fn parse_fill_and_align(text: &str) -> (Option, Option, &str) { - let char_indices: Vec<(usize, char)> = text.char_indices().take(3).collect(); +fn parse_fill_and_align(text: &Wtf8) -> (Option, Option, &Wtf8) { + let char_indices: Vec<(usize, CodePoint)> = text.code_point_indices().take(3).collect(); if char_indices.is_empty() { (None, None, text) } else if char_indices.len() == 1 { @@ -225,12 +227,12 @@ fn parse_fill_and_align(text: &str) -> (Option, Option, &str) } } -fn parse_number(text: &str) -> Result<(Option, &str), FormatSpecError> { +fn parse_number(text: &Wtf8) -> Result<(Option, &Wtf8), FormatSpecError> { let num_digits: usize = get_num_digits(text); if num_digits == 0 { return Ok((None, text)); } - if let Ok(num) = text[..num_digits].parse::() { + if let Some(num) = parse_usize(&text[..num_digits]) { Ok((Some(num), &text[num_digits..])) } else { // NOTE: this condition is different from CPython @@ -238,27 +240,27 @@ fn parse_number(text: &str) -> Result<(Option, &str), FormatSpecError> { } } -fn parse_alternate_form(text: &str) -> (bool, &str) { - let mut chars = text.chars(); - match chars.next() { - Some('#') => (true, chars.as_str()), +fn parse_alternate_form(text: &Wtf8) -> (bool, &Wtf8) { + let mut chars = text.code_points(); + match chars.next().and_then(CodePoint::to_char) { + Some('#') => (true, chars.as_wtf8()), _ => (false, text), } } -fn parse_zero(text: &str) -> (bool, &str) { - let mut chars = text.chars(); - match chars.next() { - Some('0') => (true, chars.as_str()), +fn parse_zero(text: &Wtf8) -> (bool, &Wtf8) { + let mut chars = text.code_points(); + match chars.next().and_then(CodePoint::to_char) { + Some('0') => (true, chars.as_wtf8()), _ => (false, text), } } -fn parse_precision(text: &str) -> Result<(Option, &str), FormatSpecError> { - let mut chars = text.chars(); - Ok(match chars.next() { +fn parse_precision(text: &Wtf8) -> Result<(Option, &Wtf8), FormatSpecError> { + let mut chars = text.code_points(); + Ok(match chars.next().and_then(CodePoint::to_char) { Some('.') => { - let (size, remaining) = parse_number(chars.as_str())?; + let (size, remaining) = parse_number(chars.as_wtf8())?; if let Some(size) = size { if size > i32::MAX as usize { return Err(FormatSpecError::PrecisionTooBig); @@ -273,7 +275,10 @@ fn parse_precision(text: &str) -> Result<(Option, &str), FormatSpecError> } impl FormatSpec { - pub fn parse(text: &str) -> Result { + pub fn parse(text: impl AsRef) -> Result { + Self::_parse(text.as_ref()) + } + fn _parse(text: &Wtf8) -> Result { // get_integer in CPython let (conversion, text) = FormatConversion::parse(text); let (mut fill, mut align, text) = parse_fill_and_align(text); @@ -289,7 +294,7 @@ impl FormatSpec { } if zero && fill.is_none() { - fill.replace('0'); + fill.replace('0'.into()); align = align.or(Some(FormatAlign::AfterSign)); } @@ -306,10 +311,8 @@ impl FormatSpec { }) } - fn compute_fill_string(fill_char: char, fill_chars_needed: i32) -> String { - (0..fill_chars_needed) - .map(|_| fill_char) - .collect::() + fn compute_fill_string(fill_char: CodePoint, fill_chars_needed: i32) -> Wtf8Buf { + (0..fill_chars_needed).map(|_| fill_char).collect() } fn add_magnitude_separators_for_char( @@ -625,7 +628,7 @@ impl FormatSpec { let align = self.align.unwrap_or(default_align); let num_chars = magnitude_str.char_len(); - let fill_char = self.fill.unwrap_or(' '); + let fill_char = self.fill.unwrap_or(' '.into()); let fill_chars_needed: i32 = self.width.map_or(0, |w| { cmp::max(0, (w as i32) - (num_chars as i32) - (sign_str.len() as i32)) }); @@ -726,20 +729,20 @@ impl FromStr for FormatSpec { #[derive(Debug, PartialEq)] pub enum FieldNamePart { - Attribute(String), + Attribute(Wtf8Buf), Index(usize), - StringIndex(String), + StringIndex(Wtf8Buf), } impl FieldNamePart { fn parse_part( - chars: &mut impl PeekingNext, + chars: &mut impl PeekingNext, ) -> Result, FormatParseError> { chars .next() - .map(|ch| match ch { + .map(|ch| match ch.to_char_lossy() { '.' => { - let mut attribute = String::new(); + let mut attribute = Wtf8Buf::new(); for ch in chars.peeking_take_while(|ch| *ch != '.' && *ch != '[') { attribute.push(ch); } @@ -750,12 +753,12 @@ impl FieldNamePart { } } '[' => { - let mut index = String::new(); + let mut index = Wtf8Buf::new(); for ch in chars { if ch == ']' { return if index.is_empty() { Err(FormatParseError::EmptyAttribute) - } else if let Ok(index) = index.parse::() { + } else if let Some(index) = parse_usize(&index) { Ok(FieldNamePart::Index(index)) } else { Ok(FieldNamePart::StringIndex(index)) @@ -775,7 +778,7 @@ impl FieldNamePart { pub enum FieldType { Auto, Index(usize), - Keyword(String), + Keyword(Wtf8Buf), } #[derive(Debug, PartialEq)] @@ -784,17 +787,20 @@ pub struct FieldName { pub parts: Vec, } +fn parse_usize(s: &Wtf8) -> Option { + s.as_str().ok().and_then(|s| s.parse().ok()) +} + impl FieldName { - pub fn parse(text: &str) -> Result { - let mut chars = text.chars().peekable(); - let mut first = String::new(); - for ch in chars.peeking_take_while(|ch| *ch != '.' && *ch != '[') { - first.push(ch); - } + pub fn parse(text: &Wtf8) -> Result { + let mut chars = text.code_points().peekable(); + let first: Wtf8Buf = chars + .peeking_take_while(|ch| *ch != '.' && *ch != '[') + .collect(); let field_type = if first.is_empty() { FieldType::Auto - } else if let Ok(index) = first.parse::() { + } else if let Some(index) = parse_usize(&first) { FieldType::Index(index) } else { FieldType::Keyword(first) @@ -812,11 +818,11 @@ impl FieldName { #[derive(Debug, PartialEq)] pub enum FormatPart { Field { - field_name: String, - conversion_spec: Option, - format_spec: String, + field_name: Wtf8Buf, + conversion_spec: Option, + format_spec: Wtf8Buf, }, - Literal(String), + Literal(Wtf8Buf), } #[derive(Debug, PartialEq)] @@ -825,8 +831,8 @@ pub struct FormatString { } impl FormatString { - fn parse_literal_single(text: &str) -> Result<(char, &str), FormatParseError> { - let mut chars = text.chars(); + fn parse_literal_single(text: &Wtf8) -> Result<(CodePoint, &Wtf8), FormatParseError> { + let mut chars = text.code_points(); // This should never be called with an empty str let first_char = chars.next().unwrap(); // isn't this detectable only with bytes operation? @@ -836,15 +842,15 @@ impl FormatString { return if maybe_next_char.is_none() || maybe_next_char.unwrap() != first_char { Err(FormatParseError::UnescapedStartBracketInLiteral) } else { - Ok((first_char, chars.as_str())) + Ok((first_char, chars.as_wtf8())) }; } - Ok((first_char, chars.as_str())) + Ok((first_char, chars.as_wtf8())) } - fn parse_literal(text: &str) -> Result<(FormatPart, &str), FormatParseError> { + fn parse_literal(text: &Wtf8) -> Result<(FormatPart, &Wtf8), FormatParseError> { let mut cur_text = text; - let mut result_string = String::new(); + let mut result_string = Wtf8Buf::new(); while !cur_text.is_empty() { match FormatString::parse_literal_single(cur_text) { Ok((next_char, remaining)) => { @@ -860,14 +866,14 @@ impl FormatString { } } } - Ok((FormatPart::Literal(result_string), "")) + Ok((FormatPart::Literal(result_string), "".as_ref())) } - fn parse_part_in_brackets(text: &str) -> Result { - let mut chars = text.chars().peekable(); + fn parse_part_in_brackets(text: &Wtf8) -> Result { + let mut chars = text.code_points().peekable(); - let mut left = String::new(); - let mut right = String::new(); + let mut left = Wtf8Buf::new(); + let mut right = Wtf8Buf::new(); let mut split = false; let mut selected = &mut left; @@ -899,12 +905,12 @@ impl FormatString { } // before the comma is a keyword or arg index, after the comma is maybe a spec. - let arg_part: &str = &left; + let arg_part: &Wtf8 = &left; - let format_spec = if split { right } else { String::new() }; + let format_spec = if split { right } else { Wtf8Buf::new() }; // left can still be the conversion (!r, !s, !a) - let parts: Vec<&str> = arg_part.splitn(2, '!').collect(); + let parts: Vec<&Wtf8> = arg_part.splitn(2, "!".as_ref()).collect(); // before the bang is a keyword or arg index, after the comma is maybe a conversion spec. let arg_part = parts[0]; @@ -913,7 +919,7 @@ impl FormatString { .map(|conversion| { // conversions are only every one character conversion - .chars() + .code_points() .exactly_one() .map_err(|_| FormatParseError::UnknownConversion) }) @@ -926,13 +932,13 @@ impl FormatString { }) } - fn parse_spec(text: &str) -> Result<(FormatPart, &str), FormatParseError> { + fn parse_spec(text: &Wtf8) -> Result<(FormatPart, &Wtf8), FormatParseError> { let mut nested = false; let mut end_bracket_pos = None; - let mut left = String::new(); + let mut left = Wtf8Buf::new(); // There may be one layer nesting brackets in spec - for (idx, c) in text.char_indices() { + for (idx, c) in text.code_point_indices() { if idx == 0 { if c != '{' { return Err(FormatParseError::MissingStartBracket); @@ -959,7 +965,7 @@ impl FormatString { } } if let Some(pos) = end_bracket_pos { - let (_, right) = text.split_at(pos); + let right = &text[pos..]; let format_part = FormatString::parse_part_in_brackets(&left)?; Ok((format_part, &right[1..])) } else { @@ -970,14 +976,14 @@ impl FormatString { pub trait FromTemplate<'a>: Sized { type Err; - fn from_str(s: &'a str) -> Result; + fn from_str(s: &'a Wtf8) -> Result; } impl<'a> FromTemplate<'a> for FormatString { type Err = FormatParseError; - fn from_str(text: &'a str) -> Result { - let mut cur_text: &str = text; + fn from_str(text: &'a Wtf8) -> Result { + let mut cur_text: &Wtf8 = text; let mut parts: Vec = Vec::new(); while !cur_text.is_empty() { // Try to parse both literals and bracketed format parts until we @@ -1001,6 +1007,14 @@ mod tests { #[test] fn test_fill_and_align() { + let parse_fill_and_align = |text| { + let (fill, align, rest) = parse_fill_and_align(str::as_ref(text)); + ( + fill.and_then(CodePoint::to_char), + align, + rest.as_str().unwrap(), + ) + }; assert_eq!( parse_fill_and_align(" <"), (Some(' '), Some(FormatAlign::Left), "") @@ -1043,7 +1057,7 @@ mod tests { fn test_fill_and_width() { let expected = Ok(FormatSpec { conversion: None, - fill: Some('<'), + fill: Some('<'.into()), align: Some(FormatAlign::Right), sign: None, alternate_form: false, @@ -1059,7 +1073,7 @@ mod tests { fn test_all() { let expected = Ok(FormatSpec { conversion: None, - fill: Some('<'), + fill: Some('<'.into()), align: Some(FormatAlign::Right), sign: Some(FormatSign::Minus), alternate_form: true, @@ -1167,33 +1181,33 @@ mod tests { fn test_format_parse() { let expected = Ok(FormatString { format_parts: vec![ - FormatPart::Literal("abcd".to_owned()), + FormatPart::Literal("abcd".into()), FormatPart::Field { - field_name: "1".to_owned(), + field_name: "1".into(), conversion_spec: None, - format_spec: String::new(), + format_spec: "".into(), }, - FormatPart::Literal(":".to_owned()), + FormatPart::Literal(":".into()), FormatPart::Field { - field_name: "key".to_owned(), + field_name: "key".into(), conversion_spec: None, - format_spec: String::new(), + format_spec: "".into(), }, ], }); - assert_eq!(FormatString::from_str("abcd{1}:{key}"), expected); + assert_eq!(FormatString::from_str("abcd{1}:{key}".as_ref()), expected); } #[test] fn test_format_parse_multi_byte_char() { - assert!(FormatString::from_str("{a:%ЫйЯЧ}").is_ok()); + assert!(FormatString::from_str("{a:%ЫйЯЧ}".as_ref()).is_ok()); } #[test] fn test_format_parse_fail() { assert_eq!( - FormatString::from_str("{s"), + FormatString::from_str("{s".as_ref()), Err(FormatParseError::UnmatchedBracket) ); } @@ -1201,27 +1215,27 @@ mod tests { #[test] fn test_square_brackets_inside_format() { assert_eq!( - FormatString::from_str("{[:123]}"), + FormatString::from_str("{[:123]}".as_ref()), Ok(FormatString { format_parts: vec![FormatPart::Field { - field_name: "[:123]".to_owned(), + field_name: "[:123]".into(), conversion_spec: None, - format_spec: "".to_owned(), + format_spec: "".into(), }], }), ); - assert_eq!(FormatString::from_str("{asdf[:123]asdf}"), { + assert_eq!(FormatString::from_str("{asdf[:123]asdf}".as_ref()), { Ok(FormatString { format_parts: vec![FormatPart::Field { - field_name: "asdf[:123]asdf".to_owned(), + field_name: "asdf[:123]asdf".into(), conversion_spec: None, - format_spec: "".to_owned(), + format_spec: "".into(), }], }) }); - assert_eq!(FormatString::from_str("{[1234}"), { + assert_eq!(FormatString::from_str("{[1234}".as_ref()), { Err(FormatParseError::MissingRightBracket) }); } @@ -1230,17 +1244,17 @@ mod tests { fn test_format_parse_escape() { let expected = Ok(FormatString { format_parts: vec![ - FormatPart::Literal("{".to_owned()), + FormatPart::Literal("{".into()), FormatPart::Field { - field_name: "key".to_owned(), + field_name: "key".into(), conversion_spec: None, - format_spec: String::new(), + format_spec: "".into(), }, - FormatPart::Literal("}ddfe".to_owned()), + FormatPart::Literal("}ddfe".into()), ], }); - assert_eq!(FormatString::from_str("{{{key}}}ddfe"), expected); + assert_eq!(FormatString::from_str("{{{key}}}ddfe".as_ref()), expected); } #[test] @@ -1277,52 +1291,44 @@ mod tests { #[test] fn test_parse_field_name() { + let parse = |s: &str| FieldName::parse(s.as_ref()); assert_eq!( - FieldName::parse(""), + parse(""), Ok(FieldName { field_type: FieldType::Auto, parts: Vec::new(), }) ); assert_eq!( - FieldName::parse("0"), + parse("0"), Ok(FieldName { field_type: FieldType::Index(0), parts: Vec::new(), }) ); assert_eq!( - FieldName::parse("key"), + parse("key"), Ok(FieldName { - field_type: FieldType::Keyword("key".to_owned()), + field_type: FieldType::Keyword("key".into()), parts: Vec::new(), }) ); assert_eq!( - FieldName::parse("key.attr[0][string]"), + parse("key.attr[0][string]"), Ok(FieldName { - field_type: FieldType::Keyword("key".to_owned()), + field_type: FieldType::Keyword("key".into()), parts: vec![ - FieldNamePart::Attribute("attr".to_owned()), + FieldNamePart::Attribute("attr".into()), FieldNamePart::Index(0), - FieldNamePart::StringIndex("string".to_owned()) + FieldNamePart::StringIndex("string".into()) ], }) ); + assert_eq!(parse("key.."), Err(FormatParseError::EmptyAttribute)); + assert_eq!(parse("key[]"), Err(FormatParseError::EmptyAttribute)); + assert_eq!(parse("key["), Err(FormatParseError::MissingRightBracket)); assert_eq!( - FieldName::parse("key.."), - Err(FormatParseError::EmptyAttribute) - ); - assert_eq!( - FieldName::parse("key[]"), - Err(FormatParseError::EmptyAttribute) - ); - assert_eq!( - FieldName::parse("key["), - Err(FormatParseError::MissingRightBracket) - ); - assert_eq!( - FieldName::parse("key[0]after"), + parse("key[0]after"), Err(FormatParseError::InvalidCharacterAfterRightBracket) ); } diff --git a/common/src/lib.rs b/common/src/lib.rs index 760ba8c55f..e83a9af43a 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs @@ -29,6 +29,7 @@ pub mod static_cell; pub mod str; #[cfg(windows)] pub mod windows; +pub mod wtf8; pub mod vendored { pub use ascii; diff --git a/common/src/str.rs b/common/src/str.rs index 89d2381d3e..176b5d0f87 100644 --- a/common/src/str.rs +++ b/common/src/str.rs @@ -1,9 +1,9 @@ -use crate::{ - atomic::{PyAtomic, Radium}, - format::CharLen, - hash::PyHash, -}; -use ascii::AsciiString; +use crate::atomic::{PyAtomic, Radium}; +use crate::format::CharLen; +use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf}; +use ascii::{AsciiChar, AsciiStr, AsciiString}; +use core::fmt; +use core::sync::atomic::Ordering::Relaxed; use std::ops::{Bound, RangeBounds}; #[cfg(not(target_arch = "wasm32"))] @@ -14,133 +14,323 @@ pub type wchar_t = libc::wchar_t; pub type wchar_t = u32; /// Utf8 + state.ascii (+ PyUnicode_Kind in future) -#[derive(Debug, Copy, Clone, PartialEq)] -pub enum PyStrKind { +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum StrKind { Ascii, Utf8, + Wtf8, } -impl std::ops::BitOr for PyStrKind { +impl std::ops::BitOr for StrKind { type Output = Self; fn bitor(self, other: Self) -> Self { + use StrKind::*; match (self, other) { - (Self::Ascii, Self::Ascii) => Self::Ascii, - _ => Self::Utf8, + (Wtf8, _) | (_, Wtf8) => Wtf8, + (Utf8, _) | (_, Utf8) => Utf8, + (Ascii, Ascii) => Ascii, } } } -impl PyStrKind { - #[inline] - pub fn new_data(self) -> PyStrKindData { +impl StrKind { + pub fn is_ascii(&self) -> bool { + matches!(self, Self::Ascii) + } + + pub fn is_utf8(&self) -> bool { + matches!(self, Self::Ascii | Self::Utf8) + } + + #[inline(always)] + pub fn can_encode(&self, code: CodePoint) -> bool { match self { - PyStrKind::Ascii => PyStrKindData::Ascii, - PyStrKind::Utf8 => PyStrKindData::Utf8(Radium::new(usize::MAX)), + StrKind::Ascii => code.is_ascii(), + StrKind::Utf8 => code.to_char().is_some(), + StrKind::Wtf8 => true, + } + } +} + +pub trait DeduceStrKind { + fn str_kind(&self) -> StrKind; +} + +impl DeduceStrKind for str { + fn str_kind(&self) -> StrKind { + if self.is_ascii() { + StrKind::Ascii + } else { + StrKind::Utf8 + } + } +} + +impl DeduceStrKind for Wtf8 { + fn str_kind(&self) -> StrKind { + if self.is_ascii() { + StrKind::Ascii + } else if self.is_utf8() { + StrKind::Utf8 + } else { + StrKind::Wtf8 } } } +impl DeduceStrKind for String { + fn str_kind(&self) -> StrKind { + (**self).str_kind() + } +} + +impl DeduceStrKind for Wtf8Buf { + fn str_kind(&self) -> StrKind { + (**self).str_kind() + } +} + +impl DeduceStrKind for &T { + fn str_kind(&self) -> StrKind { + (**self).str_kind() + } +} + +impl DeduceStrKind for Box { + fn str_kind(&self) -> StrKind { + (**self).str_kind() + } +} + #[derive(Debug)] -pub enum PyStrKindData { - Ascii, - // uses usize::MAX as a sentinel for "uncomputed" - Utf8(PyAtomic), +pub enum PyKindStr<'a> { + Ascii(&'a AsciiStr), + Utf8(&'a str), + Wtf8(&'a Wtf8), } -impl PyStrKindData { - #[inline] - pub fn kind(&self) -> PyStrKind { - match self { - PyStrKindData::Ascii => PyStrKind::Ascii, - PyStrKindData::Utf8(_) => PyStrKind::Utf8, +#[derive(Debug, Clone)] +pub struct StrData { + data: Box, + kind: StrKind, + len: StrLen, +} + +struct StrLen(PyAtomic); + +impl From for StrLen { + #[inline(always)] + fn from(value: usize) -> Self { + Self(Radium::new(value)) + } +} + +impl fmt::Debug for StrLen { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let len = self.0.load(Relaxed); + if len == usize::MAX { + f.write_str("") + } else { + len.fmt(f) } } } -pub struct BorrowedStr<'a> { - bytes: &'a [u8], - kind: PyStrKindData, - #[allow(dead_code)] - hash: PyAtomic, +impl StrLen { + #[inline(always)] + fn zero() -> Self { + 0usize.into() + } + #[inline(always)] + fn uncomputed() -> Self { + usize::MAX.into() + } +} + +impl Clone for StrLen { + fn clone(&self) -> Self { + Self(self.0.load(Relaxed).into()) + } } -impl<'a> BorrowedStr<'a> { - /// # Safety - /// `s` have to be an ascii string - #[inline] - pub unsafe fn from_ascii_unchecked(s: &'a [u8]) -> Self { - debug_assert!(s.is_ascii()); +impl Default for StrData { + fn default() -> Self { Self { - bytes: s, - kind: PyStrKind::Ascii.new_data(), - hash: PyAtomic::::new(0), + data: >::default(), + kind: StrKind::Ascii, + len: StrLen::zero(), } } +} + +impl From> for StrData { + fn from(value: Box) -> Self { + // doing the check is ~10x faster for ascii, and is actually only 2% slower worst case for + // non-ascii; see https://github.com/RustPython/RustPython/pull/2586#issuecomment-844611532 + let kind = value.str_kind(); + unsafe { Self::new_str_unchecked(value, kind) } + } +} + +impl From> for StrData { + #[inline] + fn from(value: Box) -> Self { + // doing the check is ~10x faster for ascii, and is actually only 2% slower worst case for + // non-ascii; see https://github.com/RustPython/RustPython/pull/2586#issuecomment-844611532 + let kind = value.str_kind(); + unsafe { Self::new_str_unchecked(value.into(), kind) } + } +} +impl From> for StrData { #[inline] - pub fn from_bytes(s: &'a [u8]) -> Self { - let k = if s.is_ascii() { - PyStrKind::Ascii.new_data() + fn from(value: Box) -> Self { + Self { + len: value.len().into(), + data: value.into(), + kind: StrKind::Ascii, + } + } +} + +impl From for StrData { + fn from(ch: AsciiChar) -> Self { + AsciiString::from(ch).into_boxed_ascii_str().into() + } +} + +impl From for StrData { + fn from(ch: char) -> Self { + if let Ok(ch) = ascii::AsciiChar::from_ascii(ch) { + ch.into() + } else { + Self { + data: ch.to_string().into(), + kind: StrKind::Utf8, + len: 1.into(), + } + } + } +} + +impl From for StrData { + fn from(ch: CodePoint) -> Self { + if let Some(ch) = ch.to_char() { + ch.into() } else { - PyStrKind::Utf8.new_data() + Self { + data: Wtf8Buf::from(ch).into(), + kind: StrKind::Wtf8, + len: 1.into(), + } + } + } +} + +impl StrData { + /// # Safety + /// + /// Given `bytes` must be valid data for given `kind` + pub unsafe fn new_str_unchecked(data: Box, kind: StrKind) -> Self { + let len = match kind { + StrKind::Ascii => data.len().into(), + _ => StrLen::uncomputed(), }; + Self { data, kind, len } + } + + /// # Safety + /// + /// `char_len` must be accurate. + pub unsafe fn new_with_char_len(data: Box, kind: StrKind, char_len: usize) -> Self { Self { - bytes: s, - kind: k, - hash: PyAtomic::::new(0), + data, + kind, + len: char_len.into(), } } #[inline] - pub fn as_str(&self) -> &str { - unsafe { - // SAFETY: Both PyStrKind::{Ascii, Utf8} are valid utf8 string - std::str::from_utf8_unchecked(self.bytes) + pub fn as_wtf8(&self) -> &Wtf8 { + &self.data + } + + #[inline] + pub fn as_str(&self) -> Option<&str> { + self.kind + .is_utf8() + .then(|| unsafe { std::str::from_utf8_unchecked(self.data.as_bytes()) }) + } + + pub fn as_ascii(&self) -> Option<&AsciiStr> { + self.kind + .is_ascii() + .then(|| unsafe { AsciiStr::from_ascii_unchecked(self.data.as_bytes()) }) + } + + pub fn kind(&self) -> StrKind { + self.kind + } + + #[inline] + pub fn as_str_kind(&self) -> PyKindStr<'_> { + match self.kind { + StrKind::Ascii => { + PyKindStr::Ascii(unsafe { AsciiStr::from_ascii_unchecked(self.data.as_bytes()) }) + } + StrKind::Utf8 => { + PyKindStr::Utf8(unsafe { std::str::from_utf8_unchecked(self.data.as_bytes()) }) + } + StrKind::Wtf8 => PyKindStr::Wtf8(&self.data), } } + #[inline] + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + #[inline] pub fn char_len(&self) -> usize { - match self.kind { - PyStrKindData::Ascii => self.bytes.len(), - PyStrKindData::Utf8(ref len) => match len.load(core::sync::atomic::Ordering::Relaxed) { - usize::MAX => self._compute_char_len(), - len => len, - }, + match self.len.0.load(Relaxed) { + usize::MAX => self._compute_char_len(), + len => len, } } #[cold] fn _compute_char_len(&self) -> usize { - match self.kind { - PyStrKindData::Utf8(ref char_len) => { - let len = self.as_str().chars().count(); - // len cannot be usize::MAX, since vec.capacity() < sys.maxsize - char_len.store(len, core::sync::atomic::Ordering::Relaxed); - len - } - _ => unsafe { - debug_assert!(false); // invalid for non-utf8 strings - std::hint::unreachable_unchecked() - }, - } + let len = if let Some(s) = self.as_str() { + // utf8 chars().count() is optimized + s.chars().count() + } else { + self.data.code_points().count() + }; + // len cannot be usize::MAX, since vec.capacity() < sys.maxsize + self.len.0.store(len, Relaxed); + len } -} -impl std::ops::Deref for BorrowedStr<'_> { - type Target = str; - fn deref(&self) -> &str { - self.as_str() + pub fn nth_char(&self, index: usize) -> CodePoint { + match self.as_str_kind() { + PyKindStr::Ascii(s) => s[index].into(), + PyKindStr::Utf8(s) => s.chars().nth(index).unwrap().into(), + PyKindStr::Wtf8(w) => w.code_points().nth(index).unwrap(), + } } } -impl std::fmt::Display for BorrowedStr<'_> { +impl std::fmt::Display for StrData { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.as_str().fmt(f) + self.data.fmt(f) } } -impl CharLen for BorrowedStr<'_> { +impl CharLen for StrData { fn char_len(&self) -> usize { self.char_len() } @@ -181,6 +371,41 @@ pub fn char_range_end(s: &str, nchars: usize) -> Option { Some(i) } +pub fn try_get_codepoints(w: &Wtf8, range: impl RangeBounds) -> Option<&Wtf8> { + let mut chars = w.code_points(); + let start = match range.start_bound() { + Bound::Included(&i) => i, + Bound::Excluded(&i) => i + 1, + Bound::Unbounded => 0, + }; + for _ in 0..start { + chars.next()?; + } + let s = chars.as_wtf8(); + let range_len = match range.end_bound() { + Bound::Included(&i) => i + 1 - start, + Bound::Excluded(&i) => i - start, + Bound::Unbounded => return Some(s), + }; + codepoint_range_end(s, range_len).map(|end| &s[..end]) +} + +pub fn get_codepoints(w: &Wtf8, range: impl RangeBounds) -> &Wtf8 { + try_get_codepoints(w, range).unwrap() +} + +#[inline] +pub fn codepoint_range_end(s: &Wtf8, nchars: usize) -> Option { + let i = match nchars.checked_sub(1) { + Some(last_char_index) => { + let (index, c) = s.code_point_indices().nth(last_char_index)?; + index + c.len_wtf8() + } + None => 0, + }; + Some(i) +} + pub fn zfill(bytes: &[u8], width: usize) -> Vec { if width <= bytes.len() { bytes.to_vec() diff --git a/common/src/wtf8/core_char.rs b/common/src/wtf8/core_char.rs new file mode 100644 index 0000000000..1444e8e130 --- /dev/null +++ b/common/src/wtf8/core_char.rs @@ -0,0 +1,113 @@ +//! Unstable functions from [`core::char`] + +use core::slice; + +pub const MAX_LEN_UTF8: usize = 4; +pub const MAX_LEN_UTF16: usize = 2; + +// UTF-8 ranges and tags for encoding characters +const TAG_CONT: u8 = 0b1000_0000; +const TAG_TWO_B: u8 = 0b1100_0000; +const TAG_THREE_B: u8 = 0b1110_0000; +const TAG_FOUR_B: u8 = 0b1111_0000; +const MAX_ONE_B: u32 = 0x80; +const MAX_TWO_B: u32 = 0x800; +const MAX_THREE_B: u32 = 0x10000; + +#[inline] +#[must_use] +pub const fn len_utf8(code: u32) -> usize { + match code { + ..MAX_ONE_B => 1, + ..MAX_TWO_B => 2, + ..MAX_THREE_B => 3, + _ => 4, + } +} + +#[inline] +#[must_use] +const fn len_utf16(code: u32) -> usize { + if (code & 0xFFFF) == code { 1 } else { 2 } +} + +/// Encodes a raw `u32` value as UTF-8 into the provided byte buffer, +/// and then returns the subslice of the buffer that contains the encoded character. +/// +/// Unlike `char::encode_utf8`, this method also handles codepoints in the surrogate range. +/// (Creating a `char` in the surrogate range is UB.) +/// The result is valid [generalized UTF-8] but not valid UTF-8. +/// +/// [generalized UTF-8]: https://simonsapin.github.io/wtf-8/#generalized-utf8 +/// +/// # Panics +/// +/// Panics if the buffer is not large enough. +/// A buffer of length four is large enough to encode any `char`. +#[doc(hidden)] +#[inline] +pub fn encode_utf8_raw(code: u32, dst: &mut [u8]) -> &mut [u8] { + let len = len_utf8(code); + match (len, &mut *dst) { + (1, [a, ..]) => { + *a = code as u8; + } + (2, [a, b, ..]) => { + *a = (code >> 6 & 0x1F) as u8 | TAG_TWO_B; + *b = (code & 0x3F) as u8 | TAG_CONT; + } + (3, [a, b, c, ..]) => { + *a = (code >> 12 & 0x0F) as u8 | TAG_THREE_B; + *b = (code >> 6 & 0x3F) as u8 | TAG_CONT; + *c = (code & 0x3F) as u8 | TAG_CONT; + } + (4, [a, b, c, d, ..]) => { + *a = (code >> 18 & 0x07) as u8 | TAG_FOUR_B; + *b = (code >> 12 & 0x3F) as u8 | TAG_CONT; + *c = (code >> 6 & 0x3F) as u8 | TAG_CONT; + *d = (code & 0x3F) as u8 | TAG_CONT; + } + _ => { + panic!( + "encode_utf8: need {len} bytes to encode U+{code:04X} but buffer has just {dst_len}", + dst_len = dst.len(), + ) + } + }; + // SAFETY: `<&mut [u8]>::as_mut_ptr` is guaranteed to return a valid pointer and `len` has been tested to be within bounds. + unsafe { slice::from_raw_parts_mut(dst.as_mut_ptr(), len) } +} + +/// Encodes a raw `u32` value as UTF-16 into the provided `u16` buffer, +/// and then returns the subslice of the buffer that contains the encoded character. +/// +/// Unlike `char::encode_utf16`, this method also handles codepoints in the surrogate range. +/// (Creating a `char` in the surrogate range is UB.) +/// +/// # Panics +/// +/// Panics if the buffer is not large enough. +/// A buffer of length 2 is large enough to encode any `char`. +#[doc(hidden)] +#[inline] +pub fn encode_utf16_raw(mut code: u32, dst: &mut [u16]) -> &mut [u16] { + let len = len_utf16(code); + match (len, &mut *dst) { + (1, [a, ..]) => { + *a = code as u16; + } + (2, [a, b, ..]) => { + code -= 0x1_0000; + *a = (code >> 10) as u16 | 0xD800; + *b = (code & 0x3FF) as u16 | 0xDC00; + } + _ => { + panic!( + "encode_utf16: need {len} bytes to encode U+{code:04X} but buffer has just {dst_len}", + dst_len = dst.len(), + ) + } + }; + // SAFETY: `<&mut [u16]>::as_mut_ptr` is guaranteed to return a valid pointer and `len` has been tested to be within bounds. + unsafe { slice::from_raw_parts_mut(dst.as_mut_ptr(), len) } +} diff --git a/common/src/wtf8/core_str.rs b/common/src/wtf8/core_str.rs new file mode 100644 index 0000000000..56f715cf75 --- /dev/null +++ b/common/src/wtf8/core_str.rs @@ -0,0 +1,113 @@ +//! Operations related to UTF-8 validation. +//! +//! Copied from `core::str::validations` + +/// Returns the initial codepoint accumulator for the first byte. +/// The first byte is special, only want bottom 5 bits for width 2, 4 bits +/// for width 3, and 3 bits for width 4. +#[inline] +const fn utf8_first_byte(byte: u8, width: u32) -> u32 { + (byte & (0x7F >> width)) as u32 +} + +/// Returns the value of `ch` updated with continuation byte `byte`. +#[inline] +const fn utf8_acc_cont_byte(ch: u32, byte: u8) -> u32 { + (ch << 6) | (byte & CONT_MASK) as u32 +} + +/// Checks whether the byte is a UTF-8 continuation byte (i.e., starts with the +/// bits `10`). +#[inline] +pub(super) const fn utf8_is_cont_byte(byte: u8) -> bool { + (byte as i8) < -64 +} + +/// Reads the next code point out of a byte iterator (assuming a +/// UTF-8-like encoding). +/// +/// # Safety +/// +/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string +#[inline] +pub unsafe fn next_code_point<'a, I: Iterator>(bytes: &mut I) -> Option { + // Decode UTF-8 + let x = *bytes.next()?; + if x < 128 { + return Some(x as u32); + } + + // Multibyte case follows + // Decode from a byte combination out of: [[[x y] z] w] + // NOTE: Performance is sensitive to the exact formulation here + let init = utf8_first_byte(x, 2); + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let y = unsafe { *bytes.next().unwrap_unchecked() }; + let mut ch = utf8_acc_cont_byte(init, y); + if x >= 0xE0 { + // [[x y z] w] case + // 5th bit in 0xE0 .. 0xEF is always clear, so `init` is still valid + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let z = unsafe { *bytes.next().unwrap_unchecked() }; + let y_z = utf8_acc_cont_byte((y & CONT_MASK) as u32, z); + ch = init << 12 | y_z; + if x >= 0xF0 { + // [x y z w] case + // use only the lower 3 bits of `init` + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let w = unsafe { *bytes.next().unwrap_unchecked() }; + ch = (init & 7) << 18 | utf8_acc_cont_byte(y_z, w); + } + } + + Some(ch) +} + +/// Reads the last code point out of a byte iterator (assuming a +/// UTF-8-like encoding). +/// +/// # Safety +/// +/// `bytes` must produce a valid UTF-8-like (UTF-8 or WTF-8) string +#[inline] +pub unsafe fn next_code_point_reverse<'a, I>(bytes: &mut I) -> Option +where + I: DoubleEndedIterator, +{ + // Decode UTF-8 + let w = match *bytes.next_back()? { + next_byte if next_byte < 128 => return Some(next_byte as u32), + back_byte => back_byte, + }; + + // Multibyte case follows + // Decode from a byte combination out of: [x [y [z w]]] + let mut ch; + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let z = unsafe { *bytes.next_back().unwrap_unchecked() }; + ch = utf8_first_byte(z, 2); + if utf8_is_cont_byte(z) { + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let y = unsafe { *bytes.next_back().unwrap_unchecked() }; + ch = utf8_first_byte(y, 3); + if utf8_is_cont_byte(y) { + // SAFETY: `bytes` produces an UTF-8-like string, + // so the iterator must produce a value here. + let x = unsafe { *bytes.next_back().unwrap_unchecked() }; + ch = utf8_first_byte(x, 4); + ch = utf8_acc_cont_byte(ch, y); + } + ch = utf8_acc_cont_byte(ch, z); + } + ch = utf8_acc_cont_byte(ch, w); + + Some(ch) +} + +/// Mask of the value bits of a continuation byte. +const CONT_MASK: u8 = 0b0011_1111; diff --git a/common/src/wtf8/core_str_count.rs b/common/src/wtf8/core_str_count.rs new file mode 100644 index 0000000000..cff5a4b076 --- /dev/null +++ b/common/src/wtf8/core_str_count.rs @@ -0,0 +1,161 @@ +//! Modified from core::str::count + +use super::Wtf8; + +const USIZE_SIZE: usize = core::mem::size_of::(); +const UNROLL_INNER: usize = 4; + +#[inline] +pub(super) fn count_chars(s: &Wtf8) -> usize { + if s.len() < USIZE_SIZE * UNROLL_INNER { + // Avoid entering the optimized implementation for strings where the + // difference is not likely to matter, or where it might even be slower. + // That said, a ton of thought was not spent on the particular threshold + // here, beyond "this value seems to make sense". + char_count_general_case(s.as_bytes()) + } else { + do_count_chars(s) + } +} + +fn do_count_chars(s: &Wtf8) -> usize { + // For correctness, `CHUNK_SIZE` must be: + // + // - Less than or equal to 255, otherwise we'll overflow bytes in `counts`. + // - A multiple of `UNROLL_INNER`, otherwise our `break` inside the + // `body.chunks(CHUNK_SIZE)` loop is incorrect. + // + // For performance, `CHUNK_SIZE` should be: + // - Relatively cheap to `/` against (so some simple sum of powers of two). + // - Large enough to avoid paying for the cost of the `sum_bytes_in_usize` + // too often. + const CHUNK_SIZE: usize = 192; + + // Check the properties of `CHUNK_SIZE` and `UNROLL_INNER` that are required + // for correctness. + const _: () = assert!(CHUNK_SIZE < 256); + const _: () = assert!(CHUNK_SIZE % UNROLL_INNER == 0); + + // SAFETY: transmuting `[u8]` to `[usize]` is safe except for size + // differences which are handled by `align_to`. + let (head, body, tail) = unsafe { s.as_bytes().align_to::() }; + + // This should be quite rare, and basically exists to handle the degenerate + // cases where align_to fails (as well as miri under symbolic alignment + // mode). + // + // The `unlikely` helps discourage LLVM from inlining the body, which is + // nice, as we would rather not mark the `char_count_general_case` function + // as cold. + if unlikely(body.is_empty() || head.len() > USIZE_SIZE || tail.len() > USIZE_SIZE) { + return char_count_general_case(s.as_bytes()); + } + + let mut total = char_count_general_case(head) + char_count_general_case(tail); + // Split `body` into `CHUNK_SIZE` chunks to reduce the frequency with which + // we call `sum_bytes_in_usize`. + for chunk in body.chunks(CHUNK_SIZE) { + // We accumulate intermediate sums in `counts`, where each byte contains + // a subset of the sum of this chunk, like a `[u8; size_of::()]`. + let mut counts = 0; + + let (unrolled_chunks, remainder) = slice_as_chunks::<_, UNROLL_INNER>(chunk); + for unrolled in unrolled_chunks { + for &word in unrolled { + // Because `CHUNK_SIZE` is < 256, this addition can't cause the + // count in any of the bytes to overflow into a subsequent byte. + counts += contains_non_continuation_byte(word); + } + } + + // Sum the values in `counts` (which, again, is conceptually a `[u8; + // size_of::()]`), and accumulate the result into `total`. + total += sum_bytes_in_usize(counts); + + // If there's any data in `remainder`, then handle it. This will only + // happen for the last `chunk` in `body.chunks()` (because `CHUNK_SIZE` + // is divisible by `UNROLL_INNER`), so we explicitly break at the end + // (which seems to help LLVM out). + if !remainder.is_empty() { + // Accumulate all the data in the remainder. + let mut counts = 0; + for &word in remainder { + counts += contains_non_continuation_byte(word); + } + total += sum_bytes_in_usize(counts); + break; + } + } + total +} + +// Checks each byte of `w` to see if it contains the first byte in a UTF-8 +// sequence. Bytes in `w` which are continuation bytes are left as `0x00` (e.g. +// false), and bytes which are non-continuation bytes are left as `0x01` (e.g. +// true) +#[inline] +fn contains_non_continuation_byte(w: usize) -> usize { + const LSB: usize = usize_repeat_u8(0x01); + ((!w >> 7) | (w >> 6)) & LSB +} + +// Morally equivalent to `values.to_ne_bytes().into_iter().sum::()`, but +// more efficient. +#[inline] +fn sum_bytes_in_usize(values: usize) -> usize { + const LSB_SHORTS: usize = usize_repeat_u16(0x0001); + const SKIP_BYTES: usize = usize_repeat_u16(0x00ff); + + let pair_sum: usize = (values & SKIP_BYTES) + ((values >> 8) & SKIP_BYTES); + pair_sum.wrapping_mul(LSB_SHORTS) >> ((USIZE_SIZE - 2) * 8) +} + +// This is the most direct implementation of the concept of "count the number of +// bytes in the string which are not continuation bytes", and is used for the +// head and tail of the input string (the first and last item in the tuple +// returned by `slice::align_to`). +fn char_count_general_case(s: &[u8]) -> usize { + s.iter() + .filter(|&&byte| !super::core_str::utf8_is_cont_byte(byte)) + .count() +} + +// polyfills of unstable library features + +const fn usize_repeat_u8(x: u8) -> usize { + usize::from_ne_bytes([x; size_of::()]) +} + +const fn usize_repeat_u16(x: u16) -> usize { + let mut r = 0usize; + let mut i = 0; + while i < size_of::() { + // Use `wrapping_shl` to make it work on targets with 16-bit `usize` + r = r.wrapping_shl(16) | (x as usize); + i += 2; + } + r +} + +fn slice_as_chunks(slice: &[T]) -> (&[[T; N]], &[T]) { + assert!(N != 0, "chunk size must be non-zero"); + let len_rounded_down = slice.len() / N * N; + // SAFETY: The rounded-down value is always the same or smaller than the + // original length, and thus must be in-bounds of the slice. + let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) }; + // SAFETY: We already panicked for zero, and ensured by construction + // that the length of the subslice is a multiple of N. + let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) }; + (array_slice, remainder) +} + +unsafe fn slice_as_chunks_unchecked(slice: &[T]) -> &[[T; N]] { + let new_len = slice.len() / N; + // SAFETY: We cast a slice of `new_len * N` elements into + // a slice of `new_len` many `N` elements chunks. + unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), new_len) } +} + +fn unlikely(x: bool) -> bool { + x +} diff --git a/common/src/wtf8/mod.rs b/common/src/wtf8/mod.rs new file mode 100644 index 0000000000..f209b98a3e --- /dev/null +++ b/common/src/wtf8/mod.rs @@ -0,0 +1,1481 @@ +//! An implementation of [WTF-8], a utf8-compatible encoding that allows for +//! unpaired surrogate codepoints. This implementation additionally allows for +//! paired surrogates that are nonetheless treated as two separate codepoints. +//! +//! +//! RustPython uses this because CPython internally uses a variant of UCS-1/2/4 +//! as its string storage, which treats each `u8`/`u16`/`u32` value (depending +//! on the highest codepoint value in the string) as simply integers, unlike +//! UTF-8 or UTF-16 where some characters are encoded using multi-byte +//! sequences. CPython additionally doesn't disallow the use of surrogates in +//! `str`s (which in UTF-16 pair together to represent codepoints with a value +//! higher than `u16::MAX`) and in fact takes quite extensive advantage of the +//! fact that they're allowed. The `surrogateescape` codec-error handler uses +//! them to represent byte sequences which are invalid in the given codec (e.g. +//! bytes with their high bit set in ASCII or UTF-8) by mapping them into the +//! surrogate range. `surrogateescape` is the default error handler in Python +//! for interacting with the filesystem, and thus if RustPython is to properly +//! support `surrogateescape`, its `str`s must be able to represent surrogates. +//! +//! We use WTF-8 over something more similar to CPython's string implementation +//! because of its compatibility with UTF-8, meaning that in the case where a +//! string has no surrogates, it can be viewed as a UTF-8 Rust [`str`] without +//! needing any copies or re-encoding. +//! +//! This implementation is mostly copied from the WTF-8 implentation in the +//! Rust 1.85 standard library, which is used as the backing for [`OsStr`] on +//! Windows targets. As previously mentioned, however, it is modified to not +//! join two surrogates into one codepoint when concatenating strings, in order +//! to match CPython's behavior. +//! +//! [WTF-8]: https://simonsapin.github.io/wtf-8 +//! [`OsStr`]: std::ffi::OsStr + +#![allow(clippy::precedence, clippy::match_overlapping_arm)] + +use core::fmt; +use core::hash::{Hash, Hasher}; +use core::iter::FusedIterator; +use core::mem; +use core::ops; +use core::slice; +use core::str; +use core_char::MAX_LEN_UTF8; +use core_char::{MAX_LEN_UTF16, encode_utf8_raw, encode_utf16_raw, len_utf8}; +use core_str::{next_code_point, next_code_point_reverse}; +use itertools::{Either, Itertools}; +use std::borrow::{Borrow, Cow}; +use std::collections::TryReserveError; +use std::string::String; +use std::vec::Vec; + +use bstr::{ByteSlice, ByteVec}; + +mod core_char; +mod core_str; +mod core_str_count; + +const UTF8_REPLACEMENT_CHARACTER: &str = "\u{FFFD}"; + +/// A Unicode code point: from U+0000 to U+10FFFF. +/// +/// Compares with the `char` type, +/// which represents a Unicode scalar value: +/// a code point that is not a surrogate (U+D800 to U+DFFF). +#[derive(Eq, PartialEq, Ord, PartialOrd, Clone, Copy)] +pub struct CodePoint { + value: u32, +} + +/// Format the code point as `U+` followed by four to six hexadecimal digits. +/// Example: `U+1F4A9` +impl fmt::Debug for CodePoint { + #[inline] + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(formatter, "U+{:04X}", self.value) + } +} + +impl fmt::Display for CodePoint { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.to_char_lossy().fmt(f) + } +} + +impl CodePoint { + /// Unsafely creates a new `CodePoint` without checking the value. + /// + /// # Safety + /// + /// `value` must be less than or equal to 0x10FFFF. + #[inline] + pub unsafe fn from_u32_unchecked(value: u32) -> CodePoint { + CodePoint { value } + } + + /// Creates a new `CodePoint` if the value is a valid code point. + /// + /// Returns `None` if `value` is above 0x10FFFF. + #[inline] + pub fn from_u32(value: u32) -> Option { + match value { + 0..=0x10FFFF => Some(CodePoint { value }), + _ => None, + } + } + + /// Creates a new `CodePoint` from a `char`. + /// + /// Since all Unicode scalar values are code points, this always succeeds. + #[inline] + pub fn from_char(value: char) -> CodePoint { + CodePoint { + value: value as u32, + } + } + + /// Returns the numeric value of the code point. + #[inline] + pub fn to_u32(self) -> u32 { + self.value + } + + /// Returns the numeric value of the code point if it is a leading surrogate. + #[inline] + pub fn to_lead_surrogate(self) -> Option { + match self.value { + lead @ 0xD800..=0xDBFF => Some(lead as u16), + _ => None, + } + } + + /// Returns the numeric value of the code point if it is a trailing surrogate. + #[inline] + pub fn to_trail_surrogate(self) -> Option { + match self.value { + trail @ 0xDC00..=0xDFFF => Some(trail as u16), + _ => None, + } + } + + /// Optionally returns a Unicode scalar value for the code point. + /// + /// Returns `None` if the code point is a surrogate (from U+D800 to U+DFFF). + #[inline] + pub fn to_char(self) -> Option { + match self.value { + 0xD800..=0xDFFF => None, + _ => Some(unsafe { char::from_u32_unchecked(self.value) }), + } + } + + /// Returns a Unicode scalar value for the code point. + /// + /// Returns `'\u{FFFD}'` (the replacement character “�”) + /// if the code point is a surrogate (from U+D800 to U+DFFF). + #[inline] + pub fn to_char_lossy(self) -> char { + self.to_char().unwrap_or('\u{FFFD}') + } + + pub fn is_char_and(self, f: impl FnOnce(char) -> bool) -> bool { + self.to_char().is_some_and(f) + } + + pub fn encode_wtf8(self, dst: &mut [u8]) -> &mut Wtf8 { + unsafe { Wtf8::from_mut_bytes_unchecked(encode_utf8_raw(self.value, dst)) } + } + + pub fn len_wtf8(&self) -> usize { + len_utf8(self.value) + } + + pub fn is_ascii(&self) -> bool { + self.is_char_and(|c| c.is_ascii()) + } +} + +impl From for CodePoint { + fn from(value: u16) -> Self { + unsafe { Self::from_u32_unchecked(value.into()) } + } +} + +impl From for CodePoint { + fn from(value: u8) -> Self { + char::from(value).into() + } +} + +impl From for CodePoint { + fn from(value: char) -> Self { + Self::from_char(value) + } +} + +impl From for CodePoint { + fn from(value: ascii::AsciiChar) -> Self { + Self::from_char(value.into()) + } +} + +impl From for Wtf8Buf { + fn from(ch: CodePoint) -> Self { + ch.encode_wtf8(&mut [0; MAX_LEN_UTF8]).to_owned() + } +} + +impl PartialEq for CodePoint { + fn eq(&self, other: &char) -> bool { + self.to_u32() == *other as u32 + } +} +impl PartialEq for char { + fn eq(&self, other: &CodePoint) -> bool { + *self as u32 == other.to_u32() + } +} + +/// An owned, growable string of well-formed WTF-8 data. +/// +/// Similar to `String`, but can additionally contain surrogate code points +/// if they’re not in a surrogate pair. +#[derive(Eq, PartialEq, Ord, PartialOrd, Clone, Default)] +pub struct Wtf8Buf { + bytes: Vec, +} + +impl ops::Deref for Wtf8Buf { + type Target = Wtf8; + + fn deref(&self) -> &Wtf8 { + self.as_slice() + } +} + +impl ops::DerefMut for Wtf8Buf { + fn deref_mut(&mut self) -> &mut Wtf8 { + self.as_mut_slice() + } +} + +impl Borrow for Wtf8Buf { + fn borrow(&self) -> &Wtf8 { + self + } +} + +/// Formats the string in double quotes, with characters escaped according to +/// [`char::escape_debug`] and unpaired surrogates represented as `\u{xxxx}`, +/// where each `x` is a hexadecimal digit. +/// +/// For example, the code units [U+0061, U+D800, U+000A] are formatted as +/// `"a\u{D800}\n"`. +impl fmt::Debug for Wtf8Buf { + #[inline] + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&**self, formatter) + } +} + +/// Formats the string with unpaired surrogates substituted with the replacement +/// character, U+FFFD. +impl fmt::Display for Wtf8Buf { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&**self, formatter) + } +} + +impl Wtf8Buf { + /// Creates a new, empty WTF-8 string. + #[inline] + pub fn new() -> Wtf8Buf { + Wtf8Buf::default() + } + + /// Creates a new, empty WTF-8 string with pre-allocated capacity for `capacity` bytes. + #[inline] + pub fn with_capacity(capacity: usize) -> Wtf8Buf { + Wtf8Buf { + bytes: Vec::with_capacity(capacity), + } + } + + /// Creates a WTF-8 string from a WTF-8 byte vec. + /// + /// # Safety + /// + /// `value` must contain valid WTF-8. + #[inline] + pub unsafe fn from_bytes_unchecked(value: Vec) -> Wtf8Buf { + Wtf8Buf { bytes: value } + } + + /// Creates a WTF-8 string from a UTF-8 `String`. + /// + /// This takes ownership of the `String` and does not copy. + /// + /// Since WTF-8 is a superset of UTF-8, this always succeeds. + #[inline] + pub fn from_string(string: String) -> Wtf8Buf { + Wtf8Buf { + bytes: string.into_bytes(), + } + } + + pub fn clear(&mut self) { + self.bytes.clear(); + } + + /// Creates a WTF-8 string from a potentially ill-formed UTF-16 slice of 16-bit code units. + /// + /// This is lossless: calling `.encode_wide()` on the resulting string + /// will always return the original code units. + pub fn from_wide(v: &[u16]) -> Wtf8Buf { + let mut string = Wtf8Buf::with_capacity(v.len()); + for item in char::decode_utf16(v.iter().cloned()) { + match item { + Ok(ch) => string.push_char(ch), + Err(surrogate) => { + let surrogate = surrogate.unpaired_surrogate(); + // Surrogates are known to be in the code point range. + let code_point = CodePoint::from(surrogate); + // Skip the WTF-8 concatenation check, + // surrogate pairs are already decoded by decode_utf16 + string.push(code_point); + } + } + } + string + } + + #[inline] + pub fn as_slice(&self) -> &Wtf8 { + unsafe { Wtf8::from_bytes_unchecked(&self.bytes) } + } + + #[inline] + pub fn as_mut_slice(&mut self) -> &mut Wtf8 { + // Safety: `Wtf8` doesn't expose any way to mutate the bytes that would + // cause them to change from well-formed UTF-8 to ill-formed UTF-8, + // which would break the assumptions of the `is_known_utf8` field. + unsafe { Wtf8::from_mut_bytes_unchecked(&mut self.bytes) } + } + + /// Reserves capacity for at least `additional` more bytes to be inserted + /// in the given `Wtf8Buf`. + /// The collection may reserve more space to avoid frequent reallocations. + /// + /// # Panics + /// + /// Panics if the new capacity exceeds `isize::MAX` bytes. + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.bytes.reserve(additional) + } + + /// Tries to reserve capacity for at least `additional` more bytes to be + /// inserted in the given `Wtf8Buf`. The `Wtf8Buf` may reserve more space to + /// avoid frequent reallocations. After calling `try_reserve`, capacity will + /// be greater than or equal to `self.len() + additional`. Does nothing if + /// capacity is already sufficient. This method preserves the contents even + /// if an error occurs. + /// + /// # Errors + /// + /// If the capacity overflows, or the allocator reports a failure, then an error + /// is returned. + #[inline] + pub fn try_reserve(&mut self, additional: usize) -> Result<(), TryReserveError> { + self.bytes.try_reserve(additional) + } + + #[inline] + pub fn reserve_exact(&mut self, additional: usize) { + self.bytes.reserve_exact(additional) + } + + /// Tries to reserve the minimum capacity for exactly `additional` more + /// bytes to be inserted in the given `Wtf8Buf`. After calling + /// `try_reserve_exact`, capacity will be greater than or equal to + /// `self.len() + additional` if it returns `Ok(())`. + /// Does nothing if the capacity is already sufficient. + /// + /// Note that the allocator may give the `Wtf8Buf` more space than it + /// requests. Therefore, capacity can not be relied upon to be precisely + /// minimal. Prefer [`try_reserve`] if future insertions are expected. + /// + /// [`try_reserve`]: Wtf8Buf::try_reserve + /// + /// # Errors + /// + /// If the capacity overflows, or the allocator reports a failure, then an error + /// is returned. + #[inline] + pub fn try_reserve_exact(&mut self, additional: usize) -> Result<(), TryReserveError> { + self.bytes.try_reserve_exact(additional) + } + + #[inline] + pub fn shrink_to_fit(&mut self) { + self.bytes.shrink_to_fit() + } + + #[inline] + pub fn shrink_to(&mut self, min_capacity: usize) { + self.bytes.shrink_to(min_capacity) + } + + #[inline] + pub fn leak<'a>(self) -> &'a mut Wtf8 { + unsafe { Wtf8::from_mut_bytes_unchecked(self.bytes.leak()) } + } + + /// Returns the number of bytes that this string buffer can hold without reallocating. + #[inline] + pub fn capacity(&self) -> usize { + self.bytes.capacity() + } + + /// Append a UTF-8 slice at the end of the string. + #[inline] + pub fn push_str(&mut self, other: &str) { + self.bytes.extend_from_slice(other.as_bytes()) + } + + /// Append a WTF-8 slice at the end of the string. + #[inline] + pub fn push_wtf8(&mut self, other: &Wtf8) { + self.bytes.extend_from_slice(&other.bytes); + } + + /// Append a Unicode scalar value at the end of the string. + #[inline] + pub fn push_char(&mut self, c: char) { + self.push(CodePoint::from_char(c)) + } + + /// Append a code point at the end of the string. + #[inline] + pub fn push(&mut self, code_point: CodePoint) { + self.push_wtf8(code_point.encode_wtf8(&mut [0; MAX_LEN_UTF8])) + } + + pub fn pop(&mut self) -> Option { + let ch = self.code_points().next_back()?; + let newlen = self.len() - ch.len_wtf8(); + self.bytes.truncate(newlen); + Some(ch) + } + + /// Shortens a string to the specified length. + /// + /// # Panics + /// + /// Panics if `new_len` > current length, + /// or if `new_len` is not a code point boundary. + #[inline] + pub fn truncate(&mut self, new_len: usize) { + assert!(is_code_point_boundary(self, new_len)); + self.bytes.truncate(new_len) + } + + /// Inserts a codepoint into this `Wtf8Buf` at a byte position. + #[inline] + pub fn insert(&mut self, idx: usize, c: CodePoint) { + self.insert_wtf8(idx, c.encode_wtf8(&mut [0; MAX_LEN_UTF8])) + } + + /// Inserts a WTF-8 slice into this `Wtf8Buf` at a byte position. + #[inline] + pub fn insert_wtf8(&mut self, idx: usize, w: &Wtf8) { + assert!(is_code_point_boundary(self, idx)); + + self.bytes.insert_str(idx, w) + } + + /// Consumes the WTF-8 string and tries to convert it to a vec of bytes. + #[inline] + pub fn into_bytes(self) -> Vec { + self.bytes + } + + /// Consumes the WTF-8 string and tries to convert it to UTF-8. + /// + /// This does not copy the data. + /// + /// If the contents are not well-formed UTF-8 + /// (that is, if the string contains surrogates), + /// the original WTF-8 string is returned instead. + pub fn into_string(self) -> Result { + if self.is_utf8() { + Ok(unsafe { String::from_utf8_unchecked(self.bytes) }) + } else { + Err(self) + } + } + + /// Consumes the WTF-8 string and converts it lossily to UTF-8. + /// + /// This does not copy the data (but may overwrite parts of it in place). + /// + /// Surrogates are replaced with `"\u{FFFD}"` (the replacement character “�”) + pub fn into_string_lossy(mut self) -> String { + let mut pos = 0; + while let Some((surrogate_pos, _)) = self.next_surrogate(pos) { + pos = surrogate_pos + 3; + // Surrogates and the replacement character are all 3 bytes, so + // they can substituted in-place. + self.bytes[surrogate_pos..pos].copy_from_slice(UTF8_REPLACEMENT_CHARACTER.as_bytes()); + } + unsafe { String::from_utf8_unchecked(self.bytes) } + } + + /// Converts this `Wtf8Buf` into a boxed `Wtf8`. + #[inline] + pub fn into_box(self) -> Box { + // SAFETY: relies on `Wtf8` being `repr(transparent)`. + unsafe { mem::transmute(self.bytes.into_boxed_slice()) } + } + + /// Converts a `Box` into a `Wtf8Buf`. + pub fn from_box(boxed: Box) -> Wtf8Buf { + let bytes: Box<[u8]> = unsafe { mem::transmute(boxed) }; + Wtf8Buf { + bytes: bytes.into_vec(), + } + } +} + +/// Creates a new WTF-8 string from an iterator of code points. +/// +/// This replaces surrogate code point pairs with supplementary code points, +/// like concatenating ill-formed UTF-16 strings effectively would. +impl FromIterator for Wtf8Buf { + fn from_iter>(iter: T) -> Wtf8Buf { + let mut string = Wtf8Buf::new(); + string.extend(iter); + string + } +} + +/// Append code points from an iterator to the string. +/// +/// This replaces surrogate code point pairs with supplementary code points, +/// like concatenating ill-formed UTF-16 strings effectively would. +impl Extend for Wtf8Buf { + fn extend>(&mut self, iter: T) { + let iterator = iter.into_iter(); + let (low, _high) = iterator.size_hint(); + // Lower bound of one byte per code point (ASCII only) + self.bytes.reserve(low); + iterator.for_each(move |code_point| self.push(code_point)); + } +} + +impl Extend for Wtf8Buf { + fn extend>(&mut self, iter: T) { + self.extend(iter.into_iter().map(CodePoint::from)) + } +} + +impl> Extend for Wtf8Buf { + fn extend>(&mut self, iter: T) { + iter.into_iter() + .for_each(move |w| self.push_wtf8(w.as_ref())); + } +} + +impl> FromIterator for Wtf8Buf { + fn from_iter>(iter: T) -> Self { + let mut buf = Wtf8Buf::new(); + iter.into_iter().for_each(|w| buf.push_wtf8(w.as_ref())); + buf + } +} + +impl AsRef for Wtf8Buf { + fn as_ref(&self) -> &Wtf8 { + self + } +} + +impl From for Wtf8Buf { + fn from(s: String) -> Self { + Wtf8Buf::from_string(s) + } +} + +impl From<&str> for Wtf8Buf { + fn from(s: &str) -> Self { + Wtf8Buf::from_string(s.to_owned()) + } +} + +impl From for Wtf8Buf { + fn from(s: ascii::AsciiString) -> Self { + Wtf8Buf::from_string(s.into()) + } +} + +/// A borrowed slice of well-formed WTF-8 data. +/// +/// Similar to `&str`, but can additionally contain surrogate code points +/// if they’re not in a surrogate pair. +#[derive(PartialEq, Eq, PartialOrd, Ord)] +pub struct Wtf8 { + bytes: [u8], +} + +impl AsRef for Wtf8 { + fn as_ref(&self) -> &Wtf8 { + self + } +} + +impl ToOwned for Wtf8 { + type Owned = Wtf8Buf; + fn to_owned(&self) -> Self::Owned { + self.to_wtf8_buf() + } +} + +impl PartialEq for Wtf8 { + fn eq(&self, other: &str) -> bool { + self.as_bytes().eq(other.as_bytes()) + } +} + +/// Formats the string in double quotes, with characters escaped according to +/// [`char::escape_debug`] and unpaired surrogates represented as `\u{xxxx}`, +/// where each `x` is a hexadecimal digit. +impl fmt::Debug for Wtf8 { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + fn write_str_escaped(f: &mut fmt::Formatter<'_>, s: &str) -> fmt::Result { + use std::fmt::Write; + for c in s.chars().flat_map(|c| c.escape_debug()) { + f.write_char(c)? + } + Ok(()) + } + + formatter.write_str("\"")?; + let mut pos = 0; + while let Some((surrogate_pos, surrogate)) = self.next_surrogate(pos) { + write_str_escaped(formatter, unsafe { + str::from_utf8_unchecked(&self.bytes[pos..surrogate_pos]) + })?; + write!(formatter, "\\u{{{:x}}}", surrogate)?; + pos = surrogate_pos + 3; + } + write_str_escaped(formatter, unsafe { + str::from_utf8_unchecked(&self.bytes[pos..]) + })?; + formatter.write_str("\"") + } +} + +/// Formats the string with unpaired surrogates substituted with the replacement +/// character, U+FFFD. +impl fmt::Display for Wtf8 { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + let wtf8_bytes = &self.bytes; + let mut pos = 0; + loop { + match self.next_surrogate(pos) { + Some((surrogate_pos, _)) => { + formatter.write_str(unsafe { + str::from_utf8_unchecked(&wtf8_bytes[pos..surrogate_pos]) + })?; + formatter.write_str(UTF8_REPLACEMENT_CHARACTER)?; + pos = surrogate_pos + 3; + } + None => { + let s = unsafe { str::from_utf8_unchecked(&wtf8_bytes[pos..]) }; + if pos == 0 { + return s.fmt(formatter); + } else { + return formatter.write_str(s); + } + } + } + } + } +} + +impl Default for &Wtf8 { + fn default() -> Self { + unsafe { Wtf8::from_bytes_unchecked(&[]) } + } +} + +impl Wtf8 { + /// Creates a WTF-8 slice from a UTF-8 `&str` slice. + /// + /// Since WTF-8 is a superset of UTF-8, this always succeeds. + #[inline] + pub fn new + ?Sized>(value: &S) -> &Wtf8 { + value.as_ref() + } + + /// Creates a WTF-8 slice from a WTF-8 byte slice. + /// + /// # Safety + /// + /// `value` must contain valid WTF-8. + #[inline] + pub unsafe fn from_bytes_unchecked(value: &[u8]) -> &Wtf8 { + // SAFETY: start with &[u8], end with fancy &[u8] + unsafe { &*(value as *const [u8] as *const Wtf8) } + } + + /// Creates a mutable WTF-8 slice from a mutable WTF-8 byte slice. + /// + /// Since the byte slice is not checked for valid WTF-8, this functions is + /// marked unsafe. + #[inline] + unsafe fn from_mut_bytes_unchecked(value: &mut [u8]) -> &mut Wtf8 { + // SAFETY: start with &mut [u8], end with fancy &mut [u8] + unsafe { &mut *(value as *mut [u8] as *mut Wtf8) } + } + + /// Returns the length, in WTF-8 bytes. + #[inline] + pub fn len(&self) -> usize { + self.bytes.len() + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.bytes.is_empty() + } + + /// Returns the code point at `position` if it is in the ASCII range, + /// or `b'\xFF'` otherwise. + /// + /// # Panics + /// + /// Panics if `position` is beyond the end of the string. + #[inline] + pub fn ascii_byte_at(&self, position: usize) -> u8 { + match self.bytes[position] { + ascii_byte @ 0x00..=0x7F => ascii_byte, + _ => 0xFF, + } + } + + /// Returns an iterator for the string’s code points. + #[inline] + pub fn code_points(&self) -> Wtf8CodePoints<'_> { + Wtf8CodePoints { + bytes: self.bytes.iter(), + } + } + + /// Returns an iterator for the string’s code points and their indices. + #[inline] + pub fn code_point_indices(&self) -> Wtf8CodePointIndices<'_> { + Wtf8CodePointIndices { + front_offset: 0, + iter: self.code_points(), + } + } + + /// Access raw bytes of WTF-8 data + #[inline] + pub fn as_bytes(&self) -> &[u8] { + &self.bytes + } + + /// Tries to convert the string to UTF-8 and return a `&str` slice. + /// + /// Returns `None` if the string contains surrogates. + /// + /// This does not copy the data. + #[inline] + pub fn as_str(&self) -> Result<&str, str::Utf8Error> { + str::from_utf8(&self.bytes) + } + + /// Creates an owned `Wtf8Buf` from a borrowed `Wtf8`. + pub fn to_wtf8_buf(&self) -> Wtf8Buf { + Wtf8Buf { + bytes: self.bytes.to_vec(), + } + } + + /// Lossily converts the string to UTF-8. + /// Returns a UTF-8 `&str` slice if the contents are well-formed in UTF-8. + /// + /// Surrogates are replaced with `"\u{FFFD}"` (the replacement character “�”). + /// + /// This only copies the data if necessary (if it contains any surrogate). + pub fn to_string_lossy(&self) -> Cow<'_, str> { + let Some((surrogate_pos, _)) = self.next_surrogate(0) else { + return Cow::Borrowed(unsafe { str::from_utf8_unchecked(&self.bytes) }); + }; + let wtf8_bytes = &self.bytes; + let mut utf8_bytes = Vec::with_capacity(self.len()); + utf8_bytes.extend_from_slice(&wtf8_bytes[..surrogate_pos]); + utf8_bytes.extend_from_slice(UTF8_REPLACEMENT_CHARACTER.as_bytes()); + let mut pos = surrogate_pos + 3; + loop { + match self.next_surrogate(pos) { + Some((surrogate_pos, _)) => { + utf8_bytes.extend_from_slice(&wtf8_bytes[pos..surrogate_pos]); + utf8_bytes.extend_from_slice(UTF8_REPLACEMENT_CHARACTER.as_bytes()); + pos = surrogate_pos + 3; + } + None => { + utf8_bytes.extend_from_slice(&wtf8_bytes[pos..]); + return Cow::Owned(unsafe { String::from_utf8_unchecked(utf8_bytes) }); + } + } + } + } + + /// Converts the WTF-8 string to potentially ill-formed UTF-16 + /// and return an iterator of 16-bit code units. + /// + /// This is lossless: + /// calling `Wtf8Buf::from_ill_formed_utf16` on the resulting code units + /// would always return the original WTF-8 string. + #[inline] + pub fn encode_wide(&self) -> EncodeWide<'_> { + EncodeWide { + code_points: self.code_points(), + extra: 0, + } + } + + pub fn chunks(&self) -> Wtf8Chunks<'_> { + Wtf8Chunks { wtf8: self } + } + + pub fn map_utf8<'a, I>(&'a self, f: impl Fn(&'a str) -> I) -> impl Iterator + where + I: Iterator, + { + self.chunks().flat_map(move |chunk| match chunk { + Wtf8Chunk::Utf8(s) => Either::Left(f(s).map_into()), + Wtf8Chunk::Surrogate(c) => Either::Right(std::iter::once(c)), + }) + } + + #[inline] + fn next_surrogate(&self, mut pos: usize) -> Option<(usize, u16)> { + let mut iter = self.bytes[pos..].iter(); + loop { + let b = *iter.next()?; + if b < 0x80 { + pos += 1; + } else if b < 0xE0 { + iter.next(); + pos += 2; + } else if b == 0xED { + match (iter.next(), iter.next()) { + (Some(&b2), Some(&b3)) if b2 >= 0xA0 => { + return Some((pos, decode_surrogate(b2, b3))); + } + _ => pos += 3, + } + } else if b < 0xF0 { + iter.next(); + iter.next(); + pos += 3; + } else { + iter.next(); + iter.next(); + iter.next(); + pos += 4; + } + } + } + + pub fn clone_into(&self, buf: &mut Wtf8Buf) { + self.bytes.clone_into(&mut buf.bytes); + } + + /// Boxes this `Wtf8`. + #[inline] + pub fn into_box(&self) -> Box { + let boxed: Box<[u8]> = self.bytes.into(); + unsafe { mem::transmute(boxed) } + } + + /// Creates a boxed, empty `Wtf8`. + pub fn empty_box() -> Box { + let boxed: Box<[u8]> = Default::default(); + unsafe { mem::transmute(boxed) } + } + + #[inline] + pub fn make_ascii_lowercase(&mut self) { + self.bytes.make_ascii_lowercase() + } + + #[inline] + pub fn make_ascii_uppercase(&mut self) { + self.bytes.make_ascii_uppercase() + } + + #[inline] + pub fn to_ascii_lowercase(&self) -> Wtf8Buf { + Wtf8Buf { + bytes: self.bytes.to_ascii_lowercase(), + } + } + + #[inline] + pub fn to_ascii_uppercase(&self) -> Wtf8Buf { + Wtf8Buf { + bytes: self.bytes.to_ascii_uppercase(), + } + } + + #[inline] + pub fn is_ascii(&self) -> bool { + self.bytes.is_ascii() + } + + #[inline] + pub fn is_utf8(&self) -> bool { + self.next_surrogate(0).is_none() + } + + #[inline] + pub fn eq_ignore_ascii_case(&self, other: &Self) -> bool { + self.bytes.eq_ignore_ascii_case(&other.bytes) + } + + pub fn split(&self, pat: &Wtf8) -> impl Iterator { + self.as_bytes() + .split_str(pat) + .map(|w| unsafe { Wtf8::from_bytes_unchecked(w) }) + } + + pub fn splitn(&self, n: usize, pat: &Wtf8) -> impl Iterator { + self.as_bytes() + .splitn_str(n, pat) + .map(|w| unsafe { Wtf8::from_bytes_unchecked(w) }) + } + + pub fn rsplit(&self, pat: &Wtf8) -> impl Iterator { + self.as_bytes() + .rsplit_str(pat) + .map(|w| unsafe { Wtf8::from_bytes_unchecked(w) }) + } + + pub fn rsplitn(&self, n: usize, pat: &Wtf8) -> impl Iterator { + self.as_bytes() + .rsplitn_str(n, pat) + .map(|w| unsafe { Wtf8::from_bytes_unchecked(w) }) + } + + pub fn trim(&self) -> &Self { + let w = self.bytes.trim(); + unsafe { Wtf8::from_bytes_unchecked(w) } + } + + pub fn trim_start(&self) -> &Self { + let w = self.bytes.trim_start(); + unsafe { Wtf8::from_bytes_unchecked(w) } + } + + pub fn trim_end(&self) -> &Self { + let w = self.bytes.trim_end(); + unsafe { Wtf8::from_bytes_unchecked(w) } + } + + pub fn trim_start_matches(&self, f: impl Fn(CodePoint) -> bool) -> &Self { + let mut iter = self.code_points(); + loop { + let old = iter.clone(); + match iter.next().map(&f) { + Some(true) => continue, + Some(false) => { + iter = old; + break; + } + None => return iter.as_wtf8(), + } + } + iter.as_wtf8() + } + + pub fn trim_end_matches(&self, f: impl Fn(CodePoint) -> bool) -> &Self { + let mut iter = self.code_points(); + loop { + let old = iter.clone(); + match iter.next_back().map(&f) { + Some(true) => continue, + Some(false) => { + iter = old; + break; + } + None => return iter.as_wtf8(), + } + } + iter.as_wtf8() + } + + pub fn trim_matches(&self, f: impl Fn(CodePoint) -> bool) -> &Self { + self.trim_start_matches(&f).trim_end_matches(&f) + } + + pub fn find(&self, pat: &Wtf8) -> Option { + memchr::memmem::find(self.as_bytes(), pat.as_bytes()) + } + + pub fn rfind(&self, pat: &Wtf8) -> Option { + memchr::memmem::rfind(self.as_bytes(), pat.as_bytes()) + } + + pub fn find_iter(&self, pat: &Wtf8) -> impl Iterator { + memchr::memmem::find_iter(self.as_bytes(), pat.as_bytes()) + } + + pub fn rfind_iter(&self, pat: &Wtf8) -> impl Iterator { + memchr::memmem::rfind_iter(self.as_bytes(), pat.as_bytes()) + } + + pub fn contains(&self, pat: &Wtf8) -> bool { + self.bytes.contains_str(pat) + } + + pub fn contains_code_point(&self, pat: CodePoint) -> bool { + self.bytes + .contains_str(pat.encode_wtf8(&mut [0; MAX_LEN_UTF8])) + } + + pub fn get(&self, range: impl ops::RangeBounds) -> Option<&Self> { + let start = match range.start_bound() { + ops::Bound::Included(&i) => i, + ops::Bound::Excluded(&i) => i.saturating_add(1), + ops::Bound::Unbounded => 0, + }; + let end = match range.end_bound() { + ops::Bound::Included(&i) => i.saturating_add(1), + ops::Bound::Excluded(&i) => i, + ops::Bound::Unbounded => self.len(), + }; + // is_code_point_boundary checks that the index is in [0, .len()] + if start <= end && is_code_point_boundary(self, start) && is_code_point_boundary(self, end) + { + Some(unsafe { slice_unchecked(self, start, end) }) + } else { + None + } + } + + pub fn ends_with(&self, w: &Wtf8) -> bool { + self.bytes.ends_with_str(w) + } + + pub fn starts_with(&self, w: &Wtf8) -> bool { + self.bytes.starts_with_str(w) + } + + pub fn strip_prefix(&self, w: &Wtf8) -> Option<&Self> { + self.bytes + .strip_prefix(w.as_bytes()) + .map(|w| unsafe { Wtf8::from_bytes_unchecked(w) }) + } + + pub fn strip_suffix(&self, w: &Wtf8) -> Option<&Self> { + self.bytes + .strip_suffix(w.as_bytes()) + .map(|w| unsafe { Wtf8::from_bytes_unchecked(w) }) + } + + pub fn replace(&self, from: &Wtf8, to: &Wtf8) -> Wtf8Buf { + let w = self.bytes.replace(from, to); + unsafe { Wtf8Buf::from_bytes_unchecked(w) } + } + + pub fn replacen(&self, from: &Wtf8, to: &Wtf8, n: usize) -> Wtf8Buf { + let w = self.bytes.replacen(from, to, n); + unsafe { Wtf8Buf::from_bytes_unchecked(w) } + } +} + +impl AsRef for str { + fn as_ref(&self) -> &Wtf8 { + unsafe { Wtf8::from_bytes_unchecked(self.as_bytes()) } + } +} + +impl AsRef<[u8]> for Wtf8 { + fn as_ref(&self) -> &[u8] { + self.as_bytes() + } +} + +/// Returns a slice of the given string for the byte range \[`begin`..`end`). +/// +/// # Panics +/// +/// Panics when `begin` and `end` do not point to code point boundaries, +/// or point beyond the end of the string. +impl ops::Index> for Wtf8 { + type Output = Wtf8; + + #[inline] + fn index(&self, range: ops::Range) -> &Wtf8 { + // is_code_point_boundary checks that the index is in [0, .len()] + if range.start <= range.end + && is_code_point_boundary(self, range.start) + && is_code_point_boundary(self, range.end) + { + unsafe { slice_unchecked(self, range.start, range.end) } + } else { + slice_error_fail(self, range.start, range.end) + } + } +} + +/// Returns a slice of the given string from byte `begin` to its end. +/// +/// # Panics +/// +/// Panics when `begin` is not at a code point boundary, +/// or is beyond the end of the string. +impl ops::Index> for Wtf8 { + type Output = Wtf8; + + #[inline] + fn index(&self, range: ops::RangeFrom) -> &Wtf8 { + // is_code_point_boundary checks that the index is in [0, .len()] + if is_code_point_boundary(self, range.start) { + unsafe { slice_unchecked(self, range.start, self.len()) } + } else { + slice_error_fail(self, range.start, self.len()) + } + } +} + +/// Returns a slice of the given string from its beginning to byte `end`. +/// +/// # Panics +/// +/// Panics when `end` is not at a code point boundary, +/// or is beyond the end of the string. +impl ops::Index> for Wtf8 { + type Output = Wtf8; + + #[inline] + fn index(&self, range: ops::RangeTo) -> &Wtf8 { + // is_code_point_boundary checks that the index is in [0, .len()] + if is_code_point_boundary(self, range.end) { + unsafe { slice_unchecked(self, 0, range.end) } + } else { + slice_error_fail(self, 0, range.end) + } + } +} + +impl ops::Index for Wtf8 { + type Output = Wtf8; + + #[inline] + fn index(&self, _range: ops::RangeFull) -> &Wtf8 { + self + } +} + +#[inline] +fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 { + // The first byte is assumed to be 0xED + 0xD800 | (second_byte as u16 & 0x3F) << 6 | third_byte as u16 & 0x3F +} + +/// Copied from str::is_char_boundary +#[inline] +pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool { + if index == 0 { + return true; + } + match slice.bytes.get(index) { + None => index == slice.len(), + Some(&b) => (b as i8) >= -0x40, + } +} + +/// Verify that `index` is at the edge of either a valid UTF-8 codepoint +/// (i.e. a codepoint that's not a surrogate) or of the whole string. +/// +/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`. +/// Splitting between surrogates is valid as far as WTF-8 is concerned, but +/// we do not permit it in the public API because WTF-8 is considered an +/// implementation detail. +#[track_caller] +#[inline] +pub fn check_utf8_boundary(slice: &Wtf8, index: usize) { + if index == 0 { + return; + } + match slice.bytes.get(index) { + Some(0xED) => (), // Might be a surrogate + Some(&b) if (b as i8) >= -0x40 => return, + Some(_) => panic!("byte index {index} is not a codepoint boundary"), + None if index == slice.len() => return, + None => panic!("byte index {index} is out of bounds"), + } + if slice.bytes[index + 1] >= 0xA0 { + // There's a surrogate after index. Now check before index. + if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 { + panic!("byte index {index} lies between surrogate codepoints"); + } + } +} + +/// Copied from core::str::raw::slice_unchecked +/// +/// # Safety +/// +/// `begin` and `end` must be within bounds and on codepoint boundaries. +#[inline] +pub unsafe fn slice_unchecked(s: &Wtf8, begin: usize, end: usize) -> &Wtf8 { + // SAFETY: memory layout of a &[u8] and &Wtf8 are the same + unsafe { + let len = end - begin; + let start = s.as_bytes().as_ptr().add(begin); + Wtf8::from_bytes_unchecked(slice::from_raw_parts(start, len)) + } +} + +/// Copied from core::str::raw::slice_error_fail +#[inline(never)] +pub fn slice_error_fail(s: &Wtf8, begin: usize, end: usize) -> ! { + assert!(begin <= end); + panic!("index {begin} and/or {end} in `{s:?}` do not lie on character boundary"); +} + +/// Iterator for the code points of a WTF-8 string. +/// +/// Created with the method `.code_points()`. +#[derive(Clone)] +pub struct Wtf8CodePoints<'a> { + bytes: slice::Iter<'a, u8>, +} + +impl Iterator for Wtf8CodePoints<'_> { + type Item = CodePoint; + + #[inline] + fn next(&mut self) -> Option { + // SAFETY: `self.bytes` has been created from a WTF-8 string + unsafe { next_code_point(&mut self.bytes).map(|c| CodePoint { value: c }) } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.bytes.len(); + (len.saturating_add(3) / 4, Some(len)) + } + + fn last(mut self) -> Option { + self.next_back() + } + + fn count(self) -> usize { + core_str_count::count_chars(self.as_wtf8()) + } +} + +impl DoubleEndedIterator for Wtf8CodePoints<'_> { + #[inline] + fn next_back(&mut self) -> Option { + // SAFETY: `str` invariant says `self.iter` is a valid WTF-8 string and + // the resulting `ch` is a valid Unicode Code Point. + unsafe { + next_code_point_reverse(&mut self.bytes).map(|ch| CodePoint::from_u32_unchecked(ch)) + } + } +} + +impl<'a> Wtf8CodePoints<'a> { + pub fn as_wtf8(&self) -> &'a Wtf8 { + unsafe { Wtf8::from_bytes_unchecked(self.bytes.as_slice()) } + } +} + +#[derive(Clone)] +pub struct Wtf8CodePointIndices<'a> { + front_offset: usize, + iter: Wtf8CodePoints<'a>, +} + +impl Iterator for Wtf8CodePointIndices<'_> { + type Item = (usize, CodePoint); + + #[inline] + fn next(&mut self) -> Option<(usize, CodePoint)> { + let pre_len = self.iter.bytes.len(); + match self.iter.next() { + None => None, + Some(ch) => { + let index = self.front_offset; + let len = self.iter.bytes.len(); + self.front_offset += pre_len - len; + Some((index, ch)) + } + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.iter.size_hint() + } + + #[inline] + fn last(mut self) -> Option<(usize, CodePoint)> { + // No need to go through the entire string. + self.next_back() + } + + #[inline] + fn count(self) -> usize { + self.iter.count() + } +} + +impl DoubleEndedIterator for Wtf8CodePointIndices<'_> { + #[inline] + fn next_back(&mut self) -> Option<(usize, CodePoint)> { + self.iter.next_back().map(|ch| { + let index = self.front_offset + self.iter.bytes.len(); + (index, ch) + }) + } +} + +impl FusedIterator for Wtf8CodePointIndices<'_> {} + +/// Generates a wide character sequence for potentially ill-formed UTF-16. +#[derive(Clone)] +pub struct EncodeWide<'a> { + code_points: Wtf8CodePoints<'a>, + extra: u16, +} + +// Copied from libunicode/u_str.rs +impl Iterator for EncodeWide<'_> { + type Item = u16; + + #[inline] + fn next(&mut self) -> Option { + if self.extra != 0 { + let tmp = self.extra; + self.extra = 0; + return Some(tmp); + } + + let mut buf = [0; MAX_LEN_UTF16]; + self.code_points.next().map(|code_point| { + let n = encode_utf16_raw(code_point.value, &mut buf).len(); + if n == 2 { + self.extra = buf[1]; + } + buf[0] + }) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let (low, high) = self.code_points.size_hint(); + let ext = (self.extra != 0) as usize; + // every code point gets either one u16 or two u16, + // so this iterator is between 1 or 2 times as + // long as the underlying iterator. + ( + low + ext, + high.and_then(|n| n.checked_mul(2)) + .and_then(|n| n.checked_add(ext)), + ) + } +} + +impl FusedIterator for EncodeWide<'_> {} + +pub struct Wtf8Chunks<'a> { + wtf8: &'a Wtf8, +} + +impl<'a> Iterator for Wtf8Chunks<'a> { + type Item = Wtf8Chunk<'a>; + + fn next(&mut self) -> Option { + match self.wtf8.next_surrogate(0) { + Some((0, surrogate)) => { + self.wtf8 = &self.wtf8[3..]; + Some(Wtf8Chunk::Surrogate(surrogate.into())) + } + Some((n, _)) => { + let s = unsafe { str::from_utf8_unchecked(&self.wtf8.as_bytes()[..n]) }; + self.wtf8 = &self.wtf8[n..]; + Some(Wtf8Chunk::Utf8(s)) + } + None => { + let s = + unsafe { str::from_utf8_unchecked(std::mem::take(&mut self.wtf8).as_bytes()) }; + (!s.is_empty()).then_some(Wtf8Chunk::Utf8(s)) + } + } + } +} + +pub enum Wtf8Chunk<'a> { + Utf8(&'a str), + Surrogate(CodePoint), +} + +impl Hash for CodePoint { + #[inline] + fn hash(&self, state: &mut H) { + self.value.hash(state) + } +} + +// == BOX IMPLS == + +/// # Safety +/// +/// `value` must be valid WTF-8. +pub unsafe fn from_boxed_wtf8_unchecked(value: Box<[u8]>) -> Box { + unsafe { Box::from_raw(Box::into_raw(value) as *mut Wtf8) } +} + +impl Clone for Box { + fn clone(&self) -> Self { + (&**self).into() + } +} + +impl Default for Box { + fn default() -> Self { + unsafe { from_boxed_wtf8_unchecked(Box::default()) } + } +} + +impl From<&Wtf8> for Box { + fn from(w: &Wtf8) -> Self { + w.into_box() + } +} + +impl From<&str> for Box { + fn from(s: &str) -> Self { + Box::::from(s).into() + } +} + +impl From> for Box { + fn from(s: Box) -> Self { + unsafe { from_boxed_wtf8_unchecked(s.into_boxed_bytes()) } + } +} + +impl From> for Box { + fn from(s: Box) -> Self { + >::from(s).into() + } +} + +impl From> for Box<[u8]> { + fn from(w: Box) -> Self { + unsafe { Box::from_raw(Box::into_raw(w) as *mut [u8]) } + } +} + +impl From for Box { + fn from(w: Wtf8Buf) -> Self { + w.into_box() + } +} + +impl From for Box { + fn from(s: String) -> Self { + s.into_boxed_str().into() + } +} diff --git a/stdlib/src/csv.rs b/stdlib/src/csv.rs index f07b40b3a2..4b79130111 100644 --- a/stdlib/src/csv.rs +++ b/stdlib/src/csv.rs @@ -198,9 +198,9 @@ mod _csv { ) -> PyResult { match_class!(match obj.get_attr("lineterminator", vm)? { s @ PyStr => { - Ok(if s.as_str().as_bytes().eq(b"\r\n") { + Ok(if s.as_bytes().eq(b"\r\n") { csv_core::Terminator::CRLF - } else if let Some(t) = s.as_str().as_bytes().first() { + } else if let Some(t) = s.as_bytes().first() { // Due to limitations in the current implementation within csv_core // the support for multiple characters in lineterminator is not complete. // only capture the first character @@ -942,7 +942,7 @@ mod _csv { ), ) })?; - let input = string.as_str().as_bytes(); + let input = string.as_bytes(); if input.is_empty() || input.starts_with(b"\n") { return Ok(PyIterReturn::Return(vm.ctx.new_list(vec![]).into())); } @@ -1101,11 +1101,11 @@ mod _csv { let field: PyObjectRef = field?; let stringified; let data: &[u8] = match_class!(match field { - ref s @ PyStr => s.as_str().as_bytes(), + ref s @ PyStr => s.as_bytes(), crate::builtins::PyNone => b"", ref obj => { stringified = obj.str(vm)?; - stringified.as_str().as_bytes() + stringified.as_bytes() } }); let mut input_offset = 0; diff --git a/stdlib/src/pyexpat.rs b/stdlib/src/pyexpat.rs index 3cfe048f17..2363e6bed4 100644 --- a/stdlib/src/pyexpat.rs +++ b/stdlib/src/pyexpat.rs @@ -136,7 +136,7 @@ mod _pyexpat { #[pymethod(name = "Parse")] fn parse(&self, data: PyStrRef, _isfinal: OptionalArg, vm: &VirtualMachine) { - let reader = Cursor::>::new(data.as_str().as_bytes().to_vec()); + let reader = Cursor::>::new(data.as_bytes().to_vec()); let parser = self.create_config().create_reader(reader); self.do_parse(vm, parser); } diff --git a/stdlib/src/pystruct.rs b/stdlib/src/pystruct.rs index f8d41414f7..2c7aa1ebf7 100644 --- a/stdlib/src/pystruct.rs +++ b/stdlib/src/pystruct.rs @@ -27,30 +27,27 @@ pub(crate) mod _struct { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { // CPython turns str to bytes but we do reversed way here // The only performance difference is this transition cost - let fmt = match_class! { - match obj { - s @ PyStr => if s.is_ascii() { - Some(s) - } else { - None - }, - b @ PyBytes => if b.is_ascii() { - Some(unsafe { - PyStr::new_ascii_unchecked(b.as_bytes().to_vec()) - }.into_ref(&vm.ctx)) - } else { - None - }, - other => return Err(vm.new_type_error(format!("Struct() argument 1 must be a str or bytes object, not {}", other.class().name()))), - } - }.ok_or_else(|| vm.new_unicode_decode_error("Struct format must be a ascii string".to_owned()))?; + let fmt = match_class!(match obj { + s @ PyStr => s.is_ascii().then_some(s), + b @ PyBytes => ascii::AsciiStr::from_ascii(&b) + .ok() + .map(|s| vm.ctx.new_str(s)), + other => + return Err(vm.new_type_error(format!( + "Struct() argument 1 must be a str or bytes object, not {}", + other.class().name() + ))), + }) + .ok_or_else(|| { + vm.new_unicode_decode_error("Struct format must be a ascii string".to_owned()) + })?; Ok(IntoStructFormatBytes(fmt)) } } impl IntoStructFormatBytes { fn format_spec(&self, vm: &VirtualMachine) -> PyResult { - FormatSpec::parse(self.0.as_str().as_bytes(), vm) + FormatSpec::parse(self.0.as_bytes(), vm) } } diff --git a/stdlib/src/re.rs b/stdlib/src/re.rs index 5417e03fc7..647f4c69ad 100644 --- a/stdlib/src/re.rs +++ b/stdlib/src/re.rs @@ -9,10 +9,11 @@ mod re { * system. */ use crate::vm::{ + PyObjectRef, PyPayload, PyResult, VirtualMachine, builtins::{PyInt, PyIntRef, PyStr, PyStrRef}, convert::{ToPyObject, TryFromObject}, function::{OptionalArg, PosArgs}, - match_class, PyObjectRef, PyResult, PyPayload, VirtualMachine, + match_class, }; use num_traits::Signed; use regex::bytes::{Captures, Regex, RegexBuilder}; @@ -158,11 +159,9 @@ mod re { } fn do_sub(pattern: &PyPattern, repl: PyStrRef, search_text: PyStrRef, limit: usize) -> String { - let out = pattern.regex.replacen( - search_text.as_str().as_bytes(), - limit, - repl.as_str().as_bytes(), - ); + let out = pattern + .regex + .replacen(search_text.as_bytes(), limit, repl.as_bytes()); String::from_utf8_lossy(&out).into_owned() } @@ -172,21 +171,21 @@ mod re { regex_text.push_str(pattern.regex.as_str()); let regex = Regex::new(®ex_text).unwrap(); regex - .captures(search_text.as_str().as_bytes()) + .captures(search_text.as_bytes()) .map(|captures| create_match(search_text.clone(), captures)) } fn do_search(regex: &PyPattern, search_text: PyStrRef) -> Option { regex .regex - .captures(search_text.as_str().as_bytes()) + .captures(search_text.as_bytes()) .map(|captures| create_match(search_text.clone(), captures)) } fn do_findall(vm: &VirtualMachine, pattern: &PyPattern, search_text: PyStrRef) -> PyResult { let out = pattern .regex - .captures_iter(search_text.as_str().as_bytes()) + .captures_iter(search_text.as_bytes()) .map(|captures| match captures.len() { 1 => { let full = captures.get(0).unwrap().as_bytes(); @@ -232,7 +231,7 @@ mod re { .map(|i| i.try_to_primitive::(vm)) .transpose()? .unwrap_or(0); - let text = search_text.as_str().as_bytes(); + let text = search_text.as_bytes(); // essentially Regex::split, but it outputs captures as well let mut output = Vec::new(); let mut last = 0; @@ -332,9 +331,7 @@ mod re { #[pymethod] fn sub(&self, repl: PyStrRef, text: PyStrRef, vm: &VirtualMachine) -> PyResult { - let replaced_text = self - .regex - .replace_all(text.as_str().as_bytes(), repl.as_str().as_bytes()); + let replaced_text = self.regex.replace_all(text.as_bytes(), repl.as_bytes()); let replaced_text = String::from_utf8_lossy(&replaced_text).into_owned(); Ok(vm.ctx.new_str(replaced_text)) } diff --git a/stdlib/src/socket.rs b/stdlib/src/socket.rs index 39bfde4bee..17daec7751 100644 --- a/stdlib/src/socket.rs +++ b/stdlib/src/socket.rs @@ -930,10 +930,15 @@ mod _socket { match family { #[cfg(unix)] c::AF_UNIX => { + use crate::vm::function::ArgStrOrBytesLike; use std::os::unix::ffi::OsStrExt; - let buf = crate::vm::function::ArgStrOrBytesLike::try_from_object(vm, addr)?; - let path = &*buf.borrow_bytes(); - socket2::SockAddr::unix(ffi::OsStr::from_bytes(path)) + let buf = ArgStrOrBytesLike::try_from_object(vm, addr)?; + let bytes = &*buf.borrow_bytes(); + let path = match &buf { + ArgStrOrBytesLike::Buf(_) => ffi::OsStr::from_bytes(bytes).into(), + ArgStrOrBytesLike::Str(s) => vm.fsencode(s)?, + }; + socket2::SockAddr::unix(path) .map_err(|_| vm.new_os_error("AF_UNIX path too long".to_owned()).into()) } c::AF_INET => { @@ -1704,7 +1709,7 @@ mod _socket { let path = ffi::OsStr::as_bytes(addr.as_pathname().unwrap_or("".as_ref()).as_ref()); let nul_pos = memchr::memchr(b'\0', path).unwrap_or(path.len()); let path = ffi::OsStr::from_bytes(&path[..nul_pos]); - return vm.ctx.new_str(path.to_string_lossy()).into(); + return vm.fsdecode(path).into(); } // TODO: support more address families (String::new(), 0).to_pyobject(vm) diff --git a/stdlib/src/sqlite.rs b/stdlib/src/sqlite.rs index 85ce8d80fd..67e94bd81b 100644 --- a/stdlib/src/sqlite.rs +++ b/stdlib/src/sqlite.rs @@ -2929,9 +2929,12 @@ mod _sqlite { } fn str_to_ptr_len(s: &PyStr, vm: &VirtualMachine) -> PyResult<(*const libc::c_char, i32)> { - let len = c_int::try_from(s.byte_len()) + let s = s + .to_str() + .ok_or_else(|| vm.new_unicode_encode_error("surrogates not allowed".to_owned()))?; + let len = c_int::try_from(s.len()) .map_err(|_| vm.new_overflow_error("TEXT longer than INT_MAX bytes".to_owned()))?; - let ptr = s.as_str().as_ptr().cast(); + let ptr = s.as_ptr().cast(); Ok((ptr, len)) } @@ -3000,7 +3003,7 @@ mod _sqlite { ) -> PyResult<*const libc::c_char> { BEGIN_STATEMENTS .iter() - .find(|&&x| x[6..].eq_ignore_ascii_case(s.as_str().as_bytes())) + .find(|&&x| x[6..].eq_ignore_ascii_case(s.as_bytes())) .map(|&x| x.as_ptr().cast()) .ok_or_else(|| { vm.new_value_error( diff --git a/stdlib/src/ssl.rs b/stdlib/src/ssl.rs index 2b8ffc7d8c..10d1906448 100644 --- a/stdlib/src/ssl.rs +++ b/stdlib/src/ssl.rs @@ -403,8 +403,8 @@ mod _ssl { .to_str() .unwrap(); let (cert_file, cert_dir) = get_cert_file_dir(); - let cert_file = OsPath::new_str(cert_file).filename(vm)?; - let cert_dir = OsPath::new_str(cert_dir).filename(vm)?; + let cert_file = OsPath::new_str(cert_file).filename(vm); + let cert_dir = OsPath::new_str(cert_dir).filename(vm); Ok((cert_file_env, cert_file, cert_dir_env, cert_dir)) } @@ -708,7 +708,7 @@ mod _ssl { if !s.is_ascii() { return Err(invalid_cadata(vm)); } - X509::stack_from_pem(s.as_str().as_bytes()) + X509::stack_from_pem(s.as_bytes()) } Either::B(b) => b.with_ref(x509_stack_from_der), }; diff --git a/stdlib/src/unicodedata.rs b/stdlib/src/unicodedata.rs index 49f3ef6250..9af921d360 100644 --- a/stdlib/src/unicodedata.rs +++ b/stdlib/src/unicodedata.rs @@ -65,6 +65,7 @@ mod unicodedata { function::OptionalArg, }; use itertools::Itertools; + use rustpython_common::wtf8::{CodePoint, Wtf8Buf}; use ucd::{Codepoint, EastAsianWidth}; use unic_char_property::EnumeratedCharProperty; use unic_normal::StrNormalForm; @@ -84,14 +85,23 @@ mod unicodedata { Self { unic_version } } - fn check_age(&self, c: char) -> bool { - Age::of(c).is_some_and(|age| age.actual() <= self.unic_version) + fn check_age(&self, c: CodePoint) -> bool { + c.to_char() + .is_none_or(|c| Age::of(c).is_some_and(|age| age.actual() <= self.unic_version)) } - fn extract_char(&self, character: PyStrRef, vm: &VirtualMachine) -> PyResult> { - let c = character.as_str().chars().exactly_one().map_err(|_| { - vm.new_type_error("argument must be an unicode character, not str".to_owned()) - })?; + fn extract_char( + &self, + character: PyStrRef, + vm: &VirtualMachine, + ) -> PyResult> { + let c = character + .as_wtf8() + .code_points() + .exactly_one() + .map_err(|_| { + vm.new_type_error("argument must be an unicode character, not str".to_owned()) + })?; Ok(self.check_age(c).then_some(c)) } @@ -103,7 +113,10 @@ mod unicodedata { fn category(&self, character: PyStrRef, vm: &VirtualMachine) -> PyResult { Ok(self .extract_char(character, vm)? - .map_or(GeneralCategory::Unassigned, GeneralCategory::of) + .map_or(GeneralCategory::Unassigned, |c| { + c.to_char() + .map_or(GeneralCategory::Surrogate, GeneralCategory::of) + }) .abbr_name() .to_owned()) } @@ -111,7 +124,7 @@ mod unicodedata { #[pymethod] fn lookup(&self, name: PyStrRef, vm: &VirtualMachine) -> PyResult { if let Some(character) = unicode_names2::character(name.as_str()) { - if self.check_age(character) { + if self.check_age(character.into()) { return Ok(character.to_string()); } } @@ -129,7 +142,7 @@ mod unicodedata { if let Some(c) = c { if self.check_age(c) { - if let Some(name) = unicode_names2::name(c) { + if let Some(name) = c.to_char().and_then(unicode_names2::name) { return Ok(vm.ctx.new_str(name.to_string()).into()); } } @@ -144,7 +157,10 @@ mod unicodedata { vm: &VirtualMachine, ) -> PyResult<&'static str> { let bidi = match self.extract_char(character, vm)? { - Some(c) => BidiClass::of(c).abbr_name(), + Some(c) => c + .to_char() + .map_or(BidiClass::LeftToRight, BidiClass::of) + .abbr_name(), None => "", }; Ok(bidi) @@ -159,19 +175,20 @@ mod unicodedata { ) -> PyResult<&'static str> { Ok(self .extract_char(character, vm)? + .and_then(|c| c.to_char()) .map_or(EastAsianWidth::Neutral, |c| c.east_asian_width()) .abbr_name()) } #[pymethod] - fn normalize(&self, form: super::NormalizeForm, unistr: PyStrRef) -> PyResult { + fn normalize(&self, form: super::NormalizeForm, unistr: PyStrRef) -> PyResult { use super::NormalizeForm::*; - let text = unistr.as_str(); + let text = unistr.as_wtf8(); let normalized_text = match form { - Nfc => text.nfc().collect::(), - Nfkc => text.nfkc().collect::(), - Nfd => text.nfd().collect::(), - Nfkd => text.nfkd().collect::(), + Nfc => text.map_utf8(|s| s.nfc()).collect(), + Nfkc => text.map_utf8(|s| s.nfkc()).collect(), + Nfd => text.map_utf8(|s| s.nfd()).collect(), + Nfkd => text.map_utf8(|s| s.nfkd()).collect(), }; Ok(normalized_text) } diff --git a/vm/src/anystr.rs b/vm/src/anystr.rs index d01136b0fb..89e7473441 100644 --- a/vm/src/anystr.rs +++ b/vm/src/anystr.rs @@ -7,28 +7,13 @@ use crate::{ use num_traits::{cast::ToPrimitive, sign::Signed}; #[derive(FromArgs)] -pub struct SplitArgs { +pub struct SplitArgs { #[pyarg(any, default)] sep: Option, #[pyarg(any, default = "-1")] maxsplit: isize, } -impl SplitArgs { - pub fn get_value(self, vm: &VirtualMachine) -> PyResult<(Option, isize)> { - let sep = if let Some(s) = self.sep { - let sep = s.as_ref(); - if sep.is_empty() { - return Err(vm.new_value_error("empty separator".to_owned())); - } - Some(s) - } else { - None - }; - Ok((sep, self.maxsplit)) - } -} - #[derive(FromArgs)] pub struct SplitLinesArgs { #[pyarg(any, default = "false")] @@ -132,9 +117,9 @@ impl StringRange for std::ops::Range { } } -pub trait AnyStrWrapper { - type Str: ?Sized + AnyStr; - fn as_ref(&self) -> &Self::Str; +pub trait AnyStrWrapper { + fn as_ref(&self) -> Option<&S>; + fn is_empty(&self) -> bool; } pub trait AnyStrContainer @@ -146,15 +131,18 @@ where fn push_str(&mut self, s: &S); } +pub trait AnyChar: Copy { + fn is_lowercase(self) -> bool; + fn is_uppercase(self) -> bool; + fn bytes_len(self) -> usize; +} + pub trait AnyStr { - type Char: Copy; + type Char: AnyChar; type Container: AnyStrContainer + Extend; - fn element_bytes_len(c: Self::Char) -> usize; - fn to_container(&self) -> Self::Container; fn as_bytes(&self) -> &[u8]; - fn chars(&self) -> impl Iterator; fn elements(&self) -> impl Iterator; fn get_bytes(&self, range: std::ops::Range) -> &Self; // FIXME: get_chars is expensive for str @@ -172,29 +160,35 @@ pub trait AnyStr { new } - fn py_split( + fn py_split( &self, args: SplitArgs, vm: &VirtualMachine, + full_obj: impl FnOnce() -> PyObjectRef, split: SP, splitn: SN, splitw: SW, - ) -> PyResult> + ) -> PyResult> where - T: TryFromObject + AnyStrWrapper, - SP: Fn(&Self, &Self, &VirtualMachine) -> Vec, - SN: Fn(&Self, &Self, usize, &VirtualMachine) -> Vec, - SW: Fn(&Self, isize, &VirtualMachine) -> Vec, + T: TryFromObject + AnyStrWrapper, + SP: Fn(&Self, &Self, &VirtualMachine) -> Vec, + SN: Fn(&Self, &Self, usize, &VirtualMachine) -> Vec, + SW: Fn(&Self, isize, &VirtualMachine) -> Vec, { - let (sep, maxsplit) = args.get_value(vm)?; - let splits = if let Some(pattern) = sep { - if maxsplit < 0 { - split(self, pattern.as_ref(), vm) + if args.sep.as_ref().is_some_and(|sep| sep.is_empty()) { + return Err(vm.new_value_error("empty separator".to_owned())); + } + let splits = if let Some(pattern) = args.sep { + let Some(pattern) = pattern.as_ref() else { + return Ok(vec![full_obj()]); + }; + if args.maxsplit < 0 { + split(self, pattern, vm) } else { - splitn(self, pattern.as_ref(), (maxsplit + 1) as usize, vm) + splitn(self, pattern, (args.maxsplit + 1) as usize, vm) } } else { - splitw(self, maxsplit, vm) + splitw(self, args.maxsplit, vm) }; Ok(splits) } @@ -242,13 +236,19 @@ pub trait AnyStr { func_default: FD, ) -> &'a Self where - S: AnyStrWrapper, + S: AnyStrWrapper, FC: Fn(&'a Self, &Self) -> &'a Self, FD: Fn(&'a Self) -> &'a Self, { let chars = chars.flatten(); match chars { - Some(chars) => func_chars(self, chars.as_ref()), + Some(chars) => { + if let Some(chars) = chars.as_ref() { + func_chars(self, chars) + } else { + self + } + } None => func_default(self), } } @@ -281,7 +281,7 @@ pub trait AnyStr { fn py_pad(&self, left: usize, right: usize, fillchar: Self::Char) -> Self::Container { let mut u = Self::Container::with_capacity( - (left + right) * Self::element_bytes_len(fillchar) + self.bytes_len(), + (left + right) * fillchar.bytes_len() + self.bytes_len(), ); u.extend(std::iter::repeat(fillchar).take(left)); u.push_str(self); @@ -305,19 +305,17 @@ pub trait AnyStr { fn py_join( &self, - mut iter: impl std::iter::Iterator< - Item = PyResult + TryFromObject>, - >, + mut iter: impl std::iter::Iterator + TryFromObject>>, ) -> PyResult { let mut joined = if let Some(elem) = iter.next() { - elem?.as_ref().to_container() + elem?.as_ref().unwrap().to_container() } else { return Ok(Self::Container::new()); }; for elem in iter { let elem = elem?; joined.push_str(self); - joined.push_str(elem.as_ref()); + joined.push_str(elem.as_ref().unwrap()); } Ok(joined) } @@ -403,25 +401,34 @@ pub trait AnyStr { rustpython_common::str::zfill(self.as_bytes(), width) } - fn py_iscase(&self, is_case: F, is_opposite: G) -> bool - where - F: Fn(char) -> bool, - G: Fn(char) -> bool, - { - // Unified form of CPython functions: - // _Py_bytes_islower - // Py_bytes_isupper - // unicode_islower_impl - // unicode_isupper_impl - let mut cased = false; - for c in self.chars() { - if is_opposite(c) { + // Unified form of CPython functions: + // _Py_bytes_islower + // unicode_islower_impl + fn py_islower(&self) -> bool { + let mut lower = false; + for c in self.elements() { + if c.is_uppercase() { + return false; + } else if !lower && c.is_lowercase() { + lower = true + } + } + lower + } + + // Unified form of CPython functions: + // Py_bytes_isupper + // unicode_isupper_impl + fn py_isupper(&self) -> bool { + let mut upper = false; + for c in self.elements() { + if c.is_lowercase() { return false; - } else if !cased && is_case(c) { - cased = true + } else if !upper && c.is_uppercase() { + upper = true } } - cased + upper } } diff --git a/vm/src/builtins/genericalias.rs b/vm/src/builtins/genericalias.rs index c03e3145b2..0e0a34227b 100644 --- a/vm/src/builtins/genericalias.rs +++ b/vm/src/builtins/genericalias.rs @@ -124,7 +124,7 @@ impl PyGenericAlias { Ok(format!( "{}[{}]", repr_item(self.origin.clone().into(), vm)?, - if self.args.len() == 0 { + if self.args.is_empty() { "()".to_owned() } else { self.args @@ -261,23 +261,20 @@ fn subs_tvars( .and_then(|sub_params| { PyTupleRef::try_from_object(vm, sub_params) .ok() - .and_then(|sub_params| { - if sub_params.len() > 0 { - let sub_args = sub_params - .iter() - .map(|arg| { - if let Some(idx) = tuple_index(params, arg) { - argitems[idx].clone() - } else { - arg.clone() - } - }) - .collect::>(); - let sub_args: PyObjectRef = PyTuple::new_ref(sub_args, &vm.ctx).into(); - Some(obj.get_item(&*sub_args, vm)) - } else { - None - } + .filter(|sub_params| !sub_params.is_empty()) + .map(|sub_params| { + let sub_args = sub_params + .iter() + .map(|arg| { + if let Some(idx) = tuple_index(params, arg) { + argitems[idx].clone() + } else { + arg.clone() + } + }) + .collect::>(); + let sub_args: PyObjectRef = PyTuple::new_ref(sub_args, &vm.ctx).into(); + obj.get_item(&*sub_args, vm) }) }) .unwrap_or(Ok(obj)) diff --git a/vm/src/builtins/object.rs b/vm/src/builtins/object.rs index cce1422d56..be14327542 100644 --- a/vm/src/builtins/object.rs +++ b/vm/src/builtins/object.rs @@ -159,7 +159,7 @@ fn object_getstate_default(obj: &PyObject, required: bool, vm: &VirtualMachine) slots.set_item(name.as_str(), value, vm).unwrap(); } - if slots.len() > 0 { + if !slots.is_empty() { return (state, slots).to_pyresult(vm); } } diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index a135af1bde..40823aa37b 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -784,9 +784,7 @@ impl Initializer for PySet { type Args = OptionalArg; fn init(zelf: PyRef, iterable: Self::Args, vm: &VirtualMachine) -> PyResult<()> { - if zelf.len() > 0 { - zelf.clear(); - } + zelf.clear(); if let OptionalArg::Present(it) = iterable { zelf.update(PosArgs::new(vec![it]), vm)?; } diff --git a/vm/src/builtins/str.rs b/vm/src/builtins/str.rs index 76cdca81ed..55cefae4f7 100644 --- a/vm/src/builtins/str.rs +++ b/vm/src/builtins/str.rs @@ -10,7 +10,7 @@ use crate::{ atomic_func, cformat::cformat_string, class::PyClassImpl, - common::str::{BorrowedStr, PyStrKind, PyStrKindData}, + common::str::{PyKindStr, StrData, StrKind}, convert::{IntoPyException, ToPyException, ToPyObject, ToPyResult}, format::{format, format_map}, function::{ArgIterable, ArgSize, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue}, @@ -24,7 +24,7 @@ use crate::{ PyComparisonOp, Representable, SelfIter, Unconstructible, }, }; -use ascii::{AsciiStr, AsciiString}; +use ascii::{AsciiChar, AsciiStr, AsciiString}; use bstr::ByteSlice; use itertools::Itertools; use num_traits::ToPrimitive; @@ -35,8 +35,10 @@ use rustpython_common::{ format::{FormatSpec, FormatString, FromTemplate}, hash, lock::PyMutex, + str::DeduceStrKind, + wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Chunk}, }; -use std::{char, fmt, ops::Range, string::ToString}; +use std::{borrow::Cow, char, fmt, ops::Range}; use unic_ucd_bidi::BidiClass; use unic_ucd_category::GeneralCategory; use unic_ucd_ident::{is_xid_continue, is_xid_start}; @@ -57,39 +59,59 @@ impl<'a> TryFromBorrowedObject<'a> for &'a str { #[pyclass(module = false, name = "str")] pub struct PyStr { - bytes: Box<[u8]>, - kind: PyStrKindData, + data: StrData, hash: PyAtomic, } impl fmt::Debug for PyStr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("PyStr") - .field("value", &self.as_str()) - .field("kind", &self.kind) + .field("value", &self.as_wtf8()) + .field("kind", &self.data.kind()) .field("hash", &self.hash) .finish() } } impl AsRef for PyStr { + #[track_caller] // <- can remove this once it doesn't panic fn as_ref(&self) -> &str { self.as_str() } } impl AsRef for Py { + #[track_caller] // <- can remove this once it doesn't panic fn as_ref(&self) -> &str { self.as_str() } } impl AsRef for PyStrRef { + #[track_caller] // <- can remove this once it doesn't panic fn as_ref(&self) -> &str { self.as_str() } } +impl AsRef for PyStr { + fn as_ref(&self) -> &Wtf8 { + self.as_wtf8() + } +} + +impl AsRef for Py { + fn as_ref(&self) -> &Wtf8 { + self.as_wtf8() + } +} + +impl AsRef for PyStrRef { + fn as_ref(&self) -> &Wtf8 { + self.as_wtf8() + } +} + impl<'a> From<&'a AsciiStr> for PyStr { fn from(s: &'a AsciiStr) -> Self { s.to_owned().into() @@ -98,7 +120,19 @@ impl<'a> From<&'a AsciiStr> for PyStr { impl From for PyStr { fn from(s: AsciiString) -> Self { - unsafe { Self::new_ascii_unchecked(s.into()) } + s.into_boxed_ascii_str().into() + } +} + +impl From> for PyStr { + fn from(s: Box) -> Self { + StrData::from(s).into() + } +} + +impl From for PyStr { + fn from(ch: AsciiChar) -> Self { + AsciiString::from(ch).into() } } @@ -108,12 +142,45 @@ impl<'a> From<&'a str> for PyStr { } } +impl<'a> From<&'a Wtf8> for PyStr { + fn from(s: &'a Wtf8) -> Self { + s.to_owned().into() + } +} + impl From for PyStr { fn from(s: String) -> Self { s.into_boxed_str().into() } } +impl From for PyStr { + fn from(w: Wtf8Buf) -> Self { + w.into_box().into() + } +} + +impl From for PyStr { + fn from(ch: char) -> Self { + StrData::from(ch).into() + } +} + +impl From for PyStr { + fn from(ch: CodePoint) -> Self { + StrData::from(ch).into() + } +} + +impl From for PyStr { + fn from(data: StrData) -> Self { + PyStr { + data, + hash: Radium::new(hash::SENTINEL), + } + } +} + impl<'a> From> for PyStr { fn from(s: std::borrow::Cow<'a, str>) -> Self { s.into_owned().into() @@ -123,19 +190,21 @@ impl<'a> From> for PyStr { impl From> for PyStr { #[inline] fn from(value: Box) -> Self { - // doing the check is ~10x faster for ascii, and is actually only 2% slower worst case for - // non-ascii; see https://github.com/RustPython/RustPython/pull/2586#issuecomment-844611532 - let is_ascii = value.is_ascii(); - let bytes = value.into_boxed_bytes(); - let kind = if is_ascii { - PyStrKind::Ascii - } else { - PyStrKind::Utf8 - } - .new_data(); + StrData::from(value).into() + } +} + +impl From> for PyStr { + #[inline] + fn from(value: Box) -> Self { + StrData::from(value).into() + } +} + +impl Default for PyStr { + fn default() -> Self { Self { - bytes, - kind, + data: StrData::default(), hash: Radium::new(hash::SENTINEL), } } @@ -146,7 +215,7 @@ pub type PyStrRef = PyRef; impl fmt::Display for PyStr { #[inline] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(self.as_str(), f) + self.as_wtf8().fmt(f) } } @@ -237,18 +306,18 @@ impl IterNext for PyStrIterator { let mut internal = zelf.internal.lock(); if let IterStatus::Active(s) = &internal.0.status { - let value = s.as_str(); + let value = s.as_wtf8(); if internal.1 == usize::MAX { - if let Some((offset, ch)) = value.char_indices().nth(internal.0.position) { + if let Some((offset, ch)) = value.code_point_indices().nth(internal.0.position) { internal.0.position += 1; - internal.1 = offset + ch.len_utf8(); + internal.1 = offset + ch.len_wtf8(); return Ok(PyIterReturn::Return(ch.to_pyobject(vm))); } } else if let Some(value) = value.get(internal.1..) { - if let Some(ch) = value.chars().next() { + if let Some(ch) = value.code_points().next() { internal.0.position += 1; - internal.1 += ch.len_utf8(); + internal.1 += ch.len_wtf8(); return Ok(PyIterReturn::Return(ch.to_pyobject(vm))); } } @@ -292,7 +361,7 @@ impl Constructor for PyStr { if string.class().is(&cls) { Ok(string.into()) } else { - PyStr::from(string.as_str()) + PyStr::from(string.as_wtf8()) .into_ref_with_type(vm, cls) .map(Into::into) } @@ -301,20 +370,19 @@ impl Constructor for PyStr { impl PyStr { /// # Safety: Given `bytes` must be valid data for given `kind` - pub(crate) unsafe fn new_str_unchecked(bytes: Vec, kind: PyStrKind) -> Self { - let s = Self { - bytes: bytes.into_boxed_slice(), - kind: kind.new_data(), - hash: Radium::new(hash::SENTINEL), - }; - debug_assert!(matches!(s.kind, PyStrKindData::Ascii) || !s.as_str().is_ascii()); - s + unsafe fn new_str_unchecked(data: Box, kind: StrKind) -> Self { + unsafe { StrData::new_str_unchecked(data, kind) }.into() + } + + unsafe fn new_with_char_len>>(s: T, char_len: usize) -> Self { + let kind = s.str_kind(); + unsafe { StrData::new_with_char_len(s.into(), kind, char_len) }.into() } /// # Safety /// Given `bytes` must be ascii pub unsafe fn new_ascii_unchecked(bytes: Vec) -> Self { - unsafe { Self::new_str_unchecked(bytes, PyStrKind::Ascii) } + unsafe { AsciiString::from_ascii_unchecked(bytes) }.into() } pub fn new_ref(zelf: impl Into, ctx: &Context) -> PyRef { @@ -322,40 +390,70 @@ impl PyStr { PyRef::new_ref(zelf, ctx.types.str_type.to_owned(), None) } - fn new_substr(&self, s: String) -> Self { - let kind = if self.kind.kind() == PyStrKind::Ascii || s.is_ascii() { - PyStrKind::Ascii + fn new_substr(&self, s: Wtf8Buf) -> Self { + let kind = if self.kind().is_ascii() || s.is_ascii() { + StrKind::Ascii + } else if self.kind().is_utf8() || s.is_utf8() { + StrKind::Utf8 } else { - PyStrKind::Utf8 + StrKind::Wtf8 }; unsafe { // SAFETY: kind is properly decided for substring - Self::new_str_unchecked(s.into_bytes(), kind) + Self::new_str_unchecked(s.into(), kind) } } #[inline] + pub fn as_wtf8(&self) -> &Wtf8 { + self.data.as_wtf8() + } + + pub fn as_bytes(&self) -> &[u8] { + self.data.as_wtf8().as_bytes() + } + + // FIXME: make this return an Option + #[inline] + #[track_caller] // <- can remove this once it doesn't panic pub fn as_str(&self) -> &str { - unsafe { - // SAFETY: Both PyStrKind::{Ascii, Utf8} are valid utf8 string - std::str::from_utf8_unchecked(&self.bytes) - } + self.data.as_str().expect("str has surrogates") + } + + pub fn to_str(&self) -> Option<&str> { + self.data.as_str() + } + + pub fn to_string_lossy(&self) -> Cow<'_, str> { + self.to_str() + .map(Cow::Borrowed) + .unwrap_or_else(|| self.as_wtf8().to_string_lossy()) + } + + pub fn kind(&self) -> StrKind { + self.data.kind() + } + + #[inline] + pub fn as_str_kind(&self) -> PyKindStr<'_> { + self.data.as_str_kind() + } + + pub fn is_utf8(&self) -> bool { + self.kind().is_utf8() } fn char_all(&self, test: F) -> bool where F: Fn(char) -> bool, { - match self.kind.kind() { - PyStrKind::Ascii => self.bytes.iter().all(|&x| test(char::from(x))), - PyStrKind::Utf8 => self.as_str().chars().all(test), + match self.as_str_kind() { + PyKindStr::Ascii(s) => s.chars().all(|ch| test(ch.into())), + PyKindStr::Utf8(s) => s.chars().all(test), + PyKindStr::Wtf8(w) => w.code_points().all(|ch| ch.is_char_and(&test)), } } - fn borrow(&self) -> &BorrowedStr<'_> { - unsafe { std::mem::transmute(self) } - } - fn repeat(zelf: PyRef, value: isize, vm: &VirtualMachine) -> PyResult> { if value == 0 && zelf.class().is(vm.ctx.types.str_type) { // Special case: when some `str` is multiplied by `0`, @@ -369,10 +467,10 @@ impl PyStr { // This only works for `str` itself, not its subclasses. return Ok(zelf); } - zelf.as_str() + zelf.as_wtf8() .as_bytes() .mul(vm, value) - .map(|x| Self::from(unsafe { String::from_utf8_unchecked(x) }).into_ref(&vm.ctx)) + .map(|x| Self::from(unsafe { Wtf8Buf::from_bytes_unchecked(x) }).into_ref(&vm.ctx)) } } @@ -394,11 +492,11 @@ impl PyStr { #[pymethod(magic)] fn add(zelf: PyRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { if let Some(other) = other.payload::() { - let bytes = zelf.as_str().py_add(other.as_ref()); + let bytes = zelf.as_wtf8().py_add(other.as_wtf8()); Ok(unsafe { // SAFETY: `kind` is safely decided - let kind = zelf.kind.kind() | other.kind.kind(); - Self::new_str_unchecked(bytes.into_bytes(), kind) + let kind = zelf.kind() | other.kind(); + Self::new_str_unchecked(bytes.into(), kind) } .to_pyobject(vm)) } else if let Some(radd) = vm.get_method(other.clone(), identifier!(vm, __radd__)) { @@ -414,7 +512,7 @@ impl PyStr { fn _contains(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult { if let Some(needle) = needle.payload::() { - Ok(self.as_str().contains(needle.as_str())) + Ok(memchr::memmem::find(self.as_bytes(), needle.as_bytes()).is_some()) } else { Err(vm.new_type_error(format!( "'in ' requires string as left operand, not {}", @@ -429,11 +527,11 @@ impl PyStr { } fn _getitem(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult { - match SequenceIndex::try_from_borrowed_object(vm, needle, "str")? { - SequenceIndex::Int(i) => self.getitem_by_index(vm, i).map(|x| x.to_string()), - SequenceIndex::Slice(slice) => self.getitem_by_slice(vm, slice), - } - .map(|x| self.new_substr(x).into_ref(&vm.ctx).into()) + let item = match SequenceIndex::try_from_borrowed_object(vm, needle, "str")? { + SequenceIndex::Int(i) => self.getitem_by_index(vm, i)?.to_pyobject(vm), + SequenceIndex::Slice(slice) => self.getitem_by_slice(vm, slice)?.to_pyobject(vm), + }; + Ok(item) } #[pymethod(magic)] @@ -450,7 +548,7 @@ impl PyStr { } #[cold] fn _compute_hash(&self, vm: &VirtualMachine) -> hash::PyHash { - let hash_val = vm.state.hash_secret.hash_str(self.as_str()); + let hash_val = vm.state.hash_secret.hash_bytes(self.as_bytes()); debug_assert_ne!(hash_val, hash::SENTINEL); // like with char_len, we don't need a cmpxchg loop, since it'll always be the same value self.hash.store(hash_val, atomic::Ordering::Relaxed); @@ -459,26 +557,23 @@ impl PyStr { #[inline] pub fn byte_len(&self) -> usize { - self.bytes.len() + self.data.len() } #[inline] pub fn is_empty(&self) -> bool { - self.bytes.is_empty() + self.data.is_empty() } #[pymethod(name = "__len__")] #[inline] pub fn char_len(&self) -> usize { - self.borrow().char_len() + self.data.char_len() } #[pymethod(name = "isascii")] #[inline(always)] pub fn is_ascii(&self) -> bool { - match self.kind { - PyStrKindData::Ascii => true, - PyStrKindData::Utf8(_) => false, - } + matches!(self.kind(), StrKind::Ascii) } #[pymethod(magic)] @@ -495,6 +590,9 @@ impl PyStr { #[inline] pub(crate) fn repr(&self, vm: &VirtualMachine) -> PyResult { use crate::literal::escape::UnicodeEscape; + if !self.kind().is_utf8() { + return Ok(format!("{:?}", self.as_wtf8())); + } let escape = UnicodeEscape::new_repr(self.as_str()); escape .str_repr() @@ -503,10 +601,18 @@ impl PyStr { } #[pymethod] - fn lower(&self) -> String { - match self.kind.kind() { - PyStrKind::Ascii => self.as_str().to_ascii_lowercase(), - PyStrKind::Utf8 => self.as_str().to_lowercase(), + fn lower(&self) -> PyStr { + match self.as_str_kind() { + PyKindStr::Ascii(s) => s.to_ascii_lowercase().into(), + PyKindStr::Utf8(s) => s.to_lowercase().into(), + PyKindStr::Wtf8(w) => w + .chunks() + .map(|c| match c { + Wtf8Chunk::Utf8(s) => s.to_lowercase().into(), + Wtf8Chunk::Surrogate(c) => Wtf8Buf::from(c), + }) + .collect::() + .into(), } } @@ -517,58 +623,98 @@ impl PyStr { } #[pymethod] - fn upper(&self) -> String { - match self.kind.kind() { - PyStrKind::Ascii => self.as_str().to_ascii_uppercase(), - PyStrKind::Utf8 => self.as_str().to_uppercase(), + fn upper(&self) -> PyStr { + match self.as_str_kind() { + PyKindStr::Ascii(s) => s.to_ascii_uppercase().into(), + PyKindStr::Utf8(s) => s.to_uppercase().into(), + PyKindStr::Wtf8(w) => w + .chunks() + .map(|c| match c { + Wtf8Chunk::Utf8(s) => s.to_uppercase().into(), + Wtf8Chunk::Surrogate(c) => Wtf8Buf::from(c), + }) + .collect::() + .into(), } } #[pymethod] - fn capitalize(&self) -> String { - let mut chars = self.as_str().chars(); - if let Some(first_char) = chars.next() { - format!( - "{}{}", - first_char.to_uppercase(), - &chars.as_str().to_lowercase(), - ) - } else { - "".to_owned() + fn capitalize(&self) -> Wtf8Buf { + match self.as_str_kind() { + PyKindStr::Ascii(s) => { + let mut s = s.to_owned(); + if let [first, rest @ ..] = s.as_mut_slice() { + first.make_ascii_uppercase(); + ascii::AsciiStr::make_ascii_lowercase(rest.into()); + } + s.into() + } + PyKindStr::Utf8(s) => { + let mut chars = s.chars(); + let mut out = String::with_capacity(s.len()); + if let Some(c) = chars.next() { + out.extend(c.to_titlecase()); + out.push_str(&chars.as_str().to_lowercase()); + } + out.into() + } + PyKindStr::Wtf8(s) => { + let mut out = Wtf8Buf::with_capacity(s.len()); + let mut chars = s.code_points(); + if let Some(ch) = chars.next() { + match ch.to_char() { + Some(ch) => out.extend(ch.to_titlecase()), + None => out.push(ch), + } + for chunk in chars.as_wtf8().chunks() { + match chunk { + Wtf8Chunk::Utf8(s) => out.push_str(&s.to_lowercase()), + Wtf8Chunk::Surrogate(ch) => out.push(ch), + } + } + } + out + } } } #[pymethod] - fn split(&self, args: SplitArgs, vm: &VirtualMachine) -> PyResult> { - let elements = match self.kind.kind() { - PyStrKind::Ascii => self.as_str().py_split( + fn split(zelf: &Py, args: SplitArgs, vm: &VirtualMachine) -> PyResult> { + let elements = match zelf.as_str_kind() { + PyKindStr::Ascii(s) => s.py_split( args, vm, + || zelf.as_object().to_owned(), |v, s, vm| { v.as_bytes() .split_str(s) - .map(|s| { - unsafe { PyStr::new_ascii_unchecked(s.to_owned()) }.to_pyobject(vm) - }) + .map(|s| unsafe { AsciiStr::from_ascii_unchecked(s) }.to_pyobject(vm)) .collect() }, |v, s, n, vm| { v.as_bytes() .splitn_str(n, s) - .map(|s| { - unsafe { PyStr::new_ascii_unchecked(s.to_owned()) }.to_pyobject(vm) - }) + .map(|s| unsafe { AsciiStr::from_ascii_unchecked(s) }.to_pyobject(vm)) .collect() }, |v, n, vm| { v.as_bytes().py_split_whitespace(n, |s| { - unsafe { PyStr::new_ascii_unchecked(s.to_owned()) }.to_pyobject(vm) + unsafe { AsciiStr::from_ascii_unchecked(s) }.to_pyobject(vm) }) }, ), - PyStrKind::Utf8 => self.as_str().py_split( + PyKindStr::Utf8(s) => s.py_split( + args, + vm, + || zelf.as_object().to_owned(), + |v, s, vm| v.split(s).map(|s| vm.ctx.new_str(s).into()).collect(), + |v, s, n, vm| v.splitn(n, s).map(|s| vm.ctx.new_str(s).into()).collect(), + |v, n, vm| v.py_split_whitespace(n, |s| vm.ctx.new_str(s).into()), + ), + PyKindStr::Wtf8(w) => w.py_split( args, vm, + || zelf.as_object().to_owned(), |v, s, vm| v.split(s).map(|s| vm.ctx.new_str(s).into()).collect(), |v, s, n, vm| v.splitn(n, s).map(|s| vm.ctx.new_str(s).into()).collect(), |v, n, vm| v.py_split_whitespace(n, |s| vm.ctx.new_str(s).into()), @@ -578,10 +724,11 @@ impl PyStr { } #[pymethod] - fn rsplit(&self, args: SplitArgs, vm: &VirtualMachine) -> PyResult> { - let mut elements = self.as_str().py_split( + fn rsplit(zelf: &Py, args: SplitArgs, vm: &VirtualMachine) -> PyResult> { + let mut elements = zelf.as_wtf8().py_split( args, vm, + || zelf.as_object().to_owned(), |v, s, vm| v.rsplit(s).map(|s| vm.ctx.new_str(s).into()).collect(), |v, s, n, vm| v.rsplitn(n, s).map(|s| vm.ctx.new_str(s).into()).collect(), |v, n, vm| v.py_rsplit_whitespace(n, |s| vm.ctx.new_str(s).into()), @@ -593,14 +740,35 @@ impl PyStr { } #[pymethod] - fn strip(&self, chars: OptionalOption) -> String { - self.as_str() - .py_strip( - chars, - |s, chars| s.trim_matches(|c| chars.contains(c)), - |s| s.trim(), - ) - .to_owned() + fn strip(&self, chars: OptionalOption) -> PyStr { + match self.as_str_kind() { + PyKindStr::Ascii(s) => s + .py_strip( + chars, + |s, chars| { + let s = s + .as_str() + .trim_matches(|c| memchr::memchr(c as _, chars.as_bytes()).is_some()); + unsafe { AsciiStr::from_ascii_unchecked(s.as_bytes()) } + }, + |s| s.trim(), + ) + .into(), + PyKindStr::Utf8(s) => s + .py_strip( + chars, + |s, chars| s.trim_matches(|c| chars.contains(c)), + |s| s.trim(), + ) + .into(), + PyKindStr::Wtf8(w) => w + .py_strip( + chars, + |s, chars| s.trim_matches(|c| chars.code_points().contains(&c)), + |s| s.trim(), + ) + .into(), + } } #[pymethod] @@ -609,10 +777,10 @@ impl PyStr { chars: OptionalOption, vm: &VirtualMachine, ) -> PyRef { - let s = zelf.as_str(); + let s = zelf.as_wtf8(); let stripped = s.py_strip( chars, - |s, chars| s.trim_start_matches(|c| chars.contains(c)), + |s, chars| s.trim_start_matches(|c| chars.contains_code_point(c)), |s| s.trim_start(), ); if s == stripped { @@ -628,10 +796,10 @@ impl PyStr { chars: OptionalOption, vm: &VirtualMachine, ) -> PyRef { - let s = zelf.as_str(); + let s = zelf.as_wtf8(); let stripped = s.py_strip( chars, - |s, chars| s.trim_end_matches(|c| chars.contains(c)), + |s, chars| s.trim_end_matches(|c| chars.contains_code_point(c)), |s| s.trim_end(), ); if s == stripped { @@ -644,7 +812,7 @@ impl PyStr { #[pymethod] fn endswith(&self, options: anystr::StartsEndsWithArgs, vm: &VirtualMachine) -> PyResult { let (affix, substr) = - match options.prepare(self.as_str(), self.len(), |s, r| s.get_chars(r)) { + match options.prepare(self.as_wtf8(), self.len(), |s, r| s.get_chars(r)) { Some(x) => x, None => return Ok(false), }; @@ -652,7 +820,7 @@ impl PyStr { &affix, "endswith", "str", - |s, x: &Py| s.ends_with(x.as_str()), + |s, x: &Py| s.ends_with(x.as_wtf8()), vm, ) } @@ -664,7 +832,7 @@ impl PyStr { vm: &VirtualMachine, ) -> PyResult { let (affix, substr) = - match options.prepare(self.as_str(), self.len(), |s, r| s.get_chars(r)) { + match options.prepare(self.as_wtf8(), self.len(), |s, r| s.get_chars(r)) { Some(x) => x, None => return Ok(false), }; @@ -672,7 +840,7 @@ impl PyStr { &affix, "startswith", "str", - |s, x: &Py| s.starts_with(x.as_str()), + |s, x: &Py| s.starts_with(x.as_wtf8()), vm, ) } @@ -701,36 +869,33 @@ impl PyStr { #[pymethod] fn isalnum(&self) -> bool { - !self.bytes.is_empty() && self.char_all(char::is_alphanumeric) + !self.data.is_empty() && self.char_all(char::is_alphanumeric) } #[pymethod] fn isnumeric(&self) -> bool { - !self.bytes.is_empty() && self.char_all(char::is_numeric) + !self.data.is_empty() && self.char_all(char::is_numeric) } #[pymethod] fn isdigit(&self) -> bool { // python's isdigit also checks if exponents are digits, these are the unicode codepoints for exponents - let valid_codepoints: [u16; 10] = [ - 0x2070, 0x00B9, 0x00B2, 0x00B3, 0x2074, 0x2075, 0x2076, 0x2077, 0x2078, 0x2079, - ]; - let s = self.as_str(); - !s.is_empty() - && s.chars() - .filter(|c| !c.is_ascii_digit()) - .all(|c| valid_codepoints.contains(&(c as u16))) + !self.data.is_empty() + && self.char_all(|c| { + c.is_ascii_digit() + || matches!(c, '⁰' | '¹' | '²' | '³' | '⁴' | '⁵' | '⁶' | '⁷' | '⁸' | '⁹') + }) } #[pymethod] fn isdecimal(&self) -> bool { - !self.bytes.is_empty() + !self.data.is_empty() && self.char_all(|c| GeneralCategory::of(c) == GeneralCategory::DecimalNumber) } #[pymethod(name = "__mod__")] - fn modulo(&self, values: PyObjectRef, vm: &VirtualMachine) -> PyResult { - cformat_string(vm, self.as_str(), values) + fn modulo(&self, values: PyObjectRef, vm: &VirtualMachine) -> PyResult { + cformat_string(vm, self.as_wtf8(), values) } #[pymethod(magic)] @@ -739,8 +904,9 @@ impl PyStr { } #[pymethod] - fn format(&self, args: FuncArgs, vm: &VirtualMachine) -> PyResult { - let format_str = FormatString::from_str(self.as_str()).map_err(|e| e.to_pyexception(vm))?; + fn format(&self, args: FuncArgs, vm: &VirtualMachine) -> PyResult { + let format_str = + FormatString::from_str(self.as_wtf8()).map_err(|e| e.to_pyexception(vm))?; format(&format_str, &args, vm) } @@ -749,9 +915,9 @@ impl PyStr { /// Return a formatted version of S, using substitutions from mapping. /// The substitutions are identified by braces ('{' and '}'). #[pymethod] - fn format_map(&self, mapping: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn format_map(&self, mapping: PyObjectRef, vm: &VirtualMachine) -> PyResult { let format_string = - FormatString::from_str(self.as_str()).map_err(|err| err.to_pyexception(vm))?; + FormatString::from_str(self.as_wtf8()).map_err(|err| err.to_pyexception(vm))?; format_map(&format_string, &mapping, vm) } @@ -767,7 +933,9 @@ impl PyStr { } let s = FormatSpec::parse(spec) - .and_then(|format_spec| format_spec.format_string(zelf.borrow())) + .and_then(|format_spec| { + format_spec.format_string(&CharLenStr(zelf.as_str(), zelf.char_len())) + }) .map_err(|err| err.into_pyexception(vm))?; Ok(vm.ctx.new_str(s)) } @@ -775,43 +943,45 @@ impl PyStr { /// Return a titlecased version of the string where words start with an /// uppercase character and the remaining characters are lowercase. #[pymethod] - fn title(&self) -> String { - let mut title = String::with_capacity(self.bytes.len()); + fn title(&self) -> Wtf8Buf { + let mut title = Wtf8Buf::with_capacity(self.data.len()); let mut previous_is_cased = false; - for c in self.as_str().chars() { + for c_orig in self.as_wtf8().code_points() { + let c = c_orig.to_char_lossy(); if c.is_lowercase() { if !previous_is_cased { title.extend(c.to_titlecase()); } else { - title.push(c); + title.push_char(c); } previous_is_cased = true; } else if c.is_uppercase() || c.is_titlecase() { if previous_is_cased { title.extend(c.to_lowercase()); } else { - title.push(c); + title.push_char(c); } previous_is_cased = true; } else { previous_is_cased = false; - title.push(c); + title.push(c_orig); } } title } #[pymethod] - fn swapcase(&self) -> String { - let mut swapped_str = String::with_capacity(self.bytes.len()); - for c in self.as_str().chars() { + fn swapcase(&self) -> Wtf8Buf { + let mut swapped_str = Wtf8Buf::with_capacity(self.data.len()); + for c_orig in self.as_wtf8().code_points() { + let c = c_orig.to_char_lossy(); // to_uppercase returns an iterator, to_ascii_uppercase returns the char if c.is_lowercase() { - swapped_str.push(c.to_ascii_uppercase()); + swapped_str.push_char(c.to_ascii_uppercase()); } else if c.is_uppercase() { - swapped_str.push(c.to_ascii_lowercase()); + swapped_str.push_char(c.to_ascii_lowercase()); } else { - swapped_str.push(c); + swapped_str.push(c_orig); } } swapped_str @@ -819,24 +989,24 @@ impl PyStr { #[pymethod] fn isalpha(&self) -> bool { - !self.bytes.is_empty() && self.char_all(char::is_alphabetic) + !self.data.is_empty() && self.char_all(char::is_alphabetic) } #[pymethod] - fn replace(&self, old: PyStrRef, new: PyStrRef, count: OptionalArg) -> String { - let s = self.as_str(); + fn replace(&self, old: PyStrRef, new: PyStrRef, count: OptionalArg) -> Wtf8Buf { + let s = self.as_wtf8(); match count { OptionalArg::Present(max_count) if max_count >= 0 => { if max_count == 0 || (s.is_empty() && !old.is_empty()) { // nothing to do; return the original bytes s.to_owned() } else if s.is_empty() && old.is_empty() { - new.as_str().to_owned() + new.as_wtf8().to_owned() } else { - s.replacen(old.as_str(), new.as_str(), max_count as usize) + s.replacen(old.as_wtf8(), new.as_wtf8(), max_count as usize) } } - _ => s.replace(old.as_str(), new.as_str()), + _ => s.replace(old.as_wtf8(), new.as_wtf8()), } } @@ -863,7 +1033,7 @@ impl PyStr { #[pymethod] fn isspace(&self) -> bool { use unic_ucd_bidi::bidi_class::abbr_names::*; - !self.bytes.is_empty() + !self.data.is_empty() && self.char_all(|c| { GeneralCategory::of(c) == GeneralCategory::SpaceSeparator || matches!(BidiClass::of(c), WS | B | S) @@ -873,41 +1043,39 @@ impl PyStr { // Return true if all cased characters in the string are lowercase and there is at least one cased character, false otherwise. #[pymethod] fn islower(&self) -> bool { - match self.kind.kind() { - PyStrKind::Ascii => self.bytes.py_iscase(char::is_lowercase, char::is_uppercase), - PyStrKind::Utf8 => self - .as_str() - .py_iscase(char::is_lowercase, char::is_uppercase), + match self.as_str_kind() { + PyKindStr::Ascii(s) => s.py_islower(), + PyKindStr::Utf8(s) => s.py_islower(), + PyKindStr::Wtf8(w) => w.py_islower(), } } // Return true if all cased characters in the string are uppercase and there is at least one cased character, false otherwise. #[pymethod] fn isupper(&self) -> bool { - match self.kind.kind() { - PyStrKind::Ascii => self.bytes.py_iscase(char::is_uppercase, char::is_lowercase), - PyStrKind::Utf8 => self - .as_str() - .py_iscase(char::is_uppercase, char::is_lowercase), + match self.as_str_kind() { + PyKindStr::Ascii(s) => s.py_isupper(), + PyKindStr::Utf8(s) => s.py_isupper(), + PyKindStr::Wtf8(w) => w.py_isupper(), } } #[pymethod] fn splitlines(&self, args: anystr::SplitLinesArgs, vm: &VirtualMachine) -> Vec { - let into_wrapper = |s: &str| self.new_substr(s.to_owned()).to_pyobject(vm); + let into_wrapper = |s: &Wtf8| self.new_substr(s.to_owned()).to_pyobject(vm); let mut elements = Vec::new(); let mut last_i = 0; - let self_str = self.as_str(); - let mut enumerated = self_str.char_indices().peekable(); + let self_str = self.as_wtf8(); + let mut enumerated = self_str.code_point_indices().peekable(); while let Some((i, ch)) = enumerated.next() { - let end_len = match ch { + let end_len = match ch.to_char_lossy() { '\n' => 1, '\r' => { let is_rn = enumerated.next_if(|(_, ch)| *ch == '\n').is_some(); if is_rn { 2 } else { 1 } } '\x0b' | '\x0c' | '\x1c' | '\x1d' | '\x1e' | '\u{0085}' | '\u{2028}' - | '\u{2029}' => ch.len_utf8(), + | '\u{2029}' => ch.len_wtf8(), _ => continue, }; let range = if args.keepends { @@ -937,27 +1105,27 @@ impl PyStr { if first.as_object().class().is(vm.ctx.types.str_type) { return Ok(first); } else { - first.as_str().to_owned() + first.as_wtf8().to_owned() } } - Err(iter) => zelf.as_str().py_join(iter)?, + Err(iter) => zelf.as_wtf8().py_join(iter)?, }; Ok(vm.ctx.new_str(joined)) } // FIXME: two traversals of str is expensive #[inline] - fn _to_char_idx(r: &str, byte_idx: usize) -> usize { - r[..byte_idx].chars().count() + fn _to_char_idx(r: &Wtf8, byte_idx: usize) -> usize { + r[..byte_idx].code_points().count() } #[inline] fn _find(&self, args: FindArgs, find: F) -> Option where - F: Fn(&str, &str) -> Option, + F: Fn(&Wtf8, &Wtf8) -> Option, { let (sub, range) = args.get_value(self.len()); - self.as_str().py_find(sub.as_str(), range, find) + self.as_wtf8().py_find(sub.as_wtf8(), range, find) } #[pymethod] @@ -986,9 +1154,9 @@ impl PyStr { #[pymethod] fn partition(&self, sep: PyStrRef, vm: &VirtualMachine) -> PyResult { - let (front, has_mid, back) = self.as_str().py_partition( - sep.as_str(), - || self.as_str().splitn(2, sep.as_str()), + let (front, has_mid, back) = self.as_wtf8().py_partition( + sep.as_wtf8(), + || self.as_wtf8().splitn(2, sep.as_wtf8()), vm, )?; let partition = ( @@ -1005,9 +1173,9 @@ impl PyStr { #[pymethod] fn rpartition(&self, sep: PyStrRef, vm: &VirtualMachine) -> PyResult { - let (back, has_mid, front) = self.as_str().py_partition( - sep.as_str(), - || self.as_str().rsplitn(2, sep.as_str()), + let (back, has_mid, front) = self.as_wtf8().py_partition( + sep.as_wtf8(), + || self.as_wtf8().rsplitn(2, sep.as_wtf8()), vm, )?; Ok(( @@ -1015,7 +1183,7 @@ impl PyStr { if has_mid { sep } else { - vm.ctx.new_str(ascii!("")) + vm.ctx.empty_str.to_owned() }, self.new_substr(back), ) @@ -1026,13 +1194,13 @@ impl PyStr { /// empty, `false` otherwise. #[pymethod] fn istitle(&self) -> bool { - if self.bytes.is_empty() { + if self.data.is_empty() { return false; } let mut cased = false; let mut previous_is_cased = false; - for c in self.as_str().chars() { + for c in self.as_wtf8().code_points().map(CodePoint::to_char_lossy) { if c.is_uppercase() || c.is_titlecase() { if previous_is_cased { return false; @@ -1055,15 +1223,15 @@ impl PyStr { #[pymethod] fn count(&self, args: FindArgs) -> usize { let (needle, range) = args.get_value(self.len()); - self.as_str() - .py_count(needle.as_str(), range, |h, n| h.matches(n).count()) + self.as_wtf8() + .py_count(needle.as_wtf8(), range, |h, n| h.find_iter(n).count()) } #[pymethod] - fn zfill(&self, width: isize) -> String { + fn zfill(&self, width: isize) -> Wtf8Buf { unsafe { - // SAFETY: this is safe-guaranteed because the original self.as_str() is valid utf8 - String::from_utf8_unchecked(self.as_str().py_zfill(width)) + // SAFETY: this is safe-guaranteed because the original self.as_wtf8() is valid wtf8 + Wtf8Buf::from_bytes_unchecked(self.as_wtf8().py_zfill(width)) } } @@ -1072,20 +1240,20 @@ impl PyStr { &self, width: isize, fillchar: OptionalArg, - pad: fn(&str, usize, char, usize) -> String, + pad: fn(&Wtf8, usize, CodePoint, usize) -> Wtf8Buf, vm: &VirtualMachine, - ) -> PyResult { - let fillchar = fillchar.map_or(Ok(' '), |ref s| { - s.as_str().chars().exactly_one().map_err(|_| { + ) -> PyResult { + let fillchar = fillchar.map_or(Ok(' '.into()), |ref s| { + s.as_wtf8().code_points().exactly_one().map_err(|_| { vm.new_type_error( "The fill character must be exactly one character long".to_owned(), ) }) })?; Ok(if self.len() as isize >= width { - String::from(self.as_str()) + self.as_wtf8().to_owned() } else { - pad(self.as_str(), width as usize, fillchar, self.len()) + pad(self.as_wtf8(), width as usize, fillchar, self.len()) }) } @@ -1095,7 +1263,7 @@ impl PyStr { width: isize, fillchar: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult { self._pad(width, fillchar, AnyStr::py_center, vm) } @@ -1105,7 +1273,7 @@ impl PyStr { width: isize, fillchar: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult { self._pad(width, fillchar, AnyStr::py_ljust, vm) } @@ -1115,7 +1283,7 @@ impl PyStr { width: isize, fillchar: OptionalArg, vm: &VirtualMachine, - ) -> PyResult { + ) -> PyResult { self._pad(width, fillchar, AnyStr::py_rjust, vm) } @@ -1247,26 +1415,42 @@ impl PyStr { } } +struct CharLenStr<'a>(&'a str, usize); +impl std::ops::Deref for CharLenStr<'_> { + type Target = str; + fn deref(&self) -> &Self::Target { + self.0 + } +} +impl crate::common::format::CharLen for CharLenStr<'_> { + fn char_len(&self) -> usize { + self.1 + } +} + #[pyclass] impl PyRef { #[pymethod(magic)] fn str(self, vm: &VirtualMachine) -> PyRefExact { - self.into_exact_or(&vm.ctx, |zelf| unsafe { - // Creating a copy with same kind is safe - PyStr::new_str_unchecked(zelf.bytes.to_vec(), zelf.kind.kind()).into_exact_ref(&vm.ctx) + self.into_exact_or(&vm.ctx, |zelf| { + PyStr::from(zelf.data.clone()).into_exact_ref(&vm.ctx) }) } } impl PyStrRef { - pub fn concat_in_place(&mut self, other: &str, vm: &VirtualMachine) { + pub fn is_empty(&self) -> bool { + (**self).is_empty() + } + + pub fn concat_in_place(&mut self, other: &Wtf8, vm: &VirtualMachine) { // TODO: call [A]Rc::get_mut on the str to try to mutate the data in place if other.is_empty() { return; } - let mut s = String::with_capacity(self.byte_len() + other.len()); - s.push_str(self.as_ref()); - s.push_str(other); + let mut s = Wtf8Buf::with_capacity(self.byte_len() + other.len()); + s.push_wtf8(self.as_ref()); + s.push_wtf8(other); *self = PyStr::from(s).into_ref(&vm.ctx); } } @@ -1296,7 +1480,7 @@ impl Comparable for PyStr { return Ok(res.into()); } let other = class_or_notimplemented!(Self, other); - Ok(op.eval_ord(zelf.as_str().cmp(other.as_str())).into()) + Ok(op.eval_ord(zelf.as_wtf8().cmp(other.as_wtf8())).into()) } } @@ -1352,8 +1536,7 @@ impl AsSequence for PyStr { }), item: atomic_func!(|seq, i, vm| { let zelf = PyStr::sequence_downcast(seq); - zelf.getitem_by_index(vm, i) - .map(|x| zelf.new_substr(x.to_string()).into_ref(&vm.ctx).into()) + zelf.getitem_by_index(vm, i).to_pyresult(vm) }), contains: atomic_func!( |seq, needle, vm| PyStr::sequence_downcast(seq)._contains(needle, vm) @@ -1396,9 +1579,21 @@ impl ToPyObject for String { } } +impl ToPyObject for Wtf8Buf { + fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.new_str(self).into() + } +} + impl ToPyObject for char { fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { - vm.ctx.new_str(self.to_string()).into() + vm.ctx.new_str(self).into() + } +} + +impl ToPyObject for CodePoint { + fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.new_str(self).into() } } @@ -1414,6 +1609,18 @@ impl ToPyObject for &String { } } +impl ToPyObject for &Wtf8 { + fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.new_str(self).into() + } +} + +impl ToPyObject for &Wtf8Buf { + fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.new_str(self.clone()).into() + } +} + impl ToPyObject for &AsciiStr { fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { vm.ctx.new_str(self).into() @@ -1426,6 +1633,12 @@ impl ToPyObject for AsciiString { } } +impl ToPyObject for AsciiChar { + fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef { + vm.ctx.new_str(self).into() + } +} + type SplitArgs = anystr::SplitArgs; #[derive(FromArgs)] @@ -1452,96 +1665,137 @@ pub fn init(ctx: &Context) { } impl SliceableSequenceOp for PyStr { - type Item = char; - type Sliced = String; + type Item = CodePoint; + type Sliced = PyStr; fn do_get(&self, index: usize) -> Self::Item { - if self.is_ascii() { - self.bytes[index] as char - } else { - self.as_str().chars().nth(index).unwrap() - } + self.data.nth_char(index) } fn do_slice(&self, range: Range) -> Self::Sliced { - let value = self.as_str(); - if self.is_ascii() { - value[range].to_owned() - } else { - rustpython_common::str::get_chars(value, range).to_owned() + match self.as_str_kind() { + PyKindStr::Ascii(s) => s[range].into(), + PyKindStr::Utf8(s) => { + let char_len = range.len(); + let out = rustpython_common::str::get_chars(s, range); + // SAFETY: char_len is accurate + unsafe { PyStr::new_with_char_len(out, char_len) } + } + PyKindStr::Wtf8(w) => { + let char_len = range.len(); + let out = rustpython_common::str::get_codepoints(w, range); + // SAFETY: char_len is accurate + unsafe { PyStr::new_with_char_len(out, char_len) } + } } } fn do_slice_reverse(&self, range: Range) -> Self::Sliced { - if self.is_ascii() { - // this is an ascii string - let mut v = self.bytes[range].to_vec(); - v.reverse(); - unsafe { - // SAFETY: an ascii string is always utf8 - String::from_utf8_unchecked(v) + match self.as_str_kind() { + PyKindStr::Ascii(s) => { + let mut out = s[range].to_owned(); + out.as_mut_slice().reverse(); + out.into() + } + PyKindStr::Utf8(s) => { + let char_len = range.len(); + let mut out = String::with_capacity(2 * char_len); + out.extend( + s.chars() + .rev() + .skip(self.char_len() - range.end) + .take(range.len()), + ); + // SAFETY: char_len is accurate + unsafe { PyStr::new_with_char_len(out, range.len()) } + } + PyKindStr::Wtf8(w) => { + let char_len = range.len(); + let mut out = Wtf8Buf::with_capacity(2 * char_len); + out.extend( + w.code_points() + .rev() + .skip(self.char_len() - range.end) + .take(range.len()), + ); + // SAFETY: char_len is accurate + unsafe { PyStr::new_with_char_len(out, char_len) } } - } else { - let mut s = String::with_capacity(self.bytes.len()); - s.extend( - self.as_str() - .chars() - .rev() - .skip(self.char_len() - range.end) - .take(range.end - range.start), - ); - s } } fn do_stepped_slice(&self, range: Range, step: usize) -> Self::Sliced { - if self.is_ascii() { - let v = self.bytes[range].iter().copied().step_by(step).collect(); - unsafe { - // SAFETY: Any subset of ascii string is a valid utf8 string - String::from_utf8_unchecked(v) + match self.as_str_kind() { + PyKindStr::Ascii(s) => s[range] + .as_slice() + .iter() + .copied() + .step_by(step) + .collect::() + .into(), + PyKindStr::Utf8(s) => { + let char_len = (range.len() / step) + 1; + let mut out = String::with_capacity(2 * char_len); + out.extend(s.chars().skip(range.start).take(range.len()).step_by(step)); + // SAFETY: char_len is accurate + unsafe { PyStr::new_with_char_len(out, char_len) } + } + PyKindStr::Wtf8(w) => { + let char_len = (range.len() / step) + 1; + let mut out = Wtf8Buf::with_capacity(2 * char_len); + out.extend( + w.code_points() + .skip(range.start) + .take(range.len()) + .step_by(step), + ); + // SAFETY: char_len is accurate + unsafe { PyStr::new_with_char_len(out, char_len) } } - } else { - let mut s = String::with_capacity(2 * ((range.len() / step) + 1)); - s.extend( - self.as_str() - .chars() - .skip(range.start) - .take(range.end - range.start) - .step_by(step), - ); - s } } fn do_stepped_slice_reverse(&self, range: Range, step: usize) -> Self::Sliced { - if self.is_ascii() { - // this is an ascii string - let v: Vec = self.bytes[range] - .iter() + match self.as_str_kind() { + PyKindStr::Ascii(s) => s[range] + .chars() .rev() - .copied() .step_by(step) - .collect(); - // TODO: from_utf8_unchecked? - String::from_utf8(v).unwrap() - } else { - // not ascii, so the codepoints have to be at least 2 bytes each - let mut s = String::with_capacity(2 * ((range.len() / step) + 1)); - s.extend( - self.as_str() - .chars() - .rev() - .skip(self.char_len() - range.end) - .take(range.end - range.start) - .step_by(step), - ); - s + .collect::() + .into(), + PyKindStr::Utf8(s) => { + let char_len = (range.len() / step) + 1; + // not ascii, so the codepoints have to be at least 2 bytes each + let mut out = String::with_capacity(2 * char_len); + out.extend( + s.chars() + .rev() + .skip(self.char_len() - range.end) + .take(range.len()) + .step_by(step), + ); + // SAFETY: char_len is accurate + unsafe { PyStr::new_with_char_len(out, char_len) } + } + PyKindStr::Wtf8(w) => { + let char_len = (range.len() / step) + 1; + // not ascii, so the codepoints have to be at least 2 bytes each + let mut out = Wtf8Buf::with_capacity(2 * char_len); + out.extend( + w.code_points() + .rev() + .skip(self.char_len() - range.end) + .take(range.len()) + .step_by(step), + ); + // SAFETY: char_len is accurate + unsafe { PyStr::new_with_char_len(out, char_len) } + } } } fn empty() -> Self::Sliced { - String::new() + PyStr::default() } fn len(&self) -> usize { @@ -1561,10 +1815,30 @@ impl AsRef for PyExact { } } -impl AnyStrWrapper for PyStrRef { - type Str = str; - fn as_ref(&self) -> &str { - self.as_str() +impl AnyStrWrapper for PyStrRef { + fn as_ref(&self) -> Option<&Wtf8> { + Some(self.as_wtf8()) + } + fn is_empty(&self) -> bool { + self.data.is_empty() + } +} + +impl AnyStrWrapper for PyStrRef { + fn as_ref(&self) -> Option<&str> { + self.data.as_str() + } + fn is_empty(&self) -> bool { + self.data.is_empty() + } +} + +impl AnyStrWrapper for PyStrRef { + fn as_ref(&self) -> Option<&AsciiStr> { + self.data.as_ascii() + } + fn is_empty(&self) -> bool { + self.data.is_empty() } } @@ -1582,14 +1856,22 @@ impl AnyStrContainer for String { } } +impl anystr::AnyChar for char { + fn is_lowercase(self) -> bool { + self.is_lowercase() + } + fn is_uppercase(self) -> bool { + self.is_uppercase() + } + fn bytes_len(self) -> usize { + self.len_utf8() + } +} + impl AnyStr for str { type Char = char; type Container = String; - fn element_bytes_len(c: char) -> usize { - c.len_utf8() - } - fn to_container(&self) -> Self::Container { self.to_owned() } @@ -1598,10 +1880,6 @@ impl AnyStr for str { self.as_bytes() } - fn chars(&self) -> impl Iterator { - str::chars(self) - } - fn elements(&self) -> impl Iterator { str::chars(self) } @@ -1675,6 +1953,232 @@ impl AnyStr for str { } } +impl AnyStrContainer for Wtf8Buf { + fn new() -> Self { + Wtf8Buf::new() + } + + fn with_capacity(capacity: usize) -> Self { + Wtf8Buf::with_capacity(capacity) + } + + fn push_str(&mut self, other: &Wtf8) { + self.push_wtf8(other) + } +} + +impl anystr::AnyChar for CodePoint { + fn is_lowercase(self) -> bool { + self.is_char_and(char::is_lowercase) + } + fn is_uppercase(self) -> bool { + self.is_char_and(char::is_uppercase) + } + fn bytes_len(self) -> usize { + self.len_wtf8() + } +} + +impl AnyStr for Wtf8 { + type Char = CodePoint; + type Container = Wtf8Buf; + + fn to_container(&self) -> Self::Container { + self.to_owned() + } + + fn as_bytes(&self) -> &[u8] { + self.as_bytes() + } + + fn elements(&self) -> impl Iterator { + self.code_points() + } + + fn get_bytes(&self, range: std::ops::Range) -> &Self { + &self[range] + } + + fn get_chars(&self, range: std::ops::Range) -> &Self { + rustpython_common::str::get_codepoints(self, range) + } + + fn bytes_len(&self) -> usize { + self.len() + } + + fn is_empty(&self) -> bool { + self.is_empty() + } + + fn py_split_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef, + { + // CPython split_whitespace + let mut splits = Vec::new(); + let mut last_offset = 0; + let mut count = maxsplit; + for (offset, _) in self + .code_point_indices() + .filter(|(_, c)| c.is_char_and(|c| c.is_ascii_whitespace() || c == '\x0b')) + { + if last_offset == offset { + last_offset += 1; + continue; + } + if count == 0 { + break; + } + splits.push(convert(&self[last_offset..offset])); + last_offset = offset + 1; + count -= 1; + } + if last_offset != self.len() { + splits.push(convert(&self[last_offset..])); + } + splits + } + + fn py_rsplit_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef, + { + // CPython rsplit_whitespace + let mut splits = Vec::new(); + let mut last_offset = self.len(); + let mut count = maxsplit; + for (offset, _) in self + .code_point_indices() + .rev() + .filter(|(_, c)| c.is_char_and(|c| c.is_ascii_whitespace() || c == '\x0b')) + { + if last_offset == offset + 1 { + last_offset -= 1; + continue; + } + if count == 0 { + break; + } + splits.push(convert(&self[offset + 1..last_offset])); + last_offset = offset; + count -= 1; + } + if last_offset != 0 { + splits.push(convert(&self[..last_offset])); + } + splits + } +} + +impl AnyStrContainer for AsciiString { + fn new() -> Self { + AsciiString::new() + } + + fn with_capacity(capacity: usize) -> Self { + AsciiString::with_capacity(capacity) + } + + fn push_str(&mut self, other: &AsciiStr) { + AsciiString::push_str(self, other) + } +} + +impl anystr::AnyChar for ascii::AsciiChar { + fn is_lowercase(self) -> bool { + self.is_lowercase() + } + fn is_uppercase(self) -> bool { + self.is_uppercase() + } + fn bytes_len(self) -> usize { + 1 + } +} + +const ASCII_WHITESPACES: [u8; 6] = [0x20, 0x09, 0x0a, 0x0c, 0x0d, 0x0b]; + +impl AnyStr for AsciiStr { + type Char = AsciiChar; + type Container = AsciiString; + + fn to_container(&self) -> Self::Container { + self.to_ascii_string() + } + + fn as_bytes(&self) -> &[u8] { + self.as_bytes() + } + + fn elements(&self) -> impl Iterator { + self.chars() + } + + fn get_bytes(&self, range: std::ops::Range) -> &Self { + &self[range] + } + + fn get_chars(&self, range: std::ops::Range) -> &Self { + &self[range] + } + + fn bytes_len(&self) -> usize { + self.len() + } + + fn is_empty(&self) -> bool { + self.is_empty() + } + + fn py_split_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef, + { + let mut splits = Vec::new(); + let mut count = maxsplit; + let mut haystack = self; + while let Some(offset) = haystack.as_bytes().find_byteset(ASCII_WHITESPACES) { + if offset != 0 { + if count == 0 { + break; + } + splits.push(convert(&haystack[..offset])); + count -= 1; + } + haystack = &haystack[offset + 1..]; + } + if !haystack.is_empty() { + splits.push(convert(haystack)); + } + splits + } + + fn py_rsplit_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef, + { + // CPython rsplit_whitespace + let mut splits = Vec::new(); + let mut count = maxsplit; + let mut haystack = self; + while let Some(offset) = haystack.as_bytes().rfind_byteset(ASCII_WHITESPACES) { + if offset + 1 != haystack.len() { + if count == 0 { + break; + } + splits.push(convert(&haystack[offset + 1..])); + count -= 1; + } + haystack = &haystack[..offset]; + } + if !haystack.is_empty() { + splits.push(convert(haystack)); + } + splits + } +} + /// The unique reference of interned PyStr /// Always intended to be used as a static reference pub type PyStrInterned = PyInterned; @@ -1688,7 +2192,7 @@ impl PyStrInterned { impl std::fmt::Display for PyStrInterned { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - std::fmt::Display::fmt(self.as_str(), f) + self.data.fmt(f) } } @@ -1717,7 +2221,7 @@ mod tests { ("Greek ῼitlecases ...", "greek ῳitlecases ..."), ]; for (title, input) in tests { - assert_eq!(PyStr::from(input).title().as_str(), title); + assert_eq!(PyStr::from(input).title().as_str(), Ok(title)); } } diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index e8ad67a666..969d6db937 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -768,7 +768,7 @@ impl PyType { value.class().slot_name(), )) })?; - if name.as_str().as_bytes().contains(&0) { + if name.as_bytes().contains(&0) { return Err(vm.new_value_error("type name must not contain null characters".to_owned())); } @@ -811,7 +811,7 @@ impl Constructor for PyType { let (name, bases, dict, kwargs): (PyStrRef, PyTupleRef, PyDictRef, KwArgs) = args.clone().bind(vm)?; - if name.as_str().as_bytes().contains(&0) { + if name.as_bytes().contains(&0) { return Err(vm.new_value_error("type name must not contain null characters".to_owned())); } diff --git a/vm/src/builtins/union.rs b/vm/src/builtins/union.rs index 2798724ae5..165113e216 100644 --- a/vm/src/builtins/union.rs +++ b/vm/src/builtins/union.rs @@ -204,7 +204,7 @@ impl PyUnion { vm, )?; let mut res; - if new_args.len() == 0 { + if new_args.is_empty() { res = make_union(&new_args, vm); } else { res = new_args.fast_getitem(0); diff --git a/vm/src/bytesinner.rs b/vm/src/bytesinner.rs index bcbbd4fce6..63d5148e04 100644 --- a/vm/src/bytesinner.rs +++ b/vm/src/bytesinner.rs @@ -355,13 +355,11 @@ impl PyBytesInner { } pub fn islower(&self) -> bool { - self.elements - .py_iscase(char::is_lowercase, char::is_uppercase) + self.elements.py_islower() } pub fn isupper(&self) -> bool { - self.elements - .py_iscase(char::is_uppercase, char::is_lowercase) + self.elements.py_isupper() } pub fn isspace(&self) -> bool { @@ -654,6 +652,7 @@ impl PyBytesInner { let elements = self.elements.py_split( options, vm, + || convert(&self.elements, vm), |v, s, vm| v.split_str(s).map(|v| convert(v, vm)).collect(), |v, s, n, vm| v.splitn_str(n, s).map(|v| convert(v, vm)).collect(), |v, n, vm| v.py_split_whitespace(n, |v| convert(v, vm)), @@ -673,6 +672,7 @@ impl PyBytesInner { let mut elements = self.elements.py_split( options, vm, + || convert(&self.elements, vm), |v, s, vm| v.rsplit_str(s).map(|v| convert(v, vm)).collect(), |v, s, n, vm| v.rsplitn_str(n, s).map(|v| convert(v, vm)).collect(), |v, n, vm| v.py_rsplit_whitespace(n, |v| convert(v, vm)), @@ -998,10 +998,12 @@ pub trait ByteOr: ToPrimitive { impl ByteOr for BigInt {} -impl AnyStrWrapper for PyBytesInner { - type Str = [u8]; - fn as_ref(&self) -> &[u8] { - &self.elements +impl AnyStrWrapper<[u8]> for PyBytesInner { + fn as_ref(&self) -> Option<&[u8]> { + Some(&self.elements) + } + fn is_empty(&self) -> bool { + self.elements.is_empty() } } @@ -1021,14 +1023,22 @@ impl AnyStrContainer<[u8]> for Vec { const ASCII_WHITESPACES: [u8; 6] = [0x20, 0x09, 0x0a, 0x0c, 0x0d, 0x0b]; +impl anystr::AnyChar for u8 { + fn is_lowercase(self) -> bool { + self.is_ascii_lowercase() + } + fn is_uppercase(self) -> bool { + self.is_ascii_uppercase() + } + fn bytes_len(self) -> usize { + 1 + } +} + impl AnyStr for [u8] { type Char = u8; type Container = Vec; - fn element_bytes_len(_: u8) -> usize { - 1 - } - fn to_container(&self) -> Self::Container { self.to_vec() } @@ -1037,10 +1047,6 @@ impl AnyStr for [u8] { self } - fn chars(&self) -> impl Iterator { - bstr::ByteSlice::chars(self) - } - fn elements(&self) -> impl Iterator { self.iter().copied() } diff --git a/vm/src/cformat.rs b/vm/src/cformat.rs index af78fde021..93c409172c 100644 --- a/vm/src/cformat.rs +++ b/vm/src/cformat.rs @@ -2,6 +2,7 @@ //! as per the [Python Docs](https://docs.python.org/3/library/stdtypes.html#printf-style-string-formatting). use crate::common::cformat::*; +use crate::common::wtf8::{CodePoint, Wtf8, Wtf8Buf}; use crate::{ AsObject, PyObjectRef, PyResult, TryFromBorrowedObject, TryFromObject, VirtualMachine, builtins::{ @@ -125,13 +126,13 @@ fn spec_format_string( spec: &CFormatSpec, obj: PyObjectRef, idx: usize, -) -> PyResult { +) -> PyResult { match &spec.format_type { CFormatType::String(conversion) => { let result = match conversion { CFormatConversion::Ascii => builtins::ascii(obj, vm)?.into(), - CFormatConversion::Str => obj.str(vm)?.as_str().to_owned(), - CFormatConversion::Repr => obj.repr(vm)?.as_str().to_owned(), + CFormatConversion::Str => obj.str(vm)?.as_wtf8().to_owned(), + CFormatConversion::Repr => obj.repr(vm)?.as_wtf8().to_owned(), CFormatConversion::Bytes => { // idx is the position of the %, we want the position of the b return Err(vm.new_value_error(format!( @@ -146,16 +147,18 @@ fn spec_format_string( CNumberType::DecimalD | CNumberType::DecimalI | CNumberType::DecimalU => { match_class!(match &obj { ref i @ PyInt => { - Ok(spec.format_number(i.as_bigint())) + Ok(spec.format_number(i.as_bigint()).into()) } ref f @ PyFloat => { - Ok(spec.format_number(&try_f64_to_bigint(f.to_f64(), vm)?)) + Ok(spec + .format_number(&try_f64_to_bigint(f.to_f64(), vm)?) + .into()) } obj => { if let Some(method) = vm.get_method(obj.clone(), identifier!(vm, __int__)) { let result = method?.call((), vm)?; if let Some(i) = result.payload::() { - return Ok(spec.format_number(i.as_bigint())); + return Ok(spec.format_number(i.as_bigint()).into()); } } Err(vm.new_type_error(format!( @@ -168,7 +171,7 @@ fn spec_format_string( } _ => { if let Some(i) = obj.payload::() { - Ok(spec.format_number(i.as_bigint())) + Ok(spec.format_number(i.as_bigint()).into()) } else { Err(vm.new_type_error(format!( "%{} format: an integer is required, not {}", @@ -180,21 +183,21 @@ fn spec_format_string( }, CFormatType::Float(_) => { let value = ArgIntoFloat::try_from_object(vm, obj)?; - Ok(spec.format_float(value.into())) + Ok(spec.format_float(value.into()).into()) } CFormatType::Character(CCharacterType::Character) => { if let Some(i) = obj.payload::() { let ch = i .as_bigint() .to_u32() - .and_then(char::from_u32) + .and_then(CodePoint::from_u32) .ok_or_else(|| { vm.new_overflow_error("%c arg not in range(0x110000)".to_owned()) })?; return Ok(spec.format_char(ch)); } if let Some(s) = obj.payload::() { - if let Ok(ch) = s.as_str().chars().exactly_one() { + if let Ok(ch) = s.as_wtf8().code_points().exactly_one() { return Ok(spec.format_char(ch)); } } @@ -374,17 +377,16 @@ pub(crate) fn cformat_bytes( pub(crate) fn cformat_string( vm: &VirtualMachine, - format_string: &str, + format_string: &Wtf8, values_obj: PyObjectRef, -) -> PyResult { - let format = format_string - .parse::() +) -> PyResult { + let format = CFormatWtf8::parse_from_wtf8(format_string) .map_err(|err| vm.new_value_error(err.to_string()))?; let (num_specifiers, mapping_required) = format .check_specifiers() .ok_or_else(|| specifier_error(vm))?; - let mut result = String::new(); + let mut result = Wtf8Buf::new(); let is_mapping = values_obj.class().has_attr(identifier!(vm, __getitem__)) && !values_obj.fast_isinstance(vm.ctx.types.tuple_type) @@ -399,7 +401,7 @@ pub(crate) fn cformat_string( { for (_, part) in format.iter() { match part { - CFormatPart::Literal(literal) => result.push_str(literal), + CFormatPart::Literal(literal) => result.push_wtf8(literal), CFormatPart::Spec(_) => unreachable!(), } } @@ -415,11 +417,11 @@ pub(crate) fn cformat_string( return if is_mapping { for (idx, part) in format { match part { - CFormatPart::Literal(literal) => result.push_str(&literal), + CFormatPart::Literal(literal) => result.push_wtf8(&literal), CFormatPart::Spec(CFormatSpecKeyed { mapping_key, spec }) => { let value = values_obj.get_item(&mapping_key.unwrap(), vm)?; let part_result = spec_format_string(vm, &spec, value, idx)?; - result.push_str(&part_result); + result.push_wtf8(&part_result); } } } @@ -439,7 +441,7 @@ pub(crate) fn cformat_string( for (idx, part) in format { match part { - CFormatPart::Literal(literal) => result.push_str(&literal), + CFormatPart::Literal(literal) => result.push_wtf8(&literal), CFormatPart::Spec(CFormatSpecKeyed { mut spec, .. }) => { try_update_quantity_from_tuple( vm, @@ -456,7 +458,7 @@ pub(crate) fn cformat_string( } }?; let part_result = spec_format_string(vm, &spec, value, idx)?; - result.push_str(&part_result); + result.push_wtf8(&part_result); } } } diff --git a/vm/src/codecs.rs b/vm/src/codecs.rs index e104097413..bdb9b4b809 100644 --- a/vm/src/codecs.rs +++ b/vm/src/codecs.rs @@ -1,3 +1,5 @@ +use rustpython_common::wtf8::{CodePoint, Wtf8Buf}; + use crate::{ AsObject, Context, PyObject, PyObjectRef, PyPayload, PyResult, TryFromObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytesRef, PyStr, PyStrRef, PyTuple, PyTupleRef}, @@ -424,12 +426,13 @@ fn xmlcharrefreplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<( } let range = extract_unicode_error_range(&err, vm)?; let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_after_start = crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or(""); + let s_after_start = + crate::common::str::try_get_codepoints(s.as_wtf8(), range.start..).unwrap_or_default(); let num_chars = range.len(); // capacity rough guess; assuming that the codepoints are 3 digits in decimal + the &#; let mut out = String::with_capacity(num_chars * 6); - for c in s_after_start.chars().take(num_chars) { - write!(out, "&#{};", c as u32).unwrap() + for c in s_after_start.code_points().take(num_chars) { + write!(out, "&#{};", c.to_u32()).unwrap() } Ok((out, range.end)) } @@ -448,12 +451,13 @@ fn backslashreplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(S } let range = extract_unicode_error_range(&err, vm)?; let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_after_start = crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or(""); + let s_after_start = + crate::common::str::try_get_codepoints(s.as_wtf8(), range.start..).unwrap_or_default(); let num_chars = range.len(); // minimum 4 output bytes per char: \xNN let mut out = String::with_capacity(num_chars * 4); - for c in s_after_start.chars().take(num_chars) { - let c = c as u32; + for c in s_after_start.code_points().take(num_chars) { + let c = c.to_u32(); if c >= 0x10000 { write!(out, "\\U{c:08x}").unwrap(); } else if c >= 0x100 { @@ -470,12 +474,12 @@ fn namereplace_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(String let range = extract_unicode_error_range(&err, vm)?; let s = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; let s_after_start = - crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or(""); + crate::common::str::try_get_codepoints(s.as_wtf8(), range.start..).unwrap_or_default(); let num_chars = range.len(); let mut out = String::with_capacity(num_chars * 4); - for c in s_after_start.chars().take(num_chars) { - let c_u32 = c as u32; - if let Some(c_name) = unicode_names2::name(c) { + for c in s_after_start.code_points().take(num_chars) { + let c_u32 = c.to_u32(); + if let Some(c_name) = unicode_names2::name(c.to_char_lossy()) { write!(out, "\\N{{{c_name}}}").unwrap(); } else if c_u32 >= 0x10000 { write!(out, "\\U{c_u32:08x}").unwrap(); @@ -570,10 +574,11 @@ fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyOb return Err(err.downcast().unwrap()); } let s_after_start = - crate::common::str::try_get_chars(s.as_str(), range.start..).unwrap_or(""); + crate::common::str::try_get_codepoints(s.as_wtf8(), range.start..).unwrap_or_default(); let num_chars = range.len(); let mut out: Vec = Vec::with_capacity(num_chars * 4); - for c in s_after_start.chars().take(num_chars).map(|x| x as u32) { + for c in s_after_start.code_points().take(num_chars) { + let c = c.to_u32(); if !(0xd800..=0xdfff).contains(&c) { // Not a surrogate, fail with original exception return Err(err.downcast().unwrap()); @@ -671,7 +676,7 @@ fn surrogatepass_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(PyOb } Ok(( - vm.new_pyobj(format!("\\x{c:x?}")), + vm.new_pyobj(CodePoint::from_u32(c).unwrap()), range.start + byte_length, )) } else { @@ -683,11 +688,11 @@ fn surrogateescape_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(Py if err.fast_isinstance(vm.ctx.exceptions.unicode_encode_error) { let range = extract_unicode_error_range(&err, vm)?; let object = PyStrRef::try_from_object(vm, err.get_attr("object", vm)?)?; - let s_after_start = - crate::common::str::try_get_chars(object.as_str(), range.start..).unwrap_or(""); + let s_after_start = crate::common::str::try_get_codepoints(object.as_wtf8(), range.start..) + .unwrap_or_default(); let mut out: Vec = Vec::with_capacity(range.len()); - for ch in s_after_start.chars().take(range.len()) { - let ch = ch as u32; + for ch in s_after_start.code_points().take(range.len()) { + let ch = ch.to_u32(); if !(0xdc80..=0xdcff).contains(&ch) { // Not a UTF-8b surrogate, fail with original exception return Err(err.downcast().unwrap()); @@ -702,14 +707,14 @@ fn surrogateescape_errors(err: PyObjectRef, vm: &VirtualMachine) -> PyResult<(Py let object = PyBytesRef::try_from_object(vm, object)?; let p = &object.as_bytes()[range.clone()]; let mut consumed = 0; - let mut replace = String::with_capacity(4 * range.len()); + let mut replace = Wtf8Buf::with_capacity(4 * range.len()); while consumed < 4 && consumed < range.len() { - let c = p[consumed] as u32; + let c = p[consumed] as u16; // Refuse to escape ASCII bytes if c < 128 { break; } - write!(replace, "#{}", 0xdc00 + c).unwrap(); + replace.push(CodePoint::from(0xdc00 + c)); consumed += 1; } if consumed == 0 { diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index b36485e86a..ab37b7dc85 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -12,6 +12,7 @@ use crate::{ common::{ hash, lock::{PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard}, + wtf8::{Wtf8, Wtf8Buf}, }, object::{Traverse, TraverseFn}, }; @@ -750,7 +751,7 @@ impl DictKey for Py { if self.is(other_key) { Ok(true) } else if let Some(pystr) = other_key.payload_if_exact::(vm) { - Ok(pystr.as_str() == self.as_str()) + Ok(self.as_wtf8() == pystr.as_wtf8()) } else { vm.bool_eq(self.as_object(), other_key) } @@ -834,7 +835,7 @@ impl DictKey for str { fn key_eq(&self, vm: &VirtualMachine, other_key: &PyObject) -> PyResult { if let Some(pystr) = other_key.payload_if_exact::(vm) { - Ok(pystr.as_str() == self) + Ok(pystr.as_wtf8() == self) } else { // Fall back to PyObjectRef implementation. let s = vm.ctx.new_str(self); @@ -871,6 +872,63 @@ impl DictKey for String { } } +impl DictKey for Wtf8 { + type Owned = Wtf8Buf; + #[inline(always)] + fn _to_owned(&self, _vm: &VirtualMachine) -> Self::Owned { + self.to_owned() + } + #[inline] + fn key_hash(&self, vm: &VirtualMachine) -> PyResult { + // follow a similar route as the hashing of PyStrRef + Ok(vm.state.hash_secret.hash_bytes(self.as_bytes())) + } + #[inline(always)] + fn key_is(&self, _other: &PyObject) -> bool { + // No matter who the other pyobject is, we are never the same thing, since + // we are a str, not a pyobject. + false + } + + fn key_eq(&self, vm: &VirtualMachine, other_key: &PyObject) -> PyResult { + if let Some(pystr) = other_key.payload_if_exact::(vm) { + Ok(pystr.as_wtf8() == self) + } else { + // Fall back to PyObjectRef implementation. + let s = vm.ctx.new_str(self); + s.key_eq(vm, other_key) + } + } + + fn key_as_isize(&self, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("'str' object cannot be interpreted as an integer".to_owned())) + } +} + +impl DictKey for Wtf8Buf { + type Owned = Wtf8Buf; + #[inline] + fn _to_owned(&self, _vm: &VirtualMachine) -> Self::Owned { + self.clone() + } + + fn key_hash(&self, vm: &VirtualMachine) -> PyResult { + (**self).key_hash(vm) + } + + fn key_is(&self, other: &PyObject) -> bool { + (**self).key_is(other) + } + + fn key_eq(&self, vm: &VirtualMachine, other_key: &PyObject) -> PyResult { + (**self).key_eq(vm, other_key) + } + + fn key_as_isize(&self, vm: &VirtualMachine) -> PyResult { + (**self).key_as_isize(vm) + } +} + impl DictKey for [u8] { type Owned = Vec; #[inline(always)] diff --git a/vm/src/format.rs b/vm/src/format.rs index 1dbfd46779..3349ee854e 100644 --- a/vm/src/format.rs +++ b/vm/src/format.rs @@ -7,6 +7,7 @@ use crate::{ }; use crate::common::format::*; +use crate::common::wtf8::{Wtf8, Wtf8Buf}; impl IntoPyException for FormatSpecError { fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef { @@ -62,18 +63,18 @@ fn format_internal( vm: &VirtualMachine, format: &FormatString, field_func: &mut impl FnMut(FieldType) -> PyResult, -) -> PyResult { - let mut final_string = String::new(); +) -> PyResult { + let mut final_string = Wtf8Buf::new(); for part in &format.format_parts { let pystr; - let result_string: &str = match part { + let result_string: &Wtf8 = match part { FormatPart::Field { field_name, conversion_spec, format_spec, } => { let FieldName { field_type, parts } = - FieldName::parse(field_name.as_str()).map_err(|e| e.to_pyexception(vm))?; + FieldName::parse(field_name).map_err(|e| e.to_pyexception(vm))?; let mut argument = field_func(field_type)?; @@ -113,7 +114,7 @@ fn format_internal( } FormatPart::Literal(literal) => literal, }; - final_string.push_str(result_string); + final_string.push_wtf8(result_string); } Ok(final_string) } @@ -122,7 +123,7 @@ pub(crate) fn format( format: &FormatString, arguments: &FuncArgs, vm: &VirtualMachine, -) -> PyResult { +) -> PyResult { let mut auto_argument_index: usize = 0; let mut seen_index = false; format_internal(vm, format, &mut |field_type| match field_type { @@ -154,8 +155,10 @@ pub(crate) fn format( .cloned() .ok_or_else(|| vm.new_index_error("tuple index out of range".to_owned())) } - FieldType::Keyword(keyword) => arguments - .get_optional_kwarg(&keyword) + FieldType::Keyword(keyword) => keyword + .as_str() + .ok() + .and_then(|keyword| arguments.get_optional_kwarg(keyword)) .ok_or_else(|| vm.new_key_error(vm.ctx.new_str(keyword).into())), }) } @@ -164,7 +167,7 @@ pub(crate) fn format_map( format: &FormatString, dict: &PyObject, vm: &VirtualMachine, -) -> PyResult { +) -> PyResult { format_internal(vm, format, &mut |field_type| match field_type { FieldType::Auto | FieldType::Index(_) => { Err(vm.new_value_error("Format string contains positional fields".to_owned())) diff --git a/vm/src/frame.rs b/vm/src/frame.rs index cf695cd87b..c4dc23c95f 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -21,6 +21,7 @@ use crate::{ }; use indexmap::IndexMap; use itertools::Itertools; +use rustpython_common::wtf8::Wtf8Buf; #[cfg(feature = "threading")] use std::sync::atomic; use std::{fmt, iter::zip}; @@ -697,8 +698,8 @@ impl ExecutingFrame<'_> { .pop_multiple(size.get(arg) as usize) .as_slice() .iter() - .map(|pyobj| pyobj.payload::().unwrap().as_ref()) - .collect::(); + .map(|pyobj| pyobj.payload::().unwrap()) + .collect::(); let str_obj = vm.ctx.new_str(s); self.push_value(str_obj.into()); Ok(None) @@ -1468,7 +1469,7 @@ impl ExecutingFrame<'_> { let kwarg_names = kwarg_names .as_slice() .iter() - .map(|pyobj| pyobj.payload::().unwrap().as_ref().to_owned()); + .map(|pyobj| pyobj.payload::().unwrap().as_str().to_owned()); FuncArgs::with_kwargs_names(args, kwarg_names) } diff --git a/vm/src/function/buffer.rs b/vm/src/function/buffer.rs index 91379e7a7f..80b36833e5 100644 --- a/vm/src/function/buffer.rs +++ b/vm/src/function/buffer.rs @@ -152,7 +152,7 @@ impl ArgStrOrBytesLike { pub fn borrow_bytes(&self) -> BorrowedValue<'_, [u8]> { match self { Self::Buf(b) => b.borrow_buf(), - Self::Str(s) => s.as_str().as_bytes().into(), + Self::Str(s) => s.as_bytes().into(), } } } @@ -195,7 +195,7 @@ impl ArgAsciiBuffer { #[inline] pub fn with_ref(&self, f: impl FnOnce(&[u8]) -> R) -> R { match self { - Self::String(s) => f(s.as_str().as_bytes()), + Self::String(s) => f(s.as_bytes()), Self::Buffer(buffer) => buffer.with_ref(f), } } diff --git a/vm/src/function/fspath.rs b/vm/src/function/fspath.rs index 69f11eb65d..83bd452151 100644 --- a/vm/src/function/fspath.rs +++ b/vm/src/function/fspath.rs @@ -5,7 +5,7 @@ use crate::{ function::PyStr, protocol::PyBuffer, }; -use std::{ffi::OsStr, path::PathBuf}; +use std::{borrow::Cow, ffi::OsStr, path::PathBuf}; #[derive(Clone)] pub enum FsPath { @@ -26,7 +26,7 @@ impl FsPath { let match1 = |obj: PyObjectRef| { let pathlike = match_class!(match obj { s @ PyStr => { - check_nul(s.as_str().as_bytes())?; + check_nul(s.as_bytes())?; FsPath::Str(s) } b @ PyBytes => { @@ -58,26 +58,26 @@ impl FsPath { }) } - pub fn as_os_str(&self, vm: &VirtualMachine) -> PyResult<&OsStr> { + pub fn as_os_str(&self, vm: &VirtualMachine) -> PyResult> { // TODO: FS encodings match self { - FsPath::Str(s) => Ok(s.as_str().as_ref()), - FsPath::Bytes(b) => Self::bytes_as_osstr(b.as_bytes(), vm), + FsPath::Str(s) => vm.fsencode(s), + FsPath::Bytes(b) => Self::bytes_as_osstr(b.as_bytes(), vm).map(Cow::Borrowed), } } pub fn as_bytes(&self) -> &[u8] { // TODO: FS encodings match self { - FsPath::Str(s) => s.as_str().as_bytes(), + FsPath::Str(s) => s.as_bytes(), FsPath::Bytes(b) => b.as_bytes(), } } - pub fn as_str(&self) -> &str { + pub fn to_string_lossy(&self) -> Cow<'_, str> { match self { - FsPath::Bytes(b) => std::str::from_utf8(b).unwrap(), - FsPath::Str(s) => s.as_str(), + FsPath::Str(s) => s.to_string_lossy(), + FsPath::Bytes(s) => String::from_utf8_lossy(s), } } diff --git a/vm/src/import.rs b/vm/src/import.rs index 860f0b8a16..2d86e47d08 100644 --- a/vm/src/import.rs +++ b/vm/src/import.rs @@ -204,15 +204,15 @@ fn remove_importlib_frames_inner( // TODO: This function should do nothing on verbose mode. // TODO: Fix this function after making PyTraceback.next mutable -pub fn remove_importlib_frames( - vm: &VirtualMachine, - exc: &PyBaseExceptionRef, -) -> PyBaseExceptionRef { +pub fn remove_importlib_frames(vm: &VirtualMachine, exc: &PyBaseExceptionRef) { + if vm.state.settings.verbose != 0 { + return; + } + let always_trim = exc.fast_isinstance(vm.ctx.exceptions.import_error); if let Some(tb) = exc.traceback() { let trimmed_tb = remove_importlib_frames_inner(vm, Some(tb), always_trim).0; exc.set_traceback(trimmed_tb); } - exc.clone() } diff --git a/vm/src/ospath.rs b/vm/src/ospath.rs index 9dda60d621..c1b1859164 100644 --- a/vm/src/ospath.rs +++ b/vm/src/ospath.rs @@ -21,28 +21,14 @@ pub(super) enum OutputMode { } impl OutputMode { - pub(super) fn process_path(self, path: impl Into, vm: &VirtualMachine) -> PyResult { - fn inner(mode: OutputMode, path: PathBuf, vm: &VirtualMachine) -> PyResult { - let path_as_string = |p: PathBuf| { - p.into_os_string().into_string().map_err(|_| { - vm.new_unicode_decode_error( - "Can't convert OS path to valid UTF-8 string".into(), - ) - }) - }; + pub(super) fn process_path(self, path: impl Into, vm: &VirtualMachine) -> PyObjectRef { + fn inner(mode: OutputMode, path: PathBuf, vm: &VirtualMachine) -> PyObjectRef { match mode { - OutputMode::String => path_as_string(path).map(|s| vm.ctx.new_str(s).into()), - OutputMode::Bytes => { - #[cfg(any(unix, target_os = "wasi"))] - { - use rustpython_common::os::ffi::OsStringExt; - Ok(vm.ctx.new_bytes(path.into_os_string().into_vec()).into()) - } - #[cfg(windows)] - { - path_as_string(path).map(|s| vm.ctx.new_bytes(s.into_bytes()).into()) - } - } + OutputMode::String => vm.fsdecode(path).into(), + OutputMode::Bytes => vm + .ctx + .new_bytes(path.into_os_string().into_encoded_bytes()) + .into(), } } inner(self, path.into(), vm) @@ -59,7 +45,7 @@ impl OsPath { } pub(crate) fn from_fspath(fspath: FsPath, vm: &VirtualMachine) -> PyResult { - let path = fspath.as_os_str(vm)?.to_owned(); + let path = fspath.as_os_str(vm)?.into_owned(); let mode = match fspath { FsPath::Str(_) => OutputMode::String, FsPath::Bytes(_) => OutputMode::Bytes, @@ -88,7 +74,7 @@ impl OsPath { widestring::WideCString::from_os_str(&self.path).map_err(|err| err.to_pyexception(vm)) } - pub fn filename(&self, vm: &VirtualMachine) -> PyResult { + pub fn filename(&self, vm: &VirtualMachine) -> PyObjectRef { self.mode.process_path(self.path.clone(), vm) } } @@ -133,7 +119,7 @@ impl From for OsPathOrFd { impl OsPathOrFd { pub fn filename(&self, vm: &VirtualMachine) -> PyObjectRef { match self { - OsPathOrFd::Path(path) => path.filename(vm).unwrap_or_else(|_| vm.ctx.none()), + OsPathOrFd::Path(path) => path.filename(vm), OsPathOrFd::Fd(fd) => vm.ctx.new_int(*fd).into(), } } diff --git a/vm/src/stdlib/builtins.rs b/vm/src/stdlib/builtins.rs index 9c2826a1e9..f778d02ba3 100644 --- a/vm/src/stdlib/builtins.rs +++ b/vm/src/stdlib/builtins.rs @@ -31,7 +31,9 @@ mod builtins { stdlib::sys, types::PyComparisonOp, }; + use itertools::Itertools; use num_traits::{Signed, ToPrimitive}; + use rustpython_common::wtf8::CodePoint; #[cfg(not(feature = "rustpython-compiler"))] const CODEGEN_NOT_SUPPORTED: &str = @@ -85,13 +87,13 @@ mod builtins { } #[pyfunction] - fn chr(i: PyIntRef, vm: &VirtualMachine) -> PyResult { + fn chr(i: PyIntRef, vm: &VirtualMachine) -> PyResult { let value = i .try_to_primitive::(vm)? .to_u32() - .and_then(char::from_u32) + .and_then(CodePoint::from_u32) .ok_or_else(|| vm.new_value_error("chr() arg not in range(0x110000)".to_owned()))?; - Ok(value.to_string()) + Ok(value) } #[derive(FromArgs)] @@ -153,7 +155,7 @@ mod builtins { return ast::compile( vm, args.source, - args.filename.as_str(), + &args.filename.to_string_lossy(), mode, Some(optimize), ); @@ -202,7 +204,7 @@ mod builtins { .compile_with_opts( source, mode, - args.filename.as_str().to_owned(), + args.filename.to_string_lossy().into_owned(), opts, ) .map_err(|err| (err, Some(source)).to_pyexception(vm))?; @@ -618,21 +620,15 @@ mod builtins { } Ok(u32::from(bytes[0])) }), - Either::B(string) => { - let string = string.as_str(); - let string_len = string.chars().count(); - if string_len != 1 { - return Err(vm.new_type_error(format!( + Either::B(string) => match string.as_wtf8().code_points().exactly_one() { + Ok(character) => Ok(character.to_u32()), + Err(_) => { + let string_len = string.char_len(); + Err(vm.new_type_error(format!( "ord() expected a character, but string of length {string_len} found" - ))); + ))) } - match string.chars().next() { - Some(character) => Ok(character as u32), - None => Err(vm.new_type_error( - "ord() could not guess the integer representing this character".to_owned(), - )), - } - } + }, } } diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index 976545f64b..664fe00616 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -3,6 +3,8 @@ pub(crate) use _codecs::make_module; #[pymodule] mod _codecs { use crate::common::encodings; + use crate::common::str::StrKind; + use crate::common::wtf8::{Wtf8, Wtf8Buf}; use crate::{ AsObject, PyObject, PyObjectRef, PyResult, TryFromBorrowedObject, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyStr, PyStrRef, PyTuple}, @@ -103,8 +105,8 @@ mod _codecs { } } impl encodings::StrBuffer for PyStrRef { - fn is_ascii(&self) -> bool { - PyStr::is_ascii(self) + fn is_compatible_with(&self, kind: StrKind) -> bool { + self.kind() <= kind } } impl encodings::ErrorHandler for ErrorsHandler<'_> { @@ -114,7 +116,7 @@ mod _codecs { fn handle_encode_error( &self, - data: &str, + data: &Wtf8, char_range: Range, reason: &str, ) -> PyResult<(encodings::EncodeReplace, usize)> { @@ -217,7 +219,7 @@ mod _codecs { fn error_encoding( &self, - data: &str, + data: &Wtf8, char_range: Range, reason: &str, ) -> Self::Error { @@ -249,15 +251,15 @@ mod _codecs { #[inline] fn encode<'a, F>(self, name: &'a str, encode: F, vm: &'a VirtualMachine) -> EncodeResult where - F: FnOnce(&str, &ErrorsHandler<'a>) -> PyResult>, + F: FnOnce(&Wtf8, &ErrorsHandler<'a>) -> PyResult>, { let errors = ErrorsHandler::new(name, self.errors, vm); - let encoded = encode(self.s.as_str(), &errors)?; + let encoded = encode(self.s.as_wtf8(), &errors)?; Ok((encoded, self.s.char_len())) } } - type DecodeResult = PyResult<(String, usize)>; + type DecodeResult = PyResult<(Wtf8Buf, usize)>; #[derive(FromArgs)] struct DecodeArgs { @@ -310,6 +312,14 @@ mod _codecs { #[pyfunction] fn utf_8_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { + if args.s.is_utf8() + || args + .errors + .as_ref() + .is_some_and(|s| s.is(identifier!(vm, surrogatepass))) + { + return Ok((args.s.as_bytes().to_vec(), args.s.byte_len())); + } do_codec!(utf8::encode, args, vm) } @@ -321,7 +331,7 @@ mod _codecs { #[pyfunction] fn latin_1_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { if args.s.is_ascii() { - return Ok((args.s.as_str().as_bytes().to_vec(), args.s.byte_len())); + return Ok((args.s.as_bytes().to_vec(), args.s.byte_len())); } do_codec!(latin_1::encode, args, vm) } @@ -334,7 +344,7 @@ mod _codecs { #[pyfunction] fn ascii_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { if args.s.is_ascii() { - return Ok((args.s.as_str().as_bytes().to_vec(), args.s.byte_len())); + return Ok((args.s.as_bytes().to_vec(), args.s.byte_len())); } do_codec!(ascii::encode, args, vm) } diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index 0f472fd940..0b680251fa 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -129,6 +129,7 @@ mod _io { PyMappedThreadMutexGuard, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard, PyThreadMutex, PyThreadMutexGuard, }, + common::wtf8::{Wtf8, Wtf8Buf}, convert::ToPyObject, function::{ ArgBytesLike, ArgIterable, ArgMemoryBuffer, ArgSize, Either, FuncArgs, IntoFuncArgs, @@ -1909,10 +1910,12 @@ mod _io { impl Newlines { /// returns position where the new line starts if found, otherwise position at which to /// continue the search after more is read into the buffer - fn find_newline(&self, s: &str) -> Result { + fn find_newline(&self, s: &Wtf8) -> Result { let len = s.len(); match self { - Newlines::Universal | Newlines::Lf => s.find('\n').map(|p| p + 1).ok_or(len), + Newlines::Universal | Newlines::Lf => { + s.find("\n".as_ref()).map(|p| p + 1).ok_or(len) + } Newlines::Passthrough => { let bytes = s.as_bytes(); memchr::memchr2(b'\n', b'\r', bytes) @@ -1927,7 +1930,7 @@ mod _io { }) .ok_or(len) } - Newlines::Cr => s.find('\n').map(|p| p + 1).ok_or(len), + Newlines::Cr => s.find("\n".as_ref()).map(|p| p + 1).ok_or(len), Newlines::Crlf => { // s[searched..] == remaining let mut searched = 0; @@ -1992,10 +1995,10 @@ mod _io { } } - fn len_str(s: &str) -> Self { + fn len_str(s: &Wtf8) -> Self { Utf8size { bytes: s.len(), - chars: s.chars().count(), + chars: s.code_points().count(), } } } @@ -2082,7 +2085,7 @@ mod _io { impl PendingWrite { fn as_bytes(&self) -> &[u8] { match self { - Self::Utf8(s) => s.as_str().as_bytes(), + Self::Utf8(s) => s.as_bytes(), Self::Bytes(b) => b.as_bytes(), } } @@ -2222,8 +2225,8 @@ mod _io { *data = None; let encoding = match args.encoding { - None if vm.state.settings.utf8_mode > 0 => PyStr::from("utf-8").into_ref(&vm.ctx), - Some(enc) if enc.as_str() != "locale" => enc, + None if vm.state.settings.utf8_mode > 0 => identifier!(vm, utf_8).to_owned(), + Some(enc) if enc.as_wtf8() != "locale" => enc, _ => { // None without utf8_mode or "locale" encoding vm.import("locale", 0)? @@ -2235,7 +2238,7 @@ mod _io { let errors = args .errors - .unwrap_or_else(|| PyStr::from("strict").into_ref(&vm.ctx)); + .unwrap_or_else(|| identifier!(vm, strict).to_owned()); let has_read1 = vm.get_attribute_opt(buffer.clone(), "read1")?.is_some(); let seekable = vm.call_method(&buffer, "seekable", ())?.try_to_bool(vm)?; @@ -2533,9 +2536,10 @@ mod _io { *snapshot = Some((cookie.dec_flags, input_chunk.clone())); let decoded = vm.call_method(decoder, "decode", (input_chunk, cookie.need_eof))?; let decoded = check_decoded(decoded, vm)?; - let pos_is_valid = decoded - .as_str() - .is_char_boundary(cookie.bytes_to_skip as usize); + let pos_is_valid = crate::common::wtf8::is_code_point_boundary( + decoded.as_wtf8(), + cookie.bytes_to_skip as usize, + ); textio.set_decoded_chars(Some(decoded)); if !pos_is_valid { return Err(vm.new_os_error("can't restore logical file position".to_owned())); @@ -2714,9 +2718,9 @@ mod _io { } else if chunks.len() == 1 { chunks.pop().unwrap() } else { - let mut ret = String::with_capacity(chunks_bytes); + let mut ret = Wtf8Buf::with_capacity(chunks_bytes); for chunk in chunks { - ret.push_str(chunk.as_str()) + ret.push_wtf8(chunk.as_wtf8()) } PyStr::from(ret).into_ref(&vm.ctx) } @@ -2743,7 +2747,7 @@ mod _io { let char_len = obj.char_len(); - let data = obj.as_str(); + let data = obj.as_wtf8(); let replace_nl = match textio.newline { Newlines::Lf => Some("\n"), @@ -2752,11 +2756,12 @@ mod _io { Newlines::Universal if cfg!(windows) => Some("\r\n"), _ => None, }; - let has_lf = (replace_nl.is_some() || textio.line_buffering) && data.contains('\n'); - let flush = textio.line_buffering && (has_lf || data.contains('\r')); + let has_lf = (replace_nl.is_some() || textio.line_buffering) + && data.contains_code_point('\n'.into()); + let flush = textio.line_buffering && (has_lf || data.contains_code_point('\r'.into())); let chunk = if let Some(replace_nl) = replace_nl { if has_lf { - PyStr::from(data.replace('\n', replace_nl)).into_ref(&vm.ctx) + PyStr::from(data.replace("\n".as_ref(), replace_nl.as_ref())).into_ref(&vm.ctx) } else { obj } @@ -2833,7 +2838,7 @@ mod _io { if self.is_full_slice() { self.0.char_len() } else { - self.slice().chars().count() + self.slice().code_points().count() } } #[inline] @@ -2841,8 +2846,8 @@ mod _io { self.1.len() >= self.0.byte_len() } #[inline] - fn slice(&self) -> &str { - &self.0.as_str()[self.1.clone()] + fn slice(&self) -> &Wtf8 { + &self.0.as_wtf8()[self.1.clone()] } #[inline] fn slice_pystr(self, vm: &VirtualMachine) -> PyStrRef { @@ -2893,7 +2898,7 @@ mod _io { Some(remaining) => { assert_eq!(textio.decoded_chars_used.bytes, 0); offset_to_buffer = remaining.utf8_len(); - let decoded_chars = decoded_chars.as_str(); + let decoded_chars = decoded_chars.as_wtf8(); let line = if remaining.is_full_slice() { let mut line = remaining.0; line.concat_in_place(decoded_chars, vm); @@ -2901,16 +2906,16 @@ mod _io { } else { let remaining = remaining.slice(); let mut s = - String::with_capacity(remaining.len() + decoded_chars.len()); - s.push_str(remaining); - s.push_str(decoded_chars); + Wtf8Buf::with_capacity(remaining.len() + decoded_chars.len()); + s.push_wtf8(remaining); + s.push_wtf8(decoded_chars); PyStr::from(s).into_ref(&vm.ctx) }; start = Utf8size::default(); line } }; - let line_from_start = &line.as_str()[start.bytes..]; + let line_from_start = &line.as_wtf8()[start.bytes..]; let nl_res = textio.newline.find_newline(line_from_start); match nl_res { Ok(p) | Err(p) => { @@ -2921,7 +2926,7 @@ mod _io { endpos = start + Utf8size { chars: limit - chunked.chars, - bytes: crate::common::str::char_range_end( + bytes: crate::common::str::codepoint_range_end( line_from_start, limit - chunked.chars, ) @@ -2962,9 +2967,9 @@ mod _io { chunked += cur_line.byte_len(); chunks.push(cur_line); } - let mut s = String::with_capacity(chunked); + let mut s = Wtf8Buf::with_capacity(chunked); for chunk in chunks { - s.push_str(chunk.slice()) + s.push_wtf8(chunk.slice()) } PyStr::from(s).into_ref(&vm.ctx) } else if let Some(cur_line) = cur_line { @@ -3099,7 +3104,7 @@ mod _io { return None; } let decoded_chars = self.decoded_chars.as_ref()?; - let avail = &decoded_chars.as_str()[self.decoded_chars_used.bytes..]; + let avail = &decoded_chars.as_wtf8()[self.decoded_chars_used.bytes..]; if avail.is_empty() { return None; } @@ -3111,7 +3116,7 @@ mod _io { (PyStr::from(avail).into_ref(&vm.ctx), avail_chars) } } else { - let s = crate::common::str::get_chars(avail, 0..n); + let s = crate::common::str::get_codepoints(avail, 0..n); (PyStr::from(s).into_ref(&vm.ctx), n) }; self.decoded_chars_used += Utf8size { @@ -3141,11 +3146,11 @@ mod _io { return decoded_chars; } // TODO: in-place editing of `str` when refcount == 1 - let decoded_chars_unused = &decoded_chars.as_str()[chars_pos..]; - let mut s = String::with_capacity(decoded_chars_unused.len() + append_len); - s.push_str(decoded_chars_unused); + let decoded_chars_unused = &decoded_chars.as_wtf8()[chars_pos..]; + let mut s = Wtf8Buf::with_capacity(decoded_chars_unused.len() + append_len); + s.push_wtf8(decoded_chars_unused); if let Some(append) = append { - s.push_str(append.as_str()) + s.push_wtf8(append.as_wtf8()) } PyStr::from(s).into_ref(&vm.ctx) } @@ -3305,14 +3310,14 @@ mod _io { }; let orig_output: PyStrRef = output.try_into_value(vm)?; // this being Cow::Owned means we need to allocate a new string - let mut output = Cow::Borrowed(orig_output.as_str()); + let mut output = Cow::Borrowed(orig_output.as_wtf8()); if self.pendingcr && (final_ || !output.is_empty()) { - output = ["\r", &*output].concat().into(); + output.to_mut().insert(0, '\r'.into()); self.pendingcr = false; } if !final_ { - if let Some(s) = output.strip_suffix('\r') { - output = s.to_owned().into(); + if let Some(s) = output.strip_suffix("\r".as_ref()) { + output = Cow::Owned(s.to_owned()); self.pendingcr = true; } } @@ -3321,19 +3326,21 @@ mod _io { return Ok(vm.ctx.empty_str.to_owned()); } - if (self.seennl == SeenNewline::LF || self.seennl.is_empty()) && !output.contains('\r') + if (self.seennl == SeenNewline::LF || self.seennl.is_empty()) + && !output.contains_code_point('\r'.into()) { - if self.seennl.is_empty() && output.contains('\n') { + if self.seennl.is_empty() && output.contains_code_point('\n'.into()) { self.seennl.insert(SeenNewline::LF); } } else if !self.translate { - let mut matches = output.match_indices(['\r', '\n']); + let output = output.as_bytes(); + let mut matches = memchr::memchr2_iter(b'\r', b'\n', output); while !self.seennl.is_all() { - let Some((i, c)) = matches.next() else { break }; - match c { - "\n" => self.seennl.insert(SeenNewline::LF), + let Some(i) = matches.next() else { break }; + match output[i] { + b'\n' => self.seennl.insert(SeenNewline::LF), // if c isn't \n, it can only be \r - _ if output[i + 1..].starts_with('\n') => { + _ if output.get(i + 1) == Some(&b'\n') => { matches.next(); self.seennl.insert(SeenNewline::CRLF); } @@ -3341,30 +3348,31 @@ mod _io { } } } else { - let mut chunks = output.match_indices(['\r', '\n']); - let mut new_string = String::with_capacity(output.len()); + let bytes = output.as_bytes(); + let mut matches = memchr::memchr2_iter(b'\r', b'\n', bytes); + let mut new_string = Wtf8Buf::with_capacity(output.len()); let mut last_modification_index = 0; - while let Some((cr_index, chunk)) = chunks.next() { - if chunk == "\r" { + while let Some(cr_index) = matches.next() { + if bytes[cr_index] == b'\r' { // skip copying the CR let mut next_chunk_index = cr_index + 1; - if output[cr_index + 1..].starts_with('\n') { - chunks.next(); + if bytes.get(cr_index + 1) == Some(&b'\n') { + matches.next(); self.seennl.insert(SeenNewline::CRLF); // skip the LF too next_chunk_index += 1; } else { self.seennl.insert(SeenNewline::CR); } - new_string.push_str(&output[last_modification_index..cr_index]); - new_string.push('\n'); + new_string.push_wtf8(&output[last_modification_index..cr_index]); + new_string.push_char('\n'); last_modification_index = next_chunk_index; } else { self.seennl.insert(SeenNewline::LF); } } - new_string.push_str(&output[last_modification_index..]); - output = new_string.into(); + new_string.push_wtf8(&output[last_modification_index..]); + output = Cow::Owned(new_string); } Ok(match output { @@ -3404,7 +3412,7 @@ mod _io { ) -> PyResult { let raw_bytes = object .flatten() - .map_or_else(Vec::new, |v| v.as_str().as_bytes().to_vec()); + .map_or_else(Vec::new, |v| v.as_bytes().to_vec()); StringIO { buffer: PyRwLock::new(BufferedIO::new(Cursor::new(raw_bytes))), @@ -3453,7 +3461,7 @@ mod _io { // write string to underlying vector #[pymethod] fn write(&self, data: PyStrRef, vm: &VirtualMachine) -> PyResult { - let bytes = data.as_str().as_bytes(); + let bytes = data.as_bytes(); self.buffer(vm)? .write(bytes) .ok_or_else(|| vm.new_type_error("Error Writing String".to_owned())) diff --git a/vm/src/stdlib/nt.rs b/vm/src/stdlib/nt.rs index 48f1ab668b..624577b5ce 100644 --- a/vm/src/stdlib/nt.rs +++ b/vm/src/stdlib/nt.rs @@ -249,7 +249,7 @@ pub(crate) mod module { .as_ref() .canonicalize() .map_err(|e| e.to_pyexception(vm))?; - path.mode.process_path(real, vm) + Ok(path.mode.process_path(real, vm)) } #[pyfunction] @@ -282,7 +282,7 @@ pub(crate) mod module { } } let buffer = widestring::WideCString::from_vec_truncate(buffer); - path.mode.process_path(buffer.to_os_string(), vm) + Ok(path.mode.process_path(buffer.to_os_string(), vm)) } #[pyfunction] @@ -297,7 +297,7 @@ pub(crate) mod module { return Err(errno_err(vm)); } let buffer = widestring::WideCString::from_vec_truncate(buffer); - path.mode.process_path(buffer.to_os_string(), vm) + Ok(path.mode.process_path(buffer.to_os_string(), vm)) } #[pyfunction] diff --git a/vm/src/stdlib/operator.rs b/vm/src/stdlib/operator.rs index d1a4b376e8..d8ff1715fa 100644 --- a/vm/src/stdlib/operator.rs +++ b/vm/src/stdlib/operator.rs @@ -328,7 +328,7 @@ mod _operator { "comparing strings with non-ASCII characters is not supported".to_owned(), )); } - cmp::timing_safe_cmp(a.as_str().as_bytes(), b.as_str().as_bytes()) + cmp::timing_safe_cmp(a.as_bytes(), b.as_bytes()) } (Either::B(a), Either::B(b)) => { a.with_ref(|a| b.with_ref(|b| cmp::timing_safe_cmp(a, b))) diff --git a/vm/src/stdlib/os.rs b/vm/src/stdlib/os.rs index 39701cb3a3..e1a5825b82 100644 --- a/vm/src/stdlib/os.rs +++ b/vm/src/stdlib/os.rs @@ -332,7 +332,7 @@ pub(super) mod _os { }; dir_iter .map(|entry| match entry { - Ok(entry_path) => path.mode.process_path(entry_path.file_name(), vm), + Ok(entry_path) => Ok(path.mode.process_path(entry_path.file_name(), vm)), Err(err) => Err(IOErrorBuilder::with_filename(&err, path.clone(), vm)), }) .collect::>()? @@ -352,22 +352,18 @@ pub(super) mod _os { let mut dir = nix::dir::Dir::from_fd(new_fd).map_err(|e| e.into_pyexception(vm))?; dir.iter() - .filter_map(|entry| { - entry - .map_err(|e| e.into_pyexception(vm)) - .and_then(|entry| { - let fname = entry.file_name().to_bytes(); - Ok(match fname { - b"." | b".." => None, - _ => Some( - OutputMode::String - .process_path(ffi::OsStr::from_bytes(fname), vm)?, - ), - }) - }) - .transpose() + .filter_map_ok(|entry| { + let fname = entry.file_name().to_bytes(); + match fname { + b"." | b".." => None, + _ => Some( + OutputMode::String + .process_path(ffi::OsStr::from_bytes(fname), vm), + ), + } }) - .collect::>()? + .collect::>() + .map_err(|e| e.into_pyexception(vm))? } } }; @@ -376,7 +372,7 @@ pub(super) mod _os { fn env_bytes_as_bytes(obj: &Either) -> &[u8] { match obj { - Either::A(s) => s.as_str().as_bytes(), + Either::A(s) => s.as_bytes(), Either::B(b) => b.as_bytes(), } } @@ -429,7 +425,7 @@ pub(super) mod _os { let [] = dir_fd.0; let path = fs::read_link(&path).map_err(|err| IOErrorBuilder::with_filename(&err, path, vm))?; - mode.process_path(path, vm) + Ok(mode.process_path(path, vm)) } #[pyattr] @@ -452,12 +448,12 @@ pub(super) mod _os { impl DirEntry { #[pygetset] fn name(&self, vm: &VirtualMachine) -> PyResult { - self.mode.process_path(&self.file_name, vm) + Ok(self.mode.process_path(&self.file_name, vm)) } #[pygetset] fn path(&self, vm: &VirtualMachine) -> PyResult { - self.mode.process_path(&self.pathval, vm) + Ok(self.mode.process_path(&self.pathval, vm)) } fn perform_on_metadata( @@ -908,12 +904,12 @@ pub(super) mod _os { #[pyfunction] fn getcwd(vm: &VirtualMachine) -> PyResult { - OutputMode::String.process_path(curdir_inner(vm)?, vm) + Ok(OutputMode::String.process_path(curdir_inner(vm)?, vm)) } #[pyfunction] fn getcwdb(vm: &VirtualMachine) -> PyResult { - OutputMode::Bytes.process_path(curdir_inner(vm)?, vm) + Ok(OutputMode::Bytes.process_path(curdir_inner(vm)?, vm)) } #[pyfunction] diff --git a/vm/src/stdlib/sre.rs b/vm/src/stdlib/sre.rs index 193976a62d..038ac9934a 100644 --- a/vm/src/stdlib/sre.rs +++ b/vm/src/stdlib/sre.rs @@ -9,6 +9,7 @@ mod _sre { PyCallableIterator, PyDictRef, PyGenericAlias, PyInt, PyList, PyListRef, PyStr, PyStrRef, PyTuple, PyTupleRef, PyTypeRef, }, + common::wtf8::{Wtf8, Wtf8Buf}, common::{ascii, hash::PyHash}, convert::ToPyObject, function::{ArgCallable, OptionalArg, PosArgs, PyComparisonValue}, @@ -66,10 +67,15 @@ mod _sre { } } - impl SreStr for &str { + impl SreStr for &Wtf8 { fn slice(&self, start: usize, end: usize, vm: &VirtualMachine) -> PyObjectRef { vm.ctx - .new_str(self.chars().take(end).skip(start).collect::()) + .new_str( + self.code_points() + .take(end) + .skip(start) + .collect::(), + ) .into() } } @@ -206,12 +212,12 @@ mod _sre { impl Pattern { fn with_str(string: &PyObject, vm: &VirtualMachine, f: F) -> PyResult where - F: FnOnce(&str) -> PyResult, + F: FnOnce(&Wtf8) -> PyResult, { let string = string.payload::().ok_or_else(|| { vm.new_type_error(format!("expected string got '{}'", string.class())) })?; - f(string.as_str()) + f(string.as_wtf8()) } fn with_bytes(string: &PyObject, vm: &VirtualMachine, f: F) -> PyResult @@ -425,7 +431,7 @@ mod _sre { let is_template = if zelf.isbytes { Self::with_bytes(&repl, vm, |x| Ok(x.contains(&b'\\')))? } else { - Self::with_str(&repl, vm, |x| Ok(x.contains('\\')))? + Self::with_str(&repl, vm, |x| Ok(x.contains("\\".as_ref())))? }; if is_template { diff --git a/vm/src/stdlib/string.rs b/vm/src/stdlib/string.rs index 3c399e1c37..576cae6277 100644 --- a/vm/src/stdlib/string.rs +++ b/vm/src/stdlib/string.rs @@ -9,6 +9,7 @@ mod _string { use crate::common::format::{ FieldName, FieldNamePart, FieldType, FormatPart, FormatString, FromTemplate, }; + use crate::common::wtf8::{CodePoint, Wtf8Buf}; use crate::{ PyObjectRef, PyResult, VirtualMachine, builtins::{PyList, PyStrRef}, @@ -18,10 +19,10 @@ mod _string { use std::mem; fn create_format_part( - literal: String, - field_name: Option, - format_spec: Option, - conversion_spec: Option, + literal: Wtf8Buf, + field_name: Option, + format_spec: Option, + conversion_spec: Option, vm: &VirtualMachine, ) -> PyObjectRef { let tuple = ( @@ -36,10 +37,10 @@ mod _string { #[pyfunction] fn formatter_parser(text: PyStrRef, vm: &VirtualMachine) -> PyResult { let format_string = - FormatString::from_str(text.as_str()).map_err(|e| e.to_pyexception(vm))?; + FormatString::from_str(text.as_wtf8()).map_err(|e| e.to_pyexception(vm))?; - let mut result = Vec::new(); - let mut literal = String::new(); + let mut result: Vec = Vec::new(); + let mut literal = Wtf8Buf::new(); for part in format_string.format_parts { match part { FormatPart::Field { @@ -55,7 +56,7 @@ mod _string { vm, )); } - FormatPart::Literal(text) => literal.push_str(&text), + FormatPart::Literal(text) => literal.push_wtf8(&text), } } if !literal.is_empty() { @@ -75,7 +76,7 @@ mod _string { text: PyStrRef, vm: &VirtualMachine, ) -> PyResult<(PyObjectRef, PyList)> { - let field_name = FieldName::parse(text.as_str()).map_err(|e| e.to_pyexception(vm))?; + let field_name = FieldName::parse(text.as_wtf8()).map_err(|e| e.to_pyexception(vm))?; let first = match field_name.field_type { FieldType::Auto => vm.ctx.new_str(ascii!("")).into(), diff --git a/vm/src/stdlib/sys.rs b/vm/src/stdlib/sys.rs index 39c803a01b..fdfe2faf69 100644 --- a/vm/src/stdlib/sys.rs +++ b/vm/src/stdlib/sys.rs @@ -458,21 +458,13 @@ mod sys { } #[pyfunction] - fn getfilesystemencoding(_vm: &VirtualMachine) -> String { - // TODO: implement non-utf-8 mode. - "utf-8".to_owned() + fn getfilesystemencoding(vm: &VirtualMachine) -> PyStrRef { + vm.fs_encoding().to_owned() } - #[cfg(not(windows))] #[pyfunction] - fn getfilesystemencodeerrors(_vm: &VirtualMachine) -> String { - "surrogateescape".to_owned() - } - - #[cfg(windows)] - #[pyfunction] - fn getfilesystemencodeerrors(_vm: &VirtualMachine) -> String { - "surrogatepass".to_owned() + fn getfilesystemencodeerrors(vm: &VirtualMachine) -> PyStrRef { + vm.fs_encode_errors().to_owned() } #[pyfunction] diff --git a/vm/src/vm/context.rs b/vm/src/vm/context.rs index 54605704a5..a61484e6bc 100644 --- a/vm/src/vm/context.rs +++ b/vm/src/vm/context.rs @@ -51,7 +51,7 @@ pub struct Context { } macro_rules! declare_const_name { - ($($name:ident,)*) => { + ($($name:ident$(: $s:literal)?,)*) => { #[derive(Debug, Clone, Copy)] #[allow(non_snake_case)] pub struct ConstName { @@ -61,11 +61,13 @@ macro_rules! declare_const_name { impl ConstName { unsafe fn new(pool: &StringPool, typ: &PyTypeRef) -> Self { Self { - $($name: unsafe { pool.intern(stringify!($name), typ.clone()) },)* + $($name: unsafe { pool.intern(declare_const_name!(@string $name $($s)?), typ.clone()) },)* } } } - } + }; + (@string $name:ident) => { stringify!($name) }; + (@string $name:ident $string:literal) => { $string }; } declare_const_name! { @@ -236,6 +238,15 @@ declare_const_name! { flush, close, WarningMessage, + strict, + ignore, + replace, + xmlcharrefreplace, + backslashreplace, + namereplace, + surrogatepass, + surrogateescape, + utf_8: "utf-8", } // Basic objects: diff --git a/vm/src/vm/interpreter.rs b/vm/src/vm/interpreter.rs index a375dbedc1..cc669e0661 100644 --- a/vm/src/vm/interpreter.rs +++ b/vm/src/vm/interpreter.rs @@ -163,7 +163,7 @@ mod tests { let b = vm.new_pyobj(4_i32); let res = vm._mul(&a, &b).unwrap(); let value = res.payload::().unwrap(); - assert_eq!(value.as_ref(), "Hello Hello Hello Hello ") + assert_eq!(value.as_str(), "Hello Hello Hello Hello ") }) } } diff --git a/vm/src/vm/mod.rs b/vm/src/vm/mod.rs index dd647e36f8..493789f510 100644 --- a/vm/src/vm/mod.rs +++ b/vm/src/vm/mod.rs @@ -41,11 +41,12 @@ use nix::{ sys::signal::{SaFlags, SigAction, SigSet, Signal::SIGINT, kill, sigaction}, unistd::getpid, }; -use std::sync::atomic::AtomicBool; use std::{ borrow::Cow, cell::{Cell, Ref, RefCell}, collections::{HashMap, HashSet}, + ffi::{OsStr, OsString}, + sync::atomic::AtomicBool, }; pub use context::Context; @@ -610,7 +611,7 @@ impl VirtualMachine { let from_list = from_list.to_pyobject(self); import_func .call((module.to_owned(), globals, locals, from_list, level), self) - .map_err(|exc| import::remove_importlib_frames(self, &exc)) + .inspect_err(|exc| import::remove_importlib_frames(self, exc)) } } } @@ -901,6 +902,54 @@ impl VirtualMachine { run_module_as_main.call((module,), self)?; Ok(()) } + + pub fn fs_encoding(&self) -> &'static PyStrInterned { + identifier!(self, utf_8) + } + + pub fn fs_encode_errors(&self) -> &'static PyStrInterned { + if cfg!(windows) { + identifier!(self, surrogatepass) + } else { + identifier!(self, surrogateescape) + } + } + + pub fn fsdecode(&self, s: impl Into) -> PyStrRef { + match s.into().into_string() { + Ok(s) => self.ctx.new_str(s), + Err(s) => { + let bytes = self.ctx.new_bytes(s.into_encoded_bytes()); + let errors = self.fs_encode_errors().to_owned(); + let res = self.state.codec_registry.decode_text( + bytes.into(), + "utf-8", + Some(errors), + self, + ); + self.expect_pyresult(res, "fsdecode should be lossless and never fail") + } + } + } + + pub fn fsencode<'a>(&self, s: &'a Py) -> PyResult> { + if cfg!(windows) || s.is_utf8() { + // XXX: this is sketchy on windows; it's not guaranteed that the + // OsStr encoding will always be compatible with WTF-8. + let s = unsafe { OsStr::from_encoded_bytes_unchecked(s.as_bytes()) }; + return Ok(Cow::Borrowed(s)); + } + let errors = self.fs_encode_errors().to_owned(); + let bytes = self + .state + .codec_registry + .encode_text(s.to_owned(), "utf-8", Some(errors), self)? + .to_vec(); + // XXX: this is sketchy on windows; it's not guaranteed that the + // OsStr encoding will always be compatible with WTF-8. + let s = unsafe { OsString::from_encoded_bytes_unchecked(bytes) }; + Ok(Cow::Owned(s)) + } } impl AsRef for VirtualMachine { diff --git a/vm/sre_engine/Cargo.toml b/vm/sre_engine/Cargo.toml index 504652f3a7..b34b01a0e8 100644 --- a/vm/sre_engine/Cargo.toml +++ b/vm/sre_engine/Cargo.toml @@ -15,6 +15,7 @@ name = "benches" harness = false [dependencies] +rustpython-common = { workspace = true } num_enum = { workspace = true } bitflags = { workspace = true } optional = "0.5" diff --git a/vm/sre_engine/src/string.rs b/vm/sre_engine/src/string.rs index 77e0f3e772..20cacbfbec 100644 --- a/vm/sre_engine/src/string.rs +++ b/vm/sre_engine/src/string.rs @@ -1,3 +1,5 @@ +use rustpython_common::wtf8::Wtf8; + #[derive(Debug, Clone, Copy)] pub struct StringCursor { pub(crate) ptr: *const u8, @@ -148,6 +150,72 @@ impl StrDrive for &str { } } +impl StrDrive for &Wtf8 { + #[inline] + fn count(&self) -> usize { + self.code_points().count() + } + + #[inline] + fn create_cursor(&self, n: usize) -> StringCursor { + let mut cursor = StringCursor { + ptr: self.as_bytes().as_ptr(), + position: 0, + }; + Self::skip(&mut cursor, n); + cursor + } + + #[inline] + fn adjust_cursor(&self, cursor: &mut StringCursor, n: usize) { + if cursor.ptr.is_null() || cursor.position > n { + *cursor = Self::create_cursor(self, n); + } else if cursor.position < n { + Self::skip(cursor, n - cursor.position); + } + } + + #[inline] + fn advance(cursor: &mut StringCursor) -> u32 { + cursor.position += 1; + unsafe { next_code_point(&mut cursor.ptr) } + } + + #[inline] + fn peek(cursor: &StringCursor) -> u32 { + let mut ptr = cursor.ptr; + unsafe { next_code_point(&mut ptr) } + } + + #[inline] + fn skip(cursor: &mut StringCursor, n: usize) { + cursor.position += n; + for _ in 0..n { + unsafe { next_code_point(&mut cursor.ptr) }; + } + } + + #[inline] + fn back_advance(cursor: &mut StringCursor) -> u32 { + cursor.position -= 1; + unsafe { next_code_point_reverse(&mut cursor.ptr) } + } + + #[inline] + fn back_peek(cursor: &StringCursor) -> u32 { + let mut ptr = cursor.ptr; + unsafe { next_code_point_reverse(&mut ptr) } + } + + #[inline] + fn back_skip(cursor: &mut StringCursor, n: usize) { + cursor.position -= n; + for _ in 0..n { + unsafe { next_code_point_reverse(&mut cursor.ptr) }; + } + } +} + /// Reads the next code point out of a byte iterator (assuming a /// UTF-8-like encoding). ///