Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/moscot not #567

Closed
wants to merge 257 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
257 commits
Select commit Hold shift + click to select a range
d11df90
condOT, kwargs combiner & pair sampling missing
AlejandroTL Nov 29, 2022
0e42fbc
fix for data sampler and growth rates
lucaeyring Nov 29, 2022
a054d01
seperated f and g loss fn
lucaeyring Nov 29, 2022
232418b
seperated f and g loss fn
lucaeyring Nov 29, 2022
bc70ba8
final picnn
AlejandroTL Dec 2, 2022
54834a8
merge moscot main
MUCDK Dec 6, 2022
599d417
new import for dualpotential
MUCDK Dec 6, 2022
e24df5f
adapt import statement of dualpotentials
MUCDK Dec 6, 2022
4eb53cf
adapt import statements
MUCDK Dec 6, 2022
ccffc11
fix pre-commit
MUCDK Dec 6, 2022
c95d081
Merge pull request #2 from theislab/basic_not
MUCDK Dec 6, 2022
01d7c1e
Merge branch 'main' into conditional_not
MUCDK Dec 6, 2022
427133f
fix pre-commit
MUCDK Dec 6, 2022
5c43ac8
fix pre-commit
MUCDK Dec 6, 2022
855cb52
reiteration
MUCDK Dec 6, 2022
aecc14b
introduce combiner_kwargs
MUCDK Dec 6, 2022
37240ce
introduce CondOTProblem
MUCDK Dec 7, 2022
c48436b
fix circular import and add problemtype for condOT
MUCDK Dec 7, 2022
9e69655
adapt attributes of CondOTProblem
MUCDK Dec 7, 2022
6b91637
adapt import statements
MUCDK Dec 7, 2022
0347ae2
Fix in _neuraldual.py
lucaeyring Dec 8, 2022
2d02181
adjustments for ott-jax change
lucaeyring Dec 13, 2022
6797978
adjustments for ott-jax change
lucaeyring Dec 13, 2022
05c15f1
adjustment for ott-jax change
lucaeyring Dec 13, 2022
51eb514
[ci skip] add generic neural ot classes
MUCDK Dec 27, 2022
45fcac6
[ci skip] add tests
MUCDK Dec 27, 2022
b3eab0f
adapt ottjax sampler
MUCDK Dec 29, 2022
78289b1
fix condot problem
MUCDK Dec 29, 2022
53f0249
fix condot problem
MUCDK Dec 29, 2022
c569755
adapt tests
MUCDK Dec 29, 2022
082df9f
merge main;
MUCDK Dec 29, 2022
54a3ed9
remove jointProblem
MUCDK Dec 29, 2022
342e618
fix jaxsampler
MUCDK Dec 29, 2022
815c4a3
fix jaxsampler
MUCDK Dec 29, 2022
7f60442
fix jaxsampler
MUCDK Dec 29, 2022
5e43ef5
fix tests
MUCDK Dec 29, 2022
d7f65fc
add plot_convergence
MUCDK Dec 30, 2022
7264441
remove jit from _compute_unbalanced marginals
MUCDK Jan 2, 2023
c427978
fix sinkhorn_divergence
MUCDK Jan 2, 2023
6c5e55a
adapt tox.ini file
MUCDK Jan 4, 2023
0d38985
shape mismatch fixed without precommit
AlejandroTL Jan 11, 2023
85e421b
remove print statement
MUCDK Jan 12, 2023
13abc46
finish merge
MUCDK Jan 12, 2023
688b421
fix moscot main
MUCDK Jan 12, 2023
e4c796d
fix pre commits
MUCDK Jan 12, 2023
20963ae
Merge remote-tracking branch 'moscot/main' into conditional_not_preco…
MUCDK Feb 5, 2023
1727307
add marginal_kwargs to prepare method of TemporalNeuralProblem
MUCDK Feb 5, 2023
8f166f6
added neural mixin and knn tm computation
lucaeyring Feb 13, 2023
129b7b8
clean up and add typing
lucaeyring Feb 14, 2023
ae18265
restructure and fix to neural cell transition
lucaeyring Feb 15, 2023
3a5434b
incorporated comments
lucaeyring Feb 15, 2023
a7fc1f1
removed get_projected_transport_matrix
lucaeyring Feb 16, 2023
f13831f
Merge pull request #8 from theislab/neural_mixin
lucaeyring Feb 16, 2023
3178cc2
fix to projected tm computation
lucaeyring Feb 16, 2023
0a6f7db
fix to scaling in
lucaeyring Mar 3, 2023
92bd10c
Revert "fix to scaling in"
lucaeyring Mar 3, 2023
0f5ca49
fix to scaling argument in marginal_kwargs
lucaeyring Mar 3, 2023
0e0a721
conditional by split_indices & icnn architecture
AlejandroTL Nov 25, 2022
eeb5e78
combiner fixed, still missing kwargs
AlejandroTL Nov 25, 2022
c6f89be
condICNN & clipping function
AlejandroTL Nov 28, 2022
f39845d
condOT, kwargs combiner & pair sampling missing
AlejandroTL Nov 29, 2022
d570491
final picnn
AlejandroTL Dec 2, 2022
cde666a
fix pre-commit
MUCDK Dec 6, 2022
1796115
fix pre-commit
MUCDK Dec 6, 2022
018c118
reiteration
MUCDK Dec 6, 2022
2c7bf11
introduce combiner_kwargs
MUCDK Dec 6, 2022
e3bad0e
introduce CondOTProblem
MUCDK Dec 7, 2022
c77cb66
fix circular import and add problemtype for condOT
MUCDK Dec 7, 2022
acfe011
adapt attributes of CondOTProblem
MUCDK Dec 7, 2022
46ca711
adapt import statements
MUCDK Dec 7, 2022
74e8ec6
[ci skip] add generic neural ot classes
MUCDK Dec 27, 2022
61d0b83
[ci skip] add tests
MUCDK Dec 27, 2022
146ffa9
adapt ottjax sampler
MUCDK Dec 29, 2022
7d7815d
fix condot problem
MUCDK Dec 29, 2022
0c492bc
fix condot problem
MUCDK Dec 29, 2022
d85cbf4
adapt tests
MUCDK Dec 29, 2022
a28c08c
remove jointProblem
MUCDK Dec 29, 2022
f86eb0e
fix jaxsampler
MUCDK Dec 29, 2022
ac3c781
fix jaxsampler
MUCDK Dec 29, 2022
385bcfd
fix jaxsampler
MUCDK Dec 29, 2022
cfebc4d
fix tests
MUCDK Dec 29, 2022
e056042
add plot_convergence
MUCDK Dec 30, 2022
a7af86b
remove jit from _compute_unbalanced marginals
MUCDK Jan 2, 2023
c8cbe9d
fix sinkhorn_divergence
MUCDK Jan 2, 2023
ff145d1
adapt tox.ini file
MUCDK Jan 4, 2023
9e241db
shape mismatch fixed without precommit
AlejandroTL Jan 11, 2023
3cc279f
remove print statement
MUCDK Jan 12, 2023
c10c5b6
finish merge
MUCDK Jan 12, 2023
74adb1e
adapt callbacks and rename tag `cost` to `cost_matrix` (#426)
MUCDK Dec 9, 2022
a282376
Feature/correlation test (#423)
MUCDK Dec 14, 2022
7360819
fix sankey return statement (#428)
MUCDK Dec 14, 2022
e08f63f
Bump version: 0.1.0 → 0.1.1
MUCDK Dec 14, 2022
728537c
fix return statements
MUCDK Dec 14, 2022
f2153e9
add save tests
MUCDK Dec 14, 2022
4f1d26e
fix return type in mpl (#432)
MUCDK Dec 14, 2022
0ffc730
Simplify linear operator (#431)
michalk8 Dec 14, 2022
cee2466
Explicitly jit the solvers (#433)
michalk8 Dec 16, 2022
6b11188
Feature/interpolate colors sankey (#434)
MUCDK Dec 23, 2022
e28816f
Remove `FGWSolver` (#437)
michalk8 Jan 4, 2023
3ff63be
fix bug in SinkhornProblem (#442)
MUCDK Jan 4, 2023
865ad5a
fix pre commits
MUCDK Jan 12, 2023
2571935
make push/pull always use source/target (#443)
MUCDK Jan 5, 2023
3d22677
fix strip plotting in sankey (#445)
MUCDK Jan 16, 2023
b5370e0
Feature/spearman correlation (#444)
MUCDK Jan 17, 2023
34ef904
Delete logo.png
MUCDK Jan 23, 2023
9806854
Feature/plot order (#453)
MUCDK Jan 27, 2023
50dbdca
Expose marginal kwargs for `moscot.temporal` and check for numeric ty…
MUCDK Feb 2, 2023
46e1c65
adapt plot_convergence (#454)
MUCDK Feb 3, 2023
61f5481
Bug/docs generic analysis mixin (#455)
MUCDK Feb 3, 2023
082c879
Docs/improvements (#456)
MUCDK Feb 3, 2023
a53ef22
remove uns_key from set_plotting_vars (#458)
MUCDK Feb 3, 2023
3a9f992
resolve `fig referenced before assignment` (#460)
MUCDK Feb 3, 2023
c323cb5
move generic mixins tests to problems` (#461)
MUCDK Feb 3, 2023
edb3c11
Tests/spatiotemporalproblem (#464)
MUCDK Feb 5, 2023
e3c3911
Feature/move taggedarray (#457)
MUCDK Feb 5, 2023
fdc5d54
add marginal_kwargs to prepare method of TemporalNeuralProblem
MUCDK Feb 5, 2023
859700c
fix to scaling in
lucaeyring Mar 3, 2023
22a48f5
Revert "fix to scaling in"
lucaeyring Mar 3, 2023
dfa2d82
fix to scaling argument in marginal_kwargs
lucaeyring Mar 3, 2023
4bcd966
updated conditional not pipeline
lucaeyring Mar 21, 2023
cb7bc2d
merge into condot branch
lucaeyring Mar 21, 2023
ddb1772
merge into condot branch
lucaeyring Mar 21, 2023
9636f1c
incoporated comments
lucaeyring Mar 23, 2023
f27dab0
incoporated comments
lucaeyring Mar 23, 2023
816b084
incoporated comments
lucaeyring Mar 24, 2023
1eb60f4
removed new_adata for push/pull
lucaeyring Mar 27, 2023
2b6686f
Merge pull request #9 from theislab/condot_revamp
MUCDK Mar 27, 2023
9e053ca
Merge pull request #7 from theislab/conditional_not_precommit
lucaeyring Mar 27, 2023
0d96dbe
[ci skip] start docs
MUCDK Mar 27, 2023
83e6d71
added temporal neural test
lucaeyring Mar 28, 2023
df346d9
[ci skip] continue docs
MUCDK Mar 28, 2023
ec0ab31
continue docs
MUCDK Mar 29, 2023
fbd4181
continue docs
MUCDK Mar 29, 2023
10799ec
change validation epsilon
MUCDK Mar 29, 2023
7f94ef9
fixed error when not computing wasserstein baseline
lucaeyring Mar 29, 2023
1421c65
fixed error when not computing wasserstein baseline
lucaeyring Mar 29, 2023
b828752
Merge pull request #11 from theislab/feature/docs
lucaeyring Mar 29, 2023
1b584ae
Merge branch 'main' into temporal_neural_test
lucaeyring Mar 29, 2023
d37c7d3
Merge pull request #12 from theislab/temporal_neural_test
lucaeyring Mar 29, 2023
dd106af
correct typo
MUCDK Mar 31, 2023
453922c
fix bug
MUCDK Mar 31, 2023
298f8fb
added neural tests
lucaeyring Mar 31, 2023
491a9a4
[ci skip] draft CondNeuralOutput
MUCDK Apr 8, 2023
e6b1f9a
include CondDualPotentials and CondDualSolver
MUCDK Apr 8, 2023
77fea48
merge moscot main restructuring
lucaeyring May 1, 2023
b1316bc
fixes to main merge
lucaeyring May 1, 2023
1ff844f
fix typo
MUCDK May 8, 2023
29ff674
fix test_cell_transition_subset_pipeline
MUCDK May 8, 2023
518a8e5
fix tests
MUCDK May 8, 2023
2a5bcd7
update conditionalDualPotentials
MUCDK May 25, 2023
6594d43
update conditionalDualPotentials
MUCDK May 25, 2023
2c43b07
fix most pre-commit hooks and fix tests
MUCDK May 25, 2023
594e2f9
fix pandas version to <2.0
MUCDK May 25, 2023
ebc7877
fix tests for non-conditional solvers
MUCDK May 25, 2023
df0b6a8
merge continue_docs
MUCDK May 25, 2023
db7b2c0
continue
MUCDK May 25, 2023
3220f9e
fix
MUCDK May 25, 2023
4fb6e51
continue fixing
MUCDK May 25, 2023
fdf2882
fix ICNN setup
MUCDK May 29, 2023
35db509
fix tests
MUCDK May 29, 2023
65850ae
Merge pull request #18 from theislab/neural_tests_local
MUCDK May 29, 2023
b7d8a00
swap role of f and g, such that push/pull is correct again
MUCDK May 30, 2023
ea7187e
[ci skip] restructure to include more general neural solvers
MUCDK Jun 1, 2023
0c9a587
[ci skip] restructure ICNNs to allow passing instances of ICNN
MUCDK Jun 1, 2023
d3a529d
adapt tests
MUCDK Jun 1, 2023
aba225d
Filled in Monge Gap structure
gocato Jun 10, 2023
d5005b1
Added Monge Gap paper to documentation
gocato Jun 10, 2023
2e0fb4a
Ammend PointCloud Import
gocato Jun 10, 2023
f328056
Merge remote-tracking branch 'origin/feature/add_monge_gap' into feat…
gocato Jun 10, 2023
4692dc5
Update _utils.py
gocato Jun 10, 2023
36b1953
Merge remote-tracking branch 'origin/feature/add_monge_gap' into feat…
gocato Jun 10, 2023
66e5794
Solve compatibility issue with ProblemKind
gocato Jun 10, 2023
2ad39a8
Solve missing Import
gocato Jun 10, 2023
16a797d
Fix call to deprecated function
gocato Jun 10, 2023
7d63082
Fix style and comment issues
gocato Jun 12, 2023
9aff229
add callback, swap f & g
lucaeyring Jun 21, 2023
8f91e4a
add callback, swap f & g
lucaeyring Jun 21, 2023
2d63db7
add callback, swap f & g
lucaeyring Jun 21, 2023
5452485
Merge pull request #26 from theislab/feature/add_callback
MUCDK Jun 21, 2023
6f413a2
Merge pull request #27 from theislab/dev
MUCDK Jun 21, 2023
43b873d
intermediate save
MUCDK Sep 1, 2023
602b8a7
intermediate save
MUCDK Sep 1, 2023
b85c8e0
intermediate save
MUCDK Sep 5, 2023
de6facc
intermediate save
MUCDK Sep 5, 2023
a32038c
partially resolve precommit errors
MUCDK Oct 10, 2023
fa9d6cf
[ci skip] fix merge conflicts
MUCDK Oct 10, 2023
501b018
resolve conflict
MUCDK Oct 10, 2023
0419670
remove pairwise policy
MUCDK Oct 10, 2023
ce7667a
add neural dependencies
MUCDK Oct 10, 2023
f97b08c
add neural dependencies
MUCDK Oct 10, 2023
243c73a
add flax
MUCDK Oct 10, 2023
4516612
fix _call_kwargs
MUCDK Oct 11, 2023
68e7bf3
fix marginal kwargs
MUCDK Oct 11, 2023
536f681
remove monge gap solver
MUCDK Oct 11, 2023
3185371
clean condneuralsolver
MUCDK Oct 11, 2023
582ad43
[ci skip] introduce new data container for joint neural problems
MUCDK Oct 12, 2023
7e3d4f7
add conditions in distirbutioncontainer
MUCDK Oct 17, 2023
51915db
resolve unfreeze/freeze
MUCDK Oct 17, 2023
47508b3
enable pretraining and weight clipping
MUCDK Oct 17, 2023
5363d98
make dicts compatible with older python versions
MUCDK Oct 17, 2023
0dd6b8b
resolve precommit errors partially
MUCDK Oct 17, 2023
21f3309
resolve precommit errors partially
MUCDK Oct 17, 2023
3bbb4da
adapt tests
MUCDK Oct 19, 2023
2cced99
[ci skip] draft unbalancedNeuralMixin
MUCDK Oct 21, 2023
aa49c10
[ci skip] fix naming of posterior marginals
MUCDK Oct 21, 2023
16d3204
[ci skip] add MLP_marginals
MUCDK Oct 23, 2023
c61633c
adapt neural output to incorporate learnt rescaling functions
MUCDK Oct 24, 2023
2dc6ff3
fix _solve in neuraldualsolver
MUCDK Oct 25, 2023
badd57b
incorporate feedback
MUCDK Oct 25, 2023
bd983fd
fix distributioncollection class
MUCDK Oct 25, 2023
b00e2b7
unify _split_data
MUCDK Oct 25, 2023
79f050e
fix tests
MUCDK Oct 25, 2023
5c19f1f
fix some precommit hooks
MUCDK Oct 25, 2023
4e964e5
make neural dependencies optional
MUCDK Oct 25, 2023
517db5f
make neural dependencies optional
MUCDK Oct 25, 2023
742f02c
delete old files
MUCDK Oct 25, 2023
d763de4
adapt pyproject.toml
MUCDK Oct 25, 2023
2c64e3b
adapt pyproject.toml
MUCDK Oct 25, 2023
73c7830
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2023
6d40537
[ci skip] adjust _format_params
MUCDK Oct 25, 2023
9483e69
adapt neuraldualsolver to be more similar to ott-jax
MUCDK Oct 26, 2023
32300e5
adapt neuraldualsolver
MUCDK Oct 26, 2023
88b9a59
TODO: make JaxSampler return conditions
MUCDK Oct 26, 2023
dd74215
add basic neural test
MUCDK Oct 27, 2023
dce2b21
[ci skip] intermediate save
MUCDK Oct 30, 2023
d2ece68
adapt neuraldualsolver and finish tests for neural backend
MUCDK Oct 30, 2023
3bd674f
[ci skip] TODO: re-iterate on initialisation of neural solver
MUCDK Oct 30, 2023
61e3a01
adapt distributioncontainer
MUCDK Nov 3, 2023
b949d6f
fix dict bug
MUCDK Nov 3, 2023
e90934c
resolve passing of arguments in solver call methods
MUCDK Nov 3, 2023
295eb6e
[ci skip] adapt `solve` in `CondOTProblem`
MUCDK Nov 3, 2023
08fcec2
adapt tests and valid loader conditions
MUCDK Nov 5, 2023
868146a
adapt neural backend tests
MUCDK Nov 5, 2023
f84bd31
fix mypy errors
MUCDK Nov 5, 2023
e74e0a4
Merge branch 'main' into feature/moscot_not
MUCDK Nov 5, 2023
9482ccc
make basesolveroutput to basediscretesolveroutput
MUCDK Nov 5, 2023
182bab7
move `to` to BaseSolverOutput`
MUCDK Nov 5, 2023
166bd49
adapt transport_matrix docs
MUCDK Nov 5, 2023
fea3ac9
adapt transport_matrix docs
MUCDK Nov 5, 2023
7669f7b
adapt tests
MUCDK Nov 5, 2023
e53f9c0
adapt tests
MUCDK Nov 5, 2023
cc6bfb5
update unbalancedness mixin
MUCDK Nov 8, 2023
a438629
use implementation from moscot
MUCDK Nov 9, 2023
3f422bf
uncomment unused code
MUCDK Nov 9, 2023
3e9dfc5
before passing states to loss-fn
MUCDK Nov 10, 2023
f4b7c76
intermediate save
MUCDK Nov 10, 2023
4152baf
adapt neuraldualsolver
MUCDK Nov 10, 2023
06e55c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2023
0b46bae
resolve some / not all pre commit errors
MUCDK Nov 10, 2023
4566eb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ Solvers
solver.BaseSolver
solver.OTSolver
output.BaseSolverOutput
output.BaseDiscreteSolverOutput

Output
^^^^^^
.. autosummary::
:toctree: genapi

output.BaseSolverOutput
output.BaseDiscreteSolverOutput
output.MatrixSolverOutput

Utils
Expand Down
34 changes: 34 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,37 @@ @article{klein:23
title = {Mapping cells through time and space with moscot},
year = {2023},
}

@article{eyring2022modeling,
title={Modeling Single-Cell Dynamics Using Unbalanced Parameterized Monge Maps},
author={Eyring, Luca Vincent and Klein, Dominik and Palla, Giovanni and Becker, Soeren and Weiler, Philipp and Kilbertus, Niki and Theis, Fabian},
journal={bioRxiv},
pages={2022--10},
year={2022},
publisher={Cold Spring Harbor Laboratory}
}

@inproceedings{amos2017input,
title={Input convex neural networks},
author={Amos, Brandon and Xu, Lei and Kolter, J Zico},
booktitle={International Conference on Machine Learning},
pages={146--155},
year={2017},
organization={PMLR}
}

@article{bunne2022supervised,
title={Supervised training of conditional monge maps},
author={Bunne, Charlotte and Krause, Andreas and Cuturi, Marco},
journal={arXiv preprint arXiv:2206.14262},
year={2022}
}

@misc{uscidda2023monge,
title={The Monge Gap: A Regularizer to Learn All Transport Maps},
author={Th\'eo Uscidda and Marco Cuturi},
year={2023},
eprint={2302.04953},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,20 @@ dependencies = [
"ott-jax>=0.4.4",
"cloudpickle>=2.2.0",
"rich>=13.5",
"optax>=0.1.7",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michalk8 how do I make these optional, i.e. a neural optional dependency as below, but still make github CI install these optional dependencies?

"flax>=0.7.1",
]

[project.optional-dependencies]
neural = [
"optax>=0.1.7",
"flax>=0.7.1",
]

spatial = [
"squidpy>=1.2.3"
]

dev = [
"pre-commit>=3.0.0",
"tox>=4",
Expand Down
32 changes: 29 additions & 3 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
from ott.geometry import costs

from moscot.backends.ott._utils import sinkhorn_divergence
from moscot.backends.ott.output import OTTOutput
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
from moscot.backends.ott.nets import ICNN, MLP_marginal
from moscot.backends.ott.output import (
ConditionalDualPotentials,
CondNeuralDualOutput,
NeuralDualOutput,
OTTOutput,
)
from moscot.backends.ott.solver import (
CondNeuralDualSolver,
GWSolver,
NeuralDualSolver,
OTTNeuralDualSolver,
SinkhornSolver,
)
from moscot.costs import register_cost

__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"]
__all__ = [
"OTTOutput",
"NeuralDualOutput",
"GWSolver",
"SinkhornSolver",
"NeuralDualSolver",
"OTTNeuralDualSolver",
"CondNeuralDualSolver",
"ConditionalDualPotentials",
"CondNeuralDualOutput",
"sinkhorn_divergence",
"ICNN",
"MLP_marginal",
]


register_cost("euclidean", backend="ott")(costs.Euclidean)
register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean)
Expand Down
161 changes: 161 additions & 0 deletions src/moscot/backends/ott/_jax_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from functools import partial
from types import MappingProxyType
from typing import Any, Dict, Hashable, List, Literal, Optional, Tuple, TypeVar, Union

import jax
import jax.numpy as jnp
from ott.geometry.pointcloud import PointCloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

K = TypeVar("K", bound=Hashable)


class JaxSampler:
"""Data sampler for Jax."""

def __init__(
self,
distributions: List[jnp.ndarray],
policy_pairs: List[Tuple[Any, Any]],
conditions: Optional[List[jnp.ndarray]],
a: List[jnp.ndarray] = None,
b: List[jnp.ndarray] = None,
sample_to_idx: Dict[K, Any] = MappingProxyType({}),
batch_size: int = 1024,
tau_a: float = 1.0,
tau_b: float = 1.0,
epsilon: float = 0.1,
):
"""Initialize data sampler."""
if not len(distributions) == len(a) == len(b):
raise ValueError("Number of distributions, a, and b must be equal.")
self._distributions = distributions
self._conditions = conditions
self._batch_size = batch_size
self._policy_pairs = policy_pairs
if not len(sample_to_idx):
if len(self.policy_pairs) > 1:
raise ValueError("If `policy_pairs` contains more than 1 value, `sample_to_idx` is required.")
sample_to_idx = {self.policy_pairs[0][0]: 0, self.policy_pairs[0][1]: 1}
self._sample_to_idx = sample_to_idx

@partial(jax.jit, static_argnames=["index"])
def _sample_source(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, None]:
"""Jitted sample function."""
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(a[index]))
return jnp.asarray(samples), None

@partial(jax.jit, static_argnames=["index"])
def _sample_source_conditional(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Jitted sample function."""
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(a[index]))
conds = jax.random.choice(key, self.conditions[index], shape=[batch_size], p=jnp.squeeze(a[index])) # type: ignore[index] # noqa: E501
return samples, conds

@partial(jax.jit, static_argnames=["index"])
def _sample_target(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Jitted sample function."""
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(b[index]))
return samples, None

@partial(jax.jit, static_argnames=["index"])
def _sample_target_conditional(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Jitted sample function."""
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(b[index]))
conds = jax.random.choice(key, self.conditions[index], shape=[batch_size], p=jnp.squeeze(b[index])) # type: ignore[index] # noqa: E501
return samples, conds

@jax.jit
def _compute_unbalanced_marginals(
batch_source: jnp.ndarray,
batch_target: jnp.ndarray,
sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}),
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Jitted function to compute the source and target marginals for a batch."""
geom = PointCloud(batch_source, batch_target, epsilon=epsilon, scale_cost="mean")
out = sinkhorn.Sinkhorn(**sinkhorn_kwargs)(linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b))
return out.matrix.sum(axis=1), out.matrix.sum(axis=0)

@jax.jit
def _unbalanced_resample(
key: jax.random.KeyArray,
batch: Tuple[jnp.ndarray, ...],
marginals: jnp.ndarray,
) -> Tuple[jnp.ndarray, ...]:
"""Resample a batch based upon log marginals."""
# sample from marginals
indices = jax.random.choice(key, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)])
return tuple(b[indices] if b is not None else None for b in batch)

def _sample_policy_pair(key: jax.random.KeyArray) -> Tuple[Any, Any]:
"""Sample a policy pair."""
index = jax.random.randint(key, shape=[], minval=0, maxval=len(self.policy_pairs))
return self.policy_pairs[index]

self._sample_source = _sample_source if self.conditions is None else _sample_source_conditional
self._sample_target = _sample_target if self.conditions is None else _sample_target_conditional
self.sample_policy_pair = _sample_policy_pair
self.compute_unbalanced_marginals = _compute_unbalanced_marginals
self.unbalanced_resample = _unbalanced_resample

def __call__(
self,
key: jax.random.KeyArray,
policy_pair: Tuple[Any, Any],
sample: Literal["source", "target", "both"] = "both",
full_dataset: bool = False,
) -> Union[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]:
"""Sample data. When sampling the source, the conditions are returned, too."""
if full_dataset:
if sample == "source":
return jnp.asarray(
self.distributions[self.sample_to_idx[policy_pair[0]]]
), None if self.conditions is None else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]])
if sample == "target":
return jnp.asarray(
self.distributions[self.sample_to_idx[policy_pair[1]]]
), None if self.conditions is None else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]])
if sample == "both":
return (
jnp.asarray(self.distributions[self.sample_to_idx[policy_pair[0]]]),
None
if self.conditions is None
else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]]),
jnp.asarray(self.distributions[self.sample_to_idx[policy_pair[1]]]),
)
raise NotImplementedError(f"Sample type {sample} not implemented.")
if sample == "source":
return self._sample_source(key, self.sample_to_idx[policy_pair[0]])
if sample == "target":
return self._sample_target(key, self.sample_to_idx[policy_pair[1]])
if sample == "both":
return (
self._sample_source(key, self.sample_to_idx[policy_pair[0]])[0],
*self._sample_target(key, self.sample_to_idx[policy_pair[1]]),
)
raise NotImplementedError(f"Sample type {sample} not implemented.")

@property
def distributions(self) -> List[jnp.ndarray]:
"""Return distributions."""
return self._distributions

@property
def policy_pairs(self) -> List[Tuple[Any, Any]]:
"""Return policy pairs."""
return self._policy_pairs

@property
def conditions(self) -> Optional[jnp.ndarray]:
"""Return conditions."""
return self._conditions

@property
def sample_to_idx(self) -> Dict[K, Any]:
"""Return sample to idx."""
return self._sample_to_idx

@property
def batch_size(self) -> int:
return self._batch_size
Loading
Loading