Skip to content

Commit

Permalink
Update to Scala 3.4.0
Browse files Browse the repository at this point in the history
Due to multiple new warnings:

- Added infix modifier for commonly used infix operators and types
- Replaced `: _*` for `*`

Due to new compilation errors:

- Added Promotion `(T, DType) => T` -- match types are behaving slightly
different now.
  - Although I'm still unsure what are the implications of this change
  everything runs properly
  • Loading branch information
davoclavo committed Jun 27, 2024
1 parent 2dfa388 commit db2dba6
Show file tree
Hide file tree
Showing 12 changed files with 21 additions and 21 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ val pytorchVersion = "2.1.2"
val cudaVersion = "12.3-8.9"
val openblasVersion = "0.3.26"
val mklVersion = "2024.0"
ThisBuild / scalaVersion := "3.3.1"
ThisBuild / scalaVersion := "3.4.0"
ThisBuild / javaCppVersion := "1.5.10"
ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots")

Expand Down
1 change: 1 addition & 0 deletions core/src/main/scala/torch/DType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ type DTypeOrDeriveArange[
* rules](https://github.com/pytorch/pytorch/blob/fb6749d977e33b5f463c2d0a1b56a939428105e5/c10/core/ScalarType.h#L423-L444)
*/
type Promoted[T <: DType, U <: DType] <: DType = (T, U) match
case (T, DType) => T
case (T, T) => T
case (U, U) => U
case (Undefined, U) | (T, Undefined) => Undefined
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/torch/Tensor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto
case _ => false

/** True if `other` has the same size and elements as this tensor, false otherwise. */
def equal(other: Tensor[D]): Boolean = native.equal(other.native)
infix def equal(other: Tensor[D]): Boolean = native.equal(other.native)

/** Returns the tensor with elements exponentiated. */
def exp: Tensor[D] = fromNative(native.exp())
Expand Down Expand Up @@ -415,7 +415,7 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto

def <(other: ScalaType): Tensor[Bool] = lt(other)

def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] =
infix def matmul[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] =
fromNative(native.matmul(u.native))

def `@`[D2 <: DType](u: Tensor[D2]): Tensor[Promoted[D, D2]] = matmul(u)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ type AtLeastOneFloatOrComplex[A <: DType, B <: DType] = A <:< (FloatNN | Complex
B <:< (FloatNN | ComplexNN)

/* Evidence that two dtypes are not the same */
type NotEqual[D <: DType, D2 <: DType] = NotGiven[D =:= D2]
infix type NotEqual[D <: DType, D2 <: DType] = NotGiven[D =:= D2]
6 changes: 3 additions & 3 deletions core/src/main/scala/torch/nn/functional/Convolution.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private[torch] trait Convolution {
toArray(padding),
toArray(outputPadding),
groups,
toArray(dilation): _*
toArray(dilation)*
)
)

Expand Down Expand Up @@ -151,7 +151,7 @@ private[torch] trait Convolution {
toArray(padding),
toArray(outputPadding),
groups,
toArray(dilation): _*
toArray(dilation)*
)
)

Expand Down Expand Up @@ -179,7 +179,7 @@ private[torch] trait Convolution {
toArray(padding),
toArray(outputPadding),
groups,
toArray(dilation): _*
toArray(dilation)*
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ final class ModuleList[D <: DType](override val modules: TensorModule[D]*)
// TODO: not in Python code. Note other modules retain index, so we have repeats
this.register(module)(using Name(index.toString()))
// TODO: make modules list mutable?
ModuleList(all: _*)
ModuleList(all*)

/** Appends a given module to the end of the list.
*
Expand All @@ -94,7 +94,7 @@ final class ModuleList[D <: DType](override val modules: TensorModule[D]*)
this.register(module)(using Name(index.toString()))
val all = modules.appended(module)
// TODO: make modules list mutable?
ModuleList(all: _*)
ModuleList(all*)

/** Appends modules from a Python iterable to the end of the list.
*
Expand All @@ -115,7 +115,7 @@ final class ModuleList[D <: DType](override val modules: TensorModule[D]*)
this.register(module)(using Name(index.toString()))
)
// TODO: make modules list mutable?
ModuleList(all: _*)
ModuleList(all*)

override def hasBias(): Boolean = modules.exists(_.hasBias())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ final class LayerNorm[ParamType <: FloatNN | ComplexNN: Default](
) extends HasWeight[ParamType]
with TensorModule[ParamType]:

private val shape: LongVector = LongVector(normalizedShape.map(_.toLong): _*)
private val shape: LongVector = LongVector(normalizedShape.map(_.toLong)*)
private val options: LayerNormOptions = LayerNormOptions(shape)
options.eps().put(eps)
options.elementwise_affine().put(elementWiseAffine)
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/ops/ReductionOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,6 @@ private[torch] trait ReductionOps {
val nativeDim = dim.toArray
fromNative(
if nativeDim.isEmpty then torchNative.count_nonzero(input.native)
else torchNative.count_nonzero(input.native, nativeDim: _*)
else torchNative.count_nonzero(input.native, nativeDim*)
)
}
2 changes: 1 addition & 1 deletion core/src/test/scala/TrainingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package torch
import torch.data.*

class TraininSuite extends munit.FunSuite {
class TrainingSuite extends munit.FunSuite {
test("training") {

val xTrain = torch.arange(end = 10, dtype = float32) // .reshape(10, 1)
Expand Down
6 changes: 3 additions & 3 deletions core/src/test/scala/torch/ops/RandomSamplingOpsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ class RandomSamplingOpsSuite extends TensorCheckSuite {

val g1 = torch.Generator()
g1.manualSeed(0)
val t1 = torch.randint(high = 100, Seq(2, 2), generator = g1)
val t2 = torch.randint(high = 100, Seq(2, 2), generator = g1)
val t1 = torch.randint(high = 100, size = Seq(2, 2), generator = g1)
val t2 = torch.randint(high = 100, size = Seq(2, 2), generator = g1)
assertNotEquals(t1, t2)

val g2 = torch.Generator()
g2.manualSeed(0)
val t3 = torch.randint(high = 100, Seq(2, 2), generator = g2)
val t3 = torch.randint(high = 100, size = Seq(2, 2), generator = g2)
assertEquals(t1, t3)

}
Expand Down
6 changes: 3 additions & 3 deletions examples/src/main/scala/gpt/V2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ object V2:
*/

// here are all the unique characters that occur in this text
val chars = SortedSet(text: _*)
val chars = SortedSet(text*)
println(s"chars = ${chars.mkString(", ")}")
val vocab_size = chars.size
println(s"vocab_size = $vocab_size")
Expand Down Expand Up @@ -413,7 +413,7 @@ object V2:
Utils.register_i(this, Head_2(nEmbed, headSize, blockSize), i)
}
// val hs = 0 until numHeads map{ i => Utils.register_i(this, Head(nEmbed, headSize, blockSize, dropout), i) }
val heads = register(nn.ModuleList(hs: _*))
val heads = register(nn.ModuleList(hs*))
// TODO: BUG - self.proj = nn.Linear(head_size * num_heads, n_embd)
val proj = register(nn.Linear(headSize * numHeads, nEmbed))
// val proj = register( nn.Linear(nEmbed, nEmbed) )
Expand Down Expand Up @@ -603,7 +603,7 @@ object V2:
val token_embedding_table = register(nn.Embedding(vocabSize, nEmbed))
val position_embedding_table = register(nn.Embedding(blockSize, nEmbed))
val blocks_i = 0 until nBlocks map { i => Block(nEmbed, nHead, blockSize, vocabSize, dropout) }
val blocks = register(nn.Sequential(blocks_i: _*))
val blocks = register(nn.Sequential(blocks_i*))
val ln_f = register(nn.LayerNorm(Seq(nEmbed)))
val lm_head = register(nn.Linear(nEmbed, vocabSize))

Expand Down
3 changes: 1 addition & 2 deletions vision/src/main/scala/torchvision/transforms/presets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package transforms

import com.sksamuel.scrimage.ImmutableImage
import com.sksamuel.scrimage.ScaleMethod
import torch.Tensor
import torch.Float32
import torch.{DType, Float32, Promoted, Tensor}
import torchvision.transforms.functional.toTensor

object Presets:
Expand Down

0 comments on commit db2dba6

Please sign in to comment.