diff --git a/README.rst b/README.rst index b0ddfa7..815e32e 100644 --- a/README.rst +++ b/README.rst @@ -190,6 +190,8 @@ second usage. Save the result to a list if the result is needed multiple times. **B034**: Calls to `re.sub`, `re.subn` or `re.split` should pass `flags` or `count`/`maxsplit` as keyword arguments. It is commonly assumed that `flags` is the third positional parameter, forgetting about `count`/`maxsplit`, since many other `re` module functions are of the form `f(pattern, string, flags)`. +**B035**: Found dict comprehension with a static key - either a constant value or variable not from the comprehension expression. This will result in a dict with a single key that was repeatedly overwritten. + Opinionated warnings ~~~~~~~~~~~~~~~~~~~~ diff --git a/bugbear.py b/bugbear.py index 88fbd2b..3b4fc3a 100644 --- a/bugbear.py +++ b/bugbear.py @@ -483,6 +483,7 @@ def visit_SetComp(self, node): def visit_DictComp(self, node): self.check_for_b023(node) + self.check_for_b035(node) self.generic_visit(node) def visit_GeneratorExp(self, node): @@ -954,6 +955,36 @@ def check_for_b031(self, loop_node): # noqa: C901 B031(node.lineno, node.col_offset, vars=(node.id,)) ) + def _get_names_from_tuple(self, node: ast.Tuple): + for dim in node.elts: + if isinstance(dim, ast.Name): + yield dim.id + elif isinstance(dim, ast.Tuple): + yield from self._get_names_from_tuple(dim) + + def _get_dict_comp_loop_var_names(self, node: ast.DictComp): + for gen in node.generators: + if isinstance(gen.target, ast.Name): + yield gen.target.id + elif isinstance(gen.target, ast.Tuple): + yield from self._get_names_from_tuple(gen.target) + + def check_for_b035(self, node: ast.DictComp): + """Check that a static key isn't used in a dict comprehension. + + Emit a warning if a likely unchanging key is used - either a constant, + or a variable that isn't coming from the generator expression. + """ + if isinstance(node.key, ast.Constant): + self.errors.append( + B035(node.key.lineno, node.key.col_offset, vars=(node.key.value,)) + ) + elif isinstance(node.key, ast.Name): + if node.key.id not in self._get_dict_comp_loop_var_names(node): + self.errors.append( + B035(node.key.lineno, node.key.col_offset, vars=(node.key.id,)) + ) + def _get_assigned_names(self, loop_node): loop_targets = (ast.For, ast.AsyncFor, ast.comprehension) for node in children_in_scope(loop_node): @@ -1884,6 +1915,8 @@ def visit_Lambda(self, node): " due to unintuitive argument positions." ) ) +B035 = Error(message="B035 Static key in dict comprehension {!r}.") + # Warnings disabled by default. B901 = Error( diff --git a/tests/b035.py b/tests/b035.py new file mode 100644 index 0000000..1a451e3 --- /dev/null +++ b/tests/b035.py @@ -0,0 +1,35 @@ +# OK - consts in regular dict +regular_dict = {"a": 1, "b": 2} +regular_nested_dict = {"a": 1, "nested": {"b": 2, "c": "three"}} + +# bad - const key in dict comprehension +bad_const_key_str = {"a": i for i in range(3)} +bad_const_key_int = {1: i for i in range(3)} + +# OK - const value in dict comp +const_val = {i: "a" for i in range(3)} + +# OK - expression with const in dict comp +key_expr_with_const = {i * i: i for i in range(3)} +key_expr_with_const2 = {"a" * i: i for i in range(3)} + +# nested +nested_bad_and_good = { + "good": {"a": 1, "b": 2}, + "bad": {"a": i for i in range(3)}, +} + +CONST_KEY_VAR = "KEY" + +# bad +bad_const_key_var = {CONST_KEY_VAR: i for i in range(3)} + +# OK - variable from tuple +var_from_tuple = {k: v for k, v in {}.items()} + +# OK - variable from nested tuple +var_from_nested_tuple = {v2: k for k, (v1, v2) in {"a": (1, 2)}.items()} + +# bad - variabe not from generator +v3 = 1 +bad_var_not_from_nested_tuple = {v3: k for k, (v1, v2) in {"a": (1, 2)}.items()} diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index 22ccd6a..3ea3e42 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -43,6 +43,7 @@ B032, B033, B034, + B035, B901, B902, B903, @@ -521,6 +522,19 @@ def test_b034(self): ) self.assertEqual(errors, expected) + def test_b035(self): + filename = Path(__file__).absolute().parent / "b035.py" + bbc = BugBearChecker(filename=str(filename)) + errors = list(bbc.run()) + expected = self.errors( + B035(6, 21, vars=("a",)), + B035(7, 21, vars=(1,)), + B035(19, 12, vars=("a",)), + B035(25, 21, vars=("CONST_KEY_VAR",)), + B035(35, 33, vars=("v3",)), + ) + self.assertEqual(errors, expected) + def test_b908(self): filename = Path(__file__).absolute().parent / "b908.py" bbc = BugBearChecker(filename=str(filename))