Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[6.0] Improve collections equality #1011

Merged
merged 7 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 72 additions & 21 deletions core/shared/src/main/scala/sigma/data/CollsOverArrays.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package sigma.data

import debox.{Buffer, cfor}
import sigma.Evaluation.stypeToRType
import sigma.data.CollOverArray.equalsPairCollWithCollOverArray
import sigma.data.RType._
import sigma.util.{CollectionUtil, MaxArrayLength, safeConcatArrays_v5}
import sigma.{Coll, CollBuilder, PairColl, VersionContext, requireSameLength}
Expand All @@ -12,7 +14,9 @@ class CollOverArray[@specialized A](val toArray: Array[A], val builder: CollBuil
s"Cannot create collection with size ${toArray.length} greater than $MaxArrayLength")

override def tItem: RType[A] = tA

@inline def length: Int = toArray.length

@inline def apply(i: Int): A = toArray.apply(i)

override def isEmpty: Boolean = length == 0
Expand All @@ -29,8 +33,11 @@ class CollOverArray[@specialized A](val toArray: Array[A], val builder: CollBuil
}

def foreach(f: A => Unit): Unit = toArray.foreach(f)

def exists(p: A => Boolean): Boolean = toArray.exists(p)

def forall(p: A => Boolean): Boolean = toArray.forall(p)

def filter(p: A => Boolean): Coll[A] = builder.fromArray(toArray.filter(p))

def foldLeft[B](zero: B, op: ((B, A)) => B): B = toArray.foldLeft(zero)((b, a) => op((b, a)))
Expand Down Expand Up @@ -117,12 +124,14 @@ class CollOverArray[@specialized A](val toArray: Array[A], val builder: CollBuil
override def unionSet(that: Coll[A]): Coll[A] = {
val set = debox.Set.ofSize[A](this.length)
val res = Buffer.ofSize[A](this.length)

@inline def addItemToSet(x: A) = {
if (!set(x)) {
set.add(x)
res += x
}
}

def addToSet(arr: Array[A]) = {
val limit = arr.length
cfor(0)(_ < limit, _ + 1) { i =>
Expand All @@ -139,14 +148,42 @@ class CollOverArray[@specialized A](val toArray: Array[A], val builder: CollBuil

override def equals(obj: scala.Any): Boolean = (this eq obj.asInstanceOf[AnyRef]) || (obj match {
case obj: CollOverArray[_] if obj.tItem == this.tItem =>
java.util.Objects.deepEquals(obj.toArray, toArray)
java.util.Objects.deepEquals(obj.toArray, this.toArray)
case obj: PairColl[Any, Any] if obj.tItem == this.tItem =>
if (VersionContext.current.isV6SoftForkActivated) {
equalsPairCollWithCollOverArray(obj, this.asInstanceOf[CollOverArray[Any]])
} else {
false
}
case _ => false
})

override def hashCode() = CollectionUtil.deepHashCode(toArray)
override def hashCode(): Int = CollectionUtil.deepHashCode(toArray)
}

object CollOverArray {

// comparing PairColl and CollOverArray instances
private[data] def equalsPairCollWithCollOverArray(pc: PairColl[Any, Any], coa: CollOverArray[Any]): Boolean = {
val ls = pc.ls
val rs = pc.rs
val ts = coa.toArray
if (ts.length == ls.length && ts.isInstanceOf[Array[(Any, Any)]]) {
val ta = ts.asInstanceOf[Array[(Any, Any)]]
var eq = true
cfor(0)(_ < ta.length && eq, _ + 1) { i =>
eq = java.util.Objects.deepEquals(ta(i)._1, ls(i)) && java.util.Objects.deepEquals(ta(i)._2, rs(i))
}
eq
} else {
false
}
}

}

private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
private[sigma] class CollOverArrayBuilder extends CollBuilder {
builder =>

@inline override def pairColl[@specialized A, @specialized B](as: Coll[A], bs: Coll[B]): PairColl[A, B] = {
if (VersionContext.current.isJitActivated) {
Expand All @@ -170,20 +207,20 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
}
}

private def fromBoxedPairs[A, B](seq: Seq[(A, B)])(implicit tA: RType[A], tB: RType[B]): PairColl[A,B] = {
private def fromBoxedPairs[A, B](seq: Seq[(A, B)])(implicit tA: RType[A], tB: RType[B]): PairColl[A, B] = {
val len = seq.length
val resA = Array.ofDim[A](len)(tA.classTag)
val resB = Array.ofDim[B](len)(tB.classTag)
cfor(0)(_ < len, _ + 1) { i =>
val item = seq.apply(i).asInstanceOf[(A,B)]
val item = seq.apply(i).asInstanceOf[(A, B)]
resA(i) = item._1
resB(i) = item._2
}
pairCollFromArrays(resA, resB)(tA, tB)
}

override def fromItems[T](items: T*)(implicit cT: RType[T]): Coll[T] = cT match {
case pt: PairType[a,b] =>
case pt: PairType[a, b] =>
val tA = pt.tFst
val tB = pt.tSnd
fromBoxedPairs(items)(tA, tB)
Expand All @@ -192,16 +229,16 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
}

override def fromArray[@specialized T: RType](arr: Array[T]): Coll[T] = RType[T] match {
case pt: PairType[a,b] =>
case pt: PairType[a, b] =>
val tA = pt.tFst
val tB = pt.tSnd
fromBoxedPairs[a,b](arr.asInstanceOf[Array[(a,b)]])(tA, tB)
fromBoxedPairs[a, b](arr.asInstanceOf[Array[(a, b)]])(tA, tB)
case _ =>
new CollOverArray(arr, builder)
}

override def replicate[@specialized T: RType](n: Int, v: T): Coll[T] = RType[T] match {
case pt: PairType[a,b] =>
case pt: PairType[a, b] =>
val tA = pt.tFst
val tB = pt.tSnd
val tuple = v.asInstanceOf[(a, b)]
Expand All @@ -210,8 +247,8 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
fromArray(Array.fill(n)(v))
}

override def unzip[@specialized A, @specialized B](xs: Coll[(A,B)]): (Coll[A], Coll[B]) = xs match {
case pa: PairColl[_,_] => (pa.ls, pa.rs)
override def unzip[@specialized A, @specialized B](xs: Coll[(A, B)]): (Coll[A], Coll[B]) = xs match {
case pa: PairColl[_, _] => (pa.ls, pa.rs)
case _ =>
val limit = xs.length
implicit val tA = xs.tItem.tFst
Expand All @@ -230,7 +267,7 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
left.zip(right).map { case (l, r) => (l ^ r).toByte }

override def emptyColl[T](implicit cT: RType[T]): Coll[T] = cT match {
case pt: PairType[a,b] =>
case pt: PairType[a, b] =>
val ls = emptyColl(pt.tFst)
val rs = emptyColl(pt.tSnd)
pairColl(ls, rs).asInstanceOf[Coll[T]]
Expand All @@ -239,24 +276,36 @@ private[sigma] class CollOverArrayBuilder extends CollBuilder { builder =>
}
}

class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R]) extends PairColl[L,R] {
class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R]) extends PairColl[L, R] {

override def equals(that: scala.Any) = (this eq that.asInstanceOf[AnyRef]) || (that match {
case that: PairColl[_,_] if that.tItem == this.tItem => ls == that.ls && rs == that.rs
override def equals(that: scala.Any): Boolean = (this eq that.asInstanceOf[AnyRef]) || (that match {
case that: PairColl[_, _] if that.tItem == this.tItem =>
ls == that.ls && rs == that.rs
case that: CollOverArray[Any] if that.tItem == this.tItem =>
if (VersionContext.current.isV6SoftForkActivated) {
equalsPairCollWithCollOverArray(this.asInstanceOf[PairColl[Any, Any]], that)
} else {
false
}
case _ => false
})

override def hashCode() = ls.hashCode() * 41 + rs.hashCode()

@inline implicit def tL: RType[L] = ls.tItem

@inline implicit def tR: RType[R] = rs.tItem

override lazy val tItem: RType[(L, R)] = {
RType.pairRType(tL, tR)
}

override def builder: CollBuilder = ls.builder

override def toArray: Array[(L, R)] = ls.toArray.zip(rs.toArray)

@inline override def length: Int = if (ls.length <= rs.length) ls.length else rs.length

@inline override def apply(i: Int): (L, R) = (ls(i), rs(i))

override def isEmpty: Boolean = length == 0
Expand Down Expand Up @@ -304,7 +353,7 @@ class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R
true
}

override def filter(p: ((L, R)) => Boolean): Coll[(L,R)] = {
override def filter(p: ((L, R)) => Boolean): Coll[(L, R)] = {
val len = ls.length
val resL: Buffer[L] = Buffer.empty[L](ls.tItem.classTag)
val resR: Buffer[R] = Buffer.empty[R](rs.tItem.classTag)
Expand Down Expand Up @@ -333,9 +382,9 @@ class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R
state
}

override def slice(from: Int, until: Int): PairColl[L,R] = builder.pairColl(ls.slice(from, until), rs.slice(from, until))
override def slice(from: Int, until: Int): PairColl[L, R] = builder.pairColl(ls.slice(from, until), rs.slice(from, until))

def append(other: Coll[(L, R)]): Coll[(L,R)] = {
def append(other: Coll[(L, R)]): Coll[(L, R)] = {
val arrs = builder.unzip(other)
builder.pairColl(ls.append(arrs._1), rs.append(arrs._2))
}
Expand All @@ -352,7 +401,7 @@ class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R
}
}

def zip[@specialized B](ys: Coll[B]): PairColl[(L,R), B] = builder.pairColl(this, ys)
def zip[@specialized B](ys: Coll[B]): PairColl[(L, R), B] = builder.pairColl(this, ys)

def startsWith(ys: Coll[(L, R)]): Boolean = ys match {
case yp: PairOfCols[L, R] => ls.startsWith(yp.ls) && rs.startsWith(yp.rs)
Expand Down Expand Up @@ -408,18 +457,20 @@ class PairOfCols[@specialized L, @specialized R](val ls: Coll[L], val rs: Coll[R
}

override def unionSet(that: Coll[(L, R)]): Coll[(L, R)] = {
val set = new java.util.HashSet[(L,R)](32)
val set = new java.util.HashSet[(L, R)](32)
implicit val ctL = ls.tItem.classTag
implicit val ctR = rs.tItem.classTag
val resL = Buffer.empty[L]
val resR = Buffer.empty[R]
def addToSet(item: (L,R)) = {

def addToSet(item: (L, R)) = {
if (!set.contains(item)) {
set.add(item)
resL += item._1
resR += item._2
}
}

var i = 0
val thisLen = math.min(ls.length, rs.length)
while (i < thisLen) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ import org.ergoplatform.settings.ErgoAlgos
import scorex.util.encode.Base16
import scorex.util.{ModifierId, Random}
import sigma.Extensions._
import sigma.SigmaDslTesting
import sigma.ast.SCollection.SByteArray
import sigma.{SigmaDslTesting, VersionContext}
import sigma.ast.SType._
import sigma.ast.syntax.{ErgoBoxCandidateRType, TrueSigmaProp}
import sigma.ast._
Expand All @@ -20,9 +19,11 @@ import sigmastate.helpers.TestingHelpers.copyTransaction
import sigmastate.utils.Helpers
import sigma.SigmaDslTesting
import sigma.Extensions._
import sigma.ast.SCollection.SByteArray
import sigmastate.CrossVersionProps
import sigmastate.utils.Helpers.EitherOps // required for Scala 2.11

class ErgoLikeTransactionSpec extends SigmaDslTesting with JsonCodecs {
class ErgoLikeTransactionSpec extends SigmaDslTesting with CrossVersionProps with JsonCodecs {

property("ErgoBox test vectors") {
val token1 = "6e789ab7b2fffff12280a6cd01557f6fb22b7f80ff7aff8e1f7f15973d7f0001"
Expand Down Expand Up @@ -99,14 +100,24 @@ class ErgoLikeTransactionSpec extends SigmaDslTesting with JsonCodecs {

{ // test case for R2
val res = b1.get(ErgoBox.R2).get
val exp = Coll(
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token1).toColl) -> 10000000L,
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token2).toColl) -> 500L
).map(identity).toConstant
// TODO v6.0 (16h): fix collections equality and remove map(identity)
// (PairOfColl should be equal CollOverArray but now it is not)

// We have versioned check here due to fixed collections equality in 6.0.0
// (PairOfColl equal CollOverArray now)
// see (https://github.com/ScorexFoundation/sigmastate-interpreter/issues/909)
res shouldBe exp
if(VersionContext.current.isV6SoftForkActivated) {
val exp = Coll(
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token1).toColl) -> 10000000L,
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token2).toColl) -> 500L
).toConstant
res shouldBe exp
exp shouldBe res
} else {
val exp = Coll(
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token1).toColl) -> 10000000L,
(Digest32Coll @@ ErgoAlgos.decodeUnsafe(token2).toColl) -> 500L
).map(identity).toConstant
res shouldBe exp
}
}

{ // test case for R3
Expand Down Expand Up @@ -470,7 +481,6 @@ class ErgoLikeTransactionSpec extends SigmaDslTesting with JsonCodecs {
// test equivalence of "from Json" and "from bytes" deserialization
tx2.id shouldBe tx.id
tx2.id shouldBe "d5c0a7908bbb8eefe72ad70a9f668dd47b748239fd34378d3588d5625dd75c82"
println(tx2.id)
}

property("Tuple in register test vector") {
Expand Down
Loading