Skip to content

Commit

Permalink
use dns resolver to get account key (#12)
Browse files Browse the repository at this point in the history
* use dns resolver to get account key

* add comment for get_key()

* remove default value for first_mapping_account_key in PythClient

* remove unused import Mock

* add test for utils

* add dnspython dependency

* change fixture name

* change fixture name in arguments
  • Loading branch information
cctdaniel authored Dec 6, 2021
1 parent d017528 commit bdd32b4
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 12 deletions.
13 changes: 8 additions & 5 deletions examples/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from loguru import logger

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pythclient.pythclient import PythClient, V2_PROGRAM_KEY, V2_FIRST_MAPPING_ACCOUNT_KEY # noqa
from pythclient.pythclient import PythClient # noqa
from pythclient.ratelimit import RateLimit # noqa
from pythclient.pythaccounts import PythPriceAccount # noqa
from pythclient.utils import get_key # noqa

logger.enable("pythclient")

Expand All @@ -32,9 +33,11 @@ def set_to_exit(sig: Any, frame: Any):
async def main():
global to_exit
use_program = len(sys.argv) >= 2 and sys.argv[1] == "program"
v2_first_mapping_account_key = get_key("devnet", "mapping")
v2_program_key = get_key("devnet", "program")
async with PythClient(
first_mapping_account_key=V2_FIRST_MAPPING_ACCOUNT_KEY,
program_key=V2_PROGRAM_KEY if use_program else None,
first_mapping_account_key=v2_first_mapping_account_key,
program_key=v2_program_key if use_program else None,
) as c:
await c.refresh_all_prices()
products = await c.get_products()
Expand All @@ -57,7 +60,7 @@ async def main():
await ws.connect()
if use_program:
print("Subscribing to program account")
await ws.program_subscribe(V2_PROGRAM_KEY, await c.get_all_accounts())
await ws.program_subscribe(v2_program_key, await c.get_all_accounts())
else:
print("Subscribing to all prices")
for account in all_prices:
Expand Down Expand Up @@ -88,7 +91,7 @@ async def main():

print("Unsubscribing...")
if use_program:
await ws.program_unsubscribe(V2_PROGRAM_KEY)
await ws.program_unsubscribe(v2_program_key)
else:
for account in all_prices:
await ws.unsubscribe(account)
Expand Down
6 changes: 1 addition & 5 deletions pythclient/pythclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@
from .pythaccounts import PythAccount, PythMappingAccount, PythProductAccount, PythPriceAccount
from . import exceptions, config, ratelimit

V1_FIRST_MAPPING_ACCOUNT_KEY = "ArppEFcsybCLE8CRtQJLQ9tLv2peGmQoKWFuiUWm4KBP"
V2_FIRST_MAPPING_ACCOUNT_KEY = "BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2"
V2_PROGRAM_KEY = "gSbePebfvPy7tRqimPoVecS2UsBvYv46ynrzWocc92s"


class PythClient:
def __init__(self, *,
solana_client: Optional[SolanaClient] = None,
solana_endpoint: str = SOLANA_DEVNET_HTTP_ENDPOINT,
solana_ws_endpoint: str = SOLANA_DEVNET_WS_ENDPOINT,
first_mapping_account_key: str = V1_FIRST_MAPPING_ACCOUNT_KEY,
first_mapping_account_key: str,
program_key: Optional[str] = None,
aiohttp_client_session: Optional[aiohttp.ClientSession] = None) -> None:
self._first_mapping_account_key = SolanaPublicKey(first_mapping_account_key)
Expand Down
35 changes: 35 additions & 0 deletions pythclient/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import ast
import dns.resolver
from loguru import logger
from typing import Optional

DEFAULT_VERSION = "v2"


# Retrieving keys via DNS TXT records should not be considered secure and is provided as a convenience only.
# Accounts should be stored locally and verified before being used for production.
def get_key(network: str, type: str, version: str = DEFAULT_VERSION) -> Optional[str]:
"""
Get the program or mapping keys from dns TXT records.
Example dns records:
devnet-program-v2.pyth.network
mainnet-program-v2.pyth.network
testnet-mapping-v2.pyth.network
"""
url = f"{network}-{type}-{version}.pyth.network"
try:
answer = dns.resolver.resolve(url, "TXT")
except dns.resolver.NXDOMAIN:
logger.error("TXT record for {} not found", url)
return ""
if len(answer) != 1:
logger.error("Invalid number of records returned for {}!", url)
return ""
# Example of the raw_key:
# "program=FsJ3A3u2vn5cTVofAjvy6y5kwABJAqYWpe4975bi2epH"
raw_key = ast.literal_eval(list(answer)[0].to_text())
# program=FsJ3A3u2vn5cTVofAjvy6y5kwABJAqYWpe4975bi2epH"
_, key = raw_key.split("=", 1)
return key
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

requirements = ['aiodns', 'aiohttp>=3.7.4', 'backoff', 'base58', 'flake8', 'loguru']
requirements = ['aiodns', 'aiohttp>=3.7.4', 'backoff', 'base58', 'dnspython', 'flake8', 'loguru']

setup(
name='pythclient',
Expand Down
5 changes: 4 additions & 1 deletion tests/test_pyth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
_ACCOUNT_HEADER_BYTES, _VERSION_2, PythMappingAccount, PythPriceType, PythProductAccount, PythPriceAccount
)

from pythclient.pythclient import PythClient, V2_FIRST_MAPPING_ACCOUNT_KEY, V2_PROGRAM_KEY, WatchSession
from pythclient.pythclient import PythClient, WatchSession
from pythclient.solana import (
SolanaClient,
SolanaCommitment,
Expand All @@ -24,6 +24,9 @@
# 2) these values are used in get_account_info_resp() and get_program_accounts_resp()
# and so if they are passed in as fixtures, the functions will complain for the args
# while mocking the respective functions
V2_FIRST_MAPPING_ACCOUNT_KEY = 'BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2'
V2_PROGRAM_KEY = 'gSbePebfvPy7tRqimPoVecS2UsBvYv46ynrzWocc92s'

BCH_PRODUCT_ACCOUNT_KEY = '89GseEmvNkzAMMEXcW9oTYzqRPXTsJ3BmNerXmgA1osV'
BCH_PRICE_ACCOUNT_KEY = '4EQrNZYk5KR1RnjyzbaaRbHsv8VqZWzSUtvx58wLsZbj'

Expand Down
89 changes: 89 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from _pytest.logging import LogCaptureFixture
import pytest

from pytest_mock import MockerFixture

from mock import Mock

import dns.resolver
import dns.rdatatype
import dns.rdataclass
import dns.message
import dns.rrset
import dns.flags

from pythclient.utils import get_key


@pytest.fixture()
def answer_program() -> dns.resolver.Answer:
qname = dns.name.Name(labels=(b'devnet-program-v2', b'pyth', b'network', b''))
rdtype = dns.rdatatype.TXT
rdclass = dns.rdataclass.IN
response = dns.message.QueryMessage(id=0)
response.flags = dns.flags.QR
rrset_qn = dns.rrset.from_text(qname, 100, rdclass, rdtype)
rrset_ans = dns.rrset.from_text(qname, 100, rdclass, rdtype, '"program=gSbePebfvPy7tRqimPoVecS2UsBvYv46ynrzWocc92s"')
response.question = [rrset_qn]
response.answer = [rrset_ans]
answer = dns.resolver.Answer(
qname=qname, rdtype=rdtype, rdclass=rdclass, response=response)
answer.rrset = rrset_ans
return answer


@pytest.fixture()
def answer_mapping() -> dns.resolver.Answer:
qname = dns.name.Name(labels=(b'devnet-mapping-v2', b'pyth', b'network', b''))
rdtype = dns.rdatatype.TXT
rdclass = dns.rdataclass.IN
response = dns.message.QueryMessage(id=0)
response.flags = dns.flags.QR
rrset_qn = dns.rrset.from_text(qname, 100, rdclass, rdtype)
rrset_ans = dns.rrset.from_text(qname, 100, rdclass, rdtype, '"mapping=BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2"')
response.question = [rrset_qn]
response.answer = [rrset_ans]
answer = dns.resolver.Answer(
qname=qname, rdtype=rdtype, rdclass=rdclass, response=response)
answer.rrset = rrset_ans
return answer


@pytest.fixture()
def mock_dns_resolver_resolve(mocker: MockerFixture) -> Mock:
mock = Mock()
mocker.patch('dns.resolver.resolve', side_effect=mock)
return mock


def test_utils_get_program_key(mock_dns_resolver_resolve: Mock, answer_program: dns.resolver.Answer) -> None:
mock_dns_resolver_resolve.return_value = answer_program
program_key = get_key("devnet", "program")
assert program_key == "gSbePebfvPy7tRqimPoVecS2UsBvYv46ynrzWocc92s"


def test_utils_get_mapping_key(mock_dns_resolver_resolve: Mock, answer_mapping: dns.resolver.Answer) -> None:
mock_dns_resolver_resolve.return_value = answer_mapping
mapping_key = get_key("devnet", "mapping")
assert mapping_key == "BmA9Z6FjioHJPpjT39QazZyhDRUdZy2ezwx4GiDdE2u2"


def test_utils_get_mapping_key_not_found(mock_dns_resolver_resolve: Mock,
answer_mapping: dns.resolver.Answer,
caplog: LogCaptureFixture) -> None:
mock_dns_resolver_resolve.side_effect = dns.resolver.NXDOMAIN
exc_message = f'TXT record for {str(answer_mapping.response.canonical_name())[:-1]} not found'
key = get_key("devnet", "mapping")
assert exc_message in caplog.text
assert key == ""


def test_utils_get_mapping_key_invalid_number(mock_dns_resolver_resolve: Mock,
answer_mapping: dns.resolver.Answer,
caplog: LogCaptureFixture) -> None:
answer_mapping.rrset = None
mock_dns_resolver_resolve.return_value = answer_mapping
exc_message = f'Invalid number of records returned for {str(answer_mapping.response.canonical_name())[:-1]}!'
key = get_key("devnet", "mapping")
assert exc_message in caplog.text
assert key == ""

0 comments on commit bdd32b4

Please sign in to comment.