Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KE2: Extract when expressions #18058

Merged
merged 4 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6274,7 +6274,12 @@ open class KotlinFileExtractor(
val bLocId = tw.getLocation(b)
tw.writeStmts_whenbranch(bId, id, i, callable)
tw.writeHasLocation(bId, bLocId)
extractExpressionExpr(b.condition, callable, bId, 0, bId)

val condId = tw.getFreshIdLabel<DbWhenbranchcondition>()
tw.writeStmts_whenbranchcondition(condId, bId, 0, callable)
tw.writeHasLocation(condId, bLocId)
tw.writeWhen_branch_condition_with_expr(condId)
extractExpressionExpr(b.condition, callable, condId, 0, condId)
extractExpressionStmt(b.result, callable, bId, 1)
if (b is IrElseBranch) {
tw.writeWhen_branch_else(bId)
Expand Down
120 changes: 94 additions & 26 deletions java/kotlin-extractor2/src/main/kotlin/entities/Expression.kt
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,10 @@ private fun KotlinFileExtractor.extractExpression(
return extractIf(e, parent, callable)
}

is KtWhenExpression -> {
return extractWhen(e, parent, callable)
}

is KtWhileExpression -> {
extractLoopWithCondition(e, parent, callable)
}
Expand Down Expand Up @@ -1160,32 +1164,6 @@ private fun KotlinFileExtractor.extractExpression(
return
}

val exprParent = parent.expr(e, callable)
val id = tw.getFreshIdLabel<DbWhenexpr>()
val type = useType(e.type)
val locId = tw.getLocation(e)
tw.writeExprs_whenexpr(
id,
type.javaResult.id,
exprParent.parent,
exprParent.idx
)
tw.writeExprsKotlinType(id, type.kotlinResult.id)
extractExprContext(id, locId, callable, exprParent.enclosingStmt)
if (e.origin == IrStatementOrigin.IF) {
tw.writeWhen_if(id)
}
e.branches.forEachIndexed { i, b ->
val bId = tw.getFreshIdLabel<DbWhenbranch>()
val bLocId = tw.getLocation(b)
tw.writeStmts_whenbranch(bId, id, i, callable)
tw.writeHasLocation(bId, bLocId)
extractExpressionExpr(b.condition, callable, bId, 0, bId)
extractExpressionStmt(b.result, callable, bId, 1)
if (b is IrElseBranch) {
tw.writeWhen_branch_else(bId)
}
}
}
is IrGetClass -> {
val exprParent = parent.expr(e, callable)
Expand Down Expand Up @@ -1737,6 +1715,96 @@ private fun KotlinFileExtractor.extractLoop(
return id
}

context(KaSession)
private fun KotlinFileExtractor.extractWhen(
e: KtWhenExpression,
parent: StmtExprParent,
callable: Label<out DbCallable>
): Label<out DbExpr>? {
val exprParent = parent.expr(e, callable)
val id = tw.getFreshIdLabel<DbWhenexpr>()
val type = useType(e.expressionType)
val locId = tw.getLocation(e)
tw.writeExprs_whenexpr(
id,
type.javaResult.id,
exprParent.parent,
exprParent.idx
)
tw.writeExprsKotlinType(id, type.kotlinResult.id)
extractExprContext(id, locId, callable, exprParent.enclosingStmt)

val subjectVariable = e.subjectVariable
val subjectExpression = e.subjectExpression

if (subjectVariable != null) {
extractVariableExpr(subjectVariable, callable, id, -1, exprParent.enclosingStmt)
} else if (subjectExpression != null) {
extractExpressionExpr(subjectExpression, callable, id, -1, exprParent.enclosingStmt)
}

e.entries.forEachIndexed { i, b ->
val bId = tw.getFreshIdLabel<DbWhenbranch>()
val bLocId = tw.getLocation(b)
tw.writeStmts_whenbranch(bId, id, i, callable)
tw.writeHasLocation(bId, bLocId)
for ((idx, cond) in b.conditions.withIndex()) {
val condId = tw.getFreshIdLabel<DbWhenbranchcondition>()
val locId = tw.getLocation(cond)
tw.writeStmts_whenbranchcondition(condId, bId, -1 * idx, callable)
tw.writeHasLocation(id, locId)

when (cond) {
is KtWhenConditionWithExpression -> {
tw.writeWhen_branch_condition_with_expr(condId)
extractExpressionExpr(
cond.expression!!,
callable,
condId,
0,
condId
)
}

is KtWhenConditionInRange -> {
// [!]in 1..10
tw.writeWhen_branch_condition_with_range(condId, cond.isNegated)
extractExpressionExpr(
cond.rangeExpression!!,
callable,
condId,
0,
condId
)
}

is KtWhenConditionIsPattern -> {
// [!]is Type
val type = useType(cond.typeReference?.type)
tw.writeWhen_branch_condition_with_pattern(
condId,
cond.isNegated,
type.javaResult.id,
type.kotlinResult.id
)
}
}
}

extractExpressionStmt(b.expression!!, callable, bId, 1)
val guardExpr = b.guard?.getExpression()
if (guardExpr != null) {
extractExpressionStmt(guardExpr, callable, bId, 2)
}

if (b.isElse) {
tw.writeWhen_branch_else(bId)
}
}

return id
}

context(KaSession)
private fun KotlinFileExtractor.extractIf(
ifStmt: KtIfExpression,
Expand Down
20 changes: 20 additions & 0 deletions java/ql/lib/config/semmlecode.dbscheme
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ case @stmt.kind of
| 23 = @yieldstmt
| 24 = @errorstmt
| 25 = @whenbranch
| 26 = @whenbranchcondition
;

#keyset[parent,idx]
Expand Down Expand Up @@ -1047,6 +1048,25 @@ when_if(unique int id: @whenexpr ref);
/** Holds if this `when` branch was written as an `else` branch. */
when_branch_else(unique int id: @whenbranch ref);

/** Holds if this `when` branch condition has an expression. */
when_branch_condition_with_expr(
unique int id: @whenbranchcondition ref
);

/** Holds if this `when` branch condition has a range. */
when_branch_condition_with_range(
unique int id: @whenbranchcondition ref,
boolean isNegated: boolean ref
);

/** Holds if this `when` branch condition has a type pattern. */
when_branch_condition_with_pattern(
unique int id: @whenbranchcondition ref,
boolean isNegated: boolean ref,
int typeid: @type ref,
int kttypeid: @kt_type ref
);

@classinstancexpr = @newexpr | @lambdaexpr | @memberref | @propertyref

@annotation = @declannotation | @typeannotation
Expand Down
125 changes: 122 additions & 3 deletions java/ql/lib/semmle/code/java/Expr.qll
Original file line number Diff line number Diff line change
Expand Up @@ -2566,16 +2566,74 @@ class WhenExpr extends Expr, StmtParent, @whenexpr {
override string getAPrimaryQlClass() { result = "WhenExpr" }

/** Gets the `i`th branch. */
WhenBranch getBranch(int i) { result.isNthChildOf(this, i) }
WhenBranch getBranch(int i) { result.isNthChildOf(this, i) and i >= 0 }

/** Holds if this was written as an `if` expression. */
predicate isIf() { when_if(this) }

/**
* Gets the expression of this `when` expression, if any; such as `foo()` in the below sample.
*
* ```
* when (foo()) {
* 1 -> ...
* 2 -> ...
* }
*/
Expr getExpr() { result.isNthChildOf(this, -1) }

/**
* Gets the local variable declaration of this `when` expression, if any; such as
* `val x = foo()` in the below sample.
*
* ```
* when (val x = foo()) {
* 1 -> ...
* 2 -> ...
* }
* ```
*/
LocalVariableDeclExpr getAVariableDeclExpr() { result.isNthChildOf(this, -1) }
}

/** A Kotlin `when` branch. */
class WhenBranch extends Stmt, @whenbranch {
/** Gets the condition of this branch. */
Expr getCondition() { result.isNthChildOf(this, 0) }
/**
* DEPRECATED: Use `getACondition` or `getCondition/1` instead.
*
* Gets the condition of this branch.
*/
deprecated Expr getCondition() {
result = this.getCondition(0).(WhenBranchConditionWithExpression).getExpression()
}

/**
* Gets the `i`th condition of this branch. The first branch in the below sample has two conditions:
*
* ```
* when (foo()) {
* 1, !in 4..10 -> ...
* 3 -> ...
* }
* ```
*/
WhenBranchCondition getCondition(int i) { i <= 0 and result.isNthChildOf(this, i) }
tamasvajk marked this conversation as resolved.
Show resolved Hide resolved

/** Gets a condition of this branch. */
WhenBranchCondition getACondition() { result = this.getCondition(_) }

/**
* Gets the guard applicable to this branch, if any. Guards are currently experimental Kotlin features.
* In the below sample, the first branch has a guard: `bar() == 42`.
*
* ```
* when (foo()) {
* 1 if bar() == 42 -> ...
* else -> ..
* }
* ```
*/
Expr getGuard() { result.isNthChildOf(this, 2) }

/** Gets the result of this branch. */
Stmt getRhs() { result.isNthChildOf(this, 1) }
Expand All @@ -2594,6 +2652,67 @@ class WhenBranch extends Stmt, @whenbranch {
override string getAPrimaryQlClass() { result = "WhenBranch" }
}

/**
* A Kotlin `when` branch condition. Sample conditions are shown below:
*
* ```
* fun foo(): Number = ...
*
* when (foo()) {
* 1 -> ...
* in 2..10 -> ...
* is Int -> ...
* !is Int -> ...
* }
* ```
*/
abstract class WhenBranchCondition extends Stmt, @whenbranchcondition { }

/** A Kotlin `when` branch condition with an expression. */
class WhenBranchConditionWithExpression extends WhenBranchCondition {
WhenBranchConditionWithExpression() { when_branch_condition_with_expr(this) }

/** Gets the expression of this branch condition. */
Expr getExpression() { result.isNthChildOf(this, 0) }

override string toString() { result = "... ->" }

override string getAPrimaryQlClass() { result = "WhenBranchConditionWithExpression" }
}

/** A Kotlin `when` branch condition with a range. */
class WhenBranchConditionWithRange extends WhenBranchCondition {
WhenBranchConditionWithRange() { when_branch_condition_with_range(this, _) }

/** Holds if this is a negated range condition. */
predicate isNegated() { when_branch_condition_with_range(this, true) }

/**
* Gets the range of this branch condition.
* Ranges are represented by calls to `operator fun <T : Comparable<T>> T.rangeTo(that: T): ClosedRange<T>`.
*/
MethodCall getRange() { result.isNthChildOf(this, 0) }

override string toString() { result = "in ... ->" }

override string getAPrimaryQlClass() { result = "WhenBranchConditionWithRange" }
}

/** A Kotlin `when` branch condition with a pattern. */
class WhenBranchConditionWithPattern extends WhenBranchCondition {
WhenBranchConditionWithPattern() { when_branch_condition_with_pattern(this, _, _, _) }

/** Holds if this is a negated pattern condition. */
predicate isNegated() { when_branch_condition_with_pattern(this, true, _, _) }

/** Gets the type pattern of this branch condition. */
Type getType() { when_branch_condition_with_pattern(this, _, result, _) }

override string toString() { result = "is ... ->" }

override string getAPrimaryQlClass() { result = "WhenBranchConditionWithPattern" }
}

// TODO: This might need more cases. It might be better as a predicate
// on Stmt, overridden in each subclass.
private Expr getAResult(Stmt s) {
Expand Down
Loading