Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m30m committed Oct 8, 2024
1 parent 4541f94 commit a6cac64
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 91 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import base64
from typing import Any, Annotated
from typing import Any, Annotated, ClassVar

from pydantic import (
GetCoreSchemaHandler,
GetJsonSchemaHandler,
BaseModel,
model_validator,
)
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
Expand All @@ -12,6 +14,11 @@
from solders.pubkey import Pubkey as _SvmAddress
from solders.transaction import Transaction as _SvmTransaction

from express_relay.express_relay_types import (
IntString,
UUIDString,
UnsupportedOpportunityVersionException,
)
from express_relay.svm.generated.limo.accounts import Order


Expand Down Expand Up @@ -171,3 +178,83 @@ def __get_pydantic_json_schema__(
SvmAddress = Annotated[_SvmAddress, _SvmAddressPydanticAnnotation]
SvmHash = Annotated[_SvmHash, _HashPydanticAnnotation]
SvmSignature = Annotated[_SvmSignature, _SignaturePydanticAnnotation]


class _OrderPydanticAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_from_str(value: str) -> Order:
return Order.decode(base64.b64decode(value))

from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)

return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(Order),
from_str_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: base64.b64encode(Order.layout.build(instance)).decode(
"utf-8"
)
),
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Use the same schema that would be used for `str`
return handler(core_schema.str_schema())


class OpportunitySvm(BaseModel):
"""
Attributes:
chain_id: The chain ID to bid on.
version: The version of the opportunity.
creation_time: The creation time of the opportunity.
opportunity_id: The ID of the opportunity.
blockHash: The block hash to use for execution.
slot: The slot where this order was created or updated
program: The program which handles this opportunity
order: The order to be executed.
order_address: The address of the order.
"""

chain_id: str
version: str
creation_time: IntString
opportunity_id: UUIDString

blockHash: SvmHash
slot: int

program: str
order: Annotated[Order, _OrderPydanticAnnotation]
order_address: SvmAddress

supported_versions: ClassVar[list[str]] = ["v1"]
supported_programs: ClassVar[list[str]] = ["limo"]

@model_validator(mode="before")
@classmethod
def check_version(cls, data):
if data["version"] not in cls.supported_versions:
raise UnsupportedOpportunityVersionException(
f"Cannot handle opportunity version: {data['version']}. Please upgrade your client."
)
return data
90 changes: 1 addition & 89 deletions express_relay/sdk/python/express_relay/express_relay_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import base64
from datetime import datetime
from enum import Enum
from pydantic import (
BaseModel,
model_validator,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
Tag,
Discriminator,
RootModel,
Expand All @@ -16,19 +13,15 @@
import web3
from typing import Union, ClassVar, Any
from pydantic import Field
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from typing_extensions import Literal, Annotated
import warnings
import string
from eth_account.datastructures import SignedMessage

from express_relay.svm.generated.limo.accounts import Order
from express_relay.express_relay_svm_types import (
SvmTransaction,
SvmHash,
SvmAddress,
SvmSignature,
OpportunitySvm,
)


Expand Down Expand Up @@ -331,87 +324,6 @@ class OpportunityParams(BaseModel):
params: Union[OpportunityParamsV1] = Field(..., discriminator="version")


class _OrderPydanticAnnotation:
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_from_str(value: str) -> Order:
return Order.decode(base64.b64decode(value))

from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_from_str),
]
)

return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(Order),
from_str_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: base64.b64encode(Order.layout.build(instance)).decode(
"utf-8"
)
),
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Use the same schema that would be used for `str`
return handler(core_schema.str_schema())


class OpportunitySvm(BaseModel):
"""
Attributes:
target_calldata: The calldata for the contract call.
chain_id: The chain ID to bid on.
target_contract: The contract address to call.
permission_key: The permission key to bid on.
buy_tokens: The tokens to receive in the opportunity.
sell_tokens: The tokens to spend in the opportunity.
target_call_value: The value to send with the contract call.
version: The version of the opportunity.
creation_time: The creation time of the opportunity.
opportunity_id: The ID of the opportunity.
"""

chain_id: str
version: str
creation_time: IntString
opportunity_id: UUIDString

blockHash: SvmHash
slot: int

program: str
order: Annotated[Order, _OrderPydanticAnnotation]
order_address: SvmAddress

supported_versions: ClassVar[list[str]] = ["v1"]
supported_programs: ClassVar[list[str]] = ["limo"]

@model_validator(mode="before")
@classmethod
def check_version(cls, data):
if data["version"] not in cls.supported_versions:
raise UnsupportedOpportunityVersionException(
f"Cannot handle opportunity version: {data['version']}. Please upgrade your client."
)
return data


class OpportunityEvm(BaseModel):
"""
Attributes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
BidSvm,
Opportunity,
OpportunityEvm,
OpportunitySvm,
)
from express_relay.express_relay_svm_types import OpportunitySvm
from express_relay.svm.generated.express_relay.accounts import ExpressRelayMetadata
from express_relay.svm.generated.express_relay.program_id import (
PROGRAM_ID as SVM_EXPRESS_RELAY_PROGRAM_ID,
Expand Down

0 comments on commit a6cac64

Please sign in to comment.