Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Discrete MCMC with JAX and Numpyro #29

Draft
wants to merge 1 commit into
base: generative_models
Choose a base branch
from

Conversation

quantshah
Copy link
Collaborator

Added a simple discrete MCMC method that can work with any energy function. A new proposal now flips multiple spins that can vary (from flipping no spins to all of them). Previously we were only flipping one spin and then deciding to accept or reject the proposal. We can try both and see what works the best (the old approach is commented).

This numpyro implementation seems very fast so it should help us speedup training of generative models with CD. We can create a large number of initial states (chains) and sample them in parallel.

Also added some basic tests for a known energy function that just sums up all the spins. The posterior samples for such an energy function is simple, all the spins should be -1, -1, .... so it forms a basic test case.

Added MCMC implementation for arbitrary energy functions and simple
tests to make sure samples with least energies are most frequent.
@quantshah quantshah marked this pull request as draft August 22, 2024 21:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant