Skip to content

Commit

Permalink
Merge pull request #793 from ScorexFoundation/fix-i778
Browse files Browse the repository at this point in the history
Fix possible race condition in PrecompiledScriptReducer
  • Loading branch information
aslesarenko authored May 4, 2022
2 parents 44253ac + 60c5e7b commit 727b7b7
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 24 deletions.
47 changes: 37 additions & 10 deletions core/src/main/scala/scalan/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,14 @@ abstract class Base { scalan: Scalan =>
cfor(0)(_ < delta, _ + 1) { _ => tab.append(null) }
val sym = if (s == null) new SingleRef(d) else s
tab += sym
assert(tab.length == id + 1)
assert(tab.length == id + 1,
s"""tab.length == id + 1:
|tab.length = ${tab.length}
|id = $id
|s = $s
|d = $d
|sym = $sym
|""".stripMargin)
sym
}
}
Expand Down Expand Up @@ -781,7 +788,6 @@ abstract class Base { scalan: Scalan =>
if (sym == null) {
sym = createDefinition(optScope, newSym, d)
}
// assert(te.rhs == d, s"${if (te !) "Found" else "Created"} unequal definition ${te.rhs} with symbol ${te.sym.toStringWithType} for $d")
sym
}

Expand All @@ -792,15 +798,36 @@ abstract class Base { scalan: Scalan =>
* @return reference to `d` (which is `s`)
*/
protected def createDefinition[T](optScope: Nullable[ThunkScope], s: Ref[T], d: Def[T]): Ref[T] = {
assert(_symbolTable(d.nodeId).node.nodeId == d.nodeId)
assert(s.node eq d, s"Inconsistent Sym -> Def pair $s -> $d")
optScope match {
case Nullable(scope) =>
scope += s
case _ =>
_globalDefs.put(d, d)
try {
val nodeId = d.nodeId
val tableSym = _symbolTable(nodeId)
assert(tableSym.node.nodeId == nodeId)
assert(s.node eq d, s"Inconsistent Sym -> Def pair $s -> $d")
optScope match {
case Nullable(scope) =>
scope += s
case _ =>
_globalDefs.put(d, d)
}
s
} catch { case t: Throwable =>
val msg = new mutable.StringBuilder(
s"""optScope = $optScope
|s = $s
|d = $d""".stripMargin)
if (d != null) {
msg ++= s"\nd.nodeId = ${d.nodeId}"
val tableSym = _symbolTable(d.nodeId)
msg ++= s"\n_symbolTable(d.nodeId) = $tableSym"
if (tableSym != null) {
msg ++= s"\ntableSym.node = ${tableSym.node}"
if (tableSym.node != null) {
msg ++= s"\ntableSym.node.nodeId = ${tableSym.node.nodeId}"
}
}
}
throw new RuntimeException(msg.result(), t)
}
s
}

/**
Expand Down
17 changes: 12 additions & 5 deletions sigmastate/src/main/scala/sigmastate/eval/IRContext.scala
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
package sigmastate.eval

import java.lang.{Math => JMath}
import sigmastate.SType
import sigmastate.Values.{Value, SValue}
import sigmastate.Values.{SValue, Value}
import sigmastate.interpreter.Interpreter.ScriptEnv
import sigmastate.lang.TransformingSigmaBuilder
import sigmastate.lang.exceptions.CostLimitException
import sigmastate.utils.Helpers
import sigmastate.utxo.CostTable

import java.util.concurrent.locks.ReentrantLock
import scala.util.Try

trait IRContext extends Evaluation with TreeBuilding {

override val builder = TransformingSigmaBuilder

/** Can be used to synchronize access to this IR object from multiple threads. */
val lock = new ReentrantLock()

/** Pass configuration which is used to turn-off constant propagation.
* @see `beginPass(noCostPropagationPass)` */
lazy val noConstPropagationPass = new DefaultPass(
Expand Down Expand Up @@ -125,9 +129,12 @@ trait IRContext extends Evaluation with TreeBuilding {
*/
def checkCostWithContext(ctx: SContext,
costF: Ref[((Context, (Int, Size[Context]))) => Int], maxCost: Long, initCost: Long): Try[Int] = Try {
val costFun = compile[(SContext, (Int, SSize[SContext])), Int, (Context, (Int, Size[Context])), Int](
getDataEnv, costF, Some(maxCost))
val (estimatedCost, accCost) = costFun((ctx, (0, Sized.sizeOf(ctx))))

val (estimatedCost, accCost) = Helpers.withReentrantLock(lock) { // protect mutable access to this IR
val costFun = compile[(SContext, (Int, SSize[SContext])), Int, (Context, (Int, Size[Context])), Int](
getDataEnv, costF, Some(maxCost))
costFun((ctx, (0, Sized.sizeOf(ctx))))
}

if (debugModeSanityChecks) {
if (estimatedCost != accCost)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ package sigmastate.interpreter

import java.util.concurrent.ExecutionException
import java.util.concurrent.atomic.AtomicInteger

import com.google.common.cache.{CacheBuilder, RemovalNotification, RemovalListener, LoadingCache, CacheLoader, CacheStats}
import com.google.common.cache.{CacheBuilder, CacheLoader, CacheStats, LoadingCache, RemovalListener, RemovalNotification}
import org.ergoplatform.settings.ErgoAlgos
import org.ergoplatform.validation.SigmaValidationSettings
import org.ergoplatform.validation.ValidationRules.{CheckCostFunc, CheckCalcFunc, trySoftForkable}
import org.ergoplatform.validation.ValidationRules.{CheckCalcFunc, CheckCostFunc, trySoftForkable}
import scalan.{AVHashMap, Nullable}
import sigmastate.Values
import sigmastate.Values.ErgoTree
import sigmastate.eval.{RuntimeIRContext, IRContext}
import sigmastate.eval.{IRContext, RuntimeIRContext}
import sigmastate.interpreter.Interpreter.{ReductionResult, WhenSoftForkReductionResult}
import sigmastate.serialization.ErgoTreeSerializer
import sigmastate.utils.Helpers
import sigmastate.utils.Helpers._
import spire.syntax.all.cfor

Expand Down Expand Up @@ -84,9 +84,11 @@ case class PrecompiledScriptReducer(scriptBytes: Seq[Byte])(implicit val IR: IRC
val estimatedCost = IR.checkCostWithContext(costingCtx, costF, maxCost, initCost).getOrThrow

// check calc
val calcF = costingRes.calcF
val calcCtx = context.toSigmaContext(isCost = false)
val res = Interpreter.calcResult(IR)(calcCtx, calcF)
val res = Helpers.withReentrantLock(IR.lock) { // protecting mutable access to IR instance
val calcF = costingRes.calcF
Interpreter.calcResult(IR)(calcCtx, calcF)
}
ReductionResult(SigmaDsl.toSigmaBoolean(res), estimatedCost)
}
}
Expand Down
27 changes: 24 additions & 3 deletions sigmastate/src/main/scala/sigmastate/utils/Helpers.scala
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package sigmastate.utils

import java.util

import io.circe.Decoder
import org.ergoplatform.settings.ErgoAlgos
import sigmastate.eval.{Colls, SigmaDsl}
import sigmastate.interpreter.CryptoConstants.EcPointType
import special.collection.Coll
import special.sigma.GroupElement

import java.util
import java.util.concurrent.locks.Lock
import scala.reflect.ClassTag
import scala.util.{Failure, Try, Either, Success, Right}
import scala.util.{Either, Failure, Right, Success, Try}

object Helpers {
def xor(ba1: Array[Byte], ba2: Array[Byte]): Array[Byte] = ba1.zip(ba2).map(t => (t._1 ^ t._2).toByte)
Expand Down Expand Up @@ -158,6 +158,27 @@ object Helpers {
val bytes = ErgoAlgos.decodeUnsafe(base16String)
Colls.fromArray(bytes)
}

/**
* Executes the given block with a reentrant mutual exclusion Lock with the same basic
* behavior and semantics as the implicit monitor lock accessed using synchronized
* methods and statements in Java.
*
* Note, using this method has an advantage of having this method in a stack trace in case of
* an exception in the block.
* @param l lock object which should be acquired by the current thread before block can start executing
* @param block block of code which will be executed retaining the lock
* @return the value produced by the block
*/
def withReentrantLock[A](l: Lock)(block: => A): A = {
l.lock()
val res = try
block
finally {
l.unlock()
}
res
}
}

object Overloading {
Expand Down

0 comments on commit 727b7b7

Please sign in to comment.