Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Connecting to ChainRulesCore for JuliaLang AD compat #652

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

oxinabox
Copy link
Contributor

@oxinabox oxinabox commented Sep 19, 2021

Early WIP.

  • Calling Atoms
    • dexize helper function
    • 1-arg Atoms
    • Multiarg Atoms
  • NativeFunctions
    • having dexfunc compile AD helper functions at parse time.
    • frule
      • 1-arg functions without implicts
      • functions with implicts
      • multiarg functions
    • rrule
  • Hygiene around insert calls (probably make something gensym-like)
  • Clean up of old environments around insert calls.
  • Destruction of enviroments created by DexModule (currently commented out)

What works right now:

using Test, DexCall, Zygote
const double_dex = evaluate(raw"\x:Float. 2.0 * x")
const double_via_dex = juliaize  double_dex  dexize
y, pb= Zygote.pullback(double_via_dex, 1.5f0)
@test y == 3f0
@test pb(1f0) == (2f0,)

@google-cla google-cla bot added the cla: yes label Sep 19, 2021
@@ -87,6 +87,8 @@ julia> typeof(convert(Int64, m.x))
Int64
```

The inverse of `juliaize` is `dexize`, it is currently very limited.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't given it proper docs yet, because it only does Float32.
It's mostly just for testing purposes.
We need a proper API for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can extend it if you'd like. Just let me know what would be helpful to have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it is something like create_literal that works like insert except instead of taking an Atom it takes a C compatible value for a Int/Float32/Float64/Array.
Possibly ctypes doesn't allow that directly, so maybe it needs to be wrapped into a tagged union?
I guess maybe accepting the same tagged union that comes out of the atom's pointer makes sense.
(Except right now that doesn't support arrays)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a function that converts a CAtom into an Atom be sufficient? That's what I would imagine. And then we can add more cases to CAtom if that would be useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think that is basically what I said badly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, which cases would you like to have? I'm happy to add them for you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float32, Float64 (particularly since can't input those as literals #497)
Arrays would be nice, but given we can't currently convert Atom to CAtom for arrays anyway, that doesn't matter so much.
Integer types would be nice, for completeness, but not particularly interesting for AD.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added that in #657


NB: this is currently a hack that goes via string processing.
"""
function dexize(x::Float32, _module=PRELUDE, env=_module)
Copy link
Contributor Author

@oxinabox oxinabox Sep 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we do not need to take a env and a _module argument.
We are making literals, we just need one for where the literal will exist?
I don't really understand the difference between them

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _module and env parameters are a bit of a hack that I added for Python bindings. The rough idea was that module is the module an output atom declares itself to be defined in, while env is the scope that's really used to evaluate the expression. This is used in the __call__ implementation where we temporarily extend the prelude with new names that refer to arguments, but then we want to pretend that the result is still defined in the original module that doesn't have those dummies. But now that I think about it, it's only well defined for non-dependent functions, so we should find a different workaround...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I think you can safely ignore env and just use the _module.

env
)

# It is important that we close over `m` as otherwise the env may be destroyed by GC
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer the case since right now that finalizers is commented out.
And before this is merged we may way move to working out how to attach the finalizer to the context more directly.

Comment on lines +96 to +102
# TODO: Undo commenting this out. But for now this causes a lot of problems.
# DexModule will often go out of scope, while a Atom attached to that context still
# exists. Possibly we need to make the ctx a mutable struct everywhere, and then
# attach the finalizer there.
#(also will let us delete some manual destroys in other palces)

#destroy_context(getfield(_m, :ctx))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs resolving


include("api_types.jl")
include("api.jl")
include("evaluate.jl")
include("native_function.jl")
include("chainrules.jl")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include("chainrules.jl")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants