diff --git a/lib/lambda_ethereum_consensus/state_transition/operations.ex b/lib/lambda_ethereum_consensus/state_transition/operations.ex index c2762a54e..3cb4263c9 100644 --- a/lib/lambda_ethereum_consensus/state_transition/operations.ex +++ b/lib/lambda_ethereum_consensus/state_transition/operations.ex @@ -852,7 +852,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do end defp check_matching_aggregation_bits_length(attestation, beacon_committee) do - if length_of_bitlist(attestation.aggregation_bits) == length(beacon_committee) do + if SszEx.length_of_bitlist(attestation.aggregation_bits) == length(beacon_committee) do :ok else {:error, "Mismatched aggregation bits length"} @@ -867,22 +867,6 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do end end - defp length_of_bitlist(bitlist) when is_binary(bitlist) do - bit_size = bit_size(bitlist) - <<_::size(bit_size - 8), last_byte>> = bitlist - bit_size - leading_zeros(<>) - 1 - end - - defp leading_zeros(<<1::1, _::7>>), do: 0 - defp leading_zeros(<<0::1, 1::1, _::6>>), do: 1 - defp leading_zeros(<<0::2, 1::1, _::5>>), do: 2 - defp leading_zeros(<<0::3, 1::1, _::4>>), do: 3 - defp leading_zeros(<<0::4, 1::1, _::3>>), do: 4 - defp leading_zeros(<<0::5, 1::1, _::2>>), do: 5 - defp leading_zeros(<<0::6, 1::1, _::1>>), do: 6 - defp leading_zeros(<<0::7, 1::1>>), do: 7 - defp leading_zeros(<<0::8>>), do: 8 - def process_bls_to_execution_change(state, signed_address_change) do address_change = signed_address_change.message diff --git a/lib/ssz_ex.ex b/lib/ssz_ex.ex index b029c9f91..6686985f5 100644 --- a/lib/ssz_ex.ex +++ b/lib/ssz_ex.ex @@ -2,6 +2,8 @@ defmodule LambdaEthereumConsensus.SszEx do @moduledoc """ SSZ library in Elixir """ + alias LambdaEthereumConsensus.Utils.BitVector + import alias LambdaEthereumConsensus.Utils.BitVector ################# ### Public API @@ -23,6 +25,15 @@ defmodule LambdaEthereumConsensus.SszEx do def encode(value, {:bytes, _}), do: {:ok, value} + def encode(value, {:bitlist, max_size}) when is_bitstring(value), + do: encode_bitlist(value, max_size) + + def encode(value, {:bitlist, max_size}) when is_integer(value), + do: encode_bitlist(:binary.encode_unsigned(value), max_size) + + def encode(value, {:bitvector, size}) when is_bitvector(value), + do: encode_bitvector(value, size) + def encode(container, module) when is_map(container), do: encode_container(container, module.schema()) @@ -36,6 +47,12 @@ defmodule LambdaEthereumConsensus.SszEx do else: decode_list(binary, basic_type, size) end + def decode(value, {:bitlist, max_size}) when is_bitstring(value), + do: decode_bitlist(value, max_size) + + def decode(value, {:bitvector, size}) when is_bitstring(value), + do: decode_bitvector(value, size) + def decode(binary, module) when is_atom(module), do: decode_container(binary, module) @spec hash_tree_root!(boolean, atom) :: Types.root() @@ -72,6 +89,23 @@ defmodule LambdaEthereumConsensus.SszEx do |> flatten_results_by(&Enum.join/1) end + defp encode_bitlist(bit_list, max_size) do + len = bit_size(bit_list) + + if len > max_size do + {:error, "excess bits"} + else + r = rem(len, @bits_per_byte) + <> = bit_list + {:ok, <>} + end + end + + defp encode_bitvector(bit_vector, size) when bit_vector_size(bit_vector) == size, + do: {:ok, BitVector.to_bytes(bit_vector)} + + defp encode_bitvector(_bit_vector, _size), do: {:error, "invalid bit_vector length"} + defp encode_variable_size_list(list, _basic_type, max_size) when length(list) > max_size, do: {:error, "invalid max_size of list"} @@ -114,6 +148,33 @@ defmodule LambdaEthereumConsensus.SszEx do end end + defp decode_bitlist(bit_list, max_size) do + num_bytes = byte_size(bit_list) + num_bits = bit_size(bit_list) + len = length_of_bitlist(bit_list) + <> = bit_list + decoded = <>)::bitstring>> + + cond do + len < 0 -> + {:error, "missing length information"} + + div(len, @bits_per_byte) + 1 != num_bytes -> + {:error, "invalid byte count"} + + len > max_size -> + {:error, "out of bounds"} + + true -> + {:ok, decoded} + end + end + + defp decode_bitvector(bit_vector, size) when bit_size(bit_vector) == size, + do: {:ok, BitVector.new(bit_vector, size)} + + defp decode_bitvector(_bit_vector, _size), do: {:error, "invalid bit_vector length"} + defp decode_list(binary, basic_type, size) do fixed_size = get_fixed_size(basic_type) @@ -407,6 +468,33 @@ defmodule LambdaEthereumConsensus.SszEx do |> Enum.any?() end + def length_of_bitlist(bitlist) when is_binary(bitlist) do + bit_size = bit_size(bitlist) + <<_::size(bit_size - 8), last_byte>> = bitlist + bit_size - leading_zeros(<>) - 1 + end + + defp leading_zeros(<<1::1, _::7>>), do: 0 + defp leading_zeros(<<0::1, 1::1, _::6>>), do: 1 + defp leading_zeros(<<0::2, 1::1, _::5>>), do: 2 + defp leading_zeros(<<0::3, 1::1, _::4>>), do: 3 + defp leading_zeros(<<0::4, 1::1, _::3>>), do: 4 + defp leading_zeros(<<0::5, 1::1, _::2>>), do: 5 + defp leading_zeros(<<0::6, 1::1, _::1>>), do: 6 + defp leading_zeros(<<0::7, 1::1>>), do: 7 + defp leading_zeros(<<0::8>>), do: 8 + + @spec remove_trailing_bit(binary()) :: bitstring() + defp remove_trailing_bit(<<1::1, rest::7>>), do: <> + defp remove_trailing_bit(<<0::1, 1::1, rest::6>>), do: <> + defp remove_trailing_bit(<<0::2, 1::1, rest::5>>), do: <> + defp remove_trailing_bit(<<0::3, 1::1, rest::4>>), do: <> + defp remove_trailing_bit(<<0::4, 1::1, rest::3>>), do: <> + defp remove_trailing_bit(<<0::5, 1::1, rest::2>>), do: <> + defp remove_trailing_bit(<<0::6, 1::1, rest::1>>), do: <> + defp remove_trailing_bit(<<0::7, 1::1>>), do: <<0::0>> + defp remove_trailing_bit(<<0::8>>), do: <<0::0>> + defp pack(value, size) when is_integer(value) and value >= 0 do pad = @bits_per_chunk - size <> diff --git a/lib/utils/bit_vector.ex b/lib/utils/bit_vector.ex index 5a636d2b0..a739c6807 100644 --- a/lib/utils/bit_vector.ex +++ b/lib/utils/bit_vector.ex @@ -33,7 +33,11 @@ defmodule LambdaEthereumConsensus.Utils.BitVector do # The internal representation is a bitstring, but we could evaluate # turning it into an integer to use bitwise operations instead. - @opaque t :: bitstring + @type t :: bitstring + + defguard is_bitvector(value) when is_bitstring(value) + + defguard bit_vector_size(value) when bit_size(value) @doc """ Creates a new bit_vector from an integer or a bitstring. diff --git a/test/unit/ssz_ex_test.exs b/test/unit/ssz_ex_test.exs index b788328e0..1e8417304 100644 --- a/test/unit/ssz_ex_test.exs +++ b/test/unit/ssz_ex_test.exs @@ -187,4 +187,52 @@ defmodule Unit.SSZExTest do assert_roundtrip(serialized, sync, Types.SyncCommittee) end + + test "serialize and deserialize bitlist" do + encoded_bytes = <<160, 92, 1>> + assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 16}) + assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 16}) + + encoded_bytes = <<255, 1>> + assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 16}) + assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 16}) + + encoded_bytes = <<31>> + assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 16}) + assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 16}) + + encoded_bytes = <<1>> + assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 31}) + assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 31}) + + encoded_bytes = <<106, 141, 117, 7>> + assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 31}) + assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 31}) + + encoded_bytes = <<7>> + assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitlist, 1}) + + encoded_bytes = <<124, 3>> + assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitlist, 1}) + + encoded_bytes = <<0>> + assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitlist, 1}) + assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitlist, 16}) + end + + test "serialize and deserialize bitvector" do + encoded_bytes = <<255, 255>> + assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitvector, 16}) + assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitvector, 16}) + + encoded_bytes = <<0, 0>> + assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitvector, 16}) + assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitvector, 16}) + + encoded_bytes = <<255, 255, 255, 255, 1>> + assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitvector, 33}) + + encoded_bytes = <<0>> + assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitvector, 9}) + end end