Skip to content

Commit

Permalink
moving to Global, finalizing costing
Browse files Browse the repository at this point in the history
  • Loading branch information
kushti committed Jul 31, 2024
1 parent aaa2aa9 commit a20c04f
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 95 deletions.
14 changes: 12 additions & 2 deletions core/shared/src/main/scala/sigma/SigmaDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ trait BigInt {
*/
def or(that: BigInt): BigInt
def |(that: BigInt): BigInt = or(that)

def nbits: Long
}

/** Base class for points on elliptic curves. */
Expand Down Expand Up @@ -697,6 +695,18 @@ trait SigmaDslBuilder {
*/
def groupGenerator: GroupElement

/**
* @return big integer provided as input approximately encoded using NBits,
* see (https://bitcoin.stackexchange.com/questions/57184/what-does-the-nbits-value-represent)
* for format details
*/
def encodeNbits(bi: BigInt): Long

/**
* @return big integer decoded from NBits value provided,
* see (https://bitcoin.stackexchange.com/questions/57184/what-does-the-nbits-value-represent)
* for format details
*/
def decodeNbits(l: Long): BigInt

/**
Expand Down
3 changes: 0 additions & 3 deletions core/shared/src/main/scala/sigma/data/CBigInt.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package sigma.data

import sigma.util.Extensions.BigIntegerOps
import sigma.util.NBitsUtils
import sigma.{BigInt, Coll, Colls}

import java.math.BigInteger
Expand Down Expand Up @@ -50,6 +49,4 @@ case class CBigInt(override val wrappedValue: BigInteger) extends BigInt with Wr
override def and(that: BigInt): BigInt = CBigInt(wrappedValue.and(that.asInstanceOf[CBigInt].wrappedValue))

override def or(that: BigInt): BigInt = CBigInt(wrappedValue.or(that.asInstanceOf[CBigInt].wrappedValue))

override def nbits: Long = NBitsUtils.encodeCompactBits(wrappedValue)
}
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ object ReflectionData {
},
mkMethod(clazz, "divide", paramTypes) { (obj, args) =>
obj.asInstanceOf[BigInt].divide(args(0).asInstanceOf[BigInt])
},
mkMethod(clazz, "nbits", paramTypes) { (obj, _) =>
obj.asInstanceOf[BigInt].nbits
}
)
)
Expand Down Expand Up @@ -447,6 +444,9 @@ object ReflectionData {
mkMethod(clazz, "decodePoint", Array[Class[_]](cColl)) { (obj, args) =>
obj.asInstanceOf[SigmaDslBuilder].decodePoint(args(0).asInstanceOf[Coll[Byte]])
},
mkMethod(clazz, "encodeNbits", Array[Class[_]](cColl)) { (obj, args) =>
obj.asInstanceOf[SigmaDslBuilder].encodeNbits(args(0).asInstanceOf[BigInt])
},
mkMethod(clazz, "decodeNbits", Array[Class[_]](cColl)) { (obj, args) =>
obj.asInstanceOf[SigmaDslBuilder].decodeNbits(args(0).asInstanceOf[Long])
}
Expand Down
73 changes: 34 additions & 39 deletions data/shared/src/main/scala/sigma/ast/methods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,13 @@ case object SIntMethods extends SNumericTypeMethods {
case object SLongMethods extends SNumericTypeMethods {
/** Type for which this container defines methods. */
override def ownerType: SMonoType = SLong

protected override def getMethods(): Seq[SMethod] = super.getMethods()

}

/** Methods of BigInt type. Implemented using [[java.math.BigInteger]]. */
case object SBigIntMethods extends SNumericTypeMethods {
/** Type for which this container defines methods. */
override def ownerType: SMonoType = SBigInt

final val ToNBitsCostInfo = OperationCostInfo(
FixedCost(JitCost(5)), NamedDesc("NBitsMethodCall"))

//id = 8 to make it after toBits
val ToNBits = SMethod(this, "nbits", SFunc(this.ownerType, SLong), 8, ToNBitsCostInfo.costKind)
.withInfo(ModQ, "Encode this big integer value as NBits")

/** The following `modQ` methods are not fully implemented in v4.x and this descriptors.
* This descritors are remain here in the code and are waiting for full implementation
* is upcoming soft-forks at which point the cost parameters should be calculated and
Expand All @@ -337,7 +327,7 @@ case object SBigIntMethods extends SNumericTypeMethods {

protected override def getMethods(): Seq[SMethod] = {
if (VersionContext.current.isV6SoftForkActivated) {
super.getMethods() ++ Seq(ToNBits)
super.getMethods()
// ModQMethod,
// PlusModQMethod,
// MinusModQMethod,
Expand All @@ -347,14 +337,6 @@ case object SBigIntMethods extends SNumericTypeMethods {
super.getMethods()
}
}

/**
*
*/
def nbits_eval(mc: MethodCall, bi: sigma.BigInt)(implicit E: ErgoTreeEvaluator): Long = {
E.nbits(mc, bi)
}

}

/** Methods of type `String`. */
Expand Down Expand Up @@ -1523,37 +1505,50 @@ case object SGlobalMethods extends MonoTypeMethods {
Xor.xorWithCosting(ls, rs)
}

private lazy val EnDecodeNBitsCost = FixedCost(JitCost(5)) // the same cost for nbits encoding and decoding

lazy val encodeNBitsMethod: SMethod = SMethod(
this, "encodeNbits", SFunc(Array(SGlobal, SBigInt), SLong), 3, EnDecodeNBitsCost)
.withIRInfo(MethodCallIrBuilder)
.withInfo(MethodCall, "Encode big integer number as nbits", ArgInfo("bigInt", "Big integer"))

lazy val decodeNBitsMethod: SMethod = SMethod(
this, "decodeNbits", SFunc(Array(SGlobal, SLong), SBigInt), 3, FixedCost(JitCost(5)))
this, "decodeNbits", SFunc(Array(SGlobal, SLong), SBigInt), 4, EnDecodeNBitsCost)
.withIRInfo(MethodCallIrBuilder)
.withInfo(Xor, "Byte-wise XOR of two collections of bytes", ArgInfo("left", "left operand"))
.withInfo(MethodCall, "Decode nbits-encoded big integer number", ArgInfo("nbits", "NBits-encoded argument"))

/**
*
* encodeNBits evaluation with costing
*/
def encodeNbits_eval(mc: MethodCall, G: SigmaDslBuilder, bigInt: BigInt)(implicit E: ErgoTreeEvaluator): Long = {
E.addFixedCost(EnDecodeNBitsCost, encodeNBitsMethod.opDesc) {
NBitsUtils.encodeCompactBits(bigInt.asInstanceOf[CBigInt].wrappedValue)
}
}

/**
* decodeNBits evaluation with costing
*/
def decodeNbits_eval(mc: MethodCall, G: SigmaDslBuilder, l: Long)(implicit E: ErgoTreeEvaluator): BigInt = {
CBigInt(NBitsUtils.decodeCompactBits(l).bigInteger) // todo: costing is ignored here
E.addFixedCost(EnDecodeNBitsCost, decodeNBitsMethod.opDesc) {
CBigInt(NBitsUtils.decodeCompactBits(l).bigInteger)
}
}

{
protected override def getMethods() = {
if (VersionContext.current.isV6SoftForkActivated) {
super.getMethods() ++ Seq(decodeNBitsMethod)
super.getMethods() ++ Seq(
groupGeneratorMethod,
xorMethod,
encodeNBitsMethod,
decodeNBitsMethod
)
} else {
super.getMethods()
super.getMethods() ++ Seq(
groupGeneratorMethod,
xorMethod
)
}
}

protected override def getMethods() = if (VersionContext.current.isV6SoftForkActivated) {
super.getMethods() ++ Seq(
groupGeneratorMethod,
xorMethod,
decodeNBitsMethod
)
} else {
super.getMethods() ++ Seq(
groupGeneratorMethod,
xorMethod
)
}
}

4 changes: 4 additions & 0 deletions data/shared/src/main/scala/sigma/data/CSigmaDslBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ class CSigmaDslBuilder extends SigmaDslBuilder { dsl =>

override def groupGenerator: GroupElement = _generatorElement

def encodeNbits(bi: BigInt): Long = {
NBitsUtils.encodeCompactBits(bi.asInstanceOf[CBigInt].wrappedValue)
}

def decodeNbits(l: Long): BigInt = {
CBigInt(NBitsUtils.decodeCompactBits(l).bigInteger)
}
Expand Down
3 changes: 0 additions & 3 deletions data/shared/src/main/scala/sigma/eval/ErgoTreeEvaluator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ abstract class ErgoTreeEvaluator {
/** Represents blockchain data context for ErgoTree evaluation. */
def context: Context

/** Implements evaluation of BigInt.nbits method call ErgoTree node. */
def nbits(mc: MethodCall, bi: sigma.BigInt): Long

/** Create an instance of [[AvlTreeVerifier]] for the given tree and proof. */
def createTreeVerifier(tree: AvlTree, proof: Coll[Byte]): AvlTreeVerifier

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ class CErgoTreeEvaluator(
override def createTreeVerifier(tree: AvlTree, proof: Coll[Byte]): AvlTreeVerifier =
CAvlTreeVerifier(tree, proof)

def nbits(mc: MethodCall, bi: sigma.BigInt): Long = {
addFixedCost(SBigIntMethods.ToNBitsCostInfo) {
bi.nbits
}
}

/** Creates [[sigma.eval.AvlTreeVerifier]] for the given tree and proof. */
def createVerifier(tree: AvlTree, proof: Coll[Byte]) = {
// the cost of tree reconstruction from proof is O(proof.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ class MethodCallSerializerSpecification extends SerializationSpecification {
roundTripTest(expr)
}

property("MethodCall deserialization round trip for BigInt.nbits") {
property("MethodCall deserialization round trip for Global.encodeNBits") {
def code = {
val bi = BigIntConstant(5)
val expr = MethodCall(bi,
SBigIntMethods.ToNBits,
Vector(),
val expr = MethodCall(Global,
SGlobalMethods.encodeNBitsMethod,
Vector(bi),
Map()
)
roundTripTest(expr)
Expand All @@ -46,12 +46,12 @@ class MethodCallSerializerSpecification extends SerializationSpecification {
})
}

property("MethodCall deserialization round trip for BigInt.nbits") {
property("MethodCall deserialization round trip for Global.decodeNBits") {
def code = {
val bi = BigIntConstant(5)
val expr = MethodCall(bi,
SBigIntMethods.ToNBits,
Vector(),
val l = LongConstant(5)
val expr = MethodCall(Global,
SGlobalMethods.decodeNBitsMethod,
Vector(l),
Map()
)
roundTripTest(expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package sigmastate.eval

import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import sigma.ast.{BigIntConstant, ErgoTree, JitCost, MethodCall, SBigIntMethods}
import sigma.ast.{BigIntConstant, ErgoTree, Global, JitCost, MethodCall, SBigIntMethods, SGlobalMethods}
import sigma.crypto.SecP256K1Group
import sigma.data.{CBigInt, CSigmaDslBuilder, TrivialProp}
import sigma.data.{CBigInt, TrivialProp}
import sigma.eval.SigmaDsl
import sigma.util.Extensions.SigmaBooleanOps
import sigma.util.NBitsUtils

import java.math.BigInteger
import sigma.{ContractsTestkit, SigmaDslBuilder, SigmaProp}
import sigma.{ContractsTestkit, SigmaProp}
import sigmastate.interpreter.{CErgoTreeEvaluator, CostAccumulator}
import sigmastate.interpreter.CErgoTreeEvaluator.DefaultProfiler

Expand Down Expand Up @@ -72,7 +72,7 @@ class BasicOpsTests extends AnyFunSuite with ContractsTestkit with Matchers {
* Checks BigInt.nbits evaluation for SigmaDSL as well as AST interpreter (MethodCall) layers
*/
test("nbits evaluation") {
SigmaDsl.BigInt(BigInteger.valueOf(0)).nbits should be
SigmaDsl.encodeNbits(CBigInt(BigInteger.valueOf(0))) should be
(NBitsUtils.encodeCompactBits(0))

val es = CErgoTreeEvaluator.DefaultEvalSettings
Expand All @@ -84,7 +84,7 @@ class BasicOpsTests extends AnyFunSuite with ContractsTestkit with Matchers {
constants = ErgoTree.EmptyConstants,
coster = accumulator, DefaultProfiler, es)

val res = MethodCall(BigIntConstant(BigInteger.valueOf(0)), SBigIntMethods.ToNBits, IndexedSeq.empty, Map.empty)
val res = MethodCall(Global, SGlobalMethods.encodeNBitsMethod, IndexedSeq(BigIntConstant(BigInteger.valueOf(0))), Map.empty)
.evalTo[Long](Map.empty)(evaluator)

res should be (NBitsUtils.encodeCompactBits(0))
Expand Down
10 changes: 3 additions & 7 deletions sc/shared/src/main/scala/sigma/compiler/ir/GraphBuilding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -504,9 +504,6 @@ trait GraphBuilding extends Base with DefRewriting { IR: IRContext =>
else
error(s"The type of $obj is expected to be Collection to select 'size' property", obj.sourceContext.toOption)

case Select(obj, SBigIntMethods.ToNBits.name, _) if obj.tpe == SBigInt && VersionContext.current.isV6SoftForkActivated =>
eval(sigma.ast.MethodCall(obj, SBigIntMethods.ToNBits, IndexedSeq.empty, Map.empty))

// Rule: proof.isProven --> IsValid(proof)
case Select(p, SSigmaPropMethods.IsProven, _) if p.tpe == SSigmaProp =>
eval(SigmaPropIsProven(p.asSigmaProp))
Expand Down Expand Up @@ -936,10 +933,6 @@ trait GraphBuilding extends Base with DefRewriting { IR: IRContext =>
val objV = eval(obj)
val argsV = args.map(eval)
(objV, method.objType) match {
case (bi: Ref[BigInt]@unchecked, SBigIntMethods) => method.name match {
case SBigIntMethods.ToNBits.name =>
bi.nbits
}
case (xs: RColl[t]@unchecked, SCollectionMethods) => method.name match {
case SCollectionMethods.IndicesMethod.name =>
xs.indices
Expand Down Expand Up @@ -1154,6 +1147,9 @@ trait GraphBuilding extends Base with DefRewriting { IR: IRContext =>
val c1 = asRep[Coll[Byte]](argsV(0))
val c2 = asRep[Coll[Byte]](argsV(1))
g.xor(c1, c2)
case SGlobalMethods.encodeNBitsMethod.name if VersionContext.current.isV6SoftForkActivated =>
val c1 = asRep[BigInt](argsV(0))
g.encodeNbits(c1)
case SGlobalMethods.decodeNBitsMethod.name if VersionContext.current.isV6SoftForkActivated =>
val c1 = asRep[Long](argsV(0))
g.decodeNbits(c1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import scalan._
def mod(m: Ref[BigInt]): Ref[BigInt];
def min(that: Ref[BigInt]): Ref[BigInt];
def max(that: Ref[BigInt]): Ref[BigInt];
def nbits: Ref[Long]
};
trait GroupElement extends Def[GroupElement] {
def exp(k: Ref[BigInt]): Ref[GroupElement];
Expand Down Expand Up @@ -115,6 +114,7 @@ import scalan._
/** This method will be used in v6.0 to handle CreateAvlTree operation in GraphBuilding */
def avlTree(operationFlags: Ref[Byte], digest: Ref[Coll[Byte]], keyLength: Ref[Int], valueLengthOpt: Ref[WOption[Int]]): Ref[AvlTree];
def xor(l: Ref[Coll[Byte]], r: Ref[Coll[Byte]]): Ref[Coll[Byte]]
def encodeNbits(bi: Ref[BigInt]): Ref[Long]
def decodeNbits(l: Ref[Long]): Ref[BigInt]
};
trait CostModelCompanion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,6 @@ object BigInt extends EntityObject("BigInt") {
Array[AnyRef](that),
true, false, element[BigInt]))
}

override def nbits: Ref[Long] = {
asRep[Long](mkMethodCall(self,
BigIntClass.getMethod("nbits"),
Array[AnyRef](),
neverInvoke = true, isAdapterCall = false, element[Long]))
}
}

implicit object LiftableBigInt
Expand Down Expand Up @@ -172,13 +165,6 @@ object BigInt extends EntityObject("BigInt") {
Array[AnyRef](that),
true, true, element[BigInt]))
}

def nbits: Ref[Long] = {
asRep[Long](mkMethodCall(source,
BigIntClass.getMethod("nbits", classOf[Sym]),
Array[AnyRef](),
neverInvoke = true, isAdapterCall = true, element[Long]))
}
}

// entityUnref: single unref method for each type family
Expand Down Expand Up @@ -1960,10 +1946,17 @@ object SigmaDslBuilder extends EntityObject("SigmaDslBuilder") {
true, false, element[Coll[Byte]]))
}

override def encodeNbits(bi: Ref[BigInt]): Ref[Long] = {
asRep[Long](mkMethodCall(self,
SigmaDslBuilderClass.getMethod("encodeNbits", classOf[Sym]),
Array[AnyRef](bi),
true, false, element[Long]))
}

override def decodeNbits(l: Ref[Long]): Ref[BigInt] = {
asRep[BigInt](mkMethodCall(self,
SigmaDslBuilderClass.getMethod("decodeNbits", classOf[Sym]),
Array[AnyRef](),
Array[AnyRef](l),
true, false, element[BigInt]))
}
}
Expand Down Expand Up @@ -2126,6 +2119,13 @@ object SigmaDslBuilder extends EntityObject("SigmaDslBuilder") {
true, true, element[Coll[Byte]]))
}

override def encodeNbits(bi: Ref[BigInt]): Ref[Long] = {
asRep[Long](mkMethodCall(source,
SigmaDslBuilderClass.getMethod("encodeNbits", classOf[Sym]),
Array[AnyRef](bi),
true, true, element[Long]))
}

override def decodeNbits(l: Ref[Long]): Ref[BigInt] = {
asRep[BigInt](mkMethodCall(source,
SigmaDslBuilderClass.getMethod("decodeNbits", classOf[Sym]),
Expand Down
Loading

0 comments on commit a20c04f

Please sign in to comment.