Module diem.serde_binary
Module describing the "binary" serialization formats.
Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future.
Expand source code
# Copyright (c) Facebook, Inc. and its affiliates
# SPDX-License-Identifier: MIT OR Apache-2.0
"""
Module describing the "binary" serialization formats.
Note: This internal module is currently only meant to share code between the BCS and bincode formats. Internal APIs could change in the future.
"""
import dataclasses
import collections
import io
import typing
from typing import get_type_hints
from diem import serde_types as st
@dataclasses.dataclass
class BinarySerializer:
"""Serialization primitives for binary formats (abstract class).
"Binary" serialization formats may differ in the way they encode sequence lengths, variant
index, and how they sort map entries (or not).
"""
output: io.BytesIO
container_depth_budget: typing.Optional[int]
primitive_type_serializer: typing.Mapping = dataclasses.field(init=False)
def __post_init__(self):
self.primitive_type_serializer = {
bool: self.serialize_bool,
st.uint8: self.serialize_u8,
st.uint16: self.serialize_u16,
st.uint32: self.serialize_u32,
st.uint64: self.serialize_u64,
st.uint128: self.serialize_u128,
st.int8: self.serialize_i8,
st.int16: self.serialize_i16,
st.int32: self.serialize_i32,
st.int64: self.serialize_i64,
st.int128: self.serialize_i128,
st.float32: self.serialize_f32,
st.float64: self.serialize_f64,
st.unit: self.serialize_unit,
st.char: self.serialize_char,
str: self.serialize_str,
bytes: self.serialize_bytes,
}
def serialize_bytes(self, value: bytes):
self.serialize_len(len(value))
self.output.write(value)
def serialize_str(self, value: str):
self.serialize_bytes(value.encode())
def serialize_unit(self, value: st.unit):
pass
def serialize_bool(self, value: bool):
self.output.write(int(value).to_bytes(1, "little", signed=False))
def serialize_u8(self, value: st.uint8):
self.output.write(int(value).to_bytes(1, "little", signed=False))
def serialize_u16(self, value: st.uint16):
self.output.write(int(value).to_bytes(2, "little", signed=False))
def serialize_u32(self, value: st.uint32):
self.output.write(int(value).to_bytes(4, "little", signed=False))
def serialize_u64(self, value: st.uint64):
self.output.write(int(value).to_bytes(8, "little", signed=False))
def serialize_u128(self, value: st.uint128):
self.output.write(int(value).to_bytes(16, "little", signed=False))
def serialize_i8(self, value: st.uint8):
self.output.write(int(value).to_bytes(1, "little", signed=True))
def serialize_i16(self, value: st.uint16):
self.output.write(int(value).to_bytes(2, "little", signed=True))
def serialize_i32(self, value: st.uint32):
self.output.write(int(value).to_bytes(4, "little", signed=True))
def serialize_i64(self, value: st.uint64):
self.output.write(int(value).to_bytes(8, "little", signed=True))
def serialize_i128(self, value: st.uint128):
self.output.write(int(value).to_bytes(16, "little", signed=True))
def serialize_f32(self, value: st.float32):
raise NotImplementedError
def serialize_f64(self, value: st.float64):
raise NotImplementedError
def serialize_char(self, value: st.char):
raise NotImplementedError
def get_buffer_offset(self) -> int:
return len(self.output.getbuffer())
def get_buffer(self) -> bytes:
return self.output.getvalue()
def increase_container_depth(self):
if self.container_depth_budget is not None:
if self.container_depth_budget == 0:
raise st.SerializationError("Exceeded maximum container depth")
self.container_depth_budget -= 1
def decrease_container_depth(self):
if self.container_depth_budget is not None:
self.container_depth_budget += 1
def serialize_len(self, value: int):
raise NotImplementedError
def serialize_variant_index(self, value: int):
raise NotImplementedError
def sort_map_entries(self, offsets: typing.List[int]):
raise NotImplementedError
# noqa: C901
def serialize_any(self, obj: typing.Any, obj_type):
if obj_type in self.primitive_type_serializer:
self.primitive_type_serializer[obj_type](obj)
elif hasattr(obj_type, "__origin__"): # Generic type
types = getattr(obj_type, "__args__")
if getattr(obj_type, "__origin__") == collections.abc.Sequence: # Sequence
assert len(types) == 1
item_type = types[0]
self.serialize_len(len(obj))
for item in obj:
self.serialize_any(item, item_type)
elif getattr(obj_type, "__origin__") == tuple: # Tuple
for i in range(len(obj)):
self.serialize_any(obj[i], types[i])
elif getattr(obj_type, "__origin__") == typing.Union: # Option
assert len(types) == 2 and types[1] == type(None)
if obj is None:
self.output.write(b"\x00")
else:
self.output.write(b"\x01")
self.serialize_any(obj, types[0])
elif getattr(obj_type, "__origin__") == dict: # Map
assert len(types) == 2
self.serialize_len(len(obj))
offsets = []
for key, value in obj.items():
offsets.append(self.get_buffer_offset())
self.serialize_any(key, types[0])
self.serialize_any(value, types[1])
self.sort_map_entries(offsets)
else:
raise st.SerializationError("Unexpected type", obj_type)
else:
if not dataclasses.is_dataclass(obj_type): # Enum
if not hasattr(obj_type, "VARIANTS"):
raise st.SerializationError("Unexpected type", obj_type)
if not hasattr(obj, "INDEX"):
raise st.SerializationError("Wrong Value for the type", obj, obj_type)
self.serialize_variant_index(obj.__class__.INDEX)
# Proceed to variant
obj_type = obj_type.VARIANTS[obj.__class__.INDEX]
if not dataclasses.is_dataclass(obj_type):
raise st.SerializationError("Unexpected type", obj_type)
# pyre-ignore
if not isinstance(obj, obj_type):
raise st.SerializationError("Wrong Value for the type", obj, obj_type)
# Content of struct or variant
fields = dataclasses.fields(obj_type)
types = get_type_hints(obj_type)
self.increase_container_depth()
for field in fields:
field_value = obj.__dict__[field.name]
field_type = types[field.name]
self.serialize_any(field_value, field_type)
self.decrease_container_depth()
@dataclasses.dataclass
class BinaryDeserializer:
"""Deserialization primitives for binary formats (abstract class).
"Binary" serialization formats may differ in the way they encode sequence lengths, variant
index, and how they verify the ordering of keys in map entries (or not).
"""
input: io.BytesIO
container_depth_budget: typing.Optional[int]
primitive_type_deserializer: typing.Mapping = dataclasses.field(init=False)
def __post_init__(self):
self.primitive_type_deserializer = {
bool: self.deserialize_bool,
st.uint8: self.deserialize_u8,
st.uint16: self.deserialize_u16,
st.uint32: self.deserialize_u32,
st.uint64: self.deserialize_u64,
st.uint128: self.deserialize_u128,
st.int8: self.deserialize_i8,
st.int16: self.deserialize_i16,
st.int32: self.deserialize_i32,
st.int64: self.deserialize_i64,
st.int128: self.deserialize_i128,
st.float32: self.deserialize_f32,
st.float64: self.deserialize_f64,
st.unit: self.deserialize_unit,
st.char: self.deserialize_char,
str: self.deserialize_str,
bytes: self.deserialize_bytes,
}
def read(self, length: int) -> bytes:
value = self.input.read(length)
if value is None or len(value) < length:
raise st.DeserializationError("Input is too short")
return value
def deserialize_bytes(self) -> bytes:
length = self.deserialize_len()
return self.read(length)
def deserialize_str(self) -> str:
content = self.deserialize_bytes()
try:
return content.decode()
except UnicodeDecodeError:
raise st.DeserializationError("Invalid unicode string:", content)
def deserialize_unit(self) -> st.unit:
pass
def deserialize_bool(self) -> bool:
b = int.from_bytes(self.read(1), byteorder="little", signed=False)
if b == 0:
return False
elif b == 1:
return True
else:
raise st.DeserializationError("Unexpected boolean value:", b)
def deserialize_u8(self) -> st.uint8:
return st.uint8(int.from_bytes(self.read(1), byteorder="little", signed=False))
def deserialize_u16(self) -> st.uint16:
return st.uint16(int.from_bytes(self.read(2), byteorder="little", signed=False))
def deserialize_u32(self) -> st.uint32:
return st.uint32(int.from_bytes(self.read(4), byteorder="little", signed=False))
def deserialize_u64(self) -> st.uint64:
return st.uint64(int.from_bytes(self.read(8), byteorder="little", signed=False))
def deserialize_u128(self) -> st.uint128:
return st.uint128(int.from_bytes(self.read(16), byteorder="little", signed=False))
def deserialize_i8(self) -> st.int8:
return st.int8(int.from_bytes(self.read(1), byteorder="little", signed=True))
def deserialize_i16(self) -> st.int16:
return st.int16(int.from_bytes(self.read(2), byteorder="little", signed=True))
def deserialize_i32(self) -> st.int32:
return st.int32(int.from_bytes(self.read(4), byteorder="little", signed=True))
def deserialize_i64(self) -> st.int64:
return st.int64(int.from_bytes(self.read(8), byteorder="little", signed=True))
def deserialize_i128(self) -> st.int128:
return st.int128(int.from_bytes(self.read(16), byteorder="little", signed=True))
def deserialize_f32(self) -> st.float32:
raise NotImplementedError
def deserialize_f64(self) -> st.float64:
raise NotImplementedError
def deserialize_char(self) -> st.char:
raise NotImplementedError
def get_buffer_offset(self) -> int:
return self.input.tell()
def get_remaining_buffer(self) -> bytes:
buf = self.input.getbuffer()
offset = self.get_buffer_offset()
return bytes(buf[offset:])
def increase_container_depth(self):
if self.container_depth_budget is not None:
if self.container_depth_budget == 0:
raise st.DeserializationError("Exceeded maximum container depth")
self.container_depth_budget -= 1
def decrease_container_depth(self):
if self.container_depth_budget is not None:
self.container_depth_budget += 1
def deserialize_len(self) -> int:
raise NotImplementedError
def deserialize_variant_index(self) -> int:
raise NotImplementedError
def check_that_key_slices_are_increasing(
self, slice1: typing.Tuple[int, int], slice2: typing.Tuple[int, int]
) -> bool:
raise NotImplementedError
# noqa
def deserialize_any(self, obj_type) -> typing.Any:
if obj_type in self.primitive_type_deserializer:
return self.primitive_type_deserializer[obj_type]()
elif hasattr(obj_type, "__origin__"): # Generic type
types = getattr(obj_type, "__args__")
if getattr(obj_type, "__origin__") == collections.abc.Sequence: # Sequence
assert len(types) == 1
item_type = types[0]
length = self.deserialize_len()
result = []
for i in range(0, length):
item = self.deserialize_any(item_type)
result.append(item)
return result
elif getattr(obj_type, "__origin__") == tuple: # Tuple
result = []
for i in range(len(types)):
item = self.deserialize_any(types[i])
result.append(item)
return tuple(result)
elif getattr(obj_type, "__origin__") == typing.Union: # Option
assert len(types) == 2 and types[1] == type(None)
tag = int.from_bytes(self.read(1), byteorder="little", signed=False)
if tag == 0:
return None
elif tag == 1:
return self.deserialize_any(types[0])
else:
raise st.DeserializationError("Wrong tag for Option value")
elif getattr(obj_type, "__origin__") == dict: # Map
assert len(types) == 2
length = self.deserialize_len()
result = dict()
previous_key_slice = None
for i in range(0, length):
key_start = self.get_buffer_offset()
key = self.deserialize_any(types[0])
key_end = self.get_buffer_offset()
value = self.deserialize_any(types[1])
key_slice = (key_start, key_end)
if previous_key_slice is not None:
self.check_that_key_slices_are_increasing(previous_key_slice, key_slice)
previous_key_slice = key_slice
result[key] = value
return result
else:
raise st.DeserializationError("Unexpected type", obj_type)
else:
# handle structs
if dataclasses.is_dataclass(obj_type):
values = []
fields = dataclasses.fields(obj_type)
typing_hints = get_type_hints(obj_type)
self.increase_container_depth()
for field in fields:
field_type = typing_hints[field.name]
field_value = self.deserialize_any(field_type)
values.append(field_value)
self.decrease_container_depth()
return obj_type(*values)
# handle variant
elif hasattr(obj_type, "VARIANTS"):
variant_index = self.deserialize_variant_index()
if variant_index not in range(len(obj_type.VARIANTS)):
raise st.DeserializationError("Unexpected variant index", variant_index)
new_type = obj_type.VARIANTS[variant_index]
return self.deserialize_any(new_type)
else:
raise st.DeserializationError("Unexpected type", obj_type)
Classes
class BinaryDeserializer (input: _io.BytesIO, container_depth_budget: Optional[int])
-
Deserialization primitives for binary formats (abstract class).
"Binary" serialization formats may differ in the way they encode sequence lengths, variant index, and how they verify the ordering of keys in map entries (or not).
Expand source code
@dataclasses.dataclass class BinaryDeserializer: """Deserialization primitives for binary formats (abstract class). "Binary" serialization formats may differ in the way they encode sequence lengths, variant index, and how they verify the ordering of keys in map entries (or not). """ input: io.BytesIO container_depth_budget: typing.Optional[int] primitive_type_deserializer: typing.Mapping = dataclasses.field(init=False) def __post_init__(self): self.primitive_type_deserializer = { bool: self.deserialize_bool, st.uint8: self.deserialize_u8, st.uint16: self.deserialize_u16, st.uint32: self.deserialize_u32, st.uint64: self.deserialize_u64, st.uint128: self.deserialize_u128, st.int8: self.deserialize_i8, st.int16: self.deserialize_i16, st.int32: self.deserialize_i32, st.int64: self.deserialize_i64, st.int128: self.deserialize_i128, st.float32: self.deserialize_f32, st.float64: self.deserialize_f64, st.unit: self.deserialize_unit, st.char: self.deserialize_char, str: self.deserialize_str, bytes: self.deserialize_bytes, } def read(self, length: int) -> bytes: value = self.input.read(length) if value is None or len(value) < length: raise st.DeserializationError("Input is too short") return value def deserialize_bytes(self) -> bytes: length = self.deserialize_len() return self.read(length) def deserialize_str(self) -> str: content = self.deserialize_bytes() try: return content.decode() except UnicodeDecodeError: raise st.DeserializationError("Invalid unicode string:", content) def deserialize_unit(self) -> st.unit: pass def deserialize_bool(self) -> bool: b = int.from_bytes(self.read(1), byteorder="little", signed=False) if b == 0: return False elif b == 1: return True else: raise st.DeserializationError("Unexpected boolean value:", b) def deserialize_u8(self) -> st.uint8: return st.uint8(int.from_bytes(self.read(1), byteorder="little", signed=False)) def deserialize_u16(self) -> st.uint16: return st.uint16(int.from_bytes(self.read(2), byteorder="little", signed=False)) def deserialize_u32(self) -> st.uint32: return st.uint32(int.from_bytes(self.read(4), byteorder="little", signed=False)) def deserialize_u64(self) -> st.uint64: return st.uint64(int.from_bytes(self.read(8), byteorder="little", signed=False)) def deserialize_u128(self) -> st.uint128: return st.uint128(int.from_bytes(self.read(16), byteorder="little", signed=False)) def deserialize_i8(self) -> st.int8: return st.int8(int.from_bytes(self.read(1), byteorder="little", signed=True)) def deserialize_i16(self) -> st.int16: return st.int16(int.from_bytes(self.read(2), byteorder="little", signed=True)) def deserialize_i32(self) -> st.int32: return st.int32(int.from_bytes(self.read(4), byteorder="little", signed=True)) def deserialize_i64(self) -> st.int64: return st.int64(int.from_bytes(self.read(8), byteorder="little", signed=True)) def deserialize_i128(self) -> st.int128: return st.int128(int.from_bytes(self.read(16), byteorder="little", signed=True)) def deserialize_f32(self) -> st.float32: raise NotImplementedError def deserialize_f64(self) -> st.float64: raise NotImplementedError def deserialize_char(self) -> st.char: raise NotImplementedError def get_buffer_offset(self) -> int: return self.input.tell() def get_remaining_buffer(self) -> bytes: buf = self.input.getbuffer() offset = self.get_buffer_offset() return bytes(buf[offset:]) def increase_container_depth(self): if self.container_depth_budget is not None: if self.container_depth_budget == 0: raise st.DeserializationError("Exceeded maximum container depth") self.container_depth_budget -= 1 def decrease_container_depth(self): if self.container_depth_budget is not None: self.container_depth_budget += 1 def deserialize_len(self) -> int: raise NotImplementedError def deserialize_variant_index(self) -> int: raise NotImplementedError def check_that_key_slices_are_increasing( self, slice1: typing.Tuple[int, int], slice2: typing.Tuple[int, int] ) -> bool: raise NotImplementedError # noqa def deserialize_any(self, obj_type) -> typing.Any: if obj_type in self.primitive_type_deserializer: return self.primitive_type_deserializer[obj_type]() elif hasattr(obj_type, "__origin__"): # Generic type types = getattr(obj_type, "__args__") if getattr(obj_type, "__origin__") == collections.abc.Sequence: # Sequence assert len(types) == 1 item_type = types[0] length = self.deserialize_len() result = [] for i in range(0, length): item = self.deserialize_any(item_type) result.append(item) return result elif getattr(obj_type, "__origin__") == tuple: # Tuple result = [] for i in range(len(types)): item = self.deserialize_any(types[i]) result.append(item) return tuple(result) elif getattr(obj_type, "__origin__") == typing.Union: # Option assert len(types) == 2 and types[1] == type(None) tag = int.from_bytes(self.read(1), byteorder="little", signed=False) if tag == 0: return None elif tag == 1: return self.deserialize_any(types[0]) else: raise st.DeserializationError("Wrong tag for Option value") elif getattr(obj_type, "__origin__") == dict: # Map assert len(types) == 2 length = self.deserialize_len() result = dict() previous_key_slice = None for i in range(0, length): key_start = self.get_buffer_offset() key = self.deserialize_any(types[0]) key_end = self.get_buffer_offset() value = self.deserialize_any(types[1]) key_slice = (key_start, key_end) if previous_key_slice is not None: self.check_that_key_slices_are_increasing(previous_key_slice, key_slice) previous_key_slice = key_slice result[key] = value return result else: raise st.DeserializationError("Unexpected type", obj_type) else: # handle structs if dataclasses.is_dataclass(obj_type): values = [] fields = dataclasses.fields(obj_type) typing_hints = get_type_hints(obj_type) self.increase_container_depth() for field in fields: field_type = typing_hints[field.name] field_value = self.deserialize_any(field_type) values.append(field_value) self.decrease_container_depth() return obj_type(*values) # handle variant elif hasattr(obj_type, "VARIANTS"): variant_index = self.deserialize_variant_index() if variant_index not in range(len(obj_type.VARIANTS)): raise st.DeserializationError("Unexpected variant index", variant_index) new_type = obj_type.VARIANTS[variant_index] return self.deserialize_any(new_type) else: raise st.DeserializationError("Unexpected type", obj_type)
Subclasses
Class variables
var container_depth_budget : Optional[int]
var input : _io.BytesIO
var primitive_type_deserializer : Mapping
Methods
def check_that_key_slices_are_increasing(self, slice1: Tuple[int, int], slice2: Tuple[int, int]) ‑> bool
-
Expand source code
def check_that_key_slices_are_increasing( self, slice1: typing.Tuple[int, int], slice2: typing.Tuple[int, int] ) -> bool: raise NotImplementedError
def decrease_container_depth(self)
-
Expand source code
def decrease_container_depth(self): if self.container_depth_budget is not None: self.container_depth_budget += 1
def deserialize_any(self, obj_type) ‑> Any
-
Expand source code
def deserialize_any(self, obj_type) -> typing.Any: if obj_type in self.primitive_type_deserializer: return self.primitive_type_deserializer[obj_type]() elif hasattr(obj_type, "__origin__"): # Generic type types = getattr(obj_type, "__args__") if getattr(obj_type, "__origin__") == collections.abc.Sequence: # Sequence assert len(types) == 1 item_type = types[0] length = self.deserialize_len() result = [] for i in range(0, length): item = self.deserialize_any(item_type) result.append(item) return result elif getattr(obj_type, "__origin__") == tuple: # Tuple result = [] for i in range(len(types)): item = self.deserialize_any(types[i]) result.append(item) return tuple(result) elif getattr(obj_type, "__origin__") == typing.Union: # Option assert len(types) == 2 and types[1] == type(None) tag = int.from_bytes(self.read(1), byteorder="little", signed=False) if tag == 0: return None elif tag == 1: return self.deserialize_any(types[0]) else: raise st.DeserializationError("Wrong tag for Option value") elif getattr(obj_type, "__origin__") == dict: # Map assert len(types) == 2 length = self.deserialize_len() result = dict() previous_key_slice = None for i in range(0, length): key_start = self.get_buffer_offset() key = self.deserialize_any(types[0]) key_end = self.get_buffer_offset() value = self.deserialize_any(types[1]) key_slice = (key_start, key_end) if previous_key_slice is not None: self.check_that_key_slices_are_increasing(previous_key_slice, key_slice) previous_key_slice = key_slice result[key] = value return result else: raise st.DeserializationError("Unexpected type", obj_type) else: # handle structs if dataclasses.is_dataclass(obj_type): values = [] fields = dataclasses.fields(obj_type) typing_hints = get_type_hints(obj_type) self.increase_container_depth() for field in fields: field_type = typing_hints[field.name] field_value = self.deserialize_any(field_type) values.append(field_value) self.decrease_container_depth() return obj_type(*values) # handle variant elif hasattr(obj_type, "VARIANTS"): variant_index = self.deserialize_variant_index() if variant_index not in range(len(obj_type.VARIANTS)): raise st.DeserializationError("Unexpected variant index", variant_index) new_type = obj_type.VARIANTS[variant_index] return self.deserialize_any(new_type) else: raise st.DeserializationError("Unexpected type", obj_type)
def deserialize_bool(self) ‑> bool
-
Expand source code
def deserialize_bool(self) -> bool: b = int.from_bytes(self.read(1), byteorder="little", signed=False) if b == 0: return False elif b == 1: return True else: raise st.DeserializationError("Unexpected boolean value:", b)
def deserialize_bytes(self) ‑> bytes
-
Expand source code
def deserialize_bytes(self) -> bytes: length = self.deserialize_len() return self.read(length)
def deserialize_char(self) ‑> char
-
Expand source code
def deserialize_char(self) -> st.char: raise NotImplementedError
def deserialize_f32(self) ‑> numpy.float32
-
Expand source code
def deserialize_f32(self) -> st.float32: raise NotImplementedError
def deserialize_f64(self) ‑> numpy.float64
-
Expand source code
def deserialize_f64(self) -> st.float64: raise NotImplementedError
def deserialize_i128(self) ‑> int128
-
Expand source code
def deserialize_i128(self) -> st.int128: return st.int128(int.from_bytes(self.read(16), byteorder="little", signed=True))
def deserialize_i16(self) ‑> numpy.int16
-
Expand source code
def deserialize_i16(self) -> st.int16: return st.int16(int.from_bytes(self.read(2), byteorder="little", signed=True))
def deserialize_i32(self) ‑> numpy.int32
-
Expand source code
def deserialize_i32(self) -> st.int32: return st.int32(int.from_bytes(self.read(4), byteorder="little", signed=True))
def deserialize_i64(self) ‑> numpy.int64
-
Expand source code
def deserialize_i64(self) -> st.int64: return st.int64(int.from_bytes(self.read(8), byteorder="little", signed=True))
def deserialize_i8(self) ‑> numpy.int8
-
Expand source code
def deserialize_i8(self) -> st.int8: return st.int8(int.from_bytes(self.read(1), byteorder="little", signed=True))
def deserialize_len(self) ‑> int
-
Expand source code
def deserialize_len(self) -> int: raise NotImplementedError
def deserialize_str(self) ‑> str
-
Expand source code
def deserialize_str(self) -> str: content = self.deserialize_bytes() try: return content.decode() except UnicodeDecodeError: raise st.DeserializationError("Invalid unicode string:", content)
def deserialize_u128(self) ‑> uint128
-
Expand source code
def deserialize_u128(self) -> st.uint128: return st.uint128(int.from_bytes(self.read(16), byteorder="little", signed=False))
def deserialize_u16(self) ‑> numpy.uint16
-
Expand source code
def deserialize_u16(self) -> st.uint16: return st.uint16(int.from_bytes(self.read(2), byteorder="little", signed=False))
def deserialize_u32(self) ‑> numpy.uint32
-
Expand source code
def deserialize_u32(self) -> st.uint32: return st.uint32(int.from_bytes(self.read(4), byteorder="little", signed=False))
def deserialize_u64(self) ‑> numpy.uint64
-
Expand source code
def deserialize_u64(self) -> st.uint64: return st.uint64(int.from_bytes(self.read(8), byteorder="little", signed=False))
def deserialize_u8(self) ‑> numpy.uint8
-
Expand source code
def deserialize_u8(self) -> st.uint8: return st.uint8(int.from_bytes(self.read(1), byteorder="little", signed=False))
def deserialize_unit(self) ‑> Type[NoneType]
-
Expand source code
def deserialize_unit(self) -> st.unit: pass
def deserialize_variant_index(self) ‑> int
-
Expand source code
def deserialize_variant_index(self) -> int: raise NotImplementedError
def get_buffer_offset(self) ‑> int
-
Expand source code
def get_buffer_offset(self) -> int: return self.input.tell()
def get_remaining_buffer(self) ‑> bytes
-
Expand source code
def get_remaining_buffer(self) -> bytes: buf = self.input.getbuffer() offset = self.get_buffer_offset() return bytes(buf[offset:])
def increase_container_depth(self)
-
Expand source code
def increase_container_depth(self): if self.container_depth_budget is not None: if self.container_depth_budget == 0: raise st.DeserializationError("Exceeded maximum container depth") self.container_depth_budget -= 1
def read(self, length: int) ‑> bytes
-
Expand source code
def read(self, length: int) -> bytes: value = self.input.read(length) if value is None or len(value) < length: raise st.DeserializationError("Input is too short") return value
class BinarySerializer (output: _io.BytesIO, container_depth_budget: Optional[int])
-
Serialization primitives for binary formats (abstract class).
"Binary" serialization formats may differ in the way they encode sequence lengths, variant index, and how they sort map entries (or not).
Expand source code
@dataclasses.dataclass class BinarySerializer: """Serialization primitives for binary formats (abstract class). "Binary" serialization formats may differ in the way they encode sequence lengths, variant index, and how they sort map entries (or not). """ output: io.BytesIO container_depth_budget: typing.Optional[int] primitive_type_serializer: typing.Mapping = dataclasses.field(init=False) def __post_init__(self): self.primitive_type_serializer = { bool: self.serialize_bool, st.uint8: self.serialize_u8, st.uint16: self.serialize_u16, st.uint32: self.serialize_u32, st.uint64: self.serialize_u64, st.uint128: self.serialize_u128, st.int8: self.serialize_i8, st.int16: self.serialize_i16, st.int32: self.serialize_i32, st.int64: self.serialize_i64, st.int128: self.serialize_i128, st.float32: self.serialize_f32, st.float64: self.serialize_f64, st.unit: self.serialize_unit, st.char: self.serialize_char, str: self.serialize_str, bytes: self.serialize_bytes, } def serialize_bytes(self, value: bytes): self.serialize_len(len(value)) self.output.write(value) def serialize_str(self, value: str): self.serialize_bytes(value.encode()) def serialize_unit(self, value: st.unit): pass def serialize_bool(self, value: bool): self.output.write(int(value).to_bytes(1, "little", signed=False)) def serialize_u8(self, value: st.uint8): self.output.write(int(value).to_bytes(1, "little", signed=False)) def serialize_u16(self, value: st.uint16): self.output.write(int(value).to_bytes(2, "little", signed=False)) def serialize_u32(self, value: st.uint32): self.output.write(int(value).to_bytes(4, "little", signed=False)) def serialize_u64(self, value: st.uint64): self.output.write(int(value).to_bytes(8, "little", signed=False)) def serialize_u128(self, value: st.uint128): self.output.write(int(value).to_bytes(16, "little", signed=False)) def serialize_i8(self, value: st.uint8): self.output.write(int(value).to_bytes(1, "little", signed=True)) def serialize_i16(self, value: st.uint16): self.output.write(int(value).to_bytes(2, "little", signed=True)) def serialize_i32(self, value: st.uint32): self.output.write(int(value).to_bytes(4, "little", signed=True)) def serialize_i64(self, value: st.uint64): self.output.write(int(value).to_bytes(8, "little", signed=True)) def serialize_i128(self, value: st.uint128): self.output.write(int(value).to_bytes(16, "little", signed=True)) def serialize_f32(self, value: st.float32): raise NotImplementedError def serialize_f64(self, value: st.float64): raise NotImplementedError def serialize_char(self, value: st.char): raise NotImplementedError def get_buffer_offset(self) -> int: return len(self.output.getbuffer()) def get_buffer(self) -> bytes: return self.output.getvalue() def increase_container_depth(self): if self.container_depth_budget is not None: if self.container_depth_budget == 0: raise st.SerializationError("Exceeded maximum container depth") self.container_depth_budget -= 1 def decrease_container_depth(self): if self.container_depth_budget is not None: self.container_depth_budget += 1 def serialize_len(self, value: int): raise NotImplementedError def serialize_variant_index(self, value: int): raise NotImplementedError def sort_map_entries(self, offsets: typing.List[int]): raise NotImplementedError # noqa: C901 def serialize_any(self, obj: typing.Any, obj_type): if obj_type in self.primitive_type_serializer: self.primitive_type_serializer[obj_type](obj) elif hasattr(obj_type, "__origin__"): # Generic type types = getattr(obj_type, "__args__") if getattr(obj_type, "__origin__") == collections.abc.Sequence: # Sequence assert len(types) == 1 item_type = types[0] self.serialize_len(len(obj)) for item in obj: self.serialize_any(item, item_type) elif getattr(obj_type, "__origin__") == tuple: # Tuple for i in range(len(obj)): self.serialize_any(obj[i], types[i]) elif getattr(obj_type, "__origin__") == typing.Union: # Option assert len(types) == 2 and types[1] == type(None) if obj is None: self.output.write(b"\x00") else: self.output.write(b"\x01") self.serialize_any(obj, types[0]) elif getattr(obj_type, "__origin__") == dict: # Map assert len(types) == 2 self.serialize_len(len(obj)) offsets = [] for key, value in obj.items(): offsets.append(self.get_buffer_offset()) self.serialize_any(key, types[0]) self.serialize_any(value, types[1]) self.sort_map_entries(offsets) else: raise st.SerializationError("Unexpected type", obj_type) else: if not dataclasses.is_dataclass(obj_type): # Enum if not hasattr(obj_type, "VARIANTS"): raise st.SerializationError("Unexpected type", obj_type) if not hasattr(obj, "INDEX"): raise st.SerializationError("Wrong Value for the type", obj, obj_type) self.serialize_variant_index(obj.__class__.INDEX) # Proceed to variant obj_type = obj_type.VARIANTS[obj.__class__.INDEX] if not dataclasses.is_dataclass(obj_type): raise st.SerializationError("Unexpected type", obj_type) # pyre-ignore if not isinstance(obj, obj_type): raise st.SerializationError("Wrong Value for the type", obj, obj_type) # Content of struct or variant fields = dataclasses.fields(obj_type) types = get_type_hints(obj_type) self.increase_container_depth() for field in fields: field_value = obj.__dict__[field.name] field_type = types[field.name] self.serialize_any(field_value, field_type) self.decrease_container_depth()
Subclasses
Class variables
var container_depth_budget : Optional[int]
var output : _io.BytesIO
var primitive_type_serializer : Mapping
Methods
def decrease_container_depth(self)
-
Expand source code
def decrease_container_depth(self): if self.container_depth_budget is not None: self.container_depth_budget += 1
def get_buffer(self) ‑> bytes
-
Expand source code
def get_buffer(self) -> bytes: return self.output.getvalue()
def get_buffer_offset(self) ‑> int
-
Expand source code
def get_buffer_offset(self) -> int: return len(self.output.getbuffer())
def increase_container_depth(self)
-
Expand source code
def increase_container_depth(self): if self.container_depth_budget is not None: if self.container_depth_budget == 0: raise st.SerializationError("Exceeded maximum container depth") self.container_depth_budget -= 1
def serialize_any(self, obj: Any, obj_type)
-
Expand source code
def serialize_any(self, obj: typing.Any, obj_type): if obj_type in self.primitive_type_serializer: self.primitive_type_serializer[obj_type](obj) elif hasattr(obj_type, "__origin__"): # Generic type types = getattr(obj_type, "__args__") if getattr(obj_type, "__origin__") == collections.abc.Sequence: # Sequence assert len(types) == 1 item_type = types[0] self.serialize_len(len(obj)) for item in obj: self.serialize_any(item, item_type) elif getattr(obj_type, "__origin__") == tuple: # Tuple for i in range(len(obj)): self.serialize_any(obj[i], types[i]) elif getattr(obj_type, "__origin__") == typing.Union: # Option assert len(types) == 2 and types[1] == type(None) if obj is None: self.output.write(b"\x00") else: self.output.write(b"\x01") self.serialize_any(obj, types[0]) elif getattr(obj_type, "__origin__") == dict: # Map assert len(types) == 2 self.serialize_len(len(obj)) offsets = [] for key, value in obj.items(): offsets.append(self.get_buffer_offset()) self.serialize_any(key, types[0]) self.serialize_any(value, types[1]) self.sort_map_entries(offsets) else: raise st.SerializationError("Unexpected type", obj_type) else: if not dataclasses.is_dataclass(obj_type): # Enum if not hasattr(obj_type, "VARIANTS"): raise st.SerializationError("Unexpected type", obj_type) if not hasattr(obj, "INDEX"): raise st.SerializationError("Wrong Value for the type", obj, obj_type) self.serialize_variant_index(obj.__class__.INDEX) # Proceed to variant obj_type = obj_type.VARIANTS[obj.__class__.INDEX] if not dataclasses.is_dataclass(obj_type): raise st.SerializationError("Unexpected type", obj_type) # pyre-ignore if not isinstance(obj, obj_type): raise st.SerializationError("Wrong Value for the type", obj, obj_type) # Content of struct or variant fields = dataclasses.fields(obj_type) types = get_type_hints(obj_type) self.increase_container_depth() for field in fields: field_value = obj.__dict__[field.name] field_type = types[field.name] self.serialize_any(field_value, field_type) self.decrease_container_depth()
def serialize_bool(self, value: bool)
-
Expand source code
def serialize_bool(self, value: bool): self.output.write(int(value).to_bytes(1, "little", signed=False))
def serialize_bytes(self, value: bytes)
-
Expand source code
def serialize_bytes(self, value: bytes): self.serialize_len(len(value)) self.output.write(value)
def serialize_char(self, value: char)
-
Expand source code
def serialize_char(self, value: st.char): raise NotImplementedError
def serialize_f32(self, value: numpy.float32)
-
Expand source code
def serialize_f32(self, value: st.float32): raise NotImplementedError
def serialize_f64(self, value: numpy.float64)
-
Expand source code
def serialize_f64(self, value: st.float64): raise NotImplementedError
def serialize_i128(self, value: uint128)
-
Expand source code
def serialize_i128(self, value: st.uint128): self.output.write(int(value).to_bytes(16, "little", signed=True))
def serialize_i16(self, value: numpy.uint16)
-
Expand source code
def serialize_i16(self, value: st.uint16): self.output.write(int(value).to_bytes(2, "little", signed=True))
def serialize_i32(self, value: numpy.uint32)
-
Expand source code
def serialize_i32(self, value: st.uint32): self.output.write(int(value).to_bytes(4, "little", signed=True))
def serialize_i64(self, value: numpy.uint64)
-
Expand source code
def serialize_i64(self, value: st.uint64): self.output.write(int(value).to_bytes(8, "little", signed=True))
def serialize_i8(self, value: numpy.uint8)
-
Expand source code
def serialize_i8(self, value: st.uint8): self.output.write(int(value).to_bytes(1, "little", signed=True))
def serialize_len(self, value: int)
-
Expand source code
def serialize_len(self, value: int): raise NotImplementedError
def serialize_str(self, value: str)
-
Expand source code
def serialize_str(self, value: str): self.serialize_bytes(value.encode())
def serialize_u128(self, value: uint128)
-
Expand source code
def serialize_u128(self, value: st.uint128): self.output.write(int(value).to_bytes(16, "little", signed=False))
def serialize_u16(self, value: numpy.uint16)
-
Expand source code
def serialize_u16(self, value: st.uint16): self.output.write(int(value).to_bytes(2, "little", signed=False))
def serialize_u32(self, value: numpy.uint32)
-
Expand source code
def serialize_u32(self, value: st.uint32): self.output.write(int(value).to_bytes(4, "little", signed=False))
def serialize_u64(self, value: numpy.uint64)
-
Expand source code
def serialize_u64(self, value: st.uint64): self.output.write(int(value).to_bytes(8, "little", signed=False))
def serialize_u8(self, value: numpy.uint8)
-
Expand source code
def serialize_u8(self, value: st.uint8): self.output.write(int(value).to_bytes(1, "little", signed=False))
def serialize_unit(self, value: Type[NoneType])
-
Expand source code
def serialize_unit(self, value: st.unit): pass
def serialize_variant_index(self, value: int)
-
Expand source code
def serialize_variant_index(self, value: int): raise NotImplementedError
def sort_map_entries(self, offsets: List[int])
-
Expand source code
def sort_map_entries(self, offsets: typing.List[int]): raise NotImplementedError