diff --git a/tests/analysis/snapshots/__init__.py b/tests/analysis/snapshots/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/analysis/snapshots/snap_test_currency_changes.py b/tests/analysis/snapshots/snap_test_currency_changes.py new file mode 100644 index 0000000..ecd1907 --- /dev/null +++ b/tests/analysis/snapshots/snap_test_currency_changes.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot + + +snapshots = Snapshot() + +snapshots["test_currency_changes_extractor currency changes"] = [ + ( + GenericRepr(""), + { + "change": -10, + "owner": "0x000000000000000000000000000000000000aaaa", + "token_address": None, + "type": "ETHER", + }, + ), + ( + GenericRepr(""), + { + "change": 10, + "owner": "0x000000000000000000000000000000000000bbbb", + "token_address": None, + "type": "ETHER", + }, + ), + ( + GenericRepr(""), + { + "change": -20, + "owner": "0x000000000000000000000000000000000000aaaa", + "token_address": None, + "type": "ETHER", + }, + ), + ( + GenericRepr(""), + { + "change": 20, + "owner": "0x000000000000000000000000000000000000cccc", + "token_address": None, + "type": "ETHER", + }, + ), + ( + GenericRepr(""), + { + "change": -10, + "owner": "0x000000000000000000000000000000000000bbbb", + "token_address": None, + "type": "ETHER", + }, + ), + ( + GenericRepr(""), + { + "change": 10, + "owner": "0x000000000000000000000000000000000000cccc", + "token_address": None, + "type": "ETHER", + }, + ), +] diff --git a/tests/analysis/test_currency_changes.py b/tests/analysis/test_currency_changes.py new file mode 100644 index 0000000..4c706b2 --- /dev/null +++ b/tests/analysis/test_currency_changes.py @@ -0,0 +1,29 @@ +from tests.test_utils.test_utils import ( + _test_callcode, +) +from traces_analyzer.features.extractors.currency_changes import ( + CurrencyChangesFeatureExtractor, +) +from tests.test_utils.test_utils import _test_call + +from snapshottest.pytest import PyTestSnapshotTest + + +def test_currency_changes_extractor(snapshot: PyTestSnapshotTest): + instructions = [ + # A -= 10 B += 10 + _test_call("0xaaaa", 1, "0xa", "0xbbbb"), + # reverted + _test_call("0xaaaa", 1, "0xaa", "0xbbbb", reverted=True), + # A -= 20 C += 20 + _test_callcode("0xaaaa", 1, "0x14", "0xcccc"), + # B -= 10 C += 10 + _test_call("0xbbbb", 1, "0xa", "0xcccc"), + ] + + extractor = CurrencyChangesFeatureExtractor() + for instruction in instructions: + extractor.on_instruction(instruction) + changes = extractor.currency_changes + + snapshot.assert_match(changes, "currency changes") diff --git a/tests/e2e/test_sample_traces_analysis.py b/tests/e2e/test_sample_traces_analysis.py index 5e32902..c9cfd8c 100644 --- a/tests/e2e/test_sample_traces_analysis.py +++ b/tests/e2e/test_sample_traces_analysis.py @@ -81,8 +81,13 @@ def test_sample_traces_analysis_e2e( ) # Instruction usage has found 17 contracts - assert len(instruction_usage_analyzer.normal.get_used_opcodes_per_contract()) == 17 - assert len(instruction_usage_analyzer.reverse.get_used_opcodes_per_contract()) == 17 + assert ( + len(instruction_usage_analyzer.normal.get_used_opcodes_per_contract()) == 17 + ) + assert ( + len(instruction_usage_analyzer.reverse.get_used_opcodes_per_contract()) + == 17 + ) # TOD source tod_source = tod_source_analyzer.get_tod_source() diff --git a/tests/evaluation/snapshots/snap_test_financial_gain_loss.py b/tests/evaluation/snapshots/snap_test_financial_gain_loss.py new file mode 100644 index 0000000..e1e7724 --- /dev/null +++ b/tests/evaluation/snapshots/snap_test_financial_gain_loss.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import Snapshot + + +snapshots = Snapshot() + +snapshots["test_financial_gain_loss_evaluation evaluation_dict"] = { + "evaluation_type": "financial_gain_loss", + "report": { + "gains": { + "0x000000000000000000000000000000000000cccc": { + "ETHER": { + "change": 10, + "owner": "0x000000000000000000000000000000000000cccc", + "token_address": None, + "type": "ETHER", + } + } + }, + "losses": { + "0x000000000000000000000000000000000000aaaa": { + "ETHER": { + "change": -8, + "owner": "0x000000000000000000000000000000000000aaaa", + "token_address": None, + "type": "ETHER", + } + }, + "0x000000000000000000000000000000000000bbbb": { + "ETHER": { + "change": -2, + "owner": "0x000000000000000000000000000000000000bbbb", + "token_address": None, + "type": "ETHER", + } + }, + }, + }, +} + +snapshots[ + "test_financial_gain_loss_evaluation evaluation_str" +] = """=== Evaluation: Financial gains and losses === +Losses in normal compared to reverse scenario: +> 0x000000000000000000000000000000000000cccc lost 10 ETHER (in Wei) + + +""" diff --git a/tests/evaluation/test_financial_gain_loss.py b/tests/evaluation/test_financial_gain_loss.py new file mode 100644 index 0000000..4634861 --- /dev/null +++ b/tests/evaluation/test_financial_gain_loss.py @@ -0,0 +1,61 @@ +import json +from typing import Sequence +from tests.test_utils.test_utils import ( + _test_callcode, +) +from traces_analyzer.evaluation.financial_gain_loss_evaluation import ( + FinancialGainLossEvaluation, +) +from traces_analyzer.features.extractors.currency_changes import ( + CurrencyChangesFeatureExtractor, +) +from tests.test_utils.test_utils import _test_call +from traces_parser.parser.instructions.instruction import Instruction + +from snapshottest.pytest import PyTestSnapshotTest + + +def get_changes(instructions: Sequence[Instruction]): + extractor = CurrencyChangesFeatureExtractor() + for instruction in instructions: + extractor.on_instruction(instruction) + return extractor.currency_changes + + +def test_financial_gain_loss_evaluation(snapshot: PyTestSnapshotTest): + changes_normal = get_changes( + [ + # A -= 10 B += 10 + _test_call("0xaaaa", 1, "0xa", "0xbbbb"), + # reverted + _test_call("0xaaaa", 1, "0xaa", "0xbbbb", reverted=True), + # A -= 20 C += 20 + _test_callcode("0xaaaa", 1, "0x14", "0xcccc"), + # B -= 10 C += 10 + _test_call("0xbbbb", 1, "0xa", "0xcccc"), + ] + ) + changes_reverse = get_changes( + [ + # A -= 2 B += 2 + _test_call("0xaaaa", 1, "0x2", "0xbbbb"), + # reverted + _test_call("0xaaaa", 1, "0xaa", "0xbbbb", reverted=True), + # A -= 20 C += 20 + _test_callcode("0xaaaa", 1, "0x14", "0xcccc"), + ] + ) + + evaluation = FinancialGainLossEvaluation( + changes_normal, + changes_reverse, + ) + + evaluation_dict = evaluation.dict_report() + snapshot.assert_match(evaluation_dict, "evaluation_dict") + + evaluation_str = evaluation.cli_report() + snapshot.assert_match(evaluation_str, "evaluation_str") + + # check if it's serializable + json.dumps(evaluation_dict) diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py index 0b325d2..47d77eb 100644 --- a/tests/test_utils/test_utils.py +++ b/tests/test_utils/test_utils.py @@ -21,6 +21,7 @@ SLOAD, CallInstruction, CALL, + CALLCODE, ) from traces_parser.parser.storage.address_key_storage import AddressKeyStorage from traces_parser.parser.storage.balances import Balances @@ -405,11 +406,38 @@ def _test_sstore( ) -def _test_call(current_address: str, pc: int, value: str, address: str): +def _test_call(current_address: str, pc: int, value: str, address: str, reverted=False): return _test_instruction( CALL, pc=pc, - call_context=_test_call_context(code_address=_test_addr(current_address)), + call_context=_test_call_context( + code_address=_test_addr(current_address), + storage_address=_test_addr(current_address), + reverted=reverted, + ), + flow=_test_flow( + accesses=StorageAccesses( + stack=_test_stack_accesses( + ["0x1234", address, value, "0x0", "0x4", "0x0", "0x0"] + ), + memory=[_test_mem_access("11111111")], + ), + writes=StorageWrites(calldata=CalldataWrite(_test_group("11111111"))), + ), + ) + + +def _test_callcode( + current_address: str, pc: int, value: str, address: str, reverted=False +): + return _test_instruction( + CALLCODE, + pc=pc, + call_context=_test_call_context( + code_address=_test_addr(current_address), + storage_address=_test_addr(current_address), + reverted=reverted, + ), flow=_test_flow( accesses=StorageAccesses( stack=_test_stack_accesses( diff --git a/traces_analyzer/cli.py b/traces_analyzer/cli.py index d3e29d3..7f6050a 100644 --- a/traces_analyzer/cli.py +++ b/traces_analyzer/cli.py @@ -6,10 +6,12 @@ from typing import Iterable from importlib.metadata import version -from networkx import ancestors from tqdm import tqdm from traces_analyzer.evaluation.evaluation import Evaluation +from traces_analyzer.evaluation.financial_gain_loss_evaluation import ( + FinancialGainLossEvaluation, +) from traces_analyzer.evaluation.instruction_differences_evaluation import ( InstructionDifferencesEvaluation, ) @@ -20,6 +22,9 @@ SecurifyPropertiesEvaluation, ) from traces_analyzer.evaluation.tod_source_evaluation import TODSourceEvaluation +from traces_analyzer.features.extractors.currency_changes import ( + CurrencyChangesFeatureExtractor, +) from traces_analyzer.features.extractors.instruction_differences import ( InstructionDifferencesFeatureExtractor, ) @@ -40,17 +45,12 @@ from traces_analyzer.loader.directory_loader import DirectoryLoader from traces_analyzer.loader.event_parser import VmTraceEventsParser from traces_analyzer.loader.loader import PotentialAttack -from traces_parser.parser.environment.call_context import CallContext from traces_parser.parser.events_parser import TraceEvent from traces_parser.parser.information_flow.information_flow_graph import ( build_information_flow_graph, ) from traces_parser.parser.instructions.instructions import ( CALL, - LOG0, - LOG1, - LOG2, - LOG3, STATICCALL, ) from traces_parser.parser.instructions_parser import ( @@ -58,7 +58,6 @@ parse_transaction, ) from traces_parser.datatypes import HexString -from traces_analyzer.utils.signatures.signature_registry import SignatureRegistry def main(): @@ -138,6 +137,9 @@ def compare_traces( instruction_usage_analyzers = SingleToDoubleInstructionFeatureExtractor( InstructionUsagesFeatureExtractor(), InstructionUsagesFeatureExtractor() ) + currency_changes_analyzer = SingleToDoubleInstructionFeatureExtractor( + CurrencyChangesFeatureExtractor(), CurrencyChangesFeatureExtractor() + ) calls_grouper = SingleToDoubleInstructionFeatureExtractor( InstructionLocationsGrouperFeatureExtractor([CALL.opcode]), InstructionLocationsGrouperFeatureExtractor([CALL.opcode]), @@ -158,6 +160,7 @@ def compare_traces( tod_source_analyzer, instruction_changes_analyzer, instruction_usage_analyzers, + currency_changes_analyzer, calls_grouper, ], transactions=(transaction_one, transaction_two), @@ -165,90 +168,90 @@ def compare_traces( ) runner.run() - information_flow_graph_one = build_information_flow_graph( - transaction_one.instructions - ) - information_flow_graph_two = build_information_flow_graph( - transaction_two.instructions - ) + build_information_flow_graph(transaction_one.instructions) + build_information_flow_graph(transaction_two.instructions) - if verbose: - call_tree_normal, call_tree_reverse = runner.get_call_trees() - print(f"Transaction: {hash}") - print("Call tree actual") - print(call_tree_normal) - print("Call tree reverse") - print(call_tree_reverse) + # if verbose: + # call_tree_normal, call_tree_reverse = runner.get_call_trees() + # print(f"Transaction: {hash}") + # print("Call tree actual") + # print(call_tree_normal) + # print("Call tree reverse") + # print(call_tree_reverse) - print("Source to Sink") - print() - all_instructions = transaction_one.instructions - tod_source_instruction = tod_source_analyzer.get_tod_source().instruction_one - changed_instructions = ( - instruction_changes_analyzer.get_instructions_with_different_inputs() - ) - # only memory input changes for CALL/LOGs - potential_sinks = [ - i - for i in changed_instructions - if i.opcode - in [CALL.opcode, LOG0.opcode, LOG1.opcode, LOG2.opcode, LOG3.opcode] - and i.memory_input_changes - ] - potential_sink_instructions = [ - change.instruction_one for change in potential_sinks - ] + # print("Source to Sink") + # print() + # all_instructions = transaction_one.instructions + # tod_source_instruction = tod_source_analyzer.get_tod_source().instruction_one + # changed_instructions = ( + # instruction_changes_analyzer.get_instructions_with_different_inputs() + # ) + # # only memory input changes for CALL/LOGs + # potential_sinks = [ + # i + # for i in changed_instructions + # if i.opcode + # in [CALL.opcode, LOG0.opcode, LOG1.opcode, LOG2.opcode, LOG3.opcode] + # and i.memory_input_changes + # ] + # potential_sink_instructions = [ + # change.instruction_one for change in potential_sinks + # ] - tod_source_instruction_index = all_instructions.index(tod_source_instruction) - potential_sink_instruction_indexes = [ - all_instructions.index(instr) for instr in potential_sink_instructions - ] - sink_instruction_index = min(potential_sink_instruction_indexes) - sink_instruction = all_instructions[sink_instruction_index] + # tod_source_instruction_index = all_instructions.index(tod_source_instruction) + # potential_sink_instruction_indexes = [ + # all_instructions.index(instr) for instr in potential_sink_instructions + # ] + # sink_instruction_index = min(potential_sink_instruction_indexes) + # sink_instruction = all_instructions[sink_instruction_index] - print(information_flow_graph_one) - print(information_flow_graph_two) - print(list(ancestors(information_flow_graph_one, sink_instruction.step_index))) + # print(information_flow_graph_one) + # print(information_flow_graph_two) + # print(list(ancestors(information_flow_graph_one, sink_instruction.step_index))) - source_to_sink_contexts: list[CallContext] = [] + # source_to_sink_contexts: list[CallContext] = [] - # NOTE: call contexts will go up and down and repeat themselves - for instr in all_instructions[ - tod_source_instruction_index : sink_instruction_index + 1 - ]: - if ( - not source_to_sink_contexts - or instr.call_context is not source_to_sink_contexts[-1] - ): - source_to_sink_contexts.append(instr.call_context) + # # NOTE: call contexts will go up and down and repeat themselves + # for instr in all_instructions[ + # tod_source_instruction_index : sink_instruction_index + 1 + # ]: + # if ( + # not source_to_sink_contexts + # or instr.call_context is not source_to_sink_contexts[-1] + # ): + # source_to_sink_contexts.append(instr.call_context) - """ - TODO: - - the source instruction is not necessarily related to the sink - -> should display all source instructions and the human can match it - - with information flow analysis, we could check which instruction is responsible for the change. - However, this would need to include stack, memory, tcache, calldata, returndata, and storage writes+reads - """ - signature_lookup = SignatureRegistry("http://localhost:8000") - min_depth = min(context.depth for context in source_to_sink_contexts) - source_indent = " " * (tod_source_instruction.call_context.depth - min_depth) - sink_indent = " " * (sink_instruction.call_context.depth - min_depth) - print(f"{source_indent}> {tod_source_instruction}") - for context in source_to_sink_contexts: - # print(context) - signature = ( - signature_lookup.lookup_by_hex(context.calldata[:8].get_hexstring()) - or context.calldata[:8] - ) - indent = " " * (context.depth - min_depth) - print(f"{indent}> {context.code_address}.{signature}") - print(f"{sink_indent}> {sink_instruction}") + # """ + # TODO: + # - the source instruction is not necessarily related to the sink + # -> should display all source instructions and the human can match it + # - with information flow analysis, we could check which instruction is responsible for the change. + # However, this would need to include stack, memory, tcache, calldata, returndata, and storage writes+reads + # """ + # signature_lookup = SignatureRegistry("http://localhost:8000") + # min_depth = min(context.depth for context in source_to_sink_contexts) + # source_indent = " " * (tod_source_instruction.call_context.depth - min_depth) + # sink_indent = " " * (sink_instruction.call_context.depth - min_depth) + # print(f"{source_indent}> {tod_source_instruction}") + # for context in source_to_sink_contexts: + # # print(context) + # signature = ( + # signature_lookup.lookup_by_hex(context.calldata[:8].get_hexstring()) + # or context.calldata[:8] + # ) + # indent = " " * (context.depth - min_depth) + # print(f"{indent}> {context.code_address}.{signature}") + # print(f"{sink_indent}> {sink_instruction}") evaluations: list[Evaluation] = [ SecurifyPropertiesEvaluation( calls_grouper.normal.instruction_groups, # type: ignore calls_grouper.reverse.instruction_groups, # type: ignore ), + FinancialGainLossEvaluation( + currency_changes_analyzer.normal.currency_changes, + currency_changes_analyzer.reverse.currency_changes, + ), TODSourceEvaluation(tod_source_analyzer.get_tod_source()), InstructionDifferencesEvaluation( occurrence_changes=instruction_changes_analyzer.get_instructions_only_executed_by_one_trace(), diff --git a/traces_analyzer/evaluation/financial_gain_loss_evaluation.py b/traces_analyzer/evaluation/financial_gain_loss_evaluation.py new file mode 100644 index 0000000..6963acb --- /dev/null +++ b/traces_analyzer/evaluation/financial_gain_loss_evaluation.py @@ -0,0 +1,109 @@ +from copy import deepcopy +from typing_extensions import override + +from collections import defaultdict +from typing import Sequence, TypedDict +from traces_analyzer.features.extractors.currency_changes import CurrencyChange +from traces_analyzer.evaluation.evaluation import Evaluation +from traces_parser.parser.instructions.instructions import Instruction + +CURRENCY_CHANGES_BY_ADDR = dict[str, dict[str, CurrencyChange]] + + +class GainsAndLosses(TypedDict): + gains: CURRENCY_CHANGES_BY_ADDR + losses: CURRENCY_CHANGES_BY_ADDR + + +class FinancialGainLossEvaluation(Evaluation): + @property + @override + def _type_key(self): + return "financial_gain_loss" + + @property + @override + def _type_name(self): + return "Financial gains and losses" + + def __init__( + self, + currency_changes_normal: Sequence[tuple[Instruction, CurrencyChange]], + currency_changes_reverse: Sequence[tuple[Instruction, CurrencyChange]], + ): + super().__init__() + self._gains_and_losses = compute_gains_and_losses( + currency_changes_normal, currency_changes_reverse + ) + + @override + def _dict_report(self) -> dict: + return dict(self._gains_and_losses) + + @override + def _cli_report(self) -> str: + # TODO + s = "Gains in normal compared to reverse scenario:\n" + for addr, gains in self._gains_and_losses["gains"].items(): + for change in gains.values(): + s += f'> {addr} gained {change["change"]} {change["type"]} {change["token_address"] or "(in Wei)"}\n' + s = "Losses in normal compared to reverse scenario:\n" + for addr, gains in self._gains_and_losses["gains"].items(): + for change in gains.values(): + s += f'> {addr} lost {change["change"]} {change["type"]} {change["token_address"] or "(in Wei)"}\n' + + return s + + +def compute_gains_and_losses( + changes_normal: Sequence[tuple[Instruction, CurrencyChange]], + changes_reverse: Sequence[tuple[Instruction, CurrencyChange]], +) -> GainsAndLosses: + grouped_normal = group_by_address(changes_normal) + grouped_reverse = group_by_address(changes_reverse) + + net_changes = subtract_changes(grouped_normal, grouped_reverse) + + gains: dict[str, dict[str, CurrencyChange]] = defaultdict(dict) + losses: dict[str, dict[str, CurrencyChange]] = defaultdict(dict) + for addr, changes in net_changes.items(): + for key, change in changes.items(): + if change["change"] > 0: + gains[addr][key] = change + if change["change"] < 0: + losses[addr][key] = change + + return { + "gains": gains, + "losses": losses, + } + + +def group_by_address( + changes: Sequence[tuple[Instruction, CurrencyChange]], +) -> CURRENCY_CHANGES_BY_ADDR: + groups: CURRENCY_CHANGES_BY_ADDR = defaultdict(dict) + + for _, change in changes: + addr = change["owner"] + key = change["type"] + (change["token_address"] or "") + if key not in groups[addr]: + groups[addr][key] = deepcopy(change) + else: + groups[addr][key]["change"] += change["change"] + + return groups + + +def subtract_changes(base: CURRENCY_CHANGES_BY_ADDR, operand: CURRENCY_CHANGES_BY_ADDR): + result = deepcopy(base) + + for addr, changes in operand.items(): + for key, change in changes.items(): + if key not in result[addr]: + result[addr][key] = deepcopy(change) + result[addr][key]["change"] *= -1 + else: + result[addr][key]["change"] -= change["change"] + + return result diff --git a/traces_analyzer/features/extractors/currency_changes.py b/traces_analyzer/features/extractors/currency_changes.py new file mode 100644 index 0000000..b51a304 --- /dev/null +++ b/traces_analyzer/features/extractors/currency_changes.py @@ -0,0 +1,74 @@ +from typing import TypedDict + +from typing_extensions import override + +from traces_analyzer.features.feature_extractor import SingleInstructionFeatureExtractor +from traces_parser.parser.instructions.instruction import Instruction +from traces_parser.parser.instructions.instructions import ( + CALL, + CALLCODE, + LOG0, + LOG1, + LOG2, + LOG3, + LOG4, +) + + +class CURRENCY: + ETHER = "ETHER" + + +class CurrencyChange(TypedDict): + type: str + """Type of the currency, e.g. ETHER or ERC-20, ...""" + token_address: str | None + """ID for the currency. For Ether this is None, for tokens this is the storage address that emitted the LOG""" + owner: str + """Address for which a change occurred""" + change: int + """Positive or negative change""" + + +class CurrencyChangesFeatureExtractor(SingleInstructionFeatureExtractor): + """Track all currency changes""" + + def __init__(self) -> None: + super().__init__() + self.currency_changes: list[tuple[Instruction, CurrencyChange]] = [] + + @override + def on_instruction(self, instruction: Instruction): + if instruction.call_context.reverted: + return + + if isinstance(instruction, (CALL, CALLCODE)): + sender = instruction.child_caller + receiver = instruction.child_code_address + value = instruction.child_value.get_hexstring().as_int() + self.currency_changes.append( + ( + instruction, + { + "type": CURRENCY.ETHER, + "token_address": None, + "owner": sender.with_prefix(), + "change": -value, + }, + ) + ) + self.currency_changes.append( + ( + instruction, + { + "type": CURRENCY.ETHER, + "token_address": None, + "owner": receiver.with_prefix(), + "change": value, + }, + ) + ) + + if isinstance(instruction, (LOG0, LOG1, LOG2, LOG3, LOG4)): + # TODO + pass