diff --git a/ifgen/enum/__init__.py b/ifgen/enum/__init__.py index b076983..2b5098a 100644 --- a/ifgen/enum/__init__.py +++ b/ifgen/enum/__init__.py @@ -14,9 +14,11 @@ def create_enum(task: GenerateTask) -> None: """Create a header file based on an enum definition.""" - includes = [""] - if not task.instance["use_map"]: - includes.append("") + includes = [] + if task.is_cpp: + includes.append("") + if not task.instance["use_map"]: + includes.append("") with task.boilerplate( includes=includes, json=task.instance["json"] diff --git a/ifgen/enum/source.py b/ifgen/enum/source.py index ca891e5..958a1e5 100644 --- a/ifgen/enum/source.py +++ b/ifgen/enum/source.py @@ -13,9 +13,6 @@ def enum_source(task: GenerateTask, writer: IndentedFileWriter) -> None: """Create a source file for an enumeration.""" - if task.is_python: - return - enum_to_string_function(task, writer, task.instance["use_map"]) writer.empty() string_to_enum_function(task, writer, task.instance["use_map"]) @@ -24,6 +21,9 @@ def enum_source(task: GenerateTask, writer: IndentedFileWriter) -> None: def create_enum_source(task: GenerateTask) -> None: """Create a source file based on an enum definition.""" + if task.is_python: + return + if task.instance["use_map"]: with task.source_boilerplate(["", ""]) as writer: enum_source(task, writer) diff --git a/ifgen/generation/interface.py b/ifgen/generation/interface.py index 48e57a3..0eeb237 100644 --- a/ifgen/generation/interface.py +++ b/ifgen/generation/interface.py @@ -66,6 +66,11 @@ def is_python(self) -> bool: """Determine if this task's language is set to Python.""" return self.language is Language.PYTHON + @property + def is_cpp(self) -> bool: + """Determine if this task's language is set to C++.""" + return self.language is Language.CPP + @property def stream_implementation(self) -> bool: """ @@ -174,11 +179,18 @@ def javadoc_header( ) -> None: """Write a standard file header.""" - with writer.javadoc(): - writer.write(self.command("file")) - writer.write( - self.command("brief", f"Generated by {PKG_NAME} ({VERSION}).") - ) + description = f"Generated by {PKG_NAME} ({VERSION})." + + with ExitStack() as stack: + if self.is_python: + stack.enter_context( + writer.scope(opener='"""', closer='"""', indent=0) + ) + writer.write(description) + else: + stack.enter_context(writer.javadoc()) + writer.write(self.command("file")) + writer.write(self.command("brief", description)) if json: writer.write(dumps(self.instance, indent=2)) @@ -188,6 +200,9 @@ def handle_namespace( ) -> None: """Write namespace boilerplate to a file.""" + if self.is_python: + return + # Write namespace. namespace = self.namespace() writer.write(f"namespace {namespace}") @@ -208,13 +223,24 @@ def write_includes( if include: writer.write(f"#include {include}") + def resolve_description( + self, description: Optional[str] = None + ) -> Optional[str]: + """TODO.""" + + if description is None: + if "description" in self.instance and self.instance["description"]: + description = self.instance["description"] + + return description + @contextmanager def boilerplate( self, includes: Iterable[str] = None, is_test: bool = False, use_namespace: bool = True, - description: Union[bool, None, str] = None, + description: Optional[str] = None, json: bool = False, ) -> Iterator[IndentedFileWriter]: """ @@ -232,7 +258,7 @@ def boilerplate( # Write file header. self.javadoc_header(writer, json=json) - if not is_test: + if not self.is_python and not is_test: writer.write("#pragma once") self.write_includes(writer, includes=includes) @@ -240,17 +266,11 @@ def boilerplate( if use_namespace: self.handle_namespace(stack, writer) - if description is None: - if ( - "description" in self.instance - and self.instance["description"] - ): - description = self.instance["description"] - - # Write struct definition. - if description: - with writer.javadoc(): - writer.write(description) # type: ignore + if not self.is_python: + description = self.resolve_description(description=description) + if description: + with writer.javadoc(): + writer.write(description) yield writer diff --git a/ifgen/generation/test.py b/ifgen/generation/test.py index c8d90da..fe2c1a5 100644 --- a/ifgen/generation/test.py +++ b/ifgen/generation/test.py @@ -92,7 +92,6 @@ def unit_test_boilerplate( + includes, is_test=True, use_namespace=False, - description=False, ) )