Skip to content

Commit

Permalink
refactor(framework) Introduce UintList and SintList (#4267)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Sep 27, 2024
1 parent 205b20e commit e521c53
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 104 deletions.
16 changes: 6 additions & 10 deletions src/proto/flwr/proto/recordset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,9 @@ syntax = "proto3";

package flwr.proto;

message Int {
oneof int {
sint64 sint64 = 1;
uint64 uint64 = 2;
}
}

message DoubleList { repeated double vals = 1; }
message IntList { repeated Int vals = 1; }
message SintList { repeated sint64 vals = 1; }
message UintList { repeated uint64 vals = 1; }
message BoolList { repeated bool vals = 1; }
message StringList { repeated string vals = 1; }
message BytesList { repeated bytes vals = 1; }
Expand All @@ -46,7 +40,8 @@ message MetricsRecordValue {

// List types
DoubleList double_list = 21;
IntList int_list = 22;
SintList sint_list = 22;
UintList uint_list = 23;
}
}

Expand All @@ -62,7 +57,8 @@ message ConfigsRecordValue {

// List types
DoubleList double_list = 21;
IntList int_list = 22;
SintList sint_list = 22;
UintList uint_list = 23;
BoolList bool_list = 24;
StringList string_list = 25;
BytesList bytes_list = 26;
Expand Down
41 changes: 17 additions & 24 deletions src/py/flwr/common/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from flwr.proto.recordset_pb2 import BoolList, BytesList
from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord
from flwr.proto.recordset_pb2 import ConfigsRecordValue as ProtoConfigsRecordValue
from flwr.proto.recordset_pb2 import DoubleList, Int, IntList
from flwr.proto.recordset_pb2 import DoubleList
from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord
from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue
from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord
from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet
from flwr.proto.recordset_pb2 import StringList
from flwr.proto.recordset_pb2 import SintList, StringList, UintList
from flwr.proto.run_pb2 import Run as ProtoRun
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import (
Expand Down Expand Up @@ -340,6 +340,7 @@ def metrics_from_proto(proto: Any) -> typing.Metrics:


# === Scalar messages ===
INT64_MAX_VALUE = 9223372036854775807 # (1 << 63) - 1


def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
Expand All @@ -354,9 +355,10 @@ def scalar_to_proto(scalar: typing.Scalar) -> Scalar:
return Scalar(double=scalar)

if isinstance(scalar, int):
if scalar >= 0:
return Scalar(uint64=scalar) # Use uint64 for non-negative integers
return Scalar(sint64=scalar) # Use sint64 for negative integers
# Use uint64 for integers larger than the maximum value of sint64
if scalar > INT64_MAX_VALUE:
return Scalar(uint64=scalar)
return Scalar(sint64=scalar)

if isinstance(scalar, str):
return Scalar(string=scalar)
Expand All @@ -378,32 +380,24 @@ def scalar_from_proto(scalar_msg: Scalar) -> typing.Scalar:

_type_to_field: dict[type, str] = {
float: "double",
int: "int",
int: "sint64",
bool: "bool",
str: "string",
bytes: "bytes",
}
_list_type_to_class_and_field: dict[type, tuple[type[GrpcMessage], str]] = {
float: (DoubleList, "double_list"),
int: (IntList, "int_list"),
int: (SintList, "sint_list"),
bool: (BoolList, "bool_list"),
str: (StringList, "string_list"),
bytes: (BytesList, "bytes_list"),
}
T = TypeVar("T")


def int_to_proto(value: int) -> Int:
"""Serialize a int to `Int`."""
if value >= 0:
return Int(uint64=value)
return Int(sint64=value)


def int_from_proto(value_proto: Int) -> int:
"""Deserialize a int from `Int`."""
fld = cast(str, value_proto.WhichOneof("int"))
return cast(int, getattr(value_proto, fld))
def _is_uint64(value: Any) -> bool:
"""Check if a value is uint64."""
return isinstance(value, int) and value > INT64_MAX_VALUE


def _record_value_to_proto(
Expand All @@ -419,15 +413,16 @@ def _record_value_to_proto(
# Note: `isinstance(False, int) == True`.
if isinstance(value, t):
fld = _type_to_field[t]
if t is int:
fld = "uint64" if cast(int, value) >= 0 else "sint64"
if t is int and _is_uint64(value):
fld = "uint64"
arg[fld] = value
return proto_class(**arg)
# List
if isinstance(value, list) and all(isinstance(item, t) for item in value):
list_class, fld = _list_type_to_class_and_field[t]
if t is int:
value = [int_to_proto(v) for v in value]
# Use UintList if any element is of type `uint64`.
if t is int and any(_is_uint64(v) for v in value):
list_class, fld = UintList, "uint_list"
arg[fld] = list_class(vals=value)
return proto_class(**arg)
# Invalid types
Expand All @@ -442,8 +437,6 @@ def _record_value_from_proto(value_proto: GrpcMessage) -> Any:
value_field = cast(str, value_proto.WhichOneof("value"))
if value_field.endswith("list"):
value = list(getattr(value_proto, value_field).vals)
if value_field == "int_list":
value = [int_from_proto(v) for v in value]
else:
value = getattr(value_proto, value_field)
return value
Expand Down
8 changes: 6 additions & 2 deletions src/py/flwr/common/serde_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,21 @@ def get_str(self, length: Optional[int] = None) -> str:
length = self.rng.randint(1, 10)
return "".join(self.rng.choices(char_pool, k=length))

def get_value(self, dtype: type[T]) -> T:
def get_value(self, dtype: Union[type[T], str]) -> T:
"""Create a value of a given type."""
ret: Any = None
if dtype == bool:
ret = self.rng.random() < 0.5
elif dtype == str:
ret = self.get_str(self.rng.randint(10, 100))
elif dtype == int:
ret = self.rng.randint(-1 << 63, (1 << 64) - 1)
ret = self.rng.randint(-1 << 63, (1 << 63) - 1)
elif dtype == float:
ret = (self.rng.random() - 0.5) * (2.0 ** self.rng.randint(0, 50))
elif dtype == bytes:
ret = self.randbytes(self.rng.randint(10, 100))
elif dtype == "uint":
ret = self.rng.randint(0, (1 << 64) - 1)
else:
raise NotImplementedError(f"Unsupported dtype: {dtype}")
return cast(T, ret)
Expand Down Expand Up @@ -317,6 +319,7 @@ def test_metrics_record_serialization_deserialization() -> None:
maker = RecordMaker()
original = maker.metrics_record()
original["uint64"] = (1 << 63) + 321
original["list of uint64"] = [maker.get_value("uint") for _ in range(30)]

# Execute
proto = metrics_record_to_proto(original)
Expand All @@ -333,6 +336,7 @@ def test_configs_record_serialization_deserialization() -> None:
maker = RecordMaker()
original = maker.configs_record()
original["uint64"] = (1 << 63) + 101
original["list of uint64"] = [maker.get_value("uint") for _ in range(100)]

# Execute
proto = configs_record_to_proto(original)
Expand Down
74 changes: 37 additions & 37 deletions src/py/flwr/proto/recordset_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit e521c53

Please sign in to comment.