Skip to content

Commit

Permalink
feat: add bitlist and bitvector ssz support (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
f3r10 authored Jan 4, 2024
1 parent f288571 commit f2fe509
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 18 deletions.
18 changes: 1 addition & 17 deletions lib/lambda_ethereum_consensus/state_transition/operations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
end

defp check_matching_aggregation_bits_length(attestation, beacon_committee) do
if length_of_bitlist(attestation.aggregation_bits) == length(beacon_committee) do
if SszEx.length_of_bitlist(attestation.aggregation_bits) == length(beacon_committee) do
:ok
else
{:error, "Mismatched aggregation bits length"}
Expand All @@ -867,22 +867,6 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
end
end

defp length_of_bitlist(bitlist) when is_binary(bitlist) do
bit_size = bit_size(bitlist)
<<_::size(bit_size - 8), last_byte>> = bitlist
bit_size - leading_zeros(<<last_byte>>) - 1
end

defp leading_zeros(<<1::1, _::7>>), do: 0
defp leading_zeros(<<0::1, 1::1, _::6>>), do: 1
defp leading_zeros(<<0::2, 1::1, _::5>>), do: 2
defp leading_zeros(<<0::3, 1::1, _::4>>), do: 3
defp leading_zeros(<<0::4, 1::1, _::3>>), do: 4
defp leading_zeros(<<0::5, 1::1, _::2>>), do: 5
defp leading_zeros(<<0::6, 1::1, _::1>>), do: 6
defp leading_zeros(<<0::7, 1::1>>), do: 7
defp leading_zeros(<<0::8>>), do: 8

def process_bls_to_execution_change(state, signed_address_change) do
address_change = signed_address_change.message

Expand Down
88 changes: 88 additions & 0 deletions lib/ssz_ex.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ defmodule LambdaEthereumConsensus.SszEx do
@moduledoc """
SSZ library in Elixir
"""
alias LambdaEthereumConsensus.Utils.BitVector
import alias LambdaEthereumConsensus.Utils.BitVector

#################
### Public API
Expand All @@ -23,6 +25,15 @@ defmodule LambdaEthereumConsensus.SszEx do

def encode(value, {:bytes, _}), do: {:ok, value}

def encode(value, {:bitlist, max_size}) when is_bitstring(value),
do: encode_bitlist(value, max_size)

def encode(value, {:bitlist, max_size}) when is_integer(value),
do: encode_bitlist(:binary.encode_unsigned(value), max_size)

def encode(value, {:bitvector, size}) when is_bitvector(value),
do: encode_bitvector(value, size)

def encode(container, module) when is_map(container),
do: encode_container(container, module.schema())

Expand All @@ -36,6 +47,12 @@ defmodule LambdaEthereumConsensus.SszEx do
else: decode_list(binary, basic_type, size)
end

def decode(value, {:bitlist, max_size}) when is_bitstring(value),
do: decode_bitlist(value, max_size)

def decode(value, {:bitvector, size}) when is_bitstring(value),
do: decode_bitvector(value, size)

def decode(binary, module) when is_atom(module), do: decode_container(binary, module)

@spec hash_tree_root!(boolean, atom) :: Types.root()
Expand Down Expand Up @@ -72,6 +89,23 @@ defmodule LambdaEthereumConsensus.SszEx do
|> flatten_results_by(&Enum.join/1)
end

defp encode_bitlist(bit_list, max_size) do
len = bit_size(bit_list)

if len > max_size do
{:error, "excess bits"}
else
r = rem(len, @bits_per_byte)
<<pre::bitstring-size(len - r), post::bitstring-size(r)>> = bit_list
{:ok, <<pre::bitstring, 1::size(@bits_per_byte - r), post::bitstring>>}
end
end

defp encode_bitvector(bit_vector, size) when bit_vector_size(bit_vector) == size,
do: {:ok, BitVector.to_bytes(bit_vector)}

defp encode_bitvector(_bit_vector, _size), do: {:error, "invalid bit_vector length"}

defp encode_variable_size_list(list, _basic_type, max_size) when length(list) > max_size,
do: {:error, "invalid max_size of list"}

Expand Down Expand Up @@ -114,6 +148,33 @@ defmodule LambdaEthereumConsensus.SszEx do
end
end

defp decode_bitlist(bit_list, max_size) do
num_bytes = byte_size(bit_list)
num_bits = bit_size(bit_list)
len = length_of_bitlist(bit_list)
<<pre::size(num_bits - 8), last_byte::8>> = bit_list
decoded = <<pre::size(num_bits - 8), remove_trailing_bit(<<last_byte>>)::bitstring>>

cond do
len < 0 ->
{:error, "missing length information"}

div(len, @bits_per_byte) + 1 != num_bytes ->
{:error, "invalid byte count"}

len > max_size ->
{:error, "out of bounds"}

true ->
{:ok, decoded}
end
end

defp decode_bitvector(bit_vector, size) when bit_size(bit_vector) == size,
do: {:ok, BitVector.new(bit_vector, size)}

defp decode_bitvector(_bit_vector, _size), do: {:error, "invalid bit_vector length"}

defp decode_list(binary, basic_type, size) do
fixed_size = get_fixed_size(basic_type)

Expand Down Expand Up @@ -407,6 +468,33 @@ defmodule LambdaEthereumConsensus.SszEx do
|> Enum.any?()
end

def length_of_bitlist(bitlist) when is_binary(bitlist) do
bit_size = bit_size(bitlist)
<<_::size(bit_size - 8), last_byte>> = bitlist
bit_size - leading_zeros(<<last_byte>>) - 1
end

defp leading_zeros(<<1::1, _::7>>), do: 0
defp leading_zeros(<<0::1, 1::1, _::6>>), do: 1
defp leading_zeros(<<0::2, 1::1, _::5>>), do: 2
defp leading_zeros(<<0::3, 1::1, _::4>>), do: 3
defp leading_zeros(<<0::4, 1::1, _::3>>), do: 4
defp leading_zeros(<<0::5, 1::1, _::2>>), do: 5
defp leading_zeros(<<0::6, 1::1, _::1>>), do: 6
defp leading_zeros(<<0::7, 1::1>>), do: 7
defp leading_zeros(<<0::8>>), do: 8

@spec remove_trailing_bit(binary()) :: bitstring()
defp remove_trailing_bit(<<1::1, rest::7>>), do: <<rest::7>>
defp remove_trailing_bit(<<0::1, 1::1, rest::6>>), do: <<rest::6>>
defp remove_trailing_bit(<<0::2, 1::1, rest::5>>), do: <<rest::5>>
defp remove_trailing_bit(<<0::3, 1::1, rest::4>>), do: <<rest::4>>
defp remove_trailing_bit(<<0::4, 1::1, rest::3>>), do: <<rest::3>>
defp remove_trailing_bit(<<0::5, 1::1, rest::2>>), do: <<rest::2>>
defp remove_trailing_bit(<<0::6, 1::1, rest::1>>), do: <<rest::1>>
defp remove_trailing_bit(<<0::7, 1::1>>), do: <<0::0>>
defp remove_trailing_bit(<<0::8>>), do: <<0::0>>

defp pack(value, size) when is_integer(value) and value >= 0 do
pad = @bits_per_chunk - size
<<value::size(size)-little, 0::size(pad)>>
Expand Down
6 changes: 5 additions & 1 deletion lib/utils/bit_vector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ defmodule LambdaEthereumConsensus.Utils.BitVector do

# The internal representation is a bitstring, but we could evaluate
# turning it into an integer to use bitwise operations instead.
@opaque t :: bitstring
@type t :: bitstring

defguard is_bitvector(value) when is_bitstring(value)

defguard bit_vector_size(value) when bit_size(value)

@doc """
Creates a new bit_vector from an integer or a bitstring.
Expand Down
48 changes: 48 additions & 0 deletions test/unit/ssz_ex_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,52 @@ defmodule Unit.SSZExTest do

assert_roundtrip(serialized, sync, Types.SyncCommittee)
end

test "serialize and deserialize bitlist" do
encoded_bytes = <<160, 92, 1>>
assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 16})
assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 16})

encoded_bytes = <<255, 1>>
assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 16})
assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 16})

encoded_bytes = <<31>>
assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 16})
assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 16})

encoded_bytes = <<1>>
assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 31})
assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 31})

encoded_bytes = <<106, 141, 117, 7>>
assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitlist, 31})
assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitlist, 31})

encoded_bytes = <<7>>
assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitlist, 1})

encoded_bytes = <<124, 3>>
assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitlist, 1})

encoded_bytes = <<0>>
assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitlist, 1})
assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitlist, 16})
end

test "serialize and deserialize bitvector" do
encoded_bytes = <<255, 255>>
assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitvector, 16})
assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitvector, 16})

encoded_bytes = <<0, 0>>
assert {:ok, decoded_bytes} = SszEx.decode(encoded_bytes, {:bitvector, 16})
assert {:ok, ^encoded_bytes} = SszEx.encode(decoded_bytes, {:bitvector, 16})

encoded_bytes = <<255, 255, 255, 255, 1>>
assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitvector, 33})

encoded_bytes = <<0>>
assert {:error, _msg} = SszEx.decode(encoded_bytes, {:bitvector, 9})
end
end

0 comments on commit f2fe509

Please sign in to comment.