From 48db43217fb1d3bf8846fb310e6520c4d30a2cb3 Mon Sep 17 00:00:00 2001 From: Alexander Chepurnoy Date: Wed, 11 Sep 2024 15:02:27 +0300 Subject: [PATCH] shiftLeft/shiftRight shift limit --- .../main/scala/sigma/data/BigIntegerOps.scala | 18 ++++++-- .../main/scala/sigma/data/ExactIntegral.scala | 42 +++++++++++++++---- 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/data/shared/src/main/scala/sigma/data/BigIntegerOps.scala b/data/shared/src/main/scala/sigma/data/BigIntegerOps.scala index 2e1d2f62ce..8d272439f4 100644 --- a/data/shared/src/main/scala/sigma/data/BigIntegerOps.scala +++ b/data/shared/src/main/scala/sigma/data/BigIntegerOps.scala @@ -101,9 +101,21 @@ object NumericOps { override def bitwiseXor(x: BigInt, y: BigInt): BigInt = x.xor(y) - override def shiftLeft(x: BigInt, y: Int): BigInt = x.shiftLeft(y) - - override def shiftRight(x: BigInt, y: Int): BigInt = x.shiftRight(y) + override def shiftLeft(x: BigInt, bits: Int): BigInt = { + if (bits < 0 || bits >= 256) { + throw new IllegalArgumentException(s"Wrong argument in BigInt.shiftRight: bits < 0 || bits >= 256 ($bits)") + } else { + x.shiftLeft(bits) + } + } + + override def shiftRight(x: BigInt, bits: Int): BigInt = { + if (bits < 0 || bits >= 256) { + throw new IllegalArgumentException(s"Wrong argument in BigInt.shiftRight: bits < 0 || bits >= 256 ($bits)") + } else { + x.shiftRight(bits) + } + } } /** The instance of [[scalan.ExactOrdering]] typeclass for [[BigInt]]. */ diff --git a/data/shared/src/main/scala/sigma/data/ExactIntegral.scala b/data/shared/src/main/scala/sigma/data/ExactIntegral.scala index 4b0d9cb720..2bd7fbe341 100644 --- a/data/shared/src/main/scala/sigma/data/ExactIntegral.scala +++ b/data/shared/src/main/scala/sigma/data/ExactIntegral.scala @@ -43,9 +43,15 @@ object ExactIntegral { override def bitwiseOr(x: Byte, y: Byte): Byte = (x | y).toByte override def bitwiseAnd(x: Byte, y: Byte): Byte = (x & y).toByte override def bitwiseXor(x: Byte, y: Byte): Byte = (x ^ y).toByte - override def shiftLeft(x: Byte, bits: Int): Byte = (x << bits).toByte + override def shiftLeft(x: Byte, bits: Int): Byte = { + if (bits < 0 || bits >= 8) { + throw new IllegalArgumentException(s"Wrong argument in Byte.shiftRight: bits < 0 || bits >= 8 ($bits)") + } else { + (x << bits).toByte + } + } override def shiftRight(x: Byte, bits: Int): Byte = { - if (bits < 0 || bits >= 8){ + if (bits < 0 || bits >= 8) { throw new IllegalArgumentException(s"Wrong argument in Byte.shiftRight: bits < 0 || bits >= 8 ($bits)") } else { (x >> bits).toByte @@ -64,7 +70,13 @@ object ExactIntegral { override def bitwiseOr(x: Short, y: Short): Short = (x | y).toShort override def bitwiseAnd(x: Short, y: Short): Short = (x & y).toShort override def bitwiseXor(x: Short, y: Short): Short = (x ^ y).toShort - override def shiftLeft(x: Short, y: Int): Short = (x << y).toShort + override def shiftLeft(x: Short, bits: Int): Short = { + if (bits < 0 || bits >= 16) { + throw new IllegalArgumentException(s"Wrong argument in Short.shiftRight: bits < 0 || bits >= 16 ($bits)") + } else { + (x << bits).toShort + } + } override def shiftRight(x: Short, bits: Int): Short = { if (bits < 0 || bits >= 16){ throw new IllegalArgumentException(s"Wrong argument in Short.shiftRight: bits < 0 || bits >= 16 ($bits)") @@ -85,9 +97,17 @@ object ExactIntegral { override def bitwiseOr(x: Int, y: Int): Int = x | y override def bitwiseAnd(x: Int, y: Int): Int = x & y override def bitwiseXor(x: Int, y: Int): Int = x ^ y - override def shiftLeft(x: Int, y: Int): Int = x << y + + override def shiftLeft(x: Int, bits: Int): Int = { + if (bits < 0 || bits >= 32) { + throw new IllegalArgumentException(s"Wrong argument in Byte.shiftRight: bits < 0 || bits >= 32 ($bits)") + } else { + x << bits + } + } + override def shiftRight(x: Int, bits: Int): Int = { - if (bits < 0 || bits >= 32){ + if (bits < 0 || bits >= 32) { throw new IllegalArgumentException(s"Wrong argument in Int.shiftRight: bits < 0 || bits >= 32 ($bits)") } else { x >> bits @@ -106,9 +126,17 @@ object ExactIntegral { override def bitwiseOr(x: Long, y: Long): Long = x | y override def bitwiseAnd(x: Long, y: Long): Long = x & y override def bitwiseXor(x: Long, y: Long): Long = x ^ y - override def shiftLeft(x: Long, y: Int): Long = x << y + + override def shiftLeft(x: Long, bits: Int): Long = { + if (bits < 0 || bits >= 64) { + throw new IllegalArgumentException(s"Wrong argument in Long.shiftRight: bits < 0 || bits >= 64 ($bits)") + } else { + x << bits + } + } + override def shiftRight(x: Long, bits: Int): Long = { - if (bits < 0 || bits >= 64){ + if (bits < 0 || bits >= 64) { throw new IllegalArgumentException(s"Wrong argument in Long.shiftRight: bits < 0 || bits >= 64 ($bits)") } else { x >> bits