1pub(crate) use decl::make_module;
2
3#[pymodule(name = "marshal")]
4mod decl {
5 use crate::builtins::code::{CodeObject, Literal, PyObjBag};
6 use crate::class::StaticType;
7 use crate::{
8 builtins::{
9 PyBool, PyByteArray, PyBytes, PyCode, PyComplex, PyDict, PyEllipsis, PyFloat,
10 PyFrozenSet, PyInt, PyList, PyNone, PySet, PyStopIteration, PyStr, PyTuple,
11 },
12 convert::ToPyObject,
13 function::{ArgBytesLike, OptionalArg},
14 object::AsObject,
15 protocol::PyBuffer,
16 PyObjectRef, PyResult, TryFromObject, VirtualMachine,
17 };
18 use malachite_bigint::BigInt;
19 use num_complex::Complex64;
20 use num_traits::Zero;
21 use rustpython_compiler_core::marshal;
22
23 #[pyattr(name = "version")]
24 use marshal::FORMAT_VERSION;
25
26 pub struct DumpError;
27
28 impl marshal::Dumpable for PyObjectRef {
29 type Error = DumpError;
30 type Constant = Literal;
31 fn with_dump<R>(
32 &self,
33 f: impl FnOnce(marshal::DumpableValue<'_, Self>) -> R,
34 ) -> Result<R, Self::Error> {
35 use marshal::DumpableValue::*;
36 if self.is(PyStopIteration::static_type()) {
37 return Ok(f(StopIter));
38 }
39 let ret = match_class!(match self {
40 PyNone => f(None),
41 PyEllipsis => f(Ellipsis),
42 ref pyint @ PyInt => {
43 if self.class().is(PyBool::static_type()) {
44 f(Boolean(!pyint.as_bigint().is_zero()))
45 } else {
46 f(Integer(pyint.as_bigint()))
47 }
48 }
49 ref pyfloat @ PyFloat => {
50 f(Float(pyfloat.to_f64()))
51 }
52 ref pycomplex @ PyComplex => {
53 f(Complex(pycomplex.to_complex64()))
54 }
55 ref pystr @ PyStr => {
56 f(Str(pystr.as_str()))
57 }
58 ref pylist @ PyList => {
59 f(List(&pylist.borrow_vec()))
60 }
61 ref pyset @ PySet => {
62 let elements = pyset.elements();
63 f(Set(&elements))
64 }
65 ref pyfrozen @ PyFrozenSet => {
66 let elements = pyfrozen.elements();
67 f(Frozenset(&elements))
68 }
69 ref pytuple @ PyTuple => {
70 f(Tuple(pytuple.as_slice()))
71 }
72 ref pydict @ PyDict => {
73 let entries = pydict.into_iter().collect::<Vec<_>>();
74 f(Dict(&entries))
75 }
76 ref bytes @ PyBytes => {
77 f(Bytes(bytes.as_bytes()))
78 }
79 ref bytes @ PyByteArray => {
80 f(Bytes(&bytes.borrow_buf()))
81 }
82 ref co @ PyCode => {
83 f(Code(co))
84 }
85 _ => return Err(DumpError),
86 });
87 Ok(ret)
88 }
89 }
90
91 #[pyfunction]
92 fn dumps(
93 value: PyObjectRef,
94 _version: OptionalArg<i32>,
95 vm: &VirtualMachine,
96 ) -> PyResult<PyBytes> {
97 use marshal::Dumpable;
98 let mut buf = Vec::new();
99 value
100 .with_dump(|val| marshal::serialize_value(&mut buf, val))
101 .unwrap_or_else(Err)
102 .map_err(|DumpError| {
103 vm.new_not_implemented_error(
104 "TODO: not implemented yet or marshal unsupported type".to_owned(),
105 )
106 })?;
107 Ok(PyBytes::from(buf))
108 }
109
110 #[pyfunction]
111 fn dump(
112 value: PyObjectRef,
113 f: PyObjectRef,
114 version: OptionalArg<i32>,
115 vm: &VirtualMachine,
116 ) -> PyResult<()> {
117 let dumped = dumps(value, version, vm)?;
118 vm.call_method(&f, "write", (dumped,))?;
119 Ok(())
120 }
121
122 #[derive(Copy, Clone)]
123 struct PyMarshalBag<'a>(&'a VirtualMachine);
124
125 impl<'a> marshal::MarshalBag for PyMarshalBag<'a> {
126 type Value = PyObjectRef;
127 fn make_bool(&self, value: bool) -> Self::Value {
128 self.0.ctx.new_bool(value).into()
129 }
130 fn make_none(&self) -> Self::Value {
131 self.0.ctx.none()
132 }
133 fn make_ellipsis(&self) -> Self::Value {
134 self.0.ctx.ellipsis()
135 }
136 fn make_float(&self, value: f64) -> Self::Value {
137 self.0.ctx.new_float(value).into()
138 }
139 fn make_complex(&self, value: Complex64) -> Self::Value {
140 self.0.ctx.new_complex(value).into()
141 }
142 fn make_str(&self, value: &str) -> Self::Value {
143 self.0.ctx.new_str(value).into()
144 }
145 fn make_bytes(&self, value: &[u8]) -> Self::Value {
146 self.0.ctx.new_bytes(value.to_vec()).into()
147 }
148 fn make_int(&self, value: BigInt) -> Self::Value {
149 self.0.ctx.new_int(value).into()
150 }
151 fn make_tuple(&self, elements: impl Iterator<Item = Self::Value>) -> Self::Value {
152 let elements = elements.collect();
153 self.0.ctx.new_tuple(elements).into()
154 }
155 fn make_code(&self, code: CodeObject) -> Self::Value {
156 self.0.ctx.new_code(code).into()
157 }
158 fn make_stop_iter(&self) -> Result<Self::Value, marshal::MarshalError> {
159 Ok(self.0.ctx.exceptions.stop_iteration.to_owned().into())
160 }
161 fn make_list(
162 &self,
163 it: impl Iterator<Item = Self::Value>,
164 ) -> Result<Self::Value, marshal::MarshalError> {
165 Ok(self.0.ctx.new_list(it.collect()).into())
166 }
167 fn make_set(
168 &self,
169 it: impl Iterator<Item = Self::Value>,
170 ) -> Result<Self::Value, marshal::MarshalError> {
171 let vm = self.0;
172 let set = PySet::new_ref(&vm.ctx);
173 for elem in it {
174 set.add(elem, vm).unwrap()
175 }
176 Ok(set.into())
177 }
178 fn make_frozenset(
179 &self,
180 it: impl Iterator<Item = Self::Value>,
181 ) -> Result<Self::Value, marshal::MarshalError> {
182 let vm = self.0;
183 Ok(PyFrozenSet::from_iter(vm, it).unwrap().to_pyobject(vm))
184 }
185 fn make_dict(
186 &self,
187 it: impl Iterator<Item = (Self::Value, Self::Value)>,
188 ) -> Result<Self::Value, marshal::MarshalError> {
189 let vm = self.0;
190 let dict = vm.ctx.new_dict();
191 for (k, v) in it {
192 dict.set_item(&*k, v, vm).unwrap()
193 }
194 Ok(dict.into())
195 }
196 type ConstantBag = PyObjBag<'a>;
197 fn constant_bag(self) -> Self::ConstantBag {
198 PyObjBag(&self.0.ctx)
199 }
200 }
201
202 #[pyfunction]
203 fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
204 let buf = pybuffer.as_contiguous().ok_or_else(|| {
205 vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous".to_owned())
206 })?;
207 marshal::deserialize_value(&mut &buf[..], PyMarshalBag(vm)).map_err(|e| match e {
208 marshal::MarshalError::Eof => vm.new_exception_msg(
209 vm.ctx.exceptions.eof_error.to_owned(),
210 "marshal data too short".to_owned(),
211 ),
212 marshal::MarshalError::InvalidBytecode => {
213 vm.new_value_error("Couldn't deserialize python bytecode".to_owned())
214 }
215 marshal::MarshalError::InvalidUtf8 => {
216 vm.new_value_error("invalid utf8 in marshalled string".to_owned())
217 }
218 marshal::MarshalError::InvalidLocation => {
219 vm.new_value_error("invalid location in marshalled object".to_owned())
220 }
221 marshal::MarshalError::BadType => {
222 vm.new_value_error("bad marshal data (unknown type code)".to_owned())
223 }
224 })
225 }
226
227 #[pyfunction]
228 fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
229 let read_res = vm.call_method(&f, "read", ())?;
230 let bytes = ArgBytesLike::try_from_object(vm, read_res)?;
231 loads(PyBuffer::from(bytes), vm)
232 }
233}