Skip to content

Commit

Permalink
Improve formatting for enums
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Hallett committed Oct 8, 2023
1 parent fa161ab commit ba4d6c2
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 9 deletions.
2 changes: 1 addition & 1 deletion clientele/generators/standard/generators/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def generate_parameters(
if param.get("$ref"):
# Get the actual parameter it is referencing
param = utils.get_param_from_ref(spec=self.spec, param=param)
clean_key = utils.clean_prop(param["name"])
clean_key = utils.snake_case_prop(param["name"])
if clean_key in param_keys:
continue
in_ = param.get("in")
Expand Down
10 changes: 4 additions & 6 deletions clientele/generators/standard/generators/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def generate_enum_properties(self, properties: dict) -> str:
for arg, arg_details in properties.items():
content = (
content
+ f""" {utils.clean_prop(arg.upper())} = {utils.get_type(arg_details)}\n"""
+ f""" {utils.snake_case_prop(arg.upper())} = {utils.get_type(arg_details)}\n"""
)
return content

Expand All @@ -47,7 +47,7 @@ def generate_headers_class(self, properties: dict, func_name: str) -> str:
template = writer.templates.get_template("schema_class.jinja2")
class_name = f"{utils.class_name_titled(func_name)}Headers"
string_props = "\n".join(
f' {utils.clean_prop(k)}: {v} = pydantic.Field(serialization_alias="{k}")'
f' {utils.snake_case_prop(k)}: {v} = pydantic.Field(serialization_alias="{k}")'
for k, v in properties.items()
)
content = template.render(
Expand All @@ -69,10 +69,8 @@ def generate_class_properties(
for arg, arg_details in properties.items():
arg_type = utils.get_type(arg_details)
is_optional = required and arg not in required
content = (
content
+ f""" {utils.clean_prop(arg)}: {is_optional and f"typing.Optional[{arg_type}]" or arg_type}\n"""
)
type_string = is_optional and f"typing.Optional[{arg_type}]" or arg_type
content = content + f""" {utils.snake_case_prop(arg)}: {type_string}\n"""
return content

def generate_input_class(self, schema: dict) -> None:
Expand Down
16 changes: 14 additions & 2 deletions clientele/generators/standard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,37 @@ def class_name_titled(input_str: str) -> str:
"""
# Capitalize the first letter always
input_str = input_str[:1].title() + input_str[1:]
# Remove any bad characters with an empty space
for badstr in [".", "-", "_", ">", "<", "/"]:
input_str = input_str.replace(badstr, " ")
if " " in input_str:
# Capitalize all the spaces
input_str = input_str.title()
# Remove all the spaces!
input_str = input_str.replace(" ", "")
return input_str


def clean_prop(input_str: str) -> str:
def snake_case_prop(input_str: str) -> str:
"""
Clean a property to not have invalid characters
Clean a property to not have invalid characters.
Returns a "snake_case" version of the input string
"""
# These strings appear in some OpenAPI schemas
for dropchar in [">", "<"]:
input_str = input_str.replace(dropchar, "")
# These we can just convert to an underscore
for badstr in ["-", "."]:
input_str = input_str.replace(badstr, "_")
# python keywords need to be converted
reserved_words = ["from"]
if input_str in reserved_words:
input_str = input_str + "_"
# Retain all-uppercase strings, otherwise convert to camel case
if not input_str.isupper():
input_str = "".join(
["_" + i.lower() if i.isupper() else i for i in input_str]
).lstrip("_")
return input_str


Expand Down Expand Up @@ -86,6 +97,7 @@ def get_type(t):
if t_type is None:
# In this case, make it an "Any"
return "typing.Any"
# Note: enums have type {'type': '"EXAMPLE"'} so fall through here
return t_type


Expand Down
18 changes: 18 additions & 0 deletions tests/generators/standard/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from clientele.generators.standard import utils


@pytest.mark.parametrize(
"input,expected_output",
[
("Lowercase", "lowercase"),
("bad-string", "bad_string"),
(">badstring", "badstring"),
("FooBarBazz", "foo_bar_bazz"),
("RETAIN_ALL_UPPER", "RETAIN_ALL_UPPER"),
("RETAINALLUPPER", "RETAINALLUPPER"),
],
)
def test_snake_case_prop(input, expected_output):
assert utils.snake_case_prop(input_str=input) == expected_output

0 comments on commit ba4d6c2

Please sign in to comment.