Skip to content

Commit

Permalink
test: check for imagenet accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 14, 2024
1 parent c30b4da commit 67ad6f7
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
6 changes: 6 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Expand All @@ -22,19 +24,22 @@ ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Aqua = "0.8.7"
Bumper = "0.6, 0.7"
ComponentArrays = "0.15.16"
DataInterpolations = "< 5.3"
Downloads = "1.6"
DynamicExpressions = "0.16, 0.17, 0.18, 0.19"
Enzyme = "0.12"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
GPUArraysCore = "0.1.6"
Hwloc = "3.2.0"
Images = "0.26"
InteractiveUtils = "<0.0.1, 1"
Lux = "1"
LuxLib = "1"
Expand All @@ -48,4 +53,5 @@ ReTestItems = "1.24.0"
Reexport = "1.2.2"
StableRNGs = "1.0.2"
Test = "1.10"
TestImages = "1.8"
Zygote = "0.6.70"
70 changes: 64 additions & 6 deletions test/vision_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,38 @@
@testitem "AlexNet" setup=[SharedTestSetup] tags=[:vision] begin
@testsetup module PretrainedWeightsTestSetup

using Images, TestImages
using Downloads

function normalize_imagenet(data)
cmean = reshape(Float32[0.485, 0.456, 0.406], (1, 1, 3, 1))
cstd = reshape(Float32[0.229, 0.224, 0.225], (1, 1, 3, 1))
return (data .- cmean) ./ cstd
end

const TEST_IMG = imresize(testimage("monarch_color_256"), (224, 224))

const TEST_X = let img_array = convert(Array{Float32}, channelview(TEST_IMG))
permutedims(img_array, (3, 2, 1)) |> normalize_imagenet
end

const TEST_LBLS = readlines(Downloads.download(
"https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
))

function imagenet_acctest(model, ps, st, dev)
ps = ps |> dev
st = Lux.testmode(st) |> dev
x = TEST_X |> dev
ypred = first(model(x, ps, st)) |> collect |> vec
top5 = TEST_LBLS[sortperm(ypred; rev=true)]
return "monarch" in top5
end

export imagenet_acctest

end

@testitem "AlexNet" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES
@testset "pretrained: $(pretrained)" for pretrained in [true, false]
model = Vision.AlexNet(; pretrained)
Expand All @@ -9,6 +43,10 @@
@jet model(img, ps, st)
@test size(first(model(img, ps, st))) == (1000, 2)

if pretrained
@test imagenet_acctest(model, ps, st, dev)
end

GC.gc(true)
end
end
Expand Down Expand Up @@ -56,7 +94,7 @@ end
end
end

@testitem "ResNet" setup=[SharedTestSetup] tags=[:vision] begin
@testitem "ResNet" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, depth in [18, 34, 50, 101, 152]
@testset for pretrained in [false, true]
model = Vision.ResNet(depth; pretrained)
Expand All @@ -67,12 +105,16 @@ end
@jet model(img, ps, st)
@test size(first(model(img, ps, st))) == (1000, 2)

if pretrained
@test imagenet_acctest(model, ps, st, dev)
end

GC.gc(true)
end
end
end

@testitem "ResNeXt" setup=[SharedTestSetup] tags=[:vision] begin
@testitem "ResNeXt" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES
@testset for (depth, cardinality, base_width) in [
(50, 32, 4), (101, 32, 8), (101, 64, 4), (152, 64, 4)]
Expand All @@ -87,13 +129,17 @@ end
@jet model(img, ps, st)
@test size(first(model(img, ps, st))) == (1000, 2)

if pretrained
@test imagenet_acctest(model, ps, st, dev)
end

GC.gc(true)
end
end
end
end

@testitem "WideResNet" setup=[SharedTestSetup] tags=[:vision] begin
@testitem "WideResNet" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, depth in [50, 101, 152]
@testset for pretrained in [false, true]
depth == 152 && pretrained && continue
Expand All @@ -106,12 +152,16 @@ end
@jet model(img, ps, st)
@test size(first(model(img, ps, st))) == (1000, 2)

if pretrained
@test imagenet_acctest(model, ps, st, dev)
end

GC.gc(true)
end
end
end

@testitem "SqueezeNet" setup=[SharedTestSetup] tags=[:vision] begin
@testitem "SqueezeNet" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES
@testset for pretrained in [false, true]
model = Vision.SqueezeNet(; pretrained)
Expand All @@ -122,12 +172,16 @@ end
@jet model(img, ps, st)
@test size(first(model(img, ps, st))) == (1000, 2)

if pretrained
@test imagenet_acctest(model, ps, st, dev)
end

GC.gc(true)
end
end
end

@testitem "VGG" setup=[SharedTestSetup] tags=[:vision] begin
@testitem "VGG" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin
for (mode, aType, dev, ongpu) in MODES, depth in [11, 13, 16, 19]
@testset for pretrained in [false, true], batchnorm in [false, true]
model = Vision.VGG(depth; batchnorm, pretrained)
Expand All @@ -138,6 +192,10 @@ end
@jet model(img, ps, st)
@test size(first(model(img, ps, st))) == (1000, 2)

if pretrained
@test imagenet_acctest(model, ps, st, dev)
end

GC.gc(true)
end
end
Expand Down

0 comments on commit 67ad6f7

Please sign in to comment.