Skip to content

Commit

Permalink
python vectorized input support
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Feb 2, 2022
1 parent d76c2f5 commit 11fdbbb
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 23 deletions.
Binary file modified camb/cambdll.dll
Binary file not shown.
10 changes: 6 additions & 4 deletions camb/initialpower.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def set_scalar_table(self, k, PK):
:param k: array of k values (Mpc^{-1})
:param PK: array of scalar power spectrum values
"""
self.f_SetScalarTable(byref(c_int(len(k))), np.asarray(k), np.asarray(PK))
self.f_SetScalarTable(byref(c_int(len(k))), np.ascontiguousarray(k, dtype=np.float64),
np.ascontiguousarray(PK, dtype=np.float64))

def set_tensor_table(self, k, PK):
"""
Expand All @@ -68,7 +69,8 @@ def set_tensor_table(self, k, PK):
:param k: array of k values (Mpc^{-1})
:param PK: array of tensor power spectrum values
"""
self.f_SetTensorTable(byref(c_int(len(k))), np.asarray(k), np.asarray(PK))
self.f_SetTensorTable(byref(c_int(len(k))), np.ascontiguousarray(k, dtype=np.float64),
np.ascontiguousarray(PK, dtype=np.float64))

def set_scalar_log_regular(self, kmin, kmax, PK):
"""
Expand All @@ -79,7 +81,7 @@ def set_scalar_log_regular(self, kmin, kmax, PK):
:param PK: array of scalar power spectrum values, with PK[0]=P(kmin) and PK[-1]=P(kmax)
"""
self.f_SetScalarLogRegular(byref(c_double(kmin)), byref(c_double(kmax)), byref(c_int(len(PK))),
np.asarray(PK))
np.ascontiguousarray(PK, dtype=np.float64))

def set_tensor_log_regular(self, kmin, kmax, PK):
"""
Expand All @@ -91,7 +93,7 @@ def set_tensor_log_regular(self, kmin, kmax, PK):
"""

self.f_SetTensorLogRegular(byref(c_double(kmin)), byref(c_double(kmax)), byref(c_int(len(PK))),
np.asarray(PK))
np.ascontiguousarray(PK, dtype=np.float64))


@fortran_class
Expand Down
53 changes: 36 additions & 17 deletions camb/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,10 @@ class CAMBdata(F2003Class):
('Hofz', [d_arg], c_double),
('HofzArr', [numpy_1d, numpy_1d, int_arg]),
('DeltaPhysicalTimeGyr', [d_arg, d_arg, d_arg], c_double),
('DeltaPhysicalTimeGyrArr', [numpy_1d, numpy_1d, numpy_1d, int_arg, d_arg]),
('GetBackgroundDensities', [int_arg, numpy_1d, numpy_2d]),
('DeltaTime', [d_arg, d_arg, d_arg], c_double),
('DeltaTimeArr', [numpy_1d, numpy_1d, numpy_1d, int_arg, d_arg]),
('TimeOfzArr', [numpy_1d, numpy_1d, int_arg, d_arg]),
('sound_horizon_zArr', [numpy_1d, numpy_1d, int_arg]),
('RedshiftAtTimeArr', [numpy_1d, numpy_1d, int_arg]),
Expand Down Expand Up @@ -1361,6 +1363,22 @@ def angular_diameter_distance(self, z):
arr[indices] = arr.copy()
return arr

def _make_scalar_or_arrays(self, z1, z2):
if np.isscalar(z1):
if np.isscalar(z2):
return z1, z2
else:
z1 = np.ones(len(z2)) * z1
else:
z1 = np.ascontiguousarray(z1, dtype=np.float64)
if np.isscalar(z2):
z2 = np.ones(len(z1)) * z2
else:
z2 = np.ascontiguousarray(z2, dtype=np.float64)
if len(z1) != len(z2):
raise CAMBError('z1 nand z2 must be scalar or same-length 1D arrays')
return z1, z2

def angular_diameter_distance2(self, z1, z2):
r"""
Get angular diameter distance between two redshifts
Expand All @@ -1375,16 +1393,9 @@ def angular_diameter_distance2(self, z1, z2):
:param z2: redshift 2, or orray of redshifts
:return: result (scalar or array of distances between pairs of z1, z2)
"""
if np.isscalar(z1):
if not np.isscalar(z2):
z1 = np.ones(len(z2)) * z1
elif np.isscalar(z2):
if not np.isscalar(z1):
z2 = np.ones(len(z1)) * z2
z1, z2 = self._make_scalar_or_arrays(z1, z2)
if not np.isscalar(z1):
z1 = np.ascontiguousarray(z1, dtype=np.float64)
z2 = np.ascontiguousarray(z2, dtype=np.float64)
dists = np.empty(z1.shape, dtype=np.float64)
dists = np.empty(z1.shape)
self.f_AngularDiameterDistance2Arr(dists, z1, z2, byref(c_int(dists.shape[0])))
return dists
else:
Expand Down Expand Up @@ -1500,9 +1511,13 @@ def physical_time_a1_a2(self, a1, a2):
:param a2: scale factor 2
:return: (age(a2)-age(a1))/Gigayear
"""
if not np.isscalar(a1) or not np.isscalar(a2):
raise CAMBError('vector inputs not supported yet')
return self.f_DeltaPhysicalTimeGyr(byref(c_double(a1)), byref(c_double(a2)), None)
a1, a2 = self._make_scalar_or_arrays(a1, a2)
if not np.isscalar(a1):
times = np.empty(a1.shape)
self.f_DeltaPhysicalTimeGyrArr(times, a1, a2, byref(c_int(times.shape[0])), None)
return times
else:
return self.f_DeltaPhysicalTimeGyr(byref(c_double(a1)), byref(c_double(a2)), None)

def physical_time(self, z):
"""
Expand All @@ -1511,6 +1526,8 @@ def physical_time(self, z):
:param z: redshift
:return: t(z)/Gigayear
"""
if not np.isscalar(z):
z = np.asarray(z, dtype=np.float64)
return self.physical_time_a1_a2(0, 1.0 / (1 + z))

def conformal_time_a1_a2(self, a1, a2):
Expand All @@ -1521,11 +1538,13 @@ def conformal_time_a1_a2(self, a1, a2):
:param a2: scale factor 2
:return: eta(a2)-eta(a1) = chi(a1)-chi(a2) in Megaparsec
"""

if not np.isscalar(a1) or not np.isscalar(a2):
raise CAMBError('vector inputs not supported yet')

return self.f_DeltaTime(byref(c_double(a1)), byref(c_double(a2)), None)
a1, a2 = self._make_scalar_or_arrays(a1, a2)
if not np.isscalar(a1):
times = np.empty(a1.shape)
self.f_DeltaTimeArr(times, a1, a2, byref(c_int(times.shape[0])), None)
return times
else:
return self.f_DeltaTime(byref(c_double(a1)), byref(c_double(a2)), None)

def conformal_time(self, z, presorted=None, tol=None):
"""
Expand Down
5 changes: 3 additions & 2 deletions camb/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ def set_table(self, z, W, bias_z=None):
raise ValueError(
"Redshifts must be well sampled and in ascending order, with window function the same length as z")
if bias_z is not None:
bias_z = np.asarray(bias_z, dtype=np.float64)
bias_z = np.ascontiguousarray(bias_z, dtype=np.float64)
if len(bias_z) != len(z):
raise ValueError("bias array must be same size as the redshift array")

self.f_SetTable(byref(c_int(len(z))), np.asarray(z, dtype=np.float64), np.asarray(W, dtype=np.float64), bias_z)
self.f_SetTable(byref(c_int(len(z))), np.ascontiguousarray(z, dtype=np.float64),
np.ascontiguousarray(W, dtype=np.float64), bias_z)
5 changes: 5 additions & 0 deletions camb/tests/camb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ def testBackground(self):
self.assertAlmostEqual(delta2, dists[1])
self.assertEqual(0, dists[2])

self.assertEqual(data.physical_time(0.4), data.physical_time([0.2, 0.4])[1])
d = data.conformal_time_a1_a2(0, 0.5) + data.conformal_time_a1_a2(0.5, 1)
self.assertAlmostEqual(d, data.conformal_time_a1_a2(0, 1))
self.assertAlmostEqual(d, sum(data.conformal_time_a1_a2([0, 0.5], [0.5, 1])))

def testEvolution(self):
redshifts = [0.4, 31.5]
pars = camb.set_params(H0=67.5, ombh2=0.022, omch2=0.122, As=2e-9, ns=0.95,
Expand Down
33 changes: 33 additions & 0 deletions fortran/results.f90
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,11 @@ module results

contains
procedure :: DeltaTime => CAMBdata_DeltaTime
procedure :: DeltaTimeArr => CAMBdata_DeltaTimeArr
procedure :: TimeOfz => CAMBdata_TimeOfz
procedure :: TimeOfzArr => CAMBdata_TimeOfzArr
procedure :: DeltaPhysicalTimeGyr => CAMBdata_DeltaPhysicalTimeGyr
procedure :: DeltaPhysicalTimeGyrArr => CAMBdata_DeltaPhysicalTimeGyrArr
procedure :: AngularDiameterDistance => CAMBdata_AngularDiameterDistance
procedure :: AngularDiameterDistanceArr => CAMBdata_AngularDiameterDistanceArr
procedure :: AngularDiameterDistance2 => CAMBdata_AngularDiameterDistance2
Expand Down Expand Up @@ -604,6 +606,21 @@ function CAMBdata_DeltaTime(this, a1,a2, in_tol)

end function CAMBdata_DeltaTime

subroutine CAMBdata_DeltaTimeArr(this, arr, a1, a2, n, tol)
class(CAMBdata) :: this
integer, intent(in) :: n
real(dl), intent(out) :: arr(n)
real(dl), intent(in) :: a1(n), a2(n)
real(dl), intent(in), optional :: tol
integer i

!$OMP PARALLEL DO DEFAULT(SHARED),SCHEDULE(STATIC)
do i = 1, n
arr(i) = this%DeltaTime(a1(i), a2(i), tol)
end do

end subroutine CAMBdata_DeltaTimeArr

function CAMBdata_TimeOfz(this, z, tol)
class(CAMBdata) :: this
real(dl) CAMBdata_TimeOfz
Expand Down Expand Up @@ -654,6 +671,22 @@ function CAMBdata_DeltaPhysicalTimeGyr(this, a1,a2, in_tol)
CAMBdata_DeltaPhysicalTimeGyr = Integrate_Romberg(this, dtda,a1,a2,atol)*Mpc/c/Gyr
end function CAMBdata_DeltaPhysicalTimeGyr

subroutine CAMBdata_DeltaPhysicalTimeGyrArr(this, arr, a1, a2, n, tol)
class(CAMBdata) :: this
integer, intent(in) :: n
real(dl), intent(out) :: arr(n)
real(dl), intent(in) :: a1(n), a2(n)
real(dl), intent(in), optional :: tol
integer i

!$OMP PARALLEL DO DEFAULT(SHARED),SCHEDULE(STATIC)
do i = 1, n
arr(i) = this%DeltaPhysicalTimeGyr(a1(i), a2(i), tol)
end do

end subroutine CAMBdata_DeltaPhysicalTimeGyrArr


function CAMBdata_AngularDiameterDistance(this,z)
class(CAMBdata) :: this
!This is the physical (non-comoving) angular diameter distance in Mpc
Expand Down

0 comments on commit 11fdbbb

Please sign in to comment.