From 72fc49f9c9d66a3e171100a8393073f77787fe2e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 12 Apr 2024 18:23:23 -0400 Subject: [PATCH] Prettify the printing --- .typos.toml | 2 ++ examples/GravitationalWaveForm/main.jl | 14 +++++++------- examples/NeuralODE/main.jl | 6 +++--- examples/SimpleRNN/main.jl | 3 +-- src/helpers/compact.jl | 24 +++++++++++++++++++----- test/helpers/compact_tests.jl | 6 +++--- 6 files changed, 35 insertions(+), 20 deletions(-) create mode 100644 .typos.toml diff --git a/.typos.toml b/.typos.toml new file mode 100644 index 000000000..e2b3e6f9a --- /dev/null +++ b/.typos.toml @@ -0,0 +1,2 @@ +[default.extend-words] +numer = "numer" \ No newline at end of file diff --git a/examples/GravitationalWaveForm/main.jl b/examples/GravitationalWaveForm/main.jl index 16658484e..7d63a5ce9 100644 --- a/examples/GravitationalWaveForm/main.jl +++ b/examples/GravitationalWaveForm/main.jl @@ -163,7 +163,7 @@ function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where { m₁ = mass_ratio * m₂ orbit₁, orbit₂ = one2two(orbit, m₁, m₂) - waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2) + waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂) else waveform = h_22_strain_one_body(dt, orbit) end @@ -183,11 +183,11 @@ end function RelativisticOrbitModel(u, (p, M, e), t) χ, ϕ = u - number = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2 + numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2 denom = sqrt((p - 2)^2 - 4 * e^2) - χ̇ = number * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom) - ϕ̇ = number / (M * (p^(3 / 2)) * denom) + χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom) + ϕ̇ = numer / (M * (p^(3 / 2)) * denom) return [χ̇, ϕ̇] end @@ -260,11 +260,11 @@ function ODE_model(u, nn_params, t) ## it, however, in general, we should use `st` to store the state of the neural network. y = 1 .+ nn_model([first(u)], nn_params) - number = (1 + e * cos(χ))^2 + numer = (1 + e * cos(χ))^2 denom = M * (p^(3 / 2)) - χ̇ = (number / denom) * y[1] - ϕ̇ = (number / denom) * y[2] + χ̇ = (numer / denom) * y[1] + ϕ̇ = (numer / denom) * y[2] return [χ̇, ϕ̇] end diff --git a/examples/NeuralODE/main.jl b/examples/NeuralODE/main.jl index f8863017a..2901534a9 100644 --- a/examples/NeuralODE/main.jl +++ b/examples/NeuralODE/main.jl @@ -236,11 +236,11 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3)); @code_warntype model_stateful(x, ps_stateful, st_stateful) +# Note, that we still recommend using this layer internally and not exposing this as the +# default API to the users. + # Finally checking the compact model model_compact, ps_compact, st_compact = create_model(NeuralODECompact) @code_warntype model_compact(x, ps_compact, st_compact) - -# Note, that we still recommend using this layer internally and not exposing this as the -# default API to the users. diff --git a/examples/SimpleRNN/main.jl b/examples/SimpleRNN/main.jl index 6ea8f92bc..0f1791408 100644 --- a/examples/SimpleRNN/main.jl +++ b/examples/SimpleRNN/main.jl @@ -108,8 +108,7 @@ end function SpiralClassifierCompact(in_dims, hidden_dims, out_dims) lstm_cell = LSTMCell(in_dims => hidden_dims) classifier = Dense(hidden_dims => out_dims, sigmoid) - return @compact(; lstm_cell=lstm_cell, - classifier=classifier) do x::AbstractArray{T, 3} where {T} + return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T} x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2))) y, carry = lstm_cell(x_init) for x in x_rest diff --git a/src/helpers/compact.jl b/src/helpers/compact.jl index b8e6088ca..399d80f9e 100644 --- a/src/helpers/compact.jl +++ b/src/helpers/compact.jl @@ -200,8 +200,6 @@ function __compact_macro_impl(_exs...) vars = map(first ∘ Base.Fix2(getproperty, :args), kwexs) fex = supportself(fex, vars, splatted_kwargs) - display(fex) - # assemble return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block), (($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...)))) @@ -346,9 +344,8 @@ function CompactLuxLayer{dispatch}( if is_lux_layer setup_strings = merge(setup_strings, NamedTuple((name => val,))) else - setup_strings = merge(setup_strings, - NamedTuple((name => sprint( - show, val; context=(:compact => true, :limit => true)),))) + setup_strings = merge( + setup_strings, NamedTuple((name => __kwarg_descriptor(val),))) end end @@ -410,3 +407,20 @@ function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing end return end + +function __kwarg_descriptor(val) + val isa Number && return string(val) + val isa AbstractArray && return sprint(Base.array_summary, val, axes(val)) + val isa Tuple && return "(" * join(map(__kwarg_descriptor, val), ", ") * ")" + if val isa NamedTuple + fields = fieldnames(typeof(val)) + strs = [] + for fname in fields[1:min(length(fields), 3)] + internal_val = getfield(val, fname) + push!(strs, "$fname = $(__kwarg_descriptor(internal_val))") + end + return "@NamedTuple{$(join(strs, ", "))" * (length(fields) > 3 ? ", ..." : "") * "}" + end + val isa Function && return sprint(show, val; context=(:compact => true, :limit => true)) + return lazy"$(nameof(typeof(val)))(...)" +end diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index bd7f35cf9..caa37fe3d 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -180,7 +180,7 @@ return w(x .* s) end expected_string = """@compact( - x = randn(32), + x = 32-element Vector{Float64}, w = Dense(32 => 32), # 1_056 parameters ) do s return w(x .* s) @@ -197,8 +197,8 @@ end expected_string = """@compact( w1 = Model(32)(), # 1_024 parameters - w2 = randn(32, 32), - w3 = randn(32), + w2 = 32×32 Matrix{Float64}, + w3 = 32-element Vector{Float64}, ) do x return w2 * w1(x) end # Total: 2_080 parameters,