diff --git a/docs/changelog.md b/docs/changelog.md index e7346875..332b8630 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,10 @@ # Changelog +## Unreleased + +- Add concept of ownership: only containers owned by calling + code may be mutated (#542) + ## Version 0.8.0 (November 5, 2022) Release highlights: diff --git a/pyanalyze/attributes.py b/pyanalyze/attributes.py index 9dcffe2f..af27347c 100644 --- a/pyanalyze/attributes.py +++ b/pyanalyze/attributes.py @@ -15,6 +15,7 @@ import qcore from .annotations import Context, type_from_annotations, type_from_runtime +from .extensions import Mutable from .options import Options, PyObjectSequenceOption from .safe import safe_isinstance, safe_issubclass from .signature import MaybeSignature @@ -30,6 +31,7 @@ KnownValue, KnownValueWithTypeVars, MultiValuedValue, + make_mutable, set_self, SubclassValue, TypedValue, @@ -89,9 +91,11 @@ def get_generic_bases( def get_attribute(ctx: AttrContext) -> Value: root_value = ctx.root_value + should_own = False if isinstance(root_value, TypeVarValue): root_value = root_value.get_fallback_value() elif isinstance(root_value, AnnotatedValue): + should_own = any(root_value.get_custom_check_of_type(Mutable)) root_value = root_value.value if isinstance(root_value, KnownValue): attribute_value = _get_attribute_from_known(root_value.val, ctx) @@ -137,7 +141,10 @@ def get_attribute(ctx: AttrContext) -> Value: ) and isinstance(ctx.root_value, AnnotatedValue): for guard in ctx.root_value.get_metadata_of_type(HasAttrExtension): if guard.attribute_name == KnownValue(ctx.attr): - return guard.attribute_type + attribute_value = guard.attribute_type + break + if should_own and attribute_value is not UNINITIALIZED_VALUE: + attribute_value = make_mutable(attribute_value) return attribute_value diff --git a/pyanalyze/error_code.py b/pyanalyze/error_code.py index 34825365..913acfd5 100644 --- a/pyanalyze/error_code.py +++ b/pyanalyze/error_code.py @@ -99,6 +99,7 @@ class ErrorCode(enum.Enum): invalid_annotated_assignment = 79 unused_assignment = 80 incompatible_yield = 81 + disallowed_mutation = 82 # Allow testing unannotated functions without too much fuss @@ -219,6 +220,7 @@ class ErrorCode(enum.Enum): ErrorCode.invalid_annotated_assignment: "Invalid annotated assignment", ErrorCode.unused_assignment: "Assigned value is never used", ErrorCode.incompatible_yield: "Incompatible yield type", + ErrorCode.disallowed_mutation: "Mutation of object that does not allow mutation", } diff --git a/pyanalyze/extensions.py b/pyanalyze/extensions.py index 8b853d28..188bdeb0 100644 --- a/pyanalyze/extensions.py +++ b/pyanalyze/extensions.py @@ -30,7 +30,7 @@ ) import typing_extensions -from typing_extensions import Literal, NoReturn +from typing_extensions import Annotated, Literal, NoReturn import pyanalyze @@ -39,6 +39,8 @@ if TYPE_CHECKING: from .value import AnySource, CanAssign, CanAssignContext, TypeVarMap, Value +_T = TypeVar("_T") + class CustomCheck: """A mechanism for extending the type system with user-defined checks. @@ -146,6 +148,33 @@ def _is_disallowed(self, value: "Value") -> bool: ) +@dataclass(frozen=True) +class Mutable(CustomCheck): + """Custom check that indicates that a mutable value is mutated. For example, + a function that mutates a list should accept an argument of type + ``Annotated[List[T], Mutable()]``. + """ + + def can_assign(self, value: "Value", ctx: "CanAssignContext") -> "CanAssign": + for val in pyanalyze.value.flatten_values(value, unwrap_annotated=False): + if isinstance(val, pyanalyze.value.AnnotatedValue): + if any(val.get_custom_check_of_type(Mutable)): + continue + val = val.value + if isinstance(val, pyanalyze.value.AnyValue): + continue + return pyanalyze.value.CanAssignError( + f"Value {val} is not owned and may not be mutated", + error_code=pyanalyze.error_code.ErrorCode.disallowed_mutation, + ) + return {} + + +def make_mutable(obj: _T) -> Annotated[_T, Mutable()]: + """Unsafely mark an object as mutable.""" + return obj + + class _AsynqCallableMeta(type): def __getitem__( self, params: Tuple[Union[Literal[Ellipsis], List[object]], object] @@ -372,9 +401,6 @@ def __call__(self) -> NoReturn: raise NotImplementedError("just here to fool typing._type_check") -_T = TypeVar("_T") - - def reveal_type(value: _T) -> _T: """Inspect the inferred type of an expression. diff --git a/pyanalyze/functions.py b/pyanalyze/functions.py index a69dc0b7..e49b8daa 100644 --- a/pyanalyze/functions.py +++ b/pyanalyze/functions.py @@ -37,6 +37,7 @@ SubclassValue, TypedValue, TypeVarValue, + make_mutable, unite_values, UnpackedValue, Value, @@ -341,6 +342,7 @@ def compute_parameters( else: # normal method value = enclosing_class + value = make_mutable(value) else: # This is meant to exclude methods in nested classes. It's a bit too # conservative for cases such as a function nested in a method nested in a diff --git a/pyanalyze/implementation.py b/pyanalyze/implementation.py index dfb11609..baa0eb93 100644 --- a/pyanalyze/implementation.py +++ b/pyanalyze/implementation.py @@ -50,8 +50,10 @@ HasAttrGuardExtension, KNOWN_MUTABLE_TYPES, KnownValue, + is_owned, kv_pairs_from_mapping, KVPair, + make_mutable, MultiValuedValue, NO_RETURN_VALUE, ParameterTypeGuardExtension, @@ -97,6 +99,20 @@ def flatten_unions( return ImplReturn.unite_impl_rets(results) +def inherit_ownership( + impl: Callable[[CallContext], ImplReturn] +) -> Callable[[CallContext], ImplReturn]: + def wrapper(ctx: CallContext) -> ImplReturn: + ret = impl(ctx) + if is_owned(ctx.vars["self"]): + return ImplReturn( + make_mutable(ret.return_value), ret.constraint, ret.no_return_unless + ) + return ret + + return wrapper + + # Implementations of some important functions for use in their ExtendedArgSpecs (see above). These # are called when the test_scope checker encounters call to these functions. @@ -326,8 +342,11 @@ def _set_impl(ctx: CallContext) -> ImplReturn: def _sequence_impl(typ: type, ctx: CallContext) -> ImplReturn: iterable = ctx.vars["iterable"] + maybe_owned: Callable[[Value], Value] = ( + (lambda x: x) if typ is tuple else make_mutable + ) if iterable is _NO_ARG_SENTINEL: - return ImplReturn(KnownValue(typ())) + return ImplReturn(maybe_owned(KnownValue(typ()))) def inner(iterable: Value) -> Value: cvi = concrete_values_from_iterable(iterable, ctx.visitor) @@ -338,12 +357,14 @@ def inner(iterable: Value) -> Value: arg="iterable", detail=str(cvi), ) - return TypedValue(typ) + return maybe_owned(TypedValue(typ)) elif isinstance(cvi, Value): - return GenericValue(typ, [cvi]) + return maybe_owned(GenericValue(typ, [cvi])) else: # TODO: Consider changing concrete_values_from_iterable to preserve unpacked bits - return SequenceValue.make_or_known(typ, [(False, elt) for elt in cvi]) + return maybe_owned( + SequenceValue.make_or_known(typ, [(False, elt) for elt in cvi]) + ) return flatten_unions(inner, iterable) @@ -358,7 +379,9 @@ def _list_append_impl(ctx: CallContext) -> ImplReturn: varname, ConstraintType.is_value_object, True, - SequenceValue.make_or_known(list, (*lst.members, (False, element))), + make_mutable( + SequenceValue.make_or_known(list, (*lst.members, (False, element))) + ), ) return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) if isinstance(lst, GenericValue): @@ -424,8 +447,10 @@ def inner(key: Value) -> Value: if isinstance(self_value, SequenceValue): members = self_value.get_member_sequence() if members is not None: - return SequenceValue.make_or_known( - typ, [(False, m) for m in members[key.val]] + return make_mutable( + SequenceValue.make_or_known( + typ, [(False, m) for m in members[key.val]] + ) ) else: # If the value contains unpacked values, we don't attempt @@ -440,7 +465,7 @@ def inner(key: Value) -> Value: # __getitem__, but then we wouldn't get here). # TODO return a more precise type if the class inherits # from a generic list/tuple. - return TypedValue(typ) + return make_mutable(TypedValue(typ)) else: ctx.show_error(f"Invalid {typ.__name__} key {key}") return AnyValue(AnySource.error) @@ -462,6 +487,7 @@ def inner(key: Value) -> Value: return flatten_unions(inner, ctx.vars["obj"], unwrap_annotated=True) +@inherit_ownership def _list_getitem_impl(ctx: CallContext) -> ImplReturn: return _sequence_getitem_impl(ctx, list) @@ -519,6 +545,7 @@ def _dict_setitem_impl(ctx: CallContext) -> ImplReturn: return _add_pairs_to_dict(ctx.vars["self"], [pair], ctx, varname) +@inherit_ownership def _dict_getitem_impl(ctx: CallContext) -> ImplReturn: def inner(key: Value) -> Value: self_value = ctx.vars["self"] @@ -582,6 +609,7 @@ def inner(key: Value) -> Value: return flatten_unions(inner, ctx.vars["k"]) +@inherit_ownership def _dict_get_impl(ctx: CallContext) -> ImplReturn: default = ctx.vars["default"] @@ -649,6 +677,7 @@ def inner(key: Value) -> Value: return flatten_unions(inner, ctx.vars["key"]) +@inherit_ownership def _dict_pop_impl(ctx: CallContext) -> ImplReturn: key = ctx.vars["key"] default = ctx.vars["default"] @@ -691,9 +720,11 @@ def _dict_pop_impl(ctx: CallContext) -> ImplReturn: existing_value = self_value.get_value(key, ctx.visitor) is_present = existing_value is not UNINITIALIZED_VALUE if varname is not None and isinstance(key, KnownValue): - new_value = DictIncompleteValue( - self_value.typ, - [pair for pair in self_value.kv_pairs if pair.key != key], + new_value = make_mutable( + DictIncompleteValue( + self_value.typ, + [pair for pair in self_value.kv_pairs if pair.key != key], + ) ) no_return_unless = Constraint( varname, ConstraintType.is_value_object, True, new_value @@ -733,6 +764,7 @@ def _maybe_unite(value: Value, default: Value) -> Value: return unite_values(value, default) +@inherit_ownership def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn: key = ctx.vars["key"] default = ctx.vars["default"] @@ -776,9 +808,14 @@ def _dict_setdefault_impl(ctx: CallContext) -> ImplReturn: elif isinstance(self_value, DictIncompleteValue): existing_value = self_value.get_value(key, ctx.visitor) is_present = existing_value is not UNINITIALIZED_VALUE - new_value = DictIncompleteValue( - self_value.typ, - [*self_value.kv_pairs, KVPair(key, default, is_required=not is_present)], + new_value = make_mutable( + DictIncompleteValue( + self_value.typ, + [ + *self_value.kv_pairs, + KVPair(key, default, is_required=not is_present), + ], + ) ) if varname is not None: no_return_unless = Constraint( @@ -845,8 +882,10 @@ def _update_incomplete_dict( varname, ConstraintType.is_value_object, True, - DictIncompleteValue( - self_val.typ if isinstance(self_val, TypedValue) else dict, pairs + make_mutable( + DictIncompleteValue( + self_val.typ if isinstance(self_val, TypedValue) else dict, pairs + ) ), ) return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) @@ -957,13 +996,14 @@ def inner(left: Value, right: Value) -> Value: left = replace_known_sequence_value(left) right = replace_known_sequence_value(right) if isinstance(left, SequenceValue) and isinstance(right, SequenceValue): - return SequenceValue.make_or_known(list, [*left.members, *right.members]) + val = SequenceValue.make_or_known(list, [*left.members, *right.members]) elif isinstance(left, TypedValue) and isinstance(right, TypedValue): left_arg = left.get_generic_arg_for_type(list, ctx.visitor, 0) right_arg = right.get_generic_arg_for_type(list, ctx.visitor, 0) - return GenericValue(list, [unite_values(left_arg, right_arg)]) + val = GenericValue(list, [unite_values(left_arg, right_arg)]) else: - return TypedValue(list) + val = TypedValue(list) + return make_mutable(val) return flatten_unions(inner, ctx.vars["self"], ctx.vars["x"]) @@ -991,6 +1031,7 @@ def inner(lst: Value, iterable: Value) -> ImplReturn: constrained_value = SequenceValue( list, [*cleaned_lst.members, (True, arg_type)] ) + constrained_value = make_mutable(constrained_value) if return_container: return ImplReturn(constrained_value) if varname is not None: @@ -1060,8 +1101,10 @@ def _set_add_impl(ctx: CallContext) -> ImplReturn: varname, ConstraintType.is_value_object, True, - SequenceValue.make_or_known( - set, (*set_value.members, (False, element)) + make_mutable( + SequenceValue.make_or_known( + set, (*set_value.members, (False, element)) + ) ), ) return ImplReturn(KnownValue(None), no_return_unless=no_return_unless) @@ -1468,7 +1511,9 @@ def get_default_argspecs() -> Dict[object, Signature]: ), Signature.make( [ - SigParameter("self", _POS_ONLY, annotation=TypedValue(list)), + SigParameter( + "self", _POS_ONLY, annotation=make_mutable(TypedValue(list)) + ), SigParameter("object", _POS_ONLY), ], callable=list.append, @@ -1484,7 +1529,9 @@ def get_default_argspecs() -> Dict[object, Signature]: ), Signature.make( [ - SigParameter("self", _POS_ONLY, annotation=TypedValue(list)), + SigParameter( + "self", _POS_ONLY, annotation=make_mutable(TypedValue(list)) + ), SigParameter( "x", _POS_ONLY, annotation=TypedValue(collections.abc.Iterable) ), @@ -1494,7 +1541,9 @@ def get_default_argspecs() -> Dict[object, Signature]: ), Signature.make( [ - SigParameter("self", _POS_ONLY, annotation=TypedValue(list)), + SigParameter( + "self", _POS_ONLY, annotation=make_mutable(TypedValue(list)) + ), SigParameter( "iterable", _POS_ONLY, @@ -1522,7 +1571,7 @@ def get_default_argspecs() -> Dict[object, Signature]: ), Signature.make( [ - SigParameter("self", _POS_ONLY, annotation=TypedValue(set)), + SigParameter("self", _POS_ONLY, annotation=make_mutable(TypedValue(set))), SigParameter("object", _POS_ONLY), ], callable=set.add, @@ -1530,7 +1579,9 @@ def get_default_argspecs() -> Dict[object, Signature]: ), Signature.make( [ - SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)), + SigParameter( + "self", _POS_ONLY, annotation=make_mutable(TypedValue(dict)) + ), SigParameter("k", _POS_ONLY), SigParameter("v", _POS_ONLY), ], @@ -1556,7 +1607,9 @@ def get_default_argspecs() -> Dict[object, Signature]: ), Signature.make( [ - SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)), + SigParameter( + "self", _POS_ONLY, annotation=make_mutable(TypedValue(dict)) + ), SigParameter("key", _POS_ONLY), SigParameter("default", _POS_ONLY, default=KnownValue(None)), ], @@ -1565,7 +1618,9 @@ def get_default_argspecs() -> Dict[object, Signature]: ), Signature.make( [ - SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)), + SigParameter( + "self", _POS_ONLY, annotation=make_mutable(TypedValue(dict)) + ), SigParameter("key", _POS_ONLY), SigParameter("default", _POS_ONLY, default=_NO_ARG_SENTINEL), ], @@ -1574,7 +1629,9 @@ def get_default_argspecs() -> Dict[object, Signature]: ), Signature.make( [ - SigParameter("self", _POS_ONLY, annotation=TypedValue(dict)), + SigParameter( + "self", _POS_ONLY, annotation=make_mutable(TypedValue(dict)) + ), SigParameter("m", _POS_ONLY, default=_NO_ARG_SENTINEL), SigParameter("kwargs", ParameterKind.VAR_KEYWORD), ], @@ -1590,8 +1647,10 @@ def get_default_argspecs() -> Dict[object, Signature]: annotation=GenericValue(dict, [TypeVarValue(K), TypeVarValue(V)]), ) ], - DictIncompleteValue( - dict, [KVPair(TypeVarValue(K), TypeVarValue(V), is_many=True)] + make_mutable( + DictIncompleteValue( + dict, [KVPair(TypeVarValue(K), TypeVarValue(V), is_many=True)] + ) ), callable=dict.copy, ), diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 92baafed..2eface00 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -113,6 +113,7 @@ KWARGS, MaybeSignature, OverloadedSignature, + ParameterKind, Signature, SigParameter, ) @@ -179,6 +180,7 @@ NoReturnConstraintExtension, ReferencingValue, SequenceValue, + make_mutable, set_self, SubclassValue, TypedValue, @@ -1918,11 +1920,12 @@ def _visit_function_body(self, function_info: FunctionInfo) -> FunctionResult: "%first_arg", VisitorState.check_names, ) + if info.param.kind is ParameterKind.VAR_KEYWORD: + annotation = make_mutable(info.param.annotation) + else: + annotation = info.param.annotation self.scopes.set( - info.param.name, - info.param.annotation, - info.node, - VisitorState.check_names, + info.param.name, annotation, info.node, VisitorState.check_names ) with qcore.override( @@ -2400,13 +2403,13 @@ def _handle_imports( # Comprehensions def visit_DictComp(self, node: ast.DictComp) -> Value: - return self._visit_sequence_comp(node, dict) + return make_mutable(self._visit_sequence_comp(node, dict)) def visit_ListComp(self, node: ast.ListComp) -> Value: - return self._visit_sequence_comp(node, list) + return make_mutable(self._visit_sequence_comp(node, list)) def visit_SetComp(self, node: ast.SetComp) -> Value: - return self._visit_sequence_comp(node, set) + return make_mutable(self._visit_sequence_comp(node, set)) def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Value: return self._visit_sequence_comp(node, types.GeneratorType) @@ -2660,7 +2663,7 @@ def visit_Dict(self, node: ast.Dict) -> Value: ErrorCode.unsupported_operation, detail=str(new_pairs), ) - return TypedValue(dict) + return make_mutable(TypedValue(dict)) all_pairs += new_pairs continue key_val = self.visit(key_node) @@ -2700,15 +2703,18 @@ def visit_Dict(self, node: ast.Dict) -> Value: ret[key] = value if has_non_literal: - return DictIncompleteValue(dict, all_pairs) + return make_mutable(DictIncompleteValue(dict, all_pairs)) else: - return KnownValue(ret) + return make_mutable(KnownValue(ret)) def visit_Set(self, node: ast.Set) -> Value: - return self._visit_display_read(node, set) + return make_mutable(self._visit_display_read(node, set)) def visit_List(self, node: ast.List) -> Optional[Value]: - return self._visit_display(node, list) + val = self._visit_display(node, list) + if val is not None: + return make_mutable(val) + return None def visit_Tuple(self, node: ast.Tuple) -> Optional[Value]: return self._visit_display(node, tuple) @@ -3986,7 +3992,10 @@ def visit_AnnAssign(self, node: ast.AnnAssign) -> None: ) # We set the declared type on initial assignment, so that the # annotation can be used to adjust pyanalyze's type inference. - value = expected_type + if isinstance(value, AnnotatedValue): + value = annotate_value(expected_type, value.metadata) + else: + value = expected_type else: is_yield = False @@ -4330,12 +4339,13 @@ def composite_from_attribute(self, node: ast.Attribute) -> Composite: composite.get_varname(), self.being_assigned, node, self.state ) - if isinstance(root_composite.value, TypedValue): - typ = root_composite.value.typ - if isinstance(typ, type): - self._record_type_attr_set( - typ, node.attr, node, self.being_assigned - ) + for root_val in flatten_values(root_composite.value, unwrap_annotated=True): + if isinstance(root_val, TypedValue): + typ = root_val.typ + if isinstance(typ, type): + self._record_type_attr_set( + typ, node.attr, node, self.being_assigned + ) return Composite(self.being_assigned, composite, node) elif self._is_read_ctx(node.ctx): if self._is_checking(): diff --git a/pyanalyze/test_ownership.py b/pyanalyze/test_ownership.py new file mode 100644 index 00000000..98bc74e3 --- /dev/null +++ b/pyanalyze/test_ownership.py @@ -0,0 +1,21 @@ +# static analysis: ignore +from .test_name_check_visitor import TestNameCheckVisitorBase +from .test_node_visitor import assert_passes + + +class TestOwnership(TestNameCheckVisitorBase): + @assert_passes() + def test(self): + from typing import List + + def capybara(x: List[str]): + x.append("x") # E: disallowed_mutation + y = list(x) + y.append("x") + + z = [a for a in x] + z.append("x") + + # TODO make it work for an all-literal list + alpha = ["a", "b", "c", str(x)] + alpha.append("x") diff --git a/pyanalyze/typeshed.py b/pyanalyze/typeshed.py index d7f8ebf5..608fa3c5 100644 --- a/pyanalyze/typeshed.py +++ b/pyanalyze/typeshed.py @@ -73,6 +73,7 @@ TypeVarValue, UNINITIALIZED_VALUE, Value, + make_mutable, ) @@ -755,6 +756,7 @@ def _get_signature_from_info( sig = sig.replace_return_value(self_val) else: self_val = SubclassValue(self_val) + sig = sig.replace_return_value(make_mutable(sig.return_value)) bound_sig = make_bound_method(sig, Composite(self_val)) if bound_sig is None: return None diff --git a/pyanalyze/value.py b/pyanalyze/value.py index 79f884c5..85b580f6 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -50,7 +50,7 @@ def function(x: int, y: list[int], z: Any): import pyanalyze from pyanalyze.error_code import ErrorCode -from pyanalyze.extensions import CustomCheck +from pyanalyze.extensions import CustomCheck, Mutable from .safe import all_of_type, safe_equals, safe_isinstance, safe_issubclass @@ -2652,6 +2652,14 @@ def is_overlapping(left: Value, right: Value, ctx: CanAssignContext) -> bool: return left.is_assignable(right, ctx) or right.is_assignable(left, ctx) +def make_mutable(typ: Value) -> Value: + return AnnotatedValue(typ, [CustomCheckExtension(Mutable())]) + + +def is_owned(val: Value) -> bool: + return isinstance(val, AnnotatedValue) and any(val.get_custom_check_of_type(Mutable)) + + def make_coro_type(return_type: Value) -> GenericValue: return GenericValue( collections.abc.Coroutine,