This repository has been archived by the owner on Jan 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
LibDiamond.sol
212 lines (194 loc) · 9.12 KB
/
LibDiamond.sol
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
// SPDX-License-Identifier: GPL-3.0-or-later
pragma solidity >=0.8.18;
/******************************************************************************\
* Author: Nick Mudge <nick@perfectabstractions.com> (https://twitter.com/mudgen)
* EIP-2535 Diamonds: https://eips.ethereum.org/EIPS/eip-2535
/******************************************************************************/
import { IDiamondCut } from "../interfaces/IDiamondCut.sol";
import { IDiamondLoupe } from "../interfaces/IDiamondLoupe.sol";
import { IERC165 } from "../interfaces/IERC165.sol";
library LibDiamond {
bytes32 public constant DIAMOND_STORAGE_POSITION =
keccak256("diamond.standard.diamond.storage");
struct FacetAddressAndSelectorPosition {
address facetAddress;
uint16 selectorPosition;
}
struct DiamondStorage {
// function selector => facet address and selector position in selectors array
mapping(bytes4 => FacetAddressAndSelectorPosition) facetAddressAndSelectorPosition;
bytes4[] selectors;
mapping(bytes4 => bool) supportedInterfaces;
// owner of the contract
address contractOwner;
}
function diamondStorage() internal pure returns (DiamondStorage storage ds) {
bytes32 position = DIAMOND_STORAGE_POSITION;
assembly {
ds.slot := position
}
}
event OwnershipTransferred(address indexed previousOwner, address indexed newOwner);
function setContractOwner(address _newOwner) internal {
DiamondStorage storage ds = diamondStorage();
address previousOwner = ds.contractOwner;
ds.contractOwner = _newOwner;
emit OwnershipTransferred(previousOwner, _newOwner);
}
function contractOwner() internal view returns (address contractOwner_) {
contractOwner_ = diamondStorage().contractOwner;
}
function enforceIsOwnerOrContract() internal view {
require(
msg.sender == diamondStorage().contractOwner || msg.sender == address(this),
"LibDiamond: Must be contract or owner"
);
}
function enforceIsContractOwner() internal view {
require(msg.sender == diamondStorage().contractOwner, "LibDiamond: Must be contract owner");
}
event DiamondCut(IDiamondCut.FacetCut[] _diamondCut, address _init, bytes _calldata);
// Internal function version of diamondCut
function diamondCut(
IDiamondCut.FacetCut[] memory _diamondCut,
address _init,
bytes memory _calldata
) internal {
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
IDiamondCut.FacetCutAction action = _diamondCut[facetIndex].action;
if (action == IDiamondCut.FacetCutAction.Add) {
addFunctions(
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].functionSelectors
);
} else if (action == IDiamondCut.FacetCutAction.Replace) {
replaceFunctions(
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].functionSelectors
);
} else if (action == IDiamondCut.FacetCutAction.Remove) {
removeFunctions(
_diamondCut[facetIndex].facetAddress,
_diamondCut[facetIndex].functionSelectors
);
} else {
revert("LibDiamondCut: Incorrect FacetCutAction");
}
}
emit DiamondCut(_diamondCut, _init, _calldata);
initializeDiamondCut(_init, _calldata);
}
function addFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
DiamondStorage storage ds = diamondStorage();
uint16 selectorCount = uint16(ds.selectors.length);
require(_facetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
enforceHasContractCode(_facetAddress, "LibDiamondCut: Add facet has no code");
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
require(
oldFacetAddress == address(0),
"LibDiamondCut: Can't add function that already exists"
);
ds.facetAddressAndSelectorPosition[selector] = FacetAddressAndSelectorPosition(
_facetAddress,
selectorCount
);
ds.selectors.push(selector);
selectorCount++;
}
}
function replaceFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
DiamondStorage storage ds = diamondStorage();
require(_facetAddress != address(0), "LibDiamondCut: Replace facet can't be address(0)");
enforceHasContractCode(_facetAddress, "LibDiamondCut: Replace facet has no code");
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
// can't replace immutable functions -- functions defined directly in the diamond
require(
oldFacetAddress != address(this),
"LibDiamondCut: Can't replace immutable function"
);
require(
oldFacetAddress != _facetAddress,
"LibDiamondCut: Can't replace function with same function"
);
require(
oldFacetAddress != address(0),
"LibDiamondCut: Can't replace function that doesn't exist"
);
// replace old facet address
ds.facetAddressAndSelectorPosition[selector].facetAddress = _facetAddress;
}
}
function removeFunctions(address _facetAddress, bytes4[] memory _functionSelectors) internal {
require(_functionSelectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
DiamondStorage storage ds = diamondStorage();
uint256 selectorCount = ds.selectors.length;
require(
_facetAddress == address(0),
"LibDiamondCut: Remove facet address must be address(0)"
);
for (uint256 selectorIndex; selectorIndex < _functionSelectors.length; selectorIndex++) {
bytes4 selector = _functionSelectors[selectorIndex];
FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds
.facetAddressAndSelectorPosition[selector];
require(
oldFacetAddressAndSelectorPosition.facetAddress != address(0),
"LibDiamondCut: Can't remove function that doesn't exist"
);
// can't remove immutable functions -- functions defined directly in the diamond
require(
oldFacetAddressAndSelectorPosition.facetAddress != address(this),
"LibDiamondCut: Can't remove immutable function."
);
// replace selector with last selector
selectorCount--;
if (oldFacetAddressAndSelectorPosition.selectorPosition != selectorCount) {
bytes4 lastSelector = ds.selectors[selectorCount];
ds.selectors[oldFacetAddressAndSelectorPosition.selectorPosition] = lastSelector;
ds
.facetAddressAndSelectorPosition[lastSelector]
.selectorPosition = oldFacetAddressAndSelectorPosition.selectorPosition;
}
// delete last selector
ds.selectors.pop();
delete ds.facetAddressAndSelectorPosition[selector];
}
}
function initializeDiamondCut(address _init, bytes memory _calldata) internal {
if (_init == address(0)) {
require(
_calldata.length == 0,
"LibDiamondCut: _init is address(0) but_calldata is not empty"
);
} else {
require(
_calldata.length > 0,
"LibDiamondCut: _calldata is empty but _init is not address(0)"
);
if (_init != address(this)) {
enforceHasContractCode(_init, "LibDiamondCut: _init address has no code");
}
(bool success, bytes memory error) = _init.delegatecall(_calldata);
if (!success) {
if (error.length > 0) {
// bubble up the error
revert(string(error));
} else {
revert("LibDiamondCut: _init function reverted");
}
}
}
}
function enforceHasContractCode(address _contract, string memory _errorMessage) internal view {
uint256 contractSize;
assembly {
contractSize := extcodesize(_contract)
}
require(contractSize > 0, _errorMessage);
}
}