Skip to content

Commit

Permalink
Merge pull request #434 from ScorexFoundation/better-costing
Browse files Browse the repository at this point in the history
Better costing
  • Loading branch information
catena2w authored Mar 14, 2019
2 parents eb4c982 + 314c8a4 commit 9cc6610
Show file tree
Hide file tree
Showing 26 changed files with 313 additions and 200 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ val scorexUtil = "org.scorexfoundation" %% "scorex-util" % "0.1.3"
val macroCompat = "org.typelevel" %% "macro-compat" % "1.1.1"
val paradise = "org.scalamacros" %% "paradise" % "2.1.0" cross CrossVersion.full

val specialVersion = "new-costing-c39c6058-SNAPSHOT"
val specialVersion = "master-5ffd1bf8-SNAPSHOT"
val specialCommon = "io.github.scalan" %% "common" % specialVersion
val specialCore = "io.github.scalan" %% "core" % specialVersion
val specialLibrary = "io.github.scalan" %% "library" % specialVersion
Expand Down Expand Up @@ -137,7 +137,7 @@ credentials ++= (for {

def libraryDefSettings = commonSettings ++ testSettings ++ Seq(
scalacOptions ++= Seq(
// s"-Xplugin:${file(".").absolutePath }/scalanizer/target/scala-2.12/scalanizer-assembly-new-costing-54f98214-SNAPSHOT.jar"
// s"-Xplugin:${file(".").absolutePath }/scalanizer/target/scala-2.12/scalanizer-assembly-better-costing-2a66ed5c-SNAPSHOT.jar"
)
)

Expand Down
7 changes: 3 additions & 4 deletions sigma-api/src/main/resources/special/sigma/SigmaDsl.scalan
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ package special.sigma {
};
@Liftable trait GroupElement extends Def[GroupElement] {
def isInfinity: Rep[Boolean];
def multiply(k: Rep[BigInt]): Rep[GroupElement];
def add(that: Rep[GroupElement]): Rep[GroupElement];
def exp(k: Rep[BigInt]): Rep[GroupElement];
def multiply(that: Rep[GroupElement]): Rep[GroupElement];
def negate: Rep[GroupElement];
//todo remove compressed flag, use GroupElementSerializer
def getEncoded(compressed: Rep[Boolean]): Rep[Coll[Byte]]
def getEncoded: Rep[Coll[Byte]]
};
@Liftable trait SigmaProp extends Def[SigmaProp] {
def isValid: Rep[Boolean];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package special.sigma.wrappers {
import WSigmaPredef._;
import WrapSpecBase._;
trait ECPointWrapSpec extends WrapSpecBase {
//todo remove compressed flag, use GroupElementSerializer
def getEncoded[A](g: Rep[WECPoint], compressed: Rep[Boolean]): Rep[WArray[Byte]] = g.getEncoded(compressed);
def multiply(l: Rep[WECPoint], r: Rep[WBigInteger]): Rep[WECPoint] = l.multiply(r);
def add(l: Rep[WECPoint], r: Rep[WECPoint]): Rep[WECPoint] = l.add(r)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ package wrappers.org.bouncycastle.math.ec {
@External("ECPoint") @Liftable trait WECPoint extends Def[WECPoint] { self =>
@External def add(x$1: Rep[WECPoint]): Rep[WECPoint];
@External def multiply(x$1: Rep[WBigInteger]): Rep[WECPoint];
//todo remove compressed flag, use GroupElementSerializer
@External def getEncoded(x$1: Rep[Boolean]): Rep[WArray[Byte]]
};
trait WECPointCompanion
Expand Down
1 change: 1 addition & 0 deletions sigma-api/src/main/scala/sigma/types/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package types {

case class PrimViewType[T, Val](classTag: ClassTag[T], tVal: RType[Val]) extends ViewType[T, Val] {
override def name: String = tVal.name
override def isConstantSize: scala.Boolean = tVal.isConstantSize
}

object IsPrimView {
Expand Down
18 changes: 13 additions & 5 deletions sigma-api/src/main/scala/special/sigma/SigmaDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ trait GroupElement {

def isInfinity: Boolean

/** Multiplies this <code>GroupElement</code> by the given number.
* @param k The multiplicator.
* @return <code>k * this</code>.
/** Exponentiate this <code>GroupElement</code> to the given number.
* @param k The power.
* @return <code>this to the power of k</code>.
* @since 2.0
*/
def multiply(k: BigInt): GroupElement
def exp(k: BigInt): GroupElement

/** Group operation. */
def add(that: GroupElement): GroupElement
def multiply(that: GroupElement): GroupElement

/** Inverse element in the group. */
def negate: GroupElement
Expand Down Expand Up @@ -590,6 +590,14 @@ trait SigmaContract {

def groupGenerator: GroupElement = this.builder.groupGenerator

def decodePoint(encoded: Coll[Byte]): GroupElement = this.builder.decodePoint(encoded)

@Reified("T")
def substConstants[T](scriptBytes: Coll[Byte],
positions: Coll[Int],
newValues: Coll[T])
(implicit cT: RType[T]): Coll[Byte] = this.builder.substConstants(scriptBytes, positions, newValues)

@clause def canOpen(ctx: Context): Boolean

def asFunction: Context => Boolean = (ctx: Context) => this.canOpen(ctx)
Expand Down
28 changes: 21 additions & 7 deletions sigma-api/src/main/scala/special/sigma/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@ import java.math.BigInteger

import org.bouncycastle.math.ec.ECPoint
import scalan.RType
import scalan.RType.GeneralType

import scala.reflect.{classTag, ClassTag}
import scala.reflect.{ClassTag, classTag}

package sigma {

case class WrapperType[Wrapper](cWrapper: ClassTag[Wrapper]) extends RType[Wrapper] {
override def classTag: ClassTag[Wrapper] = cWrapper
override def toString: String = cWrapper.toString
override def name: String = cWrapper.runtimeClass.getSimpleName
override def isConstantSize: Boolean = false // pessimistic but safe default
}

}
Expand All @@ -21,17 +23,25 @@ package object sigma {
def wrapperType[W: ClassTag]: RType[W] = WrapperType(classTag[W])

// TODO make these types into GeneralType (same as Header and PreHeader)
implicit val BigIntRType: RType[BigInt] = wrapperType[BigInt]
implicit val GroupElementRType: RType[GroupElement] = wrapperType[GroupElement]
implicit val BigIntRType: RType[BigInt] = new WrapperType(classTag[BigInt]) {
override def isConstantSize: Boolean = true
}
implicit val GroupElementRType: RType[GroupElement] = new WrapperType(classTag[GroupElement]) {
override def isConstantSize: Boolean = true
}
implicit val SigmaPropRType: RType[SigmaProp] = wrapperType[SigmaProp]
implicit val BoxRType: RType[Box] = wrapperType[Box]
implicit val AvlTreeRType: RType[AvlTree] = wrapperType[AvlTree]
implicit val ContextRType: RType[Context] = wrapperType[Context]

// these are not wrapper types since they are used directly in ErgoTree values (e.g. Constants)
// and no conversion is necessary
implicit val HeaderRType: RType[Header] = RType.fromClassTag(classTag[Header])
implicit val PreHeaderRType: RType[PreHeader] = RType.fromClassTag(classTag[PreHeader])
implicit val HeaderRType: RType[Header] = new GeneralType(classTag[Header]) {
override def isConstantSize: Boolean = true
}
implicit val PreHeaderRType: RType[PreHeader] = new GeneralType(classTag[PreHeader]) {
override def isConstantSize: Boolean = true
}

implicit val AnyValueRType: RType[AnyValue] = RType.fromClassTag(classTag[AnyValue])
implicit val CostModelRType: RType[CostModel] = RType.fromClassTag(classTag[CostModel])
Expand All @@ -40,8 +50,12 @@ package object sigma {
implicit val SigmaContractRType: RType[SigmaContract] = RType.fromClassTag(classTag[SigmaContract])
implicit val SigmaDslBuilderRType: RType[SigmaDslBuilder] = RType.fromClassTag(classTag[SigmaDslBuilder])

implicit val BigIntegerRType: RType[BigInteger] = RType.fromClassTag(classTag[BigInteger])
implicit val ECPointRType: RType[ECPoint] = RType.fromClassTag(classTag[ECPoint])
implicit val BigIntegerRType: RType[BigInteger] = new GeneralType(classTag[BigInteger]) {
override def isConstantSize: Boolean = true
}
implicit val ECPointRType: RType[ECPoint] = new GeneralType(classTag[ECPoint]) {
override def isConstantSize: Boolean = true
}


implicit val SizeAnyValueRType: RType[SizeAnyValue] = RType.fromClassTag(classTag[SizeAnyValue])
Expand Down
20 changes: 3 additions & 17 deletions sigma-impl/src/main/scala/special/sigma/SigmaDslCosted.scala
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package special.sigma

import special.SpecialPredef
import special.collection.{Coll, CCostedPrim, _}

import scala.reflect.ClassTag
import scalan.RType
import scalan.{NeverInline, Reified}
import special.collection._
import scalan.{RType, NeverInline}

class CSizeAnyValue(val tVal: RType[Any], val valueSize: Size[Any]) extends SizeAnyValue {
@NeverInline
Expand Down Expand Up @@ -34,17 +30,7 @@ class CSizeBox(

@NeverInline
override def getReg[T](id: Byte)(implicit tT: RType[T]): Size[Option[T]] = {
val varSize = registers.asInstanceOf[SizeColl[Option[AnyValue]]].sizes(id.toInt)
val foundSize = varSize.asInstanceOf[SizeOption[AnyValue]].sizeOpt
val regSize = foundSize match {
case Some(varSize: SizeAnyValue) =>
assert(varSize.tVal == tT, s"Unexpected register type found ${varSize.tVal}: expected $tT")
val regSize = varSize.valueSize.asInstanceOf[Size[T]]
regSize
case _ =>
new CSizePrim(0L, tT)
}
new CSizeOption[T](Some(regSize))
sys.error(s"Shouldn't be called and must be overriden by the class in sigmastate.eval package")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ abstract class TestGroupElement(private[sigma] val value: ECPoint) extends Group

override def isInfinity: Boolean = value.isInfinity

override def multiply(k: BigInt): GroupElement = dsl.GroupElement(value.multiply(k.value))
override def exp(k: BigInt): GroupElement = dsl.GroupElement(value.multiply(k.value))

override def add(that: GroupElement): GroupElement = dsl.GroupElement(value.add(that.value))
override def multiply(that: GroupElement): GroupElement = dsl.GroupElement(value.add(that.value))

override def negate: GroupElement = dsl.GroupElement(value.negate())

Expand Down
7 changes: 3 additions & 4 deletions sigma-library/src/main/scala/special/sigma/SigmaDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ package special.sigma {
};
@Liftable trait GroupElement extends Def[GroupElement] {
def isInfinity: Rep[Boolean];
def multiply(k: Rep[BigInt]): Rep[GroupElement];
def add(that: Rep[GroupElement]): Rep[GroupElement];
def exp(k: Rep[BigInt]): Rep[GroupElement];
def multiply(that: Rep[GroupElement]): Rep[GroupElement];
def negate: Rep[GroupElement];
//todo remove compressed flag, use GroupElementSerializer
def getEncoded(compressed: Rep[Boolean]): Rep[Coll[Byte]]
def getEncoded: Rep[Coll[Byte]]
};
@Liftable trait SigmaProp extends Def[SigmaProp] {
def isValid: Rep[Boolean];
Expand Down
50 changes: 24 additions & 26 deletions sigma-library/src/main/scala/special/sigma/impl/SigmaDslImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1184,16 +1184,16 @@ object GroupElement extends EntityObject("GroupElement") {
true, false, element[Boolean]))
}

override def multiply(k: Rep[BigInt]): Rep[GroupElement] = {
override def exp(k: Rep[BigInt]): Rep[GroupElement] = {
asRep[GroupElement](mkMethodCall(self,
GroupElementClass.getMethod("multiply", classOf[Sym]),
GroupElementClass.getMethod("exp", classOf[Sym]),
List(k),
true, false, element[GroupElement]))
}

override def add(that: Rep[GroupElement]): Rep[GroupElement] = {
override def multiply(that: Rep[GroupElement]): Rep[GroupElement] = {
asRep[GroupElement](mkMethodCall(self,
GroupElementClass.getMethod("add", classOf[Sym]),
GroupElementClass.getMethod("multiply", classOf[Sym]),
List(that),
true, false, element[GroupElement]))
}
Expand All @@ -1205,11 +1205,10 @@ object GroupElement extends EntityObject("GroupElement") {
true, false, element[GroupElement]))
}

//todo remove compressed flag, use GroupElementSerializer
override def getEncoded(compressed: Rep[Boolean]): Rep[Coll[Byte]] = {
override def getEncoded: Rep[Coll[Byte]] = {
asRep[Coll[Byte]](mkMethodCall(self,
GroupElementClass.getMethod("getEncoded", classOf[Sym]),
List(compressed),
GroupElementClass.getMethod("getEncoded"),
List(),
true, false, element[Coll[Byte]]))
}
}
Expand Down Expand Up @@ -1242,16 +1241,16 @@ object GroupElement extends EntityObject("GroupElement") {
true, true, element[Boolean]))
}

def multiply(k: Rep[BigInt]): Rep[GroupElement] = {
def exp(k: Rep[BigInt]): Rep[GroupElement] = {
asRep[GroupElement](mkMethodCall(source,
thisClass.getMethod("multiply", classOf[Sym]),
thisClass.getMethod("exp", classOf[Sym]),
List(k),
true, true, element[GroupElement]))
}

def add(that: Rep[GroupElement]): Rep[GroupElement] = {
def multiply(that: Rep[GroupElement]): Rep[GroupElement] = {
asRep[GroupElement](mkMethodCall(source,
thisClass.getMethod("add", classOf[Sym]),
thisClass.getMethod("multiply", classOf[Sym]),
List(that),
true, true, element[GroupElement]))
}
Expand All @@ -1263,11 +1262,10 @@ object GroupElement extends EntityObject("GroupElement") {
true, true, element[GroupElement]))
}

//todo remove compressed flag, use GroupElementSerializer
def getEncoded(compressed: Rep[Boolean]): Rep[Coll[Byte]] = {
def getEncoded: Rep[Coll[Byte]] = {
asRep[Coll[Byte]](mkMethodCall(source,
thisClass.getMethod("getEncoded", classOf[Sym]),
List(compressed),
thisClass.getMethod("getEncoded"),
List(),
true, true, element[Coll[Byte]]))
}
}
Expand All @@ -1287,7 +1285,7 @@ object GroupElement extends EntityObject("GroupElement") {
override protected def collectMethods: Map[java.lang.reflect.Method, MethodDesc] = {
super.collectMethods ++
Elem.declaredMethods(classOf[GroupElement], classOf[SGroupElement], Set(
"isInfinity", "multiply", "add", "negate", "getEncoded"
"isInfinity", "exp", "multiply", "negate", "getEncoded"
))
}

Expand Down Expand Up @@ -1343,9 +1341,9 @@ object GroupElement extends EntityObject("GroupElement") {
}
}

object multiply {
object exp {
def unapply(d: Def[_]): Nullable[(Rep[GroupElement], Rep[BigInt])] = d match {
case MethodCall(receiver, method, args, _) if receiver.elem.isInstanceOf[GroupElementElem[_]] && method.getName == "multiply" =>
case MethodCall(receiver, method, args, _) if receiver.elem.isInstanceOf[GroupElementElem[_]] && method.getName == "exp" =>
val res = (receiver, args(0))
Nullable(res).asInstanceOf[Nullable[(Rep[GroupElement], Rep[BigInt])]]
case _ => Nullable.None
Expand All @@ -1356,9 +1354,9 @@ object GroupElement extends EntityObject("GroupElement") {
}
}

object add {
object multiply {
def unapply(d: Def[_]): Nullable[(Rep[GroupElement], Rep[GroupElement])] = d match {
case MethodCall(receiver, method, args, _) if receiver.elem.isInstanceOf[GroupElementElem[_]] && method.getName == "add" =>
case MethodCall(receiver, method, args, _) if receiver.elem.isInstanceOf[GroupElementElem[_]] && method.getName == "multiply" =>
val res = (receiver, args(0))
Nullable(res).asInstanceOf[Nullable[(Rep[GroupElement], Rep[GroupElement])]]
case _ => Nullable.None
Expand All @@ -1383,13 +1381,13 @@ object GroupElement extends EntityObject("GroupElement") {
}

object getEncoded {
def unapply(d: Def[_]): Nullable[(Rep[GroupElement], Rep[Boolean])] = d match {
case MethodCall(receiver, method, args, _) if receiver.elem.isInstanceOf[GroupElementElem[_]] && method.getName == "getEncoded" =>
val res = (receiver, args(0))
Nullable(res).asInstanceOf[Nullable[(Rep[GroupElement], Rep[Boolean])]]
def unapply(d: Def[_]): Nullable[Rep[GroupElement]] = d match {
case MethodCall(receiver, method, _, _) if receiver.elem.isInstanceOf[GroupElementElem[_]] && method.getName == "getEncoded" =>
val res = receiver
Nullable(res).asInstanceOf[Nullable[Rep[GroupElement]]]
case _ => Nullable.None
}
def unapply(exp: Sym): Nullable[(Rep[GroupElement], Rep[Boolean])] = exp match {
def unapply(exp: Sym): Nullable[Rep[GroupElement]] = exp match {
case Def(d) => unapply(d)
case _ => Nullable.None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ package special.sigma.wrappers {
import WSigmaPredef._;
import WrapSpecBase._;
trait ECPointWrapSpec extends WrapSpecBase {
//todo remove compressed flag, use GroupElementSerializer
def getEncoded[A](g: Rep[WECPoint], compressed: Rep[Boolean]): Rep[WArray[Byte]] = g.getEncoded(compressed);
def multiply(l: Rep[WECPoint], r: Rep[WBigInteger]): Rep[WECPoint] = l.multiply(r);
def add(l: Rep[WECPoint], r: Rep[WECPoint]): Rep[WECPoint] = l.add(r)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ package wrappers.org.bouncycastle.math.ec {
@External("ECPoint") @Liftable trait WECPoint extends Def[WECPoint] {
@External def add(x$1: Rep[WECPoint]): Rep[WECPoint];
@External def multiply(x$1: Rep[WBigInteger]): Rep[WECPoint];
//todo remove compressed flag, use GroupElementSerializer
@External def getEncoded(x$1: Rep[Boolean]): Rep[WArray[Byte]]
};
trait WECPointCompanion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ object WECPoint extends EntityObject("WECPoint") {
true, false, element[WECPoint]))
}

//todo remove compressed flag, use GroupElementSerializer
override def getEncoded(x$1: Rep[Boolean]): Rep[WArray[Byte]] = {
asRep[WArray[Byte]](mkMethodCall(self,
WECPointClass.getMethod("getEncoded", classOf[Sym]),
Expand Down Expand Up @@ -98,7 +97,6 @@ object WECPoint extends EntityObject("WECPoint") {
true, true, element[WECPoint]))
}

//todo remove compressed flag, use GroupElementSerializer
def getEncoded(x$1: Rep[Boolean]): Rep[WArray[Byte]] = {
asRep[WArray[Byte]](mkMethodCall(source,
thisClass.getMethod("getEncoded", classOf[Sym]),
Expand Down
Loading

0 comments on commit 9cc6610

Please sign in to comment.