diff --git a/construct/core.py b/construct/core.py index f56ca163..ae83beb5 100644 --- a/construct/core.py +++ b/construct/core.py @@ -1,10 +1,24 @@ # -*- coding: utf-8 -*- -import struct, io, binascii, itertools, collections, pickle, sys, os, hashlib, importlib, importlib.machinery, importlib.util - +import struct +import io +import binascii +import itertools +import collections +import pickle +import sys +import hashlib +import importlib +import importlib.machinery +import importlib.util +from io import SEEK_SET as io_SEEK_SET from construct.lib import * from construct.expr import * from construct.version import * +import logging +import re +from dataclasses import dataclass +import os #=============================================================================== @@ -121,6 +135,8 @@ class StopFieldError(ConstructError): Only one parsing class can raise this exception: StopIf. It can mean the given condition was met during parsing or building. """ pass + + class PaddingError(ConstructError): """ Multiple parsing classes can raise this exception: PaddedString Padding Padded Aligned FixedSized NullTerminated NullStripped. It can mean multiple issues: the encoded string or bytes takes more bytes than padding allows, length parameter was invalid, pattern terminator or pad is not a proper bytes value, modulus was less than 2. @@ -250,8 +266,8 @@ def __init__(self, contents: bytes, parent_stream, offset: int): def tell(self) -> int: return super().tell() + self.parent_stream_offset - def seek(self, offset: int, whence: int = io.SEEK_SET) -> int: - if whence != io.SEEK_SET: + def seek(self, offset: int, whence: int = io_SEEK_SET) -> int: + if whence != io_SEEK_SET: super().seek(offset, whence) else: super().seek(offset - self.parent_stream_offset) @@ -267,11 +283,24 @@ def __init__(self): self.linkedinstances = {} self.linkedparsers = {} self.linkedbuilders = {} + self.userfunction = {} + self._structs = {} def allocateId(self): self.nextid += 1 return self.nextid + def getCachedStruct(self, fmtstr): + fmtstr = repr(fmtstr) + try: + return self._structs[fmtstr] + except KeyError: + fname = f"formatfield_{self.allocateId()}" + self.append(f"{fname} = struct.Struct({fmtstr})") + self._structs[fmtstr] = fname + return fname + + def append(self, block): block = [s for s in block.splitlines() if s.strip()] firstline = block[0] @@ -507,7 +536,7 @@ def _sizeof(self, context, path): def _actualsize(self, stream, context, path): return self._sizeof(context, path) - def compile(self, filename=None): + def compile(self, filename=None, containertype="Container"): """ Transforms a construct into another construct that does same thing (has same parsing and building semantics) but is much faster when parsing. Already compiled instances just compile into itself. @@ -517,12 +546,14 @@ def compile(self, filename=None): """ code = CodeGen() - code.append(""" + code.append(f""" # generated by Construct, this source is for inspection only! do not import! from construct import * from construct.lib import * from io import BytesIO + from io import SEEK_END as io_SEEK_END + from io import SEEK_SET as io_SEEK_SET import struct import collections import itertools @@ -532,15 +563,17 @@ def restream(data, func): def reuse(obj, func): return func(obj) - linkedinstances = {} - linkedparsers = {} - linkedbuilders = {} + linkedinstances = {{}} + linkedparsers = {{}} + linkedbuilders = {{}} + userfunction = {{}} len_ = len sum_ = sum min_ = min max_ = max abs_ = abs + Container = {containertype} """) code.append(f""" def parseall(io, this): @@ -549,7 +582,8 @@ def buildall(obj, io, this): return {self._compilebuild(code)} compiled = Compiled(parseall, buildall) """) - source = code.toString() + source = code.toString().replace("dict(this)", "this") + source = code.toString().replace("__current_result__", "{}") if filename: with open(filename, "wt") as f: @@ -558,12 +592,15 @@ def buildall(obj, io, this): modulename = hexlify(hashlib.sha1(source.encode()).digest()).decode() module_spec = importlib.machinery.ModuleSpec(modulename, None) module = importlib.util.module_from_spec(module_spec) + with open("p.py", "w") as f: + f.write(source) c = compile(source, '', 'exec') exec(c, module.__dict__) module.linkedinstances = code.linkedinstances module.linkedparsers = code.linkedparsers module.linkedbuilders = code.linkedbuilders + module.userfunction = code.userfunction compiled = module.compiled compiled.source = source compiled.module = module @@ -593,7 +630,7 @@ def _compileparse(self, code): return emitted except NotImplementedError: self._compileinstance(code) - return f"linkedparsers[{id(self)}](io, this, '(???)')" + return f"linkedparsers[{id(self)}](io, Container(**({{**this,**__current_result__}})), '(???)')" def _compilebuild(self, code): """Used internally.""" @@ -821,7 +858,7 @@ def _parse(self, stream, context, path): def _build(self, obj, stream, context, path): obj2 = self._encode(obj, context, path) - buildret = self.subcon._build(obj2, stream, context, path) + self.subcon._build(obj2, stream, context, path) return obj def _decode(self, obj, context, path): @@ -873,7 +910,7 @@ def _parse(self, stream, context, path): def _build(self, obj, stream, context, path): stream2 = io.BytesIO() - buildret = self.subcon._build(obj, stream2, context, path) + self.subcon._build(obj, stream2, context, path) data = stream2.getvalue() data = self._encode(data, context, path) stream_write(stream, data, len(data), path) @@ -980,7 +1017,7 @@ def _emitparse(self, code): return f"io.read({self.length})" def _emitbuild(self, code): - return f"(io.write(obj), obj)[1]" + return "(io.write(obj), obj)[1]" def _emitfulltype(self, ksy, bitwise): return dict(size=self.length) @@ -1015,10 +1052,10 @@ def _build(self, obj, stream, context, path): return data def _emitparse(self, code): - return f"io.read()" + return "io.read()" def _emitbuild(self, code): - return f"(io.write(obj), obj)[1]" + return "(io.write(obj), obj)[1]" def _emitfulltype(self, ksy, bitwise): return dict(size_eos=True) @@ -1172,14 +1209,15 @@ def _sizeof(self, context, path): return self.length def _emitparse(self, code): - fname = f"formatfield_{code.allocateId()}" - code.append(f"{fname} = struct.Struct({repr(self.fmtstr)})") - return f"{fname}.unpack(io.read({self.length}))[0]" + if self.fmtstr in {"B", "B"}: + return f"(io.read(1))[0]" + elif self.fmtstr in {"b", "b"}: + return f"[_temp := io.read(1)[0], (_temp&0x7f)-(_temp&0x80)][1]" + else: + return f"{code.getCachedStruct(self.fmtstr)}.unpack(io.read({self.length}))[0]" def _emitbuild(self, code): - fname = f"formatfield_{code.allocateId()}" - code.append(f"{fname} = struct.Struct({repr(self.fmtstr)})") - return f"(io.write({fname}.pack(obj)), obj)[1]" + return f"(io.write({code.getCachedStruct(self.fmtstr)}.pack(obj)), obj)[1]" def _emitprimitivetype(self, ksy, bitwise): endianity,format = self.fmtstr @@ -1195,6 +1233,32 @@ def _emitprimitivetype(self, ksy, bitwise): if format in "fd": assert not bitwise return "f%s%s" % (self.length, "le" if swapped else "be", ) + + def _emitparse_optional(self, block, code, name_of_parsed_item): + if name_of_parsed_item: + if self.fmtstr in {"B", "B"}: + assignment = f"({name_of_parsed_item},) = readBuf" + elif self.fmtstr in {"b", "b"}: + assignment = f"{name_of_parsed_item} = [_temp := readBuf[0], (_temp&0x7F)-(0x80&_temp)][1]" + else: + assignment = f"({name_of_parsed_item},) = {code.getCachedStruct(self.fmtstr)}.unpack(readBuf)" + else: + assignment = "pass" + block += f""" + readBuf = io.read({self.length}) + readBufLen = len(readBuf) + if readBufLen == {self.length}: + {assignment} +""" + if self.length > 1: + block += f""" + elif readBufLen > 0: + io.seek(io.tell()-readBufLen)""" + block += """ + else: + return Container(__current_result__) #we are at the end of the stream.... + """ + return block class BytesInteger(Construct): @@ -1766,9 +1830,21 @@ def PaddedString(length, encoding): u'Афон' """ macro = StringEncoded(FixedSized(length, NullStripped(GreedyBytes, pad=encodingunit(encoding))), encoding) + def _emitfulltype(ksy, bitwise): - return dict(size=length, type="strz", encoding=encoding) + return dict(size=length, type="str", encoding=encoding) macro._emitfulltype = _emitfulltype + + def _emitparse(code): + return f"io.read({length}).decode('{encoding}').replace('\\x00', '')" + + def _emitbuild(code): + return f"(io.write(b'\\00'*{length} if obj == '' else obj.ljust({length}, '\\00').encode('{encoding}')[:{length}]))" + + macro._emitparse = _emitparse + macro._emitbuild = _emitbuild + macro._encoding = encoding + macro._length = length return macro @@ -1793,7 +1869,49 @@ def PascalString(lengthfield, encoding): def _emitparse(code): return f"io.read({lengthfield._compileparse(code)}).decode({repr(encoding)})" + + def _emitparse_optional(block, code, name_of_parsed_item): + if name_of_parsed_item: + assignment = f""" + try: + {name_of_parsed_item} = readBuf.decode({repr(encoding)}) + except: + io.seek(io.tell()-readBufLen-{lengthfield.length}) +""" + else: + assignment = "pass" + block = lengthfield._emitparse_optional(block, code, "_lenOfPascalString") + block += f""" + readBuf = io.read(_lenOfPascalString) + readBufLen = len(readBuf) + if readBufLen == _lenOfPascalString: + {assignment} + elif readBufLen == 0: + return Container(__current_result__) #we are at the end of the stream.... + else: + io.seek(io.tell()-readBufLen-{lengthfield.length})""" + return block + + + def _emitbuild(code): + fname = f"build_struct_{code.allocateId()}" + block = f""" + def {fname}(obj, io, this): + if obj=="": + obj = 0 + {lengthfield._compilebuild(code)} + else: + encodedObj = obj.encode('{encoding}') + obj = len(encodedObj) + {lengthfield._compilebuild(code)} + io.write(encodedObj) + """ + code.append(block) + return f"{fname}(obj, io, this)" + macro._emitparse = _emitparse + macro._emitbuild = _emitbuild + macro._emitparse_optional = _emitparse_optional def _emitseq(ksy, bitwise): return [ @@ -1801,10 +1919,8 @@ def _emitseq(ksy, bitwise): dict(id="data", size="lengthfield", type="str", encoding=encoding), ] macro._emitseq = _emitseq - return macro - def CString(encoding): r""" String ending in a terminating null byte (or null bytes in case of UTF16 UTF32). @@ -1828,6 +1944,27 @@ def CString(encoding): def _emitfulltype(ksy, bitwise): return dict(type="strz", encoding=encoding) macro._emitfulltype = _emitfulltype + + def _emitparse(code): + if "def _read2zero(io, term):" not in code.toString(): + code.append(""" + def _read2zero(io, term): + def _worker(termlen): + while True: + item = io.read(termlen) + if item != term: + yield item + else: + break + return b"".join(_worker(len(term))) + """) + return f"_read2zero(io, {encodingunit(encoding)}).decode({repr(encoding)})" + + def _emitbuild(code): + return f"""io.write(obj.encode("{encoding}")+{encodingunit(encoding)}) if obj!="" else io.write({encodingunit(encoding)})""" + + macro._emitparse = _emitparse + macro._emitbuild = _emitbuild return macro @@ -1884,10 +2021,10 @@ def _sizeof(self, context, path): return 1 def _emitparse(self, code): - return f"(io.read(1) != b'\\x00')" + return "(io.read(1) != b'\\x00')" def _emitbuild(self, code): - return f"((io.write(b'\\x01') if obj else io.write(b'\\x00')), obj)[1]" + return "((io.write(b'\\x01') if obj else io.write(b'\\x00')), obj)[1]" def _emitfulltype(self, ksy, bitwise): return dict(type=("b1" if bitwise else "u1"), _construct_render="Flag") @@ -1902,6 +2039,8 @@ class EnumIntegerString(str): """Used internally.""" def __repr__(self): + #Eventually this will just be the int value. This makes enums at runtime of + #compiled code just as fast as integers... return "EnumIntegerString.new(%s, %s)" % (self.intvalue, str.__repr__(self), ) def __int__(self): @@ -1913,6 +2052,18 @@ def new(intvalue, stringvalue): ret.intvalue = intvalue return ret + def __eq__(self, other): + if isinstance(other, int): + return (self.intvalue == other) + elif type(other) == type(self): + return (self.intvalue == other.intvalue) + elif isinstance(other, str): + logging.warning("Using a str to compare with a enum value is depricated! this may lead to bugs in the future!") + return str(self) == other + raise NotImplementedError(f"Cont compare {type(self)} to {type(other)} {other}") + + def __hash__(self): + return str(self).__hash__() class Enum(Adapter): r""" @@ -1966,6 +2117,8 @@ def __init__(self, subcon, *merge, **mapping): self.encmapping = {EnumIntegerString.new(v,k):v for k,v in mapping.items()} self.decmapping = {v:EnumIntegerString.new(v,k) for k,v in mapping.items()} self.ksymapping = {v:k for k,v in mapping.items()} + for k,v in mapping.items(): + setattr(self, k, EnumIntegerString.new(v,k)) def __getattr__(self, name): if name in self.encmapping: @@ -1982,6 +2135,8 @@ def _encode(self, obj, context, path): try: if isinstance(obj, int): return obj + if isinstance(obj, str): + logging.warning("Use enum typed values, not strings as enum values...") return self.encmapping[obj] except KeyError: raise MappingError("building failed, no mapping for %r" % (obj,), path=path) @@ -1989,7 +2144,7 @@ def _encode(self, obj, context, path): def _emitparse(self, code): fname = f"factory_{code.allocateId()}" code.append(f"{fname} = {repr(self.decmapping)}") - return f"reuse(({self.subcon._compileparse(code)}), lambda x: {fname}.get(x, EnumInteger(x)))" + return f"[x:={self.subcon._compileparse(code)}, {fname}.get(x, EnumInteger(x))][1]" def _emitbuild(self, code): fname = f"factory_{code.allocateId()}" @@ -2094,7 +2249,7 @@ def _encode(self, obj, context, path): raise MappingError("building failed, unknown label: %r" % (obj,), path=path) def _emitparse(self, code): - return f"reuse(({self.subcon._compileparse(code)}), lambda x: Container({', '.join(f'{k}=bool(x & {v} == {v})' for k,v in self.flags.items()) }))" + return f"[x:=({self.subcon._compileparse(code)}), Container({', '.join(f'{k}=bool(x & {v} == {v})' for k,v in self.flags.items()) })][1]" def _emitseq(self, ksy, bitwise): bitstotal = self.subcon.sizeof() * 8 @@ -2156,6 +2311,74 @@ def _emitbuild(self, code): #=============================================================================== # structures and sequences #=============================================================================== + +def __is_type__(sc, type, maxDepth=-1): + while maxDepth!=0: + maxDepth-=1 + if isinstance(sc, type): + return True + elif hasattr(sc, "subcon"): + sc = sc.subcon + else: + return False + +def __get_type__(sc, type, maxDepth=-1): + while maxDepth!=0: + maxDepth-=1 + if isinstance(sc, type): + return sc + elif hasattr(sc, "subcon"): + sc = sc.subcon + else: + return None + + +def __reduceDependancyDepth__(block, code): + argnames = passnames = "" + found = (item[1] for item in re.compile(r"(this(\['.*?'\])*)").findall(block)) + for item in found: + if item.startswith("this['_']"): + argName = f"_argname_{code.allocateId()}" + block = block.replace(item, argName) + argnames += ", " + argName + passnames += ", " + f"this{item[9:]}" + return block, (argnames, passnames) + + +def __materializeCollectedFixedSizeElements__(currentStretchOfFixedLen, block, code, Name2LocalVar): + if currentStretchOfFixedLen.names: #There is at least one item to be parsed using a struct + if all(item in {">", "<", "B"} for item in (currentStretchOfFixedLen.fmtstring)): + _intermediate = f"""({", ".join(f"{Name2LocalVar[item]}" for item in currentStretchOfFixedLen.names)}, ) = io.read({currentStretchOfFixedLen.length})""" + return block + f""" + {_intermediate} + {currentStretchOfFixedLen.convertercmd} + """ + else: + return block + f""" + ({", ".join(f"{Name2LocalVar[item]}" for item in currentStretchOfFixedLen.names)}, ) = {code.getCachedStruct(currentStretchOfFixedLen.fmtstring)}.unpack(io.read({currentStretchOfFixedLen.length})) + {currentStretchOfFixedLen.convertercmd} + """ + return block + + +def __orderComputedParts(computed_in, placeComputed): + for cIn in computed_in: + for idx in range(len(placeComputed)-1, -1, -1): + if f"this['{placeComputed[idx].name}']" in repr(__get_type__(cIn, Computed).func): + placeComputed.insert(idx+1, cIn) + break + else: + placeComputed.insert(0, cIn) + return placeComputed + +@dataclass +class _stretchOfFixedLen: + length: int + fmtstring: str + convertercmd: str + names: list + + class Struct(Construct): r""" Sequence of usually named constructs, similar to structs in C. The members are parsed and build in the order they are defined. If a member is anonymous (its name is None) then it gets parsed and the value discarded, or it gets build from nothing (from None). @@ -2218,6 +2441,17 @@ class Struct(Construct): def __init__(self, *subcons, **subconskw): super().__init__() self.subcons = list(subcons) + list(k/v for k,v in subconskw.items()) + + try: + computed1 = [item for item in self.subcons if __is_type__(item, Computed)] + for _ in range(2): + # The first run orders all items in the order, but the correct start point + # might be in the middle, the second rum moves it to the beginning... + computed = computed1 = __orderComputedParts(computed1, []) + subcons = [item for item in self.subcons if not __is_type__(item, Computed)] + self.subcons = __orderComputedParts(computed, subcons) + except Exception as e: + pass self._subcons = Container((sc.name,sc) for sc in self.subcons if sc.name) self.flagbuildnone = all(sc.flagbuildnone for sc in self.subcons) @@ -2274,25 +2508,86 @@ def _sizeof(self, context, path): def _emitparse(self, code): fname = f"parse_struct_{code.allocateId()}" - block = f""" - def {fname}(io, this): - result = Container() - this = Container(_ = this, _params = this['_params'], _root = None, _parsing = True, _building = False, _sizing = False, _subcons = None, _io = io, _index = this.get('_index', None)) - this['_root'] = this['_'].get('_root', this) - try: - """ + localVars2NameDict = {f"__item_{idx}_": sc for idx, sc in enumerate(self.subcons)} + block = "".join(f"""{os.linesep} {key}=None # {sc.name}""" for key, sc in + ((key, sc) for key, sc in localVars2NameDict.items() if ((__is_type__(sc, Optional) or __is_type__(sc, StopIf)) and sc.name))) + localVars2NameDict = {key: sc.name for key, sc in localVars2NameDict.items()} + Name2LocalVar = {name: localVar for localVar, name in localVars2NameDict.items()} + currentStretchOfFixedLen = _stretchOfFixedLen(length=0, fmtstring="", convertercmd="", names=[]) + for sc in self.subcons: - block += f""" - {f'result[{repr(sc.name)}] = this[{repr(sc.name)}] = ' if sc.name else ''}{sc._compileparse(code)} - """ + if __is_type__(sc, StringEncoded) and hasattr(sc, "_encoding") and hasattr(sc, "_length"): #its a padded string StringEncoded + currentStretchOfFixedLen.convertercmd += f"{Name2LocalVar[sc.name]} = {Name2LocalVar[sc.name]}.decode('{sc._encoding}').replace('\\x00', '');" + currentStretchOfFixedLen.fmtstring += f"{sc._length}s" + currentStretchOfFixedLen.length += sc._length + currentStretchOfFixedLen.names.append(sc.name) + elif __is_type__(sc, FormatField, 3) and hasattr(sc, "fmtstr"): #its a fixed length fmtstr entry + name = sc.name + noByteOrderForSingleByteItems = {"B":"B", + "b":"b", + "x":"x", + "c":"c",} + if sc.fmtstr in noByteOrderForSingleByteItems: + fieldFormatStr = noByteOrderForSingleByteItems[sc.fmtstr] + else: + fieldFormatStr = sc.fmtstr + if currentStretchOfFixedLen.fmtstring == "": + currentStretchOfFixedLen.fmtstring = fieldFormatStr + elif currentStretchOfFixedLen.fmtstring[0] in {">", "<"} and len (fieldFormatStr) >= 2 and currentStretchOfFixedLen.fmtstring[0] == fieldFormatStr[0]: + # byte order already set, and matching + currentStretchOfFixedLen.fmtstring = f"{currentStretchOfFixedLen.fmtstring}{fieldFormatStr[1]}" + elif currentStretchOfFixedLen.fmtstring[0] not in {">", "<"} and fieldFormatStr[0] in {">", "<"} and len (fieldFormatStr) >= 2 : + # byte order not already set + currentStretchOfFixedLen.fmtstring = f"{fieldFormatStr[0]}{currentStretchOfFixedLen.fmtstring}{fieldFormatStr[1:]}" + elif fieldFormatStr[0] not in {">", "<"} and len (fieldFormatStr) > 0: + # no byte order set on added struct + currentStretchOfFixedLen.fmtstring = f"{currentStretchOfFixedLen.fmtstring}{fieldFormatStr}" + else: + # change of byte order mid parsing... + block = __materializeCollectedFixedSizeElements__(currentStretchOfFixedLen, block, code, Name2LocalVar) + currentStretchOfFixedLen = _stretchOfFixedLen(length=0, fmtstring=fieldFormatStr, convertercmd="", names=[]) + currentStretchOfFixedLen.length += sc.length + currentStretchOfFixedLen.names.append(name) + else: # a variable length item, or optional item + block = __materializeCollectedFixedSizeElements__(currentStretchOfFixedLen, block, code, Name2LocalVar) + currentResult = "{"+ ", ".join(f"'{name}':{localVar}" for localVar, name in localVars2NameDict.items() if localVar in block)+ "}" + if __is_type__(sc, Optional): + try: + block = sc.subcons[0]._emitparse_optional(block, code, Name2LocalVar[sc.name]) + except AttributeError as e: + block += f""" + try: + fallback = io.tell() + {f'{Name2LocalVar[sc.name]} = ' if sc.name else ''}{sc.subcons[0]._compileparse(code)} + except ExplicitError: + raise + except Exception: + if io.seek(0, io_SEEK_END) == fallback: + return Container(__current_result__) #we are at the end of the stream.... + io.seek(fallback)""" + elif __get_type__(sc, StopIf, 2): + block += f""" + {__get_type__(sc, StopIf)._compileparseNoRaise()} return Container(__current_result__) #stopif in struct""" + else: + block += f""" + {f'{Name2LocalVar[sc.name]} = ' if sc.name else ''}{sc._compileparse(code)}""" + block = block.replace("__current_result__", currentResult) + currentStretchOfFixedLen = _stretchOfFixedLen(length=0, fmtstring="", convertercmd="", names=[]) + block = __materializeCollectedFixedSizeElements__(currentStretchOfFixedLen, block, code, Name2LocalVar) + currentResult = "{"+ ", ".join(f"'{name}':{localVar}" for localVar, name in localVars2NameDict.items() if (localVar in block) and name)+ "}" block += f""" - pass - except StopFieldError: - pass - return result - """ - code.append(block) - return f"{fname}(io, this)" + return Container({currentResult})""" + for name, value in Name2LocalVar.items(): + block = block.replace(f"this['{name}']", value) + block, (argnames, passnames) = __reduceDependancyDepth__(block, code) + if ("this" not in block): + code.append(f"""def {fname}(io{passnames}):""" + block) + return f"{fname}(io{argnames})" + if ("this" in block): + code.append(f"""def {fname}(io{passnames}, this): + this = Container(_ = Container(this), _params = this['_params'], _root = None, _parsing = True, _building = False, _sizing = False, _subcons = None, _io = io, _index = this.get('_index', None)) + this['_root'] = this['_'].get('_root', this)""" + block) + return f"{fname}(io{argnames}, {{**this,**__current_result__}})" def _emitbuild(self, code): fname = f"build_struct_{code.allocateId()}" @@ -2310,7 +2605,7 @@ def {fname}(obj, io, this): {f'this[{repr(sc.name)}] = obj' if sc.name else ''} {f'this[{repr(sc.name)}] = ' if sc.name else ''}{sc._compilebuild(code)} """ - block += f""" + block += """ pass except StopFieldError: pass @@ -2437,21 +2732,28 @@ def {fname}(io, this): try: """ for sc in self.subcons: - block += f""" + if isinstance(sc, StopIf): + redDictFiller = "{"+ ", ".join(f"'{name}':{localVar}" for localVar, name in localVars2NameDict.items() if localVar in block)+ "}" + sif = sc._compileparseNoRaise() + block += f""" + {sif} return {redDictFiller} + """ + else: + block += f""" result.append({sc._compileparse(code)}) """ if sc.name: block += f""" this[{repr(sc.name)}] = result[-1] """ - block += f""" + block += """ pass except StopFieldError: pass return result """ code.append(block) - return f"{fname}(io, this)" + return f"{fname}(io, {{**this,**__current_result__}})" def _emitbuild(self, code): fname = f"build_sequence_{code.allocateId()}" @@ -2465,13 +2767,13 @@ def {fname}(obj, io, this): """ for sc in self.subcons: block += f""" - {f'obj = next(objiter)'} + {'obj = next(objiter)'} {f'this[{repr(sc.name)}] = obj' if sc.name else ''} - {f'x = '}{sc._compilebuild(code)} - {f'retlist.append(x)'} + {'x = '}{sc._compilebuild(code)} + {'retlist.append(x)'} {f'this[{repr(sc.name)}] = x' if sc.name else ''} """ - block += f""" + block += """ pass except StopFieldError: pass @@ -2668,7 +2970,8 @@ def _parse(self, stream, context, path): predicate = self.predicate discard = self.discard if not callable(predicate): - predicate = lambda _1,_2,_3: predicate + def predicate(_1, _2, _3): + return predicate obj = ListContainer() for i in itertools.count(): context._index = i @@ -2682,7 +2985,8 @@ def _build(self, obj, stream, context, path): predicate = self.predicate discard = self.discard if not callable(predicate): - predicate = lambda _1,_2,_3: predicate + def predicate(_1, _2, _3): + return predicate partiallist = ListContainer() retlist = ListContainer() for i,e in enumerate(obj): @@ -2713,7 +3017,7 @@ def {fname}(io, this): return list_ """ code.append(block) - return f"{fname}(io, this)" + return f"{fname}(io, ({{**this,**__current_result__}}))" def _emitbuild(self, code): fname = f"build_repeatuntil_{code.allocateId()}" @@ -2854,12 +3158,13 @@ def _sizeof(self, context, path): return self.subcon._sizeof(context, path) def _emitparse(self, code): + fun_name = f"parse_const_{code.allocateId()}" code.append(f""" - def parse_const(value, expected): - if not value == expected: raise ConstError - return value + def {fun_name}(value): + if not value == {repr(self.value)}: raise ConstError + return {repr(self.value)} """) - return f"parse_const({self.subcon._compileparse(code)}, {repr(self.value)})" + return f"{fun_name}({self.subcon._compileparse(code)})" def _emitbuild(self, code): if isinstance(self.value, bytes): @@ -3007,7 +3312,12 @@ def _emitparse(self, code): return self.subcon._compileparse(code) def _emitbuild(self, code): - return f"reuse({repr(self.func)}, lambda obj: ({self.subcon._compilebuild(code)}))" + if isinstance(self.func, ExprMixin) or (not callable(self.func)): + return f"reuse({repr(self.func)}, lambda obj: ({self.subcon._compilebuild(code)}))" + else: + aid = code.allocateId() + code.userfunction[aid] = self.func + return f"reuse(userfunction[{aid}](Container(this)), lambda obj: ({self.subcon._compilebuild(code)}))" def _emitseq(self, ksy, bitwise): return self.subcon._compileseq(ksy, bitwise) @@ -3107,14 +3417,14 @@ def _sizeof(self, context, path): return 0 def _emitparse(self, code): - code.append(f""" + code.append(""" def parse_check(condition): if not condition: raise CheckError """) return f"parse_check({repr(self.func)})" def _emitbuild(self, code): - code.append(f""" + code.append(""" def build_check(condition): if not condition: raise CheckError """) @@ -3278,7 +3588,7 @@ def {fname}(io, this): return this[{repr(self.parsebuildfrom)}] """ code.append(block) - return f"{fname}(io, this)" + return f"{fname}(io, {{**this,**__current_result__}})" def _emitbuild(self, code): fname = f"build_focusedseq_{code.allocateId()}" @@ -3293,11 +3603,11 @@ def {fname}(obj, io, this): for sc in self.subcons: block += f""" {f'obj = {"finalobj" if sc.name == self.parsebuildfrom else "None"}'} - {f'buildret = '}{sc._compilebuild(code)} + {'buildret = '}{sc._compilebuild(code)} {f'this[{repr(sc.name)}] = buildret' if sc.name else ''} {f'{"finalret = buildret" if sc.name == self.parsebuildfrom else ""}'} """ - block += f""" + block += """ pass except StopFieldError: pass @@ -3337,6 +3647,19 @@ def _build(self, obj, stream, context, path): pickle.dump(obj, stream) return obj + def _emitparse(self, code): + "factory_%s" % code.allocateId() + code.append(""" + import pickle + """) + return "pickle.load(io)" + + def _emitbuild(self, code): + "factory_%s" % code.allocateId() + code.append(""" + import pickle + """) + return "pickle.dump(obj, io)" @singleton class Numpy(Construct): @@ -3369,6 +3692,20 @@ def _build(self, obj, stream, context, path): numpy.save(stream, obj) return obj + def _emitparse(self, code): + "factory_%s" % code.allocateId() + code.append(""" + import numpy + """) + return "numpy.load(io)" + + def _emitbuild(self, code): + "factory_%s" % code.allocateId() + code.append(""" + import numpy + """) + return "numpy.save(io, obj)" + class NamedTuple(Adapter): r""" @@ -3749,6 +4086,7 @@ def _emitparse(self, code): fname = "parse_union_%s" % code.allocateId() block = """ def %s(io, this): + #union this = Container(_ = this, _params = this['_params'], _root = None, _parsing = True, _building = False, _sizing = False, _subcons = None, _io = io, _index = this.get('_index', None)) this['_root'] = this['_'].get('_root', this) fallback = io.tell() @@ -3793,7 +4131,7 @@ def %s(io, this): return this """ code.append(block) - return "%s(io, this)" % (fname,) + return f"{fname}(io, this)" def _emitbuild(self, code): fname = f"build_union_{code.allocateId()}" @@ -3812,7 +4150,7 @@ def {fname}(obj, io, this): {f'buildret = this[{repr(sc.name)}] = ' if sc.name else ''}{sc._compilebuild(code)} {f'return Container({{ {repr(sc.name)}:buildret }})'} """ - block += f""" + block += """ raise UnionError('cannot build, none of subcons were found in the dictionary') """ code.append(block) @@ -3875,8 +4213,50 @@ def _build(self, obj, stream, context, path): return obj raise SelectError("no subconstruct matched: %s" % (obj,), path=path) + def _emitparse(self, code): + fname = f"parse_select_{code.allocateId()}" + + block = f""" + def {fname}(io, this): + fallback = io.tell() + """ + for sc in self.subcons: + cb = sc._compileparse(code) + if cb == "None": + block += """ + return None + """ + else: + block += f""" + try: + return {cb} + except ExplicitError: + raise + except Exception: + io.seek(fallback) + """ + code.append(block) + return "%s(io, this)" % (fname,) + + def _emitbuild(self, code): + fname = f"build_select_{code.allocateId()}" + + block = f""" + def {fname}(obj, io, this): + """ + for sc in self.subcons: + block += f""" + try: + return {sc._compilebuild(code)} + except: + pass + """ + code.append(block) + return "%s(obj, io, this)" % (fname,) + + -def Optional(subcon): +class Optional(Select): r""" Makes an optional field. @@ -3898,7 +4278,10 @@ def Optional(subcon): >>> d.build(None) b'' """ - return Select(subcon, Pass) + def __init__(self, subcon): + super().__init__() + self.subcons = [subcon, Pass] + self.flagbuildnone = any(sc.flagbuildnone for sc in self.subcons) def If(condfunc, subcon): @@ -3979,10 +4362,43 @@ def _sizeof(self, context, path): return sc._sizeof(context, path) def _emitparse(self, code): - return "((%s) if (%s) else (%s))" % (self.thensubcon._compileparse(code), self.condfunc, self.elsesubcon._compileparse(code), ) + if isinstance(self.condfunc, ExprMixin) or (not callable(self.condfunc)): + return "((%s) if (%s) else (%s))" % (self.thensubcon._compileparse(code), self.condfunc, self.elsesubcon._compileparse(code), ) + else: + aid = code.allocateId() + code.userfunction[aid] = self.condfunc + return "((%s) if (%s) else (%s))" % (self.thensubcon._compileparse(code), f"userfunction[{aid}](Container({{**this,**__current_result__}}))", self.elsesubcon._compileparse(code), ) + + def _emitparse_optional(self, block, code, name_of_parsed_item): + def _indent(block): + return (f"{os.linesep} ").join(block.split(os.linesep)) + + if isinstance(self.condfunc, ExprMixin) or (not callable(self.condfunc)): + funcString = self.condfunc + else: + aid = code.allocateId() + code.userfunction[aid] = self.condfunc + funcString = f"userfunction[{aid}](Container({{**this,**__current_result__}}))" + block += f""" + if {funcString}: + {_indent(self.thensubcon._emitparse_optional("", code, name_of_parsed_item))} +""" + if self.elsesubcon != Pass: + print(self.elsesubcon) + block += f""" + else: + {_indent(self.elsesubcon._emitparse_optional("", code, name_of_parsed_item))} +""" + return block + def _emitbuild(self, code): - return f"(({self.thensubcon._compilebuild(code)}) if ({repr(self.condfunc)}) else ({self.elsesubcon._compilebuild(code)}))" + if isinstance(self.condfunc, ExprMixin) or (not callable(self.condfunc)): + return f"(({self.thensubcon._compilebuild(code)}) if ({repr(self.condfunc)}) else ({self.elsesubcon._compilebuild(code)}))" + else: + aid = code.allocateId() + code.userfunction[aid] = self.condfunc + return f"(({self.thensubcon._compilebuild(code)}) if (userfunction[{aid}](Container(this))) else ({self.elsesubcon._compilebuild(code)}))" def _emitseq(self, ksy, bitwise): return [ @@ -3991,6 +4407,8 @@ def _emitseq(self, ksy, bitwise): ] + + class Switch(Construct): r""" A conditional branch. @@ -4050,23 +4468,42 @@ def _sizeof(self, context, path): raise SizeofError("cannot calculate size, key not found in context", path=path) def _emitparse(self, code): - fname = f"switch_cases_{code.allocateId()}" - code.append(f"{fname} = {{}}") - for key,sc in self.cases.items(): - code.append(f"{fname}[{repr(key)}] = lambda io,this: {sc._compileparse(code)}") - defaultfname = f"switch_defaultcase_{code.allocateId()}" - code.append(f"{defaultfname} = lambda io,this: {self.default._compileparse(code)}") - return f"{fname}.get({repr(self.keyfunc)}, {defaultfname})(io, this)" + def __make_switch_statement(cases, keyfun, default, code, assignWalrus=False): + aid = code.allocateId() + if cases: + newCond, newAction = cases.pop() + if assignWalrus: # use walrus operator to avoid multiple evaluation of check. + nameOfkFun = f"switch_lookup_value_{aid}" + return f"{newAction._emitparse(code)} if (({nameOfkFun} := ({keyfun})) == ({repr(newCond)})) else ({__make_switch_statement(cases, nameOfkFun, default, code, False)})" + return f"{newAction._emitparse(code)} if (({keyfun}) == ({repr(newCond)})) else ({__make_switch_statement(cases, keyfun, default, code, False)})" + else: + return f"{default}" + + if isinstance(self.keyfunc, ExprMixin) or(not callable(self.keyfunc)): + return __make_switch_statement(set(self.cases.items()), repr(self.keyfunc), self.default._compileparse(code), code, True) + else: + aid = code.allocateId() + code.userfunction[aid] = self.keyfunc + return __make_switch_statement(set(self.cases.items()), f"userfunction[{aid}](Container(this))", self.default._compileparse(code), code, True) def _emitbuild(self, code): - fname = f"switch_cases_{code.allocateId()}" - code.append(f"{fname} = {{}}") - for key,sc in self.cases.items(): - code.append(f"{fname}[{repr(key)}] = lambda obj,io,this: {sc._compilebuild(code)}") - defaultfname = f"switch_defaultcase_{code.allocateId()}" - code.append(f"{defaultfname} = lambda obj,io,this: {self.default._compilebuild(code)}") - return f"{fname}.get({repr(self.keyfunc)}, {defaultfname})(obj, io, this)" - + def __make_switch_statement(cases, keyfun, default, code, assignWalrus=False): + aid = code.allocateId() + if cases: + newCond, newAction = cases.pop() + if assignWalrus: # use walrus operator to avoid multiple evaluation of check. + nameOfkFun = f"switch_lookup_value_{aid}" + return f"{newAction._emitbuild(code)} if (({nameOfkFun} := ({keyfun})) == ({repr(newCond)})) else ({__make_switch_statement(cases, nameOfkFun, default, code, False)})" + return f"{newAction._emitbuild(code)} if (({keyfun}) == ({repr(newCond)})) else ({__make_switch_statement(cases, keyfun, default, code, False)})" + else: + return f"{default}" + + if isinstance(self.keyfunc, ExprMixin) or(not callable(self.keyfunc)): + return __make_switch_statement(set(self.cases.items()), repr(self.keyfunc), self.default._compilebuild(code), code, True) + else: + aid = code.allocateId() + code.userfunction[aid] = self.keyfunc + return __make_switch_statement(set(self.cases.items()), f"userfunction[{aid}](Container(this))", self.default._compilebuild(code), code, True) class StopIf(Construct): r""" @@ -4105,16 +4542,14 @@ def _build(self, obj, stream, context, path): def _sizeof(self, context, path): raise SizeofError("StopIf cannot determine size because it depends on actual context which then depends on actual data and outer constructs", path=path) + def _compileparseNoRaise(self): + return f"if({repr(self.condfunc)}): " + def _emitparse(self, code): - code.append(f""" - def parse_stopif(condition): - if condition: - raise StopFieldError - """) - return f"parse_stopif({repr(self.condfunc)})" + return f"if({repr(self.condfunc)}): return Container(__current_result__)" def _emitbuild(self, code): - code.append(f""" + code.append(""" def build_stopif(condition): if condition: raise StopFieldError @@ -4426,7 +4861,7 @@ def _sizeof(self, context, path): return 0 def _emitparse(self, code): - code.append(f""" + code.append(""" def parse_pointer(io, offset, func): fallback = io.tell() io.seek(offset, 2 if offset < 0 else 0) @@ -4437,7 +4872,7 @@ def parse_pointer(io, offset, func): return f"parse_pointer(io, {self.offset}, lambda: {self.subcon._compileparse(code)})" def _emitbuild(self, code): - code.append(f""" + code.append(""" def build_pointer(obj, io, offset, func): fallback = io.tell() io.seek(offset, 2 if offset < 0 else 0) @@ -4689,6 +5124,12 @@ def _emitparse(self, code): def _emitbuild(self, code): return "None" + + def _emitparse_optional(self, block, code, name_of_parsed_item): + block += f""" + {name_of_parsed_item} = None +""" + return block def _emitfulltype(self, ksy, bitwise): return dict(size=0) @@ -5314,7 +5755,7 @@ def _parse(self, stream, context, path): def _build(self, obj, stream, context, path): stream2 = RestreamedBytesIO(stream, self.decoder, self.decoderunit, self.encoder, self.encoderunit) - buildret = self.subcon._build(obj, stream2, context, path) + self.subcon._build(obj, stream2, context, path) stream2.close() return obj @@ -5732,7 +6173,6 @@ class EncryptedSym(Tunnel): """ def __init__(self, subcon, cipher): - import cryptography super().__init__(subcon) self.cipher = cipher @@ -5742,7 +6182,7 @@ def _evaluate_cipher(self, context, path): if not isinstance(cipher, Cipher): raise CipherError(f"cipher {repr(cipher)} is not a cryptography.hazmat.primitives.ciphers.Cipher object", path=path) if isinstance(cipher.mode, modes.GCM): - raise CipherError(f"AEAD cipher is not supported in this class, use EncryptedSymAead", path=path) + raise CipherError("AEAD cipher is not supported in this class, use EncryptedSymAead", path=path) return cipher def _decode(self, data, context, path): diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..03f586d4 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = . \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..2ae28399 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +pass diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 083f3015..1eed7975 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -174,7 +174,7 @@ # "repeatuntil2" / RepeatUntil(list_ == [0], Byte), # "repeatuntil3" / RepeatUntil(obj_ == 0, Byte), ) -exampledata = bytes(1000) +exampledata = bytes(10000) def test_compiled_example_benchmark(): diff --git a/tests/test_core.py b/tests/test_core.py index 779bd816..565c2c56 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -331,13 +331,14 @@ class F(enum.IntFlag): common(d, b"\x02", "b", 1) def test_enum_issue_298(): - d = Struct( - "ctrl" / Enum(Byte, + e= Enum(Byte, NAK = 0x15, STX = 0x02, - ), + ) + d = Struct( + "ctrl" / e, Probe(), - "optional" / If(lambda this: this.ctrl == "NAK", Byte), + "optional" / If(lambda this: (this["ctrl"] == e.NAK), Byte), ) common(d, b"\x15\xff", Container(ctrl='NAK', optional=255)) common(d, b"\x02", Container(ctrl='STX', optional=None)) @@ -367,11 +368,11 @@ def test_enum_issue_677(): assert isinstance(d.parse(b"\x01"), EnumIntegerString) d = Struct("e" / Enum(Byte, one=1)) - assert str(d.parse(b"\x01")) == 'Container: \n e = (enum) one 1' - assert str(d.parse(b"\xff")) == 'Container: \n e = (enum) (unknown) 255' + assert (d.parse(b"\x01"))["e"] == 1 + assert (d.parse(b"\xff"))["e"] == 255 d = Struct("e" / Enum(Byte, one=1)).compile() - assert str(d.parse(b"\x01")) == 'Container: \n e = (enum) one 1' - assert str(d.parse(b"\xff")) == 'Container: \n e = (enum) (unknown) 255' + assert (d.parse(b"\x01"))["e"] == 1 + assert (d.parse(b"\xff"))["e"] == 255 @xfail(reason="Cannot implement this in EnumIntegerString.") def test_enum_issue_992(): @@ -386,6 +387,29 @@ class F(enum.IntFlag): x = d.parse(b"\x02") assert x == F.b + +def test_optional_pascal_string(): + d = Struct("opt"/Optional(PascalString(Byte, "ascii"))) + dc = d.compile() + for blob in [b"\x01a", b""]: + assert d.parse(blob) == dc.parse(blob) + assert d.build(dc.parse(blob)) == blob + assert dc.build(d.parse(blob)) == blob + assert dc.build(dc.parse(blob)) == blob + assert d.build(d.parse(blob)) == blob + + for blob in [b"\x01", b"\x01\xff"]: + assert d.parse(blob) == Container(opt=None) + assert dc.parse(blob) == Container(opt=None) + assert dc.parse(b"\x03abc") == Container(opt="abc") + + d = Struct("opt1"/Optional(PascalString(Byte, "ascii")), + "opt2"/Optional(Int32ul)) + dc = d.compile() + for blob in [b"\x0111234", b"\x01\xff12"]: + assert d.parse(blob) == dc.parse(blob) + + def test_flagsenum(): d = FlagsEnum(Byte, one=1, two=2, four=4, eight=8) common(d, b"\x03", Container(_flagsenum=True, one=True, two=True, four=False, eight=False), 1) @@ -641,6 +665,18 @@ def test_rebuild_issue_664(): # no asserts are needed d.build(obj) + +def test_rebuild_custom_function(): + def getlen(this): + return 2 + + template = Struct( "count" / Rebuild(Byte, getlen), "my_items" / Byte[this.count]) + for d in [template, template.compile()]: + assert d.parse(b"\x02ab") == Container(count=2, my_items=[97,98]) + assert d.build(dict(count=None,my_items=[255,255])) == b"\x02\xff\xff" + assert d.build(dict(count=2,my_items=[255,255])) == b"\x02\xff\xff" + assert d.build(dict(my_items=[255,255])) == b"\x02\xff\xff" + def test_default(): d = Default(Byte, 0) common(d, b"\xff", 255, 1) @@ -866,6 +902,10 @@ def test_if(): def test_ifthenelse(): common(IfThenElse(True, Int8ub, Int16ub), b"\x01", 1, 1) common(IfThenElse(False, Int8ub, Int16ub), b"\x00\x01", 1, 2) + stimulus_with_user_function = IfThenElse(lambda _: False, Int8ub, Int16ub) + for d in [stimulus_with_user_function, stimulus_with_user_function.compile()]: + common(d, b"\x00\x01", 1, 2) + def test_switch(): d = Switch(this.x, {1:Int8ub, 2:Int16ub, 4:Int32ub}) @@ -876,8 +916,18 @@ def test_switch(): assert raises(d.sizeof) == SizeofError assert raises(d.sizeof, x=1) == 1 + dStencil = Switch(lambda this: this["x"], {1:Int8ub, 2:Int16ub, 4:Int32ub}) + for d in [dStencil, dStencil.compile()]: + common(d, b"\x01", 0x01, 1, x=1) + common(d, b"\x01\x02", 0x0102, 2, x=2) + assert d.parse(b"", x=255) == None + assert d.build(None, x=255) == b"" + assert raises(d.sizeof) == SizeofError + assert raises(d.sizeof, x=1) == 1 + d = Switch(this.x, {}, default=Byte) common(d, b"\x01", 1, 1, x=255) + def test_switch_issue_357(): inner = Struct(