Skip to content

Commit

Permalink
shiftRight
Browse files Browse the repository at this point in the history
  • Loading branch information
kushti committed Jul 16, 2024
1 parent ced229f commit 9519ef6
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 11 deletions.
2 changes: 2 additions & 0 deletions core/shared/src/main/scala/sigma/SigmaDsl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ trait BigInt {
def xor(that: BigInt): BigInt

def shiftLeft(bits: Int): BigInt

def shiftRight(bits: Int): BigInt
}

/** Base class for points on elliptic curves. */
Expand Down
2 changes: 2 additions & 0 deletions core/shared/src/main/scala/sigma/data/CBigInt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,6 @@ case class CBigInt(override val wrappedValue: BigInteger) extends BigInt with Wr
override def xor(that: BigInt): BigInt = CBigInt(wrappedValue.xor(that.asInstanceOf[CBigInt].wrappedValue))

def shiftLeft(bits: Int): BigInt = CBigInt(wrappedValue.shiftLeft(bits).to256BitValueExact)

def shiftRight(bits: Int): BigInt = CBigInt(wrappedValue.shiftRight(bits).to256BitValueExact)
}
15 changes: 8 additions & 7 deletions data/shared/src/main/scala/sigma/ast/methods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,15 +363,15 @@ object SNumericTypeMethods extends MethodsContainer {
""".stripMargin)

val ShiftRightMethod: SMethod = SMethod(
this, "shiftRight", SFunc(Array(tNum, tNum), tNum), 13, BitwiseInverse_CostKind)
this, "shiftRight", SFunc(Array(tNum, SInt), tNum), 13, BitwiseInverse_CostKind)
.withIRInfo(MethodCallIrBuilder)
.withUserDefinedInvoke({ (m: SMethod, obj: Any, other: Array[Any]) =>
m.objType match {
case SByteMethods => ByteIsExactIntegral.bitwiseXor(obj.asInstanceOf[Byte], other.head.asInstanceOf[Byte])
case SShortMethods => ShortIsExactIntegral.bitwiseXor(obj.asInstanceOf[Short], other.head.asInstanceOf[Short])
case SIntMethods => IntIsExactIntegral.bitwiseXor(obj.asInstanceOf[Int], other.head.asInstanceOf[Int])
case SLongMethods => LongIsExactIntegral.bitwiseXor(obj.asInstanceOf[Long], other.head.asInstanceOf[Long])
case SBigIntMethods => BigIntIsExactIntegral.bitwiseXor(obj.asInstanceOf[BigInt], other.head.asInstanceOf[BigInt])
case SByteMethods => ByteIsExactIntegral.shiftRight(obj.asInstanceOf[Byte], other.head.asInstanceOf[Int])
case SShortMethods => ShortIsExactIntegral.shiftRight(obj.asInstanceOf[Short], other.head.asInstanceOf[Int])
case SIntMethods => IntIsExactIntegral.shiftRight(obj.asInstanceOf[Int], other.head.asInstanceOf[Int])
case SLongMethods => LongIsExactIntegral.shiftRight(obj.asInstanceOf[Long], other.head.asInstanceOf[Int])
case SBigIntMethods => BigIntIsExactIntegral.shiftRight(obj.asInstanceOf[BigInt], other.head.asInstanceOf[Int])
}
})
.withInfo(PropertyCall,
Expand All @@ -391,7 +391,8 @@ object SNumericTypeMethods extends MethodsContainer {
BitwiseOrMethod,
BitwiseAndMethod,
BitwiseXorMethod,
ShiftLeftMethod
ShiftLeftMethod,
ShiftRightMethod
)

/** Collection of names of numeric casting methods (like `toByte`, `toInt`, etc). */
Expand Down
2 changes: 2 additions & 0 deletions data/shared/src/main/scala/sigma/data/BigIntegerOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ object NumericOps {
override def bitwiseXor(x: BigInt, y: BigInt): BigInt = x.xor(y)

override def shiftLeft(x: BigInt, y: Int): BigInt = x.shiftLeft(y)

override def shiftRight(x: BigInt, y: Int): BigInt = x.shiftRight(y)
}

/** The instance of [[scalan.ExactOrdering]] typeclass for [[BigInt]]. */
Expand Down
9 changes: 7 additions & 2 deletions data/shared/src/main/scala/sigma/data/ExactIntegral.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ object ExactIntegral {
override def bitwiseOr(x: Byte, y: Byte): Byte = (x | y).toByte
override def bitwiseAnd(x: Byte, y: Byte): Byte = (x & y).toByte
override def bitwiseXor(x: Byte, y: Byte): Byte = (x ^ y).toByte
override def shiftLeft(x: Byte, y: Int): Byte = (x << y).toByte
override def shiftLeft(x: Byte, bits: Int): Byte = (x << bits).toByte
override def shiftRight(x: Byte, bits: Int): Byte = (x >> bits).toByte
}

implicit object ShortIsExactIntegral extends ExactIntegral[Short] {
val n = scala.math.Numeric.ShortIsIntegral
override def plus(x: Short, y: Short): Short = x.addExact(y)
override def plus(x: Short, y: Short):
Short = x.addExact(y)
override def minus(x: Short, y: Short): Short = x.subtractExact(y)
override def times(x: Short, y: Short): Short = x.multiplyExact(y)
override def toBigEndianBytes(x: Short): Coll[Byte] = Colls.fromItems((x >> 8).toByte, x.toByte)
Expand All @@ -57,6 +59,7 @@ object ExactIntegral {
override def bitwiseAnd(x: Short, y: Short): Short = (x & y).toShort
override def bitwiseXor(x: Short, y: Short): Short = (x ^ y).toShort
override def shiftLeft(x: Short, y: Int): Short = (x << y).toShort
override def shiftRight(x: Short, bits: Int): Short = (x >> bits).toShort
}

implicit object IntIsExactIntegral extends ExactIntegral[Int] {
Expand All @@ -71,6 +74,7 @@ object ExactIntegral {
override def bitwiseAnd(x: Int, y: Int): Int = x & y
override def bitwiseXor(x: Int, y: Int): Int = x ^ y
override def shiftLeft(x: Int, y: Int): Int = x << y
override def shiftRight(x: Int, bits: Int): Int = x >> bits
}

implicit object LongIsExactIntegral extends ExactIntegral[Long] {
Expand All @@ -85,5 +89,6 @@ object ExactIntegral {
override def bitwiseAnd(x: Long, y: Long): Long = x & y
override def bitwiseXor(x: Long, y: Long): Long = x ^ y
override def shiftLeft(x: Long, y: Int): Long = x << y
override def shiftRight(x: Long, bits: Int): Long = x >> bits
}
}
6 changes: 4 additions & 2 deletions data/shared/src/main/scala/sigma/data/ExactNumeric.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package sigma.data

import sigma.{Coll, Colls}
import sigma.{BigInt, Coll, Colls}
import sigma.data.ExactIntegral._

import scala.collection.mutable
Expand Down Expand Up @@ -62,7 +62,9 @@ trait ExactNumeric[T] {

def bitwiseXor(x: T, y: T): T

def shiftLeft(x: T, y: Int): T
def shiftLeft(x: T, bits: Int): T

def shiftRight(x: T, bits: Int): T

/** A value of type T which corresponds to integer 0. */
lazy val zero: T = fromInt(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,10 @@ trait GraphBuilding extends Base with DefRewriting { IR: IRContext =>
val y = asRep[Int](argsV(0))
val op = NumericShiftLeft(elemToExactNumeric(x.elem))(x.elem)
ApplyBinOpDiffArgs(op, x, y)
case SNumericTypeMethods.ShiftRightMethod.name =>
val y = asRep[Int](argsV(0))
val op = NumericShiftRight(elemToExactNumeric(x.elem))(x.elem)
ApplyBinOpDiffArgs(op, x, y)
case _ => throwError()
}
case _ => throwError(s"Type ${stypeToRType(obj.tpe).name} doesn't have methods")
Expand Down
6 changes: 6 additions & 0 deletions sc/shared/src/main/scala/sigma/compiler/ir/TreeBuilding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,12 @@ trait TreeBuilding extends Base { IR: IRContext =>
val m = SMethod.fromIds(receiverType.typeId, SNumericTypeMethods.ShiftLeftMethod.methodId)
builder.mkMethodCall(x.asNumValue, m, IndexedSeq(y))

case Def(ApplyBinOpDiffArgs(op, xSym, ySym)) if op.isInstanceOf[NumericShiftRight[_]] =>
val Seq(x, y) = Seq(xSym, ySym).map(recurse)
val receiverType = x.asNumValue.tpe.asNumTypeOrElse(error(s"Expected numeric type, got: ${x.tpe}"))
val m = SMethod.fromIds(receiverType.typeId, SNumericTypeMethods.ShiftRightMethod.methodId)
builder.mkMethodCall(x.asNumValue, m, IndexedSeq(y))


case Def(ApplyBinOp(IsArithOp(opCode), xSym, ySym)) =>
val Seq(x, y) = Seq(xSym, ySym).map(recurse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ trait NumericOps extends Base { self: IRContext =>
override def applySeq(x: T, y: Int): T = n.shiftLeft(x, y)
}

case class NumericShiftRight[T: Elem](n: ExactNumeric[T]) extends BinDiffArgsOp[T, Int](">>") {
override def applySeq(x: T, y: Int): T = n.shiftRight(x, y)
}

/** Base class for descriptors of binary division operations. */
abstract class DivOp[T: Elem](opName: String, n: ExactIntegral[T]) extends EndoBinOp[T](opName) {
override def shouldPropagate(lhs: T, rhs: T) = rhs != n.zero
Expand Down
121 changes: 121 additions & 0 deletions sc/shared/src/test/scala/sigmastate/utxo/BasicOpsSpecification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,127 @@ class BasicOpsSpecification extends CompilerTestingCommons
}
}

property("Byte.shiftRight") {
def shiftRightTest(): Assertion = test("Byte.shiftRight", env, ext,
s"""{
| val x = 8.toByte
| val y = 2
| x.shiftRight(y) == 2.toByte
|}""".stripMargin,
null
)

if (VersionContext.current.isV6SoftForkActivated) {
shiftRightTest()
} else {
an[Exception] shouldBe thrownBy(shiftRightTest())
}
}

property("Byte.shiftRight - neg") {
def shiftRightTest(): Assertion = test("Byte.shiftRight", env, ext,
s"""{
| val x = (-8).toByte
| val y = 2
| x.shiftRight(y) == (-2).toByte
|}""".stripMargin,
null
)

if (VersionContext.current.isV6SoftForkActivated) {
shiftRightTest()
} else {
an[Exception] shouldBe thrownBy(shiftRightTest())
}
}

property("Byte.shiftRight - neg - neg shift") {
def shiftRightTest(): Assertion = test("Byte.shiftRight", env, ext,
s"""{
| val x = (-8).toByte
| val y = -2
| x.shiftRight(y) == (-1).toByte
|}""".stripMargin,
null
)

if (VersionContext.current.isV6SoftForkActivated) {
shiftRightTest()
} else {
an[Exception] shouldBe thrownBy(shiftRightTest())
}
}

property("Long.shiftRight - neg") {
def shiftRightTest(): Assertion = test("Long.shiftRight", env, ext,
s"""{
| val x = -32L
| val y = 2
| x.shiftRight(y) == -8L
|}""".stripMargin,
null
)

if (VersionContext.current.isV6SoftForkActivated) {
shiftRightTest()
} else {
an[Exception] shouldBe thrownBy(shiftRightTest())
}
}

property("Long.shiftRight - neg - neg shift") {
def shiftRightTest(): Assertion = test("Long.shiftRight", env, ext,
s"""{
| val x = -32L
| val y = -2
| x.shiftRight(y) == -1L
|}""".stripMargin,
null
)

if (VersionContext.current.isV6SoftForkActivated) {
shiftRightTest()
} else {
an[Exception] shouldBe thrownBy(shiftRightTest())
}
}

property("BigInt.shiftRight") {
def shiftRightTest(): Assertion = test("BigInt.shiftRight", env, ext,
s"""{
| val x = bigInt("${CryptoConstants.groupOrder.divide(new BigInteger("2"))}")
| val y = 2
| val z = bigInt("${CryptoConstants.groupOrder.divide(new BigInteger("8"))}")
| x.shiftRight(y) == z
|}""".stripMargin,
null
)

if (VersionContext.current.isV6SoftForkActivated) {
shiftRightTest()
} else {
an[Exception] shouldBe thrownBy(shiftRightTest())
}
}

property("BigInt.shiftRight - neg shift") {
def shiftRightTest(): Assertion = test("BigInt.shiftRight", env, ext,
s"""{
| val x = bigInt("${CryptoConstants.groupOrder.divide(new BigInteger("2"))}")
| val y = -2
| val z = bigInt("${CryptoConstants.groupOrder.divide(new BigInteger("8"))}")
| z.shiftRight(y) == x
|}""".stripMargin,
null
)

if (VersionContext.current.isV6SoftForkActivated) {
shiftRightTest()
} else {
an[Exception] shouldBe thrownBy(shiftRightTest())
}
}

property("Unit register") {
// TODO frontend: implement missing Unit support in compiler
// https://github.com/ScorexFoundation/sigmastate-interpreter/issues/820
Expand Down

0 comments on commit 9519ef6

Please sign in to comment.