From 0be74e7b99f7f73207353dae3945259e20576a66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Mon, 5 Feb 2024 17:28:34 +0100 Subject: [PATCH] Fix tests --- core/src/main/scala/torch/hub.scala | 3 +-- core/src/test/scala/TrainingSuite.scala | 8 ++++---- core/src/test/scala/torch/DeviceSuite.scala | 8 +------- core/src/test/scala/torch/TensorCheckSuite.scala | 2 +- core/src/test/scala/torch/TensorSuite.scala | 3 --- 5 files changed, 7 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/torch/hub.scala b/core/src/main/scala/torch/hub.scala index 4f1a27c1..f1a49df4 100644 --- a/core/src/main/scala/torch/hub.scala +++ b/core/src/main/scala/torch/hub.scala @@ -34,7 +34,6 @@ object hub: if !os.exists(cachedFile) then System.err.println(s"Downloading: $url to $cachedFile") Using.resource(URL(url).openStream()) { inputStream => - Files.copy(inputStream, cachedFile.toNIO) - () + val _ = Files.copy(inputStream, cachedFile.toNIO) } torch.pickleLoad(cachedFile.toNIO) diff --git a/core/src/test/scala/TrainingSuite.scala b/core/src/test/scala/TrainingSuite.scala index 808669a0..796f441d 100644 --- a/core/src/test/scala/TrainingSuite.scala +++ b/core/src/test/scala/TrainingSuite.scala @@ -28,8 +28,8 @@ class TraininSuite extends munit.FunSuite { torch.manualSeed(1) - var weight = torch.randn(Seq(1), requiresGrad = true) - var bias = torch.zeros(Seq(1), requiresGrad = true) + val weight = torch.randn(Seq(1), requiresGrad = true) + val bias = torch.zeros(Seq(1), requiresGrad = true) def model(xb: Tensor[Float32]): Tensor[Float32] = (xb matmul weight) + bias @@ -57,11 +57,11 @@ class TraininSuite extends munit.FunSuite { noGrad { weight.grad.foreach { grad => weight -= grad * learningRate - grad.zero() + grad.zero_() } bias.grad.foreach { grad => weight -= grad * learningRate - grad.zero() + grad.zero_() } } loss diff --git a/core/src/test/scala/torch/DeviceSuite.scala b/core/src/test/scala/torch/DeviceSuite.scala index 2642a875..81866c1c 100644 --- a/core/src/test/scala/torch/DeviceSuite.scala +++ b/core/src/test/scala/torch/DeviceSuite.scala @@ -17,15 +17,9 @@ package torch import munit.ScalaCheckSuite -import torch.DeviceType.CUDA import org.scalacheck.Prop.* -import org.bytedeco.pytorch.global.torch as torch_native -import org.scalacheck.{Arbitrary, Gen} import org.scalacheck._ -import Gen._ -import Arbitrary.arbitrary -import DeviceType.CPU -import Generators.{*, given} +import Generators.given class DeviceSuite extends ScalaCheckSuite { test("device native roundtrip") { diff --git a/core/src/test/scala/torch/TensorCheckSuite.scala b/core/src/test/scala/torch/TensorCheckSuite.scala index da620258..81c86902 100644 --- a/core/src/test/scala/torch/TensorCheckSuite.scala +++ b/core/src/test/scala/torch/TensorCheckSuite.scala @@ -19,7 +19,7 @@ package torch import munit.ScalaCheckSuite import shapeless3.typeable.{TypeCase, Typeable} import shapeless3.typeable.syntax.typeable.* -import Generators.{*, given} +import Generators.* import org.scalacheck.Prop.* import scala.util.Try diff --git a/core/src/test/scala/torch/TensorSuite.scala b/core/src/test/scala/torch/TensorSuite.scala index a6c3d746..a0141b33 100644 --- a/core/src/test/scala/torch/TensorSuite.scala +++ b/core/src/test/scala/torch/TensorSuite.scala @@ -16,9 +16,6 @@ package torch -import org.scalacheck.Prop.* -import Generators.given - class TensorSuite extends TensorCheckSuite { test("tensor properties") {