Skip to content

Commit

Permalink
Merge pull request #1011 from ergoplatform/i909
Browse files Browse the repository at this point in the history
[6.0] Improve collections equality
  • Loading branch information
kushti authored Oct 18, 2024
2 parents 83ba4a4 + 5cabedd commit 571e721
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 32 deletions.
93 changes: 72 additions & 21 deletions core/shared/src/main/scala/sigma/data/CollsOverArrays.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package sigma.data

import debox.{Buffer, cfor}
import sigma.Evaluation.stypeToRType
import sigma.data.CollOverArray.equalsPairCollWithCollOverArray
import sigma.data.RType._
import sigma.util.{CollectionUtil, MaxArrayLength, safeConcatArrays_v5}
import sigma.{Coll, CollBuilder, PairColl, VersionContext, requireSameLength}
Expand All @@ -12,7 +14,9 @@ class CollOverArray[@specialized A](val toArray: Array[A], val builder: CollBuil
s"Cannot create collection with size ${toArray.length} greater than $MaxArrayLength")

override def tItem: RType[A] = tA

@inline def length: Int = toArray.length

@inline def apply(i: Int): A = toArray.apply(i)

override def isEmpty: Boolean = length == 0
Expand All @@ -29,8 +33,11 @@ class CollOverArray[@specialized A](val toArray: Array[A], val builder: CollBuil
}

def foreach(f: A => Unit): Unit = toArray.foreach(f)

def exists(p: A => Boolean): Boolean = toArray.exists(p)

def forall(p: A => Boolean): Boolean = toArray.forall(p)

def filter(p: A => Boolean): Coll[A] = builder.fromArray(toArray.filter(p))

def foldLeft[B](zero: B, op: ((B, A)) => B): B = toArray.foldLeft(zero)((b, a) => op((b, a)))
Expand Down Expand Up @@ -117,12 +124,14 @@ class CollOverArray[@specialized A](val toArray: Array[A], val builder: CollBuil
override def unionSet(that: Coll[A]): Coll[A] = {
val set = debox.Set.ofSize[A](this.length)
val res = Buffer.ofSize[A](this.length)

@inline def addItemToSet(x: A) = {
if (!set(x)) {
set.add(x)
res += x
}
}

def addToSet(arr: Array[A]) = {
val limit = arr.length
cfor(0)(_ < limit, _ + 1) { i =>
Expand All @@ -139,14 +148,42 @@ class CollOverArray[@specialized A](val toArray: Array[A], val builder: CollBuil

override def equals(obj: scala.Any): Boolean = (this eq obj.asInstanceOf[AnyRef]) || (obj match {
case obj: CollOverArray[_] if obj.tItem == this.tItem =>
java.util.Objects.deepEquals(obj.toArray, toArray)
java.util.Objects.deepEquals(obj.toArray, this.toArray)
case obj: PairColl[Any, Any] if obj.tItem == this.tItem =>
if (VersionContext.current.isV6SoftForkActivated) {
equalsPairCollWithCollOverArray(obj, this.asInstanceOf[CollOverArray[Any]])
} else {
false
}
case _ => false
})

override def hashCode() = CollectionUtil.deepHashCode(toArray)
override def hashCode(): Int = CollectionUtil.deepHashCode(toArray)
}

object CollOverArray {

// comparing PairColl and CollOverArray instances
private[data] def equalsPairCollWithCollOverArray(pc: PairColl[Any, Any], coa: CollOverArray[Any]): Boolean = {
val ls = pc.ls
val rs = pc.rs
val ts = coa.toArray
if (ts.length == ls.length && ts.isInstanceOf[Array[(Any, Any)]]) {
val ta = ts.asInstanceOf[Array[(Any, Any)]]
var eq = true
cfor(0)(_ < ta.length && eq, _ + 1) { i =>
eq = java.util.Objects.deepEquals(ta(i)._1, ls(i)) && java.util.Objects.deepEquals(ta(i)._2, rs(i))
}
eq
} else {
false
}
}

}

private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
private[sigma] class CollOverArrayBuilder extends CollBuilder {
builder =>

@inline override def pairColl[@specialized A, @specialized B](as: Coll[A], bs: Coll[B]): PairColl[A, B] = {
if (VersionContext.current.isJitActivated) {
Expand All @@ -170,20 +207,20 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
}
}

private def fromBoxedPairs[A, B](seq: Seq[(A, B)])(implicit tA: RType[A], tB: RType[B]): PairColl[A,B] = {
private def fromBoxedPairs[A, B](seq: Seq[(A, B)])(implicit tA: RType[A], tB: RType[B]): PairColl[A, B] = {
val len = seq.length
val resA = Array.ofDim[A](len)(tA.classTag)
val resB = Array.ofDim[B](len)(tB.classTag)
cfor(0)(_ < len, _ + 1) { i =>
val item = seq.apply(i).asInstanceOf[(A,B)]
val item = seq.apply(i).asInstanceOf[(A, B)]
resA(i) = item._1
resB(i) = item._2
}
pairCollFromArrays(resA, resB)(tA, tB)
}

override def fromItems[T](items: T*)(implicit cT: RType[T]): Coll[T] = cT match {
case pt: PairType[a,b] =>
case pt: PairType[a, b] =>
val tA = pt.tFst
val tB = pt.tSnd
fromBoxedPairs(items)(tA, tB)
Expand All @@ -192,16 +229,16 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
}

override def fromArray[@specialized T: RType](arr: Array[T]): Coll[T] = RType[T] match {
case pt: PairType[a,b] =>
case pt: PairType[a, b] =>
val tA = pt.tFst
val tB = pt.tSnd
fromBoxedPairs[a,b](arr.asInstanceOf[Array[(a,b)]])(tA, tB)
fromBoxedPairs[a, b](arr.asInstanceOf[Array[(a, b)]])(tA, tB)
case _ =>
new CollOverArray(arr, builder)
}

override def replicate[@specialized T: RType](n: Int, v: T): Coll[T] = RType[T] match {
case pt: PairType[a,b] =>
case pt: PairType[a, b] =>
val tA = pt.tFst
val tB = pt.tSnd
val tuple = v.asInstanceOf[(a, b)]
Expand All @@ -210,8 +247,8 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
fromArray(Array.fill(n)(v))
}

override def unzip[@specialized A, @specialized B](xs: Coll[(A,B)]): (Coll[A], Coll[B]) = xs match {
case pa: PairColl[_,_] => (pa.ls, pa.rs)
override def unzip[@specialized A, @specialized B](xs: Coll[(A, B)]): (Coll[A], Coll[B]) = xs match {
case pa: PairColl[_, _] => (pa.ls, pa.rs)
case _ =>
val limit = xs.length
implicit val tA = xs.tItem.tFst
Expand All @@ -230,7 +267,7 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
left.zip(right).map { case (l, r) => (l ^ r).toByte }

override def emptyColl[T](implicit cT: RType[T]): Coll[T] = cT match {
case pt: PairType[a,b] =>
case pt: PairType[a, b] =>
val ls = emptyColl(pt.tFst)
val rs = emptyColl(pt.tSnd)
pairColl(ls, rs).asInstanceOf[Coll[T]]
Expand All @@ -239,24 +276,36 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
}
}

class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R]) extends PairColl[L,R] {
class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R]) extends PairColl[L, R] {

override def equals(that: scala.Any) = (this eq that.asInstanceOf[AnyRef]) || (that match {
case that: PairColl[_,_] if that.tItem == this.tItem => ls == that.ls && rs == that.rs
override def equals(that: scala.Any): Boolean = (this eq that.asInstanceOf[AnyRef]) || (that match {
case that: PairColl[_, _] if that.tItem == this.tItem =>
ls == that.ls && rs == that.rs
case that: CollOverArray[Any] if that.tItem == this.tItem =>
if (VersionContext.current.isV6SoftForkActivated) {
equalsPairCollWithCollOverArray(this.asInstanceOf[PairColl[Any, Any]], that)
} else {
false
}
case _ => false
})

override def hashCode() = ls.hashCode() * 41 + rs.hashCode()

@inline implicit def tL: RType[L] = ls.tItem

@inline implicit def tR: RType[R] = rs.tItem

override lazy val tItem: RType[(L, R)] = {
RType.pairRType(tL, tR)
}

override def builder: CollBuilder = ls.builder

override def toArray: Array[(L, R)] = ls.toArray.zip(rs.toArray)

@inline override def length: Int = if (ls.length <= rs.length) ls.length else rs.length

@inline override def apply(i: Int): (L, R) = (ls(i), rs(i))

override def isEmpty: Boolean = length == 0
Expand Down Expand Up @@ -304,7 +353,7 @@ class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R
true
}

override def filter(p: ((L, R)) => Boolean): Coll[(L,R)] = {
override def filter(p: ((L, R)) => Boolean): Coll[(L, R)] = {
val len = ls.length
val resL: Buffer[L] = Buffer.empty[L](ls.tItem.classTag)
val resR: Buffer[R] = Buffer.empty[R](rs.tItem.classTag)
Expand Down Expand Up @@ -333,9 +382,9 @@ class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R
state
}

override def slice(from: Int, until: Int): PairColl[L,R] = builder.pairColl(ls.slice(from, until), rs.slice(from, until))
override def slice(from: Int, until: Int): PairColl[L, R] = builder.pairColl(ls.slice(from, until), rs.slice(from, until))

def append(other: Coll[(L, R)]): Coll[(L,R)] = {
def append(other: Coll[(L, R)]): Coll[(L, R)] = {
val arrs = builder.unzip(other)
builder.pairColl(ls.append(arrs._1), rs.append(arrs._2))
}
Expand All @@ -352,7 +401,7 @@ class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R
}
}

def zip[@specialized B](ys: Coll[B]): PairColl[(L,R), B] = builder.pairColl(this, ys)
def zip[@specialized B](ys: Coll[B]): PairColl[(L, R), B] = builder.pairColl(this, ys)

def startsWith(ys: Coll[(L, R)]): Boolean = ys match {
case yp: PairOfCols[L, R] => ls.startsWith(yp.ls) && rs.startsWith(yp.rs)
Expand Down Expand Up @@ -408,18 +457,20 @@ class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R
}

override def unionSet(that: Coll[(L, R)]): Coll[(L, R)] = {
val set = new java.util.HashSet[(L,R)](32)
val set = new java.util.HashSet[(L, R)](32)
implicit val ctL = ls.tItem.classTag
implicit val ctR = rs.tItem.classTag
val resL = Buffer.empty[L]
val resR = Buffer.empty[R]
def addToSet(item: (L,R)) = {

def addToSet(item: (L, R)) = {
if (!set.contains(item)) {
set.add(item)
resL += item._1
resR += item._2
}
}

var i = 0
val thisLen = math.min(ls.length, rs.length)
while (i < thisLen) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import org.ergoplatform.settings.ErgoAlgos
import scorex.util.encode.Base16
import scorex.util.{ModifierId, Random}
import sigma.Extensions._
import sigma.SigmaDslTesting
import sigma.ast.SCollection.SByteArray
import sigma.{SigmaDslTesting, VersionContext}
import sigma.ast.SType._
import sigma.ast.syntax.{ErgoBoxCandidateRType, TrueSigmaProp}
import sigma.ast._
Expand All @@ -20,9 +19,11 @@ import sigmastate.helpers.TestingHelpers.copyTransaction
import sigmastate.utils.Helpers
import sigma.SigmaDslTesting
import sigma.Extensions._
import sigma.ast.SCollection.SByteArray
import sigmastate.CrossVersionProps
import sigmastate.utils.Helpers.EitherOps // required for Scala 2.11

class ErgoLikeTransactionSpec extends SigmaDslTesting with JsonCodecs {
class ErgoLikeTransactionSpec extends SigmaDslTesting with CrossVersionProps with JsonCodecs {

property("ErgoBox test vectors") {
val token1 = "6e789ab7b2fffff12280a6cd01557f6fb22b7f80ff7aff8e1f7f15973d7f0001"
Expand Down Expand Up @@ -99,14 +100,24 @@ class ErgoLikeTransactionSpec extends SigmaDslTesting with JsonCodecs {

{ // test case for R2
val res = b1.get(ErgoBox.R2).get
val exp = Coll(
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token1).toColl) -> 10000000L,
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token2).toColl) -> 500L
).map(identity).toConstant
// TODO v6.0 (16h): fix collections equality and remove map(identity)
// (PairOfColl should be equal CollOverArray but now it is not)

// We have versioned check here due to fixed collections equality in 6.0.0
// (PairOfColl equal CollOverArray now)
// see (https://github.com/ScorexFoundation/sigmastate-interpreter/issues/909)
res shouldBe exp
if(VersionContext.current.isV6SoftForkActivated) {
val exp = Coll(
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token1).toColl) -> 10000000L,
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token2).toColl) -> 500L
).toConstant
res shouldBe exp
exp shouldBe res
} else {
val exp = Coll(
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token1).toColl) -> 10000000L,
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token2).toColl) -> 500L
).map(identity).toConstant
res shouldBe exp
}
}

{ // test case for R3
Expand Down Expand Up @@ -470,7 +481,6 @@ class ErgoLikeTransactionSpec extends SigmaDslTesting with JsonCodecs {
// test equivalence of "from Json" and "from bytes" deserialization
tx2.id shouldBe tx.id
tx2.id shouldBe "d5c0a7908bbb8eefe72ad70a9f668dd47b748239fd34378d3588d5625dd75c82"
println(tx2.id)
}

property("Tuple in register test vector") {
Expand Down

0 comments on commit 571e721

Please sign in to comment.