Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[6.0] GetVar(inputIndex, varId) for reading context variable from another input #1014

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/shared/src/main/scala/sigma/SigmaDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,8 @@ trait Context {
*/
def getVar[T](id: Byte)(implicit cT: RType[T]): Option[T]

def getVarFromInput[T](inputId: Short, id: Byte)(implicit cT: RType[T]): Option[T]

def vars: Coll[AnyValue]

/** Maximum version of ErgoTree currently activated on the network.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ object ReflectionData {
mkMethod(clazz, "getVar", Array[Class[_]](classOf[Byte], classOf[RType[_]])) { (obj, args) =>
obj.asInstanceOf[Context].getVar(args(0).asInstanceOf[Byte])(args(1).asInstanceOf[RType[_]])
},
mkMethod(clazz, "getVarFromInput", Array[Class[_]](classOf[Short], classOf[Byte], classOf[RType[_]])) { (obj, args) =>
obj.asInstanceOf[Context].getVarFromInput(args(0).asInstanceOf[Byte], args(1).asInstanceOf[Byte])(args(2).asInstanceOf[RType[_]])
},
mkMethod(clazz, "headers", Array[Class[_]]()) { (obj, _) =>
obj.asInstanceOf[Context].headers
}
Expand Down
31 changes: 30 additions & 1 deletion data/shared/src/main/scala/sigma/ast/methods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sigma.ast

import org.ergoplatform._
import org.ergoplatform.validation._
import sigma.Evaluation.stypeToRType
import sigma._
import sigma.ast.SCollection.{SBooleanArray, SBoxArray, SByteArray, SByteArray2, SHeaderArray}
import sigma.ast.SMethod.{MethodCallIrBuilder, MethodCostFunc, javaMethodOf}
Expand Down Expand Up @@ -1418,16 +1419,44 @@ case object SContextMethods extends MonoTypeMethods {
lazy val selfBoxIndexMethod = propertyCall("selfBoxIndex", SInt, 8, FixedCost(JitCost(20)))
lazy val lastBlockUtxoRootHashMethod = property("LastBlockUtxoRootHash", SAvlTree, 9, LastBlockUtxoRootHash)
lazy val minerPubKeyMethod = property("minerPubKey", SByteArray, 10, MinerPubkey)

lazy val getVarMethod = SMethod(
this, "getVar", SFunc(ContextFuncDom, SOption(tT), Array(paramT)), 11, GetVar.costKind)
.withInfo(GetVar, "Get context variable with given \\lst{varId} and type.",
ArgInfo("varId", "\\lst{Byte} identifier of context variable"))

protected override def getMethods() = super.getMethods() ++ Seq(
// todo: costing, desc
lazy val getVarFromInputMethod = SMethod(
this, "getVarFromInput", SFunc(Array(SContext, SShort, SByte), SOption(tT), Array(paramT)), 12, GetVar.costKind, Seq(tT))
.withIRInfo(MethodCallIrBuilder)
.withInfo(MethodCall, "Multiply this number with \\lst{other} by module Q.", ArgInfo("other", "Number to multiply with this."))

def getVarFromInput_eval[T](mc: MethodCall, ctx: sigma.Context, inputId: Short, varId: Byte)
(implicit E: ErgoTreeEvaluator): Option[T] = {
// E.addCost(getVarFromInputMethod.costKind)
val rt = stypeToRType(mc.typeSubst.get(tT).get)
val res = ctx.getVarFromInput(inputId, varId)(rt).asInstanceOf[Option[T]]
res
}

private lazy val v5Methods = super.getMethods() ++ Seq(
dataInputsMethod, headersMethod, preHeaderMethod, inputsMethod, outputsMethod, heightMethod, selfMethod,
selfBoxIndexMethod, lastBlockUtxoRootHashMethod, minerPubKeyMethod, getVarMethod
)

private lazy val v6Methods = super.getMethods() ++ Seq(
dataInputsMethod, headersMethod, preHeaderMethod, inputsMethod, outputsMethod, heightMethod, selfMethod,
selfBoxIndexMethod, lastBlockUtxoRootHashMethod, minerPubKeyMethod, getVarMethod, getVarFromInputMethod
)

protected override def getMethods(): Seq[SMethod] = {
if(VersionContext.current.isV6SoftForkActivated) {
v6Methods
} else {
v5Methods
}
}

/** Names of methods which provide blockchain context.
* This value can be reused where necessary to avoid allocations. */
val BlockchainContextMethodNames: IndexedSeq[String] = Array(
Expand Down
2 changes: 1 addition & 1 deletion data/shared/src/main/scala/sigma/ast/values.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ case class MethodCall(
val objV = obj.evalTo[Any](env)
addCost(MethodCall.costKind) // MethodCall overhead
method.costKind match {
case fixed: FixedCost =>
case fixed: FixedCost if method.explicitTypeArgs.isEmpty =>
val extra = method.extraDescriptors
val extraLen = extra.length
val len = args.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ import sigma.serialization.{SigmaByteReader, SigmaByteWriter, SigmaSerializer}
* @param values internal container of the key-value pairs
*/
case class ContextExtension(values: scala.collection.Map[Byte, EvaluatedValue[_ <: SType]]) {
def add(bindings: VarBinding*): ContextExtension =
def add(bindings: VarBinding*): ContextExtension = {
ContextExtension(values ++ bindings)
}

def get(varId: Byte): Option[EvaluatedValue[_ <: SType]] = values.get(varId)
}

object ContextExtension {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class ErgoLikeContext(val lastBlockUtxoRoot: AvlTreeData,
syntax.error(s"Undefined context property: currentErgoTreeVersion"))
CContext(
dataInputs, headers, preHeader, inputs, outputs, preHeader.height, selfBox, selfIndex, avlTree,
preHeader.minerPk.getEncoded, vars, activatedScriptVersion, ergoTreeVersion)
preHeader.minerPk.getEncoded, vars, spendingTransaction, activatedScriptVersion, ergoTreeVersion)
}


Expand Down
12 changes: 12 additions & 0 deletions interpreter/shared/src/main/scala/sigmastate/eval/CContext.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package sigmastate.eval

import debox.cfor
import org.ergoplatform.{ErgoLikeTransactionTemplate, UnsignedInput}
import sigma.Evaluation.{stypeToRType, toDslTuple}
import sigma.Extensions.ArrayOps
import sigma._
import sigma.ast.SType
import sigma.data._
import sigma.exceptions.InvalidType

Expand All @@ -24,6 +27,7 @@ case class CContext(
lastBlockUtxoRootHash: AvlTree,
_minerPubKey: Coll[Byte],
vars: Coll[AnyValue],
spendingTransaction: ErgoLikeTransactionTemplate[_ <: UnsignedInput],
override val activatedScriptVersion: Byte,
override val currentErgoTreeVersion: Byte
) extends Context {
Expand Down Expand Up @@ -69,6 +73,14 @@ case class CContext(
} else None
}

override def getVarFromInput[T](inputId: Short, id: Byte)(implicit tT: RType[T]): Option[T] = {
spendingTransaction.inputs.unapply(inputId).flatMap(_.extension.get(id)) match {
case Some(v) if stypeToRType[SType](v.tpe) == tT => Some(v.value.asInstanceOf[T])
case _ =>
None
}
}

/** Return a new context instance with variables collection updated.
* @param bindings a new binding of the context variables with new values.
* @return a new instance (if `bindings` non-empty) with the specified bindings.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ trait ContractsTestkit {
new CContext(
noInputs.toColl, noHeaders, dummyPreHeader,
inputs.toColl, outputs.toColl, height, self, inputs.indexOf(self), tree,
minerPk.toColl, vars.toColl, activatedScriptVersion, currErgoTreeVersion)
minerPk.toColl, vars.toColl, null, activatedScriptVersion, currErgoTreeVersion)

def newContext(
height: Int,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sigma.compiler.ir

import org.ergoplatform._
import sigma.ast.SType.tT
import sigma.ast.TypeCodes.LastConstantCode
import sigma.ast.Value.Typed
import sigma.ast.syntax.{SValue, ValueOps}
Expand Down Expand Up @@ -928,7 +929,7 @@ trait GraphBuilding extends Base with DefRewriting { IR: IRContext =>
sigmaDslBuilder.decodePoint(bytes)

// fallback rule for MethodCall, should be the last case in the list
case sigma.ast.MethodCall(obj, method, args, _) =>
case sigma.ast.MethodCall(obj, method, args, typeSubst) =>
val objV = eval(obj)
val argsV = args.map(eval)
(objV, method.objType) match {
Expand Down Expand Up @@ -1040,6 +1041,11 @@ trait GraphBuilding extends Base with DefRewriting { IR: IRContext =>
ctx.LastBlockUtxoRootHash
case SContextMethods.minerPubKeyMethod.name =>
ctx.minerPubKey
case SContextMethods.getVarFromInputMethod.name =>
val c1 = asRep[Short](argsV(0))
val c2 = asRep[Byte](argsV(1))
val c3 = stypeToElem(typeSubst.apply(tT))
ctx.getVarFromInput(c1, c2)(c3)
case _ => throwError
}
case (tree: Ref[AvlTree]@unchecked, SAvlTreeMethods) => method.name match {
Expand Down
2 changes: 1 addition & 1 deletion sc/shared/src/main/scala/sigma/compiler/ir/IRContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ trait IRContext
override def invokeUnlifted(e: Elem[_], mc: MethodCall, dataEnv: DataEnv): Any = e match {
case _: CollElem[_,_] => mc match {
case CollMethods.map(_, f) =>
val newMC = mc.copy(args = mc.args :+ f.elem.eRange)(mc.resultType, mc.isAdapterCall)
val newMC = mc.copy(args = mc.args :+ f.elem.eRange)(mc.resultType, mc.isAdapterCall, mc.typeSubst)
super.invokeUnlifted(e, newMC, dataEnv)
case _ =>
super.invokeUnlifted(e, mc, dataEnv)
Expand Down
7 changes: 4 additions & 3 deletions sc/shared/src/main/scala/sigma/compiler/ir/MethodCalls.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sigma.compiler.ir

import debox.{cfor, Buffer => DBuffer}
import sigma.ast.{SType, STypeVar}
import sigma.compiler.DelayInvokeException
import sigma.reflection.RMethod
import sigma.util.CollectionUtil.TraversableOps
Expand All @@ -26,7 +27,7 @@ trait MethodCalls extends Base { self: IRContext =>
* given `method`.
*/
case class MethodCall private[MethodCalls](receiver: Sym, method: RMethod, args: Seq[AnyRef], neverInvoke: Boolean)
(val resultType: Elem[Any], val isAdapterCall: Boolean = false) extends Def[Any] {
(val resultType: Elem[Any], val isAdapterCall: Boolean = false, val typeSubst: Map[STypeVar, SType] = Map()) extends Def[Any] {

override def mirror(t: Transformer): Ref[Any] = {
val len = args.length
Expand Down Expand Up @@ -100,8 +101,8 @@ trait MethodCalls extends Base { self: IRContext =>

/** Creates new MethodCall node and returns its node ref. */
def mkMethodCall(receiver: Sym, method: RMethod, args: Seq[AnyRef],
neverInvoke: Boolean, isAdapterCall: Boolean, resultElem: Elem[_]): Sym = {
reifyObject(MethodCall(receiver, method, args, neverInvoke)(asElem[Any](resultElem), isAdapterCall))
neverInvoke: Boolean, isAdapterCall: Boolean, resultElem: Elem[_], typeSubst: Map[STypeVar, SType] = Map.empty): Sym = {
reifyObject(MethodCall(receiver, method, args, neverInvoke)(asElem[Any](resultElem), isAdapterCall, typeSubst))
}

@tailrec
Expand Down
7 changes: 4 additions & 3 deletions sc/shared/src/main/scala/sigma/compiler/ir/TreeBuilding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,14 @@ trait TreeBuilding extends Base { IR: IRContext =>
mkMultiplyGroup(obj.asGroupElement, arg.asGroupElement)

// Fallback MethodCall rule: should be the last in this list of cases
case Def(MethodCall(objSym, m, argSyms, _)) =>
case Def(mc @ MethodCall(objSym, m, argSyms, _)) =>
val obj = recurse[SType](objSym)
val args = argSyms.collect { case argSym: Sym => recurse[SType](argSym) }
MethodsContainer.getMethod(obj.tpe, m.getName) match {
case Some(method) =>
val specMethod = method.specializeFor(obj.tpe, args.map(_.tpe))
builder.mkMethodCall(obj, specMethod, args.toIndexedSeq, Map())
val typeSubst = mc.typeSubst
val specMethod = method.specializeFor(obj.tpe, args.map(_.tpe)).withConcreteTypes(typeSubst)
builder.mkMethodCall(obj, specMethod, args.toIndexedSeq, typeSubst)
case None =>
error(s"Cannot find method ${m.getName} in object $obj")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ import scalan._
def preHeader: Ref[PreHeader];
def minerPubKey: Ref[Coll[Byte]];
def getVar[T](id: Ref[Byte])(implicit cT: Elem[T]): Ref[WOption[T]];
def getVarFromInput[T](inputId: Ref[Short], id: Ref[Byte])(implicit cT: Elem[T]): Ref[WOption[T]];
};
trait SigmaDslBuilder extends Def[SigmaDslBuilder] {
def Colls: Ref[CollBuilder];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import sigma.compiler.ir.wrappers.sigma.impl.SigmaDslDefs
import scala.collection.compat.immutable.ArraySeq

package impl {
import sigma.Evaluation
import sigma.ast.SType.tT
import sigma.compiler.ir.meta.ModuleInfo
import sigma.compiler.ir.wrappers.sigma.SigmaDsl
import sigma.compiler.ir.{Base, GraphIRReflection, IRContext}
Expand Down Expand Up @@ -1614,6 +1616,14 @@ object Context extends EntityObject("Context") {
true, false, element[WOption[T]]))
}

override def getVarFromInput[T](inputId: Ref[Short], varId: Ref[Byte])(implicit cT: Elem[T]): Ref[WOption[T]] = {
val st = Evaluation.rtypeToSType(cT.sourceType)
asRep[WOption[T]](mkMethodCall(self,
ContextClass.getMethod("getVarFromInput", classOf[Sym], classOf[Sym], classOf[Elem[_]]),
Array[AnyRef](inputId, varId, cT),
true, false, element[WOption[T]], Map(tT -> st)))
}

}

implicit object LiftableContext
Expand Down Expand Up @@ -1710,6 +1720,14 @@ object Context extends EntityObject("Context") {
Array[AnyRef](id, cT),
true, true, element[WOption[T]]))
}

def getVarFromInput[T](inputId: Ref[Short], varId: Ref[Byte])(implicit cT: Elem[T]): Ref[WOption[T]] = {
val st = Evaluation.rtypeToSType(cT.sourceType)
asRep[WOption[T]](mkMethodCall(source,
ContextClass.getMethod("getVarFromInput", classOf[Sym], classOf[Sym], classOf[Elem[_]]),
Array[AnyRef](inputId, varId, cT),
true, true, element[WOption[T]], Map(tT -> st)))
}
}

// entityUnref: single unref method for each type family
Expand All @@ -1727,7 +1745,7 @@ object Context extends EntityObject("Context") {
override protected def collectMethods: Map[RMethod, MethodDesc] = {
super.collectMethods ++
Elem.declaredMethods(RClass(classOf[Context]), RClass(classOf[SContext]), Set(
"OUTPUTS", "INPUTS", "dataInputs", "HEIGHT", "SELF", "selfBoxIndex", "LastBlockUtxoRootHash", "headers", "preHeader", "minerPubKey", "getVar", "vars"
"OUTPUTS", "INPUTS", "dataInputs", "HEIGHT", "SELF", "selfBoxIndex", "LastBlockUtxoRootHash", "headers", "preHeader", "minerPubKey", "getVar", "getVarFromInput", "vars"
))
}
}
Expand Down
17 changes: 16 additions & 1 deletion sc/shared/src/main/scala/sigma/compiler/phases/SigmaTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,18 @@ class SigmaTyper(val builder: SigmaBuilder,
res

case Apply(ApplyTypes(sel @ Select(obj, n, _), Seq(rangeTpe)), args) =>
val nArgs = if (n == SContextMethods.getVarFromInputMethod.name &&
args.length == 2 &&
args(0).isInstanceOf[Constant[_]] &&
args(1).isInstanceOf[Constant[_]] &&
args(0).tpe.isNumType &&
args(1).tpe.isNumType) {
IndexedSeq(ShortConstant(SShort.downcast(args(0).asInstanceOf[Constant[SNumericType]].value.asInstanceOf[AnyVal])).withSrcCtx(args(0).sourceContext),
ByteConstant(SByte.downcast(args(1).asInstanceOf[Constant[SNumericType]].value.asInstanceOf[AnyVal])).withSrcCtx(args(1).sourceContext))
} else args

val newObj = assignType(env, obj)
val newArgs = args.map(assignType(env, _))
val newArgs = nArgs.map(assignType(env, _))
obj.tpe match {
case p: SProduct =>
MethodsContainer.getMethod(p, n) match {
Expand Down Expand Up @@ -221,6 +231,11 @@ class SigmaTyper(val builder: SigmaBuilder,
case (Ident(GetVarFunc.name | ExecuteFromVarFunc.name, _), Seq(id: Constant[SNumericType]@unchecked))
if id.tpe.isNumType =>
Seq(ByteConstant(SByte.downcast(id.value.asInstanceOf[AnyVal])).withSrcCtx(id.sourceContext))
case (Ident(SContextMethods.getVarFromInputMethod.name, _),
Seq(inputId: Constant[SNumericType]@unchecked, varId: Constant[SNumericType]@unchecked))
if inputId.tpe.isNumType && varId.tpe.isNumType =>
Seq(ShortConstant(SShort.downcast(inputId.value.asInstanceOf[AnyVal])).withSrcCtx(inputId.sourceContext),
ByteConstant(SByte.downcast(varId.value.asInstanceOf[AnyVal])).withSrcCtx(varId.sourceContext))
case _ => typedArgs
}
val actualTypes = adaptedTypedArgs.map(_.tpe)
Expand Down
1 change: 1 addition & 0 deletions sc/shared/src/test/scala/sigma/SigmaDslSpecification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4762,6 +4762,7 @@ class SigmaDslSpecification extends SigmaDslTesting
.append(Coll[AnyValue](
CAnyValue(Helpers.decodeBytes("00")),
CAnyValue(true))),
spendingTransaction = null,
activatedScriptVersion = activatedVersionInTests,
currentErgoTreeVersion = ergoTreeVersionInTests
)
Expand Down
Loading
Loading