Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve literal function and module detection #1400

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ToucheSir
Copy link
Member

Previously, instrumentation would be thrown off by storing a function in a variable/SSA value. Likewise, module references like Main.Base would be ignored because they're lowered as GlobalRefs.

This continues some of the work started in #1371 and should fix a couple of edge cases.

PR Checklist

  • Tests are added
  • [N/A] Documentation, if applicable

Previously, instrumentation would be thrown off by storing a function in a variable/SSA value.
Likewise, module references like `Main.Base` would be ignored because they're lowered as GlobalRefs.
@ToucheSir
Copy link
Member Author

Buildkite error is spurious. Dealing with PyCall on CI is such a mess...

isconst(ref.mod, ref.name) || return ref
val = getproperty(ref.mod, ref.name)
return val isa Module ? val : ref
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this is probably going to be type unstable, you may want to write this as something like

function trylookup(ir::IR, v) # maybe @nospecialize(v) ?
    while true
        if v isa Variable
            v = ir[v].expr
        elseif v isa GlobalRef
           isconst(ref.mod, ref.name) || return ref
           val = getproperty(ref.mod, ref.name)
           return val isa Module ? val : ref
        else
           return v
        end
    end
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Mason. I remember thinking about a loop and rejecting it, but I can't think of why because the snippet you posted looks just fine. Will try incorporating it when I return to this PR.

@Pangoraw
Copy link
Contributor

Pangoraw commented Jun 9, 2023

I looked into the current literal detection when I discovered that some of my throw call were not ignored as they should and found this PR!

It seems to me that the trylookup can be generalized to checks to any GlobalRef for example in the following function:

ignored_f(f) = f in (GlobalRef(Base, :not_int),
GlobalRef(Core.Intrinsics, :not_int),
GlobalRef(Core, :(===)),
GlobalRef(Core, :apply_type),
GlobalRef(Core, :typeof),
GlobalRef(Core, :throw),
GlobalRef(Base, :kwerr),
GlobalRef(Core, :kwfunc),
GlobalRef(Core, :isdefined))

which could be replaced with:

ignored_f(f) = f in (Base.not_int,
                     Core.:(===),
                     Core.apply_type,
                     Core.typeof,
                     Core.throw,
                     Base.kwerr,
                     Core.kwfunc,
                     Core.isdefined)

This comes from the fact that if the symbols are unresolved, the GlobalRef to a call will be with respect to the calling module and these checks will fail. Example where throw is unresolved and placed as GlobalRef(Main,:throw) in the lowered IR:

julia> f(x) = x == 0 ? throw("invalid") : x

julia> @code_lowered f(1)
CodeInfo(
1%1 = x == 0
└──      goto #3 if not %1
2%3 = Main.throw("invalid")
└──      return %3
3return x
)

julia> (@code_lowered f(1)).code[3].args[1]
:(Main.throw)

julia> (@code_lowered f(1)).code[3].args[1] == GlobalRef(Core, :throw)
false

Since we are already calling isconst(::GlobalRef) for any called GlobalRef, Zygote already has the side-effect of resolving the bindings even if they are not used which the compiler will do anyway later (see JuliaLang/julia#44604). Therefore, checking for strict equality (isconst(ref) && getproperty(ref.mod, ref.name) === Base.getproperty for example in isliteral_getproperty) could be beneficial.

@ToucheSir
Copy link
Member Author

ToucheSir commented Jun 9, 2023

Interesting, I could've sworn I had a PR which tried to normalize source modules for this exact reason but I can't find it now. I'm also not sure why the original code used GlobalRefs instead of just checking strict equality. Did any tests break when you tried changing to use the latter?

@Pangoraw
Copy link
Contributor

Pangoraw commented Jun 15, 2023

I have had problems with StackOverflows in the _pullback(::typeof(literal_...)) (I added a fallback to _pullback(typeof(...))) and the following test failures in forward tests where _pushforward(::typeof(literal_indexed_literal)) appears in the stacktrace:

  Expression: D((x->begin
                D((y->begin
                                x = y
                            end), x) * x
            end), 1) == 1
  MethodError: no method matching *(::Nothing, ::Int64)
  
  Closest candidates are:
    *(::Any, ::Any, ::Any, ::Any...)
     @ Base operators.jl:578
    *(::T, ::T) where T<:Union{Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8}
     @ Base int.jl:88
    *(::StridedArray{P}, ::Real) where P<:Dates.Period
     @ Dates ~/.julia/juliaup/julia-1.9.1+0.x64.linux.gnu/share/julia/stdlib/v1.9/Dates/src/deprecated.jl:44

Here are my changes for reference:

https://github.com/FluxML/Zygote.jl/compare/master...Pangoraw:Zygote.jl:check_globalref?expand=1

@ToucheSir
Copy link
Member Author

Yes, I think I ran into similar issues at one point. This part of the AD transform is very finicky and difficult to debug...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants