Skip to content

Commit

Permalink
Optimize process_attestation
Browse files Browse the repository at this point in the history
  • Loading branch information
MegaRedHand committed Nov 30, 2023
1 parent b65d0cf commit 239c8db
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 149 deletions.
79 changes: 37 additions & 42 deletions lib/lambda_ethereum_consensus/state_transition/accessors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do

alias LambdaEthereumConsensus.SszEx
alias LambdaEthereumConsensus.StateTransition.{Math, Misc, Predicates}
alias LambdaEthereumConsensus.Utils
alias SszTypes.{Attestation, BeaconState, IndexedAttestation, SyncCommittee, Validator}

@doc """
Expand Down Expand Up @@ -314,62 +315,56 @@ defmodule LambdaEthereumConsensus.StateTransition.Accessors do
) ::
{:ok, list(SszTypes.uint64())} | {:error, binary()}
def get_attestation_participation_flag_indices(state, data, inclusion_delay) do
with :ok <- check_valid_source(state, data),
{:ok, target_root} <-
get_block_root(state, data.target.epoch) |> Utils.map_err("invalid target"),
{:ok, head_root} <-
get_block_root_at_slot(state, data.slot) |> Utils.map_err("invalid head") do
is_matching_target = data.target.root == target_root
is_matching_head = is_matching_target and data.beacon_block_root == head_root

source_indices = compute_source_indices(inclusion_delay)
target_indices = compute_target_indices(is_matching_target, inclusion_delay)
head_indices = compute_head_indices(is_matching_head, inclusion_delay)

{:ok, Enum.concat([source_indices, target_indices, head_indices])}
end
end

defp check_valid_source(state, data) do
justified_checkpoint =
if data.target.epoch == get_current_epoch(state) do
state.current_justified_checkpoint
else
state.previous_justified_checkpoint
end

is_matching_source = data.source == justified_checkpoint

case {get_block_root(state, data.target.epoch), get_block_root_at_slot(state, data.slot)} do
{{:ok, block_root}, {:ok, block_root_at_slot}} ->
if is_matching_source do
is_matching_target = is_matching_source && data.target.root == block_root
source_indices = compute_source_indices(data, justified_checkpoint, inclusion_delay)

target_indices =
compute_target_indices(data, block_root, inclusion_delay, is_matching_source)

head_indices =
compute_head_indices(data, block_root_at_slot, inclusion_delay, is_matching_target)

{:ok, Enum.concat([source_indices, target_indices, head_indices])}
else
{:error, "Attestation source does not match justified checkpoint"}
end

_ ->
{:error, "Failed to get block roots"}
if data.source == justified_checkpoint do
:ok
else
{:error, "invalid source"}
end
end

defp compute_source_indices(data, justified_checkpoint, inclusion_delay) do
if data.source == justified_checkpoint &&
inclusion_delay <= Math.integer_squareroot(ChainSpec.get("SLOTS_PER_EPOCH")) do
[Constants.timely_source_flag_index()]
else
[]
end
defp compute_source_indices(inclusion_delay) do
max_delay = ChainSpec.get("SLOTS_PER_EPOCH") |> Math.integer_squareroot()
if inclusion_delay <= max_delay, do: [Constants.timely_source_flag_index()], else: []
end

defp compute_target_indices(data, block_root, inclusion_delay, is_matching_source) do
if is_matching_source && data.target.root == block_root &&
inclusion_delay <= ChainSpec.get("SLOTS_PER_EPOCH") do
[Constants.timely_target_flag_index()]
else
[]
end
defp compute_target_indices(is_matching_target, inclusion_delay) do
max_delay = ChainSpec.get("SLOTS_PER_EPOCH")

if is_matching_target and inclusion_delay <= max_delay,
do: [Constants.timely_target_flag_index()],
else: []
end

defp compute_head_indices(data, block_root_at_slot, inclusion_delay, is_matching_target) do
if is_matching_target && data.beacon_block_root == block_root_at_slot &&
inclusion_delay == ChainSpec.get("MIN_ATTESTATION_INCLUSION_DELAY") do
[Constants.timely_head_flag_index()]
else
[]
end
defp compute_head_indices(is_matching_head, inclusion_delay) do
min_inclusion_delay = ChainSpec.get("MIN_ATTESTATION_INCLUSION_DELAY")

if is_matching_head and inclusion_delay == min_inclusion_delay,
do: [Constants.timely_head_flag_index()],
else: []
end

@doc """
Expand Down
162 changes: 56 additions & 106 deletions lib/lambda_ethereum_consensus/state_transition/operations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -645,18 +645,10 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
@spec process_attestation(BeaconState.t(), Attestation.t()) ::
{:ok, BeaconState.t()} | {:error, binary()}
def process_attestation(state, attestation) do
case verify_attestation_for_process(state, attestation) do
{:ok, _} ->
data = attestation.data
aggregation_bits = attestation.aggregation_bits

case process_attestation(state, data, aggregation_bits) do
{:ok, updated_state} -> {:ok, updated_state}
{:error, reason} -> {:error, reason}
end

{:error, reason} ->
{:error, reason}
# TODO: optimize (takes ~3s)
with :ok <- verify_attestation_for_process(state, attestation) do
# TODO: optimize (takes ~1s)
process_attestation(state, attestation.data, attestation.aggregation_bits)
end
end

Expand Down Expand Up @@ -684,17 +676,10 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do

{:ok, proposer_index} = Accessors.get_beacon_proposer_index(state)

bal_updated_state =
Mutators.increase_balance(
state,
proposer_index,
proposer_reward
)

updated_state =
update_state(bal_updated_state, is_current_epoch, updated_epoch_participation)

{:ok, updated_state}
state
|> Mutators.increase_balance(proposer_index, proposer_reward)
|> update_state(is_current_epoch, updated_epoch_participation)
|> then(&{:ok, &1})
else
{:error, reason} -> {:error, reason}
end
Expand All @@ -721,21 +706,25 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
|> Stream.with_index()
|> Enum.map_reduce(0, fn {{validator, participation}, i}, acc ->
if MapSet.member?(attesting_indices, i) do
bv_participation = BitVector.new(participation, 8)
base_reward = Accessors.get_base_reward(validator, base_reward_per_increment)

weights
|> Stream.reject(&BitVector.set?(bv_participation, elem(&1, 1)))
|> Enum.reduce({bv_participation, acc}, fn {weight, index}, {bv_participation, acc} ->
{bv_participation |> BitVector.set(index), acc + base_reward * weight}
end)
|> then(fn {p, acc} -> {BitVector.to_integer(p), acc} end)
update_participation(participation, acc, base_reward, weights)
else
{participation, acc}
end
end)
end

defp update_participation(participation, acc, base_reward, weights) do
bv_participation = BitVector.new(participation, 8)

weights
|> Stream.reject(&BitVector.set?(bv_participation, elem(&1, 1)))
|> Enum.reduce({bv_participation, acc}, fn {weight, index}, {bv_participation, acc} ->
{bv_participation |> BitVector.set(index), acc + base_reward * weight}
end)
|> then(fn {p, acc} -> {BitVector.to_integer(p), acc} end)
end

defp compute_proposer_reward(proposer_reward_numerator) do
proposer_reward_denominator =
((Constants.weight_denominator() - Constants.proposer_weight()) *
Expand All @@ -751,16 +740,31 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
defp update_state(state, false, updated_epoch_participation),
do: %{state | previous_epoch_participation: updated_epoch_participation}

def verify_attestation_for_process(state, attestation) do
data = attestation.data
def verify_attestation_for_process(state, %Attestation{data: data} = attestation) do
with {:ok, beacon_committee} <- Accessors.get_beacon_committee(state, data.slot, data.index),
{:ok, indexed_attestation} <- Accessors.get_indexed_attestation(state, attestation) do
cond do
invalid_target_epoch?(data, state) ->
{:error, "Invalid target epoch"}

beacon_committee = fetch_beacon_committee(state, data)
indexed_attestation = fetch_indexed_attestation(state, attestation)
epoch_mismatch?(data) ->
{:error, "Epoch mismatch"}

if has_invalid_conditions?(data, state, beacon_committee, indexed_attestation, attestation) do
{:error, get_error_message(data, state, beacon_committee, indexed_attestation, attestation)}
else
{:ok, "Valid"}
invalid_slot_range?(data, state) ->
{:error, "Invalid slot range"}

exceeds_committee_count?(data, state) ->
{:error, "Index exceeds committee count"}

mismatched_aggregation_bits_length?(attestation, beacon_committee) ->
{:error, "Mismatched aggregation bits length"}

not valid_signature?(state, indexed_attestation) ->
{:error, "Invalid signature"}

true ->
:ok
end
end
end

Expand Down Expand Up @@ -831,55 +835,6 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
end
end

defp has_invalid_conditions?(data, state, beacon_committee, indexed_attestation, attestation) do
invalid_target_epoch?(data, state) ||
epoch_mismatch?(data) ||
invalid_slot_range?(data, state) ||
exceeds_committee_count?(data, state) ||
!beacon_committee || !indexed_attestation ||
mismatched_aggregation_bits_length?(attestation, beacon_committee) ||
invalid_signature?(state, indexed_attestation)
end

defp get_error_message(data, state, beacon_committee, indexed_attestation, attestation) do
cond do
invalid_target_epoch?(data, state) ->
"Invalid target epoch"

epoch_mismatch?(data) ->
"Epoch mismatch"

invalid_slot_range?(data, state) ->
"Invalid slot range"

exceeds_committee_count?(data, state) ->
"Index exceeds committee count"

!beacon_committee || !indexed_attestation ->
"Indexing error at beacon committee"

mismatched_aggregation_bits_length?(attestation, beacon_committee) ->
"Mismatched aggregation bits length"

invalid_signature?(state, indexed_attestation) ->
"Invalid signature"
end
end

defp fetch_beacon_committee(state, data) do
case Accessors.get_beacon_committee(state, data.slot, data.index) do
{:ok, committee} -> committee
{:error, _reason} -> nil
end
end

defp fetch_indexed_attestation(state, attestation) do
case Accessors.get_indexed_attestation(state, attestation) do
{:ok, indexed_attestation} -> indexed_attestation
{:error, _reason} -> nil
end
end

defp invalid_target_epoch?(data, state) do
data.target.epoch < Accessors.get_previous_epoch(state) ||
data.target.epoch > Accessors.get_current_epoch(state)
Expand All @@ -902,8 +857,8 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
length_of_bitstring(attestation.aggregation_bits) - 1 != length(beacon_committee)
end

defp invalid_signature?(state, indexed_attestation) do
not Predicates.is_valid_indexed_attestation(state, indexed_attestation)
defp valid_signature?(state, indexed_attestation) do
Predicates.is_valid_indexed_attestation(state, indexed_attestation)
end

defp length_of_bitstring(binary) when is_binary(binary) do
Expand Down Expand Up @@ -996,25 +951,20 @@ defmodule LambdaEthereumConsensus.StateTransition.Operations do
def process_operations(state, body) do
# Ensure that outstanding deposits are processed up to the maximum number of deposits
with :ok <- verify_deposits(state, body) do
# Define a function that iterates over a list of operations and applies a given function to each element
updated_state =
state
|> for_ops(body.proposer_slashings, &process_proposer_slashing/2)
|> for_ops(body.attester_slashings, &process_attester_slashing/2)
|> for_ops(body.attestations, &process_attestation/2)
|> for_ops(body.deposits, &process_deposit/2)
|> for_ops(body.voluntary_exits, &process_voluntary_exit/2)
|> for_ops(body.bls_to_execution_changes, &process_bls_to_execution_change/2)

{:ok, updated_state}
{:ok, state}
|> for_ops(body.proposer_slashings, &process_proposer_slashing/2)
|> for_ops(body.attester_slashings, &process_attester_slashing/2)
|> for_ops(body.attestations, &process_attestation/2)
|> for_ops(body.deposits, &process_deposit/2)
|> for_ops(body.voluntary_exits, &process_voluntary_exit/2)
|> for_ops(body.bls_to_execution_changes, &process_bls_to_execution_change/2)
end
end

defp for_ops(state, operations, func) do
Enum.reduce(operations, state, fn operation, acc ->
with {:ok, state} <- func.(acc, operation) do
state
end
defp for_ops(acc, operations, func) do
Enum.reduce_while(operations, acc, fn
operation, {:ok, state} -> {:cont, func.(state, operation)}
_, {:error, reason} -> {:halt, {:error, reason}}
end)
end

Expand Down
8 changes: 8 additions & 0 deletions lib/lambda_ethereum_consensus/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,12 @@ defmodule LambdaEthereumConsensus.Utils do
@spec map({:ok | :error, any()}, (any() -> any())) :: any() | {:error, any()}
def map({:ok, value}, fun), do: fun.(value)
def map({:error, _} = err, _fun), do: err

@doc """
If first arg is an ``{:error, reason}`` tuple, replace ``reason`` with
``new_reason``. Else, return the first arg unmodified.
"""
@spec map_err(any() | {:error, String.t()}, String.t()) :: any() | {:error, String.t()}
def map_err({:error, _}, reason), do: {:error, reason}
def map_err(v, _), do: v
end
2 changes: 1 addition & 1 deletion lib/utils/bit_vector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ defmodule LambdaEthereumConsensus.Utils.BitVector do
@doc """
Turns the bit_vector into an integer.
"""
@spec to_integer(t) :: t
@spec to_integer(t) :: non_neg_integer()
def to_integer(bit_vector) do
<<int::unsigned-size(bit_size(bit_vector))>> = bit_vector
int
Expand Down

0 comments on commit 239c8db

Please sign in to comment.