From 111aa36cee24645da11684513a32010071981ef6 Mon Sep 17 00:00:00 2001 From: Alisander Qoshqosh Date: Wed, 10 Jul 2024 12:41:38 +0400 Subject: [PATCH] refactor generically typed checkpoints --- .../token/erc721/extensions/consecutive.rs | 27 ++-- .../src/utils/structs/checkpoints/mod.rs | 97 +++++++++++++ .../{checkpoints.rs => checkpoints/trace.rs} | 128 ++++++++---------- 3 files changed, 169 insertions(+), 83 deletions(-) create mode 100644 contracts/src/utils/structs/checkpoints/mod.rs rename contracts/src/utils/structs/{checkpoints.rs => checkpoints/trace.rs} (84%) diff --git a/contracts/src/token/erc721/extensions/consecutive.rs b/contracts/src/token/erc721/extensions/consecutive.rs index 35a88440..a6eb14bd 100644 --- a/contracts/src/token/erc721/extensions/consecutive.rs +++ b/contracts/src/token/erc721/extensions/consecutive.rs @@ -43,18 +43,20 @@ use crate::{ structs::{ bitmap::BitMap, checkpoints, - checkpoints::{Trace160, U96}, + checkpoints::{trace::Trace, Size, S160}, }, }, }; +type U96 = ::Key; + sol_storage! { /// State of an [`Erc721Consecutive`] token. pub struct Erc721Consecutive { /// Erc721 contract storage. Erc721 erc721; /// Checkpoint library contract for sequential ownership. - Trace160 _sequential_ownership; + Trace _sequential_ownership; /// BitMap library contract for sequential burn of tokens. BitMap _sequential_burn; /// Used to offset the first token id in @@ -113,8 +115,8 @@ sol! { pub enum Error { /// Error type from [`Erc721`] contract [`erc721::Error`]. Erc721(erc721::Error), - /// Error type from checkpoint contract [`checkpoints::Error`]. - Checkpoints(checkpoints::Error), + /// Error type from checkpoint contract [`checkpoints::trace::Error`]. + Checkpoints(checkpoints::trace::Error), /// Batch mint is restricted to the constructor. /// Any batch mint not emitting the [`Transfer`] event outside of /// the constructor is non ERC-721 compliant. @@ -793,18 +795,15 @@ mod tests { use alloy_primitives::{address, uint, Address, U256}; use stylus_sdk::msg; - use crate::{ - token::{ - erc721, - erc721::{ - extensions::consecutive::{ - ERC721ExceededMaxBatchMint, Erc721Consecutive, Error, - }, - tests::random_token_id, - ERC721InvalidReceiver, ERC721NonexistentToken, IErc721, + use crate::token::{ + erc721, + erc721::{ + extensions::consecutive::{ + ERC721ExceededMaxBatchMint, Erc721Consecutive, Error, U96, }, + tests::random_token_id, + ERC721InvalidReceiver, ERC721NonexistentToken, IErc721, }, - utils::structs::checkpoints::U96, }; const BOB: Address = address!("F4EaCDAbEf3c8f1EdE91b6f2A6840bc2E4DD3526"); diff --git a/contracts/src/utils/structs/checkpoints/mod.rs b/contracts/src/utils/structs/checkpoints/mod.rs new file mode 100644 index 00000000..979216ad --- /dev/null +++ b/contracts/src/utils/structs/checkpoints/mod.rs @@ -0,0 +1,97 @@ +//! Contract for checkpointing values as they change at different points in +//! time, to looking up past values by block number later. +//! +//! To create a history of checkpoints, define a variable type [`trace::Trace`] +//! in your contract. +//! Types [`S160`], [`S160`] and [`S160`] can be used to +//! define sizes for key and value. +//! Then store a new checkpoint for the current +//! transaction block using the [`trace::Trace::push`] function. +pub mod trace; + +use core::ops::{Add, Div, Mul, Sub}; + +use alloy_primitives::Uint; +use stylus_sdk::prelude::*; + +/// Trait that associates types of specific size for checkpoints key and value. +pub trait Size { + /// Type of the key in abi. + type Key: Num; + + /// Type of the key in storage. + type KeyStorage: for<'a> StorageType = Self::Key> + + Accessor; + + /// Type of the value in abi. + type Value: Num; + + /// Type of the value in storage. + type ValueStorage: for<'a> StorageType = Self::Value> + + Accessor; +} + +/// Size of checkpoint storage contract corresponding to the size of 96 bits of +/// the key and size 160 bits of the value. +pub type S160 = SpecificSize<96, 2, 160, 3>; + +/// Size of checkpoint storage contract corresponding to the size of 32 bits of +/// the key and size 224 bits of the value. +pub type S224 = SpecificSize<32, 1, 224, 4>; + +/// Size of checkpoint storage contract corresponding to the size of 48 bits of +/// the key and size 208 bits of the value. +pub type S208 = SpecificSize<48, 1, 208, 4>; + +/// Contains the size of checkpoint's key and value in bits. +pub struct SpecificSize< + const KEY_BITS: usize, + const KEY_LIMBS: usize, + const VALUE_BITS: usize, + const VALUE_LIMBS: usize, +>; + +impl Size + for SpecificSize +{ + type Key = Uint; + type KeyStorage = stylus_sdk::storage::StorageUint; + type Value = Uint; + type ValueStorage = stylus_sdk::storage::StorageUint; +} + +/// Abstracts number inside the checkpoint contract. +pub trait Num: Add + Sub + Mul + Div + Ord + Sized + Copy { + /// Zero value of the number. + const ZERO: Self; +} + +impl Num for Uint { + const ZERO: Self = Self::ZERO; +} + +/// Abstracts accessor inside the checkpoint contract +pub trait Accessor { + /// Type of the number associated with the storage type. + type Wrap: Num; + + /// Gets underlying element [`Self::Wrap`] from persistent storage. + fn get(&self) -> Self::Wrap; + + /// Sets underlying element [`Self::Wrap`] in persistent storage. + fn set(&mut self, value: Self::Wrap); +} + +impl Accessor + for stylus_sdk::storage::StorageUint +{ + type Wrap = Uint; + + fn get(&self) -> Self::Wrap { + self.get() + } + + fn set(&mut self, value: Self::Wrap) { + self.set(value); + } +} diff --git a/contracts/src/utils/structs/checkpoints.rs b/contracts/src/utils/structs/checkpoints/trace.rs similarity index 84% rename from contracts/src/utils/structs/checkpoints.rs rename to contracts/src/utils/structs/checkpoints/trace.rs index 9e055756..d6a61a55 100644 --- a/contracts/src/utils/structs/checkpoints.rs +++ b/contracts/src/utils/structs/checkpoints/trace.rs @@ -1,34 +1,24 @@ //! Contract for checkpointing values as they change at different points in -//! time, and later looking up and later looking up past values by block number. -//! -//! To create a history of checkpoints, define a variable type [`Trace160`] -//! in your contract, and store a new checkpoint for the current transaction -//! block using the [`Trace160::push`] function. -use alloy_primitives::{uint, Uint, U256, U32}; +//! time, to looking up past values by block number later. + +use alloy_primitives::{uint, U256, U32}; use alloy_sol_types::sol; -use stylus_proc::{sol_storage, SolidityError}; +use stylus_proc::{storage, SolidityError}; use stylus_sdk::{ call::MethodError, - storage::{StorageGuard, StorageGuardMut}, + storage::{StorageGuard, StorageGuardMut, StorageVec}, }; +use super::{Accessor, Num, Size}; use crate::utils::math::alloy::Math; -// TODO: add generics for other pairs (uint32, uint224) and (uint48, uint208). -// Logic should be the same. -/// [`Uint`] for 96 bits. -pub type U96 = Uint<96, 2>; - -/// [`Uint`] for 160 bits. -pub type U160 = Uint<160, 3>; - sol! { /// A value was attempted to be inserted into a past checkpoint. #[derive(Debug)] error CheckpointUnorderedInsertion(); } -/// An error that occurred while calling the [`Trace160`] checkpoint contract. +/// An error that occurred while calling the [`Trace`] checkpoint contract. #[derive(SolidityError, Debug)] pub enum Error { /// A value was attempted to be inserted into a past checkpoint. @@ -41,24 +31,24 @@ impl MethodError for Error { } } -sol_storage! { - /// State of the checkpoint library contract. - pub struct Trace160 { - /// Stores checkpoints in a dynamic array sorted by key. - Checkpoint160[] _checkpoints; - } +/// State of the checkpoint library contract. +#[storage] +pub struct Trace { + /// Stores checkpoints in a dynamic array sorted by key. + _checkpoints: StorageVec>, +} - /// State of a single checkpoint. - pub struct Checkpoint160 { - /// The key of the checkpoint. Used as a sorting key. - uint96 _key; - /// The value corresponding to the key. - uint160 _value; - } +/// State of a single checkpoint. +#[storage] +pub struct Checkpoint { + /// The key of the checkpoint. Used as a sorting key. + _key: T::KeyStorage, + /// The value corresponding to the key. + _value: T::ValueStorage, } -impl Trace160 { - /// Pushes a (`key`, `value`) pair into a `Trace160` so that it is +impl Trace { + /// Pushes a (`key`, `value`) pair into a `Trace` so that it is /// stored as the checkpoint. /// /// Returns the previous value and the new value as an ordered pair. @@ -79,48 +69,49 @@ impl Trace160 { /// maintain sorted order). pub fn push( &mut self, - key: U96, - value: U160, - ) -> Result<(U160, U160), Error> { + key: T::Key, + value: T::Value, + ) -> Result<(T::Value, T::Value), Error> { self._insert(key, value) } /// Returns the value in the first (oldest) checkpoint with key greater or - /// equal than the search key, or `U160::ZERO` if there is none. + /// equal than the search key, or `T::Value::ZERO` if there is none. /// /// # Arguments /// /// * `&self` - Read access to the checkpoint's state. /// * `key` - Checkpoint's key to lookup. - pub fn lower_lookup(&self, key: U96) -> U160 { + pub fn lower_lookup(&self, key: T::Key) -> T::Value { let len = self.length(); let pos = self._lower_binary_lookup(key, U256::ZERO, len); if pos == len { - U160::ZERO + T::Value::ZERO } else { self._index(pos)._value.get() } } /// Returns the value in the last (most recent) checkpoint with key - /// lower or equal than the search key, or `U160::ZERO` if there is none. + /// lower or equal than the search key, or `T::Value::ZERO` if there is + /// none. /// /// # Arguments /// /// * `&self` - Read access to the checkpoint's state. /// * `key` - Checkpoint's key to lookup. - pub fn upper_lookup(&self, key: U96) -> U160 { + pub fn upper_lookup(&self, key: T::Key) -> T::Value { let len = self.length(); let pos = self._upper_binary_lookup(key, U256::ZERO, len); if pos == U256::ZERO { - U160::ZERO + T::Value::ZERO } else { self._index(pos - uint!(1_U256))._value.get() } } /// Returns the value in the last (most recent) checkpoint with key lower or - /// equal than the search key, or `U160::ZERO` if there is none. + /// equal than the search key, or `T::Value::ZERO` if there is none. /// /// This is a variant of [`Self::upper_lookup`] that is optimized to find /// "recent" checkpoints (checkpoints with high keys). @@ -129,7 +120,7 @@ impl Trace160 { /// /// * `&self` - Read access to the checkpoint's state. /// * `key` - Checkpoint's key to query. - pub fn upper_lookup_recent(&self, key: U96) -> U160 { + pub fn upper_lookup_recent(&self, key: T::Key) -> T::Value { let len = self.length(); let mut low = U256::ZERO; @@ -147,22 +138,22 @@ impl Trace160 { let pos = self._upper_binary_lookup(key, low, high); if pos == U256::ZERO { - U160::ZERO + T::Value::ZERO } else { self._index(pos - uint!(1_U256))._value.get() } } - /// Returns the value in the most recent checkpoint, or `U160::ZERO` if + /// Returns the value in the most recent checkpoint, or `T::Value::ZERO` if /// there are no checkpoints. /// /// # Arguments /// /// * `&self` - Read access to the checkpoint's state. - pub fn latest(&self) -> U160 { + pub fn latest(&self) -> T::Value { let pos = self.length(); if pos == U256::ZERO { - U160::ZERO + T::Value::ZERO } else { self._index(pos - uint!(1_U256))._value.get() } @@ -175,7 +166,7 @@ impl Trace160 { /// # Arguments /// /// * `&self` - Read access to the checkpoint's state. - pub fn latest_checkpoint(&self) -> Option<(U96, U160)> { + pub fn latest_checkpoint(&self) -> Option<(T::Key, T::Value)> { let pos = self.length(); if pos == U256::ZERO { None @@ -204,7 +195,7 @@ impl Trace160 { /// /// * `&self` - Read access to the checkpoint's state. /// * `pos` - Index of the checkpoint. - pub fn at(&self, pos: U32) -> (U96, U160) { + pub fn at(&self, pos: U32) -> (T::Key, T::Value) { let guard = self._checkpoints.get(pos).unwrap_or_else(|| { panic!("should get checkpoint at index `{pos}`") }); @@ -228,9 +219,9 @@ impl Trace160 { /// returned. fn _insert( &mut self, - key: U96, - value: U160, - ) -> Result<(U160, U160), Error> { + key: T::Key, + value: T::Value, + ) -> Result<(T::Value, T::Value), Error> { let pos = self.length(); if pos > U256::ZERO { let last = self._index(pos - uint!(1_U256)); @@ -251,7 +242,7 @@ impl Trace160 { Ok((last_value, value)) } else { self._unchecked_push(key, value); - Ok((U160::ZERO, value)) + Ok((T::Value::ZERO, value)) } } @@ -271,7 +262,7 @@ impl Trace160 { /// * `high` - Exclusive index where search ends. fn _upper_binary_lookup( &self, - key: U96, + key: T::Key, mut low: U256, mut high: U256, ) -> U256 { @@ -302,7 +293,7 @@ impl Trace160 { /// * `high` - Exclusive index where search ends. fn _lower_binary_lookup( &self, - key: U96, + key: T::Key, mut low: U256, mut high: U256, ) -> U256 { @@ -328,7 +319,7 @@ impl Trace160 { /// /// * `&self` - Read access to the checkpoint's state. /// * `pos` - Index of the checkpoint. - fn _index(&self, pos: U256) -> StorageGuard { + fn _index(&self, pos: U256) -> StorageGuard> { self._checkpoints .get(pos) .unwrap_or_else(|| panic!("should get checkpoint at index `{pos}`")) @@ -345,7 +336,7 @@ impl Trace160 { /// /// * `&mut self` - Write access to the checkpoint's state. /// * `pos` - Index of the checkpoint. - fn _index_mut(&mut self, pos: U256) -> StorageGuardMut { + fn _index_mut(&mut self, pos: U256) -> StorageGuardMut> { self._checkpoints .setter(pos) .unwrap_or_else(|| panic!("should get checkpoint at index `{pos}`")) @@ -358,7 +349,7 @@ impl Trace160 { /// * `&mut self` - Write access to the checkpoint's state. /// * `key` - Checkpoint key to insert. /// * `value` - Checkpoint value corresponding to insertion `key`. - fn _unchecked_push(&mut self, key: U96, value: U160) { + fn _unchecked_push(&mut self, key: T::Key, value: T::Value) { let mut new_checkpoint = self._checkpoints.grow(); new_checkpoint._key.set(key); new_checkpoint._value.set(value); @@ -369,12 +360,11 @@ impl Trace160 { mod tests { use alloy_primitives::uint; - use crate::utils::structs::checkpoints::{ - CheckpointUnorderedInsertion, Error, Trace160, - }; + use super::{CheckpointUnorderedInsertion, Error, Trace}; + use crate::utils::structs::checkpoints::S160; #[motsu::test] - fn push(checkpoint: Trace160) { + fn push(checkpoint: Trace) { let first_key = uint!(1_U96); let first_value = uint!(11_U160); @@ -396,7 +386,7 @@ mod tests { } #[motsu::test] - fn push_same_value(checkpoint: Trace160) { + fn push_same_value(checkpoint: Trace) { let first_key = uint!(1_U96); let first_value = uint!(11_U160); @@ -421,7 +411,7 @@ mod tests { } #[motsu::test] - fn lower_lookup(checkpoint: Trace160) { + fn lower_lookup(checkpoint: Trace) { checkpoint.push(uint!(1_U96), uint!(11_U160)).expect("push first"); checkpoint.push(uint!(3_U96), uint!(33_U160)).expect("push second"); checkpoint.push(uint!(5_U96), uint!(55_U160)).expect("push third"); @@ -433,7 +423,7 @@ mod tests { } #[motsu::test] - fn upper_lookup(checkpoint: Trace160) { + fn upper_lookup(checkpoint: Trace) { checkpoint.push(uint!(1_U96), uint!(11_U160)).expect("push first"); checkpoint.push(uint!(3_U96), uint!(33_U160)).expect("push second"); checkpoint.push(uint!(5_U96), uint!(55_U160)).expect("push third"); @@ -445,7 +435,7 @@ mod tests { } #[motsu::test] - fn upper_lookup_recent(checkpoint: Trace160) { + fn upper_lookup_recent(checkpoint: Trace) { // `upper_lookup_recent` has different optimizations for "short" (<=5) // and "long" (>5) checkpoint arrays. // @@ -489,7 +479,7 @@ mod tests { } #[motsu::test] - fn latest(checkpoint: Trace160) { + fn latest(checkpoint: Trace) { assert_eq!(checkpoint.latest(), uint!(0_U160)); checkpoint.push(uint!(1_U96), uint!(11_U160)).expect("push first"); checkpoint.push(uint!(3_U96), uint!(33_U160)).expect("push second"); @@ -498,7 +488,7 @@ mod tests { } #[motsu::test] - fn latest_checkpoint(checkpoint: Trace160) { + fn latest_checkpoint(checkpoint: Trace) { assert_eq!(checkpoint.latest_checkpoint(), None); checkpoint.push(uint!(1_U96), uint!(11_U160)).expect("push first"); checkpoint.push(uint!(3_U96), uint!(33_U160)).expect("push second"); @@ -510,7 +500,7 @@ mod tests { } #[motsu::test] - fn error_when_unordered_insertion(checkpoint: Trace160) { + fn error_when_unordered_insertion(checkpoint: Trace) { checkpoint.push(uint!(1_U96), uint!(11_U160)).expect("push first"); checkpoint.push(uint!(3_U96), uint!(33_U160)).expect("push second"); let err = checkpoint