Skip to content

Commit

Permalink
fix: load JLD2 for pretrained models
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 27, 2024
1 parent bfb2f9d commit 84ee1af
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
4 changes: 4 additions & 0 deletions docs/src/api/vision.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ Vision.ResNeXt

## Pretrained Models

!!! note "Load JLD2"

You need to load `JLD2` before being able to load pretrained weights.

!!! tip "Load Pretrained Weights"

Pass `pretrained=true` to the model constructor to load the pretrained weights.
Expand Down
2 changes: 2 additions & 0 deletions examples/GettingStarted/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Expand All @@ -9,6 +10,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
[compat]
Boltz = "0.4"
InteractiveUtils = "<0.0.1, 1"
JLD2 = "0.4.52"
Literate = "2.19"
Lux = "0.5.65"
Metalhead = "0.9.3"
Expand Down
6 changes: 6 additions & 0 deletions examples/GettingStarted/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ model

# We can also load pretrained ImageNet weights using

# !!! note "Load JLD2"
#
# You need to load `JLD2` before being able to load pretrained weights.

using JLD2

model, _, _ = Vision.VGG(13; pretrained=true)
model

Expand Down
6 changes: 4 additions & 2 deletions src/vision/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ Create a VGG model [simonyan2014very](@citep).
function VGG(imsize; config, inchannels, batchnorm=false, nclasses, fcsize, dropout)
feature_extractor = vgg_convolutional_layers(config, batchnorm, inchannels)
# TODO: Use Lux.outputsize once it is ready
@show (imsize..., inchannels, 2)
outsize = Lux.compute_output_size(
feature_extractor, (imsize..., inchannels, 1), Random.default_rng())
classifier = vgg_classifier_layers(outsize[1:((end - 1))], nclasses, fcsize, dropout)
feature_extractor, (imsize..., inchannels, 2), Random.default_rng())
@show outsize
classifier = vgg_classifier_layers(outsize, nclasses, fcsize, dropout)
return Lux.Chain(feature_extractor, classifier)
end

Expand Down

0 comments on commit 84ee1af

Please sign in to comment.