Skip to content

Commit

Permalink
fix: pretty printing of MaxPool Layer
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 9, 2024
1 parent ec5841b commit 2c69ae7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.0.0"
version = "1.0.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 3 additions & 1 deletion src/layers/display.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ end
show_leaflike(x) = Functors.isleaf(x) # mostly follow Functors, except for:
show_leaflike(x::AbstractLuxLayer) = false

isa_printable_leaf(x) = false

function underscorise(n::Integer)
return join(reverse(join.(reverse.(Iterators.partition(digits(n), 3)))), '_')
end
Expand All @@ -27,7 +29,7 @@ function big_show(io::IO, obj, indent::Int=0, name=nothing)
return
end
children = printable_children(obj)
if all(show_leaflike, values(children))
if all(show_leaflike, values(children)) || isa_printable_leaf(obj)
layer_show(io, obj, indent, name)
else
println(io, " "^indent, isnothing(name) ? "" : "$name = ", display_name(obj), "(")
Expand Down
12 changes: 9 additions & 3 deletions src/layers/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ for layer_op in (:Max, :Mean, :LP)
window; stride, pad, dilation, p))
end

function Base.show(io::IO, ::MIME"text/plain", m::$(layer_name))
function Base.show(io::IO, m::$(layer_name))
kernel_size = m.layer.mode.kernel_size
print(io, string($(Meta.quot(layer_name))), "($(kernel_size)")
pad = m.layer.mode.pad
Expand All @@ -213,6 +213,8 @@ for layer_op in (:Max, :Mean, :LP)
print(io, ")")
end

PrettyPrinting.isa_printable_leaf(::$(layer_name)) = true

# Global Pooling Layer
@doc $(global_pooling_docstring) @concrete struct $(global_layer_name) <:
AbstractLuxWrapperLayer{:layer}
Expand All @@ -223,14 +225,16 @@ for layer_op in (:Max, :Mean, :LP)
return $(global_layer_name)(PoolingLayer(static(:global), $(Meta.quot(op)); p))
end

function Base.show(io::IO, ::MIME"text/plain", g::$(global_layer_name))
function Base.show(io::IO, g::$(global_layer_name))
print(io, string($(Meta.quot(global_layer_name))), "(")
if $(Meta.quot(op)) == :lp
g.layer.op.p == 2 || print(io, ", p=", g.layer.op.p)
end
print(io, ")")
end

PrettyPrinting.isa_printable_leaf(::$(global_layer_name)) = true

# Adaptive Pooling Layer
@doc $(adaptive_pooling_docstring) @concrete struct $(adaptive_layer_name) <:
AbstractLuxWrapperLayer{:layer}
Expand All @@ -242,12 +246,14 @@ for layer_op in (:Max, :Mean, :LP)
static(:adaptive), $(Meta.quot(op)), out_size; p))
end

function Base.show(io::IO, ::MIME"text/plain", a::$(adaptive_layer_name))
function Base.show(io::IO, a::$(adaptive_layer_name))
print(io, string($(Meta.quot(adaptive_layer_name))), "(", a.layer.mode.out_size)
if $(Meta.quot(op)) == :lp
a.layer.op.p == 2 || print(io, ", p=", a.layer.op.p)
end
print(io, ")")
end

PrettyPrinting.isa_printable_leaf(::$(adaptive_layer_name)) = true
end
end

0 comments on commit 2c69ae7

Please sign in to comment.