Skip to content

Commit

Permalink
Allow user to aggressively use memory (#164)
Browse files Browse the repository at this point in the history
* Allow user to aggressively use memory

* Fix pre-commit

* Fix little bug
  • Loading branch information
dachengx authored May 20, 2024
1 parent 81b334e commit c711e74
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 27 deletions.
9 changes: 7 additions & 2 deletions appletree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
# stop jax to preallocate memory
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
if "AGGRESSIVE_MEMORY_ALLOCATION" not in os.environ:
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
os.environ.setdefault("XLA_PYTHON_CLIENT_ALLOCATOR", "platform")
XLA_PYTHON_CLIENT_PREALLOCATE = os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]
XLA_PYTHON_CLIENT_ALLOCATOR = os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]
print(f"XLA_PYTHON_CLIENT_PREALLOCATE is set to {XLA_PYTHON_CLIENT_PREALLOCATE}")
print(f"XLA_PYTHON_CLIENT_ALLOCATOR is set to {XLA_PYTHON_CLIENT_ALLOCATOR}")

from . import utils
from .utils import *
Expand Down
6 changes: 3 additions & 3 deletions appletree/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def __init__(self, name: Optional[str] = None, llh_name: Optional[str] = None, *

if self.bins_type != "meshgrid" and self.add_eps_to_hist:
warn(
"It is empirically dangerous to have add_eps_to_hist==True,\
when your bins_type is not meshgrid! It may lead to very bad fit with\
lots of eff==0."
"It is empirically dangerous to have add_eps_to_hist==True, "
"when your bins_type is not meshgrid! It may lead to very bad fit with "
"lots of eff==0."
)

def set_binning(self, **kwargs):
Expand Down
38 changes: 19 additions & 19 deletions appletree/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,18 @@ def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True):

print(f"{indent}COMPONENT {i}: {name}")
if isinstance(component, ComponentSim):
print(f"{indent*2}type: simulation")
print(f"{indent*2}rate_par: {component.rate_name}")
print(f"{indent*2}pars: {need}")
print(f"{indent * 2}type: simulation")
print(f"{indent * 2}rate_par: {component.rate_name}")
print(f"{indent * 2}pars: {need}")
if not short:
print(f"{indent*2}worksheet: {component.worksheet}")
print(f"{indent * 2}worksheet: {component.worksheet}")
elif isinstance(component, ComponentFixed):
print(f"{indent*2}type: fixed")
print(f"{indent*2}file_name: {component._file_name}")
print(f"{indent*2}rate_par: {component.rate_name}")
print(f"{indent*2}pars: {need}")
print(f"{indent * 2}type: fixed")
print(f"{indent * 2}file_name: {component._file_name}")
print(f"{indent * 2}rate_par: {component.rate_name}")
print(f"{indent * 2}pars: {need}")
if not short:
print(f"{indent*2}from_file: {component._file_name}")
print(f"{indent * 2}from_file: {component._file_name}")
else:
pass
print()
Expand Down Expand Up @@ -525,7 +525,7 @@ def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True):
print("LOGPDF\n")
print(f"{indent}logpdf_args:")
for k, v in self.logpdf_args.items():
print(f"{indent*2}{k}: {v}")
print(f"{indent * 2}{k}: {v}")
print("\n" + "-" * 40)

print("MODEL\n")
Expand All @@ -536,18 +536,18 @@ def print_likelihood_summary(self, indent: str = " " * 4, short: bool = True):

print(f"{indent}COMPONENT {i}: {name}")
if isinstance(component, ComponentSim):
print(f"{indent*2}type: simulation")
print(f"{indent*2}rate_par: {component.rate_name}")
print(f"{indent*2}pars: {need}")
print(f"{indent * 2}type: simulation")
print(f"{indent * 2}rate_par: {component.rate_name}")
print(f"{indent * 2}pars: {need}")
if not short:
print(f"{indent*2}worksheet: {component.worksheet}")
print(f"{indent * 2}worksheet: {component.worksheet}")
elif isinstance(component, ComponentFixed):
print(f"{indent*2}type: fixed")
print(f"{indent*2}file_name: {component._file_name}")
print(f"{indent*2}rate_par: {component.rate_name}")
print(f"{indent*2}pars: {need}")
print(f"{indent * 2}type: fixed")
print(f"{indent * 2}file_name: {component._file_name}")
print(f"{indent * 2}rate_par: {component.rate_name}")
print(f"{indent * 2}pars: {need}")
if not short:
print(f"{indent*2}from_file: {component._file_name}")
print(f"{indent * 2}from_file: {component._file_name}")
else:
pass
print()
Expand Down
6 changes: 3 additions & 3 deletions appletree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ def set_gpu_memory_usage(fraction=0.3):
"""
if fraction > 1:
fraction = 1
fraction = 0.99
if fraction <= 0:
raise ValueError("fraction must be positive!")
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = f"{fraction:.2f}"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = f".{int(fraction * 100):d}"


@export
Expand Down Expand Up @@ -335,7 +335,7 @@ def plot_irreg_histogram_2d(bins_x, bins_y, hist, **kwargs):
bins_y = np.asarray(bins_y)

density = kwargs.get("density", False)
cmap = mpl.cm.get_cmap("RdBu_r")
cmap = mpl.cm.RdBu_r

loc = []
width = []
Expand Down

0 comments on commit c711e74

Please sign in to comment.