diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala index cc2df9921b11..8f907ad780cf 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstCreator.scala @@ -215,7 +215,7 @@ class AstCreator( val node = createIdentifierWithScope(ctx, varSymbol.getText, varSymbol.getText, Defines.Any, List(Defines.Any)) Seq(Ast(node)) case _ => - logger.error("astForSingleLeftHandSideContext() All contexts mismatched.") + logger.error(s"astForSingleLeftHandSideContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } @@ -340,7 +340,7 @@ class AstCreator( case ctx: ChainedInvocationWithoutArgumentsPrimaryContext => astForChainedInvocationWithoutArgumentsPrimaryContext(ctx) case _ => - logger.error("astForPrimaryContext() All contexts mismatched.") + logger.error(s"astForPrimaryContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } @@ -364,7 +364,7 @@ class AstCreator( case ctx: MultipleAssignmentExpressionContext => astForMultipleAssignmentExpressionContext(ctx) case ctx: IsDefinedExpressionContext => Seq(astForIsDefinedExpression(ctx)) case _ => - logger.error("astForExpressionContext() All contexts mismatched.") + logger.error(s"astForExpressionContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } @@ -415,7 +415,7 @@ class AstCreator( case ctx: RubyParser.SplattingOnlyIndexingArgumentsContext => astForSplattingArgumentContext(ctx.splattingArgument()) case _ => - logger.error("astForIndexingArgumentsContext() All contexts mismatched.") + logger.error(s"astForIndexingArgumentsContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } @@ -627,7 +627,7 @@ class AstCreator( case ctx: GroupedLeftHandSideOnlyMultipleLeftHandSideContext => astForGroupedLeftHandSideContext(ctx.groupedLeftHandSide()) case _ => - logger.error("astForMultipleLeftHandSideContext() All contexts mismatched.") + logger.error(s"astForMultipleLeftHandSideContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } @@ -735,7 +735,7 @@ class AstCreator( .withChildren(astForArguments(ctx.arguments())) ) case _ => - logger.error("astForInvocationWithoutParenthesesContext() All contexts mismatched.") + logger.error(s"astForInvocationWithoutParenthesesContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } @@ -973,7 +973,7 @@ class AstCreator( case ctx: SimpleMethodNamePartContext => astForSimpleMethodNamePartContext(ctx) case ctx: SingletonMethodNamePartContext => astForSingletonMethodNamePartContext(ctx) case _ => - logger.error("astForMethodNamePartContext() All contexts mismatched.") + logger.error(s"astForMethodNamePartContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } @@ -1050,7 +1050,7 @@ class AstCreator( } def astForBodyStatementContext(ctx: BodyStatementContext, addReturnNode: Boolean = false): Seq[Ast] = { - val compoundStatementAsts = astForCompoundStatement(ctx.compoundStatement()) + val compoundStatementAsts = astForCompoundStatement(ctx.compoundStatement(), !addReturnNode) val compoundStatementAstsWithReturn = if (addReturnNode && compoundStatementAsts.size > 0) { @@ -1320,7 +1320,7 @@ class AstCreator( val primaryAsts = astForPrimaryContext(ctx.primary()) primaryAsts ++ methodNameAsts ++ argsAsts ++ doBlockAsts case _ => - logger.error("astForCommandWithDoBlockContext() All contexts mismatched.") + logger.error(s"astForCommandWithDoBlockContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } @@ -1354,7 +1354,7 @@ class AstCreator( case ctx: ChainedCommandWithDoBlockOnlyArgumentsWithParenthesesContext => astForChainedCommandWithDoBlockContext(ctx.chainedCommandWithDoBlock()) case _ => - logger.error("astForArgumentsWithParenthesesContext() All contexts mismatched.") + logger.error(s"astForArgumentsWithParenthesesContext() $filename, ${ctx.getText} All contexts mismatched.") Seq(Ast()) } diff --git a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala index 4e0803b7bf23..986ab76feac0 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/main/scala/io/joern/rubysrc2cpg/astcreation/AstForStatementsCreator.scala @@ -7,6 +7,7 @@ import io.joern.x2cpg.Ast import io.joern.x2cpg.Defines.DynamicCallUnknownFullName import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.{NewBlock, NewCall, NewControlStructure, NewImport, NewLiteral} +import org.slf4j.LoggerFactory import org.antlr.v4.runtime.ParserRuleContext import scala.jdk.CollectionConverters.CollectionHasAsScala @@ -14,6 +15,7 @@ import scala.jdk.CollectionConverters.CollectionHasAsScala trait AstForStatementsCreator { this: AstCreator => + private val logger = LoggerFactory.getLogger(this.getClass) protected def astForAliasStatement(ctx: AliasStatementContext): Ast = { val aliasName = ctx.definedMethodNameOrSymbol(0).getText.substring(1) val methodName = ctx.definedMethodNameOrSymbol(1).getText.substring(1) @@ -80,9 +82,13 @@ trait AstForStatementsCreator { controlStructureAst(throwNode, rhs.headOption, lhs) } - protected def astForCompoundStatement(ctx: CompoundStatementContext): Seq[Ast] = { + protected def astForCompoundStatement(ctx: CompoundStatementContext, packInBlock: Boolean = true): Seq[Ast] = { val stmtAsts = Option(ctx.statements()).map(astForStatements).getOrElse(Seq()) - Seq(blockAst(blockNode(ctx), stmtAsts.toList)) + if (packInBlock) { + Seq(blockAst(blockNode(ctx), stmtAsts.toList)) + } else { + stmtAsts + } } protected def astForStatements(ctx: StatementsContext): Seq[Ast] = { @@ -110,7 +116,9 @@ trait AstForStatementsCreator { case ctx: NotExpressionOrCommandContext => Seq(astForNotKeywordExpressionOrCommand(ctx)) case ctx: OrAndExpressionOrCommandContext => Seq(astForOrAndExpressionOrCommand(ctx)) case ctx: ExpressionExpressionOrCommandContext => astForExpressionContext(ctx.expression()) - case _ => Seq(Ast()) + case _ => + logger.error(s"astForExpressionOrCommand() $filename, ${ctx.getText} All contexts mismatched.") + Seq(Ast()) } protected def astForNotKeywordExpressionOrCommand(ctx: NotExpressionOrCommandContext): Ast = { diff --git a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/DataFlowTests.scala b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/DataFlowTests.scala index bc2d41c65424..f5b064293a9b 100644 --- a/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/DataFlowTests.scala +++ b/joern-cli/frontends/rubysrc2cpg/src/test/scala/io/joern/rubysrc2cpg/dataflow/DataFlowTests.scala @@ -944,6 +944,29 @@ class DataFlowTests extends DataFlowCodeToCpgSuite { } } + "Data flow for begin/rescue with sink in else" should { + val cpg = code(""" + |x = 1 + |begin + | puts "In begin" + |rescue SomeException + | puts "SomeException occurred" + |rescue => exceptionVar + | puts "Caught exception in variable #{exceptionVar}" + |rescue + | puts "Catch-all block" + |else + | puts x + |end + |""".stripMargin) + + "find flows to the sink" in { + val source = cpg.identifier.name("x").l + val sink = cpg.call.name("puts").l + sink.reachableByFlows(source).size shouldBe 2 + } + } + "Data flow for begin/rescue with sink in rescue" should { val cpg = code(""" |x = 1 @@ -1010,6 +1033,70 @@ class DataFlowTests extends DataFlowCodeToCpgSuite { } } + "Data flow for begin/rescue with sink in ensure" should { + val cpg = code(""" + |x = 1 + |begin + | puts "in begin" + |rescue SomeException + | puts "SomeException occurred" + |rescue => exceptionVar + | puts "Caught exception in variable #{exceptionVar}" + |rescue + | puts "In rescue all" + |ensure + | puts x + |end + | + |""".stripMargin) + + "find flows to the sink" in { + val source = cpg.identifier.name("x").l + val sink = cpg.call.name("puts").l + sink.reachableByFlows(source).size shouldBe 2 + } + } + + // parsing issue. comment out when fixed + "Data flow for begin/rescue with data flow through the exception" ignore { + val cpg = code(""" + |x = "Exception message: " + |begin + |1/0 + |rescue ZeroDivisionError => e + | y = x + e.message + | puts y + |end + | + |""".stripMargin) + + "find flows to the sink" in { + val source = cpg.identifier.name("x").l + val sink = cpg.call.name("puts").l + sink.reachableByFlows(source).size shouldBe 2 + } + } + + "Data flow for begin/rescue with data flow through block with multiple exceptions being caught" should { + val cpg = code(""" + |x = 1 + |y = 10 + |begin + |1/0 + |rescue SystemCallError, ZeroDivisionError + | y = x + 100 + |end + | + |puts y + |""".stripMargin) + + "find flows to the sink" in { + val source = cpg.identifier.name("x").l + val sink = cpg.call.name("puts").l + sink.reachableByFlows(source).size shouldBe 2 + } + } + "Data flow for begin/rescue with sink in function without begin" ignore { val cpg = code(""" |def foo(arg)