diff --git a/examples/dump.py b/examples/dump.py index e0cd9f7..57b2eee 100644 --- a/examples/dump.py +++ b/examples/dump.py @@ -10,9 +10,10 @@ from loguru import logger sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from pythclient.pythclient import PythClient, V2_PROGRAM_KEY, V2_FIRST_MAPPING_ACCOUNT_KEY # noqa +from pythclient.pythclient import PythClient # noqa from pythclient.ratelimit import RateLimit # noqa from pythclient.pythaccounts import PythPriceAccount # noqa +from pythclient.utils import get_key # noqa logger.enable("pythclient") @@ -32,9 +33,11 @@ def set_to_exit(sig: Any, frame: Any): async def main(): global to_exit use_program = len(sys.argv) >= 2 and sys.argv[1] == "program" + v2_first_mapping_account_key = get_key("devnet", "mapping") + v2_program_key = get_key("devnet", "program") async with PythClient( - first_mapping_account_key=V2_FIRST_MAPPING_ACCOUNT_KEY, - program_key=V2_PROGRAM_KEY if use_program else None, + first_mapping_account_key=v2_first_mapping_account_key, + program_key=v2_program_key if use_program else None, ) as c: await c.refresh_all_prices() products = await c.get_products() @@ -57,7 +60,7 @@ async def main(): await ws.connect() if use_program: print("Subscribing to program account") - await ws.program_subscribe(V2_PROGRAM_KEY, await c.get_all_accounts()) + await ws.program_subscribe(v2_program_key, await c.get_all_accounts()) else: print("Subscribing to all prices") for account in all_prices: @@ -88,7 +91,7 @@ async def main(): print("Unsubscribing...") if use_program: - await ws.program_unsubscribe(V2_PROGRAM_KEY) + await ws.program_unsubscribe(v2_program_key) else: for account in all_prices: await ws.unsubscribe(account) diff --git a/pythclient/pythclient.py b/pythclient/pythclient.py index da75ce2..74279b5 100644 --- a/pythclient/pythclient.py +++ b/pythclient/pythclient.py @@ -14,17 +14,13 @@ from .pythaccounts import PythAccount, PythMappingAccount, PythProductAccount, PythPriceAccount from . import exceptions, config, ratelimit -V1_FIRST_MAPPING_ACCOUNT_KEY = "ArppEFcsybCLE8CRtQJLQ9tLv2peGmQoKWFuiUWm4KBP" -V2_FIRST_MAPPING_ACCOUNT_KEY = "BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2" -V2_PROGRAM_KEY = "gSbePebfvPy7tRqimPoVecS2UsBvYv46ynrzWocc92s" - class PythClient: def __init__(self, *, solana_client: Optional[SolanaClient] = None, solana_endpoint: str = SOLANA_DEVNET_HTTP_ENDPOINT, solana_ws_endpoint: str = SOLANA_DEVNET_WS_ENDPOINT, - first_mapping_account_key: str = V1_FIRST_MAPPING_ACCOUNT_KEY, + first_mapping_account_key: str, program_key: Optional[str] = None, aiohttp_client_session: Optional[aiohttp.ClientSession] = None) -> None: self._first_mapping_account_key = SolanaPublicKey(first_mapping_account_key) diff --git a/pythclient/utils.py b/pythclient/utils.py new file mode 100644 index 0000000..36384ff --- /dev/null +++ b/pythclient/utils.py @@ -0,0 +1,35 @@ +import ast +import dns.resolver +from loguru import logger +from typing import Optional + +DEFAULT_VERSION = "v2" + + +# Retrieving keys via DNS TXT records should not be considered secure and is provided as a convenience only. +# Accounts should be stored locally and verified before being used for production. +def get_key(network: str, type: str, version: str = DEFAULT_VERSION) -> Optional[str]: + """ + Get the program or mapping keys from dns TXT records. + + Example dns records: + + devnet-program-v2.pyth.network + mainnet-program-v2.pyth.network + testnet-mapping-v2.pyth.network + """ + url = f"{network}-{type}-{version}.pyth.network" + try: + answer = dns.resolver.resolve(url, "TXT") + except dns.resolver.NXDOMAIN: + logger.error("TXT record for {} not found", url) + return "" + if len(answer) != 1: + logger.error("Invalid number of records returned for {}!", url) + return "" + # Example of the raw_key: + # "program=FsJ3A3u2vn5cTVofAjvy6y5kwABJAqYWpe4975bi2epH" + raw_key = ast.literal_eval(list(answer)[0].to_text()) + # program=FsJ3A3u2vn5cTVofAjvy6y5kwABJAqYWpe4975bi2epH" + _, key = raw_key.split("=", 1) + return key diff --git a/setup.py b/setup.py index 10c20d5..e4d66c1 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from setuptools import setup -requirements = ['aiodns', 'aiohttp>=3.7.4', 'backoff', 'base58', 'flake8', 'loguru'] +requirements = ['aiodns', 'aiohttp>=3.7.4', 'backoff', 'base58', 'dnspython', 'flake8', 'loguru'] setup( name='pythclient', diff --git a/tests/test_pyth_client.py b/tests/test_pyth_client.py index b6c22a2..7c467e3 100644 --- a/tests/test_pyth_client.py +++ b/tests/test_pyth_client.py @@ -6,7 +6,7 @@ _ACCOUNT_HEADER_BYTES, _VERSION_2, PythMappingAccount, PythPriceType, PythProductAccount, PythPriceAccount ) -from pythclient.pythclient import PythClient, V2_FIRST_MAPPING_ACCOUNT_KEY, V2_PROGRAM_KEY, WatchSession +from pythclient.pythclient import PythClient, WatchSession from pythclient.solana import ( SolanaClient, SolanaCommitment, @@ -24,6 +24,9 @@ # 2) these values are used in get_account_info_resp() and get_program_accounts_resp() # and so if they are passed in as fixtures, the functions will complain for the args # while mocking the respective functions +V2_FIRST_MAPPING_ACCOUNT_KEY = 'BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2' +V2_PROGRAM_KEY = 'gSbePebfvPy7tRqimPoVecS2UsBvYv46ynrzWocc92s' + BCH_PRODUCT_ACCOUNT_KEY = '89GseEmvNkzAMMEXcW9oTYzqRPXTsJ3BmNerXmgA1osV' BCH_PRICE_ACCOUNT_KEY = '4EQrNZYk5KR1RnjyzbaaRbHsv8VqZWzSUtvx58wLsZbj' diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..357817b --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,89 @@ +from _pytest.logging import LogCaptureFixture +import pytest + +from pytest_mock import MockerFixture + +from mock import Mock + +import dns.resolver +import dns.rdatatype +import dns.rdataclass +import dns.message +import dns.rrset +import dns.flags + +from pythclient.utils import get_key + + +@pytest.fixture() +def answer_program() -> dns.resolver.Answer: + qname = dns.name.Name(labels=(b'devnet-program-v2', b'pyth', b'network', b'')) + rdtype = dns.rdatatype.TXT + rdclass = dns.rdataclass.IN + response = dns.message.QueryMessage(id=0) + response.flags = dns.flags.QR + rrset_qn = dns.rrset.from_text(qname, 100, rdclass, rdtype) + rrset_ans = dns.rrset.from_text(qname, 100, rdclass, rdtype, '"program=gSbePebfvPy7tRqimPoVecS2UsBvYv46ynrzWocc92s"') + response.question = [rrset_qn] + response.answer = [rrset_ans] + answer = dns.resolver.Answer( + qname=qname, rdtype=rdtype, rdclass=rdclass, response=response) + answer.rrset = rrset_ans + return answer + + +@pytest.fixture() +def answer_mapping() -> dns.resolver.Answer: + qname = dns.name.Name(labels=(b'devnet-mapping-v2', b'pyth', b'network', b'')) + rdtype = dns.rdatatype.TXT + rdclass = dns.rdataclass.IN + response = dns.message.QueryMessage(id=0) + response.flags = dns.flags.QR + rrset_qn = dns.rrset.from_text(qname, 100, rdclass, rdtype) + rrset_ans = dns.rrset.from_text(qname, 100, rdclass, rdtype, '"mapping=BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2"') + response.question = [rrset_qn] + response.answer = [rrset_ans] + answer = dns.resolver.Answer( + qname=qname, rdtype=rdtype, rdclass=rdclass, response=response) + answer.rrset = rrset_ans + return answer + + +@pytest.fixture() +def mock_dns_resolver_resolve(mocker: MockerFixture) -> Mock: + mock = Mock() + mocker.patch('dns.resolver.resolve', side_effect=mock) + return mock + + +def test_utils_get_program_key(mock_dns_resolver_resolve: Mock, answer_program: dns.resolver.Answer) -> None: + mock_dns_resolver_resolve.return_value = answer_program + program_key = get_key("devnet", "program") + assert program_key == "gSbePebfvPy7tRqimPoVecS2UsBvYv46ynrzWocc92s" + + +def test_utils_get_mapping_key(mock_dns_resolver_resolve: Mock, answer_mapping: dns.resolver.Answer) -> None: + mock_dns_resolver_resolve.return_value = answer_mapping + mapping_key = get_key("devnet", "mapping") + assert mapping_key == "BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2" + + +def test_utils_get_mapping_key_not_found(mock_dns_resolver_resolve: Mock, + answer_mapping: dns.resolver.Answer, + caplog: LogCaptureFixture) -> None: + mock_dns_resolver_resolve.side_effect = dns.resolver.NXDOMAIN + exc_message = f'TXT record for {str(answer_mapping.response.canonical_name())[:-1]} not found' + key = get_key("devnet", "mapping") + assert exc_message in caplog.text + assert key == "" + + +def test_utils_get_mapping_key_invalid_number(mock_dns_resolver_resolve: Mock, + answer_mapping: dns.resolver.Answer, + caplog: LogCaptureFixture) -> None: + answer_mapping.rrset = None + mock_dns_resolver_resolve.return_value = answer_mapping + exc_message = f'Invalid number of records returned for {str(answer_mapping.response.canonical_name())[:-1]}!' + key = get_key("devnet", "mapping") + assert exc_message in caplog.text + assert key == ""