Skip to content

Commit

Permalink
fix: change default cardinality/width to allow pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 13, 2024
1 parent d17892a commit e89017c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
6 changes: 4 additions & 2 deletions ext/BoltzMetalheadExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ function Vision.ResNetMetalhead(depth::Int; pretrained::Bool=false)
depth; pretrain=pretrained).layers)
end

function Vision.ResNeXtMetalhead(depth::Int; pretrained::Bool=false)
function Vision.ResNeXtMetalhead(
depth::Int; cardinality=32, base_width=nothing, pretrained::Bool=false)
@argcheck depth in (50, 101, 152)
base_width = base_width === nothing ? (depth == 101 ? 8 : 4) : base_width
return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.ResNeXt(
depth; pretrain=pretrained).layers)
depth; pretrain=pretrained, cardinality, base_width).layers)
end

function Vision.GoogLeNetMetalhead(; pretrained::Bool=false)
Expand Down
9 changes: 6 additions & 3 deletions src/vision/extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Create a ResNet model [he2016deep](@citep).
function ResNet end

"""
ResNeXt(depth::Int; pretrained::Bool=false)
ResNeXt(depth::Int; cardinality=32, base_width=nothing, pretrained::Bool=false)
Create a ResNeXt model [xie2017aggregated](@citep).
Expand All @@ -27,6 +27,9 @@ Create a ResNeXt model [xie2017aggregated](@citep).
- `pretrained::Bool=false`: If `true`, loads pretrained weights when `LuxCore.setup` is
called.
- `cardinality`: The cardinality of the ResNeXt model. Defaults to 32.
- `base_width`: The base width of the ResNeXt model. Defaults to 8 for depth 101 and 4
otherwise.
"""
function ResNeXt end

Expand Down Expand Up @@ -132,12 +135,12 @@ for f in [:ResNet, :ResNeXt, :GoogLeNet, :DenseNet,
f_metalhead = Symbol(f, :Metalhead)
@eval begin
function $(f_metalhead) end
function $(f)(args...; pretrained::Bool=false)
function $(f)(args...; pretrained::Bool=false, kwargs...)
if !is_extension_loaded(Val(:Metalhead))
error("`Metalhead.jl` is not loaded. Please load `Metalhead.jl` to use \
this function.")
end
model = $(f_metalhead)(args...; pretrained)
model = $(f_metalhead)(args...; pretrained, kwargs...)
return MetalheadWrapperLayer(model, :metalhead, false)
end
end
Expand Down

0 comments on commit e89017c

Please sign in to comment.