diff --git a/sbe/__init__.py b/sbe/__init__.py index 6d3e769..73a1081 100644 --- a/sbe/__init__.py +++ b/sbe/__init__.py @@ -7,6 +7,7 @@ import lxml import lxml.etree + class PrimitiveType(enum.Enum): CHAR = 'char' UINT8 = 'uint8' @@ -281,10 +282,10 @@ def encode(self, vals: Iterable[str]) -> int: return bitstring.BitArray(v.name in vals for v in reversed(self.choices)).uint def decode(self, val: int) -> List[str]: - if isinstance(self.encodingType, SetEncodingType): - length = FORMAT_SIZES[PrimitiveType[self.encodingType.name]] * 8 - else: - length = FORMAT_SIZES[self.encodingType.primitiveType] * 8 + # if isinstance(self.encodingType, SetEncodingType): + # length = FORMAT_SIZES[PrimitiveType[self.encodingType.name]] * 8 + # else: + # length = FORMAT_SIZES[self.encodingType.primitiveType] * 8 return [c.name for c in self.choices if (1 << c.value) & val] @@ -559,7 +560,7 @@ def wrap(self, buf: Union[bytes, memoryview], header_only=False) -> WrappedMessa return WrappedMessage(buf, header, body) - def create_wrappers(self) -> WrappedMessage: + def create_wrappers(self): cursor = Cursor(0) pointers = {} @@ -573,7 +574,6 @@ def create_wrappers(self) -> WrappedMessage: _walk_fields_wrap(self, pointers, m.fields, cursor) self.message_wrappers[i] = WrappedComposite(m.name, pointers, None, 0) - def _unpack_format( schema: Schema, type_: Union[Field, Group, PrimitiveType, Type, RefType, Set, Enum, Composite], @@ -640,6 +640,7 @@ def _unpack_format( if isinstance(type_, Composite): return prefix + ''.join(_unpack_format(schema, t, '', buffer, buffer_cursor) for t in type_.types) + assert False, f"unreachable: {type_}" def _pack_format(_schema: Schema, composite: Composite): fmt = [] @@ -793,7 +794,7 @@ def _walk_fields_encode(schema: Schema, fields: List[Union[Group, Field]], if t == PrimitiveType.CHAR: vals.append(obj[f.name].encode()) else: - vals.append(f.type.nullValue) if obj[f.name] is None else vals.append(obj[f.name]) + vals.append(obj[f.name] if obj[f.name] is not None else f.type.nullValue) cursor.val += FORMAT_SIZES[t] elif isinstance(f.type, Set): @@ -822,7 +823,7 @@ def _walk_fields_encode(schema: Schema, fields: List[Union[Group, Field]], elif isinstance(f.type, PrimitiveType): fmt.append(FORMAT[f.type]) - vals.append(obj[f.name].encode()) if f.type == PrimitiveType.CHAR else vals.append(obj[f.name]) + vals.append(obj[f.name].encode() if f.type == PrimitiveType.CHAR else obj[f.name]) cursor.val += FORMAT_SIZES[f.type] else: assert 0