From 2c09692d02295d7e3013313483a6f999ce6f60bb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 27 Jul 2024 16:46:21 -0700 Subject: [PATCH] feat: easy mechanism to set preferences --- LocalPreferences.toml | 2 -- Project.toml | 4 ++- docs/src/api/Lux/utilities.md | 6 ++++ docs/src/manual/performance_pitfalls.md | 7 +++++ docs/src/manual/preferences.md | 11 +++++++ src/Lux.jl | 3 +- src/preferences.jl | 39 +++++++++++++++++++++++++ test/runtests.jl | 23 +++++++++++++++ test/shared_testsetup.jl | 2 ++ 9 files changed, 93 insertions(+), 4 deletions(-) delete mode 100644 LocalPreferences.toml diff --git a/LocalPreferences.toml b/LocalPreferences.toml deleted file mode 100644 index bfc941cb4..000000000 --- a/LocalPreferences.toml +++ /dev/null @@ -1,2 +0,0 @@ -[LuxTestUtils] -target_modules = ["Lux", "LuxCore", "LuxLib"] diff --git a/Project.toml b/Project.toml index 52f436bbb..80f79722c 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -75,6 +76,7 @@ Compat = "4.15" ComponentArrays = "0.15.11" ConcreteStructs = "0.2.3" ConstructionBase = "1.5" +DispatchDoctor = "0.4.12" Documenter = "1.4" DynamicExpressions = "0.16, 0.17, 0.18" Enzyme = "0.12.24" @@ -95,7 +97,7 @@ LossFunctions = "0.11.1" LuxCore = "0.1.16" LuxDeviceUtils = "0.1.26" LuxLib = "0.3.33" -LuxTestUtils = "0.1.15" +LuxTestUtils = "0.1.18" MLUtils = "0.4.3" MPI = "0.20.19" MacroTools = "0.5.13" diff --git a/docs/src/api/Lux/utilities.md b/docs/src/api/Lux/utilities.md index 53faad239..0369fc5b8 100644 --- a/docs/src/api/Lux/utilities.md +++ b/docs/src/api/Lux/utilities.md @@ -114,6 +114,12 @@ StatefulLuxLayer @non_trainable ``` +## Preferences + +```@docs +Lux.set_dispatch_doctor_preferences! +``` + ## Truncated Stacktraces (Deprecated) ```@docs diff --git a/docs/src/manual/performance_pitfalls.md b/docs/src/manual/performance_pitfalls.md index 55ae1da59..24a17dc14 100644 --- a/docs/src/manual/performance_pitfalls.md +++ b/docs/src/manual/performance_pitfalls.md @@ -61,3 +61,10 @@ using: using GPUArraysCore GPUArraysCore.allowscalar(false) ``` + +## Type Instabilities + +`Lux.jl` is integrated with `DispatchDoctor.jl` to catch type instabilities. You can easily +enable it by setting the `instability_check` preference. This will help you catch type +instabilities in your code. For more information on how to set preferences, check out +[`set_dispatch_doctor_preferences`](@ref). diff --git a/docs/src/manual/preferences.md b/docs/src/manual/preferences.md index 0496fdcb6..88117b2ad 100644 --- a/docs/src/manual/preferences.md +++ b/docs/src/manual/preferences.md @@ -46,3 +46,14 @@ By default, both of these preferences are set to `false`. 1. `eltype_mismatch_handling` - Preference controlling what happens when layers get different eltypes as input. See the documentation on [`match_eltype`](@ref) for more details. + +## [Dispatch Doctor](@id dispatch-doctor-preference) + +1. `instability_check` - Preference controlling the dispatch doctor. See the documentation + on [`set_dispatch_doctor_preferences!`](@ref) for more details. The preferences need to + be set for `LuxCore` and `LuxLib` packages. Both of them default to `disable`. + - Setting the `LuxCore` preference sets the check at the level of `LuxCore.apply`. This + essentially activates the dispatch doctor for all Lux layers. + - Setting the `LuxLib` preference sets the check at the level of functional layer of + Lux, for example, [`fused_dense_bias_activation`](@ref). These functions are supposed + to be type stable for common input types and can be used to guarantee type stability. diff --git a/src/Lux.jl b/src/Lux.jl index 73cddcbc9..fdd32cb7b 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -20,7 +20,7 @@ using MacroTools: MacroTools, block, combinedef, splitdef using Markdown: @doc_str using NNlib: NNlib using Optimisers: Optimisers -using Preferences: load_preference, has_preference +using Preferences: load_preference, has_preference, set_preferences! using Random: Random, AbstractRNG using Reexport: @reexport using Statistics: mean @@ -133,6 +133,7 @@ export MPIBackend, NCCLBackend, DistributedUtils # Unexported functions that are part of the public API @compat public Experimental @compat public xlogx, xlogy +@compat public set_dispatch_doctor_preferences! @compat(public, (recursive_add!!, recursive_copyto!, recursive_eltype, recursive_make_zero, recursive_map, recursive_make_zero!!)) diff --git a/src/preferences.jl b/src/preferences.jl index 513d2de65..6356f009a 100644 --- a/src/preferences.jl +++ b/src/preferences.jl @@ -38,3 +38,42 @@ const MPI_ROCM_AWARE = @deprecate_preference("LuxDistributedMPIROCMAware", "rocm # Eltype Auto Conversion const ELTYPE_MISMATCH_HANDLING = @load_preference_with_choices("eltype_mismatch_handling", "none", ("none", "warn", "convert", "error")) + +# Dispatch Doctor +""" + set_dispatch_doctor_preferences!(mode::String) + set_dispatch_doctor_preferences!(; luxcore::String="disable", luxlib::String="disable") + +Set the dispatch doctor preference for `LuxCore` and `LuxLib` packages. + +`mode` can be `"disable"`, `"warn"`, or `"error"`. For details on the different modes, see +the [DispatchDoctor.jl](https://astroautomata.com/DispatchDoctor.jl/dev/) documentation. + +If the preferences are already set, then no action is taken. Otherwise the preference is +set. For changes to take effect, the Julia session must be restarted. +""" +function set_dispatch_doctor_preferences!(mode::String) + return set_dispatch_doctor_preferences!(; luxcore=mode, luxlib=mode) +end + +function set_dispatch_doctor_preferences!(; + luxcore::String="disable", luxlib::String="disable") + _set_dispatch_doctor_preferences!(LuxCore, luxcore) + _set_dispatch_doctor_preferences!(LuxLib, luxlib) + return +end + +function _set_dispatch_doctor_preferences!(package, mode::String) + @argcheck mode in ("disable", "warn", "error") + if has_preference(package, "dispatch_doctor") + orig_pref = load_preference(package, "dispatch_doctor") + if orig_pref == mode + @info "Dispatch Doctor preference for $(package) is already set to $mode." + return + end + end + set_preferences!(package, "instability_check" => mode; force=true) + @info "Dispatch Doctor preference for $(package) set to $mode. Please restart Julia \ + for this change to take effect." + return +end diff --git a/test/runtests.jl b/test/runtests.jl index 2427e03fe..08ac3e676 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -131,3 +131,26 @@ if ("all" in LUX_TEST_GROUP || "eltype_match" in LUX_TEST_GROUP) Test.@test true end end + +# Set preferences tests +if ("all" in LUX_TEST_GROUP || "others" in LUX_TEST_GROUP) + @testset "DispatchDoctor Preferences" begin + @testset "set_dispatch_doctor_preferences!" begin + @test_throws ArgumentError Lux.set_dispatch_doctor_preferences!("invalid") + @test_throws ArgumentError Lux.set_dispatch_doctor_preferences!(; + luxcore="invalid") + + Lux.set_dispatch_doctor_preferences!("disable") + @test Preferences.load_preference(LuxCore, "instability_check") == "disable" + @test Preferences.load_preference(LuxLib, "instability_check") == "disable" + + Lux.set_dispatch_doctor_preferences!(; luxcore="warn", luxlib="error") + @test Preferences.load_preference(LuxCore, "instability_check") == "warn" + @test Preferences.load_preference(LuxLib, "instability_check") == "error" + + Lux.set_dispatch_doctor_preferences!(; luxcore="error") + @test Preferences.load_preference(LuxCore, "instability_check") == "error" + @test Preferences.load_preference(LuxLib, "instability_check") == "disable" + end + end +end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index 3fcfdda67..2aa61c059 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -9,6 +9,8 @@ using Lux, Functors Zygote, Statistics using LuxTestUtils: @jet, @test_gradients, check_approx +LuxTestUtils.jet_target_modules!(["Lux", "LuxCore", "LuxLib"]) + # Some Helper Functions function get_default_rng(mode::String) dev = mode == "cpu" ? LuxCPUDevice() :