From d4f5633ef2b32531bb9a21be842c61579585f347 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Arjovsky?= Date: Tue, 20 Feb 2024 15:49:18 +0100 Subject: [PATCH] feat: add bitlists and bitvectors natively in SSZ nif library (#785) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tomás Grüner <47506558+MegaRedHand@users.noreply.github.com> --- .../p2p/gossip/handler.ex | 6 +-- .../state_transition/accessors.ex | 9 +--- .../state_transition/epoch_processing.ex | 25 +++-------- .../state_transition/operations.ex | 11 ++--- lib/ssz_ex.ex | 16 +++---- lib/types/beacon_chain/attestation.ex | 10 +++++ lib/types/beacon_chain/beacon_state.ex | 9 +++- lib/types/beacon_chain/pending_attestation.ex | 10 +++++ lib/types/beacon_chain/sync_aggregate.ex | 14 +++++- lib/utils/bit_list.ex | 44 +++++++------------ test/spec/utils.ex | 2 +- test/unit/bit_list_test.exs | 19 ++++---- 12 files changed, 90 insertions(+), 85 deletions(-) diff --git a/lib/lambda_ethereum_consensus/p2p/gossip/handler.ex b/lib/lambda_ethereum_consensus/p2p/gossip/handler.ex index 370641023..18d4ef71d 100644 --- a/lib/lambda_ethereum_consensus/p2p/gossip/handler.ex +++ b/lib/lambda_ethereum_consensus/p2p/gossip/handler.ex @@ -7,7 +7,7 @@ defmodule LambdaEthereumConsensus.P2P.Gossip.Handler do alias LambdaEthereumConsensus.Beacon.BeaconChain alias LambdaEthereumConsensus.Beacon.PendingBlocks - alias LambdaEthereumConsensus.Utils.BitVector + alias LambdaEthereumConsensus.Utils.BitField alias Types.{AggregateAndProof, SignedAggregateAndProof, SignedBeaconBlock} def handle_beacon_block(%SignedBeaconBlock{message: block} = signed_block) do @@ -25,11 +25,11 @@ defmodule LambdaEthereumConsensus.P2P.Gossip.Handler do def handle_beacon_aggregate_and_proof(%SignedAggregateAndProof{ message: %AggregateAndProof{aggregate: aggregate} }) do - votes = BitVector.count(aggregate.aggregation_bits) + votes = BitField.count(aggregate.aggregation_bits) slot = aggregate.data.slot root = aggregate.data.beacon_block_root |> Base.encode16() - # We are getting ~500 attestations in half a second. This is overwheling the store GenServer at the moment. + # We are getting ~500 attestations in half a second. This is overwhelming the store GenServer at the moment. # Store.on_attestation(aggregate) Logger.debug( diff --git a/lib/lambda_ethereum_consensus/state_transition/accessors.ex b/lib/lambda_ethereum_consensus/state_transition/accessors.ex index 0c4d8442b..77b015b46 100644 --- a/lib/lambda_ethereum_consensus/state_transition/accessors.ex +++ b/lib/lambda_ethereum_consensus/state_transition/accessors.ex @@ -6,6 +6,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do alias LambdaEthereumConsensus.SszEx alias LambdaEthereumConsensus.StateTransition.{Cache, Math, Misc, Predicates} alias LambdaEthereumConsensus.Utils + alias LambdaEthereumConsensus.Utils.BitList alias LambdaEthereumConsensus.Utils.Randao alias Types.{Attestation, BeaconState, IndexedAttestation, SyncCommittee, Validator} @@ -510,13 +511,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do |> Enum.sort() end - defp participated?(bits, index) do - # The bit order inside the byte is reversed (e.g. bits[0] is the 8th bit). - # Here we keep the byte index the same, but reverse the bit index. - bit_index = index + 7 - 2 * rem(index, 8) - <<_::size(bit_index), flag::1, _::bits>> = bits - flag == 1 - end + defp participated?(bits, index), do: BitList.set?(bits, index) @doc """ Return the combined effective balance of the ``indices``. diff --git a/lib/lambda_ethereum_consensus/state_transition/epoch_processing.ex b/lib/lambda_ethereum_consensus/state_transition/epoch_processing.ex index 0bea9ce40..6baa08d55 100644 --- a/lib/lambda_ethereum_consensus/state_transition/epoch_processing.ex +++ b/lib/lambda_ethereum_consensus/state_transition/epoch_processing.ex @@ -358,16 +358,10 @@ defmodule LambdaEthereumConsensus.StateTransition.EpochProcessing do end defp update_first_bit(state) do - bits = - state.justification_bits - |> BitVector.new(4) - |> BitVector.shift_higher(1) - |> BitVector.to_bytes() - %BeaconState{ state | previous_justified_checkpoint: state.current_justified_checkpoint, - justification_bits: bits + justification_bits: BitVector.shift_higher(state.justification_bits, 1) } end @@ -377,13 +371,11 @@ defmodule LambdaEthereumConsensus.StateTransition.EpochProcessing do with {:ok, block_root} <- Accessors.get_block_root(state, epoch) do new_checkpoint = %Types.Checkpoint{epoch: epoch, root: block_root} - bits = - state.justification_bits - |> BitVector.new(4) - |> BitVector.set(index) - |> BitVector.to_bytes() - - %{state | current_justified_checkpoint: new_checkpoint, justification_bits: bits} + %{ + state + | current_justified_checkpoint: new_checkpoint, + justification_bits: BitVector.set(state.justification_bits, index) + } |> then(&{:ok, &1}) end end @@ -395,10 +387,7 @@ defmodule LambdaEthereumConsensus.StateTransition.EpochProcessing do range, offset ) do - bits_set = - state.justification_bits - |> BitVector.new(4) - |> BitVector.all?(range) + bits_set = BitVector.all?(state.justification_bits, range) if bits_set and old_justified_checkpoint.epoch + offset == current_epoch do %BeaconState{state | finalized_checkpoint: old_justified_checkpoint} diff --git a/lib/lambda_ethereum_consensus/state_transition/operations.ex b/lib/lambda_ethereum_consensus/state_transition/operations.ex index 1bf76cc18..3844688b1 100644 --- a/lib/lambda_ethereum_consensus/state_transition/operations.ex +++ b/lib/lambda_ethereum_consensus/state_transition/operations.ex @@ -117,13 +117,10 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do # Verify sync committee aggregate signature signing over the previous slot block root committee_pubkeys = state.current_sync_committee.pubkeys - sync_committee_bits = - BitVector.new(aggregate.sync_committee_bits, ChainSpec.get("SYNC_COMMITTEE_SIZE")) - participant_pubkeys = committee_pubkeys |> Enum.with_index() - |> Enum.filter(fn {_, index} -> BitVector.set?(sync_committee_bits, index) end) + |> Enum.filter(fn {_, index} -> BitVector.set?(aggregate.sync_committee_bits, index) end) |> Enum.map(fn {public_key, _} -> public_key end) previous_slot = max(state.slot, 1) - 1 @@ -138,7 +135,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do # Compute participant and proposer rewards {participant_reward, proposer_reward} = compute_sync_aggregate_rewards(state) - total_proposer_reward = BitVector.count(sync_committee_bits) * proposer_reward + total_proposer_reward = BitVector.count(aggregate.sync_committee_bits) * proposer_reward # PERF: make Map with committee_index by pubkey, then # Enum.map validators -> new balance all in place, without map_reduce @@ -146,7 +143,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do |> get_sync_committee_indices(committee_pubkeys) |> Stream.with_index() |> Stream.map(fn {validator_index, committee_index} -> - if BitVector.set?(sync_committee_bits, committee_index), + if BitVector.set?(aggregate.sync_committee_bits, committee_index), do: {validator_index, participant_reward}, else: {validator_index, -participant_reward} end) @@ -845,7 +842,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do end defp check_matching_aggregation_bits_length(attestation, beacon_committee) do - if BitList.length_of_bitlist(attestation.aggregation_bits) == length(beacon_committee) do + if BitList.length(attestation.aggregation_bits) == length(beacon_committee) do :ok else {:error, "Mismatched aggregation bits length"} diff --git a/lib/ssz_ex.ex b/lib/ssz_ex.ex index 2c3bc5796..1452c6f0b 100644 --- a/lib/ssz_ex.ex +++ b/lib/ssz_ex.ex @@ -300,8 +300,7 @@ defmodule LambdaEthereumConsensus.SszEx do end def pack_bits(value, :bitlist) do - len = value |> bit_size() - {value, len} |> BitList.to_packed_bytes() |> pack_bytes() + value |> BitList.to_packed_bytes() |> pack_bytes() end def chunk_count({:list, type, max_size}) do @@ -354,7 +353,7 @@ defmodule LambdaEthereumConsensus.SszEx do if len > max_size do {:error, "excess bits"} else - {:ok, BitList.to_bytes({bit_list, len})} + {:ok, BitList.to_bytes(bit_list)} end end @@ -407,11 +406,12 @@ defmodule LambdaEthereumConsensus.SszEx do defp decode_bitlist(bit_list, max_size) when bit_size(bit_list) > 0 do num_bytes = byte_size(bit_list) - {decoded, len} = BitList.new(bit_list) + decoded = BitList.new(bit_list) + len = BitList.length(decoded) cond do - len < 0 -> - {:error, "missing length information"} + match?(<<_::binary-size(num_bytes - 1), 0>>, bit_list) -> + {:error, "BitList has no length information."} div(len, @bits_per_byte) + 1 != num_bytes -> {:error, "invalid byte count"} @@ -652,7 +652,7 @@ defmodule LambdaEthereumConsensus.SszEx do defp check_first_offset([{offset, _} | _rest], items_index, _binary_size) do cond do - offset < items_index -> {:error, "OffsetIntoFixedPortion"} + offset < items_index -> {:error, "OffsetIntoFixedPortion (#{offset})"} offset > items_index -> {:error, "OffsetSkipsVariableBytes"} true -> :ok end @@ -738,7 +738,7 @@ defmodule LambdaEthereumConsensus.SszEx do defp sanitize_offset(offset, previous_offset, _num_bytes, num_fixed_bytes) do cond do offset < num_fixed_bytes -> - {:error, "OffsetIntoFixedPortion"} + {:error, "OffsetIntoFixedPortion #{offset}"} previous_offset == nil && offset != num_fixed_bytes -> {:error, "OffsetSkipsVariableBytes"} diff --git a/lib/types/beacon_chain/attestation.ex b/lib/types/beacon_chain/attestation.ex index da565156a..8ca6a6d29 100644 --- a/lib/types/beacon_chain/attestation.ex +++ b/lib/types/beacon_chain/attestation.ex @@ -3,6 +3,8 @@ defmodule Types.Attestation do Struct definition for `AttestationMainnet`. Related definitions in `native/ssz_nif/src/types/`. """ + alias LambdaEthereumConsensus.Utils.BitList + @behaviour LambdaEthereumConsensus.Container fields = [ @@ -29,4 +31,12 @@ defmodule Types.Attestation do {:signature, TypeAliases.bls_signature()} ] end + + def encode(%__MODULE__{} = map) do + Map.update!(map, :aggregation_bits, &BitList.to_bytes/1) + end + + def decode(%__MODULE__{} = map) do + Map.update!(map, :aggregation_bits, &BitList.new/1) + end end diff --git a/lib/types/beacon_chain/beacon_state.ex b/lib/types/beacon_chain/beacon_state.ex index 8f51a19be..ccfb2ac49 100644 --- a/lib/types/beacon_chain/beacon_state.ex +++ b/lib/types/beacon_chain/beacon_state.ex @@ -3,9 +3,10 @@ defmodule Types.BeaconState do Struct definition for `BeaconState`. Related definitions in `native/ssz_nif/src/types/`. """ - @behaviour LambdaEthereumConsensus.Container alias LambdaEthereumConsensus.Utils.BitVector + @behaviour LambdaEthereumConsensus.Container + fields = [ :genesis_time, :genesis_validators_root, @@ -114,6 +115,7 @@ defmodule Types.BeaconState do |> Map.update!(:previous_epoch_participation, &Aja.Vector.to_list/1) |> Map.update!(:current_epoch_participation, &Aja.Vector.to_list/1) |> Map.update!(:latest_execution_payload_header, &Types.ExecutionPayloadHeader.encode/1) + |> Map.update!(:justification_bits, &BitVector.to_bytes/1) end def decode(%__MODULE__{} = map) do @@ -124,6 +126,9 @@ defmodule Types.BeaconState do |> Map.update!(:previous_epoch_participation, &Aja.Vector.new/1) |> Map.update!(:current_epoch_participation, &Aja.Vector.new/1) |> Map.update!(:latest_execution_payload_header, &Types.ExecutionPayloadHeader.decode/1) + |> Map.update!(:justification_bits, fn bits -> + BitVector.new(bits, Constants.justification_bits_length()) + end) end @doc """ @@ -261,7 +266,7 @@ defmodule Types.BeaconState do {:list, TypeAliases.participation_flags(), ChainSpec.get("VALIDATOR_REGISTRY_LIMIT")}}, {:current_epoch_participation, {:list, TypeAliases.participation_flags(), ChainSpec.get("VALIDATOR_REGISTRY_LIMIT")}}, - {:justification_bits, {:bitvector, ChainSpec.get("JUSTIFICATION_BITS_LENGTH")}}, + {:justification_bits, {:bitvector, Constants.justification_bits_length()}}, {:previous_justified_checkpoint, Types.Checkpoint}, {:current_justified_checkpoint, Types.Checkpoint}, {:finalized_checkpoint, Types.Checkpoint}, diff --git a/lib/types/beacon_chain/pending_attestation.ex b/lib/types/beacon_chain/pending_attestation.ex index d04d22815..a072d7713 100644 --- a/lib/types/beacon_chain/pending_attestation.ex +++ b/lib/types/beacon_chain/pending_attestation.ex @@ -3,6 +3,8 @@ defmodule Types.PendingAttestation do Struct definition for `PendingAttestation`. Related definitions in `native/ssz_nif/src/types/`. """ + alias LambdaEthereumConsensus.Utils.BitList + @behaviour LambdaEthereumConsensus.Container fields = [ @@ -32,4 +34,12 @@ defmodule Types.PendingAttestation do {:proposer_index, TypeAliases.validator_index()} ] end + + def encode(%__MODULE__{} = map) do + Map.update!(map, :aggregation_bits, &BitList.to_bytes/1) + end + + def decode(%__MODULE__{} = map) do + Map.update!(map, :aggregation_bits, &BitList.new/1) + end end diff --git a/lib/types/beacon_chain/sync_aggregate.ex b/lib/types/beacon_chain/sync_aggregate.ex index 38240ae6a..8194cbf2d 100644 --- a/lib/types/beacon_chain/sync_aggregate.ex +++ b/lib/types/beacon_chain/sync_aggregate.ex @@ -3,6 +3,8 @@ defmodule Types.SyncAggregate do Struct definition for `SyncAggregate`. Related definitions in `native/ssz_nif/src/types/`. """ + alias LambdaEthereumConsensus.Utils.BitVector + @behaviour LambdaEthereumConsensus.Container fields = [ @@ -15,7 +17,7 @@ defmodule Types.SyncAggregate do @type t :: %__MODULE__{ # max size SYNC_COMMITTEE_SIZE - sync_committee_bits: Types.bitvector(), + sync_committee_bits: BitVector.t(), sync_committee_signature: Types.bls_signature() } @@ -26,4 +28,14 @@ defmodule Types.SyncAggregate do {:sync_committee_signature, TypeAliases.bls_signature()} ] end + + def encode(%__MODULE__{} = map) do + Map.update!(map, :sync_committee_bits, &BitVector.to_bytes/1) + end + + def decode(%__MODULE__{} = map) do + Map.update!(map, :sync_committee_bits, fn bits -> + BitVector.new(bits, ChainSpec.get("SYNC_COMMITTEE_SIZE")) + end) + end end diff --git a/lib/utils/bit_list.ex b/lib/utils/bit_list.ex index fa64ca4dd..2cecb5a15 100644 --- a/lib/utils/bit_list.ex +++ b/lib/utils/bit_list.ex @@ -1,9 +1,9 @@ defmodule LambdaEthereumConsensus.Utils.BitList do @moduledoc """ - Set of utilities to interact with BitList, represented as {bitstring, len}. + Set of utilities to interact with BitList, represented as a bitstring. """ alias LambdaEthereumConsensus.Utils.BitField - @type t :: {bitstring, integer()} + @type t :: bitstring @bits_per_byte 8 @sentinel_bit 1 @bits_in_sentinel_bit 1 @@ -15,22 +15,19 @@ defmodule LambdaEthereumConsensus.Utils.BitList do def new(bitstring) when is_bitstring(bitstring) do # Change the byte order from little endian to big endian (reverse bytes). num_bits = bit_size(bitstring) - len = length_of_bitlist(bitstring) <> = bitstring - decoded = - <>)::bitstring, - pre::integer-size(num_bits - @bits_per_byte)>> - - {decoded, len} + <>)::bitstring, + pre::integer-size(num_bits - @bits_per_byte)>> end @spec to_bytes(t) :: bitstring - def to_bytes({bit_list, len}) do + def to_bytes(bit_list) do # Change the byte order from big endian to little endian (reverse bytes). + len = bit_size(bit_list) r = rem(len, @bits_per_byte) <> = bit_list @@ -40,8 +37,9 @@ defmodule LambdaEthereumConsensus.Utils.BitList do end @spec to_packed_bytes(t) :: bitstring - def to_packed_bytes({bit_list, len}) do + def to_packed_bytes(bit_list) do # Change the byte order from big endian to little endian (reverse bytes). + len = bit_size(bit_list) r = rem(len, @bits_per_byte) <> = bit_list @@ -55,37 +53,27 @@ defmodule LambdaEthereumConsensus.Utils.BitList do Equivalent to bit_list[index] == 1. """ @spec set?(t, non_neg_integer) :: boolean - def set?({bit_list, _}, index), do: BitField.set?(bit_list, index) + def set?(bit_list, index), do: BitField.set?(bit_list, index) @doc """ Sets a bit (turns it to 1). Equivalent to bit_list[index] = 1. """ @spec set(t, non_neg_integer) :: t - def set({bit_list, len}, index), do: {BitField.set(bit_list, index), len} + def set(bit_list, index), do: BitField.set(bit_list, index) @doc """ Clears a bit (turns it to 0). Equivalent to bit_list[index] = 0. """ @spec clear(t, non_neg_integer) :: t - def clear({bit_list, len}, index), do: {BitField.clear(bit_list, index), len} - - def length_of_bitlist(bitlist) when is_binary(bitlist) do - bit_size = bit_size(bitlist) - <<_::size(bit_size - @bits_per_byte), last_byte>> = bitlist - bit_size - leading_zeros(<>) - @bits_in_sentinel_bit - end + def clear(bit_list, index), do: BitField.clear(bit_list, index) - defp leading_zeros(<<@sentinel_bit::@bits_in_sentinel_bit, _::7>>), do: 0 - defp leading_zeros(<<0::1, @sentinel_bit::@bits_in_sentinel_bit, _::6>>), do: 1 - defp leading_zeros(<<0::2, @sentinel_bit::@bits_in_sentinel_bit, _::5>>), do: 2 - defp leading_zeros(<<0::3, @sentinel_bit::@bits_in_sentinel_bit, _::4>>), do: 3 - defp leading_zeros(<<0::4, @sentinel_bit::@bits_in_sentinel_bit, _::3>>), do: 4 - defp leading_zeros(<<0::5, @sentinel_bit::@bits_in_sentinel_bit, _::2>>), do: 5 - defp leading_zeros(<<0::6, @sentinel_bit::@bits_in_sentinel_bit, _::1>>), do: 6 - defp leading_zeros(<<0::7, @sentinel_bit::@bits_in_sentinel_bit>>), do: 7 - defp leading_zeros(<<0::8>>), do: 8 + @doc """ + Calculates the length of the bit_list. + """ + @spec length(t) :: non_neg_integer() + def length(bit_list), do: bit_size(bit_list) @spec remove_trailing_bit(binary()) :: bitstring() defp remove_trailing_bit(<<@sentinel_bit::@bits_in_sentinel_bit, rest::7>>), do: <> diff --git a/test/spec/utils.ex b/test/spec/utils.ex index 928a9a677..cf7597aaa 100644 --- a/test/spec/utils.ex +++ b/test/spec/utils.ex @@ -132,7 +132,7 @@ defmodule SpecTestUtils do def sanitize_ssz(vector_elements, {:vector, module, _size} = _schema) when is_atom(module), do: Enum.map(vector_elements, &struct!(module, &1)) - def sanitize_ssz(bitlist, {:bitlist, _size} = _schema), do: elem(BitList.new(bitlist), 0) + def sanitize_ssz(bitlist, {:bitlist, _size} = _schema), do: BitList.new(bitlist) def sanitize_ssz(bitvector, {:bitvector, size} = _schema), do: BitVector.new(bitvector, size) def sanitize_ssz(0, {:list, {:int, 8}, _size} = _schema), do: [] diff --git a/test/unit/bit_list_test.exs b/test/unit/bit_list_test.exs index b065e7c97..db18fdfbc 100644 --- a/test/unit/bit_list_test.exs +++ b/test/unit/bit_list_test.exs @@ -1,25 +1,24 @@ defmodule BitListTest do use ExUnit.Case - alias LambdaEthereumConsensus.SszEx alias LambdaEthereumConsensus.Utils.BitList describe "Sub-byte BitList" do test "build from binary" do - input_encoded = <<237, 7>> - {:ok, decoded} = SszEx.decode(input_encoded, {:bitlist, 10}) - assert BitList.set?({decoded, 10}, 0) == true - assert BitList.set?({decoded, 10}, 1) == false - assert BitList.set?({decoded, 10}, 4) == false - assert BitList.set?({decoded, 10}, 9) == true + decoded = BitList.new(<<237, 7>>) - {updated_bitlist, _} = - {decoded, 10} + assert BitList.set?(decoded, 0) == true + assert BitList.set?(decoded, 1) == false + assert BitList.set?(decoded, 4) == false + assert BitList.set?(decoded, 9) == true + + updated_bitlist = + decoded |> BitList.set(1) |> BitList.set(4) |> BitList.clear(0) |> BitList.clear(9) - {:ok, <<254, 5>>} = SszEx.encode(updated_bitlist, {:bitlist, 10}) + <<254, 5>> = BitList.to_bytes(updated_bitlist) end test "sets a single bit" do