Skip to content

Commit

Permalink
add type annotations to generator python code (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
wfraser authored Apr 1, 2024
1 parent 6740072 commit 0e4b257
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cargo-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5.0.0
with:
python-version: 3.8
python-version: "3.10"

- name: Install Python dependencies
run: |
Expand Down
67 changes: 38 additions & 29 deletions generator/rust.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from abc import ABC
from contextlib import contextmanager
from typing import Optional, Iterator

from stone import ir
from stone.backend import CodeBackend
Expand Down Expand Up @@ -31,25 +33,32 @@
EXTRA_DISPLAY_TYPES = ["auth::RateLimitReason"]


class RustHelperBackend(CodeBackend):
def _arg_list(args: list[str]) -> str:
arg_list = ''
for arg in args:
arg_list += (', ' if arg_list != '' else '') + arg
return arg_list


class RustHelperBackend(CodeBackend, ABC):
"""
A superclass for RustGenerator and TestGenerator to contain some common rust-generation methods.
"""

def _dent_len(self):
def _dent_len(self) -> int:
if self.tabs_for_indents:
return 4 * self.cur_indent
else:
return self.cur_indent

def _arg_list(self, args):
arg_list = ''
for arg in args:
arg_list += (', ' if arg_list != '' else '') + arg
return arg_list

@contextmanager
def emit_rust_function_def(self, name, args=None, return_type=None, access=None):
def emit_rust_function_def(
self,
name: str,
args: Optional[list[str]] = None,
return_type: Optional[str] = None,
access: Optional[str] = None,
) -> Iterator[None]:
"""
A Rust function definition context manager.
"""
Expand All @@ -60,7 +69,7 @@ def emit_rust_function_def(self, name, args=None, return_type=None, access=None)
else:
access += ' '
ret = f' -> {return_type}' if return_type is not None else ''
one_line = f'{access}fn {name}({self._arg_list(args)}){ret} {{'
one_line = f'{access}fn {name}({_arg_list(args)}){ret} {{'
if self._dent_len() + len(one_line) < 100:
# one-line version
self.emit(one_line)
Expand All @@ -76,99 +85,99 @@ def emit_rust_function_def(self, name, args=None, return_type=None, access=None)
yield
self.emit('}')

def emit_rust_fn_call(self, func_name, args, end=None):
def emit_rust_fn_call(self, func_name: str, args: list[str], end: Optional[str] = None) -> None:
"""
Emit a Rust function call. Wraps arguments to multiple lines if it gets too long.
If `end` is None, the call ends without any semicolon.
"""
if end is None:
end = ''
one_line = f'{func_name}({self._arg_list(args)}){end}'
one_line = f'{func_name}({_arg_list(args)}){end}'
if self._dent_len() + len(one_line) < 100:
self.emit(one_line)
else:
self.emit(func_name + '(')
with self.indent():
for i, arg in enumerate(args):
self.emit(arg + (',' if i+1 < len(args) else (')' + end)))
self.emit(arg + (',' if i + 1 < len(args) else (')' + end)))

def is_enum_type(self, typ):
def is_enum_type(self, typ: ir.DataType) -> bool:
return isinstance(typ, ir.Union) or \
(isinstance(typ, ir.Struct) and typ.has_enumerated_subtypes())

def is_nullary_struct(self, typ):
def is_nullary_struct(self, typ: ir.DataType) -> bool:
return isinstance(typ, ir.Struct) and not typ.all_fields

def is_closed_union(self, typ):
def is_closed_union(self, typ: ir.DataType) -> bool:
return (isinstance(typ, ir.Union) and typ.closed) \
or (isinstance(typ, ir.Struct)
and typ.has_enumerated_subtypes() and not typ.is_catch_all())

def get_enum_variants(self, typ):
def get_enum_variants(self, typ: ir.DataType) -> list[ir.StructField]:
if isinstance(typ, ir.Union):
return typ.all_fields
elif isinstance(typ, ir.Struct) and typ.has_enumerated_subtypes():
return typ.get_enumerated_subtypes()
else:
return []

def namespace_name(self, ns):
def namespace_name(self, ns: ir.ApiNamespace) -> str:
return self.namespace_name_raw(ns.name)

def namespace_name_raw(self, ns_name):
def namespace_name_raw(self, ns_name: str) -> str:
name = fmt_underscores(ns_name)
if name in RUST_RESERVED_WORDS + RUST_GLOBAL_NAMESPACE:
name = 'dbx_' + name
return name

def struct_name(self, struct):
def struct_name(self, struct: ir.Struct) -> str:
name = fmt_pascal(struct.name)
if name in RUST_RESERVED_WORDS + RUST_GLOBAL_NAMESPACE:
name += 'Struct'
return name

def enum_name(self, union):
def enum_name(self, union: ir.DataType) -> str:
name = fmt_pascal(union.name)
if name in RUST_RESERVED_WORDS + RUST_GLOBAL_NAMESPACE:
name += 'Union'
return name

def field_name(self, field):
def field_name(self, field: ir.StructField) -> str:
return self.field_name_raw(field.name)

def field_name_raw(self, name):
def field_name_raw(self, name: str) -> str:
name = fmt_underscores(name)
if name in RUST_RESERVED_WORDS:
name += '_field'
return name

def enum_variant_name(self, field):
def enum_variant_name(self, field: ir.UnionField) -> str:
return self.enum_variant_name_raw(field.name)

def enum_variant_name_raw(self, name):
def enum_variant_name_raw(self, name: str) -> str:
name = fmt_pascal(name)
if name in RUST_RESERVED_WORDS:
name += 'Variant'
return name

def route_name(self, route):
def route_name(self, route: ir.ApiRoute) -> str:
return self.route_name_raw(route.name, route.version)

def route_name_raw(self, name, version):
def route_name_raw(self, name: str, version: int) -> str:
name = fmt_underscores(name)
if version > 1:
name = f'{name}_v{version}'
if name in RUST_RESERVED_WORDS:
name = 'do_' + name
return name

def alias_name(self, alias):
def alias_name(self, alias: ir.Alias) -> str:
name = fmt_pascal(alias.name)
if name in RUST_RESERVED_WORDS + RUST_GLOBAL_NAMESPACE:
name += 'Alias'
return name

def rust_type(self, typ, current_namespace, no_qualify=False, crate='crate'):
def rust_type(self, typ: ir.DataType, current_namespace: str, no_qualify: bool = False, crate: str ='crate') -> str:
if isinstance(typ, ir.Nullable):
t = self.rust_type(typ.data_type, current_namespace, no_qualify, crate)
return f'Option<{t}>'
Expand Down
Loading

0 comments on commit 0e4b257

Please sign in to comment.