1use super::{
2 builtins_iter, tuple::tuple_hash, PyInt, PyIntRef, PySlice, PyTupleRef, PyType, PyTypeRef,
3};
4use crate::{
5 atomic_func,
6 class::PyClassImpl,
7 common::hash::PyHash,
8 function::{ArgIndex, FuncArgs, OptionalArg, PyComparisonValue},
9 protocol::{PyIterReturn, PyMappingMethods, PySequenceMethods},
10 types::{
11 AsMapping, AsSequence, Comparable, Hashable, IterNext, Iterable, PyComparisonOp,
12 Representable, SelfIter, Unconstructible,
13 },
14 AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
15 VirtualMachine,
16};
17use crossbeam_utils::atomic::AtomicCell;
18use malachite_bigint::{BigInt, Sign};
19use num_integer::Integer;
20use num_traits::{One, Signed, ToPrimitive, Zero};
21use once_cell::sync::Lazy;
22use std::cmp::max;
23
24enum SearchType {
26 Count,
27 Contains,
28 Index,
29}
30
31#[inline]
34fn iter_search(
35 obj: &PyObject,
36 item: &PyObject,
37 flag: SearchType,
38 vm: &VirtualMachine,
39) -> PyResult<usize> {
40 let mut count = 0;
41 let iter = obj.get_iter(vm)?;
42 for element in iter.iter_without_hint::<PyObjectRef>(vm)? {
43 if vm.bool_eq(item, &*element?)? {
44 match flag {
45 SearchType::Index => return Ok(count),
46 SearchType::Contains => return Ok(1),
47 SearchType::Count => count += 1,
48 }
49 }
50 }
51 match flag {
52 SearchType::Count => Ok(count),
53 SearchType::Contains => Ok(0),
54 SearchType::Index => Err(vm.new_value_error(format!(
55 "{} not in range",
56 item.repr(vm)
57 .map(|v| v.as_str().to_owned())
58 .unwrap_or_else(|_| "value".to_owned())
59 ))),
60 }
61}
62
63#[pyclass(module = false, name = "range")]
64#[derive(Debug, Clone)]
65pub struct PyRange {
66 pub start: PyIntRef,
67 pub stop: PyIntRef,
68 pub step: PyIntRef,
69}
70
71impl PyPayload for PyRange {
72 fn class(ctx: &Context) -> &'static Py<PyType> {
73 ctx.types.range_type
74 }
75}
76
77impl PyRange {
78 #[inline]
79 fn offset(&self, value: &BigInt) -> Option<BigInt> {
80 let start = self.start.as_bigint();
81 let stop = self.stop.as_bigint();
82 let step = self.step.as_bigint();
83 match step.sign() {
84 Sign::Plus if value >= start && value < stop => Some(value - start),
85 Sign::Minus if value <= self.start.as_bigint() && value > stop => Some(start - value),
86 _ => None,
87 }
88 }
89
90 #[inline]
91 pub fn index_of(&self, value: &BigInt) -> Option<BigInt> {
92 let step = self.step.as_bigint();
93 match self.offset(value) {
94 Some(ref offset) if offset.is_multiple_of(step) => Some((offset / step).abs()),
95 Some(_) | None => None,
96 }
97 }
98
99 #[inline]
100 pub fn is_empty(&self) -> bool {
101 self.compute_length().is_zero()
102 }
103
104 #[inline]
105 pub fn forward(&self) -> bool {
106 self.start.as_bigint() < self.stop.as_bigint()
107 }
108
109 #[inline]
110 pub fn get(&self, index: &BigInt) -> Option<BigInt> {
111 let start = self.start.as_bigint();
112 let step = self.step.as_bigint();
113 let stop = self.stop.as_bigint();
114 if self.is_empty() {
115 return None;
116 }
117
118 if index.is_negative() {
119 let length = self.compute_length();
120 let index: BigInt = &length + index;
121 if index.is_negative() {
122 return None;
123 }
124
125 Some(if step.is_one() {
126 start + index
127 } else {
128 start + step * index
129 })
130 } else {
131 let index = if step.is_one() {
132 start + index
133 } else {
134 start + step * index
135 };
136
137 if (step.is_positive() && stop > &index) || (step.is_negative() && stop < &index) {
138 Some(index)
139 } else {
140 None
141 }
142 }
143 }
144
145 #[inline]
146 fn compute_length(&self) -> BigInt {
147 let start = self.start.as_bigint();
148 let stop = self.stop.as_bigint();
149 let step = self.step.as_bigint();
150
151 match step.sign() {
152 Sign::Plus if start < stop => {
153 if step.is_one() {
154 stop - start
155 } else {
156 (stop - start - 1usize) / step + 1
157 }
158 }
159 Sign::Minus if start > stop => (start - stop - 1usize) / (-step) + 1,
160 Sign::Plus | Sign::Minus => BigInt::zero(),
161 Sign::NoSign => unreachable!(),
162 }
163 }
164}
165
166pub fn init(context: &Context) {
171 PyRange::extend_class(context, context.types.range_type);
172 PyLongRangeIterator::extend_class(context, context.types.long_range_iterator_type);
173 PyRangeIterator::extend_class(context, context.types.range_iterator_type);
174}
175
176#[pyclass(with(
177 Py,
178 AsMapping,
179 AsSequence,
180 Hashable,
181 Comparable,
182 Iterable,
183 Representable
184))]
185impl PyRange {
186 fn new(cls: PyTypeRef, stop: ArgIndex, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
187 PyRange {
188 start: vm.ctx.new_pyref(0),
189 stop: stop.into(),
190 step: vm.ctx.new_pyref(1),
191 }
192 .into_ref_with_type(vm, cls)
193 }
194
195 fn new_from(
196 cls: PyTypeRef,
197 start: PyObjectRef,
198 stop: PyObjectRef,
199 step: OptionalArg<ArgIndex>,
200 vm: &VirtualMachine,
201 ) -> PyResult<PyRef<Self>> {
202 let step = step.map_or_else(|| vm.ctx.new_int(1), |step| step.into());
203 if step.as_bigint().is_zero() {
204 return Err(vm.new_value_error("range() arg 3 must not be zero".to_owned()));
205 }
206 PyRange {
207 start: start.try_index(vm)?,
208 stop: stop.try_index(vm)?,
209 step,
210 }
211 .into_ref_with_type(vm, cls)
212 }
213
214 #[pygetset]
215 fn start(&self) -> PyIntRef {
216 self.start.clone()
217 }
218
219 #[pygetset]
220 fn stop(&self) -> PyIntRef {
221 self.stop.clone()
222 }
223
224 #[pygetset]
225 fn step(&self) -> PyIntRef {
226 self.step.clone()
227 }
228
229 #[pymethod(magic)]
230 fn reversed(&self, vm: &VirtualMachine) -> PyResult {
231 let start = self.start.as_bigint();
232 let step = self.step.as_bigint();
233
234 let length = self.len();
236 let new_stop = start - step;
237 let start = &new_stop + length.clone() * step;
238 let step = -step;
239
240 Ok(
241 if let (Some(start), Some(step), Some(_)) =
242 (start.to_isize(), step.to_isize(), new_stop.to_isize())
243 {
244 PyRangeIterator {
245 index: AtomicCell::new(0),
246 start,
247 step,
248 length: length.to_usize().unwrap_or(0),
251 }
252 .into_pyobject(vm)
253 } else {
254 PyLongRangeIterator {
255 index: AtomicCell::new(0),
256 start,
257 step,
258 length,
259 }
260 .into_pyobject(vm)
261 },
262 )
263 }
264
265 #[pymethod(magic)]
266 fn len(&self) -> BigInt {
267 self.compute_length()
268 }
269
270 #[pymethod(magic)]
271 fn bool(&self) -> bool {
272 !self.is_empty()
273 }
274
275 #[pymethod(magic)]
276 fn reduce(&self, vm: &VirtualMachine) -> (PyTypeRef, PyTupleRef) {
277 let range_parameters: Vec<PyObjectRef> = [&self.start, &self.stop, &self.step]
278 .iter()
279 .map(|x| x.as_object().to_owned())
280 .collect();
281 let range_parameters_tuple = vm.ctx.new_tuple(range_parameters);
282 (vm.ctx.types.range_type.to_owned(), range_parameters_tuple)
283 }
284
285 #[pymethod(magic)]
286 fn getitem(&self, subscript: PyObjectRef, vm: &VirtualMachine) -> PyResult {
287 match RangeIndex::try_from_object(vm, subscript)? {
288 RangeIndex::Slice(slice) => {
289 let (mut sub_start, mut sub_stop, mut sub_step) =
290 slice.inner_indices(&self.compute_length(), vm)?;
291 let range_step = &self.step;
292 let range_start = &self.start;
293
294 sub_step *= range_step.as_bigint();
295 sub_start = (sub_start * range_step.as_bigint()) + range_start.as_bigint();
296 sub_stop = (sub_stop * range_step.as_bigint()) + range_start.as_bigint();
297
298 Ok(PyRange {
299 start: vm.ctx.new_pyref(sub_start),
300 stop: vm.ctx.new_pyref(sub_stop),
301 step: vm.ctx.new_pyref(sub_step),
302 }
303 .into_ref(&vm.ctx)
304 .into())
305 }
306 RangeIndex::Int(index) => match self.get(index.as_bigint()) {
307 Some(value) => Ok(vm.ctx.new_int(value).into()),
308 None => Err(vm.new_index_error("range object index out of range".to_owned())),
309 },
310 }
311 }
312
313 #[pyslot]
314 fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
315 let range = if args.args.len() <= 1 {
316 let stop = args.bind(vm)?;
317 PyRange::new(cls, stop, vm)
318 } else {
319 let (start, stop, step) = args.bind(vm)?;
320 PyRange::new_from(cls, start, stop, step, vm)
321 }?;
322
323 Ok(range.into())
324 }
325}
326
327#[pyclass]
328impl Py<PyRange> {
329 fn contains_inner(&self, needle: &PyObject, vm: &VirtualMachine) -> bool {
330 if let Some(int) = needle.downcast_ref_if_exact::<PyInt>(vm) {
332 match self.offset(int.as_bigint()) {
333 Some(ref offset) => offset.is_multiple_of(self.step.as_bigint()),
334 None => false,
335 }
336 } else {
337 iter_search(self.as_object(), needle, SearchType::Contains, vm).unwrap_or(0) != 0
338 }
339 }
340
341 #[pymethod(magic)]
342 fn contains(&self, needle: PyObjectRef, vm: &VirtualMachine) -> bool {
343 self.contains_inner(&needle, vm)
344 }
345
346 #[pymethod]
347 fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<BigInt> {
348 if let Ok(int) = needle.clone().downcast::<PyInt>() {
349 match self.index_of(int.as_bigint()) {
350 Some(idx) => Ok(idx),
351 None => Err(vm.new_value_error(format!("{int} is not in range"))),
352 }
353 } else {
354 Ok(BigInt::from_bytes_be(
356 Sign::Plus,
357 &iter_search(self.as_object(), &needle, SearchType::Index, vm)?.to_be_bytes(),
358 ))
359 }
360 }
361
362 #[pymethod]
363 fn count(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
364 if let Ok(int) = item.clone().downcast::<PyInt>() {
365 let count = if self.index_of(int.as_bigint()).is_some() {
366 1
367 } else {
368 0
369 };
370 Ok(count)
371 } else {
372 iter_search(self.as_object(), &item, SearchType::Count, vm)
375 }
376 }
377}
378
379impl PyRange {
380 fn protocol_length(&self, vm: &VirtualMachine) -> PyResult<usize> {
381 PyInt::from(self.len())
382 .try_to_primitive::<isize>(vm)
383 .map(|x| x as usize)
384 }
385}
386
387impl AsMapping for PyRange {
388 fn as_mapping() -> &'static PyMappingMethods {
389 static AS_MAPPING: Lazy<PyMappingMethods> = Lazy::new(|| PyMappingMethods {
390 length: atomic_func!(
391 |mapping, vm| PyRange::mapping_downcast(mapping).protocol_length(vm)
392 ),
393 subscript: atomic_func!(|mapping, needle, vm| {
394 PyRange::mapping_downcast(mapping).getitem(needle.to_owned(), vm)
395 }),
396 ..PyMappingMethods::NOT_IMPLEMENTED
397 });
398 &AS_MAPPING
399 }
400}
401
402impl AsSequence for PyRange {
403 fn as_sequence() -> &'static PySequenceMethods {
404 static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
405 length: atomic_func!(|seq, vm| PyRange::sequence_downcast(seq).protocol_length(vm)),
406 item: atomic_func!(|seq, i, vm| {
407 PyRange::sequence_downcast(seq)
408 .get(&i.into())
409 .map(|x| PyInt::from(x).into_ref(&vm.ctx).into())
410 .ok_or_else(|| vm.new_index_error("index out of range".to_owned()))
411 }),
412 contains: atomic_func!(|seq, needle, vm| {
413 Ok(PyRange::sequence_downcast(seq).contains_inner(needle, vm))
414 }),
415 ..PySequenceMethods::NOT_IMPLEMENTED
416 });
417 &AS_SEQUENCE
418 }
419}
420
421impl Hashable for PyRange {
422 fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
423 let length = zelf.compute_length();
424 let elements = if length.is_zero() {
425 [vm.ctx.new_int(length).into(), vm.ctx.none(), vm.ctx.none()]
426 } else if length.is_one() {
427 [
428 vm.ctx.new_int(length).into(),
429 zelf.start().into(),
430 vm.ctx.none(),
431 ]
432 } else {
433 [
434 vm.ctx.new_int(length).into(),
435 zelf.start().into(),
436 zelf.step().into(),
437 ]
438 };
439 tuple_hash(&elements, vm)
440 }
441}
442
443impl Comparable for PyRange {
444 fn cmp(
445 zelf: &Py<Self>,
446 other: &PyObject,
447 op: PyComparisonOp,
448 _vm: &VirtualMachine,
449 ) -> PyResult<PyComparisonValue> {
450 op.eq_only(|| {
451 if zelf.is(other) {
452 return Ok(true.into());
453 }
454 let rhs = class_or_notimplemented!(Self, other);
455 let lhs_len = zelf.compute_length();
456 let eq = if lhs_len != rhs.compute_length() {
457 false
458 } else if lhs_len.is_zero() {
459 true
460 } else if zelf.start.as_bigint() != rhs.start.as_bigint() {
461 false
462 } else if lhs_len.is_one() {
463 true
464 } else {
465 zelf.step.as_bigint() == rhs.step.as_bigint()
466 };
467 Ok(eq.into())
468 })
469 }
470}
471
472impl Iterable for PyRange {
473 fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
474 let (start, stop, step, length) = (
475 zelf.start.as_bigint(),
476 zelf.stop.as_bigint(),
477 zelf.step.as_bigint(),
478 zelf.len(),
479 );
480 if let (Some(start), Some(step), Some(_), Some(_)) = (
481 start.to_isize(),
482 step.to_isize(),
483 stop.to_isize(),
484 (start + step).to_isize(),
485 ) {
486 Ok(PyRangeIterator {
487 index: AtomicCell::new(0),
488 start,
489 step,
490 length: length.to_usize().unwrap_or(0),
493 }
494 .into_pyobject(vm))
495 } else {
496 Ok(PyLongRangeIterator {
497 index: AtomicCell::new(0),
498 start: start.clone(),
499 step: step.clone(),
500 length,
501 }
502 .into_pyobject(vm))
503 }
504 }
505}
506
507impl Representable for PyRange {
508 #[inline]
509 fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
510 let repr = if zelf.step.as_bigint().is_one() {
511 format!("range({}, {})", zelf.start, zelf.stop)
512 } else {
513 format!("range({}, {}, {})", zelf.start, zelf.stop, zelf.step)
514 };
515 Ok(repr)
516 }
517}
518
519#[pyclass(module = false, name = "longrange_iterator")]
528#[derive(Debug)]
529pub struct PyLongRangeIterator {
530 index: AtomicCell<usize>,
531 start: BigInt,
532 step: BigInt,
533 length: BigInt,
534}
535
536impl PyPayload for PyLongRangeIterator {
537 fn class(ctx: &Context) -> &'static Py<PyType> {
538 ctx.types.long_range_iterator_type
539 }
540}
541
542#[pyclass(with(Unconstructible, IterNext, Iterable))]
543impl PyLongRangeIterator {
544 #[pymethod(magic)]
545 fn length_hint(&self) -> BigInt {
546 let index = BigInt::from(self.index.load());
547 if index < self.length {
548 self.length.clone() - index
549 } else {
550 BigInt::zero()
551 }
552 }
553
554 #[pymethod(magic)]
555 fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
556 self.index.store(range_state(&self.length, state, vm)?);
557 Ok(())
558 }
559
560 #[pymethod(magic)]
561 fn reduce(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
562 range_iter_reduce(
563 self.start.clone(),
564 self.length.clone(),
565 self.step.clone(),
566 self.index.load(),
567 vm,
568 )
569 }
570}
571impl Unconstructible for PyLongRangeIterator {}
572
573impl SelfIter for PyLongRangeIterator {}
574impl IterNext for PyLongRangeIterator {
575 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
576 let index = BigInt::from(zelf.index.fetch_add(1));
580 let r = if index < zelf.length {
581 let value = zelf.start.clone() + index * zelf.step.clone();
582 PyIterReturn::Return(vm.ctx.new_int(value).into())
583 } else {
584 PyIterReturn::StopIteration(None)
585 };
586 Ok(r)
587 }
588}
589
590#[pyclass(module = false, name = "range_iterator")]
593#[derive(Debug)]
594pub struct PyRangeIterator {
595 index: AtomicCell<usize>,
596 start: isize,
597 step: isize,
598 length: usize,
599}
600
601impl PyPayload for PyRangeIterator {
602 fn class(ctx: &Context) -> &'static Py<PyType> {
603 ctx.types.range_iterator_type
604 }
605}
606
607#[pyclass(with(Unconstructible, IterNext, Iterable))]
608impl PyRangeIterator {
609 #[pymethod(magic)]
610 fn length_hint(&self) -> usize {
611 let index = self.index.load();
612 if index < self.length {
613 self.length - index
614 } else {
615 0
616 }
617 }
618
619 #[pymethod(magic)]
620 fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
621 self.index
622 .store(range_state(&BigInt::from(self.length), state, vm)?);
623 Ok(())
624 }
625
626 #[pymethod(magic)]
627 fn reduce(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
628 range_iter_reduce(
629 BigInt::from(self.start),
630 BigInt::from(self.length),
631 BigInt::from(self.step),
632 self.index.load(),
633 vm,
634 )
635 }
636}
637impl Unconstructible for PyRangeIterator {}
638
639impl SelfIter for PyRangeIterator {}
640impl IterNext for PyRangeIterator {
641 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
642 let index = zelf.index.fetch_add(1);
646 let r = if index < zelf.length {
647 let value = zelf.start + (index as isize) * zelf.step;
648 PyIterReturn::Return(vm.ctx.new_int(value).into())
649 } else {
650 PyIterReturn::StopIteration(None)
651 };
652 Ok(r)
653 }
654}
655
656fn range_iter_reduce(
657 start: BigInt,
658 length: BigInt,
659 step: BigInt,
660 index: usize,
661 vm: &VirtualMachine,
662) -> PyResult<PyTupleRef> {
663 let iter = builtins_iter(vm).to_owned();
664 let stop = start.clone() + length * step.clone();
665 let range = PyRange {
666 start: PyInt::from(start).into_ref(&vm.ctx),
667 stop: PyInt::from(stop).into_ref(&vm.ctx),
668 step: PyInt::from(step).into_ref(&vm.ctx),
669 };
670 Ok(vm.new_tuple((iter, (range,), index)))
671}
672
673fn range_state(length: &BigInt, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
675 if let Some(i) = state.payload::<PyInt>() {
676 let mut index = i.as_bigint();
677 let max_usize = BigInt::from(usize::MAX);
678 if index > length {
679 index = max(length, &max_usize);
680 }
681 Ok(index.to_usize().unwrap_or(0))
682 } else {
683 Err(vm.new_type_error("an integer is required.".to_owned()))
684 }
685}
686
687pub enum RangeIndex {
688 Int(PyIntRef),
689 Slice(PyRef<PySlice>),
690}
691
692impl TryFromObject for RangeIndex {
693 fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
694 match_class!(match obj {
695 i @ PyInt => Ok(RangeIndex::Int(i)),
696 s @ PySlice => Ok(RangeIndex::Slice(s)),
697 obj => {
698 let val = obj.try_index(vm).map_err(|_| vm.new_type_error(format!(
699 "sequence indices be integers or slices or classes that override __index__ operator, not '{}'",
700 obj.class().name()
701 )))?;
702 Ok(RangeIndex::Int(val))
703 }
704 })
705 }
706}