Skip to content

Commit

Permalink
Fix error when using load_mem with non-radian units
Browse files Browse the repository at this point in the history
  • Loading branch information
rmjarvis committed Jun 13, 2022
1 parent 234b90e commit 47434de
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 33 deletions.
38 changes: 21 additions & 17 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1653,15 +1653,17 @@ def test_save_patches():
y = rng.normal(0,s, (ngal,) ) + 100 # Put everything at large y, so smallish angle on sky
z = rng.normal(0,s, (ngal,) )
ra, dec = coord.CelestialCoord.xyz_to_radec(x,y,z)
ra *= 180./np.pi
dec *= 180./np.pi

file_name = os.path.join('output','test_save_patches.fits')
cat0 = treecorr.Catalog(ra=ra, dec=dec, ra_units='rad', dec_units='rad')
cat0 = treecorr.Catalog(ra=ra, dec=dec, ra_units='deg', dec_units='deg')
cat0.write(file_name)


# When catalog has explicit ra, dec, etc., then file names are patch000.fits, ...
clear_save('patch%03d.fits', npatch)
cat1 = treecorr.Catalog(ra=ra, dec=dec, ra_units='rad', dec_units='rad', npatch=npatch,
cat1 = treecorr.Catalog(ra=ra, dec=dec, ra_units='deg', dec_units='deg', npatch=npatch,
save_patch_dir='output')
assert len(cat1.patches) == npatch
for i in range(npatch):
Expand All @@ -1676,7 +1678,7 @@ def test_save_patches():

# When catalog is a file, then base name off of given file_name.
clear_save('test_save_patches_%03d.fits', npatch)
cat2 = treecorr.Catalog(file_name, ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
cat2 = treecorr.Catalog(file_name, ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
npatch=npatch, save_patch_dir='output')
assert not cat2.loaded
cat2.get_patches(low_mem=True)
Expand Down Expand Up @@ -1706,7 +1708,7 @@ def test_save_patches():
cat0.write(file_name)
# And also try to match the type if HDF
clear_save('test_save_patches_%03d.hdf5', npatch)
cat2 = treecorr.Catalog(file_name, ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
cat2 = treecorr.Catalog(file_name, ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
npatch=npatch, save_patch_dir='output')
assert not cat2.loaded
cat2.get_patches(low_mem=True)
Expand Down Expand Up @@ -2295,9 +2297,11 @@ def test_lowmem():
y = rng.uniform(80,120, (ngal,) ) # Put everything at large y, so smallish angle on sky
z = rng.uniform(-20,20, (ngal,) )
ra, dec, r = coord.CelestialCoord.xyz_to_radec(x,y,z, return_r=True)
ra *= 180./np.pi # -> deg
dec *= 180./np.pi

file_name = os.path.join('output','test_lowmem.fits')
orig_cat = treecorr.Catalog(ra=ra, dec=dec, ra_units='rad', dec_units='rad')
orig_cat = treecorr.Catalog(ra=ra, dec=dec, ra_units='deg', dec_units='deg')
orig_cat.write(file_name)
del orig_cat

Expand All @@ -2309,14 +2313,14 @@ def test_lowmem():
hp = None

partial_cat = treecorr.Catalog(file_name, every_nth=100,
ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
npatch=npatch)

patch_centers = partial_cat.patch_centers
del partial_cat

full_cat = treecorr.Catalog(file_name,
ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
patch_centers=patch_centers)

dd = treecorr.NNCorrelation(bin_size=0.5, min_sep=1., max_sep=30., sep_units='arcmin')
Expand All @@ -2337,7 +2341,7 @@ def test_lowmem():
# Remake with save_patch_dir.
clear_save('test_lowmem_%03d.fits', npatch)
save_cat = treecorr.Catalog(file_name,
ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
patch_centers=patch_centers, save_patch_dir='output')

t0 = time.time()
Expand Down Expand Up @@ -2369,7 +2373,7 @@ def test_lowmem():
g2 = rng.uniform(-0.1,0.1, (ngal//100,) )
k = rng.uniform(-0.1,0.1, (ngal//100,) )
gk_cat0 = treecorr.Catalog(ra=ra[:ngal//100], dec=dec[:ngal//100], r=r[:ngal//100],
ra_units='rad', dec_units='rad',
ra_units='deg', dec_units='deg',
g1=g1, g2=g2, k=k,
npatch=4)
patch_centers = gk_cat0.patch_centers
Expand All @@ -2387,11 +2391,11 @@ def test_lowmem():
# First GG with normal ra,dec from a file
clear_save('test_lowmem_gk_%03d.fits', npatch)
gk_cat1 = treecorr.Catalog(file_name,
ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
g1_col='g1', g2_col='g2', k_col='k',
patch_centers=patch_centers)
gk_cat2 = treecorr.Catalog(file_name,
ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
g1_col='g1', g2_col='g2', k_col='k',
patch_centers=patch_centers, save_patch_dir='output')

Expand Down Expand Up @@ -2438,11 +2442,11 @@ def test_lowmem():
# KK with r_col now to test that that works properly.
clear_save('test_lowmem_gk_%03d.fits', npatch)
gk_cat1 = treecorr.Catalog(file_name,
ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
r_col='r', g1_col='g1', g2_col='g2', k_col='k',
patch_centers=patch_centers)
gk_cat2 = treecorr.Catalog(file_name,
ra_col='ra', dec_col='dec', ra_units='rad', dec_units='rad',
ra_col='ra', dec_col='dec', ra_units='deg', dec_units='deg',
r_col='r', g1_col='g1', g2_col='g2', k_col='k',
patch_centers=patch_centers, save_patch_dir='output')

Expand All @@ -2469,11 +2473,11 @@ def test_lowmem():
clear_save('patch%03d.fits', npatch)
gk_cat1 = treecorr.Catalog(ra=ra[:ngal//100], dec=dec[:ngal//100], r=r[:ngal//100],
g1=g1[:ngal//100], g2=g2[:ngal//100], k=k[:ngal//100],
ra_units='rad', dec_units='rad',
ra_units='deg', dec_units='deg',
patch_centers=patch_centers)
gk_cat2 = treecorr.Catalog(ra=ra[:ngal//100], dec=dec[:ngal//100], r=r[:ngal//100],
g1=g1[:ngal//100], g2=g2[:ngal//100], k=k[:ngal//100],
ra_units='rad', dec_units='rad',
ra_units='deg', dec_units='deg',
patch_centers=patch_centers, save_patch_dir='output')

nk1 = treecorr.NKCorrelation(bin_size=0.5, min_sep=1., max_sep=20.)
Expand All @@ -2499,11 +2503,11 @@ def test_lowmem():
clear_save('patch%03d.fits', npatch)
gk_cat1 = treecorr.Catalog(ra=ra[:ngal//100], dec=dec[:ngal//100],
g1=g1[:ngal//100], g2=g2[:ngal//100], k=k[:ngal//100],
ra_units='rad', dec_units='rad',
ra_units='deg', dec_units='deg',
patch_centers=patch_centers)
gk_cat2 = treecorr.Catalog(ra=ra[:ngal//100], dec=dec[:ngal//100],
g1=g1[:ngal//100], g2=g2[:ngal//100], k=k[:ngal//100],
ra_units='rad', dec_units='rad',
ra_units='deg', dec_units='deg',
patch_centers=patch_centers, save_patch_dir='output')

kg1 = treecorr.KGCorrelation(bin_size=0.5, min_sep=1., max_sep=30., sep_units='arcmin')
Expand Down
38 changes: 22 additions & 16 deletions treecorr/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,10 @@ def __init__(self, file_name=None, config=None, num=0, logger=None, is_rand=Fals
self._patch = self.makeArray(patch,'patch',int)
if self._patch is not None:
self._set_npatch()
if self._x is not None:
self._apply_xyz_units()
if self._ra is not None:
self._apply_radec_units()

# Check that all columns have the same length. (This is impossible in file input)
if self._x is not None:
Expand Down Expand Up @@ -949,9 +953,6 @@ def _finish_input(self):
self.checkForNaN(self._w,'w')
self.checkForNaN(self._wpos,'wpos')

# Apply units as appropriate
self._apply_units()

# If using ra/dec, generate x,y,z
# Note: This also makes self.ntot work properly.
self._generate_xyz()
Expand Down Expand Up @@ -1053,18 +1054,17 @@ def _get_patch_index(self, single_patch):
use = slice(None) # Which ironically means use all. :)
return use

def _apply_units(self):
# Apply units to x,y,ra,dec
if self._ra is not None:
self.ra_units = get_from_list(self.config,'ra_units',self._num)
self.dec_units = get_from_list(self.config,'dec_units',self._num)
self._ra *= self.ra_units
self._dec *= self.dec_units
else:
self.x_units = get_from_list(self.config,'x_units',self._num,str, 'radians')
self.y_units = get_from_list(self.config,'y_units',self._num,str, 'radians')
self._x *= self.x_units
self._y *= self.y_units
def _apply_radec_units(self):
self.ra_units = get_from_list(self.config,'ra_units',self._num)
self.dec_units = get_from_list(self.config,'dec_units',self._num)
self._ra *= self.ra_units
self._dec *= self.dec_units

def _apply_xyz_units(self):
self.x_units = get_from_list(self.config,'x_units',self._num,str, 'radians')
self.y_units = get_from_list(self.config,'y_units',self._num,str, 'radians')
self._x *= self.x_units
self._y *= self.y_units

def _generate_xyz(self):
if self._x is None:
Expand Down Expand Up @@ -1304,6 +1304,7 @@ def set_pos(data, x_col, y_col, z_col, ra_col, dec_col, r_col):
if z_col != '0':
self._z = data[z_col].astype(float)
self.logger.debug('read z')
self._apply_xyz_units()
if ra_col != '0' and ra_col in data:
self._ra = data[ra_col].astype(float)
self.logger.debug('read ra')
Expand All @@ -1312,6 +1313,7 @@ def set_pos(data, x_col, y_col, z_col, ra_col, dec_col, r_col):
if r_col != '0':
self._r = data[r_col].astype(float)
self.logger.debug('read r')
self._apply_radec_units()

def set_patch(data, patch_col):
if patch_col != '0' and patch_col in data:
Expand Down Expand Up @@ -1413,6 +1415,7 @@ def set_patch(data, patch_col):
for c in use_cols1:
data[c] = data1[c]
set_pos(data, x_col, y_col, z_col, ra_col, dec_col, r_col)
x_col = y_col = z_col = ra_col = dec_col = r_col = '0'
use = self._get_patch_index(self._single_patch)
self.select(use)
if isinstance(s,np.ndarray):
Expand Down Expand Up @@ -1976,7 +1979,10 @@ def write_patches(self, save_patch_dir=None):
file_names = self.get_patch_file_names(save_patch_dir)
for i, p, file_name in zip(range(self.npatch), self.patches, file_names):
self.logger.info('Writing patch %d to %s',i,file_name)
cols = p.write(file_name)
if p.ra is not None:
# Don't multiply and divide by the units on round trip.
p.ra_units = p.dec_units = 1
p.write(file_name)

def read_patches(self, save_patch_dir=None):
"""Read the patches from files on disk.
Expand Down

0 comments on commit 47434de

Please sign in to comment.