diff --git a/protoletariat/rewrite.py b/protoletariat/rewrite.py index bbf07ec3..5a4d942f 100644 --- a/protoletariat/rewrite.py +++ b/protoletariat/rewrite.py @@ -5,7 +5,7 @@ import collections.abc import typing from ast import AST -from typing import Any, Callable, NamedTuple, Sequence, Union +from typing import Any, Callable, MutableSet, NamedTuple, Sequence, Union try: from ast import unparse as astunparse @@ -205,12 +205,18 @@ class ImportNodeTransformer(ast.NodeTransformer): def __init__(self, ast_rewriter: ASTRewriter) -> None: self.ast_rewriter = ast_rewriter + # track the results we've produced to avoid duplication of imports + self.seen: MutableSet[str] = set() - def visit_Import(self, node: ast.Import) -> AST: - return self.ast_rewriter.rewrite(node) + def visit_Import(self, node: ast.AST) -> AST | None: + result = self.ast_rewriter.rewrite(node) + code = astunparse(result) + if code not in self.seen: + self.seen.add(code) + return result + return None - def visit_ImportFrom(self, node: ast.ImportFrom) -> AST: - return self.ast_rewriter.rewrite(node) + visit_ImportFrom = visit_Import class ASTImportRewriter: @@ -233,4 +239,5 @@ def _rewrite(_: AST, repl: AST = new_node) -> AST: ), f"more than one rewrite rule found for pattern `{replacement.old}`" def rewrite(self, src: str) -> str: + self.node_transformer.seen.clear() return astunparse(self.node_transformer.visit(ast.parse(src)))