Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 960e86c

Browse filesBrowse files
committed
Fix more surrogate crashes
1 parent e9e116b commit 960e86c
Copy full SHA for 960e86c

20 files changed

+126
-122
lines changed

‎Lib/test/test_json/test_scanstring.py

Copy file name to clipboardExpand all lines: Lib/test/test_json/test_scanstring.py
-2Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def test_scanstring(self):
8686
scanstring('["Bad value", truth]', 2, True),
8787
('Bad value', 12))
8888

89-
# TODO: RUSTPYTHON
90-
@unittest.expectedFailure
9189
def test_surrogates(self):
9290
scanstring = self.json.decoder.scanstring
9391
def assertScan(given, expect):

‎Lib/test/test_stringprep.py

Copy file name to clipboardExpand all lines: Lib/test/test_stringprep.py
-2Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from stringprep import *
77

88
class StringprepTests(unittest.TestCase):
9-
# TODO: RUSTPYTHON
10-
@unittest.expectedFailure
119
def test(self):
1210
self.assertTrue(in_table_a1("\u0221"))
1311
self.assertFalse(in_table_a1("\u0222"))

‎Lib/test/test_subprocess.py

Copy file name to clipboardExpand all lines: Lib/test/test_subprocess.py
-2Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,8 +1198,6 @@ def test_universal_newlines_communicate_encodings(self):
11981198
stdout, stderr = popen.communicate(input='')
11991199
self.assertEqual(stdout, '1\n2\n3\n4')
12001200

1201-
# TODO: RUSTPYTHON
1202-
@unittest.expectedFailure
12031201
def test_communicate_errors(self):
12041202
for errors, expected in [
12051203
('ignore', ''),

‎Lib/test/test_tarfile.py

Copy file name to clipboardExpand all lines: Lib/test/test_tarfile.py
-14Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2086,11 +2086,6 @@ class UstarUnicodeTest(UnicodeTest, unittest.TestCase):
20862086

20872087
format = tarfile.USTAR_FORMAT
20882088

2089-
# TODO: RUSTPYTHON
2090-
@unittest.expectedFailure
2091-
def test_uname_unicode(self):
2092-
super().test_uname_unicode()
2093-
20942089
# Test whether the utf-8 encoded version of a filename exceeds the 100
20952090
# bytes name field limit (every occurrence of '\xff' will be expanded to 2
20962091
# bytes).
@@ -2170,13 +2165,6 @@ class GNUUnicodeTest(UnicodeTest, unittest.TestCase):
21702165

21712166
format = tarfile.GNU_FORMAT
21722167

2173-
# TODO: RUSTPYTHON
2174-
@unittest.expectedFailure
2175-
def test_uname_unicode(self):
2176-
super().test_uname_unicode()
2177-
2178-
# TODO: RUSTPYTHON
2179-
@unittest.expectedFailure
21802168
def test_bad_pax_header(self):
21812169
# Test for issue #8633. GNU tar <= 1.23 creates raw binary fields
21822170
# without a hdrcharset=BINARY header.
@@ -2198,8 +2186,6 @@ class PAXUnicodeTest(UnicodeTest, unittest.TestCase):
21982186
# PAX_FORMAT ignores encoding in write mode.
21992187
test_unicode_filename_error = None
22002188

2201-
# TODO: RUSTPYTHON
2202-
@unittest.expectedFailure
22032189
def test_binary_header(self):
22042190
# Test a POSIX.1-2008 compatible header with a hdrcharset=BINARY field.
22052191
for encoding, name in (

‎Lib/test/test_unicode.py

Copy file name to clipboardExpand all lines: Lib/test/test_unicode.py
-8Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,6 @@ def test_bytes_comparison(self):
608608
self.assertEqual('abc' == bytearray(b'abc'), False)
609609
self.assertEqual('abc' != bytearray(b'abc'), True)
610610

611-
# TODO: RUSTPYTHON
612-
@unittest.expectedFailure
613611
def test_comparison(self):
614612
# Comparisons:
615613
self.assertEqual('abc', 'abc')
@@ -830,8 +828,6 @@ def test_isidentifier_legacy(self):
830828
warnings.simplefilter('ignore', DeprecationWarning)
831829
self.assertTrue(_testcapi.unicode_legacy_string(u).isidentifier())
832830

833-
# TODO: RUSTPYTHON
834-
@unittest.expectedFailure
835831
def test_isprintable(self):
836832
self.assertTrue("".isprintable())
837833
self.assertTrue(" ".isprintable())
@@ -847,8 +843,6 @@ def test_isprintable(self):
847843
self.assertTrue('\U0001F46F'.isprintable())
848844
self.assertFalse('\U000E0020'.isprintable())
849845

850-
# TODO: RUSTPYTHON
851-
@unittest.expectedFailure
852846
def test_surrogates(self):
853847
for s in ('a\uD800b\uDFFF', 'a\uDFFFb\uD800',
854848
'a\uD800b\uDFFFa', 'a\uDFFFb\uD800a'):
@@ -1827,8 +1821,6 @@ def test_codecs_utf7(self):
18271821
'ill-formed sequence'):
18281822
b'+@'.decode('utf-7')
18291823

1830-
# TODO: RUSTPYTHON
1831-
@unittest.expectedFailure
18321824
def test_codecs_utf8(self):
18331825
self.assertEqual(''.encode('utf-8'), b'')
18341826
self.assertEqual('\u20ac'.encode('utf-8'), b'\xe2\x82\xac')

‎Lib/test/test_userstring.py

Copy file name to clipboardExpand all lines: Lib/test/test_userstring.py
-4Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,13 @@ def __rmod__(self, other):
5353
str3 = ustr3('TEST')
5454
self.assertEqual(fmt2 % str3, 'value is TEST')
5555

56-
# TODO: RUSTPYTHON
57-
@unittest.expectedFailure
5856
def test_encode_default_args(self):
5957
self.checkequal(b'hello', 'hello', 'encode')
6058
# Check that encoding defaults to utf-8
6159
self.checkequal(b'\xf0\xa3\x91\x96', '\U00023456', 'encode')
6260
# Check that errors defaults to 'strict'
6361
self.checkraises(UnicodeError, '\ud800', 'encode')
6462

65-
# TODO: RUSTPYTHON
66-
@unittest.expectedFailure
6763
def test_encode_explicit_none_args(self):
6864
self.checkequal(b'hello', 'hello', 'encode', None, None)
6965
# Check that encoding defaults to utf-8

‎common/src/wtf8/mod.rs

Copy file name to clipboardExpand all lines: common/src/wtf8/mod.rs
+34-21Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ impl CodePoint {
122122

123123
/// Returns the numeric value of the code point if it is a leading surrogate.
124124
#[inline]
125-
pub fn to_lead_surrogate(self) -> Option<u16> {
125+
pub fn to_lead_surrogate(self) -> Option<LeadSurrogate> {
126126
match self.value {
127-
lead @ 0xD800..=0xDBFF => Some(lead as u16),
127+
lead @ 0xD800..=0xDBFF => Some(LeadSurrogate(lead as u16)),
128128
_ => None,
129129
}
130130
}
131131

132132
/// Returns the numeric value of the code point if it is a trailing surrogate.
133133
#[inline]
134-
pub fn to_trail_surrogate(self) -> Option<u16> {
134+
pub fn to_trail_surrogate(self) -> Option<TrailSurrogate> {
135135
match self.value {
136-
trail @ 0xDC00..=0xDFFF => Some(trail as u16),
136+
trail @ 0xDC00..=0xDFFF => Some(TrailSurrogate(trail as u16)),
137137
_ => None,
138138
}
139139
}
@@ -216,6 +216,18 @@ impl PartialEq<CodePoint> for char {
216216
}
217217
}
218218

219+
#[derive(Clone, Copy)]
220+
pub struct LeadSurrogate(u16);
221+
222+
#[derive(Clone, Copy)]
223+
pub struct TrailSurrogate(u16);
224+
225+
impl LeadSurrogate {
226+
pub fn merge(self, trail: TrailSurrogate) -> char {
227+
decode_surrogate_pair(self.0, trail.0)
228+
}
229+
}
230+
219231
/// An owned, growable string of well-formed WTF-8 data.
220232
///
221233
/// Similar to `String`, but can additionally contain surrogate code points
@@ -291,6 +303,14 @@ impl Wtf8Buf {
291303
Wtf8Buf { bytes: value }
292304
}
293305

306+
/// Create a WTF-8 string from a WTF-8 byte vec.
307+
pub fn from_bytes(value: Vec<u8>) -> Result<Self, Vec<u8>> {
308+
match Wtf8::from_bytes(&value) {
309+
Some(_) => Ok(unsafe { Self::from_bytes_unchecked(value) }),
310+
None => Err(value),
311+
}
312+
}
313+
294314
/// Creates a WTF-8 string from a UTF-8 `String`.
295315
///
296316
/// This takes ownership of the `String` and does not copy.
@@ -750,15 +770,10 @@ impl Wtf8 {
750770
}
751771

752772
fn decode_surrogate(b: &[u8]) -> Option<CodePoint> {
753-
let [a, b, c, ..] = *b else { return None };
754-
if (a & 0xf0) == 0xe0 && (b & 0xc0) == 0x80 && (c & 0xc0) == 0x80 {
755-
// it's a three-byte code
756-
let c = ((a as u32 & 0x0f) << 12) + ((b as u32 & 0x3f) << 6) + (c as u32 & 0x3f);
757-
let 0xD800..=0xDFFF = c else { return None };
758-
Some(CodePoint { value: c })
759-
} else {
760-
None
761-
}
773+
let [0xed, b2 @ (0xa0..), b3, ..] = *b else {
774+
return None;
775+
};
776+
Some(decode_surrogate(b2, b3).into())
762777
}
763778

764779
/// Returns the length, in WTF-8 bytes.
@@ -914,14 +929,6 @@ impl Wtf8 {
914929
}
915930
}
916931

917-
#[inline]
918-
fn final_lead_surrogate(&self) -> Option<u16> {
919-
match self.bytes {
920-
[.., 0xED, b2 @ 0xA0..=0xAF, b3] => Some(decode_surrogate(b2, b3)),
921-
_ => None,
922-
}
923-
}
924-
925932
pub fn is_code_point_boundary(&self, index: usize) -> bool {
926933
is_code_point_boundary(self, index)
927934
}
@@ -1222,6 +1229,12 @@ fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 {
12221229
0xD800 | (second_byte as u16 & 0x3F) << 6 | third_byte as u16 & 0x3F
12231230
}
12241231

1232+
#[inline]
1233+
fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
1234+
let code_point = 0x10000 + ((((lead - 0xD800) as u32) << 10) | (trail - 0xDC00) as u32);
1235+
unsafe { char::from_u32_unchecked(code_point) }
1236+
}
1237+
12251238
/// Copied from str::is_char_boundary
12261239
#[inline]
12271240
fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {

‎stdlib/src/json.rs

Copy file name to clipboardExpand all lines: stdlib/src/json.rs
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod _json {
1313
types::{Callable, Constructor},
1414
};
1515
use malachite_bigint::BigInt;
16+
use rustpython_common::wtf8::Wtf8Buf;
1617
use std::str::FromStr;
1718

1819
#[pyattr(name = "make_scanner")]
@@ -253,8 +254,8 @@ mod _json {
253254
end: usize,
254255
strict: OptionalArg<bool>,
255256
vm: &VirtualMachine,
256-
) -> PyResult<(String, usize)> {
257-
machinery::scanstring(s.as_str(), end, strict.unwrap_or(true))
257+
) -> PyResult<(Wtf8Buf, usize)> {
258+
machinery::scanstring(s.as_wtf8(), end, strict.unwrap_or(true))
258259
.map_err(|e| py_decode_error(e, s, vm))
259260
}
260261
}

‎stdlib/src/json/machinery.rs

Copy file name to clipboardExpand all lines: stdlib/src/json/machinery.rs
+34-39Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
use std::io;
3030

31+
use itertools::Itertools;
32+
use rustpython_common::wtf8::{CodePoint, Wtf8, Wtf8Buf};
33+
3134
static ESCAPE_CHARS: [&str; 0x20] = [
3235
"\\u0000", "\\u0001", "\\u0002", "\\u0003", "\\u0004", "\\u0005", "\\u0006", "\\u0007", "\\b",
3336
"\\t", "\\n", "\\u000", "\\f", "\\r", "\\u000e", "\\u000f", "\\u0010", "\\u0011", "\\u0012",
@@ -111,39 +114,39 @@ impl DecodeError {
111114
}
112115

113116
enum StrOrChar<'a> {
114-
Str(&'a str),
115-
Char(char),
117+
Str(&'a Wtf8),
118+
Char(CodePoint),
116119
}
117120
impl StrOrChar<'_> {
118121
fn len(&self) -> usize {
119122
match self {
120123
StrOrChar::Str(s) => s.len(),
121-
StrOrChar::Char(c) => c.len_utf8(),
124+
StrOrChar::Char(c) => c.len_wtf8(),
122125
}
123126
}
124127
}
125128
pub fn scanstring<'a>(
126-
s: &'a str,
129+
s: &'a Wtf8,
127130
end: usize,
128131
strict: bool,
129-
) -> Result<(String, usize), DecodeError> {
132+
) -> Result<(Wtf8Buf, usize), DecodeError> {
130133
let mut chunks: Vec<StrOrChar<'a>> = Vec::new();
131134
let mut output_len = 0usize;
132135
let mut push_chunk = |chunk: StrOrChar<'a>| {
133136
output_len += chunk.len();
134137
chunks.push(chunk);
135138
};
136139
let unterminated_err = || DecodeError::new("Unterminated string starting at", end - 1);
137-
let mut chars = s.char_indices().enumerate().skip(end).peekable();
140+
let mut chars = s.code_point_indices().enumerate().skip(end).peekable();
138141
let &(_, (mut chunk_start, _)) = chars.peek().ok_or_else(unterminated_err)?;
139142
while let Some((char_i, (byte_i, c))) = chars.next() {
140-
match c {
143+
match c.to_char_lossy() {
141144
'"' => {
142145
push_chunk(StrOrChar::Str(&s[chunk_start..byte_i]));
143-
let mut out = String::with_capacity(output_len);
146+
let mut out = Wtf8Buf::with_capacity(output_len);
144147
for x in chunks {
145148
match x {
146-
StrOrChar::Str(s) => out.push_str(s),
149+
StrOrChar::Str(s) => out.push_wtf8(s),
147150
StrOrChar::Char(c) => out.push(c),
148151
}
149152
}
@@ -152,7 +155,7 @@ pub fn scanstring<'a>(
152155
'\\' => {
153156
push_chunk(StrOrChar::Str(&s[chunk_start..byte_i]));
154157
let (_, (_, c)) = chars.next().ok_or_else(unterminated_err)?;
155-
let esc = match c {
158+
let esc = match c.to_char_lossy() {
156159
'"' => "\"",
157160
'\\' => "\\",
158161
'/' => "/",
@@ -162,41 +165,33 @@ pub fn scanstring<'a>(
162165
'r' => "\r",
163166
't' => "\t",
164167
'u' => {
165-
let surrogate_err = || DecodeError::new("unpaired surrogate", char_i);
166168
let mut uni = decode_unicode(&mut chars, char_i)?;
167169
chunk_start = byte_i + 6;
168-
if (0xd800..=0xdbff).contains(&uni) {
170+
if let Some(lead) = uni.to_lead_surrogate() {
169171
// uni is a surrogate -- try to find its pair
170-
if let Some(&(pos2, (_, '\\'))) = chars.peek() {
171-
// ok, the next char starts an escape
172-
chars.next();
173-
if let Some((_, (_, 'u'))) = chars.peek() {
174-
// ok, it's a unicode escape
175-
chars.next();
176-
let uni2 = decode_unicode(&mut chars, pos2)?;
172+
let mut chars2 = chars.clone();
173+
if let Some(((pos2, _), (_, _))) = chars2
174+
.next_tuple()
175+
.filter(|((_, (_, c1)), (_, (_, c2)))| *c1 == '\\' && *c2 == 'u')
176+
{
177+
let uni2 = decode_unicode(&mut chars2, pos2)?;
178+
if let Some(trail) = uni2.to_trail_surrogate() {
179+
// ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates
180+
uni = lead.merge(trail).into();
177181
chunk_start = pos2 + 6;
178-
if (0xdc00..=0xdfff).contains(&uni2) {
179-
// ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates
180-
uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00));
181-
} else {
182-
// if we don't find a matching surrogate, error -- until str
183-
// isn't utf8 internally, we can't parse surrogates
184-
return Err(surrogate_err());
185-
}
186-
} else {
187-
return Err(surrogate_err());
182+
chars = chars2;
188183
}
189184
}
190185
}
191-
push_chunk(StrOrChar::Char(
192-
std::char::from_u32(uni).ok_or_else(surrogate_err)?,
193-
));
186+
push_chunk(StrOrChar::Char(uni));
194187
continue;
195188
}
196-
_ => return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i)),
189+
_ => {
190+
return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i));
191+
}
197192
};
198193
chunk_start = byte_i + 2;
199-
push_chunk(StrOrChar::Str(esc));
194+
push_chunk(StrOrChar::Str(esc.as_ref()));
200195
}
201196
'\x00'..='\x1f' if strict => {
202197
return Err(DecodeError::new(
@@ -211,16 +206,16 @@ pub fn scanstring<'a>(
211206
}
212207

213208
#[inline]
214-
fn decode_unicode<I>(it: &mut I, pos: usize) -> Result<u32, DecodeError>
209+
fn decode_unicode<I>(it: &mut I, pos: usize) -> Result<CodePoint, DecodeError>
215210
where
216-
I: Iterator<Item = (usize, (usize, char))>,
211+
I: Iterator<Item = (usize, (usize, CodePoint))>,
217212
{
218213
let err = || DecodeError::new("Invalid \\uXXXX escape", pos);
219214
let mut uni = 0;
220215
for x in (0..4).rev() {
221216
let (_, (_, c)) = it.next().ok_or_else(err)?;
222-
let d = c.to_digit(16).ok_or_else(err)?;
223-
uni += d * 16u32.pow(x);
217+
let d = c.to_char().and_then(|c| c.to_digit(16)).ok_or_else(err)? as u16;
218+
uni += d * 16u16.pow(x);
224219
}
225-
Ok(uni)
220+
Ok(uni.into())
226221
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.