Skip to content

Commit

Permalink
Update to PyTorch 2.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunk committed Oct 29, 2023
1 parent 05078a8 commit 72d7e30
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 8 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ ThisBuild / tlSitePublishBranch := Some("main")
ThisBuild / apiURL := Some(new URL("https://storch.dev/api/"))

val scrImageVersion = "4.0.34"
val pytorchVersion = "2.0.1"
val cudaVersion = "12.1-8.9"
val pytorchVersion = "2.1.0"
val cudaVersion = "12.3-8.9"
val openblasVersion = "0.3.23"
val mklVersion = "2023.1"
ThisBuild / scalaVersion := "3.3.1"
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/torch/internal/NativeConverters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.bytedeco.pytorch.{
DeviceOptional,
DoubleOptional,
BoolOptional,
LongArrayRefOptional,
LongOptional,
TensorOptional
}
Expand Down Expand Up @@ -60,6 +61,8 @@ private[torch] object NativeConverters:
pytorch.ScalarOptional(scalar)
)

extension (i: Int) def toScalarOptional = pytorch.ScalarOptional(pytorch.Scalar(i))

extension [D <: DType](t: Tensor[D] | Option[Tensor[D]])
def toOptional: TensorOptional =
convertToOptional(t, t => pytorch.TensorOptional(t.native))
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/ops/PointwiseOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ private[torch] trait PointwiseOps {
*
* @group pointwise_ops
*/
def nanToNum[D <: RealNN](
def nanToNum[D <: DType](
input: Tensor[D],
nan: Option[Double] = None,
posinf: Option[Double] = None,
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/torch/ops/ReductionOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ private[torch] trait ReductionOps {
torchNative.std(
input.native,
dim.toArray,
correction.toOptional,
correction.toScalarOptional,
keepdim
)
)
Expand Down Expand Up @@ -722,7 +722,7 @@ private[torch] trait ReductionOps {
torchNative.std_mean(
input.native,
dim.toArray,
correction.toOptional,
correction.toScalarOptional,
keepdim
)
(fromNative[D](nativeTuple.get0), fromNative[D](nativeTuple.get1))
Expand Down Expand Up @@ -828,7 +828,7 @@ private[torch] trait ReductionOps {
torchNative.`var`(
input.native,
dim.toArray,
correction.toOptional,
correction.toScalarOptional,
keepdim
)
)
Expand Down Expand Up @@ -867,7 +867,7 @@ private[torch] trait ReductionOps {
torchNative.var_mean(
input.native,
dim.toArray,
correction.toOptional,
correction.toScalarOptional,
keepdim
)
(fromNative[D](nativeTuple.get0), fromNative[D](nativeTuple.get1))
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/torch/special/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ package object special:
fromNative(torchNative.exp2(input.native))

/** Computes the exponential of the elements minus 1 of `input`. */
def expm1[D <: RealNN](input: Tensor[D]): Tensor[FloatPromoted[D]] =
def expm1[D <: DType](input: Tensor[D]): Tensor[FloatPromoted[D]] =
fromNative(torchNative.expm1(input.native))

/** Computes the zeroth order modified Bessel function of the first kind for each element of
Expand Down

0 comments on commit 72d7e30

Please sign in to comment.