From 71751da75eea52370396989c02d3e5c8e5d106fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Sun, 4 Feb 2024 21:50:32 +0100 Subject: [PATCH 1/3] Add Scala Days video and update website - update sbt-typelevel plugin - adapt to changes from update to Laika 1.0 - remove custom html templates and use css/js overrides instead - add Apple Silicon docs --- .github/workflows/ci.yml | 163 +++++++++--------- docs/faq.md | 4 +- docs/installation.md | 6 +- examples/src/main/scala/ImageClassifier.scala | 9 +- examples/src/main/scala/LeNet.scala | 9 +- project/SiteSettings.scala | 57 ++++-- project/plugins.sbt | 8 +- site/src/css/custom.css | 34 ++++ site/src/default.template.html | 121 ------------- site/src/js/render-katex.js | 12 ++ site/src/landing-page.md | 13 ++ site/src/landing.template.html | 83 --------- 12 files changed, 203 insertions(+), 316 deletions(-) create mode 100644 site/src/css/custom.css delete mode 100644 site/src/default.template.html create mode 100644 site/src/js/render-katex.js create mode 100644 site/src/landing-page.md delete mode 100644 site/src/landing.template.html diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0eeb5fe2..e6bff67d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,55 +15,41 @@ on: tags: [v*] env: - PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} - SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} - SONATYPE_CREDENTIAL_HOST: ${{ secrets.SONATYPE_CREDENTIAL_HOST }} - SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} - PGP_SECRET: ${{ secrets.PGP_SECRET }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + +concurrency: + group: ${{ github.workflow }} @ ${{ github.ref }} + cancel-in-progress: true + jobs: build: name: Build and Test strategy: matrix: os: [ubuntu-latest] - scala: [3.3.1] + scala: [3] java: [temurin@11] runs-on: ${{ matrix.os }} + timeout-minutes: 60 steps: - name: Checkout current branch (full) - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Download Java (temurin@11) - id: download-java-temurin-11 - if: matrix.java == 'temurin@11' - uses: typelevel/download-java@v2 - with: - distribution: temurin - java-version: 11 - - name: Setup Java (temurin@11) + id: setup-java-temurin-11 if: matrix.java == 'temurin@11' - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: - distribution: jdkfile + distribution: temurin java-version: 11 - jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }} + cache: sbt - - name: Cache sbt - uses: actions/cache@v3 - with: - path: | - ~/.sbt - ~/.ivy2/cache - ~/.coursier/cache/v1 - ~/.cache/coursier/v1 - ~/AppData/Local/Coursier/Cache/v1 - ~/Library/Caches/Coursier/v1 - key: ${{ runner.os }}-sbt-cache-v2-${{ hashFiles('**/*.sbt') }}-${{ hashFiles('project/build.properties') }} + - name: sbt update + if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + run: sbt +update - name: Check that workflows are up to date run: sbt githubWorkflowCheck @@ -85,15 +71,15 @@ jobs: - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') - run: mkdir -p examples/target target site/target vision/target core/target project/target + run: mkdir -p vision/target core/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') - run: tar cf targets.tar examples/target target site/target vision/target core/target project/target + run: tar cf targets.tar vision/target core/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: target-${{ matrix.os }}-${{ matrix.java }}-${{ matrix.scala }} path: targets.tar @@ -109,64 +95,60 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Checkout current branch (full) - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Download Java (temurin@11) - id: download-java-temurin-11 + - name: Setup Java (temurin@11) + id: setup-java-temurin-11 if: matrix.java == 'temurin@11' - uses: typelevel/download-java@v2 + uses: actions/setup-java@v4 with: distribution: temurin java-version: 11 + cache: sbt - - name: Setup Java (temurin@11) - if: matrix.java == 'temurin@11' - uses: actions/setup-java@v3 - with: - distribution: jdkfile - java-version: 11 - jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }} + - name: sbt update + if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + run: sbt +update - - name: Cache sbt - uses: actions/cache@v3 + - name: Download target directories (3) + uses: actions/download-artifact@v4 with: - path: | - ~/.sbt - ~/.ivy2/cache - ~/.coursier/cache/v1 - ~/.cache/coursier/v1 - ~/AppData/Local/Coursier/Cache/v1 - ~/Library/Caches/Coursier/v1 - key: ${{ runner.os }}-sbt-cache-v2-${{ hashFiles('**/*.sbt') }}-${{ hashFiles('project/build.properties') }} - - - name: Download target directories (3.3.1) - uses: actions/download-artifact@v3 - with: - name: target-${{ matrix.os }}-${{ matrix.java }}-3.3.1 + name: target-${{ matrix.os }}-${{ matrix.java }}-3 - - name: Inflate target directories (3.3.1) + - name: Inflate target directories (3) run: | tar xf targets.tar rm targets.tar - name: Import signing key if: env.PGP_SECRET != '' && env.PGP_PASSPHRASE == '' - run: echo $PGP_SECRET | base64 -di | gpg --import + env: + PGP_SECRET: ${{ secrets.PGP_SECRET }} + PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} + run: echo $PGP_SECRET | base64 -d -i - | gpg --import - name: Import signing key and strip passphrase if: env.PGP_SECRET != '' && env.PGP_PASSPHRASE != '' + env: + PGP_SECRET: ${{ secrets.PGP_SECRET }} + PGP_PASSPHRASE: ${{ secrets.PGP_PASSPHRASE }} run: | - echo "$PGP_SECRET" | base64 -di > /tmp/signing-key.gpg + echo "$PGP_SECRET" | base64 -d -i - > /tmp/signing-key.gpg echo "$PGP_PASSPHRASE" | gpg --pinentry-mode loopback --passphrase-fd 0 --import /tmp/signing-key.gpg (echo "$PGP_PASSPHRASE"; echo; echo) | gpg --command-fd 0 --pinentry-mode loopback --change-passphrase $(gpg --list-secret-keys --with-colons 2> /dev/null | grep '^sec:' | cut --delimiter ':' --fields 5 | tail -n 1) - name: Publish + env: + SONATYPE_USERNAME: ${{ secrets.SONATYPE_USERNAME }} + SONATYPE_PASSWORD: ${{ secrets.SONATYPE_PASSWORD }} + SONATYPE_CREDENTIAL_HOST: ${{ secrets.SONATYPE_CREDENTIAL_HOST }} run: sbt tlCiRelease - site: - name: Generate Site + dependency-submission: + name: Submit Dependencies + if: github.event_name != 'pull_request' strategy: matrix: os: [ubuntu-latest] @@ -174,44 +156,61 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Checkout current branch (full) - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 - - name: Download Java (temurin@11) - id: download-java-temurin-11 + - name: Setup Java (temurin@11) + id: setup-java-temurin-11 if: matrix.java == 'temurin@11' - uses: typelevel/download-java@v2 + uses: actions/setup-java@v4 with: distribution: temurin java-version: 11 + cache: sbt + + - name: sbt update + if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + run: sbt +update + + - name: Submit Dependencies + uses: scalacenter/sbt-dependency-submission@v2 + with: + modules-ignore: examples_3 root_3 docs_3 + configs-ignore: test scala-tool scala-doc-tool test-internal + + site: + name: Generate Site + strategy: + matrix: + os: [ubuntu-latest] + java: [temurin@11] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout current branch (full) + uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Setup Java (temurin@11) + id: setup-java-temurin-11 if: matrix.java == 'temurin@11' - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: - distribution: jdkfile + distribution: temurin java-version: 11 - jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }} + cache: sbt - - name: Cache sbt - uses: actions/cache@v3 - with: - path: | - ~/.sbt - ~/.ivy2/cache - ~/.coursier/cache/v1 - ~/.cache/coursier/v1 - ~/AppData/Local/Coursier/Cache/v1 - ~/Library/Caches/Coursier/v1 - key: ${{ runner.os }}-sbt-cache-v2-${{ hashFiles('**/*.sbt') }}-${{ hashFiles('project/build.properties') }} + - name: sbt update + if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + run: sbt +update - name: Generate site run: sbt docs/tlSite - name: Publish site if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main' - uses: peaceiris/actions-gh-pages@v3.9.0 + uses: peaceiris/actions-gh-pages@v3.9.3 with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: site/target/docs/site diff --git a/docs/faq.md b/docs/faq.md index f7eea873..de630bc1 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -24,5 +24,5 @@ export CUDA_CACHE_MAXSIZE=4294967296 Recent PyTorch versions provide a new backend based on Apple’s Metal Performance Shaders (MPS). The MPS backend enables GPU-accelerated training on the M1/M2 architecture. -Right now, there's no ARM build of PyTorch in JavaCPP and MPS ist not enabled. -If you have an M1/M2 machine and want to help, check the umbrella [issue for macosx-aarch64 support](https://github.com/bytedeco/javacpp-presets/issues/1069). \ No newline at end of file +While we have an ARM build of PyTorch in JavaCPP as of version `1.5.10`, MPS ist not enabled as the CI runners currently run on a macOS version that is too old. +If you want to help getting this to work, check out [the corresponding issue](https://github.com/bytedeco/javacpp-presets/issues/1464). \ No newline at end of file diff --git a/docs/installation.md b/docs/installation.md index f68bfbb3..b3481371 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -77,13 +77,13 @@ fork := true There is one downside to this approach. Because `pytorch-platform` depends on the native libraries for all supported platforms, it will download and cache **all** these libraries, no matter on which platform you actually are. -One way to avoid the overhead, is to explicitly depend on the native libraries for **your** platform instead of using -`pytorch-platform`. +One way to avoid the overhead is to explicitly depend on the native libraries for **your** platform instead of using +`pytorch-platform`. As of JavaCPP `1.5.10` the platform approach also doesn't work for `macosx-arm64` because the native dependencies are missing in the current version of `pytorch-platform` (should be fixed in the next release). ### Via classifier This can be done by providing dependency classifiers specifically for your platform. -Currently supported are `linux-x86_64`, `macosx-x86_64` and `windows-x86_64`. +Currently supported are `linux-x86_64`, `macosx-x86_64`, `macosx-arm64` and `windows-x86_64`. @:select(build-tool) diff --git a/examples/src/main/scala/ImageClassifier.scala b/examples/src/main/scala/ImageClassifier.scala index ed2b26db..d68e37b7 100644 --- a/examples/src/main/scala/ImageClassifier.scala +++ b/examples/src/main/scala/ImageClassifier.scala @@ -17,14 +17,17 @@ //> using scala "3.3" //> using repository "sonatype:snapshots" //> using repository "sonatype-s01:snapshots" -//> using lib "dev.storch::vision:0.0-bbdc238-SNAPSHOT" +//> using lib "dev.storch::vision:0.0-2fff591-SNAPSHOT" //> using lib "me.tongfei:progressbar:0.9.5" //> using lib "com.github.alexarchambault::case-app:2.1.0-M24" //> using lib "org.scala-lang.modules::scala-parallel-collections:1.0.4" // replace with pytorch-platform-gpu if you have a CUDA capable GPU -//> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.10-SNAPSHOT" +//> using lib "org.bytedeco:pytorch-platform:2.1.2-1.5.10" // enable for CUDA support -////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.10-SNAPSHOT" +////> using lib "org.bytedeco:cuda-platform-redist:12.3-8.9-1.5.10" +// enable for native Apple Silicon support +// will not be needed with newer versions of pytorch-platform +////> using lib "org.bytedeco:pytorch:2.1.2-1.5.10,classifier=macosx-arm64" import Commands.* import ImageClassifier.{Prediction, predict, train} diff --git a/examples/src/main/scala/LeNet.scala b/examples/src/main/scala/LeNet.scala index 8599ace4..dc947b65 100644 --- a/examples/src/main/scala/LeNet.scala +++ b/examples/src/main/scala/LeNet.scala @@ -17,11 +17,14 @@ //> using scala "3.3" //> using repository "sonatype:snapshots" //> using repository "sonatype-s01:snapshots" -//> using lib "dev.storch::vision:0.0-bbdc238-SNAPSHOT" +//> using lib "dev.storch::vision:0.0-2fff591-SNAPSHOT" // replace with pytorch-platform-gpu if you have a CUDA capable GPU -//> using lib "org.bytedeco:pytorch-platform:2.0.1-1.5.10-SNAPSHOT" +//> using lib "org.bytedeco:pytorch-platform:2.1.2-1.5.10" // enable for CUDA support -////> using lib "org.bytedeco:cuda-platform-redist:12.1-8.9-1.5.10-SNAPSHOT" +////> using lib "org.bytedeco:cuda-platform-redist:12.3-8.9-1.5.10" +// enable for native Apple Silicon support +// will not be needed with newer versions of pytorch-platform +////> using lib "org.bytedeco:pytorch:2.1.2-1.5.10,classifier=macosx-arm64" import torch.* import torch.nn.functional as F diff --git a/project/SiteSettings.scala b/project/SiteSettings.scala index 46f29b92..a490c24c 100644 --- a/project/SiteSettings.scala +++ b/project/SiteSettings.scala @@ -18,30 +18,21 @@ import laika.helium.config.Favicon import laika.helium.config.HeliumIcon import laika.helium.config.IconLink import laika.helium.config.ImageLink -import laika.rewrite.nav.{ChoiceConfig, Selections, SelectionConfig} -import laika.rewrite.link.{ApiLinks, LinkConfig} +import laika.config.{ApiLinks, LinkConfig, ChoiceConfig, SelectionConfig, Selections} import laika.sbt.LaikaPlugin import laika.theme.ThemeProvider - -import java.net.URL +import laika.theme.config.{CrossOrigin, ScriptAttributes, StyleAttributes} object StorchSitePlugin extends AutoPlugin { override def requires = TypelevelSitePlugin override def projectSettings = Seq( - tlSiteRelatedProjects := Seq( - "PyTorch" -> new URL("https://pytorch.org/"), - "JavaCPP" -> new URL("https://github.com/bytedeco/javacpp") - ), laikaConfig := LaikaConfig.defaults.withRawContent .withConfigValue( - LinkConfig(apiLinks = - Seq( - // ApiLinks(baseUri = "http://localhost:4242/api/") - ApiLinks(baseUri = "https://storch.dev/api/") - ) - ) + LinkConfig.empty + // .addApiLinks(ApiLinks(baseUri = "http://localhost:4242/api/") + .addApiLinks(ApiLinks(baseUri = "https://storch.dev/api/")) ) .withConfigValue( Selections( @@ -52,7 +43,7 @@ object StorchSitePlugin extends AutoPlugin { ).withSeparateEbooks ) ), - tlSiteHeliumConfig := Helium.defaults.site + tlSiteHelium := tlSiteHelium.value.site .metadata( title = Some("Storch"), authors = developers.value.map(_.name), @@ -90,6 +81,16 @@ object StorchSitePlugin extends AutoPlugin { ) ) .site + .mainNavigation( + appendLinks = Seq( + ThemeNavigationSection( + "Related Projects", + TextLink.external("https://pytorch.org/", "PyTorch"), + TextLink.external("https://github.com/bytedeco/javacpp", "JavaCPP") + ) + ) + ) + .site .landingPage( logo = Some( Image.internal(Root / "img" / "storch.svg", height = Some(Length(300, LengthUnit.px))) @@ -142,5 +143,31 @@ object StorchSitePlugin extends AutoPlugin { ) ) ) + .site + .internalCSS(Root / "css") // custom styles + // KaTeX + .site + .externalCSS( + url = "https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/katex.min.css", + attributes = StyleAttributes.defaults + .withIntegrity("sha384-n8MVd4RsNIU0tAv4ct0nTaAbDJwPJzDEaqSD1odI+WdtXRGWt2kTvGFasHpSy3SV") + .withCrossOrigin(CrossOrigin.Anonymous) + ) + .site + .externalJS( + url = "https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/katex.min.js", + attributes = ScriptAttributes.defaults.defer + .withIntegrity("sha384-XjKyOOlGwcjNTAIQHIpgOno0Hl1YQqzUOEleOLALmuqehneUG+vnGctmUb0ZY0l8") + .withCrossOrigin(CrossOrigin.Anonymous) + ) + .site + .externalJS( + url = "https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/contrib/auto-render.min.js", + attributes = ScriptAttributes.defaults.defer + .withIntegrity("sha384-+VBxd3r6XgURycqtZ117nYw44OOcIax56Z4dCRWbxyPt0Koah1uHoK0o4+/RRE05") + .withCrossOrigin(CrossOrigin.Anonymous) + ) + .site + .internalJS(Root / "js" / "render-katex.js") ) } diff --git a/project/plugins.sbt b/project/plugins.sbt index a1bcc1e4..b21dfa51 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,6 +1,6 @@ -addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.0") +addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.5.2") addSbtPlugin("org.bytedeco" % "sbt-javacpp" % "1.17") -addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.3.7") +addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.2") addSbtPlugin("com.github.sbt" % "sbt-unidoc" % "0.5.0") -addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.4.22") -addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.4.22") +addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.6.5") +addSbtPlugin("org.typelevel" % "sbt-typelevel-site" % "0.6.5") diff --git a/site/src/css/custom.css b/site/src/css/custom.css new file mode 100644 index 00000000..f9e17131 --- /dev/null +++ b/site/src/css/custom.css @@ -0,0 +1,34 @@ +/* Adjust a few styles from landing-page.css */ + +#header { + padding-top: 20px; +} + +#header-left h1, #header-left h2 { + color: var(--component-color); + line-height: 1; + margin-bottom: 5px; +} + +#header-left h1 { + font-size: 40px; +} + +#header-left h2 { + font-size: 22px; + margin-top: 0.7em; +} + +.teaser h2 { + font-size: 20px; + margin-bottom: 0.25em; + margin-top: 0; +} + +.teaser p { + font-size: 15px; +} + +.teasers { + margin: 15px auto 0 auto; +} \ No newline at end of file diff --git a/site/src/default.template.html b/site/src/default.template.html deleted file mode 100644 index 31606124..00000000 --- a/site/src/default.template.html +++ /dev/null @@ -1,121 +0,0 @@ - - - - - - - - ${cursor.currentDocument.title} - @:for(laika.site.metadata.authors) - - @:@ - @:for(laika.site.metadata.description) - - @:@ - @:for(helium.favIcons) - - @:@ - @:for(helium.webFonts) - - @:@ - @:linkCSS { paths = ${helium.site.includeCSS} } - @:linkJS { paths = ${helium.site.includeJS} } - @:heliumInitVersions - @:heliumInitPreview(container) - - - - - - - - - - - -
- - - - ${?helium.topBar.home} - - ${?helium.topBar.links} - -
- - - -
- - - -
- - ${cursor.currentDocument.content} - -
- -
- - - diff --git a/site/src/js/render-katex.js b/site/src/js/render-katex.js new file mode 100644 index 00000000..88bb4205 --- /dev/null +++ b/site/src/js/render-katex.js @@ -0,0 +1,12 @@ +document.addEventListener("DOMContentLoaded", function() { + renderMathInElement(document.body, { + // customised options + // • auto-render specific keys, e.g.: + delimiters: [ + {left: '$$', right: '$$', display: true}, + {left: '$', right: '$', display: false}, + ], + // • rendering keys, e.g.: + throwOnError : false + }); +}); \ No newline at end of file diff --git a/site/src/landing-page.md b/site/src/landing-page.md new file mode 100644 index 00000000..ee587e23 --- /dev/null +++ b/site/src/landing-page.md @@ -0,0 +1,13 @@ +
+ +
+ +
+ +
+ +
+ +
+ +

Torch by Mailtoanton / CC BY-SA 3.0

\ No newline at end of file diff --git a/site/src/landing.template.html b/site/src/landing.template.html deleted file mode 100644 index 728b5623..00000000 --- a/site/src/landing.template.html +++ /dev/null @@ -1,83 +0,0 @@ - - - - - - - - ${?laika.site.metadata.title} - @:for(laika.site.metadata.authors) - - @:@ - @:for(laika.site.metadata.description) - - @:@ - @:for(helium.favIcons) - - @:@ - @:for(helium.webFonts) - - @:@ - @:linkCSS { paths = ${helium.site.includeCSS} } - @:for(laika.versions) - - @:empty - - @:@ - @:linkJS { paths = ${helium.site.includeJS} } - - - - - - - ${?cursor.currentDocument.fragments.header} - @:for(helium.landingPage.teaserRows) -
- @:for(_.teasers) -
-

${_.title}

-

${_.description}

-
- @:@ -
- @:@ - ${cursor.currentDocument.content} - -
- -
- -

Torch by Mailtoanton / CC BY-SA 3.0

- - - From 11eb92cdf818de7d8fd609708f98da4e1431a0b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Mon, 5 Feb 2024 08:35:18 +0100 Subject: [PATCH 2/3] Fix warnings with new compiler options Enable CI on windows and macos --- .github/workflows/ci.yml | 35 ++++++- build.sbt | 10 +- core/src/main/scala/torch/DType.scala | 17 +--- core/src/main/scala/torch/Tensor.scala | 39 +++----- core/src/main/scala/torch/hub.scala | 2 +- .../torch/internal/NativeConverters.scala | 3 - .../torch/nn/functional/Activations.scala | 5 +- .../torch/nn/functional/Convolution.scala | 1 - .../scala/torch/nn/functional/Linear.scala | 2 - .../main/scala/torch/nn/functional/Loss.scala | 5 +- .../scala/torch/nn/functional/Pooling.scala | 4 +- .../scala/torch/nn/functional/package.scala | 2 - core/src/main/scala/torch/nn/init.scala | 96 +++++++++++-------- .../main/scala/torch/nn/modules/Module.scala | 8 +- .../nn/modules/batchnorm/BatchNorm1d.scala | 2 - .../nn/modules/batchnorm/BatchNorm2d.scala | 2 - .../nn/modules/container/ModuleList.scala | 3 +- .../nn/modules/container/Sequential.scala | 1 - .../scala/torch/nn/modules/conv/Conv2d.scala | 2 - .../torch/nn/modules/linear/Identity.scala | 4 +- .../torch/nn/modules/linear/Linear.scala | 9 +- .../nn/modules/normalization/LayerNorm.scala | 1 - .../modules/pooling/AdaptiveAvgPool2d.scala | 3 +- .../torch/nn/modules/pooling/MaxPool2d.scala | 1 - .../nn/modules/regularization/Dropout.scala | 3 - .../torch/nn/modules/sparse/Embedding.scala | 9 +- core/src/main/scala/torch/ops/OtherOps.scala | 2 - .../scala/torch/ops/RandomSamplingOps.scala | 1 - .../main/scala/torch/ops/ReductionOps.scala | 1 - core/src/main/scala/torch/ops/package.scala | 3 +- core/src/main/scala/torch/optim/Adam.scala | 6 +- core/src/main/scala/torch/optim/AdamW.scala | 6 +- .../main/scala/torch/optim/Optimizer.scala | 5 +- core/src/main/scala/torch/optim/SGD.scala | 4 +- .../optim/lr_scheduler/LRScheduler.scala | 2 - examples/src/main/scala/ImageClassifier.scala | 9 +- examples/src/main/scala/gpt/V2.scala | 27 ++---- .../scala/torchvision/datasets/MNIST.scala | 4 +- .../scala/torchvision/models/resnet.scala | 18 +--- .../torchvision/transforms/presets.scala | 9 +- 40 files changed, 158 insertions(+), 208 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e6bff67d..1cf534e7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,12 +27,17 @@ jobs: name: Build and Test strategy: matrix: - os: [ubuntu-latest] + os: [macos-latest, ubuntu-latest, windows-latest] scala: [3] java: [temurin@11] runs-on: ${{ matrix.os }} timeout-minutes: 60 steps: + - name: Ignore line ending differences in git + if: contains(runner.os, 'windows') + shell: bash + run: git config --global core.autocrlf false + - name: Checkout current branch (full) uses: actions/checkout@v4 with: @@ -49,32 +54,40 @@ jobs: - name: sbt update if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + shell: bash run: sbt +update - name: Check that workflows are up to date + shell: bash run: sbt githubWorkflowCheck - name: Check headers and formatting - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'macos-latest' + shell: bash run: sbt '++ ${{ matrix.scala }}' headerCheckAll scalafmtCheckAll 'project /' scalafmtSbtCheck - name: Test + shell: bash run: sbt '++ ${{ matrix.scala }}' test - name: Check binary compatibility - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'macos-latest' + shell: bash run: sbt '++ ${{ matrix.scala }}' mimaReportBinaryIssues - name: Generate API documentation - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'macos-latest' + shell: bash run: sbt '++ ${{ matrix.scala }}' doc - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') + shell: bash run: mkdir -p vision/target core/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') + shell: bash run: tar cf targets.tar vision/target core/target project/target - name: Upload target directories @@ -90,10 +103,14 @@ jobs: if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main') strategy: matrix: - os: [ubuntu-latest] + os: [macos-latest] java: [temurin@11] runs-on: ${{ matrix.os }} steps: + - name: Ignore line ending differences in git + if: contains(runner.os, 'windows') + run: git config --global core.autocrlf false + - name: Checkout current branch (full) uses: actions/checkout@v4 with: @@ -155,6 +172,10 @@ jobs: java: [temurin@11] runs-on: ${{ matrix.os }} steps: + - name: Ignore line ending differences in git + if: contains(runner.os, 'windows') + run: git config --global core.autocrlf false + - name: Checkout current branch (full) uses: actions/checkout@v4 with: @@ -187,6 +208,10 @@ jobs: java: [temurin@11] runs-on: ${{ matrix.os }} steps: + - name: Ignore line ending differences in git + if: contains(runner.os, 'windows') + run: git config --global core.autocrlf false + - name: Checkout current branch (full) uses: actions/checkout@v4 with: diff --git a/build.sbt b/build.sbt index 7bae4bce..2b5b0609 100644 --- a/build.sbt +++ b/build.sbt @@ -33,6 +33,7 @@ ThisBuild / javaCppVersion := "1.5.10" ThisBuild / resolvers ++= Resolver.sonatypeOssRepos("snapshots") ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11")) +ThisBuild / githubWorkflowOSes := Seq("macos-latest", "ubuntu-latest", "windows-latest") val enableGPU = settingKey[Boolean]("enable or disable GPU support") @@ -46,10 +47,11 @@ val hasMKL = { lazy val commonSettings = Seq( Compile / doc / scalacOptions ++= Seq("-groups", "-snippet-compiler:compile"), javaCppVersion := (ThisBuild / javaCppVersion).value, - javaCppPlatform := Seq() + javaCppPlatform := Seq(), // This is a hack to avoid depending on the native libs when publishing // but conveniently have them on the classpath during development. // There's probably a cleaner way to do this. + tlJdkRelease := Some(11) ) ++ tlReplaceCommandAlias( "tlReleaseLocal", List( @@ -111,7 +113,11 @@ lazy val vision = project lazy val examples = project .in(file("examples")) .enablePlugins(NoPublishPlugin) - .settings(commonSettings) + .settings( + commonSettings, + // disable discarded non-Unit value warnings in examples for now + scalacOptions ~= (_.filterNot(Set("-Wvalue-discard"))) + ) .settings( fork := true, libraryDependencies ++= Seq( diff --git a/core/src/main/scala/torch/DType.scala b/core/src/main/scala/torch/DType.scala index 14f8e4a5..6d45c534 100644 --- a/core/src/main/scala/torch/DType.scala +++ b/core/src/main/scala/torch/DType.scala @@ -16,22 +16,9 @@ package torch -import org.bytedeco.javacpp.{DoublePointer, FloatPointer} import org.bytedeco.pytorch.global.torch.ScalarType -import org.bytedeco.pytorch.Scalar - -import java.nio.{ - Buffer, - ByteBuffer, - CharBuffer, - DoubleBuffer, - FloatBuffer, - IntBuffer, - LongBuffer, - ShortBuffer -} -import scala.annotation.{targetName, unused} -import scala.reflect.ClassTag + +import java.nio.{Buffer, ByteBuffer, DoubleBuffer, FloatBuffer, IntBuffer, LongBuffer, ShortBuffer} import spire.math.{Complex, UByte} import scala.compiletime.{erasedValue, summonFrom} diff --git a/core/src/main/scala/torch/Tensor.scala b/core/src/main/scala/torch/Tensor.scala index 2e7830cd..d1c38b97 100644 --- a/core/src/main/scala/torch/Tensor.scala +++ b/core/src/main/scala/torch/Tensor.scala @@ -25,43 +25,22 @@ import org.bytedeco.javacpp.{ LongPointer, ShortPointer } -import org.bytedeco.javacpp.indexer.{Indexer, IntIndexer, LongIndexer} import org.bytedeco.pytorch -import org.bytedeco.pytorch.{LongOptional, Scalar, TensorIndexArrayRef} +import org.bytedeco.pytorch.TensorIndexArrayRef import org.bytedeco.pytorch.global.torch as torchNative import Tensor.* import org.bytedeco.pytorch.global.torch.ScalarType -import org.bytedeco.pytorch.NoGradGuard - -import java.nio.{ - Buffer, - ByteBuffer, - CharBuffer, - DoubleBuffer, - FloatBuffer, - IntBuffer, - LongBuffer, - ShortBuffer -} + +import java.nio.{Buffer, ByteBuffer, DoubleBuffer, FloatBuffer, IntBuffer, LongBuffer, ShortBuffer} import scala.collection.immutable.ArraySeq import scala.reflect.ClassTag -import scala.annotation.{targetName, unused} -import org.bytedeco.pytorch.global.torch.DeviceType import internal.NativeConverters.{toOptional, toScalar} import spire.math.{Complex, UByte} -import scala.reflect.Typeable import internal.NativeConverters import internal.NativeConverters.toArray import Device.CPU import Layout.Strided -import org.bytedeco.pytorch.ByteArrayRef -import org.bytedeco.pytorch.ShortArrayRef -import org.bytedeco.pytorch.BoolArrayRef -import org.bytedeco.pytorch.IntArrayRef -import org.bytedeco.pytorch.LongArrayRef -import org.bytedeco.pytorch.FloatArrayRef -import org.bytedeco.pytorch.DoubleArrayRef import org.bytedeco.pytorch.EllipsisIndexType import org.bytedeco.pytorch.SymInt import org.bytedeco.pytorch.SymIntOptional @@ -675,7 +654,9 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto */ def unsqueeze(dim: Int): Tensor[D] = fromNative(native.unsqueeze(dim)) - def zero(): Unit = native.zero_() + def zero_(): this.type = + native.zero_() + this private def nativeIndices[T <: Boolean | Long: ClassTag]( indices: (Slice | Int | Long | Tensor[Bool] | Tensor[UInt8] | Tensor[Int64] | Seq[T] | @@ -731,7 +712,9 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto def requiresGrad: Boolean = native.requires_grad() - def requiresGrad_=(requiresGrad: Boolean): Unit = native.requires_grad_(requiresGrad) + def requiresGrad_=(requiresGrad: Boolean): this.type = + native.requires_grad_(requiresGrad) + this def split( splitSize: Int | Seq[Int], @@ -807,7 +790,8 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto get: (Array[A], TypedBuffer[A]) => TypedBuffer[A] ): Array[A] = val a = new Array[A](numel.toInt) - if numel > 0 then get(a, tensor.native.contiguous.createBuffer[TypedBuffer[A]]) + if numel > 0 then + val _ = get(a, tensor.native.contiguous.createBuffer[TypedBuffer[A]]) a import ScalarType.* @@ -867,7 +851,6 @@ sealed abstract class Tensor[D <: DType]( /* private[torch] */ val native: pyto flattened: Boolean = false, includeInfo: Boolean = true ): String = - if dtype == int32 then max() def format(x: Any): String = x match case x: Float => "%1.4f".format(x) diff --git a/core/src/main/scala/torch/hub.scala b/core/src/main/scala/torch/hub.scala index 3c811f30..4f1a27c1 100644 --- a/core/src/main/scala/torch/hub.scala +++ b/core/src/main/scala/torch/hub.scala @@ -17,7 +17,6 @@ package torch import dev.dirs.BaseDirectories -import scala.io.Source import scala.util.Using import java.nio.file.Files import java.net.URL @@ -36,5 +35,6 @@ object hub: System.err.println(s"Downloading: $url to $cachedFile") Using.resource(URL(url).openStream()) { inputStream => Files.copy(inputStream, cachedFile.toNIO) + () } torch.pickleLoad(cachedFile.toNIO) diff --git a/core/src/main/scala/torch/internal/NativeConverters.scala b/core/src/main/scala/torch/internal/NativeConverters.scala index 18d21a9b..4c3d5edb 100644 --- a/core/src/main/scala/torch/internal/NativeConverters.scala +++ b/core/src/main/scala/torch/internal/NativeConverters.scala @@ -24,15 +24,12 @@ import org.bytedeco.pytorch.{ DeviceOptional, DoubleOptional, BoolOptional, - LongArrayRefOptional, LongOptional, TensorOptional } -import scala.reflect.Typeable import org.bytedeco.javacpp.{LongPointer, DoublePointer} import org.bytedeco.pytorch.GenericDict -import org.bytedeco.pytorch.GenericDictIterator import spire.math.Complex import spire.math.UByte import scala.annotation.targetName diff --git a/core/src/main/scala/torch/nn/functional/Activations.scala b/core/src/main/scala/torch/nn/functional/Activations.scala index 1c9ac05d..a3e19bf1 100644 --- a/core/src/main/scala/torch/nn/functional/Activations.scala +++ b/core/src/main/scala/torch/nn/functional/Activations.scala @@ -21,9 +21,8 @@ package functional import Derive.derive import org.bytedeco.pytorch import org.bytedeco.pytorch.global.torch as torchNative -import org.bytedeco.javacpp.LongPointer -import torch.internal.NativeConverters.{fromNative, toNative, toOptional} -import org.bytedeco.pytorch.{ScalarTypeOptional, TensorOptional} +import torch.internal.NativeConverters.fromNative +import org.bytedeco.pytorch.ScalarTypeOptional private[torch] trait Activations { diff --git a/core/src/main/scala/torch/nn/functional/Convolution.scala b/core/src/main/scala/torch/nn/functional/Convolution.scala index e744276a..15f7fad6 100644 --- a/core/src/main/scala/torch/nn/functional/Convolution.scala +++ b/core/src/main/scala/torch/nn/functional/Convolution.scala @@ -19,7 +19,6 @@ package nn package functional import org.bytedeco.pytorch -import org.bytedeco.pytorch.TensorOptional import org.bytedeco.pytorch.global.torch as torchNative import torch.internal.NativeConverters.* diff --git a/core/src/main/scala/torch/nn/functional/Linear.scala b/core/src/main/scala/torch/nn/functional/Linear.scala index 0d4a09af..6d657830 100644 --- a/core/src/main/scala/torch/nn/functional/Linear.scala +++ b/core/src/main/scala/torch/nn/functional/Linear.scala @@ -18,9 +18,7 @@ package torch package nn package functional -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch -import org.bytedeco.pytorch.TensorOptional import org.bytedeco.pytorch.global.torch as torchNative import torch.internal.NativeConverters.{fromNative, toOptional} diff --git a/core/src/main/scala/torch/nn/functional/Loss.scala b/core/src/main/scala/torch/nn/functional/Loss.scala index f2dbb913..b9322b12 100644 --- a/core/src/main/scala/torch/nn/functional/Loss.scala +++ b/core/src/main/scala/torch/nn/functional/Loss.scala @@ -18,11 +18,10 @@ package torch package nn package functional -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch -import org.bytedeco.pytorch.{BCEWithLogitsLossOptions, TensorOptional} +import org.bytedeco.pytorch.BCEWithLogitsLossOptions import org.bytedeco.pytorch.global.torch as torchNative -import torch.internal.NativeConverters.{fromNative, toOptional} +import torch.internal.NativeConverters.fromNative // Loss functions private[torch] trait Loss { diff --git a/core/src/main/scala/torch/nn/functional/Pooling.scala b/core/src/main/scala/torch/nn/functional/Pooling.scala index 78007b43..6504a8d4 100644 --- a/core/src/main/scala/torch/nn/functional/Pooling.scala +++ b/core/src/main/scala/torch/nn/functional/Pooling.scala @@ -18,7 +18,6 @@ package torch package nn package functional -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch import org.bytedeco.pytorch.{ AvgPool1dOptions, @@ -26,8 +25,7 @@ import org.bytedeco.pytorch.{ AvgPool3dOptions, MaxPool1dOptions, MaxPool2dOptions, - MaxPool3dOptions, - TensorOptional + MaxPool3dOptions } import org.bytedeco.pytorch.global.torch as torchNative import torch.internal.NativeConverters.* diff --git a/core/src/main/scala/torch/nn/functional/package.scala b/core/src/main/scala/torch/nn/functional/package.scala index 6d14a1a2..c01fdfe1 100644 --- a/core/src/main/scala/torch/nn/functional/package.scala +++ b/core/src/main/scala/torch/nn/functional/package.scala @@ -17,8 +17,6 @@ package torch package nn -import functional.* - /** @groupname nn_conv Convolution functions * @groupname nn_pooling Pooling functions * @groupname nn_attention Attention mechanisms diff --git a/core/src/main/scala/torch/nn/init.scala b/core/src/main/scala/torch/nn/init.scala index 6135790f..99f5656a 100644 --- a/core/src/main/scala/torch/nn/init.scala +++ b/core/src/main/scala/torch/nn/init.scala @@ -94,12 +94,13 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1ab0aeccca28b2225ee9aab809ec38a801.html#exhale-function-namespacetorch-1-1nn-1-1init-1ab0aeccca28b2225ee9aab809ec38a801 */ - def uniform_( - t: Tensor[?], + def uniform_[D <: DType]( + t: Tensor[D], a: Double = 0, b: Double = 1 - ): Unit = + ): Tensor[D] = torchNative.uniform_(t.native, a, b) + t /** Fills the he given 2-dimensional input Tensor with values drawn from the normal distribution * $N(\text{mean},\text{std}^2)$. No gradient will be recorded for this operation. @@ -113,12 +114,13 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1a105c2a8ef81c6faa82a01cf35ce9f3b1.html#exhale-function-namespacetorch-1-1nn-1-1init-1a105c2a8ef81c6faa82a01cf35ce9f3b1 */ - def normal_( - t: Tensor[?], + def normal_[D <: DType]( + t: Tensor[D], mean: Double = 0, std: Double = 0 - ): Unit = + ): t.type = torchNative.normal_(t.native, mean, std) + t // TODO valid for all scala types /** Fills the input Tensor with the value valval. @@ -131,8 +133,9 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1a9c886724aac3a487553dc0a406565c83.html#exhale-function-namespacetorch-1-1nn-1-1init-1a9c886724aac3a487553dc0a406565c83 */ - def constant_(t: Tensor[?], fillValue: Double): Unit = - torchNative.constant_(t.native, Scalar(fillValue)): Unit + def constant_[D <: DType](t: Tensor[D], fillValue: Double): t.type = + torchNative.constant_(t.native, Scalar(fillValue)) + t /** Fills the input Tensor with the scalar value 1. No gradient will be recorded for this * operation. @@ -142,10 +145,11 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1a9dcc2051aadbe8ddb37d58bbd2b7943a.html#exhale-function-namespacetorch-1-1nn-1-1init-1a9dcc2051aadbe8ddb37d58bbd2b7943a */ - def ones_( - t: Tensor[?] - ): Unit = + def ones_[D <: DType]( + t: Tensor[D] + ): t.type = torchNative.ones_(t.native) + t /** Fills the input Tensor with the scalar value 0. No gradient will be recorded for this * operation. @@ -155,10 +159,11 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1af7e7736ba2d050adc0523d84285564e8.html#exhale-function-namespacetorch-1-1nn-1-1init-1af7e7736ba2d050adc0523d84285564e8 */ - def zeros_( - t: Tensor[?] - ): Unit = + def zeros_[D <: DType]( + t: Tensor[D] + ): t.type = torchNative.zeros_(t.native) + t /** Fills the given 2-dimensional matrix with an identity matrix. No gradient will be recorded for * this operation. @@ -171,10 +176,11 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1a77eb9bba76a93da5b33e7770f9113015.html#exhale-function-namespacetorch-1-1nn-1-1init-1a77eb9bba76a93da5b33e7770f9113015 */ - def eye_( - t: Tensor[?] - ): Unit = + def eye_[D <: DType]( + t: Tensor[D] + ): t.type = torchNative.eye_(t.native) + t // TODO: no groups available /** From libTorch Fills the given tensor with the Dirac delta function in-place, and returns it. @@ -193,10 +199,11 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1ab9fa9ea51c05df8a5c9dcca7a54dd628.html#exhale-function-namespacetorch-1-1nn-1-1init-1ab9fa9ea51c05df8a5c9dcca7a54dd628 */ - def dirac_( - t: Tensor[?] - ): Unit = + def dirac_[D <: DType]( + t: Tensor[D] + ): t.type = torchNative.dirac_(t.native) + t /** Fills the input [[Tensor]] with values according to the method described in "Understanding the * difficulty of training deep feedforward neural networks"" - Glorot, X. & Bengio, Y. (2010), @@ -212,11 +219,12 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1a86191a828a085e1c720dbce185d6c307.html#exhale-function-namespacetorch-1-1nn-1-1init-1a86191a828a085e1c720dbce185d6c307 */ - def xavierNormal_( - t: Tensor[?], + def xavierNormal_[D <: DType]( + t: Tensor[D], gain: Double = 1.0 - ): Unit = + ): t.type = torchNative.xavier_normal_(t.native, gain) + t /** Fills the input [[Tensor]] with values according to the method described in "Understanding the * difficulty of training deep feedforward neural networks"" - Glorot, X. & Bengio, Y. (2010), @@ -231,11 +239,12 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1ace282f75916a862c9678343dfd4d5ffe.html#exhale-function-namespacetorch-1-1nn-1-1init-1ace282f75916a862c9678343dfd4d5ffe */ - def xavierUniform_( - t: Tensor[?], + def xavierUniform_[D <: DType]( + t: Tensor[D], gain: Double = 1.0 - ): Unit = + ): t.type = torchNative.xavier_uniform_(t.native, gain) + t /** Fills the input Tensor with values according to the method described in Delving deep into * rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. @@ -260,13 +269,14 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1a5e807af188fc8542c487d50d81cb1aa1.html#exhale-function-namespacetorch-1-1nn-1-1init-1a5e807af188fc8542c487d50d81cb1aa1 */ - def kaimingUniform_( - t: Tensor[?], + def kaimingUniform_[D <: DType]( + t: Tensor[D], a: Double = 0, mode: Mode = Mode.FanIn, nonlinearity: NonLinearity = NonLinearity.LeakyReLU - ): Unit = + ): t.type = torchNative.kaiming_uniform_(t.native, a, mode.toNative, nonlinearity.toNative) + t /** Fills the input Tensor with values according to the method described in "Delving deep into * rectifiers: Surpassing human-level performance on ImageNet classification - He, K. et al. @@ -291,13 +301,14 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1ac8a913c051976a3f41f20df7d6126e57.html#exhale-function-namespacetorch-1-1nn-1-1init-1ac8a913c051976a3f41f20df7d6126e57 */ - def kaimingNormal_( - t: Tensor[?], + def kaimingNormal_[D <: DType]( + t: Tensor[D], a: Double = 0, mode: Mode = Mode.FanIn, nonlinearity: NonLinearity = NonLinearity.LeakyReLU - ): Unit = + ): t.type = torchNative.kaiming_normal_(t.native, a, mode.toNative, nonlinearity.toNative) + t // TODO: no trunc normal as per the PyTorch API. C++ docs not commented. Not part of init but of function // /** @@ -313,10 +324,11 @@ object init: // * @param b (float) – the maximum cutoff value // * @see https://pytorch.org/cppdocs/api/function_namespaceat_1aa604fcef7ea09fc379dc92c5d92a06ab.html // */ - def trunc_( - t: Tensor[?] - ): Unit = + def trunc_[D <: DType]( + t: Tensor[D] + ): t.type = torchNative.trunc_(t.native) + t /** Fills the input Tensor with a (semi) orthogonal matrix, as described in Exact solutions to the * nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013). The @@ -330,11 +342,12 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1a5978fcc257460475f635b5960e892a8e.html#exhale-function-namespacetorch-1-1nn-1-1init-1a5978fcc257460475f635b5960e892a8e */ - def orthogonal_( - t: Tensor[?], + def orthogonal_[D <: DType]( + t: Tensor[D], gain: Double = 1.0 - ): Unit = + ): t.type = torchNative.orthogonal_(t.native, gain) + t /** Fills the 2D input Tensor as a sparse matrix, where the non-zero elements will be drawn from * the normal distribution $N(0,0.01)$, as described in "Deep learning via Hessian-free @@ -352,12 +365,13 @@ object init: * @see * https://pytorch.org/cppdocs/api/function_namespacetorch_1_1nn_1_1init_1a82f2e5810880c7cc60c84516eb283be6.html#exhale-function-namespacetorch-1-1nn-1-1init-1a82f2e5810880c7cc60c84516eb283be6 */ - def sparse_( - t: Tensor[?], + def sparse_[D <: DType]( + t: Tensor[D], sparsity: Double, std: Double = 0.01 - ): Unit = + ): t.type = torchNative.sparse_(t.native, sparsity, std) + t enum Mode: case FanIn, FanOut diff --git a/core/src/main/scala/torch/nn/modules/Module.scala b/core/src/main/scala/torch/nn/modules/Module.scala index c9b08d29..590cd127 100644 --- a/core/src/main/scala/torch/nn/modules/Module.scala +++ b/core/src/main/scala/torch/nn/modules/Module.scala @@ -18,15 +18,11 @@ package torch package nn package modules -import org.bytedeco.javacpp.CharPointer import org.bytedeco.pytorch -import org.bytedeco.pytorch.{Conv2dImpl, InputArchive, OutputArchive} +import org.bytedeco.pytorch.{InputArchive, OutputArchive} import Tensor.fromNative -import java.nio.CharBuffer import scala.collection.immutable.{ArraySeq, SeqMap, TreeSeqMap} -import scala.reflect.ClassTag -import scala.annotation.targetName abstract class Module { @@ -69,7 +65,7 @@ abstract class Module { def namedChildren: SeqMap[String, Module] = childModules def namedModules: SeqMap[String, Module] = - namedChildren.flatMap((name, module) => module.namedModules) + namedChildren.flatMap((_, module) => module.namedModules) def apply(fn: Module => Unit): this.type = for (_, module) <- namedModules diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala index 9650da93..39aea8c7 100644 --- a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm1d.scala @@ -19,10 +19,8 @@ package nn package modules package batchnorm -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch.{BatchNorm1dImpl, BatchNormOptions} import org.bytedeco.pytorch -import sourcecode.Name import torch.internal.NativeConverters.fromNative /** Applies Batch Normalization over a 2D or 3D input as described in the paper [Batch diff --git a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala index 1803ffc9..7b75c10d 100644 --- a/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala +++ b/core/src/main/scala/torch/nn/modules/batchnorm/BatchNorm2d.scala @@ -19,10 +19,8 @@ package nn package modules package batchnorm -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch.{BatchNorm2dImpl, BatchNormOptions} import org.bytedeco.pytorch -import sourcecode.Name import torch.internal.NativeConverters.fromNative /** Applies Batch Normalization over a 4D input as described in the paper [Batch Normalization: diff --git a/core/src/main/scala/torch/nn/modules/container/ModuleList.scala b/core/src/main/scala/torch/nn/modules/container/ModuleList.scala index 3e76e15e..d1d338f4 100644 --- a/core/src/main/scala/torch/nn/modules/container/ModuleList.scala +++ b/core/src/main/scala/torch/nn/modules/container/ModuleList.scala @@ -20,7 +20,6 @@ package modules package container import sourcecode.Name -import scala.util.Random /** Holds submodules in a list. * @@ -109,7 +108,7 @@ final class ModuleList[D <: DType](override val modules: TensorModule[D]*) // for i, module in enumerate(modules): // self.add_module(str(offset + i), module) // return self - val offset = modules.length + // val offset = modules.length val all = modules ++ newModules // Not in Python newModules.zipWithIndex.foreach((module, index) => diff --git a/core/src/main/scala/torch/nn/modules/container/Sequential.scala b/core/src/main/scala/torch/nn/modules/container/Sequential.scala index b455f386..9e4bead4 100644 --- a/core/src/main/scala/torch/nn/modules/container/Sequential.scala +++ b/core/src/main/scala/torch/nn/modules/container/Sequential.scala @@ -20,7 +20,6 @@ package modules package container import sourcecode.Name -import scala.util.Random final class Sequential[D <: DType](override val modules: TensorModule[D]*) extends Module diff --git a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala index cb1feaff..369bf89a 100644 --- a/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala +++ b/core/src/main/scala/torch/nn/modules/conv/Conv2d.scala @@ -19,10 +19,8 @@ package nn package modules package conv -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch import org.bytedeco.pytorch.{Conv2dImpl, Conv2dOptions, kZeros, kReflect, kReplicate, kCircular} -import sourcecode.Name import torch.internal.NativeConverters.{fromNative, toNative} import torch.nn.modules.conv.Conv2d.PaddingMode diff --git a/core/src/main/scala/torch/nn/modules/linear/Identity.scala b/core/src/main/scala/torch/nn/modules/linear/Identity.scala index 7ebf9e84..5c24c272 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Identity.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Identity.scala @@ -22,12 +22,14 @@ package linear import org.bytedeco.pytorch import org.bytedeco.pytorch.IdentityImpl import torch.internal.NativeConverters.fromNative +import scala.annotation.nowarn /** A placeholder identity operator that is argument-insensitive. * * @group nn_linear */ -final class Identity[D <: DType: Default](args: Any*) extends TensorModule[D]: +final class Identity[D <: DType: Default](@nowarn("msg=unused explicit parameter") args: Any*) + extends TensorModule[D]: override val nativeModule: IdentityImpl = IdentityImpl() override def hasBias(): Boolean = false diff --git a/core/src/main/scala/torch/nn/modules/linear/Linear.scala b/core/src/main/scala/torch/nn/modules/linear/Linear.scala index ba7bbb7f..213bfb2f 100644 --- a/core/src/main/scala/torch/nn/modules/linear/Linear.scala +++ b/core/src/main/scala/torch/nn/modules/linear/Linear.scala @@ -21,7 +21,6 @@ package linear import org.bytedeco.pytorch import org.bytedeco.pytorch.{LinearImpl, LinearOptions} -import torch.nn.modules.{HasParams} import internal.NativeConverters.fromNative /** Applies a linear transformation to the incoming data: $y = xA^T + b$ @@ -65,10 +64,14 @@ final class Linear[ParamType <: FloatNN: Default]( override def hasBias(): Boolean = options.bias().get() def weight = fromNative[ParamType](nativeModule.weight()) - def weight_=(t: Tensor[ParamType]): Unit = nativeModule.weight(t.native) + def weight_=(t: Tensor[ParamType]): Tensor[ParamType] = + nativeModule.weight(t.native) + t def bias = fromNative[ParamType](nativeModule.bias()) - def bias_=(t: Tensor[ParamType]): Unit = nativeModule.bias(t.native) + def bias_=(t: Tensor[ParamType]): Tensor[ParamType] = + nativeModule.bias(t.native) + t def apply(input: Tensor[ParamType]): Tensor[ParamType] = fromNative( nativeModule.forward(input.native) diff --git a/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala b/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala index 71c6ed76..5e89ab2c 100644 --- a/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala +++ b/core/src/main/scala/torch/nn/modules/normalization/LayerNorm.scala @@ -23,7 +23,6 @@ package normalization import org.bytedeco.pytorch import org.bytedeco.pytorch.{LayerNormImpl, LayerNormOptions, LongVector} -import torch.nn.modules.TensorModule import internal.NativeConverters.fromNative // format: off diff --git a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala index 49859905..4e85ea24 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/AdaptiveAvgPool2d.scala @@ -20,10 +20,9 @@ package modules package pooling import org.bytedeco.pytorch.AdaptiveAvgPool2dImpl -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch -import torch.internal.NativeConverters.{fromNative, toNative, toOptional} +import torch.internal.NativeConverters.{fromNative, toOptional} import org.bytedeco.pytorch.LongOptionalVector import org.bytedeco.pytorch.LongOptional diff --git a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala index 6cfdbb02..dac4a65e 100644 --- a/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala +++ b/core/src/main/scala/torch/nn/modules/pooling/MaxPool2d.scala @@ -19,7 +19,6 @@ package nn package modules package pooling -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch import org.bytedeco.pytorch.{MaxPool2dImpl, MaxPool2dOptions} import torch.internal.NativeConverters.{fromNative, toNative} diff --git a/core/src/main/scala/torch/nn/modules/regularization/Dropout.scala b/core/src/main/scala/torch/nn/modules/regularization/Dropout.scala index 484fcb14..85b019fc 100644 --- a/core/src/main/scala/torch/nn/modules/regularization/Dropout.scala +++ b/core/src/main/scala/torch/nn/modules/regularization/Dropout.scala @@ -21,12 +21,9 @@ package nn package modules package regularization -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch -import sourcecode.Name import org.bytedeco.pytorch.DropoutImpl import org.bytedeco.pytorch.DropoutOptions -import torch.nn.modules.{HasParams, HasWeight, TensorModule} import torch.internal.NativeConverters.fromNative // format: off diff --git a/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala b/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala index 0750ce84..b05a3232 100644 --- a/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala +++ b/core/src/main/scala/torch/nn/modules/sparse/Embedding.scala @@ -19,13 +19,10 @@ package nn package modules package sparse -import org.bytedeco.javacpp.LongPointer import org.bytedeco.pytorch -import sourcecode.Name import org.bytedeco.pytorch.EmbeddingImpl import org.bytedeco.pytorch.EmbeddingOptions -import torch.nn.modules.{HasParams, HasWeight, TensorModule} -import torch.internal.NativeConverters.{fromNative, toNative, doubleToDoublePointer} +import torch.internal.NativeConverters.{fromNative, toNative} // format: off /** A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -85,7 +82,9 @@ final class Embedding[ParamType <: FloatNN | ComplexNN: Default]( override def hasBias(): Boolean = false def weight: Tensor[ParamType] = fromNative(nativeModule.weight) - def weight_=(w: Tensor[ParamType]): Unit = nativeModule.weight(w.native) + def weight_=(w: Tensor[ParamType]): Tensor[ParamType] = + nativeModule.weight(w.native) + w def apply(t: Tensor[Int64]): Tensor[ParamType] = fromNative(nativeModule.forward(t.native)) diff --git a/core/src/main/scala/torch/ops/OtherOps.scala b/core/src/main/scala/torch/ops/OtherOps.scala index 4172a348..cda891f5 100644 --- a/core/src/main/scala/torch/ops/OtherOps.scala +++ b/core/src/main/scala/torch/ops/OtherOps.scala @@ -20,8 +20,6 @@ package ops import internal.NativeConverters.* import org.bytedeco.javacpp.BytePointer import org.bytedeco.pytorch.global.torch as torchNative -import org.bytedeco.pytorch.LongArrayRef -import org.bytedeco.pytorch.ScalarTypeOptional /** Other Ops * diff --git a/core/src/main/scala/torch/ops/RandomSamplingOps.scala b/core/src/main/scala/torch/ops/RandomSamplingOps.scala index df6aa9d1..7d0c8f18 100644 --- a/core/src/main/scala/torch/ops/RandomSamplingOps.scala +++ b/core/src/main/scala/torch/ops/RandomSamplingOps.scala @@ -23,7 +23,6 @@ import internal.NativeConverters import NativeConverters.* import org.bytedeco.pytorch.global.torch as torchNative -import scala.util.Using /** Random Sampling * diff --git a/core/src/main/scala/torch/ops/ReductionOps.scala b/core/src/main/scala/torch/ops/ReductionOps.scala index 53f65d36..701305ed 100644 --- a/core/src/main/scala/torch/ops/ReductionOps.scala +++ b/core/src/main/scala/torch/ops/ReductionOps.scala @@ -20,7 +20,6 @@ package ops import internal.NativeConverters.* import org.bytedeco.pytorch.global.torch as torchNative -import org.bytedeco.pytorch.LongArrayRef import org.bytedeco.pytorch.ScalarTypeOptional /** Reduction Ops diff --git a/core/src/main/scala/torch/ops/package.scala b/core/src/main/scala/torch/ops/package.scala index 986931db..035c3526 100644 --- a/core/src/main/scala/torch/ops/package.scala +++ b/core/src/main/scala/torch/ops/package.scala @@ -17,9 +17,8 @@ package torch import internal.NativeConverters.{fromNative, tensorOptions} -import org.bytedeco.pytorch.global.torch as torchNative import org.bytedeco.pytorch -import org.bytedeco.pytorch.{MemoryFormatOptional, TensorArrayRef, TensorVector} +import org.bytedeco.pytorch.MemoryFormatOptional package object ops { diff --git a/core/src/main/scala/torch/optim/Adam.scala b/core/src/main/scala/torch/optim/Adam.scala index e07fa94b..de84daf4 100644 --- a/core/src/main/scala/torch/optim/Adam.scala +++ b/core/src/main/scala/torch/optim/Adam.scala @@ -14,11 +14,11 @@ * limitations under the License. */ -package torch.optim +package torch +package optim import org.bytedeco.pytorch -import org.bytedeco.pytorch.{AdamOptions, SGDOptions, TensorVector} -import torch.{DType, Tensor} +import org.bytedeco.pytorch.{AdamOptions, TensorVector} import scala.collection.immutable.Iterable diff --git a/core/src/main/scala/torch/optim/AdamW.scala b/core/src/main/scala/torch/optim/AdamW.scala index ee6b78fd..969f479e 100644 --- a/core/src/main/scala/torch/optim/AdamW.scala +++ b/core/src/main/scala/torch/optim/AdamW.scala @@ -14,11 +14,11 @@ * limitations under the License. */ -package torch.optim +package torch +package optim import org.bytedeco.pytorch -import org.bytedeco.pytorch.{AdamWOptions, SGDOptions, TensorVector} -import torch.{DType, Tensor} +import org.bytedeco.pytorch.{AdamWOptions, TensorVector} import scala.collection.immutable.Iterable diff --git a/core/src/main/scala/torch/optim/Optimizer.scala b/core/src/main/scala/torch/optim/Optimizer.scala index c89637d1..73e2da34 100644 --- a/core/src/main/scala/torch/optim/Optimizer.scala +++ b/core/src/main/scala/torch/optim/Optimizer.scala @@ -29,7 +29,10 @@ abstract class Optimizer { * Unless otherwise specified, this function should not modify the ``.grad`` field of the * parameters. */ - def step(): Unit = native.step() + def step(): Unit = + native.step() + // TODO check what tensor is returned by step + () /** Sets the gradients of all optimized `Tensor`s to zero. */ def zeroGrad(): Unit = native.zero_grad() diff --git a/core/src/main/scala/torch/optim/SGD.scala b/core/src/main/scala/torch/optim/SGD.scala index a928e082..a350418e 100644 --- a/core/src/main/scala/torch/optim/SGD.scala +++ b/core/src/main/scala/torch/optim/SGD.scala @@ -14,11 +14,11 @@ * limitations under the License. */ -package torch.optim +package torch +package optim import org.bytedeco.pytorch import org.bytedeco.pytorch.{SGDOptions, TensorVector} -import torch.{DType, Tensor} import scala.collection.immutable.Iterable diff --git a/core/src/main/scala/torch/optim/lr_scheduler/LRScheduler.scala b/core/src/main/scala/torch/optim/lr_scheduler/LRScheduler.scala index b76580b5..28c797ca 100644 --- a/core/src/main/scala/torch/optim/lr_scheduler/LRScheduler.scala +++ b/core/src/main/scala/torch/optim/lr_scheduler/LRScheduler.scala @@ -18,7 +18,5 @@ package torch package optim package lr_scheduler -import org.bytedeco.pytorch - trait LRScheduler: def step(): Unit diff --git a/examples/src/main/scala/ImageClassifier.scala b/examples/src/main/scala/ImageClassifier.scala index d68e37b7..5a789806 100644 --- a/examples/src/main/scala/ImageClassifier.scala +++ b/examples/src/main/scala/ImageClassifier.scala @@ -34,8 +34,8 @@ import ImageClassifier.{Prediction, predict, train} import caseapp.* import caseapp.core.argparser.{ArgParser, SimpleArgParser} import caseapp.core.app.CommandsEntryPoint -import com.sksamuel.scrimage.{ImmutableImage, ScaleMethod} -import me.tongfei.progressbar.{ProgressBar, ProgressBarBuilder} +import com.sksamuel.scrimage.ImmutableImage +import me.tongfei.progressbar.ProgressBarBuilder import org.bytedeco.javacpp.PointerScope import org.bytedeco.pytorch.{InputArchive, OutputArchive} import os.Path @@ -46,9 +46,6 @@ import torchvision.models.resnet.{ResNet, ResNetVariant} import java.nio.file.Paths import scala.collection.parallel.CollectionConverters.ImmutableSeqIsParallelizable -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.duration.DurationInt -import scala.concurrent.{Await, Future} import scala.util.{Random, Try, Using} // format: off @@ -133,7 +130,7 @@ object ImageClassifier extends CommandsEntryPoint: // Don't load the classification head weights, as we they are specific to the imagenet classes // and their output size (1000) usually won't match the number of classes of our dataset. model.loadStateDict( - weights.filterNot((k, v) => Set("fc.weight", "fc.bias").contains(k)) + weights.filterNot((k, _) => Set("fc.weight", "fc.bias").contains(k)) ) model.to(device) diff --git a/examples/src/main/scala/gpt/V2.scala b/examples/src/main/scala/gpt/V2.scala index fdcbb22a..f628ef00 100644 --- a/examples/src/main/scala/gpt/V2.scala +++ b/examples/src/main/scala/gpt/V2.scala @@ -30,17 +30,12 @@ from torch.nn import functional as F // cSpell: ignore xbow, xprev, isinstance, idx, tok_emb // cSpell: ignore Andrej, Karpathy -import java.nio.file.Paths import java.nio.file.Files -import java.net.URL import java.net.URI -import scala.annotation.targetName -import scala.util.Random import scala.util.Using import scala.collection.immutable.SortedSet -import org.bytedeco.pytorch.OutputArchive import org.bytedeco.javacpp.PointerScope import torch.* @@ -48,25 +43,12 @@ import torch.Device.CUDA import torch.Device.CPU import torch.nn.functional as F import torch.nn.modules.Module -import torch.nn.modules.HasParams -import torch.nn.modules.HasWeight -import torch.{---, Slice} -import torch.optim.Adam -import torch.DType.float32 +import torch.Slice import org.bytedeco.javacpp.Pointer -import org.bytedeco.pytorch.cuda.Stat -import org.bytedeco.pytorch.cuda.CheckpointDelta -import org.bytedeco.pytorch.cuda.SnapshotInfo -import org.bytedeco.pytorch.cuda.CUDAAllocator -import org.bytedeco.pytorch.cuda.SegmentInfo -import org.bytedeco.pytorch.cuda.BlockInfo -import org.bytedeco.pytorch.cuda.DeviceStats -import org.bytedeco.javacpp.BoolPointer -import org.bytedeco.pytorch.global.torch_cuda - -import gpt.Utils + import gpt.Utils.Modules as UtilsM import gpt.Utils.CUDAMemory as Mem +import scala.annotation.nowarn /** This is code translated from Andrej Karpathy's * [[video https://www.youtube.com/watch?v=kCc8FmEb1nY]] titled "Let's build GPT: from scratch, in @@ -512,6 +494,8 @@ object V2: nEmbed: Int, nHead: Int, blockSize: Int, + // TODO check if we need this + @nowarn("msg=unused explicit parameter") vocabSize: Int, dropout: Double ) extends torch.nn.modules.TensorModule[D]: @@ -636,6 +620,7 @@ object V2: elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) */ + @nowarn("msg=unused private member") private def init_weights(m: Module): Unit = m match case lm: nn.Linear[?] => diff --git a/vision/src/main/scala/torchvision/datasets/MNIST.scala b/vision/src/main/scala/torchvision/datasets/MNIST.scala index 9304474a..49e93bd4 100644 --- a/vision/src/main/scala/torchvision/datasets/MNIST.scala +++ b/vision/src/main/scala/torchvision/datasets/MNIST.scala @@ -41,7 +41,7 @@ trait MNISTBase( private def downloadAndExtractArchive(url: URL, target: Path): Unit = println(s"downloading from $url") Using.resource(url.openStream()) { inputStream => - Files.copy(GZIPInputStream(inputStream), target) + val _ = Files.copy(GZIPInputStream(inputStream), target) } if download then { @@ -50,7 +50,7 @@ trait MNISTBase( val finalPath = root.resolve(filename.stripSuffix(".gz")) if !Files.exists(finalPath) then println(s"$finalPath not found") - mirrors.iterator + val _ = mirrors.iterator .map { mirror => Try(downloadAndExtractArchive(URL(s"$mirror$filename"), finalPath)) } diff --git a/vision/src/main/scala/torchvision/models/resnet.scala b/vision/src/main/scala/torchvision/models/resnet.scala index 3f4ca09f..0ce65411 100644 --- a/vision/src/main/scala/torchvision/models/resnet.scala +++ b/vision/src/main/scala/torchvision/models/resnet.scala @@ -17,21 +17,9 @@ package torchvision package models -import torch.{ - BFloat16, - ComplexNN, - DType, - Default, - Float32, - Float32Tensor, - Float64, - FloatNN, - Tensor, - nn -} +import torch.{BFloat16, DType, Default, Float32, Float64, FloatNN, Tensor, nn} import torch.nn.init.{Mode, NonLinearity, constant_, kaimingNormal_} -import scala.collection.mutable import torch.nn.modules.batchnorm.BatchNorm2d import torch.nn.modules.container.Sequential import torch.nn.modules.linear.Linear @@ -42,9 +30,6 @@ import torch.nn.modules.pooling.{AdaptiveAvgPool2d, MaxPool2d} import torch.nn.modules.{HasWeight, Module} import torchvision.transforms.* -import scala.util.Using -import com.sksamuel.scrimage.ImmutableImage -import torch.Int32 import torch.nn.modules.TensorModule /** ResNet architecture implementations @@ -110,7 +95,6 @@ object resnet: dilation: Int = 1, normLayer: => (Int => TensorModule[D]) ) extends TensorModule[D] { - import BasicBlock.expansion if groups != 1 || baseWidth != 64 then throw new IllegalArgumentException("BasicBlock only supports groups=1 and baseWidth=64") diff --git a/vision/src/main/scala/torchvision/transforms/presets.scala b/vision/src/main/scala/torchvision/transforms/presets.scala index 93c86ede..3cb95499 100644 --- a/vision/src/main/scala/torchvision/transforms/presets.scala +++ b/vision/src/main/scala/torchvision/transforms/presets.scala @@ -19,7 +19,6 @@ package transforms import com.sksamuel.scrimage.ImmutableImage import com.sksamuel.scrimage.ScaleMethod -import torch.Int32 import torch.Tensor import torch.Float32 import torchvision.transforms.functional.toTensor @@ -39,13 +38,13 @@ object Presets: image.scaleTo( (resizeSize * (image.width / image.height.toDouble)).toInt, resizeSize, - ScaleMethod.Bilinear + interpolation ) else image.scaleTo( resizeSize, (resizeSize * (image.height / image.width.toDouble)).toInt, - ScaleMethod.Bilinear + interpolation ) val croppedImage = scaledImage.resizeTo(cropSize, cropSize) toTensor(croppedImage) @@ -53,6 +52,6 @@ object Presets: def batchTransforms(input: Tensor[Float32]): Tensor[Float32] = torchvision.transforms.functional.normalize( input, - mean = Seq(0.485f, 0.456f, 0.406f), - std = Seq(0.229f, 0.224f, 0.225f) + mean = mean, + std = std ) From 0be74e7b99f7f73207353dae3945259e20576a66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=B6ren=20Brunk?= Date: Mon, 5 Feb 2024 17:28:34 +0100 Subject: [PATCH 3/3] Fix tests --- core/src/main/scala/torch/hub.scala | 3 +-- core/src/test/scala/TrainingSuite.scala | 8 ++++---- core/src/test/scala/torch/DeviceSuite.scala | 8 +------- core/src/test/scala/torch/TensorCheckSuite.scala | 2 +- core/src/test/scala/torch/TensorSuite.scala | 3 --- 5 files changed, 7 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/torch/hub.scala b/core/src/main/scala/torch/hub.scala index 4f1a27c1..f1a49df4 100644 --- a/core/src/main/scala/torch/hub.scala +++ b/core/src/main/scala/torch/hub.scala @@ -34,7 +34,6 @@ object hub: if !os.exists(cachedFile) then System.err.println(s"Downloading: $url to $cachedFile") Using.resource(URL(url).openStream()) { inputStream => - Files.copy(inputStream, cachedFile.toNIO) - () + val _ = Files.copy(inputStream, cachedFile.toNIO) } torch.pickleLoad(cachedFile.toNIO) diff --git a/core/src/test/scala/TrainingSuite.scala b/core/src/test/scala/TrainingSuite.scala index 808669a0..796f441d 100644 --- a/core/src/test/scala/TrainingSuite.scala +++ b/core/src/test/scala/TrainingSuite.scala @@ -28,8 +28,8 @@ class TraininSuite extends munit.FunSuite { torch.manualSeed(1) - var weight = torch.randn(Seq(1), requiresGrad = true) - var bias = torch.zeros(Seq(1), requiresGrad = true) + val weight = torch.randn(Seq(1), requiresGrad = true) + val bias = torch.zeros(Seq(1), requiresGrad = true) def model(xb: Tensor[Float32]): Tensor[Float32] = (xb matmul weight) + bias @@ -57,11 +57,11 @@ class TraininSuite extends munit.FunSuite { noGrad { weight.grad.foreach { grad => weight -= grad * learningRate - grad.zero() + grad.zero_() } bias.grad.foreach { grad => weight -= grad * learningRate - grad.zero() + grad.zero_() } } loss diff --git a/core/src/test/scala/torch/DeviceSuite.scala b/core/src/test/scala/torch/DeviceSuite.scala index 2642a875..81866c1c 100644 --- a/core/src/test/scala/torch/DeviceSuite.scala +++ b/core/src/test/scala/torch/DeviceSuite.scala @@ -17,15 +17,9 @@ package torch import munit.ScalaCheckSuite -import torch.DeviceType.CUDA import org.scalacheck.Prop.* -import org.bytedeco.pytorch.global.torch as torch_native -import org.scalacheck.{Arbitrary, Gen} import org.scalacheck._ -import Gen._ -import Arbitrary.arbitrary -import DeviceType.CPU -import Generators.{*, given} +import Generators.given class DeviceSuite extends ScalaCheckSuite { test("device native roundtrip") { diff --git a/core/src/test/scala/torch/TensorCheckSuite.scala b/core/src/test/scala/torch/TensorCheckSuite.scala index da620258..81c86902 100644 --- a/core/src/test/scala/torch/TensorCheckSuite.scala +++ b/core/src/test/scala/torch/TensorCheckSuite.scala @@ -19,7 +19,7 @@ package torch import munit.ScalaCheckSuite import shapeless3.typeable.{TypeCase, Typeable} import shapeless3.typeable.syntax.typeable.* -import Generators.{*, given} +import Generators.* import org.scalacheck.Prop.* import scala.util.Try diff --git a/core/src/test/scala/torch/TensorSuite.scala b/core/src/test/scala/torch/TensorSuite.scala index a6c3d746..a0141b33 100644 --- a/core/src/test/scala/torch/TensorSuite.scala +++ b/core/src/test/scala/torch/TensorSuite.scala @@ -16,9 +16,6 @@ package torch -import org.scalacheck.Prop.* -import Generators.given - class TensorSuite extends TensorCheckSuite { test("tensor properties") {