Skip to content

Commit

Permalink
trivial refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Aug 27, 2024
1 parent 55edb9c commit ecd375b
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 48 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Branches

The master branch contains latest changes to the main release version.

There isa a test suite, which runs automatically on GitHub actions for new commits and pull requests.
There is a test suite, which runs automatically on GitHub actions for new commits and pull requests.
Reference results and test outputs are stored in the `test outputs repository <https://github.com/cmbant/CAMB_test_outputs/>`_. Tests can also be run locally.

To reproduce legacy results, see these branches:
Expand Down
7 changes: 3 additions & 4 deletions camb/baseconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,8 @@ def __init__(self, name, **kwargs):
else:
assert isinstance(names, dict)
self.name_values = names
for name in names:
self.values[names[name]] = name
for name, value in names.items():
self.values[value] = name

def __get__(self, instance, owner):
value = getattr(instance, self.real_name)
Expand Down Expand Up @@ -693,8 +693,7 @@ def __new__(cls, *args, **kwargs):

def __init__(self, **kwargs):
if kwargs:
unknowns = set(kwargs) - self.get_valid_field_names()
if unknowns:
if unknowns := set(kwargs).difference(self.get_valid_field_names()):
raise ValueError('Unknown argument(s): %s' % unknowns)
super().__init__(**kwargs)

Expand Down
6 changes: 3 additions & 3 deletions camb/recombination.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class Recfast(RecombinationModel):
@optional_fortran_class
class CosmoRec(RecombinationModel):
"""
`CosmoRec <http://www.jb.man.ac.uk/~jchluba/Science/CosmoRec/CosmoRec.html>`_ recombination model.
To use this, the library must be build with CosmoRec installed and RECOMBINATION_FILES including cosmorec
`CosmoRec <https://www.jb.man.ac.uk/~jchluba/Science/CosmoRec/CosmoRec.html>`_ recombination model.
To use this, the library must be built with CosmoRec installed and RECOMBINATION_FILES including cosmorec
in the Makefile.
CosmoRec must be built with -fPIC added to the compiler flags.
Expand All @@ -60,7 +60,7 @@ class CosmoRec(RecombinationModel):
class HyRec(RecombinationModel):
r"""
`HyRec <https://github.com/nanoomlee/HYREC-2>`_ recombination model.
To use this, the library must be build with HyRec installed and RECOMBINATION_FILES including hyrec in the Makefile.
To use this, the library must be built with HyRec installed and RECOMBINATION_FILES including hyrec in the Makefile.
"""
_fortran_class_module_ = 'HyRec'
Expand Down
61 changes: 21 additions & 40 deletions camb/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,15 @@ def get_derived_params(self):
"""
:return: dictionary of derived parameter values, indexed by name ('kd', 'age', etc..)
"""
res = {}
for name, value in zip(model.derived_names, self.ThermoDerivedParams):
res[name] = value
return res
return dict(zip(model.derived_names, self.ThermoDerivedParams))

def get_background_outputs(self):
"""
Get BAO values for redshifts set in Params.z_outputs
:return: rs/DV, H, DA, F_AP for each requested redshift (as 2D array)
"""
n = len(self.Params.z_outputs)
if not n:
if not (n := len(self.Params.z_outputs)):
raise CAMBError('Set z_outputs with required redshifts (and then calculate transfers/results)'
' before calling get_background_outputs')
outputs = np.empty((n, 4))
Expand Down Expand Up @@ -527,8 +523,7 @@ def get_time_evolution(self, q, eta, vars=model.evolve_names, lAccuracyBoost=4,
import sympy
named_vars = [var for var in vars if isinstance(var, str)]

unknown = set(named_vars) - set(model.evolve_names)
if unknown:
if unknown := set(named_vars).difference(model.evolve_names):
raise CAMBError('Unknown names %s; valid names are %s' % (unknown, model.evolve_names))

num_standard_names = len(model.evolve_names)
Expand All @@ -551,18 +546,17 @@ def get_time_evolution(self, q, eta, vars=model.evolve_names, lAccuracyBoost=4,
k = np.array(q, dtype=np.float64)
times = np.array(np.atleast_1d(eta), dtype=np.float64)
indices = np.argsort(times) # times must be in increasing order
ncustom = len(custom_vars)
if ncustom:
if n_custom := len(custom_vars):
from . import symbolic
funcPtr = symbolic.compile_sympy_to_camb_source_func(custom_vars, frame=frame)
custom_source_func = ctypes.cast(funcPtr, ctypes.c_void_p)
else:
custom_source_func = ctypes.c_void_p(0)
nvars = num_standard_names + ncustom
nvars = num_standard_names + n_custom
outputs = np.empty((k.shape[0], times.shape[0], nvars))
if CAMB_TimeEvolution(byref(self), byref(c_int(k.shape[0])), k, byref(c_int(times.shape[0])),
times[indices], byref(c_int(nvars)), outputs,
byref(c_int(ncustom)), byref(custom_source_func)):
byref(c_int(n_custom)), byref(custom_source_func)):
config.check_global_error('get_time_evolution')
i_rev = np.zeros(times.shape, dtype=int)
i_rev[indices] = np.arange(times.shape[0])
Expand Down Expand Up @@ -599,17 +593,13 @@ def get_background_time_evolution(self, eta, vars=model.background_names, format

if isinstance(vars, str):
vars = [vars]
unknown = set(vars) - set(model.background_names)
if unknown:
if unknown := set(vars).difference(model.background_names):
raise CAMBError('Unknown names %s; valid names are %s' % (unknown, model.background_names))
outputs = np.zeros((eta.shape[0], 9))
CAMB_BackgroundThermalEvolution(byref(self), byref(c_int(eta.shape[0])), eta, outputs)
indices = [model.background_names.index(var) for var in vars]
if format == 'dict':
res = {}
for var, index in zip(vars, indices):
res[var] = outputs[:, index]
return res
return {var: outputs[:, index] for var, index in zip(vars, indices)}
else:
assert format == 'array', "format must be dict or array"
return outputs[:, np.array(indices)]
Expand Down Expand Up @@ -639,18 +629,14 @@ def get_background_densities(self, a, vars=model.density_names, format='dict'):
"""
if isinstance(vars, str):
vars = [vars]
unknown = set(vars) - set(model.density_names)
if unknown:
if unknown := set(vars).difference(model.density_names):
raise CAMBError('Unknown names %s; valid names are %s' % (unknown, model.density_names))
arr = np.atleast_1d(a)
outputs = np.zeros((arr.shape[0], 8))
self.f_GetBackgroundDensities(byref(c_int(arr.shape[0])), arr, outputs)
indices = [model.density_names.index(var) for var in vars]
if format == 'dict':
res = {}
for var, index in zip(vars, indices):
res[var] = outputs[:, index]
return res
return {var: outputs[:, index] for var, index in zip(vars, indices)}
else:
assert format == 'array', "format must be dict or array"
return outputs[:, np.array(indices)]
Expand All @@ -663,14 +649,14 @@ def get_dark_energy_rho_w(self, a):
:param a: scalar factor or array of scale factors
:return: rho, w arrays at redshifts :math:`1/a-1` [or scalars if :math:`a` is scalar]
"""
if np.isscalar(a):
if scalar := np.isscalar(a):
scales = np.array([a])
else:
scales = np.ascontiguousarray(a)
rho = np.zeros(scales.shape)
w = np.zeros(scales.shape)
self.f_DarkEnergyStressEnergy(scales, rho, w, byref(c_int(len(scales))))
if np.isscalar(a):
if scalar:
return rho[0], w[0]
else:
return rho, w
Expand All @@ -685,10 +671,7 @@ def get_Omega(self, var, z=0):
"""
dic = self.get_background_densities(1. / (1 + z), ['tot', var])
res = dic[var] / dic['tot']
if np.isscalar(z):
return res[0]
else:
return res
return res[0] if np.isscalar(z) else res

def get_matter_transfer_data(self) -> MatterTransferData:
"""
Expand Down Expand Up @@ -1371,7 +1354,8 @@ def angular_diameter_distance(self, z):
arr[indices] = arr.copy()
return arr

def _make_scalar_or_arrays(self, z1, z2):
@staticmethod
def _make_scalar_or_arrays(z1, z2):
if np.isscalar(z1):
if np.isscalar(z2):
return z1, z2
Expand Down Expand Up @@ -1449,14 +1433,14 @@ def redshift_at_conformal_time(self, eta):
:return: redshift at eta, scalar or array
"""

if np.isscalar(eta):
if scalar := np.isscalar(eta):
times = np.array([eta], dtype=np.float64)
else:
times = np.ascontiguousarray(eta, dtype=np.float64)
redshifts = np.empty(times.shape)
self.f_RedshiftAtTimeArr(redshifts, times, byref(c_int(times.shape[0])))
config.check_global_error('redshift_at_conformal_time')
if np.isscalar(eta):
if scalar:
return redshifts[0]
else:
return redshifts
Expand Down Expand Up @@ -1564,7 +1548,7 @@ def conformal_time(self, z, presorted=None, tol=None):
:param tol: integration tolerance
:return: eta(z)/Mpc
"""
if np.isscalar(z):
if scalar := np.isscalar(z):
redshifts = np.array([z], dtype=np.float64)
else:
redshifts = np.asarray(z, dtype=np.float64)
Expand All @@ -1579,7 +1563,7 @@ def conformal_time(self, z, presorted=None, tol=None):
tol = byref(c_double(tol))

self.f_TimeOfzArr(eta, redshifts, byref(c_int(eta.shape[0])), tol)
if np.isscalar(z):
if scalar:
return eta[0]
else:
if presorted is False:
Expand All @@ -1598,16 +1582,13 @@ def sound_horizon(self, z):
:param z: redshift or array of redshifts
:return: r_s(z)
"""
if np.isscalar(z):
if scalar := np.isscalar(z):
redshifts = np.array([z], dtype=np.float64)
else:
redshifts = np.array(z, dtype=np.float64)
rs = np.empty(redshifts.shape)
self.f_sound_horizon_zArr(rs, redshifts, byref(c_int(redshifts.shape[0])))
if np.isscalar(z):
return rs[0]
else:
return rs
return rs[0] if scalar else rs

def cosmomc_theta(self):
r"""
Expand Down

0 comments on commit ecd375b

Please sign in to comment.