Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve literal function and module detection #1400

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ Convert `dx` from the format Zygote uses internally to differentials types Chain
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(dxs::AbstractArray{T}) where {T<:AbstractZero} = first(dxs)
@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, dxs)
# This produces Tangent{Any} since it does not get to see the primal, `x`.
Expand Down
82 changes: 48 additions & 34 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,38 @@ unwrapquote(x::QuoteNode) = x.value

is_getproperty(ex) = iscall(ex, Base, :getproperty)

# Allows us to resolve constants which have been stored in Variables.
# e.g. `%1 = 1; %2 = %1``, or `%1 = identity; %1(...)`.
trylookup(ir::IR, @nospecialize(v)) = v
trylookup(ir::IR, v::Variable) = haskey(ir, v) ? trylookup(ir, ir[v].expr) : v
# Only resolve GlobalRefs to traverse module hierarchies
function trylookup(ir::IR, ref::GlobalRef)
isconst(ref.mod, ref.name) || return ref
val = getproperty(ref.mod, ref.name)
return val isa Module ? val : ref
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is probably going to be type unstable, you may want to write this as something like

function trylookup(ir::IR, v) # maybe @nospecialize(v) ?
    while true
        if v isa Variable
            v = ir[v].expr
        elseif v isa GlobalRef
           isconst(ref.mod, ref.name) || return ref
           val = getproperty(ref.mod, ref.name)
           return val isa Module ? val : ref
        else
           return v
        end
    end
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Mason. I remember thinking about a loop and rejecting it, but I can't think of why because the snippet you posted looks just fine. Will try incorporating it when I return to this PR.



# The initial premise of literal_getproperty was in some ways inherently flawed, because for
# getproperty it was intended that _pullback falls back to literal_getproperty, but we actually
# want the opposite to happen, since Zygote should fall back to recursing into the getproperty
# implementation by default. Users still want to define custom adjoints using only
# literal_getproperty, though. We can't really have mutually recursive definitions here, so we
# now always instrument getproperty as literal_getproperty, no matter whether the second
# argument is a literal or not.
function instrument_getproperty!(ir, v, ex)
if is_getproperty(ex)
obj, prop = ex.args[2], ex.args[3]
if obj isa Module && prop isa QuoteNode && isconst(obj, unwrapquote(prop))
function instrument_getproperty!(ir::Pipe, v, ex)
func = trylookup(ir.from, ex.args[1])
if func == GlobalRef(Base, :getproperty) && length(ex.args) >= 3
obj, prop = ex.args[2], trylookup(ir.from, ex.args[3])
original = trylookup(ir.from, obj)
if original isa Module && prop isa QuoteNode && isconst(original, unwrapquote(prop))
# Metaprogramming can generate getproperty(::Module, ...) calls.
# Like other types, these are type unstable without constprop.
# However, literal_getproperty's heuristic is also not general enough for modules.
# Thankfully, we can skip instrumenting these if they're const properties.
ex
elseif prop isa Union{QuoteNode,Integer}
return ex
end
if prop isa Union{QuoteNode,Integer}
ir[v] = xcall(Zygote, :literal_getproperty, obj, Val(unwrapquote(prop)))
else
f = insert!(ir, v, :(Val($(prop))))
Expand All @@ -55,45 +70,48 @@ function instrument_getproperty!(ir, v, ex)
end
end

is_literal_getfield(ex) =
(iscall(ex, Core, :getfield) || iscall(ex, Base, :getfield)) &&
ex.args[3] isa Union{QuoteNode,Integer}

# Here, only instrumenting getfield with literals is fine, since users should never have to
# define custom adjoints for literal_getfield
function instrument_getfield!(ir, v, ex)
if is_literal_getfield(ex)
ir[v] = xcall(Zygote, :literal_getfield, ex.args[2], Val(unwrapquote(ex.args[3])))
else
ex
func = trylookup(ir.from, ex.args[1])
if func == GlobalRef(Core, :getfield) || func == GlobalRef(Base, :getfield)
obj, field = ex.args[2], trylookup(ir.from, ex.args[3])
if field isa Union{QuoteNode,Integer}
call = xcall(Zygote, :literal_getfield, obj, Val(unwrapquote(field)))
return ir[v] = call
end
end
return ex
end

is_literal_getindex(ex) =
iscall(ex, Base, :getindex) && length(ex.args) == 3 && ex.args[3] isa Union{Integer,QuoteNode}

# TODO: is this always correct for user defined getindex methods?
function instrument_getindex!(ir, v, ex)
if is_literal_getindex(ex)
ir[v] = xcall(Zygote, :literal_getindex, ex.args[2], Val(unwrapquote(ex.args[3])))
else
ex
func = trylookup(ir.from, ex.args[1])
if func == GlobalRef(Base, :getindex) && length(ex.args) == 3
obj, idx = ex.args[2], trylookup(ir.from, ex.args[3])
if idx isa Union{QuoteNode,Integer}
call = xcall(Zygote, :literal_getindex, obj, Val(unwrapquote(idx)))
return ir[v] = call
end
end
return ex
end

is_literal_iterate(ex) =
iscall(ex, Base, :indexed_iterate) && length(ex.args) >= 3 && ex.args[3] isa Union{Integer,QuoteNode}

function instrument_iterate!(ir, v, ex)
if is_literal_iterate(ex)
ir[v] = xcall(Zygote, :literal_indexed_iterate, ex.args[2],
Val(unwrapquote(ex.args[3])), ex.args[4:end]...)
else
ex
func = ex.args[1]
if func == GlobalRef(Base, :indexed_iterate) && length(ex.args) >= 3
obj, idx, rest = ex.args[2], trylookup(ir.from, ex.args[3]), ex.args[4:end]
if idx isa Union{QuoteNode,Integer}
call = xcall(Zygote, :literal_indexed_iterate, obj, Val(unwrapquote(idx)), rest...)
return ir[v] = call
end
end
return ex
end

function instrument_literals!(ir, v, ex)
isexpr(ex, :call) || return ex
ex = instrument_getproperty!(ir, v, ex)
ex = instrument_getfield!(ir, v, ex)
ex = instrument_getindex!(ir, v, ex)
Expand Down Expand Up @@ -177,14 +195,10 @@ ignored_f(ir, f) = ignored_f(f)
ignored_f(ir, f::Variable) = ignored_f(get(ir, f, nothing))

function ignored(ir, ex)
isexpr(ex, :call) || return false
f = ex.args[1]
f = trylookup(ir, ex.args[1])
ignored_f(ir, f) && return true
if f isa Variable && haskey(ir, f)
f = ir[f].expr
end
if f == GlobalRef(Base, :getproperty) && length(ex.args) >= 3
obj, prop = ex.args[2], ex.args[3]
obj, prop = trylookup(ir, ex.args[2]), trylookup(ir, ex.args[3])
# Metaprogramming can generate getproperty(::Module, ...) calls.
# These are type unstable without constprop, which transforming to _pullback breaks.
# However, we can skip differentiating these if they're const properties.
Expand Down
2 changes: 2 additions & 0 deletions test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ module MyMod
end

@eval usesmod(x) = Base.getproperty($MyMod, :func)(x, Base.getproperty($MyMod, :C))
usesmod2(x) = Base.getproperty(MyMod, :func)(x, Base.getproperty(MyMod, :C))

@testset "inference for `getproperty`" begin
Gaussian = _Gaussian(:getproperty)
Expand Down Expand Up @@ -221,6 +222,7 @@ end

# Const properties on modules should be lowered as-is (not differentiated)
@test @inferred gradient(usesmod, 1)[1] == 1.0
@test @inferred gradient(usesmod2, 1)[1] == 1.0
end

# issue 897
Expand Down