Skip to content

Commit

Permalink
added logging suppressor for falling back to cpu (#135)
Browse files Browse the repository at this point in the history
* added logging suppressor for falling back to cpu

* updated python version matrix for test, now >3.9 <3.12
  • Loading branch information
jgallowa07 authored Feb 22, 2024
1 parent 7f8d675 commit 5889166
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 256 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest]
python-version: [3.8, 3.9, "3.10", 3.11] # quoted 3.10 needed due to this bug: https://github.com/actions/runner/issues/1989
python-version: [3.9, "3.10", 3.11] # quoted 3.10 needed due to this bug: https://github.com/actions/runner/issues/1989

steps:
- name: Checkout 🛎️
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish_package_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
id-token: write
strategy:
matrix:
python-version: [3.8]
python-version: [3.9]

steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 0 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
# add sourcecode to path
import sys, os

# sys.path.insert(0, os.path.abspath("../multidms"))
# sys.path.insert(0, "{0}/..".format(os.path.abspath(".")))
sys.path.insert(0, "{}/..".format(os.path.abspath(".")))

# -- Project information -----------------------------------------------------
Expand All @@ -33,9 +31,7 @@
"sphinx.ext.mathjax",
"sphinx.ext.githubpages",
"sphinxcontrib.bibtex",
# "sphinx.ext.viewcode",
"sphinx.ext.napoleon",
# "matplotlib.sphinxext.plot_directive",
"nbsphinx",
"nbsphinx_link",
]
Expand Down
16 changes: 12 additions & 4 deletions multidms/model_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
import numpy as onp
import altair as alt

import logging

logging.getLogger("jax._src.xla_bridge").addFilter(
logging.Filter(
"An NVIDIA GPU may be present on this machine, "
"but a CUDA-enabled jaxlib is not installed. Falling back to cpu."
)
)


PARAMETER_NAMES_FOR_PLOTTING = {
"scale_coeff_lasso_shift": "Lasso Penalty",
Expand Down Expand Up @@ -922,10 +931,9 @@ def mut_type(mut):
.assign(mut_type=lambda x: x.mutation.apply(mut_type))
.reset_index()
.groupby(by=feature_cols)
.apply(
sparsity, include_groups=True
) # TODO This throws deprecation warning
.drop(columns=feature_cols + ["mutation"])
# .apply(sparsity, include_groups=True)
.apply(sparsity, include_groups=False)
# .drop(columns=feature_cols + ["mutation"])
.reset_index(drop=False)
.melt(id_vars=feature_cols, var_name="mut_param", value_name="sparsity")
)
Expand Down
405 changes: 161 additions & 244 deletions notebooks/simulation_validation.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
]
keywords = [
"multidms",
Expand All @@ -32,10 +33,10 @@ keywords = [


# Software Dependencies
requires-python = ">=3.8"
requires-python = ">=3.9"
dependencies = [
"polyclonal",
"jax",
"jax[cpu]==0.4.24",
"jaxopt",
"typing_extensions",
"numpy",
Expand Down

0 comments on commit 5889166

Please sign in to comment.