From 980279439b8053954fd7b31d20d5152557f18a08 Mon Sep 17 00:00:00 2001 From: Zheng Zhao Date: Wed, 22 May 2024 13:25:37 +0200 Subject: [PATCH] update readme --- .github/workflows/unittest.yml | 2 +- README.md | 8 +++++--- {datasets => experiments/datasets}/celebaHQ/convert.py | 0 {datasets => experiments/datasets}/make_cifar10.py | 2 +- {datasets => experiments/datasets}/make_cifar10.sh | 0 tests/test_gibbs.py | 2 +- 6 files changed, 8 insertions(+), 6 deletions(-) rename {datasets => experiments/datasets}/celebaHQ/convert.py (100%) rename {datasets => experiments/datasets}/make_cifar10.py (92%) rename {datasets => experiments/datasets}/make_cifar10.sh (100%) diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml index 28f8dd42..60319c68 100644 --- a/.github/workflows/unittest.yml +++ b/.github/workflows/unittest.yml @@ -29,7 +29,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ 3.9 ] + python-version: [ '3.10' ] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/README.md b/README.md index 7fda14ed..756b5db6 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,11 @@ This repository is concerned with Markov chain Monte Carlo (MCMC) method for con see, https://arxiv.org/placeholder. More specifically, our core contributions are as follows. -1. We develop a new and efficient particle Gibbs sampler for conditioning diffusion models. +1. We develop a new and efficient particle Gibbs sampler, and a pseudo-marginal sampler for conditioning diffusion models. 2. The proposed method is not only consistent but is also asymptotically exact, even when 1) using a finite number of particles, and 2) no access to the reference distribution. -To quickly see what our method can do while others cannot, please check the two animations below. +To quickly see what our method can do while others cannot, please check the two animations below +(you may wait for seconds for the animations to start). @@ -37,7 +38,8 @@ The scripts in `./experiments` are explained as follows. 4. `./experiments/sb_imgs`. This folder is concerned with the Schrödinger bridge experiments on MNIST super-resolution. 5. `./experiments/toy`. This folder is concerned with the Gaussian synthetic experiments. -You can download the CelebA-HQ dataset as per the instruction in https://github.com/Algolzw/daclip-uir. +You can download the CelebA-HQ dataset as per the instruction in https://github.com/Algolzw/daclip-uir, and the scripts +in `./experiments/datasets`. After you have run all the experiments, results will be saved in their corresponding directories. Then, simply run any file in `./experiments/tabulators` to produce the tables and figures in our paper. diff --git a/datasets/celebaHQ/convert.py b/experiments/datasets/celebaHQ/convert.py similarity index 100% rename from datasets/celebaHQ/convert.py rename to experiments/datasets/celebaHQ/convert.py diff --git a/datasets/make_cifar10.py b/experiments/datasets/make_cifar10.py similarity index 92% rename from datasets/make_cifar10.py rename to experiments/datasets/make_cifar10.py index 22831b8b..80a9380d 100644 --- a/datasets/make_cifar10.py +++ b/experiments/datasets/make_cifar10.py @@ -23,4 +23,4 @@ def load_batch(filename): train_data = train_data.astype('float32') / 255. test_data = test_data.astype('float32') / 255. -np.savez('./cifar10.npz', train_data=train_data, test_data=test_data) +np.savez('cifar10.npz', train_data=train_data, test_data=test_data) diff --git a/datasets/make_cifar10.sh b/experiments/datasets/make_cifar10.sh similarity index 100% rename from datasets/make_cifar10.sh rename to experiments/datasets/make_cifar10.sh diff --git a/tests/test_gibbs.py b/tests/test_gibbs.py index 4758002f..ce924e8c 100644 --- a/tests/test_gibbs.py +++ b/tests/test_gibbs.py @@ -119,5 +119,5 @@ def gibbs_kernel(key_, x0_, y0_, us_star_, bs_star_): x0s = x0s[burnin:] - npt.assert_allclose(jnp.mean(x0s), true_posterior_mean, rtol=1e-2) + npt.assert_allclose(jnp.mean(x0s), true_posterior_mean, rtol=5e-2) npt.assert_allclose(jnp.var(x0s), true_posterior_cov, rtol=2e-2)