Skip to content

Commit

Permalink
[Enhanced Switch][Record Pattern] Simplify exhaustiveness checking fo…
Browse files Browse the repository at this point in the history
…r switch statements with record patterns (#3172)

* Fixes #3458
  • Loading branch information
srikanth-sankaran authored Dec 16, 2024
1 parent 57fb407 commit 9096700
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,14 @@ public static record SingletonBootstrap(String id, char[] selector, char[] signa
/* package */ List<Pattern> caseLabelElements = new ArrayList<>(0);//TODO: can we remove this?
public List<TypeBinding> caseLabelElementTypes = new ArrayList<>(0);

class Node {
abstract class Node {
TypeBinding type;
boolean hasError = false;
public void traverse(NodeVisitor visitor) {
visitor.visit(this);
visitor.endVisit(this);
}
public abstract void traverse(CoverageCheckerVisitor visitor);
}

class RNode extends Node {

TNode firstComponent;

RNode(TypeBinding rec) {
Expand All @@ -147,37 +146,27 @@ class RNode extends Node {
this.firstComponent = new TNode(comp.type);
}
}

void addPattern(Pattern p) {
if (p instanceof RecordPattern)
addPattern((RecordPattern)p);
}
void addPattern(RecordPattern rp) {
if (!TypeBinding.equalsEquals(this.type, rp.type.resolvedType))
return;
if (this.firstComponent == null)
return;
this.firstComponent.addPattern(rp, 0);
if (p instanceof RecordPattern rp && TypeBinding.equalsEquals(this.type, rp.type.resolvedType) && this.firstComponent != null)
this.firstComponent.addPattern(rp, 0);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[RNode] {\n"); //$NON-NLS-1$
sb.append(" type:"); //$NON-NLS-1$
sb.append(this.type != null ? this.type.toString() : "null"); //$NON-NLS-1$
sb.append(" firstComponent:"); //$NON-NLS-1$
sb.append(this.firstComponent != null ? this.firstComponent.toString() : "null"); //$NON-NLS-1$
sb.append("\n}\n"); //$NON-NLS-1$
return sb.toString();
return "[RNode] {\n type:" + this.type + " firstComponent:" + this.firstComponent + "\n}\n"; //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$
}

@Override
public void traverse(NodeVisitor visitor) {
public void traverse(CoverageCheckerVisitor visitor) {
if (this.firstComponent != null) {
visitor.visit(this.firstComponent);
}
visitor.endVisit(this);
}
}

class TNode extends Node {

List<PatternNode> children;

TNode(TypeBinding type) {
Expand All @@ -199,8 +188,7 @@ public void addPattern(RecordPattern rp, int i) {
}
}
if (child == null) {
child = childType.isRecord() ?
new RecordPatternNode(childType) : new PatternNode(childType);
child = new PatternNode(childType);
if (this.type.isSubtypeOf(childType, false))
this.children.add(0, child);
else
Expand All @@ -210,38 +198,28 @@ public void addPattern(RecordPattern rp, int i) {
child.addPattern(rp, i + 1);
}
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[TNode] {\n"); //$NON-NLS-1$
sb.append(" type:"); //$NON-NLS-1$
sb.append(this.type != null ? this.type.toString() : "null"); //$NON-NLS-1$
sb.append(" children:"); //$NON-NLS-1$
StringBuilder sb = new StringBuilder("[TNode] {\n type:" + this.type + " children:"); //$NON-NLS-1$ //$NON-NLS-2$
if (this.children == null) {
sb.append("null"); //$NON-NLS-1$
} else {
for (Node child : this.children) {
sb.append(child.toString());
sb.append(child);
}
}
sb.append("\n}\n"); //$NON-NLS-1$
return sb.toString();
return sb.append("\n}\n").toString(); //$NON-NLS-1$
}

@Override
public void traverse(NodeVisitor visitor) {
if (visitor.visit(this)) {
if (this.children != null) {
for (PatternNode child : this.children) {
if (!visitor.visit(child)) {
break;
}
}
}
}
visitor.endVisit(this);
public void traverse(CoverageCheckerVisitor visitor) {
visitor.visit(this);
}
}

class PatternNode extends Node {

TNode next; // next component

PatternNode(TypeBinding type) {
Expand All @@ -250,102 +228,31 @@ class PatternNode extends Node {

public void addPattern(RecordPattern rp, int i) {
TypeBinding ref = SwitchStatement.this.expression.resolvedType;
if (!(ref instanceof ReferenceBinding))
return;
RecordComponentBinding[] comps = ref.components();
if (comps == null || comps.length <= i) // safety-net for incorrect code.
return;
if (this.next == null)
this.next = new TNode(comps[i].type);
this.next.addPattern(rp, i);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[Pattern node] {\n"); //$NON-NLS-1$
sb.append(" type:"); //$NON-NLS-1$
sb.append(this.type != null ? this.type.toString() : "null"); //$NON-NLS-1$
sb.append(" next:"); //$NON-NLS-1$
sb.append(this.next != null ? this.next.toString() : "null"); //$NON-NLS-1$
sb.append("\n}\n"); //$NON-NLS-1$
return sb.toString();
}
@Override
public void traverse(NodeVisitor visitor) {
if (visitor.visit(this)) {
if (this.next != null) {
visitor.visit(this.next);
}
}
visitor.endVisit(this);
}
}
class RecordPatternNode extends PatternNode {
RNode rNode;
RecordPatternNode(TypeBinding type) {
super(type);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[RecordPattern node] {\n"); //$NON-NLS-1$
sb.append(" type:"); //$NON-NLS-1$
sb.append(this.type != null ? this.type.toString() : "null"); //$NON-NLS-1$
sb.append(" next:"); //$NON-NLS-1$
sb.append(this.next != null ? this.next.toString() : "null"); //$NON-NLS-1$
sb.append(" rNode:"); //$NON-NLS-1$
sb.append(this.rNode != null ? this.rNode.toString() : "null"); //$NON-NLS-1$
sb.append("\n}\n"); //$NON-NLS-1$
return sb.toString();
return "[" + (this.type.isRecord() ? "Record" : "") + "Pattern node] {\n type:" + this.type + " next:" + this.next + "\n}\n"; //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ //$NON-NLS-4$ //$NON-NLS-5$ //$NON-NLS-6$
}

@Override
public void traverse(NodeVisitor visitor) {
if (visitor.visit(this)) {
if (visitor.visit(this.rNode)) {
if (this.next != null) {
visitor.visit(this.next);
}
}
public void traverse(CoverageCheckerVisitor visitor) {
if (this.next != null) {
visitor.visit(this.next);
}
visitor.endVisit(this);
}
}

abstract class NodeVisitor {
public void endVisit(Node node) {
// do nothing by default
}
public void endVisit(PatternNode node) {
// do nothing by default
}
public void endVisit(RecordPatternNode node) {
// do nothing by default
}
public void endVisit(RNode node) {
// do nothing by default
}
public void endVisit(TNode node) {
// do nothing by default
}
public boolean visit(Node node) {
return true;
}
public boolean visit(PatternNode node) {
return true;
}
public boolean visit(RecordPatternNode node) {
return true;
}
public boolean visit(RNode node) {
return true;
}
public boolean visit(TNode node) {
return true;
}
}
class CoverageCheckerVisitor extends NodeVisitor {
class CoverageCheckerVisitor {

public boolean covers = true;
@Override

public boolean visit(TNode node) {
if (node.hasError)
return false;
Expand All @@ -365,8 +272,7 @@ public boolean visit(TNode node) {
this.covers &= caseElementsCoverSealedType(ref, availableTypes);
return this.covers;
}
this.covers = false;
return false; // no need to visit further.
return this.covers = false; // no need to visit further.
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4788,4 +4788,137 @@ record Box<T>(T contents) { }
},
"Contents");
}
public void testJEP440Example() {
runNegativeTest(new String[] {
"X.java",
"""
class A {
}
class B extends A {
}
sealed interface I permits C, D {
}
final class C implements I {
}
final class D implements I {
}
record Pair<T>(T x, T y) {
}
public class X {
static Pair<A> p1;
static Pair<I> p2;
public static void main(String[] args) {
// As of Java 21
switch (p1) { // Error!
case Pair<A>(A a, B b) -> System.out.println();
case Pair<A>(B b, A a) -> System.out.println();
}
switch (p2) {
case Pair<I>(I i, C c) -> System.out.println();
case Pair<I>(I i, D d) -> System.out.println();
}
switch (p2) {
case Pair<I>(C c, I i) -> System.out.println();
case Pair<I>(D d, C c) -> System.out.println();
case Pair<I>(D d1, D d2) -> System.out.println();
}
switch (p2) { // Error!
case Pair<I>(C fst, D snd) -> System.out.println();
case Pair<I>(D fst, C snd) -> System.out.println();
case Pair<I>(I fst, C snd) -> System.out.println();
}
}
}
"""
},
"----------\n" +
"1. ERROR in X.java (at line 25)\n" +
" switch (p1) { // Error!\n" +
" ^^\n" +
"An enhanced switch statement should be exhaustive; a default label expected\n" +
"----------\n" +
"2. ERROR in X.java (at line 41)\n" +
" switch (p2) { // Error!\n" +
" ^^\n" +
"An enhanced switch statement should be exhaustive; a default label expected\n" +
"----------\n");
}

public void testRecordCoverage() {
runConformTest(new String[] {
"X.java",
"""
sealed interface I permits A, B, C {
}
final class A implements I {
}
final class B implements I {
}
record C(int j) implements I {
} // Implicitly final
record Box(I i) {
}
public class X {
int testExhaustiveRecordPatterns(Box b) {
return switch (b) { // Exhaustive!
case Box(A aa) -> 0;
case Box(B bb) -> 1;
case Box(C cc) -> 2;
};
}
record IPair(I i, I j) {
}
int testExhaustiveRecordPatterns(IPair p) {
return switch (p) { // Exhaustive!
case IPair(A a1, A a2) -> 0;
case IPair(A a1, B b2) -> 1;
case IPair(A a1, C c3) -> 2;
case IPair(B b1, A b2) -> 3;
case IPair(B b1, B b2) -> 4;
case IPair(B b1, C b2) -> 5;
case IPair(C c1, A c2) -> 6;
case IPair(C c1, B c2) -> 7;
case IPair(C c1, C c2) -> 8;
};
}
public static void main(String [] args) {
X x = new X();
System.out.print(x.testExhaustiveRecordPatterns(new Box(new A())));
System.out.print(x.testExhaustiveRecordPatterns(new Box(new B())));
System.out.print(x.testExhaustiveRecordPatterns(new Box(new C(42))));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new A(), new A())));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new A(), new B())));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new A(), new C(42))));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new B(), new A())));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new B(), new B())));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new B(), new C(42))));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new C(42), new A())));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new C(42), new B())));
System.out.print(x.testExhaustiveRecordPatterns(new IPair(new C(42), new C(42))));
}
}
"""
},
"012012345678");
}
}

0 comments on commit 9096700

Please sign in to comment.