Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added publishing workflow #104

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
max-line-length = 120
ignore = W291,W503,W504,E123,E126,E203,E402,E701
per-file-ignores = __init__.py: F401
99 changes: 99 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
name: Publish
on:
release:
types: [published]
branches: [master]

jobs:
build_and_test:
strategy:
matrix:
python-version: [ 3.6, 3.8 ]
os: [ macos-latest, ubuntu-latest, windows-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Check version
shell: bash
run: |
python -m pip install --upgrade pip

python -m pip install torchsde
pypi_info=$(pip list | grep torchsde)
pypi_version=$(echo ${pypi_info} | cut -d " " -f2)
python -m pip uninstall -y torchsde

python setup.py install
master_info=$(pip list | grep torchsde)
master_version=$(echo ${master_info} | cut -d " " -f2)
python -m pip uninstall -y torchsde

python -c "import itertools as it
import sys
_, pypi_version, master_version = sys.argv
pypi_version_ = [int(i) for i in pypi_version.split('.')]
master_version_ = [int(i) for i in master_version.split('.')]
pypi_version__ = tuple(p for m, p in it.zip_longest(master_version_, pypi_version_, fillvalue=0))
master_version__ = tuple(m for m, p in it.zip_longest(master_version_, pypi_version_, fillvalue=0))
sys.exit(master_version__ <= pypi_version__)" ${pypi_version} ${master_version}

- name: Install dependencies
run: |
python -m pip install flake8 pytest wheel

- name: Lint with flake8
run: |
python -m flake8 .

- name: Build
shell: bash
run: |
python setup.py sdist bdist_wheel
rm -f dist/*.egg

- name: Run sdist tests
shell: bash
run: |
python -m pip install dist/*.tar.gz
python -m pytest
python -m pip uninstall -y torchsde

- name: Run bdist_wheel tests
shell: bash
run: |
python -m pip install dist/*.whl
python -m pytest
python -m pip uninstall -y torchsde

- name: Upload builds
if: matrix.python-version == '3.8' && matrix.os == 'ubuntu-latest'
uses: actions/upload-artifact@v2
with:
name: build-artifact
path: dist/

publish:
needs: [ build_and_test ]
strategy:
matrix:
os: [ ubuntu-latest ]
runs-on: ${{ matrix.os }}
steps:
- name: Download builds
uses: actions/download-artifact@v2
with:
name: build-artifact

- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@v1.4.2
with:
user: ${{ secrets.pypi_username }}
password: ${{ secrets.pypi_password }}
4 changes: 2 additions & 2 deletions examples/cont_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def plot(imgs, path):
if global_step % pause_every == 0:
logging.warning(f'global_step: {global_step:06d}, loss: {loss:.4f}')

img_path = os.path.join(train_dir, f'ode_samples', f'global_step_{global_step:07d}.png')
img_path = os.path.join(train_dir, 'ode_samples', f'global_step_{global_step:07d}.png')
ode_samples = reverse.ode_sample_final(tau=tau)
plot(ode_samples, img_path)

img_path = os.path.join(train_dir, f'sde_samples', f'global_step_{global_step:07d}.png')
img_path = os.path.join(train_dir, 'sde_samples', f'global_step_{global_step:07d}.png')
sde_samples = reverse.sde_sample_final(tau=tau)
plot(sde_samples, img_path)

Expand Down
1 change: 0 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

3 changes: 2 additions & 1 deletion tests/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def _methods():
yield SDE_TYPES.stratonovich, METHODS.reversible_heun, None


@pytest.mark.parametrize("sde_cls", [problems.ExDiagonal, problems.ExScalar, problems.ExAdditive, problems.NeuralGeneral])
@pytest.mark.parametrize("sde_cls", [problems.ExDiagonal, problems.ExScalar, problems.ExAdditive,
problems.NeuralGeneral])
@pytest.mark.parametrize("sde_type, method, options", _methods())
@pytest.mark.parametrize('adaptive', (False,))
def test_against_numerical(sde_cls, sde_type, method, options, adaptive):
Expand Down
8 changes: 4 additions & 4 deletions torchsde/_brownian/brownian_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ def _check_tensor_info(*tensors, size, dtype, device):
devices += [t.device for t in tensors]

if len(sizes) == 0:
raise ValueError(f"Must either specify `size` or pass in `W` or `H` to implicitly define the size.")
raise ValueError("Must either specify `size` or pass in `W` or `H` to implicitly define the size.")

if not all(i == sizes[0] for i in sizes):
raise ValueError(f"Multiple sizes found. Make sure `size` and `W` or `H` are consistent.")
raise ValueError("Multiple sizes found. Make sure `size` and `W` or `H` are consistent.")
if not all(i == dtypes[0] for i in dtypes):
raise ValueError(f"Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.")
raise ValueError("Multiple dtypes found. Make sure `dtype` and `W` or `H` are consistent.")
if not all(i == devices[0] for i in devices):
raise ValueError(f"Multiple devices found. Make sure `device` and `W` or `H` are consistent.")
raise ValueError("Multiple devices found. Make sure `device` and `W` or `H` are consistent.")

# Make sure size is a tuple (not a torch.Size) for neat repr-printing purposes.
return tuple(sizes[0]), dtypes[0], devices[0]
Expand Down
3 changes: 2 additions & 1 deletion torchsde/_core/adjoint_sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,4 +373,5 @@ def g_prod_and_gdg_prod_diagonal(self, t, y_aug, v1, v2): # For Ito/Stratonovic
create_graph=requires_grad
)
vjp_y_and_params = misc.seq_sub(prod_partials_adj_y_and_params, mixed_partials_adj_y_and_params)
return self._g_prod(g_prod, y, adj_y, requires_grad), misc.flatten((vg_dg_vjp, *vjp_y_and_params)).unsqueeze(0)
return self._g_prod(g_prod, y, adj_y, requires_grad), misc.flatten((vg_dg_vjp,
*vjp_y_and_params)).unsqueeze(0)
6 changes: 3 additions & 3 deletions torchsde/_core/methods/log_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class LogODEMidpoint(base_solver.BaseSDESolver):

def __init__(self, sde, **kwargs):
if isinstance(sde, adjoint_sde.AdjointSDE):
raise ValueError(f"Log-ODE schemes cannot be used for adjoint SDEs, because they require "
f"direct access to the diffusion, whilst adjoint SDEs rely on a more efficient "
f"diffusion-vector product. Use a different method instead.")
raise ValueError("Log-ODE schemes cannot be used for adjoint SDEs, because they require "
"direct access to the diffusion, whilst adjoint SDEs rely on a more efficient "
"diffusion-vector product. Use a different method instead.")
self.strong_order = 0.5 if sde.noise_type == NOISE_TYPES.general else 1.0
super(LogODEMidpoint, self).__init__(sde=sde, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion torchsde/_core/methods/reversible_heun.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from .. import adjoint_sde
from .. import base_solver
from .. import misc
from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHODS, METHOD_OPTIONS
from ...settings import SDE_TYPES, NOISE_TYPES, LEVY_AREA_APPROXIMATIONS, METHODS


class ReversibleHeun(base_solver.BaseSDESolver):
Expand Down
6 changes: 3 additions & 3 deletions torchsde/_core/methods/srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(self, sde, **kwargs):
self.step = self.diagonal_or_scalar_step

if isinstance(sde, adjoint_sde.AdjointSDE):
raise ValueError(f"Stochastic Runge–Kutta methods cannot be used for adjoint SDEs, because it requires "
f"direct access to the diffusion, whilst adjoint SDEs rely on a more efficient "
f"diffusion-vector product. Use a different method instead.")
raise ValueError("Stochastic Runge–Kutta methods cannot be used for adjoint SDEs, because it requires "
"direct access to the diffusion, whilst adjoint SDEs rely on a more efficient "
"diffusion-vector product. Use a different method instead.")

super(SRK, self).__init__(sde=sde, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions torchsde/_core/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def stable_division(a, b, epsilon=1e-7):
def vjp(outputs, inputs, **kwargs):
if torch.is_tensor(inputs):
inputs = [inputs]
_dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784.
_dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. # noqa: 74

if torch.is_tensor(outputs):
outputs = [outputs]
Expand All @@ -85,7 +85,7 @@ def jvp(outputs, inputs, grad_inputs=None, **kwargs):
# Unlike `torch.autograd.functional.jvp`, this function avoids repeating forward computation.
if torch.is_tensor(inputs):
inputs = [inputs]
_dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784.
_dummy_inputs = [torch.as_strided(i, (), ()) for i in inputs] # Workaround for PyTorch bug #39784. # noqa: 88

if torch.is_tensor(outputs):
outputs = [outputs]
Expand Down
10 changes: 5 additions & 5 deletions torchsde/_core/sdeint.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ def check_contract(sde, y0, ts, bm, method, adaptive, options, names, logqp):
sde = base_sde.RenameMethodsSDE(sde, **names_to_change)

if not hasattr(sde, "noise_type"):
raise ValueError(f"sde does not have the attribute noise_type.")
raise ValueError("sde does not have the attribute noise_type.")

if sde.noise_type not in NOISE_TYPES:
raise ValueError(f"Expected noise type in {NOISE_TYPES}, but found {sde.noise_type}.")

if not hasattr(sde, "sde_type"):
raise ValueError(f"sde does not have the attribute sde_type.")
raise ValueError("sde does not have the attribute sde_type.")

if sde.sde_type not in SDE_TYPES:
raise ValueError(f"Expected sde type in {SDE_TYPES}, but found {sde.sde_type}.")
Expand Down Expand Up @@ -160,7 +160,7 @@ def check_contract(sde, y0, ts, bm, method, adaptive, options, names, logqp):

if not torch.is_tensor(ts):
if not isinstance(ts, (tuple, list)) or not all(isinstance(t, (float, int)) for t in ts):
raise ValueError(f"Evaluation times `ts` must be a 1-D Tensor or list/tuple of floats.")
raise ValueError("Evaluation times `ts` must be a 1-D Tensor or list/tuple of floats.")
ts = torch.tensor(ts, dtype=y0.dtype, device=y0.device)
if not misc.is_strictly_increasing(ts):
raise ValueError("Evaluation times `ts` must be strictly increasing.")
Expand Down Expand Up @@ -275,8 +275,8 @@ def _check_2d_or_3d(name, shape):
options = options.copy()

if adaptive and method == METHODS.euler and sde.noise_type != NOISE_TYPES.additive:
warnings.warn(f"Numerical solution is not guaranteed to converge to the correct solution when using adaptive "
f"time-stepping with the Euler--Maruyama method with non-additive noise.")
warnings.warn("Numerical solution is not guaranteed to converge to the correct solution when using adaptive "
"time-stepping with the Euler--Maruyama method with non-additive noise.")

return sde, y0, ts, bm, method, options

Expand Down
2 changes: 1 addition & 1 deletion torchsde/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# We import from `typing` more than what's enough, so that other modules can import from this file and not `typing`.
from typing import Sequence, Union, Optional, Any, Dict, Tuple, Callable
from typing import Sequence, Union, Optional, Any, Dict, Tuple, Callable # noqa: F401

import torch

Expand Down