From 67ad6f7cb7d5ae3be051ae264e1bd6a5e129f493 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 13 Sep 2024 20:26:06 -0400 Subject: [PATCH] test: check for imagenet accuracy --- test/Project.toml | 6 ++++ test/vision_tests.jl | 70 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index a724ae5..644730e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,12 +3,14 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" +Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" @@ -22,6 +24,7 @@ ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -29,12 +32,14 @@ Aqua = "0.8.7" Bumper = "0.6, 0.7" ComponentArrays = "0.15.16" DataInterpolations = "< 5.3" +Downloads = "1.6" DynamicExpressions = "0.16, 0.17, 0.18, 0.19" Enzyme = "0.12" ExplicitImports = "1.9.0" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" Hwloc = "3.2.0" +Images = "0.26" InteractiveUtils = "<0.0.1, 1" Lux = "1" LuxLib = "1" @@ -48,4 +53,5 @@ ReTestItems = "1.24.0" Reexport = "1.2.2" StableRNGs = "1.0.2" Test = "1.10" +TestImages = "1.8" Zygote = "0.6.70" diff --git a/test/vision_tests.jl b/test/vision_tests.jl index fec2b5c..d02fcb2 100644 --- a/test/vision_tests.jl +++ b/test/vision_tests.jl @@ -1,4 +1,38 @@ -@testitem "AlexNet" setup=[SharedTestSetup] tags=[:vision] begin +@testsetup module PretrainedWeightsTestSetup + +using Images, TestImages +using Downloads + +function normalize_imagenet(data) + cmean = reshape(Float32[0.485, 0.456, 0.406], (1, 1, 3, 1)) + cstd = reshape(Float32[0.229, 0.224, 0.225], (1, 1, 3, 1)) + return (data .- cmean) ./ cstd +end + +const TEST_IMG = imresize(testimage("monarch_color_256"), (224, 224)) + +const TEST_X = let img_array = convert(Array{Float32}, channelview(TEST_IMG)) + permutedims(img_array, (3, 2, 1)) |> normalize_imagenet +end + +const TEST_LBLS = readlines(Downloads.download( + "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" +)) + +function imagenet_acctest(model, ps, st, dev) + ps = ps |> dev + st = Lux.testmode(st) |> dev + x = TEST_X |> dev + ypred = first(model(x, ps, st)) |> collect |> vec + top5 = TEST_LBLS[sortperm(ypred; rev=true)] + return "monarch" in top5 +end + +export imagenet_acctest + +end + +@testitem "AlexNet" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES @testset "pretrained: $(pretrained)" for pretrained in [true, false] model = Vision.AlexNet(; pretrained) @@ -9,6 +43,10 @@ @jet model(img, ps, st) @test size(first(model(img, ps, st))) == (1000, 2) + if pretrained + @test imagenet_acctest(model, ps, st, dev) + end + GC.gc(true) end end @@ -56,7 +94,7 @@ end end end -@testitem "ResNet" setup=[SharedTestSetup] tags=[:vision] begin +@testitem "ResNet" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES, depth in [18, 34, 50, 101, 152] @testset for pretrained in [false, true] model = Vision.ResNet(depth; pretrained) @@ -67,12 +105,16 @@ end @jet model(img, ps, st) @test size(first(model(img, ps, st))) == (1000, 2) + if pretrained + @test imagenet_acctest(model, ps, st, dev) + end + GC.gc(true) end end end -@testitem "ResNeXt" setup=[SharedTestSetup] tags=[:vision] begin +@testitem "ResNeXt" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES @testset for (depth, cardinality, base_width) in [ (50, 32, 4), (101, 32, 8), (101, 64, 4), (152, 64, 4)] @@ -87,13 +129,17 @@ end @jet model(img, ps, st) @test size(first(model(img, ps, st))) == (1000, 2) + if pretrained + @test imagenet_acctest(model, ps, st, dev) + end + GC.gc(true) end end end end -@testitem "WideResNet" setup=[SharedTestSetup] tags=[:vision] begin +@testitem "WideResNet" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES, depth in [50, 101, 152] @testset for pretrained in [false, true] depth == 152 && pretrained && continue @@ -106,12 +152,16 @@ end @jet model(img, ps, st) @test size(first(model(img, ps, st))) == (1000, 2) + if pretrained + @test imagenet_acctest(model, ps, st, dev) + end + GC.gc(true) end end end -@testitem "SqueezeNet" setup=[SharedTestSetup] tags=[:vision] begin +@testitem "SqueezeNet" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES @testset for pretrained in [false, true] model = Vision.SqueezeNet(; pretrained) @@ -122,12 +172,16 @@ end @jet model(img, ps, st) @test size(first(model(img, ps, st))) == (1000, 2) + if pretrained + @test imagenet_acctest(model, ps, st, dev) + end + GC.gc(true) end end end -@testitem "VGG" setup=[SharedTestSetup] tags=[:vision] begin +@testitem "VGG" setup=[SharedTestSetup, PretrainedWeightsTestSetup] tags=[:vision] begin for (mode, aType, dev, ongpu) in MODES, depth in [11, 13, 16, 19] @testset for pretrained in [false, true], batchnorm in [false, true] model = Vision.VGG(depth; batchnorm, pretrained) @@ -138,6 +192,10 @@ end @jet model(img, ps, st) @test size(first(model(img, ps, st))) == (1000, 2) + if pretrained + @test imagenet_acctest(model, ps, st, dev) + end + GC.gc(true) end end