-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve literal function and module detection #1400
base: master
Are you sure you want to change the base?
Conversation
Previously, instrumentation would be thrown off by storing a function in a variable/SSA value. Likewise, module references like `Main.Base` would be ignored because they're lowered as GlobalRefs.
Buildkite error is spurious. Dealing with PyCall on CI is such a mess... |
isconst(ref.mod, ref.name) || return ref | ||
val = getproperty(ref.mod, ref.name) | ||
return val isa Module ? val : ref | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because this is probably going to be type unstable, you may want to write this as something like
function trylookup(ir::IR, v) # maybe @nospecialize(v) ?
while true
if v isa Variable
v = ir[v].expr
elseif v isa GlobalRef
isconst(ref.mod, ref.name) || return ref
val = getproperty(ref.mod, ref.name)
return val isa Module ? val : ref
else
return v
end
end
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Mason. I remember thinking about a loop and rejecting it, but I can't think of why because the snippet you posted looks just fine. Will try incorporating it when I return to this PR.
I looked into the current literal detection when I discovered that some of my It seems to me that the Zygote.jl/src/compiler/reverse.jl Lines 185 to 193 in 6d983d5
which could be replaced with: ignored_f(f) = f in (Base.not_int,
Core.:(===),
Core.apply_type,
Core.typeof,
Core.throw,
Base.kwerr,
Core.kwfunc,
Core.isdefined) This comes from the fact that if the symbols are unresolved, the julia> f(x) = x == 0 ? throw("invalid") : x
julia> @code_lowered f(1)
CodeInfo(
1 ─ %1 = x == 0
└── goto #3 if not %1
2 ─ %3 = Main.throw("invalid")
└── return %3
3 ─ return x
)
julia> (@code_lowered f(1)).code[3].args[1]
:(Main.throw)
julia> (@code_lowered f(1)).code[3].args[1] == GlobalRef(Core, :throw)
false Since we are already calling |
Interesting, I could've sworn I had a PR which tried to normalize source modules for this exact reason but I can't find it now. I'm also not sure why the original code used GlobalRefs instead of just checking strict equality. Did any tests break when you tried changing to use the latter? |
I have had problems with StackOverflows in the
Here are my changes for reference: https://github.com/FluxML/Zygote.jl/compare/master...Pangoraw:Zygote.jl:check_globalref?expand=1 |
Yes, I think I ran into similar issues at one point. This part of the AD transform is very finicky and difficult to debug... |
Previously, instrumentation would be thrown off by storing a function in a variable/SSA value. Likewise, module references like
Main.Base
would be ignored because they're lowered as GlobalRefs.This continues some of the work started in #1371 and should fix a couple of edge cases.
PR Checklist