Skip to content

Commit

Permalink
Merge branch 'microsoft:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
DetachHead authored Mar 24, 2024
2 parents e5f24ec + fe41569 commit 1cfae11
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 14 deletions.
21 changes: 18 additions & 3 deletions packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -750,10 +750,15 @@ function narrowTypeBasedOnClassPattern(
}

// We might be able to narrow further based on arguments, but only
// if the types match exactly or the subtype is a final class and
// therefore cannot be subclassed.
// if the types match exactly, the subject subtype is a final class (and
// therefore cannot be subclassed), or the pattern class is a protocol
// class.
if (!evaluator.assignType(subjectSubtypeExpanded, classInstance)) {
if (isClass(subjectSubtypeExpanded) && !ClassType.isFinal(subjectSubtypeExpanded)) {
if (
isClass(subjectSubtypeExpanded) &&
!ClassType.isFinal(subjectSubtypeExpanded) &&
!ClassType.isProtocolClass(classInstance)
) {
return subjectSubtypeExpanded;
}
}
Expand Down Expand Up @@ -786,6 +791,16 @@ function narrowTypeBasedOnClassPattern(
pattern.className
);
return NeverType.createNever();
} else if (
isInstantiableClass(exprType) &&
ClassType.isProtocolClass(exprType) &&
!ClassType.isRuntimeCheckable(exprType)
) {
evaluator.addDiagnostic(
DiagnosticRule.reportGeneralTypeIssues,
LocAddendum.protocolRequiresRuntimeCheckable(),
pattern.className
);
}

return evaluator.mapSubtypesExpandTypeVars(
Expand Down
14 changes: 13 additions & 1 deletion packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17691,6 +17691,14 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
annotatedType = getTypeOfParameterAnnotation(paramTypeNode, param.category);
}

if (annotatedType) {
addTypeVarsToListIfUnique(
typeParametersSeen,
getTypeVarArgumentsRecursive(annotatedType),
functionType.details.typeVarScopeId
);
}

if (isVariadicTypeVar(annotatedType) && !annotatedType.isVariadicUnpacked) {
addError(
LocMessage.unpackedTypeVarTupleExpected().format({
Expand Down Expand Up @@ -17834,7 +17842,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
FunctionType.addParameter(functionType, functionParam);

if (functionParam.hasDeclaredType) {
addTypeVarsToListIfUnique(typeParametersSeen, getTypeVarArgumentsRecursive(functionParam.type));
addTypeVarsToListIfUnique(
typeParametersSeen,
getTypeVarArgumentsRecursive(functionParam.type),
functionType.details.typeVarScopeId
);
}

if (param.name) {
Expand Down
6 changes: 5 additions & 1 deletion packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1944,8 +1944,12 @@ export function getClassFieldsRecursive(classType: ClassType): Map<string, Class

// Combines two lists of type var types, maintaining the combined order
// but removing any duplicates.
export function addTypeVarsToListIfUnique(list1: TypeVarType[], list2: TypeVarType[]) {
export function addTypeVarsToListIfUnique(list1: TypeVarType[], list2: TypeVarType[], typeVarScopeId?: TypeVarScopeId) {
for (const type2 of list2) {
if (typeVarScopeId && type2.scopeId !== typeVarScopeId) {
continue;
}

if (!list1.find((type1) => isTypeSame(convertToInstance(type1), convertToInstance(type2)))) {
list1.push(type2);
}
Expand Down
29 changes: 28 additions & 1 deletion packages/pyright-internal/src/tests/samples/matchClass1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
# This sample tests type checking for match statements (as
# described in PEP 634) that contain class patterns.

from typing import Any, Generic, Literal, NamedTuple, TypeVar
from typing import (
Any,
Generic,
Literal,
NamedTuple,
Protocol,
TypeVar,
runtime_checkable,
)
from typing_extensions import ( # pyright: ignore[reportMissingModuleSource]
LiteralString,
)
Expand Down Expand Up @@ -468,3 +476,22 @@ def func20(x: T6) -> T6:

reveal_type(x, expected_text="float* | int*")
return x


@runtime_checkable
class Proto1(Protocol):
x: int


class Proto2(Protocol):
x: int


def func21(subj: object):
match subj:
case Proto1():
pass

# This should generate an error because Proto2 isn't runtime checkable.
case Proto2():
pass
41 changes: 34 additions & 7 deletions packages/pyright-internal/src/tests/samples/matchClass3.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,49 @@
# This sample tests class-based pattern matching when the class is
# marked final and can be discriminated based on the argument patterns.

from typing import final
from typing import final, Protocol, runtime_checkable
from dataclasses import dataclass


class A:
title: str


class B:
name: str


class C:
name: str


def func1(r: A | B | C):
match r:
case object(title=_):
reveal_type(r, expected_text='A | B | C')
reveal_type(r, expected_text="A | B | C")

case object(name=_):
reveal_type(r, expected_text='A | B | C')
reveal_type(r, expected_text="A | B | C")

case _:
reveal_type(r, expected_text='A | B | C')
reveal_type(r, expected_text="A | B | C")


@final
class AFinal:
title: str


@final
class BFinal:
name: str


@final
class CFinal:
name: str


@final
class DFinal:
nothing: str
Expand All @@ -44,10 +52,29 @@ class DFinal:
def func2(r: AFinal | BFinal | CFinal | DFinal):
match r:
case object(title=_):
reveal_type(r, expected_text='AFinal')
reveal_type(r, expected_text="AFinal")

case object(name=_):
reveal_type(r, expected_text='BFinal | CFinal')
reveal_type(r, expected_text="BFinal | CFinal")

case _:
reveal_type(r, expected_text='DFinal')
reveal_type(r, expected_text="DFinal")


@runtime_checkable
class ProtoE(Protocol):
__match_args__ = ("x",)
x: int


@dataclass
class E:
x: int


match E(1):
case ProtoE(x):
pass

case y:
reveal_type(y, expected_text="Never")
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ def y(self) -> T5: ...
b1 = ClassB()
reveal_type(b1.x, expected_text="int")
reveal_type(b1.y, expected_text="int")


T6 = TypeVar("T6", default=int)
T7 = TypeVar("T7", default=T6)
T8 = TypeVar("T8", default=int | None)


class ClassC(Generic[T6, T7, T8]):
def __new__(cls, x: T7, /) -> Self: ...
def method1(self) -> T7: ...
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator3.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1323,7 +1323,7 @@ test('MatchClass1', () => {

configOptions.defaultPythonVersion = pythonVersion3_10;
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['matchClass1.py'], configOptions);
TestUtils.validateResults(analysisResults, 4);
TestUtils.validateResults(analysisResults, 5);
});

test('MatchClass2', () => {
Expand Down

0 comments on commit 1cfae11

Please sign in to comment.