rustpython_vm/builtins/
range.rs

1use super::{
2    builtins_iter, tuple::tuple_hash, PyInt, PyIntRef, PySlice, PyTupleRef, PyType, PyTypeRef,
3};
4use crate::{
5    atomic_func,
6    class::PyClassImpl,
7    common::hash::PyHash,
8    function::{ArgIndex, FuncArgs, OptionalArg, PyComparisonValue},
9    protocol::{PyIterReturn, PyMappingMethods, PySequenceMethods},
10    types::{
11        AsMapping, AsSequence, Comparable, Hashable, IterNext, Iterable, PyComparisonOp,
12        Representable, SelfIter, Unconstructible,
13    },
14    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
15    VirtualMachine,
16};
17use crossbeam_utils::atomic::AtomicCell;
18use malachite_bigint::{BigInt, Sign};
19use num_integer::Integer;
20use num_traits::{One, Signed, ToPrimitive, Zero};
21use once_cell::sync::Lazy;
22use std::cmp::max;
23
24// Search flag passed to iter_search
25enum SearchType {
26    Count,
27    Contains,
28    Index,
29}
30
31// Note: might be a good idea to merge with _membership_iter_search or generalize (_sequence_iter_check?)
32// and place in vm.rs for all sequences to be able to use it.
33#[inline]
34fn iter_search(
35    obj: &PyObject,
36    item: &PyObject,
37    flag: SearchType,
38    vm: &VirtualMachine,
39) -> PyResult<usize> {
40    let mut count = 0;
41    let iter = obj.get_iter(vm)?;
42    for element in iter.iter_without_hint::<PyObjectRef>(vm)? {
43        if vm.bool_eq(item, &*element?)? {
44            match flag {
45                SearchType::Index => return Ok(count),
46                SearchType::Contains => return Ok(1),
47                SearchType::Count => count += 1,
48            }
49        }
50    }
51    match flag {
52        SearchType::Count => Ok(count),
53        SearchType::Contains => Ok(0),
54        SearchType::Index => Err(vm.new_value_error(format!(
55            "{} not in range",
56            item.repr(vm)
57                .map(|v| v.as_str().to_owned())
58                .unwrap_or_else(|_| "value".to_owned())
59        ))),
60    }
61}
62
63#[pyclass(module = false, name = "range")]
64#[derive(Debug, Clone)]
65pub struct PyRange {
66    pub start: PyIntRef,
67    pub stop: PyIntRef,
68    pub step: PyIntRef,
69}
70
71impl PyPayload for PyRange {
72    fn class(ctx: &Context) -> &'static Py<PyType> {
73        ctx.types.range_type
74    }
75}
76
77impl PyRange {
78    #[inline]
79    fn offset(&self, value: &BigInt) -> Option<BigInt> {
80        let start = self.start.as_bigint();
81        let stop = self.stop.as_bigint();
82        let step = self.step.as_bigint();
83        match step.sign() {
84            Sign::Plus if value >= start && value < stop => Some(value - start),
85            Sign::Minus if value <= self.start.as_bigint() && value > stop => Some(start - value),
86            _ => None,
87        }
88    }
89
90    #[inline]
91    pub fn index_of(&self, value: &BigInt) -> Option<BigInt> {
92        let step = self.step.as_bigint();
93        match self.offset(value) {
94            Some(ref offset) if offset.is_multiple_of(step) => Some((offset / step).abs()),
95            Some(_) | None => None,
96        }
97    }
98
99    #[inline]
100    pub fn is_empty(&self) -> bool {
101        self.compute_length().is_zero()
102    }
103
104    #[inline]
105    pub fn forward(&self) -> bool {
106        self.start.as_bigint() < self.stop.as_bigint()
107    }
108
109    #[inline]
110    pub fn get(&self, index: &BigInt) -> Option<BigInt> {
111        let start = self.start.as_bigint();
112        let step = self.step.as_bigint();
113        let stop = self.stop.as_bigint();
114        if self.is_empty() {
115            return None;
116        }
117
118        if index.is_negative() {
119            let length = self.compute_length();
120            let index: BigInt = &length + index;
121            if index.is_negative() {
122                return None;
123            }
124
125            Some(if step.is_one() {
126                start + index
127            } else {
128                start + step * index
129            })
130        } else {
131            let index = if step.is_one() {
132                start + index
133            } else {
134                start + step * index
135            };
136
137            if (step.is_positive() && stop > &index) || (step.is_negative() && stop < &index) {
138                Some(index)
139            } else {
140                None
141            }
142        }
143    }
144
145    #[inline]
146    fn compute_length(&self) -> BigInt {
147        let start = self.start.as_bigint();
148        let stop = self.stop.as_bigint();
149        let step = self.step.as_bigint();
150
151        match step.sign() {
152            Sign::Plus if start < stop => {
153                if step.is_one() {
154                    stop - start
155                } else {
156                    (stop - start - 1usize) / step + 1
157                }
158            }
159            Sign::Minus if start > stop => (start - stop - 1usize) / (-step) + 1,
160            Sign::Plus | Sign::Minus => BigInt::zero(),
161            Sign::NoSign => unreachable!(),
162        }
163    }
164}
165
166// pub fn get_value(obj: &PyObject) -> PyRange {
167//     obj.payload::<PyRange>().unwrap().clone()
168// }
169
170pub fn init(context: &Context) {
171    PyRange::extend_class(context, context.types.range_type);
172    PyLongRangeIterator::extend_class(context, context.types.long_range_iterator_type);
173    PyRangeIterator::extend_class(context, context.types.range_iterator_type);
174}
175
176#[pyclass(with(
177    Py,
178    AsMapping,
179    AsSequence,
180    Hashable,
181    Comparable,
182    Iterable,
183    Representable
184))]
185impl PyRange {
186    fn new(cls: PyTypeRef, stop: ArgIndex, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
187        PyRange {
188            start: vm.ctx.new_pyref(0),
189            stop: stop.into(),
190            step: vm.ctx.new_pyref(1),
191        }
192        .into_ref_with_type(vm, cls)
193    }
194
195    fn new_from(
196        cls: PyTypeRef,
197        start: PyObjectRef,
198        stop: PyObjectRef,
199        step: OptionalArg<ArgIndex>,
200        vm: &VirtualMachine,
201    ) -> PyResult<PyRef<Self>> {
202        let step = step.map_or_else(|| vm.ctx.new_int(1), |step| step.into());
203        if step.as_bigint().is_zero() {
204            return Err(vm.new_value_error("range() arg 3 must not be zero".to_owned()));
205        }
206        PyRange {
207            start: start.try_index(vm)?,
208            stop: stop.try_index(vm)?,
209            step,
210        }
211        .into_ref_with_type(vm, cls)
212    }
213
214    #[pygetset]
215    fn start(&self) -> PyIntRef {
216        self.start.clone()
217    }
218
219    #[pygetset]
220    fn stop(&self) -> PyIntRef {
221        self.stop.clone()
222    }
223
224    #[pygetset]
225    fn step(&self) -> PyIntRef {
226        self.step.clone()
227    }
228
229    #[pymethod(magic)]
230    fn reversed(&self, vm: &VirtualMachine) -> PyResult {
231        let start = self.start.as_bigint();
232        let step = self.step.as_bigint();
233
234        // Use CPython calculation for this:
235        let length = self.len();
236        let new_stop = start - step;
237        let start = &new_stop + length.clone() * step;
238        let step = -step;
239
240        Ok(
241            if let (Some(start), Some(step), Some(_)) =
242                (start.to_isize(), step.to_isize(), new_stop.to_isize())
243            {
244                PyRangeIterator {
245                    index: AtomicCell::new(0),
246                    start,
247                    step,
248                    // Cannot fail. If start, stop and step all successfully convert to isize, then result of zelf.len will
249                    // always fit in a usize.
250                    length: length.to_usize().unwrap_or(0),
251                }
252                .into_pyobject(vm)
253            } else {
254                PyLongRangeIterator {
255                    index: AtomicCell::new(0),
256                    start,
257                    step,
258                    length,
259                }
260                .into_pyobject(vm)
261            },
262        )
263    }
264
265    #[pymethod(magic)]
266    fn len(&self) -> BigInt {
267        self.compute_length()
268    }
269
270    #[pymethod(magic)]
271    fn bool(&self) -> bool {
272        !self.is_empty()
273    }
274
275    #[pymethod(magic)]
276    fn reduce(&self, vm: &VirtualMachine) -> (PyTypeRef, PyTupleRef) {
277        let range_parameters: Vec<PyObjectRef> = [&self.start, &self.stop, &self.step]
278            .iter()
279            .map(|x| x.as_object().to_owned())
280            .collect();
281        let range_parameters_tuple = vm.ctx.new_tuple(range_parameters);
282        (vm.ctx.types.range_type.to_owned(), range_parameters_tuple)
283    }
284
285    #[pymethod(magic)]
286    fn getitem(&self, subscript: PyObjectRef, vm: &VirtualMachine) -> PyResult {
287        match RangeIndex::try_from_object(vm, subscript)? {
288            RangeIndex::Slice(slice) => {
289                let (mut sub_start, mut sub_stop, mut sub_step) =
290                    slice.inner_indices(&self.compute_length(), vm)?;
291                let range_step = &self.step;
292                let range_start = &self.start;
293
294                sub_step *= range_step.as_bigint();
295                sub_start = (sub_start * range_step.as_bigint()) + range_start.as_bigint();
296                sub_stop = (sub_stop * range_step.as_bigint()) + range_start.as_bigint();
297
298                Ok(PyRange {
299                    start: vm.ctx.new_pyref(sub_start),
300                    stop: vm.ctx.new_pyref(sub_stop),
301                    step: vm.ctx.new_pyref(sub_step),
302                }
303                .into_ref(&vm.ctx)
304                .into())
305            }
306            RangeIndex::Int(index) => match self.get(index.as_bigint()) {
307                Some(value) => Ok(vm.ctx.new_int(value).into()),
308                None => Err(vm.new_index_error("range object index out of range".to_owned())),
309            },
310        }
311    }
312
313    #[pyslot]
314    fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
315        let range = if args.args.len() <= 1 {
316            let stop = args.bind(vm)?;
317            PyRange::new(cls, stop, vm)
318        } else {
319            let (start, stop, step) = args.bind(vm)?;
320            PyRange::new_from(cls, start, stop, step, vm)
321        }?;
322
323        Ok(range.into())
324    }
325}
326
327#[pyclass]
328impl Py<PyRange> {
329    fn contains_inner(&self, needle: &PyObject, vm: &VirtualMachine) -> bool {
330        // Only accept ints, not subclasses.
331        if let Some(int) = needle.downcast_ref_if_exact::<PyInt>(vm) {
332            match self.offset(int.as_bigint()) {
333                Some(ref offset) => offset.is_multiple_of(self.step.as_bigint()),
334                None => false,
335            }
336        } else {
337            iter_search(self.as_object(), needle, SearchType::Contains, vm).unwrap_or(0) != 0
338        }
339    }
340
341    #[pymethod(magic)]
342    fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> bool {
343        self.contains_inner(&needle, vm)
344    }
345
346    #[pymethod]
347    fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<BigInt> {
348        if let Ok(int) = needle.clone().downcast::<PyInt>() {
349            match self.index_of(int.as_bigint()) {
350                Some(idx) => Ok(idx),
351                None => Err(vm.new_value_error(format!("{int} is not in range"))),
352            }
353        } else {
354            // Fallback to iteration.
355            Ok(BigInt::from_bytes_be(
356                Sign::Plus,
357                &iter_search(self.as_object(), &needle, SearchType::Index, vm)?.to_be_bytes(),
358            ))
359        }
360    }
361
362    #[pymethod]
363    fn count(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
364        if let Ok(int) = item.clone().downcast::<PyInt>() {
365            let count = if self.index_of(int.as_bigint()).is_some() {
366                1
367            } else {
368                0
369            };
370            Ok(count)
371        } else {
372            // Dealing with classes who might compare equal with ints in their
373            // __eq__, slow search.
374            iter_search(self.as_object(), &item, SearchType::Count, vm)
375        }
376    }
377}
378
379impl PyRange {
380    fn protocol_length(&self, vm: &VirtualMachine) -> PyResult<usize> {
381        PyInt::from(self.len())
382            .try_to_primitive::<isize>(vm)
383            .map(|x| x as usize)
384    }
385}
386
387impl AsMapping for PyRange {
388    fn as_mapping() -> &'static PyMappingMethods {
389        static AS_MAPPING: Lazy<PyMappingMethods> = Lazy::new(|| PyMappingMethods {
390            length: atomic_func!(
391                |mapping, vm| PyRange::mapping_downcast(mapping).protocol_length(vm)
392            ),
393            subscript: atomic_func!(|mapping, needle, vm| {
394                PyRange::mapping_downcast(mapping).getitem(needle.to_owned(), vm)
395            }),
396            ..PyMappingMethods::NOT_IMPLEMENTED
397        });
398        &AS_MAPPING
399    }
400}
401
402impl AsSequence for PyRange {
403    fn as_sequence() -> &'static PySequenceMethods {
404        static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
405            length: atomic_func!(|seq, vm| PyRange::sequence_downcast(seq).protocol_length(vm)),
406            item: atomic_func!(|seq, i, vm| {
407                PyRange::sequence_downcast(seq)
408                    .get(&i.into())
409                    .map(|x| PyInt::from(x).into_ref(&vm.ctx).into())
410                    .ok_or_else(|| vm.new_index_error("index out of range".to_owned()))
411            }),
412            contains: atomic_func!(|seq, needle, vm| {
413                Ok(PyRange::sequence_downcast(seq).contains_inner(needle, vm))
414            }),
415            ..PySequenceMethods::NOT_IMPLEMENTED
416        });
417        &AS_SEQUENCE
418    }
419}
420
421impl Hashable for PyRange {
422    fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
423        let length = zelf.compute_length();
424        let elements = if length.is_zero() {
425            [vm.ctx.new_int(length).into(), vm.ctx.none(), vm.ctx.none()]
426        } else if length.is_one() {
427            [
428                vm.ctx.new_int(length).into(),
429                zelf.start().into(),
430                vm.ctx.none(),
431            ]
432        } else {
433            [
434                vm.ctx.new_int(length).into(),
435                zelf.start().into(),
436                zelf.step().into(),
437            ]
438        };
439        tuple_hash(&elements, vm)
440    }
441}
442
443impl Comparable for PyRange {
444    fn cmp(
445        zelf: &Py<Self>,
446        other: &PyObject,
447        op: PyComparisonOp,
448        _vm: &VirtualMachine,
449    ) -> PyResult<PyComparisonValue> {
450        op.eq_only(|| {
451            if zelf.is(other) {
452                return Ok(true.into());
453            }
454            let rhs = class_or_notimplemented!(Self, other);
455            let lhs_len = zelf.compute_length();
456            let eq = if lhs_len != rhs.compute_length() {
457                false
458            } else if lhs_len.is_zero() {
459                true
460            } else if zelf.start.as_bigint() != rhs.start.as_bigint() {
461                false
462            } else if lhs_len.is_one() {
463                true
464            } else {
465                zelf.step.as_bigint() == rhs.step.as_bigint()
466            };
467            Ok(eq.into())
468        })
469    }
470}
471
472impl Iterable for PyRange {
473    fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
474        let (start, stop, step, length) = (
475            zelf.start.as_bigint(),
476            zelf.stop.as_bigint(),
477            zelf.step.as_bigint(),
478            zelf.len(),
479        );
480        if let (Some(start), Some(step), Some(_), Some(_)) = (
481            start.to_isize(),
482            step.to_isize(),
483            stop.to_isize(),
484            (start + step).to_isize(),
485        ) {
486            Ok(PyRangeIterator {
487                index: AtomicCell::new(0),
488                start,
489                step,
490                // Cannot fail. If start, stop and step all successfully convert to isize, then result of zelf.len will
491                // always fit in a usize.
492                length: length.to_usize().unwrap_or(0),
493            }
494            .into_pyobject(vm))
495        } else {
496            Ok(PyLongRangeIterator {
497                index: AtomicCell::new(0),
498                start: start.clone(),
499                step: step.clone(),
500                length,
501            }
502            .into_pyobject(vm))
503        }
504    }
505}
506
507impl Representable for PyRange {
508    #[inline]
509    fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
510        let repr = if zelf.step.as_bigint().is_one() {
511            format!("range({}, {})", zelf.start, zelf.stop)
512        } else {
513            format!("range({}, {}, {})", zelf.start, zelf.stop, zelf.step)
514        };
515        Ok(repr)
516    }
517}
518
519// Semantically, this is the same as the previous representation.
520//
521// Unfortunately, since AtomicCell requires a Copy type, no BigInt implementations can
522// generally be used. As such, usize::MAX is the upper bound on number of elements (length)
523// the range can contain in RustPython.
524//
525// This doesn't preclude the range from containing large values, since start and step
526// can be BigInts, we can store any arbitrary range of values.
527#[pyclass(module = false, name = "longrange_iterator")]
528#[derive(Debug)]
529pub struct PyLongRangeIterator {
530    index: AtomicCell<usize>,
531    start: BigInt,
532    step: BigInt,
533    length: BigInt,
534}
535
536impl PyPayload for PyLongRangeIterator {
537    fn class(ctx: &Context) -> &'static Py<PyType> {
538        ctx.types.long_range_iterator_type
539    }
540}
541
542#[pyclass(with(Unconstructible, IterNext, Iterable))]
543impl PyLongRangeIterator {
544    #[pymethod(magic)]
545    fn length_hint(&self) -> BigInt {
546        let index = BigInt::from(self.index.load());
547        if index < self.length {
548            self.length.clone() - index
549        } else {
550            BigInt::zero()
551        }
552    }
553
554    #[pymethod(magic)]
555    fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
556        self.index.store(range_state(&self.length, state, vm)?);
557        Ok(())
558    }
559
560    #[pymethod(magic)]
561    fn reduce(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
562        range_iter_reduce(
563            self.start.clone(),
564            self.length.clone(),
565            self.step.clone(),
566            self.index.load(),
567            vm,
568        )
569    }
570}
571impl Unconstructible for PyLongRangeIterator {}
572
573impl SelfIter for PyLongRangeIterator {}
574impl IterNext for PyLongRangeIterator {
575    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
576        // TODO: In pathological case (index == usize::MAX) this can wrap around
577        // (since fetch_add wraps). This would result in the iterator spinning again
578        // from the beginning.
579        let index = BigInt::from(zelf.index.fetch_add(1));
580        let r = if index < zelf.length {
581            let value = zelf.start.clone() + index * zelf.step.clone();
582            PyIterReturn::Return(vm.ctx.new_int(value).into())
583        } else {
584            PyIterReturn::StopIteration(None)
585        };
586        Ok(r)
587    }
588}
589
590// When start, stop, step are isize, we can use a faster more compact representation
591// that only operates using isize to track values.
592#[pyclass(module = false, name = "range_iterator")]
593#[derive(Debug)]
594pub struct PyRangeIterator {
595    index: AtomicCell<usize>,
596    start: isize,
597    step: isize,
598    length: usize,
599}
600
601impl PyPayload for PyRangeIterator {
602    fn class(ctx: &Context) -> &'static Py<PyType> {
603        ctx.types.range_iterator_type
604    }
605}
606
607#[pyclass(with(Unconstructible, IterNext, Iterable))]
608impl PyRangeIterator {
609    #[pymethod(magic)]
610    fn length_hint(&self) -> usize {
611        let index = self.index.load();
612        if index < self.length {
613            self.length - index
614        } else {
615            0
616        }
617    }
618
619    #[pymethod(magic)]
620    fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
621        self.index
622            .store(range_state(&BigInt::from(self.length), state, vm)?);
623        Ok(())
624    }
625
626    #[pymethod(magic)]
627    fn reduce(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
628        range_iter_reduce(
629            BigInt::from(self.start),
630            BigInt::from(self.length),
631            BigInt::from(self.step),
632            self.index.load(),
633            vm,
634        )
635    }
636}
637impl Unconstructible for PyRangeIterator {}
638
639impl SelfIter for PyRangeIterator {}
640impl IterNext for PyRangeIterator {
641    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
642        // TODO: In pathological case (index == usize::MAX) this can wrap around
643        // (since fetch_add wraps). This would result in the iterator spinning again
644        // from the beginning.
645        let index = zelf.index.fetch_add(1);
646        let r = if index < zelf.length {
647            let value = zelf.start + (index as isize) * zelf.step;
648            PyIterReturn::Return(vm.ctx.new_int(value).into())
649        } else {
650            PyIterReturn::StopIteration(None)
651        };
652        Ok(r)
653    }
654}
655
656fn range_iter_reduce(
657    start: BigInt,
658    length: BigInt,
659    step: BigInt,
660    index: usize,
661    vm: &VirtualMachine,
662) -> PyResult<PyTupleRef> {
663    let iter = builtins_iter(vm).to_owned();
664    let stop = start.clone() + length * step.clone();
665    let range = PyRange {
666        start: PyInt::from(start).into_ref(&vm.ctx),
667        stop: PyInt::from(stop).into_ref(&vm.ctx),
668        step: PyInt::from(step).into_ref(&vm.ctx),
669    };
670    Ok(vm.new_tuple((iter, (range,), index)))
671}
672
673// Silently clips state (i.e index) in range [0, usize::MAX].
674fn range_state(length: &BigInt, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
675    if let Some(i) = state.payload::<PyInt>() {
676        let mut index = i.as_bigint();
677        let max_usize = BigInt::from(usize::MAX);
678        if index > length {
679            index = max(length, &max_usize);
680        }
681        Ok(index.to_usize().unwrap_or(0))
682    } else {
683        Err(vm.new_type_error("an integer is required.".to_owned()))
684    }
685}
686
687pub enum RangeIndex {
688    Int(PyIntRef),
689    Slice(PyRef<PySlice>),
690}
691
692impl TryFromObject for RangeIndex {
693    fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
694        match_class!(match obj {
695            i @ PyInt => Ok(RangeIndex::Int(i)),
696            s @ PySlice => Ok(RangeIndex::Slice(s)),
697            obj => {
698                let val = obj.try_index(vm).map_err(|_| vm.new_type_error(format!(
699                    "sequence indices be integers or slices or classes that override __index__ operator, not '{}'",
700                    obj.class().name()
701                )))?;
702                Ok(RangeIndex::Int(val))
703            }
704        })
705    }
706}
Morty Proxy This is a proxified and sanitized view of the page, visit original site.