Skip to content

Commit

Permalink
feat: Add SimpleGuardianModule to SimplePlusAccount
Browse files Browse the repository at this point in the history
  • Loading branch information
jpgonzalezra committed May 20, 2024
1 parent 7858b34 commit 0bd4f76
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 57 deletions.
87 changes: 87 additions & 0 deletions src/SimpleGuardianModule.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// SPDX-License-Identifier: GPL-3.0
pragma solidity ^0.8.25;

import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
// import { console2 } from "forge-std/src/console2.sol";

abstract contract SimpleGuardianModule {
using ECDSA for bytes32;

bytes32 public constant _RECOVER_TYPEHASH =
keccak256("Recover(address currentOwner, address newOwner, uint256 nonce)");

event NonceConsumed(address indexed owner, uint256 idx);
event GuardianUpdated(address indexed previousGuardian, address indexed newGuardian);

address public guardian;
mapping(address => uint256) private _nonces;

modifier onlyGuardian() {
require(msg.sender == guardian, "Not the guardian");
_;
}

/**
* @notice Retuns a nonce for a given address.
* @param from Address.
* @return uint256 Nonce Value.
*/
function getNonce(address from) external view virtual returns (uint256) {
return _nonces[from];
}

function _verifyAndConsumeNonce(address owner, uint256 nonde) internal virtual {
require(nonde == _nonces[owner]++, "invalid nonce");
emit NonceConsumed(owner, nonde);
}

function initGuardian(address newGuardian) external {
require(guardian == address(0));
_updateGuardian(newGuardian);
}

function updateGuardian(address newGuardian) external {
require(_onlyAuthorized(), "Not authorized");
_updateGuardian(newGuardian);
}

function _updateGuardian(address newGuardian) internal {
require(
newGuardian != address(0) && guardian != newGuardian && newGuardian != address(this),
"Invalid guardian address"
);
address oldGuardian = guardian;
guardian = newGuardian;
emit GuardianUpdated(oldGuardian, newGuardian);
}

function recoverAccount(address newOwner, uint256 nonce, bytes calldata signature) external {
require(
newOwner != address(0) && _owner() != newOwner && newOwner != address(this), "Invalid new owner address"
);

_verifyAndConsumeNonce(newOwner, nonce);
bytes32 structHash = keccak256(abi.encode(_RECOVER_TYPEHASH, _owner(), newOwner, nonce));
bytes32 digest = _hashTypedDataV4(structHash);

address recoveredAddress = digest.recover(signature);

require(recoveredAddress == guardian, "Invalid guardian signature");

_transferOwnership(newOwner);
}

function _transferOwnership(address newOwner) internal virtual;

function _hashTypedDataV4(bytes32 structHash) internal view virtual returns (bytes32);

function _onlyAuthorized() internal view virtual returns (bool);

function _owner() internal view virtual returns (address);

/**
* @dev This empty reserved space is put in place to allow future versions to add new
* variables without shifting down storage in the inheritance chain.
*/
uint256[49] private __gap;
}
131 changes: 74 additions & 57 deletions src/SimplePlusAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@ import { MessageHashUtils } from "@openzeppelin/contracts/utils/cryptography/Mes
import { ECDSA } from "@openzeppelin/contracts/utils/cryptography/ECDSA.sol";
import { IERC1271 } from "@openzeppelin/contracts/interfaces/IERC1271.sol";
import { EIP712 } from "@openzeppelin/contracts/utils/cryptography/EIP712.sol"; // TODO: use upgradable version
import { SimpleGuardianModule } from "./SimpleGuardianModule.sol";
// import { console2 } from "forge-std/src/console2.sol";

contract SimplePlusAccount is SimpleAccount, IERC1271, EIP712 {
contract SimplePlusAccount is SimpleAccount, SimpleGuardianModule, IERC1271, EIP712 {
using ECDSA for bytes32;
using MessageHashUtils for bytes32;

bytes32 internal constant _MESSAGE_TYPEHASH = keccak256("SimplePlusAccount(bytes message)");

modifier onlyAuthorized() {
_onlyAuthorized();
_;
}
bytes32 public constant _MESSAGE_TYPEHASH = keccak256("SimplePlusAccount(bytes message)");

// @notice Signature types used for user operation validation and ERC-1271 signature validation.
enum SignatureType {
Expand Down Expand Up @@ -48,16 +45,18 @@ contract SimplePlusAccount is SimpleAccount, IERC1271, EIP712 {
/// 1. The entry point
/// 2. The account itself (when redirected through `execute`, etc.)
/// 3. An owner
function _onlyAuthorized() internal view {
function _onlyAuthorized() internal view virtual override returns (bool) {
if (msg.sender != address(entryPoint()) && msg.sender != address(this) && msg.sender != owner) {
revert NotAuthorized();
}
return true;
}

/// @notice Transfers ownership of the contract to a new account (`newOwner`). Can only be called by the current
/// owner or from the entry point via a user operation signed by the current owner.
/// @param newOwner The new owner.
function transferOwnership(address newOwner) external onlyAuthorized {
function transferOwnership(address newOwner) public {
require(_onlyAuthorized());
if (newOwner == address(0) || newOwner == address(this) || owner == newOwner) {
revert InvalidOwner(newOwner);
}
Expand All @@ -80,62 +79,80 @@ contract SimplePlusAccount is SimpleAccount, IERC1271, EIP712 {
* "Ethereum Signed Message" envelope before checking the signature for the EOA-owner case.
*/
function isValidSignature(bytes32 hash, bytes calldata _signature) public view virtual returns (bytes4) {
if (_signature.length == 0) {
revert InvalidSignatureType();
if (_signature.length == 0) {
revert InvalidSignatureType();
}

bytes32 structHash = keccak256(abi.encode(_MESSAGE_TYPEHASH, keccak256(abi.encode(hash))));
bytes32 replaySafeHash = MessageHashUtils.toTypedDataHash(_domainSeparatorV4(), structHash);

return _validateSignatureWithType(uint8(_signature[0]), replaySafeHash, _signature[1:])
? this.isValidSignature.selector
: bytes4(0xffffffff);
}

bytes32 structHash = keccak256(abi.encode(_MESSAGE_TYPEHASH, keccak256(abi.encode(hash))));
bytes32 replaySafeHash = MessageHashUtils.toTypedDataHash(_domainSeparatorV4(), structHash);
function _validateSignature(
PackedUserOperation calldata userOp,
bytes32 userOpHash
)
internal
virtual
override
returns (uint256 validationData)
{
if (userOp.signature.length == 0) {
revert InvalidSignatureType();
}

return _validateSignatureWithType(uint8(_signature[0]), replaySafeHash, _signature[1:])
? this.isValidSignature.selector
: bytes4(0xffffffff);
}
return _validateSignatureWithType(
uint8(userOp.signature[0]), userOpHash.toEthSignedMessageHash(), userOp.signature[1:]
) ? SIG_VALIDATION_SUCCESS : SIG_VALIDATION_FAILED;
}

function _validateSignatureWithType(
uint8 signatureType,
bytes32 hash,
bytes memory signature
)
private
view
returns (bool)
{
if (signatureType == uint8(SignatureType.EOA)) {
return _validateEOASignature(hash, signature) == SIG_VALIDATION_SUCCESS;
} else if (signatureType == uint8(SignatureType.CONTRACT)) {
return _validateContractSignature(hash, signature) == SIG_VALIDATION_SUCCESS;
} else {
revert InvalidSignatureType();
}
}

function _validateSignature(
PackedUserOperation calldata userOp,
bytes32 userOpHash
)
internal
virtual
override
returns (uint256 validationData)
{
if (userOp.signature.length == 0) {
revert InvalidSignatureType();
function _validateEOASignature(bytes32 hash, bytes memory signature) private view returns (uint256) {
address recovered = hash.recover(signature);
return recovered == owner ? SIG_VALIDATION_SUCCESS : SIG_VALIDATION_FAILED;
}

return _validateSignatureWithType(uint8(userOp.signature[0]), userOpHash.toEthSignedMessageHash(), userOp.signature[1:])
? SIG_VALIDATION_SUCCESS
: SIG_VALIDATION_FAILED;
}
function _validateContractSignature(bytes32 userOpHash, bytes memory signature) private view returns (uint256) {
return SignatureChecker.isValidERC1271SignatureNow(owner, userOpHash, signature)
? SIG_VALIDATION_SUCCESS
: SIG_VALIDATION_FAILED;
}

function _validateSignatureWithType(
uint8 signatureType,
bytes32 hash,
bytes memory signature
)
private
view
returns (bool)
{
if (signatureType == uint8(SignatureType.EOA)) {
return _validateEOASignature(hash, signature) == SIG_VALIDATION_SUCCESS;
} else if (signatureType == uint8(SignatureType.CONTRACT)) {
return _validateContractSignature(hash, signature) == SIG_VALIDATION_SUCCESS;
} else {
revert InvalidSignatureType();
function _transferOwnership(address newOwner) internal virtual override {
this.transferOwnership(newOwner);
}
}

function _validateEOASignature(bytes32 hash, bytes memory signature) private view returns (uint256) {
address recovered = hash.recover(signature);
return recovered == owner ? SIG_VALIDATION_SUCCESS : SIG_VALIDATION_FAILED;
}
function _hashTypedDataV4(bytes32 structHash)
internal
view
virtual
override(EIP712, SimpleGuardianModule)
returns (bytes32)
{
return super._hashTypedDataV4(structHash);
}

function _validateContractSignature(bytes32 userOpHash, bytes memory signature) private view returns (uint256) {
return SignatureChecker.isValidERC1271SignatureNow(owner, userOpHash, signature)
? SIG_VALIDATION_SUCCESS
: SIG_VALIDATION_FAILED;
}
function _owner() internal view virtual override returns (address) {
return owner;
}
}
25 changes: 25 additions & 0 deletions test/SimplePlusAccount.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,25 @@ import { SimplePlusAccountFactory } from "../src/SimplePlusAccountFactory.sol";
import { EntryPoint } from "@account-abstraction/contracts/core/EntryPoint.sol";
import { SimpleAccount } from "@account-abstraction/contracts/samples/SimpleAccount.sol";
import { AccountTest } from "./AccountTest.sol";
import { MessageHashUtils } from "@openzeppelin/contracts/utils/cryptography/MessageHashUtils.sol";
// import { console2 } from "forge-std/src/console2.sol";

contract SimplePlusAccountTest is AccountTest {
uint256 public constant EOA_PRIVATE_KEY = 1;
uint256 public constant GUARDIAN_PRIVATE_KEY = 2;
address payable public constant BENEFICIARY = payable(address(0xbe9ef1c1a2ee));
address public eoaAddress;
address public guardianAddress;

SimplePlusAccount public account;
EntryPoint public entryPoint;

event OwnershipTransferred(address indexed previousOwner, address indexed newOwner);
event GuardianUpdated(address indexed previousOwner, address indexed newOwner);

function setUp() public {
eoaAddress = vm.addr(EOA_PRIVATE_KEY);
guardianAddress = vm.addr(GUARDIAN_PRIVATE_KEY);
entryPoint = new EntryPoint();
SimplePlusAccountFactory factory = new SimplePlusAccountFactory(entryPoint);
account = factory.createAccount(eoaAddress, 1);
Expand Down Expand Up @@ -97,6 +104,24 @@ contract SimplePlusAccountTest is AccountTest {
assertEq(account.isValidSignature(message, signature), bytes4(keccak256("isValidSignature(bytes32,bytes)")));
}

function testGuardianCanTransferOwnership() public {
vm.prank(guardianAddress);
emit GuardianUpdated(address(0), guardianAddress);
account.initGuardian(guardianAddress);
uint256 nonce = account.getNonce(eoaAddress);

address newOwner = address(0x100);
bytes32 structHash = keccak256(abi.encode(account._RECOVER_TYPEHASH(), eoaAddress, newOwner, nonce));
bytes32 digest = MessageHashUtils.toTypedDataHash(domainSeparator(address(account)), structHash);

bytes memory signature = sign(GUARDIAN_PRIVATE_KEY, digest);

vm.expectEmit(true, true, false, false);
emit OwnershipTransferred(eoaAddress, newOwner);
account.recoverAccount(newOwner, nonce, signature);
assertEq(account.owner(), newOwner);
}

function _transferOwnership(address currentOwner, address newOwner) internal {
vm.prank(currentOwner);
vm.expectEmit(true, true, false, false);
Expand Down

0 comments on commit 0bd4f76

Please sign in to comment.