-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/moscot not #567
Closed
Closed
Feature/moscot not #567
Changes from all commits
Commits
Show all changes
257 commits
Select commit
Hold shift + click to select a range
d11df90
condOT, kwargs combiner & pair sampling missing
AlejandroTL 0e42fbc
fix for data sampler and growth rates
lucaeyring a054d01
seperated f and g loss fn
lucaeyring 232418b
seperated f and g loss fn
lucaeyring bc70ba8
final picnn
AlejandroTL 54834a8
merge moscot main
MUCDK 599d417
new import for dualpotential
MUCDK e24df5f
adapt import statement of dualpotentials
MUCDK 4eb53cf
adapt import statements
MUCDK ccffc11
fix pre-commit
MUCDK c95d081
Merge pull request #2 from theislab/basic_not
MUCDK 01d7c1e
Merge branch 'main' into conditional_not
MUCDK 427133f
fix pre-commit
MUCDK 5c43ac8
fix pre-commit
MUCDK 855cb52
reiteration
MUCDK aecc14b
introduce combiner_kwargs
MUCDK 37240ce
introduce CondOTProblem
MUCDK c48436b
fix circular import and add problemtype for condOT
MUCDK 9e69655
adapt attributes of CondOTProblem
MUCDK 6b91637
adapt import statements
MUCDK 0347ae2
Fix in _neuraldual.py
lucaeyring 2d02181
adjustments for ott-jax change
lucaeyring 6797978
adjustments for ott-jax change
lucaeyring 05c15f1
adjustment for ott-jax change
lucaeyring 51eb514
[ci skip] add generic neural ot classes
MUCDK 45fcac6
[ci skip] add tests
MUCDK b3eab0f
adapt ottjax sampler
MUCDK 78289b1
fix condot problem
MUCDK 53f0249
fix condot problem
MUCDK c569755
adapt tests
MUCDK 082df9f
merge main;
MUCDK 54a3ed9
remove jointProblem
MUCDK 342e618
fix jaxsampler
MUCDK 815c4a3
fix jaxsampler
MUCDK 7f60442
fix jaxsampler
MUCDK 5e43ef5
fix tests
MUCDK d7f65fc
add plot_convergence
MUCDK 7264441
remove jit from _compute_unbalanced marginals
MUCDK c427978
fix sinkhorn_divergence
MUCDK 6c5e55a
adapt tox.ini file
MUCDK 0d38985
shape mismatch fixed without precommit
AlejandroTL 85e421b
remove print statement
MUCDK 13abc46
finish merge
MUCDK 688b421
fix moscot main
MUCDK e4c796d
fix pre commits
MUCDK 20963ae
Merge remote-tracking branch 'moscot/main' into conditional_not_preco…
MUCDK 1727307
add marginal_kwargs to prepare method of TemporalNeuralProblem
MUCDK 8f166f6
added neural mixin and knn tm computation
lucaeyring 129b7b8
clean up and add typing
lucaeyring ae18265
restructure and fix to neural cell transition
lucaeyring 3a5434b
incorporated comments
lucaeyring a7fc1f1
removed get_projected_transport_matrix
lucaeyring f13831f
Merge pull request #8 from theislab/neural_mixin
lucaeyring 3178cc2
fix to projected tm computation
lucaeyring 0a6f7db
fix to scaling in
lucaeyring 92bd10c
Revert "fix to scaling in"
lucaeyring 0f5ca49
fix to scaling argument in marginal_kwargs
lucaeyring 0e0a721
conditional by split_indices & icnn architecture
AlejandroTL eeb5e78
combiner fixed, still missing kwargs
AlejandroTL c6f89be
condICNN & clipping function
AlejandroTL f39845d
condOT, kwargs combiner & pair sampling missing
AlejandroTL d570491
final picnn
AlejandroTL cde666a
fix pre-commit
MUCDK 1796115
fix pre-commit
MUCDK 018c118
reiteration
MUCDK 2c7bf11
introduce combiner_kwargs
MUCDK e3bad0e
introduce CondOTProblem
MUCDK c77cb66
fix circular import and add problemtype for condOT
MUCDK acfe011
adapt attributes of CondOTProblem
MUCDK 46ca711
adapt import statements
MUCDK 74e8ec6
[ci skip] add generic neural ot classes
MUCDK 61d0b83
[ci skip] add tests
MUCDK 146ffa9
adapt ottjax sampler
MUCDK 7d7815d
fix condot problem
MUCDK 0c492bc
fix condot problem
MUCDK d85cbf4
adapt tests
MUCDK a28c08c
remove jointProblem
MUCDK f86eb0e
fix jaxsampler
MUCDK ac3c781
fix jaxsampler
MUCDK 385bcfd
fix jaxsampler
MUCDK cfebc4d
fix tests
MUCDK e056042
add plot_convergence
MUCDK a7af86b
remove jit from _compute_unbalanced marginals
MUCDK c8cbe9d
fix sinkhorn_divergence
MUCDK ff145d1
adapt tox.ini file
MUCDK 9e241db
shape mismatch fixed without precommit
AlejandroTL 3cc279f
remove print statement
MUCDK c10c5b6
finish merge
MUCDK 74adb1e
adapt callbacks and rename tag `cost` to `cost_matrix` (#426)
MUCDK a282376
Feature/correlation test (#423)
MUCDK 7360819
fix sankey return statement (#428)
MUCDK e08f63f
Bump version: 0.1.0 → 0.1.1
MUCDK 728537c
fix return statements
MUCDK f2153e9
add save tests
MUCDK 4f1d26e
fix return type in mpl (#432)
MUCDK 0ffc730
Simplify linear operator (#431)
michalk8 cee2466
Explicitly jit the solvers (#433)
michalk8 6b11188
Feature/interpolate colors sankey (#434)
MUCDK e28816f
Remove `FGWSolver` (#437)
michalk8 3ff63be
fix bug in SinkhornProblem (#442)
MUCDK 865ad5a
fix pre commits
MUCDK 2571935
make push/pull always use source/target (#443)
MUCDK 3d22677
fix strip plotting in sankey (#445)
MUCDK b5370e0
Feature/spearman correlation (#444)
MUCDK 34ef904
Delete logo.png
MUCDK 9806854
Feature/plot order (#453)
MUCDK 50dbdca
Expose marginal kwargs for `moscot.temporal` and check for numeric ty…
MUCDK 46e1c65
adapt plot_convergence (#454)
MUCDK 61f5481
Bug/docs generic analysis mixin (#455)
MUCDK 082c879
Docs/improvements (#456)
MUCDK a53ef22
remove uns_key from set_plotting_vars (#458)
MUCDK 3a9f992
resolve `fig referenced before assignment` (#460)
MUCDK c323cb5
move generic mixins tests to problems` (#461)
MUCDK edb3c11
Tests/spatiotemporalproblem (#464)
MUCDK e3c3911
Feature/move taggedarray (#457)
MUCDK fdc5d54
add marginal_kwargs to prepare method of TemporalNeuralProblem
MUCDK 859700c
fix to scaling in
lucaeyring 22a48f5
Revert "fix to scaling in"
lucaeyring dfa2d82
fix to scaling argument in marginal_kwargs
lucaeyring 4bcd966
updated conditional not pipeline
lucaeyring cb7bc2d
merge into condot branch
lucaeyring ddb1772
merge into condot branch
lucaeyring 9636f1c
incoporated comments
lucaeyring f27dab0
incoporated comments
lucaeyring 816b084
incoporated comments
lucaeyring 1eb60f4
removed new_adata for push/pull
lucaeyring 2b6686f
Merge pull request #9 from theislab/condot_revamp
MUCDK 9e053ca
Merge pull request #7 from theislab/conditional_not_precommit
lucaeyring 0d96dbe
[ci skip] start docs
MUCDK 83e6d71
added temporal neural test
lucaeyring df346d9
[ci skip] continue docs
MUCDK ec0ab31
continue docs
MUCDK fbd4181
continue docs
MUCDK 10799ec
change validation epsilon
MUCDK 7f94ef9
fixed error when not computing wasserstein baseline
lucaeyring 1421c65
fixed error when not computing wasserstein baseline
lucaeyring b828752
Merge pull request #11 from theislab/feature/docs
lucaeyring 1b584ae
Merge branch 'main' into temporal_neural_test
lucaeyring d37c7d3
Merge pull request #12 from theislab/temporal_neural_test
lucaeyring dd106af
correct typo
MUCDK 453922c
fix bug
MUCDK 298f8fb
added neural tests
lucaeyring 491a9a4
[ci skip] draft CondNeuralOutput
MUCDK e6b1f9a
include CondDualPotentials and CondDualSolver
MUCDK 77fea48
merge moscot main restructuring
lucaeyring b1316bc
fixes to main merge
lucaeyring 1ff844f
fix typo
MUCDK 29ff674
fix test_cell_transition_subset_pipeline
MUCDK 518a8e5
fix tests
MUCDK 2a5bcd7
update conditionalDualPotentials
MUCDK 6594d43
update conditionalDualPotentials
MUCDK 2c43b07
fix most pre-commit hooks and fix tests
MUCDK 594e2f9
fix pandas version to <2.0
MUCDK ebc7877
fix tests for non-conditional solvers
MUCDK df0b6a8
merge continue_docs
MUCDK db7b2c0
continue
MUCDK 3220f9e
fix
MUCDK 4fb6e51
continue fixing
MUCDK fdf2882
fix ICNN setup
MUCDK 35db509
fix tests
MUCDK 65850ae
Merge pull request #18 from theislab/neural_tests_local
MUCDK b7d8a00
swap role of f and g, such that push/pull is correct again
MUCDK ea7187e
[ci skip] restructure to include more general neural solvers
MUCDK 0c9a587
[ci skip] restructure ICNNs to allow passing instances of ICNN
MUCDK d3a529d
adapt tests
MUCDK aba225d
Filled in Monge Gap structure
gocato d5005b1
Added Monge Gap paper to documentation
gocato 2e0fb4a
Ammend PointCloud Import
gocato f328056
Merge remote-tracking branch 'origin/feature/add_monge_gap' into feat…
gocato 4692dc5
Update _utils.py
gocato 36b1953
Merge remote-tracking branch 'origin/feature/add_monge_gap' into feat…
gocato 66e5794
Solve compatibility issue with ProblemKind
gocato 2ad39a8
Solve missing Import
gocato 16a797d
Fix call to deprecated function
gocato 7d63082
Fix style and comment issues
gocato 9aff229
add callback, swap f & g
lucaeyring 8f91e4a
add callback, swap f & g
lucaeyring 2d63db7
add callback, swap f & g
lucaeyring 5452485
Merge pull request #26 from theislab/feature/add_callback
MUCDK 6f413a2
Merge pull request #27 from theislab/dev
MUCDK 43b873d
intermediate save
MUCDK 602b8a7
intermediate save
MUCDK b85c8e0
intermediate save
MUCDK de6facc
intermediate save
MUCDK a32038c
partially resolve precommit errors
MUCDK fa9d6cf
[ci skip] fix merge conflicts
MUCDK 501b018
resolve conflict
MUCDK 0419670
remove pairwise policy
MUCDK ce7667a
add neural dependencies
MUCDK f97b08c
add neural dependencies
MUCDK 243c73a
add flax
MUCDK 4516612
fix _call_kwargs
MUCDK 68e7bf3
fix marginal kwargs
MUCDK 536f681
remove monge gap solver
MUCDK 3185371
clean condneuralsolver
MUCDK 582ad43
[ci skip] introduce new data container for joint neural problems
MUCDK 7e3d4f7
add conditions in distirbutioncontainer
MUCDK 51915db
resolve unfreeze/freeze
MUCDK 47508b3
enable pretraining and weight clipping
MUCDK 5363d98
make dicts compatible with older python versions
MUCDK 0dd6b8b
resolve precommit errors partially
MUCDK 21f3309
resolve precommit errors partially
MUCDK 3bbb4da
adapt tests
MUCDK 2cced99
[ci skip] draft unbalancedNeuralMixin
MUCDK aa49c10
[ci skip] fix naming of posterior marginals
MUCDK 16d3204
[ci skip] add MLP_marginals
MUCDK c61633c
adapt neural output to incorporate learnt rescaling functions
MUCDK 2dc6ff3
fix _solve in neuraldualsolver
MUCDK badd57b
incorporate feedback
MUCDK bd983fd
fix distributioncollection class
MUCDK b00e2b7
unify _split_data
MUCDK 79f050e
fix tests
MUCDK 5c19f1f
fix some precommit hooks
MUCDK 4e964e5
make neural dependencies optional
MUCDK 517db5f
make neural dependencies optional
MUCDK 742f02c
delete old files
MUCDK d763de4
adapt pyproject.toml
MUCDK 2c64e3b
adapt pyproject.toml
MUCDK 73c7830
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6d40537
[ci skip] adjust _format_params
MUCDK 9483e69
adapt neuraldualsolver to be more similar to ott-jax
MUCDK 32300e5
adapt neuraldualsolver
MUCDK 88b9a59
TODO: make JaxSampler return conditions
MUCDK dd74215
add basic neural test
MUCDK dce2b21
[ci skip] intermediate save
MUCDK d2ece68
adapt neuraldualsolver and finish tests for neural backend
MUCDK 3bd674f
[ci skip] TODO: re-iterate on initialisation of neural solver
MUCDK 61e3a01
adapt distributioncontainer
MUCDK b949d6f
fix dict bug
MUCDK e90934c
resolve passing of arguments in solver call methods
MUCDK 295eb6e
[ci skip] adapt `solve` in `CondOTProblem`
MUCDK 08fcec2
adapt tests and valid loader conditions
MUCDK 868146a
adapt neural backend tests
MUCDK f84bd31
fix mypy errors
MUCDK e74e0a4
Merge branch 'main' into feature/moscot_not
MUCDK 9482ccc
make basesolveroutput to basediscretesolveroutput
MUCDK 182bab7
move `to` to BaseSolverOutput`
MUCDK 166bd49
adapt transport_matrix docs
MUCDK fea3ac9
adapt transport_matrix docs
MUCDK 7669f7b
adapt tests
MUCDK e53f9c0
adapt tests
MUCDK cc6bfb5
update unbalancedness mixin
MUCDK a438629
use implementation from moscot
MUCDK 3f422bf
uncomment unused code
MUCDK 3e9dfc5
before passing states to loss-fn
MUCDK f4b7c76
intermediate save
MUCDK 4152baf
adapt neuraldualsolver
MUCDK 06e55c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0b46bae
resolve some / not all pre commit errors
MUCDK 4566eb9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from functools import partial | ||
from types import MappingProxyType | ||
from typing import Any, Dict, Hashable, List, Literal, Optional, Tuple, TypeVar, Union | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from ott.geometry.pointcloud import PointCloud | ||
from ott.problems.linear import linear_problem | ||
from ott.solvers.linear import sinkhorn | ||
|
||
K = TypeVar("K", bound=Hashable) | ||
|
||
|
||
class JaxSampler: | ||
"""Data sampler for Jax.""" | ||
|
||
def __init__( | ||
self, | ||
distributions: List[jnp.ndarray], | ||
policy_pairs: List[Tuple[Any, Any]], | ||
conditions: Optional[List[jnp.ndarray]], | ||
a: List[jnp.ndarray] = None, | ||
b: List[jnp.ndarray] = None, | ||
sample_to_idx: Dict[K, Any] = MappingProxyType({}), | ||
batch_size: int = 1024, | ||
tau_a: float = 1.0, | ||
tau_b: float = 1.0, | ||
epsilon: float = 0.1, | ||
): | ||
"""Initialize data sampler.""" | ||
if not len(distributions) == len(a) == len(b): | ||
raise ValueError("Number of distributions, a, and b must be equal.") | ||
self._distributions = distributions | ||
self._conditions = conditions | ||
self._batch_size = batch_size | ||
self._policy_pairs = policy_pairs | ||
if not len(sample_to_idx): | ||
if len(self.policy_pairs) > 1: | ||
raise ValueError("If `policy_pairs` contains more than 1 value, `sample_to_idx` is required.") | ||
sample_to_idx = {self.policy_pairs[0][0]: 0, self.policy_pairs[0][1]: 1} | ||
self._sample_to_idx = sample_to_idx | ||
|
||
@partial(jax.jit, static_argnames=["index"]) | ||
def _sample_source(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, None]: | ||
"""Jitted sample function.""" | ||
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(a[index])) | ||
return jnp.asarray(samples), None | ||
|
||
@partial(jax.jit, static_argnames=["index"]) | ||
def _sample_source_conditional(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
"""Jitted sample function.""" | ||
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(a[index])) | ||
conds = jax.random.choice(key, self.conditions[index], shape=[batch_size], p=jnp.squeeze(a[index])) # type: ignore[index] # noqa: E501 | ||
return samples, conds | ||
|
||
@partial(jax.jit, static_argnames=["index"]) | ||
def _sample_target(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
"""Jitted sample function.""" | ||
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(b[index])) | ||
return samples, None | ||
|
||
@partial(jax.jit, static_argnames=["index"]) | ||
def _sample_target_conditional(key: jax.random.KeyArray, index: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
"""Jitted sample function.""" | ||
samples = jax.random.choice(key, self.distributions[index], shape=[batch_size], p=jnp.squeeze(b[index])) | ||
conds = jax.random.choice(key, self.conditions[index], shape=[batch_size], p=jnp.squeeze(b[index])) # type: ignore[index] # noqa: E501 | ||
return samples, conds | ||
|
||
@jax.jit | ||
def _compute_unbalanced_marginals( | ||
batch_source: jnp.ndarray, | ||
batch_target: jnp.ndarray, | ||
sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), | ||
) -> Tuple[jnp.ndarray, jnp.ndarray]: | ||
"""Jitted function to compute the source and target marginals for a batch.""" | ||
geom = PointCloud(batch_source, batch_target, epsilon=epsilon, scale_cost="mean") | ||
out = sinkhorn.Sinkhorn(**sinkhorn_kwargs)(linear_problem.LinearProblem(geom, tau_a=tau_a, tau_b=tau_b)) | ||
return out.matrix.sum(axis=1), out.matrix.sum(axis=0) | ||
|
||
@jax.jit | ||
def _unbalanced_resample( | ||
key: jax.random.KeyArray, | ||
batch: Tuple[jnp.ndarray, ...], | ||
marginals: jnp.ndarray, | ||
) -> Tuple[jnp.ndarray, ...]: | ||
"""Resample a batch based upon log marginals.""" | ||
# sample from marginals | ||
indices = jax.random.choice(key, a=len(marginals), p=jnp.squeeze(marginals), shape=[len(marginals)]) | ||
return tuple(b[indices] if b is not None else None for b in batch) | ||
|
||
def _sample_policy_pair(key: jax.random.KeyArray) -> Tuple[Any, Any]: | ||
"""Sample a policy pair.""" | ||
index = jax.random.randint(key, shape=[], minval=0, maxval=len(self.policy_pairs)) | ||
return self.policy_pairs[index] | ||
|
||
self._sample_source = _sample_source if self.conditions is None else _sample_source_conditional | ||
self._sample_target = _sample_target if self.conditions is None else _sample_target_conditional | ||
self.sample_policy_pair = _sample_policy_pair | ||
self.compute_unbalanced_marginals = _compute_unbalanced_marginals | ||
self.unbalanced_resample = _unbalanced_resample | ||
|
||
def __call__( | ||
self, | ||
key: jax.random.KeyArray, | ||
policy_pair: Tuple[Any, Any], | ||
sample: Literal["source", "target", "both"] = "both", | ||
full_dataset: bool = False, | ||
) -> Union[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: | ||
"""Sample data. When sampling the source, the conditions are returned, too.""" | ||
if full_dataset: | ||
if sample == "source": | ||
return jnp.asarray( | ||
self.distributions[self.sample_to_idx[policy_pair[0]]] | ||
), None if self.conditions is None else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]]) | ||
if sample == "target": | ||
return jnp.asarray( | ||
self.distributions[self.sample_to_idx[policy_pair[1]]] | ||
), None if self.conditions is None else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]]) | ||
if sample == "both": | ||
return ( | ||
jnp.asarray(self.distributions[self.sample_to_idx[policy_pair[0]]]), | ||
None | ||
if self.conditions is None | ||
else jnp.asarray(self.conditions[self.sample_to_idx[policy_pair[0]]]), | ||
jnp.asarray(self.distributions[self.sample_to_idx[policy_pair[1]]]), | ||
) | ||
raise NotImplementedError(f"Sample type {sample} not implemented.") | ||
if sample == "source": | ||
return self._sample_source(key, self.sample_to_idx[policy_pair[0]]) | ||
if sample == "target": | ||
return self._sample_target(key, self.sample_to_idx[policy_pair[1]]) | ||
if sample == "both": | ||
return ( | ||
self._sample_source(key, self.sample_to_idx[policy_pair[0]])[0], | ||
*self._sample_target(key, self.sample_to_idx[policy_pair[1]]), | ||
) | ||
raise NotImplementedError(f"Sample type {sample} not implemented.") | ||
|
||
@property | ||
def distributions(self) -> List[jnp.ndarray]: | ||
"""Return distributions.""" | ||
return self._distributions | ||
|
||
@property | ||
def policy_pairs(self) -> List[Tuple[Any, Any]]: | ||
"""Return policy pairs.""" | ||
return self._policy_pairs | ||
|
||
@property | ||
def conditions(self) -> Optional[jnp.ndarray]: | ||
"""Return conditions.""" | ||
return self._conditions | ||
|
||
@property | ||
def sample_to_idx(self) -> Dict[K, Any]: | ||
"""Return sample to idx.""" | ||
return self._sample_to_idx | ||
|
||
@property | ||
def batch_size(self) -> int: | ||
return self._batch_size |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@michalk8 how do I make these optional, i.e. a
neural
optional dependency as below, but still make github CI install these optional dependencies?