Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a LogDensityProblemAD extension so we can support Turing #95

Closed
yebai opened this issue Mar 22, 2024 · 26 comments
Closed

Add a LogDensityProblemAD extension so we can support Turing #95

yebai opened this issue Mar 22, 2024 · 26 comments
Labels
enhancement (testing) Would improve the tests enhancement New feature or request

Comments

@yebai
Copy link
Contributor

yebai commented Mar 22, 2024

Let's create a LogDensityProblemADextension within Taped.jl. That should enable us to play with more Turing models before this package gets officially registered.

See, e.g.: https://github.com/tpapp/LogDensityProblemsAD.jl/blob/master/ext/LogDensityProblemsADZygoteExt.jl

@willtebbutt willtebbutt added the enhancement New feature or request label Apr 29, 2024
@yebai
Copy link
Contributor Author

yebai commented Apr 29, 2024

Fixed by #123

@yebai yebai closed this as completed Apr 29, 2024
@willtebbutt willtebbutt reopened this Apr 29, 2024
@willtebbutt
Copy link
Member

This actually isn't quite done -- we still need to specify how interface ADTypes.AutoTapir maps to an instance of the LogDensityFunctionsAD object. This is presently blocked by Turing.jl.

@yebai
Copy link
Contributor Author

yebai commented Apr 30, 2024

This is presently blocked by Turing.jl.

This should be fixed in Turing v0.31.3

@willtebbutt
Copy link
Member

@yebai I think this is now sorted on Tapir v0.2.3.

@yebai @torfjelde is there any reason to add an integration test that runs more of the Turing.jl pipeline than I'm currently doing in my integration tests? e.g. run sampling on a model using the interface that we expect users to play with?

@yebai
Copy link
Contributor Author

yebai commented Apr 30, 2024

I don't think we need to run sampling here, but it would be good to add an additional test for LogDensityProblemsAD on a Turing model.

https://github.com/withbayes/Tapir.jl/blob/7d00e19d7b097f2c5fd4f6be7d07ae549605e147/test/integration_testing/turing.jl#L78

@yebai
Copy link
Contributor Author

yebai commented May 1, 2024

Here is an example forcing Tuirng to use SimpleVarInfo through AbstractMCMC

using Turing, AbstractMCMC, Tapir, ADTypes, LogDensityProblems, LogDensityProblemsAD

@model function demo(x)
       m ~ Normal()
       x ~ Normal(m, 1)
end

function AbstractMCMC.LogDensityModel(m::Turing.DynamicPPL.Model, adtype::ADTypes.AbstractADType)
       f = LogDensityFunction(m, DynamicPPL.SimpleVarInfo(m))
       adf = AbstractMCMC.LogDensityModel(LogDensityProblemsAD.ADgradient(adtype, f))
       return adf
end

f = AbstractMCMC.LogDensityModel(demo(1.), AutoTapir());

# compute log density
initial_params = rand(LogDensityProblems.dimension(f.logdensity))
LogDensityProblems.logdensity_and_gradient(f.logdensity, initial_params)


# sampling using NUTS
using AdvancedHMC

n_samples, n_adapts, δ = 1_000, 2_000, 0.8
samples = AbstractMCMC.sample(
       f,
       AdvancedHMC.NUTS(0.8),
       n_adapts + n_samples;
       nadapts = n_adapts,
       initial_params = initial_params
       );

@willtebbutt
Copy link
Member

I'll look at incorporating this in the near future.

@willtebbutt willtebbutt added the enhancement (testing) Would improve the tests label May 1, 2024
@gdalle
Copy link

gdalle commented May 1, 2024

I've been meaning to do a LogDensityProblemsAD PR that switches everything to ADTypes + DifferentiationInterface. Maybe this is a sign

@willtebbutt
Copy link
Member

I've been meaning to do a LogDensityProblemsAD PR that switches everything to ADTypes + DifferentiationInterface. Maybe this is a sign

I, for one, would be in favour of having fewer things to maintain.

@gdalle
Copy link

gdalle commented May 1, 2024

Well here you go: tpapp/LogDensityProblemsAD.jl#29

@gdalle
Copy link

gdalle commented May 1, 2024

If either of you can spare the time to review, it might help Tamas who is not familiar with DI (and me who is not familiar with Turing and its inner workings ^^)

@torfjelde
Copy link

Does Tapir.jl not work with VarInfo?

@willtebbutt
Copy link
Member

I've only really tested it with SimpleVarInfo because I assumed it was easy for users of Turing to specify using SimpleVarInfo. Since it's not really, I need to get it working with VarInfo really.

@gdalle
Copy link

gdalle commented May 8, 2024

The PR to LogDensityProblemsAD is nearly complete if you wanna help

@torfjelde
Copy link

When you're talking about SimpleVarInfo, do you mean specifically the one using a NamedTuple under the hood @willtebbutt ?

I've only really tested it with SimpleVarInfo because I assumed it was easy for users of Turing to specify using SimpleVarInfo. Since it's not really, I need to get it working with VarInfo really.

Even if we allow switching, SimpleVarInfo using a NamedTuple is more restricftive in what it can handle automatically.

It's also just a thing where SimpleVarInfo is not properly compatible with many models. For example, if I have a model like the following:

@model function demo()
    x = Vector{Float64}(undef, 2)
    x[1:2][:][1] ~ Normal()
    x[2] ~ Normal()
end

reconstructing this in a way that is compatible with SimpleVarInfo using a NamedTuple is very non-trivial since in Turing.jl, we only really see the stuff involved in a ~ statement, and not the other stuff.

@torfjelde
Copy link

If either of you can spare the time to review, it might help Tamas who is not familiar with DI (and me who is not familiar with Turing and its inner workings ^^)

Other than the Julia compat issue mentioned by @devmotion, Turing.jl doesn't really need much here I think:) As in, it doesn't "really matter" for Turing.jl what LogDensityProblemsAD.jl uses under the hood, as long as ADTypes.jl is still the way to specify which backend to use 👍

@gdalle
Copy link

gdalle commented May 8, 2024

I know, and I fixed the Julia compat issue (with a lot of blood, sweat and tears). My main challenge now is convincing Turing people than the switch to DI is a good idea, so that it can get merged ^^

And there's a bug on Tracker that I still haven't figured out

@torfjelde
Copy link

I know, and I fixed the Julia compat issue (with a lot of blood, sweat and tears).

Lovely ❤️

My main challenge now is convincing Turing people than the switch to DI is a good idea, so that it can get merged ^^

You mean the LogDensityProblems "people"? I think from our side, there's no other concerns?

@devmotion
Copy link

As @torfjelde said, for Turing it does not matter what LogDensityProblemsAD is doing under the hood (if it does not cause any regressions) as long as the API and compatibilities are not broken.

@gdalle
Copy link

gdalle commented May 8, 2024

I did my best to keep them working in that PR, so hopefully this won't be a problem

@gdalle
Copy link

gdalle commented May 8, 2024

But come to think of it, it would be nice to have downstream tests. Maybe I can open a Turing PR that uses this new branch of LogDensityProblemsAD, just to see what might break?

@yebai
Copy link
Contributor Author

yebai commented May 8, 2024

It's also just a thing where SimpleVarInfo is not properly compatible with many models.

@torfjelde @willtebbutt It would be good to focus efforts on making SimpleVarInfo feature parity with VarInfo, e.g. through TuringLang/DynamicPPL.jl#555, rather than keeping spending more time dealing with VarInfo issues.

@torfjelde
Copy link

It would be good to focus efforts on making SimpleVarInfo feature parity with VarInfo, e.g. through TuringLang/DynamicPPL.jl#555, rather than keeping spending more time dealing with VarInfo issues.

This is probably best discussed in a separate issue. But I don't think Tapir.jl will work with the VarNameVector approach if it doesn't work with the current VarInfo.

What's the actual reason for why Tapir.jl doesn't work with VarInfo?

@willtebbutt
Copy link
Member

What's the actual reason for why Tapir.jl doesn't work with VarInfo?

The current reason (I think) is Tapir's lack of suppport for Core._apply_iterate -- but I'll only really know once I've got it working. Either way, I'm definitely pushing towards supporting VarInfo.

When you're talking about SimpleVarInfo, do you mean specifically the one using a NamedTuple under the hood @willtebbutt ?

I suspect so -- I'm guessing we don't get type stability with the Dict approach?

@torfjelde
Copy link

The current reason (I think) is Tapir's lack of suppport for Core._apply_iterate -- but I'll only really know once I've got it working. Either way, I'm definitely pushing towards supporting VarInfo.

Ah gotcha 👍 Lovely:)

I suspect so -- I'm guessing we don't get type stability with the Dict approach?

Exactly. But once I've reworked some of the internals of VarInfo, we'll have something that is as flexible as the current VarInfo but type-stable in a wider range of scenarios than currently.

@willtebbutt willtebbutt added this to the A Milestone milestone May 13, 2024
@willtebbutt
Copy link
Member

I'm closing this in favour of #132 as I believe we've covered all of the ground that is Tapir.jl-specific. Please do re-open if you think I'm missing something!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (testing) Would improve the tests enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants