Skip to content

Commit

Permalink
Prettify the printing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 12, 2024
1 parent 3981a2c commit 72fc49f
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 20 deletions.
2 changes: 2 additions & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[default.extend-words]
numer = "numer"
14 changes: 7 additions & 7 deletions examples/GravitationalWaveForm/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ function compute_waveform(dt::T, soln, mass_ratio, model_params=nothing) where {
m₁ = mass_ratio * m₂

orbit₁, orbit₂ = one2two(orbit, m₁, m₂)
waveform = h_22_strain_two_body(dt, orbit1, mass1, orbit2, mass2)
waveform = h_22_strain_two_body(dt, orbit₁, m₁, orbit₂, m₂)
else
waveform = h_22_strain_one_body(dt, orbit)
end
Expand All @@ -183,11 +183,11 @@ end
function RelativisticOrbitModel(u, (p, M, e), t)
χ, ϕ = u

number = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
numer = (p - 2 - 2 * e * cos(χ)) * (1 + e * cos(χ))^2
denom = sqrt((p - 2)^2 - 4 * e^2)

χ̇ = number * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = number / (M * (p^(3 / 2)) * denom)
χ̇ = numer * sqrt(p - 6 - 2 * e * cos(χ)) / (M * (p^2) * denom)
ϕ̇ = numer / (M * (p^(3 / 2)) * denom)

return [χ̇, ϕ̇]
end
Expand Down Expand Up @@ -260,11 +260,11 @@ function ODE_model(u, nn_params, t)
## it, however, in general, we should use `st` to store the state of the neural network.
y = 1 .+ nn_model([first(u)], nn_params)

number = (1 + e * cos(χ))^2
numer = (1 + e * cos(χ))^2
denom = M * (p^(3 / 2))

χ̇ = (number / denom) * y[1]
ϕ̇ = (number / denom) * y[2]
χ̇ = (numer / denom) * y[1]
ϕ̇ = (numer / denom) * y[2]

return [χ̇, ϕ̇]
end
Expand Down
6 changes: 3 additions & 3 deletions examples/NeuralODE/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ x = gpu_device()(ones(Float32, 28, 28, 1, 3));

@code_warntype model_stateful(x, ps_stateful, st_stateful)

# Note, that we still recommend using this layer internally and not exposing this as the
# default API to the users.

# Finally checking the compact model

model_compact, ps_compact, st_compact = create_model(NeuralODECompact)

@code_warntype model_compact(x, ps_compact, st_compact)

# Note, that we still recommend using this layer internally and not exposing this as the
# default API to the users.
3 changes: 1 addition & 2 deletions examples/SimpleRNN/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ end
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell=lstm_cell,
classifier=classifier) do x::AbstractArray{T, 3} where {T}
return @compact(; lstm_cell, classifier) do x::AbstractArray{T, 3} where {T}
x_init, x_rest = Iterators.peel(Lux._eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
Expand Down
24 changes: 19 additions & 5 deletions src/helpers/compact.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ function __compact_macro_impl(_exs...)
vars = map(first Base.Fix2(getproperty, :args), kwexs)
fex = supportself(fex, vars, splatted_kwargs)

display(fex)

# assemble
return esc(:($CompactLuxLayer{$dispatch}($fex, $name, ($layer, $input, $block),
(($(Meta.quot.(splatted_kwargs)...),), ($(splatted_kwargs...),)); $(kwexs...))))
Expand Down Expand Up @@ -346,9 +344,8 @@ function CompactLuxLayer{dispatch}(
if is_lux_layer
setup_strings = merge(setup_strings, NamedTuple((name => val,)))
else
setup_strings = merge(setup_strings,
NamedTuple((name => sprint(
show, val; context=(:compact => true, :limit => true)),)))
setup_strings = merge(
setup_strings, NamedTuple((name => __kwarg_descriptor(val),)))
end
end

Expand Down Expand Up @@ -410,3 +407,20 @@ function Lux._big_show(io::IO, obj::CompactLuxLayer, indent::Int=0, name=nothing
end
return
end

function __kwarg_descriptor(val)
val isa Number && return string(val)
val isa AbstractArray && return sprint(Base.array_summary, val, axes(val))
val isa Tuple && return "(" * join(map(__kwarg_descriptor, val), ", ") * ")"
if val isa NamedTuple
fields = fieldnames(typeof(val))
strs = []
for fname in fields[1:min(length(fields), 3)]
internal_val = getfield(val, fname)
push!(strs, "$fname = $(__kwarg_descriptor(internal_val))")
end
return "@NamedTuple{$(join(strs, ", "))" * (length(fields) > 3 ? ", ..." : "") * "}"
end
val isa Function && return sprint(show, val; context=(:compact => true, :limit => true))
return lazy"$(nameof(typeof(val)))(...)"
end
6 changes: 3 additions & 3 deletions test/helpers/compact_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
return w(x .* s)
end
expected_string = """@compact(
x = randn(32),
x = 32-element Vector{Float64},
w = Dense(32 => 32), # 1_056 parameters
) do s
return w(x .* s)
Expand All @@ -197,8 +197,8 @@
end
expected_string = """@compact(
w1 = Model(32)(), # 1_024 parameters
w2 = randn(32, 32),
w3 = randn(32),
w2 = 32×32 Matrix{Float64},
w3 = 32-element Vector{Float64},
) do x
return w2 * w1(x)
end # Total: 2_080 parameters,
Expand Down

0 comments on commit 72fc49f

Please sign in to comment.