Skip to content

Commit

Permalink
Merge pull request #3 from LedgerHQ/classify-eip712-swap
Browse files Browse the repository at this point in the history
feat: classify and check a permit2
  • Loading branch information
ckorchane-ledger authored Sep 19, 2024
2 parents 0ce6bd9 + f4e5a9d commit a2116ac
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 11 deletions.
7 changes: 0 additions & 7 deletions demo-registry/uniswap/eip712-permit2.json
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,6 @@
"params": {
"tokenPath": "details.token"
}
},
"details.expiration": {
"label": "Valid Until",
"format": "date",
"params": {
"encoding": "timestamp"
}
}
},
"required": [
Expand Down
4 changes: 3 additions & 1 deletion src/erc7730/classifier/eip712_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@
class EIP712Classifier(Classifier[EIP712JsonSchema]):
@override
def classify(self, schema: EIP712JsonSchema) -> TxClass | None:
pass
if "permit" in schema.primaryType.lower():
return TxClass.PERMIT
return None
51 changes: 49 additions & 2 deletions src/erc7730/display_format_checker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,59 @@
from erc7730.classifier import TxClass
from erc7730.linter import Linter
from erc7730.model.display import Display
from erc7730.model.display import Display, Format


def _fields_contain(word: str, fields: set[str]) -> bool:
"""
to check if the provided keyword is contained in one of the fields (case insensitive)
"""
for field in fields:
if word.lower() in field.lower():
return True
return False


class DisplayFormatChecker:
def __init__(self, c: TxClass, d: Display):
self.c = c
self.d = d

def _get_all_displayed_fields(self, formats: dict[str, Format]) -> set[str]:
fields: set[str] = set()
for format in formats.values():
if format.fields is not None:
for field in format.fields.root.keys():
fields.add(str(field))
return fields

def check(self) -> list[Linter.Output]:
return []
res: list[Linter.Output] = []
match self.c:
case TxClass.PERMIT:
formats = self.d.formats
fields = self._get_all_displayed_fields(formats)
if not _fields_contain("spender", fields):
res.append(
Linter.Output(
title="Missing spender in displayed fields", message="", level=Linter.Output.Level.ERROR
)
)
if not _fields_contain("amount", fields):
res.append(
Linter.Output(
title="Missing amount in displayed fields", message="", level=Linter.Output.Level.ERROR
)
)
if (
not _fields_contain("valid until", fields)
and not _fields_contain("expiry", fields)
and not _fields_contain("expiration", fields)
):
res.append(
Linter.Output(
title="Missing expiration date in displayed fields for permit",
message="",
level=Linter.Output.Level.ERROR,
)
)
return res
8 changes: 8 additions & 0 deletions src/erc7730/linter/linter_transaction_type_classifier_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,15 @@ def lint(self, descriptor: ERC7730Descriptor, out: Linter.OutputAdder) -> None:
return None
c = determine_tx_class(descriptor)
if c is None:
out(
Linter.Output(
title="Transaction type: ",
message="could not determine transaction type",
level=Linter.Output.Level.WARNING,
)
)
return None
out(Linter.Output(title="Transaction type: ", message=str(c), level=Linter.Output.Level.INFO))
d: Display | None = descriptor.display
if d is None:
return None
Expand Down
2 changes: 1 addition & 1 deletion src/erc7730/model/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class StructFormats(BaseLibraryModel):
fields: ForwardRef("Fields") # type: ignore


class Fields(RootModel[dict[str, Union[Reference, Field, StructFormats]]]):
class Fields(RootModel[dict[str, Union[Reference, FieldDescription, Field, StructFormats]]]):
"""todo use StructFormats instead"""


Expand Down

0 comments on commit a2116ac

Please sign in to comment.