From c30b4da1719d1425cbf6992654757f8293459327 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 Sep 2024 12:47:14 -0400 Subject: [PATCH] fix: try disabling force_preserve --- ext/BoltzMetalheadExt.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ext/BoltzMetalheadExt.jl b/ext/BoltzMetalheadExt.jl index f98fa41..05f04c0 100644 --- a/ext/BoltzMetalheadExt.jl +++ b/ext/BoltzMetalheadExt.jl @@ -11,7 +11,7 @@ Utils.is_extension_loaded(::Val{:Metalhead}) = true function Vision.ResNetMetalhead(depth::Int; pretrained::Bool=false) @argcheck depth in (18, 34, 50, 101, 152) - return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.ResNet( + return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.ResNet( depth; pretrain=pretrained).layers) end @@ -19,24 +19,24 @@ function Vision.ResNeXtMetalhead( depth::Int; cardinality=32, base_width=nothing, pretrained::Bool=false) @argcheck depth in (50, 101, 152) base_width = base_width === nothing ? (depth == 101 ? 8 : 4) : base_width - return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.ResNeXt( + return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.ResNeXt( depth; pretrain=pretrained, cardinality, base_width).layers) end function Vision.GoogLeNetMetalhead(; pretrained::Bool=false) - return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.GoogLeNet(; + return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.GoogLeNet(; pretrain=pretrained).layers) end function Vision.DenseNetMetalhead(depth::Int; pretrained::Bool=false) @argcheck depth in (121, 161, 169, 201) - return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.DenseNet( + return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.DenseNet( depth; pretrain=pretrained).layers) end function Vision.MobileNetMetalhead(name::Symbol; pretrained::Bool=false) @argcheck name in (:v1, :v2, :v3_small, :v3_large) - adaptor = FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true) + adaptor = FromFluxAdaptor(; preserve_ps_st=pretrained) model = if name == :v1 adaptor(Metalhead.MobileNetv1(; pretrain=pretrained).layers) elseif name == :v2 @@ -51,18 +51,18 @@ end function Vision.ConvMixerMetalhead(name::Symbol; pretrained::Bool=false) @argcheck name in (:base, :large, :small) - return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.ConvMixer( + return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.ConvMixer( name; pretrain=pretrained).layers) end function Vision.SqueezeNetMetalhead(; pretrained::Bool=false) - return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.SqueezeNet(; + return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.SqueezeNet(; pretrain=pretrained).layers) end function Vision.WideResNetMetalhead(depth::Int; pretrained::Bool=false) @argcheck depth in (18, 34, 50, 101, 152) - return FromFluxAdaptor(; preserve_ps_st=pretrained, force_preserve=true)(Metalhead.WideResNet( + return FromFluxAdaptor(; preserve_ps_st=pretrained)(Metalhead.WideResNet( depth; pretrain=pretrained).layers) end