Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNlib Support #171

Open
willtebbutt opened this issue May 28, 2024 · 11 comments
Open

NNlib Support #171

willtebbutt opened this issue May 28, 2024 · 11 comments
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers high priority

Comments

@willtebbutt
Copy link
Member

#169 highlights that we need rules for a range of functionality which lives in NNlib.jl -- this is not surprising, and has largely already been done (see e.g. the Enzyme extension). Someone needs to systematically work through and test that Tapir works on everything in NNlib.

@willtebbutt willtebbutt added enhancement New feature or request good first issue Good for newcomers labels May 28, 2024
@yebai
Copy link
Contributor

yebai commented May 28, 2024

Can we re-use the ChainRules implementation? See, e.g. the rule for dropout.

@willtebbutt
Copy link
Member Author

In short: yes. We don't have a nice macro for doing it at the minute, but it ought to be quite doable provided that we constrain the types that impose some of our own constraints on the types.

In the case of dropout: this one actually might be a little tricky because it uses kwargs, and I've not looked at supporting kwargs yet (this largely happen automatically), but it should otherwise be fine.

Something like batched_transpose should be quite straightforward. You would just do something along the lines of

function rrule!!(::CoDual{typeof(batched_transpose)}, A::CoDual{Array{<:Any,3}})
    B, pb = rrule(batched_transpose, primal(A)
    function pb!!(dB)
        _, dA_inc = pb(dB)
        increment!!(tangent(A), dA_inc)
        return NoRData(), NoRData()
    end
    return zero_fcodual(B), pb!!
end

Note that I've restricted the array type here -- we would just need to be careful with how we extend this rule to other array types, as it's only valid in Tapir.jl for arrays whose tangents are arrays, rather than Tangents.

@yebai
Copy link
Contributor

yebai commented Jul 3, 2024

Note that I've restricted the array type here -- we would just need to be careful with how we extend this rule to other array types, as it's only valid in Tapir.jl for arrays whose tangents are arrays, rather than Tangents.

Thanks @willtebbutt. A follow-up question on this topic: if we put mutation aside, is it true in general, ChainRules rules are a superset of Tapir rules, in the sense that one can always get a Tapir rule by restricting argument types on a ChainRules rule.

@willtebbutt
Copy link
Member Author

In the narrow sense that lots of ChainRules.rrules are written with abstract arguments: yes.

However, recall that this gives you an overly optimistic sense of how much more general a ChainRules.rrule is. It's not very hard to construct e.g. subtypes of AbstractArray{<:Any, 3} for which this won't work.

@yebai
Copy link
Contributor

yebai commented Jul 3, 2024

In the narrow sense that lots of ChainRules.rrules are written with abstract arguments: yes.

Was the current ChainRules design motivated by something, or is it an oversight?

@willtebbutt
Copy link
Member Author

willtebbutt commented Jul 3, 2024

A bit of both.

It's motivated by trying to write rules for things like Zygote, where you really want your rule system to do a lot of work, because Zygote can't differentiate much itself due to lack of mutation support.

It's an oversight in the sense that we didn't realise how hard it is to write rules which work for lots of types that are robust. The upshot is the kinds of issues we've discussed previously (mainly around poor composition performance).

@yebai
Copy link
Contributor

yebai commented Jul 3, 2024

But ChainRules doesn't prevent one from writing rules with concrete argument types in principle. Is that correct?

Or does writing rules with concrete argument types always require something like typed tangents? IIUC, the typed tangents in Tapirs seem like an extension of ChainRules via concrete types rather than something incompatible.

@willtebbutt
Copy link
Member Author

But ChainRules doesn't prevent one from writing rules with concrete argument types in principle. Is that correct?

Correct. Equally, Tapir.jl doesn't prevent you from writing rules with abstract arguments types, it's just not generally a good idea.

Or does writing rules with concrete argument types always require something like typed tangents? IIUC, the typed tangents in Tapirs seem like an extension of ChainRules via concrete types rather than something incompatible.

I'm not sure what you mean by "typed tangents" -- could you expand?

@yebai
Copy link
Contributor

yebai commented Jul 5, 2024

I'm not sure what you mean by "typed tangents" -- could you expand?

I am referring to the fact that, in Tapir, each primal type has a unique tangent type. Do writing ChainRule rules with concrete argument types always require this unique tangent-type assumption?

@willtebbutt
Copy link
Member Author

Hmm okay. Could you provide a small code example showing what you mean? I'm still not quite following what you're trying to say / ask, so I think I need some examples.

@willtebbutt
Copy link
Member Author

This is partially addressed by #254, but there remains work to be done on GPU integration, and generalising the functionality added to AbstractArrays other than Arrays. Consequently, I'm going to keep this issue open.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers high priority
Projects
None yet
Development

No branches or pull requests

2 participants