Skip to content

Commit

Permalink
Add client type option (async or sync)
Browse files Browse the repository at this point in the history
  • Loading branch information
overcat committed Nov 28, 2024
1 parent 3764d62 commit a7c428b
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions stellar_contract_bindings/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,23 @@ def render_info():
)


def render_imports():
def render_imports(client_type: str = "both"):
template = """
from enum import IntEnum, Enum
from typing import Dict, List, Tuple, Optional, Union
from stellar_sdk import scval, xdr, Address, MuxedAccount, Keypair
from stellar_sdk.contract import AssembledTransaction, AssembledTransactionAsync, ContractClient, ContractClientAsync
{%- if client_type == "sync" or client_type == "both" %}
from stellar_sdk.contract import AssembledTransaction, ContractClient
{%- endif %}
{%- if client_type == "async" or client_type == "both" %}
from stellar_sdk.contract import AssembledTransactionAsync, ContractClientAsync
{%- endif %}
NULL_ACCOUNT = "GAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAWHF"
"""
return template
rendered_code = Template(template).render(client_type=client_type)
return rendered_code


def render_enum(entry: xdr.SCSpecUDTEnumV0):
Expand Down Expand Up @@ -479,8 +485,9 @@ def __hash__(self) -> int:
return kind_enum_rendered_code + "\n" + union_rendered_code


def render_client(entries: List[xdr.SCSpecFunctionV0]):
def render_client(entries: List[xdr.SCSpecFunctionV0], client_type: str):
template = '''
{%- if client_type == "sync" or client_type == "both" %}
class Client(ContractClient):
{%- for entry in entries %}
def {{ entry.name.sc_symbol.decode() }}(self, {% for param in entry.inputs %}{{ param.name.decode() }}: {{ to_py_type(param.type, True) }}, {% endfor %} source: Union[str, MuxedAccount] = NULL_ACCOUNT, signer: Optional[Keypair] = None, base_fee: int = 100, transaction_timeout: int = 300, submit_timeout: int = 30, simulate: bool = True, restore: bool = True) -> AssembledTransaction[{{ parse_result_type(entry.outputs) }}]:
Expand All @@ -489,7 +496,9 @@ def {{ entry.name.sc_symbol.decode() }}(self, {% for param in entry.inputs %}{{
{%- endif %}
return self.invoke('{{ entry.name.sc_symbol.decode() }}', [{% for param in entry.inputs %}{{ to_scval(param.type, param.name.decode()) }}{% if not loop.last %}, {% endif %}{% endfor %}], parse_result_xdr_fn={{ parse_result_xdr_fn(entry.outputs) }}, source = source, signer = signer, base_fee = base_fee, transaction_timeout = transaction_timeout, submit_timeout = submit_timeout, simulate = simulate, restore = restore)
{%- endfor %}
{%- endif %}
{%- if client_type == "async" or client_type == "both" %}
class ClientAsync(ContractClientAsync):
{%- for entry in entries %}
async def {{ entry.name.sc_symbol.decode() }}(self, {% for param in entry.inputs %}{{ param.name.decode() }}: {{ to_py_type(param.type, True) }}, {% endfor %} source: Union[str, MuxedAccount] = NULL_ACCOUNT, signer: Optional[Keypair] = None, base_fee: int = 100, transaction_timeout: int = 300, submit_timeout: int = 30, simulate: bool = True, restore: bool = True) -> AssembledTransactionAsync[{{ parse_result_type(entry.outputs) }}]:
Expand All @@ -498,6 +507,7 @@ async def {{ entry.name.sc_symbol.decode() }}(self, {% for param in entry.inputs
{%- endif %}
return await self.invoke('{{ entry.name.sc_symbol.decode() }}', [{% for param in entry.inputs %}{{ to_scval(param.type, param.name.decode()) }}{% if not loop.last %}, {% endif %}{% endfor %}], parse_result_xdr_fn={{ parse_result_xdr_fn(entry.outputs) }}, source = source, signer = signer, base_fee = base_fee, transaction_timeout = transaction_timeout, submit_timeout = submit_timeout, simulate = simulate, restore = restore)
{%- endfor %}
{%- endif %}
'''

def parse_result_type(output: List[xdr.SCSpecTypeDef]):
Expand All @@ -524,18 +534,19 @@ def parse_result_xdr_fn(output: List[xdr.SCSpecTypeDef]):
to_scval=to_scval,
parse_result_type=parse_result_type,
parse_result_xdr_fn=parse_result_xdr_fn,
client_type=client_type,
)
return client_rendered_code


def generate_binding(wasm: bytes) -> str:
def generate_binding(wasm: bytes, client_type: str) -> str:
generated = []
generated.append(render_info())

metadata = parse_contract_metadata(wasm)
specs = metadata.spec

generated.append(render_imports())
generated.append(render_imports(client_type))

for spec in specs:
if spec.kind == xdr.SCSpecEntryKind.SC_SPEC_ENTRY_UDT_ENUM_V0:
Expand All @@ -556,7 +567,7 @@ def generate_binding(wasm: bytes) -> str:
if spec.kind == xdr.SCSpecEntryKind.SC_SPEC_ENTRY_FUNCTION_V0
and not spec.function_v0.name.sc_symbol.decode().startswith("__")
]
generated.append(render_client(function_specs))
generated.append(render_client(function_specs, client_type))
return "\n".join(generated)


Expand All @@ -572,7 +583,13 @@ def generate_binding(wasm: bytes) -> str:
default=None,
help="Output directory for generated bindings, defaults to current directory",
)
def command(contract_id: str, rpc_url: str, output: str):
@click.option(
"--client-type",
type=click.Choice(["sync", "async", "both"], case_sensitive=False),
default="both",
help="Client type to generate, defaults to both sync and async",
)
def command(contract_id: str, rpc_url: str, output: str, client_type: str):
"""Generate Python bindings for a Soroban contract"""
if not StrKey.is_valid_contract(contract_id):
click.echo(f"Invalid contract ID: {contract_id}", err=True)
Expand All @@ -591,7 +608,7 @@ def command(contract_id: str, rpc_url: str, output: str):
raise click.Abort()

click.echo("Generating Python bindings")
generated = generate_binding(wasm_code)
generated = generate_binding(wasm_code, client_type=client_type)
if not os.path.exists(output):
os.makedirs(output)
output_path = os.path.join(output, "bindings.py")
Expand All @@ -607,5 +624,5 @@ def command(contract_id: str, rpc_url: str, output: str):
wasm_file = "/Users/overcat/repo/lightsail/stellar-contract-bindings/tests/contracts/target/wasm32-unknown-unknown/release/python.wasm"
with open(wasm_file, "rb") as f:
wasm = f.read()
generated_code = generate_binding(wasm)
generated_code = generate_binding(wasm, "both")
print(generated_code)

0 comments on commit a7c428b

Please sign in to comment.