+# Copyright (c) 2003-2019 by Mike Jarvis
+#
+# TreeCorr is free software: redistribution and use in source and binary forms,
+# with or without modification, are permitted provided that the following
+# conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice, this
+# list of conditions, and the disclaimer given in the accompanying LICENSE
+# file.
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions, and the disclaimer given in the documentation
+# and/or other materials provided with the distribution.
+
+"""
+.. module:: util
+"""
+
+import numpy as np
+import os
+import coord
+import functools
+import inspect
+import warnings
+
+from . import _lib, _ffi, Rperp_alias
+from .writer import AsciiWriter, FitsWriter, HdfWriter
+from .reader import AsciiReader, FitsReader, HdfReader
+
+max_omp_threads=None
+
+[docs]def set_max_omp_threads(num_threads, logger=None):
+
"""Set the maximum allowed number of OpenMP threads to use in the C++ layer in any
+
further TreeCorr functions
+
+
:param num_threads: The target maximum number of threads to allow. None means no limit.
+
:param logger: If desired, a logger object for logging any warnings here. (default: None)
+
"""
+
global max_omp_threads
+
max_omp_threads=num_threads
+
+[docs]def set_omp_threads(num_threads, logger=None):
+
"""Set the number of OpenMP threads to use in the C++ layer.
+
+
:param num_threads: The target number of threads to use
+
:param logger: If desired, a logger object for logging any warnings here. (default: None)
+
+
:returns: The number of threads OpenMP reports that it will use. Typically this
+
matches the input, but OpenMP reserves the right not to comply with
+
the requested number of threads.
+
"""
+
input_num_threads = num_threads # Save the input value.
+
+
# If num_threads is auto, get it from cpu_count
+
if num_threads is None or num_threads <= 0:
+
import multiprocessing
+
num_threads = multiprocessing.cpu_count()
+
if logger:
+
logger.debug('multiprocessing.cpu_count() = %d',num_threads)
+
+
# Max at max_omp_threads, if set.
+
if max_omp_threads is not None and num_threads > max_omp_threads:
+
num_threads = max_omp_threads
+
if logger:
+
logger.debug('max_omp_threads = %d',max_omp_threads)
+
+
# Tell OpenMP to use this many threads
+
if logger:
+
logger.debug('Telling OpenMP to use %d threads',num_threads)
+
num_threads = _lib.SetOMPThreads(num_threads)
+
+
# Report back appropriately.
+
if logger:
+
logger.debug('OpenMP reports that it will use %d threads',num_threads)
+
if num_threads > 1:
+
logger.info('Using %d threads.',num_threads)
+
elif input_num_threads is not None and input_num_threads != 1:
+
# Only warn if the user specifically asked for num_threads != 1.
+
logger.warning("Unable to use multiple threads, since OpenMP is not enabled.")
+
+
return num_threads
+
+[docs]def get_omp_threads():
+
"""Get the number of OpenMP threads currently set to be used in the C++ layer.
+
+
:returns: The number of threads OpenMP reports that it will use.
+
"""
+
return _lib.GetOMPThreads()
+
+def parse_file_type(file_type, file_name, output=False, logger=None):
+ """Parse the file_type from the file_name if necessary
+
+ :param file_type: The input file_type. If None, then parse from file_name's extension.
+ :param file_name: The filename to use for parsing if necessary.
+ :param output: Limit to output file types (FITS/ASCII)? (default: False)
+ :param logger: A logger if desired. (default: None)
+
+ :returns: The parsed file_type.
+ """
+ if file_type is None:
+ import os
+ name, ext = os.path.splitext(file_name)
+ if ext.lower().startswith('.fit'):
+ file_type = 'FITS'
+ elif ext.lower().startswith('.hdf'):
+ file_type = 'HDF'
+ elif not output and ext.lower().startswith('.par'):
+ file_type = 'Parquet'
+ else:
+ file_type = 'ASCII'
+ if logger:
+ logger.info(" file_type assumed to be %s from the file name.",file_type)
+ return file_type.upper()
+
+def make_writer(file_name, precision=4, file_type=None, logger=None):
+ """Factory function to make a writer instance of the correct type.
+ """
+ # Figure out which file type to use.
+ file_type = parse_file_type(file_type, file_name, output=True, logger=logger)
+ if file_type == 'FITS':
+ writer = FitsWriter(file_name, logger=logger)
+ elif file_type == 'HDF':
+ writer = HdfWriter(file_name, logger=logger)
+ elif file_type == 'ASCII':
+ writer = AsciiWriter(file_name, precision=precision, logger=logger)
+ else:
+ raise ValueError("Invalid file_type %s"%file_type)
+ return writer
+
+def make_reader(file_name, file_type=None, logger=None):
+ """Factory function to make a writer instance of the correct type.
+ """
+ # Figure out which file type to use.
+ file_type = parse_file_type(file_type, file_name, output=False, logger=logger)
+
+ if file_type == 'FITS':
+ reader = FitsReader(file_name, logger=logger)
+ elif file_type == 'HDF':
+ reader = HdfReader(file_name, logger=logger)
+ elif file_type == 'ASCII':
+ reader = AsciiReader(file_name, logger=logger)
+ else:
+ raise ValueError("Invalid file_type %s"%file_type)
+ return reader
+
+class LRU_Cache(object):
+ """ Simplified Least Recently Used Cache.
+ Mostly stolen from http://code.activestate.com/recipes/577970-simplified-lru-cache/,
+ but added a method for dynamic resizing. The least recently used cached item is
+ overwritten on a cache miss.
+
+ Note: This has additional functionality beyond what functools.lru_cache provides.
+ 1. The ability to resize the maxsize non-destructively.
+ 2. The key is only on the args, not kwargs, so a logger can be provided as a kwarg
+ without triggering a cache miss.
+
+ :param user_function: A python function to cache.
+ :param maxsize: Maximum number of inputs to cache. [Default: 1024]
+
+ Usage
+ -----
+ >>> def slow_function(*args) # A slow-to-evaluate python function
+ >>> ...
+ >>>
+ >>> v1 = slow_function(*k1) # Calling function is slow
+ >>> v1 = slow_function(*k1) # Calling again with same args is still slow
+ >>> cache = galsim.utilities.LRU_Cache(slow_function)
+ >>> v1 = cache(*k1) # Returns slow_function(*k1), slowly the first time
+ >>> v1 = cache(*k1) # Returns slow_function(*k1) again, but fast this time.
+
+ Methods
+ -------
+ >>> cache.resize(maxsize) # Resize the cache, either upwards or downwards. Upwards resizing
+ # is non-destructive. Downwards resizing will remove the least
+ # recently used items first.
+ """
+ def __init__(self, user_function, maxsize=1024):
+ # Link layout: [PREV, NEXT, KEY, RESULT]
+ self.root = [None, None, None, None]
+ self.user_function = user_function
+ self.cache = {}
+
+ last = self.root
+ for i in range(maxsize):
+ key = object()
+ self.cache[key] = last[1] = last = [last, self.root, key, None]
+ self.root[0] = last
+ self.count = 0
+
+ def __call__(self, *key, **kwargs):
+ link = self.cache.get(key)
+ if link is not None:
+ # Cache hit: move link to last position
+ link_prev, link_next, _, result = link
+ link_prev[1] = link_next
+ link_next[0] = link_prev
+ last = self.root[0]
+ last[1] = self.root[0] = link
+ link[0] = last
+ link[1] = self.root
+ return result
+ # Cache miss: evaluate and insert new key/value at root, then increment root
+ # so that just-evaluated value is in last position.
+ result = self.user_function(*key, **kwargs)
+ self.root[2] = key
+ self.root[3] = result
+ oldroot = self.root
+ self.root = self.root[1]
+ oldkey = self.root[2]
+ self.root[2] = None
+ self.root[3] = None
+ self.cache[key] = oldroot
+ del self.cache[oldkey]
+ if self.count < self.size: self.count += 1
+ return result
+
+ def values(self):
+ """Lists all items stored in the cache"""
+ return list([v[3] for v in self.cache.values() if v[3] is not None])
+
+ @property
+ def last_value(self):
+ """Return the most recently used value"""
+ return self.root[0][3]
+
+ def resize(self, maxsize):
+ """ Resize the cache. Increasing the size of the cache is non-destructive, i.e.,
+ previously cached inputs remain in the cache. Decreasing the size of the cache will
+ necessarily remove items from the cache if the cache is already filled. Items are removed
+ in least recently used order.
+
+ :param maxsize: The new maximum number of inputs to cache.
+ """
+ oldsize = len(self.cache)
+ if maxsize == oldsize:
+ return
+ else:
+ if maxsize < 0:
+ raise ValueError("Invalid maxsize")
+ elif maxsize < oldsize:
+ for i in range(oldsize - maxsize):
+ # Delete root.next
+ current_next_link = self.root[1]
+ new_next_link = self.root[1] = self.root[1][1]
+ new_next_link[0] = self.root
+ del self.cache[current_next_link[2]]
+ self.count = min(self.count, maxsize)
+ else: # maxsize > oldsize:
+ for i in range(maxsize - oldsize):
+ # Insert between root and root.next
+ key = object()
+ self.cache[key] = link = [self.root, self.root[1], key, None]
+ self.root[1][0] = link
+ self.root[1] = link
+
+ def clear(self):
+ """ Clear all items from the cache.
+ """
+ maxsize = len(self.cache)
+ self.cache.clear()
+ last = self.root
+ for i in range(maxsize):
+ last[3] = None # Sever pointer to any existing result.
+ key = object()
+ self.cache[key] = last[1] = last = [last, self.root, key, None]
+ self.root[0] = last
+ self.count = 0
+
+ @property
+ def size(self):
+ return len(self.cache)
+
+def double_ptr(x):
+ """
+ Cast x as a double* to pass to library C functions
+
+ :param x: A numpy array assumed to have dtype = float.
+
+ :returns: A version of the array that can be passed to cffi C functions.
+ """
+ if x is None:
+ return _ffi.cast('double*', 0)
+ else:
+ # This fails if x is read_only
+ #return _ffi.cast('double*', _ffi.from_buffer(x))
+ # This works, presumably by ignoring the numpy read_only flag. Although, I think it's ok.
+ return _ffi.cast('double*', x.ctypes.data)
+
+def long_ptr(x):
+ """
+ Cast x as a long* to pass to library C functions
+
+ :param x: A numpy array assumed to have dtype = int.
+
+ :returns: A version of the array that can be passed to cffi C functions.
+ """
+ if x is None: # pragma: no cover (I don't ever have x=None for this one.)
+ return _ffi.cast('long*', 0)
+ else:
+ return _ffi.cast('long*', x.ctypes.data)
+
+def parse_metric(metric, coords, coords2=None, coords3=None):
+ """
+ Convert a string metric into the corresponding enum to pass to the C code.
+ """
+ if coords2 is None:
+ auto = True
+ else:
+ auto = False
+ # Special Rlens doesn't care about the distance to the sources, so spherical is fine
+ # for cat2, cat3 in that case.
+ if metric == 'Rlens':
+ if coords2 == 'spherical': coords2 = '3d'
+ if coords3 == 'spherical': coords3 = '3d'
+
+ if metric == 'Arc':
+ # If all coords are 3d, then leave it 3d, but if any are spherical,
+ # then convert to spherical.
+ if all([c in [None, '3d'] for c in [coords, coords2, coords3]]):
+ # Leave coords as '3d'
+ pass
+ elif any([c not in [None, 'spherical', '3d'] for c in [coords, coords2, coords3]]):
+ raise ValueError("Arc metric is only valid for catalogs with spherical positions.")
+ elif any([c == 'spherical' for c in [coords, coords2, coords3]]): # pragma: no branch
+ # Switch to spherical
+ coords = 'spherical'
+ else: # pragma: no cover
+ # This is impossible now, but here in case we add additional coordinates.
+ raise ValueError("Cannot correlate catalogs with different coordinate systems.")
+ else:
+ if ( (coords2 != coords) or (coords3 is not None and coords3 != coords) ):
+ raise ValueError("Cannot correlate catalogs with different coordinate systems.")
+
+ if coords not in ['flat', 'spherical', '3d']:
+ raise ValueError("Invalid coords %s"%coords)
+
+ if metric not in ['Euclidean', 'Rperp', 'OldRperp', 'FisherRperp', 'Rlens', 'Arc', 'Periodic']:
+ raise ValueError("Invalid metric %s"%metric)
+
+ if metric in ['Rperp', 'OldRperp', 'FisherRperp'] and coords != '3d':
+ raise ValueError("%s metric is only valid for catalogs with 3d positions."%metric)
+ if metric == 'Rlens' and auto:
+ raise ValueError("Rlens metric is only valid for cross correlations.")
+ if metric == 'Rlens' and coords != '3d':
+ raise ValueError("Rlens metric is only valid for catalogs with 3d positions.")
+ if metric == 'Arc' and coords not in ['spherical', '3d']:
+ raise ValueError("Arc metric is only valid for catalogs with spherical positions.")
+
+ return coords, metric
+
+def coord_enum(coords):
+ """Return the C++-layer enum for the given string value of coords.
+ """
+ if coords == 'flat':
+ return _lib.Flat
+ elif coords == 'spherical':
+ return _lib.Sphere
+ elif coords == '3d':
+ return _lib.ThreeD
+ else:
+ raise ValueError("Invalid coords %s"%coords)
+
+def metric_enum(metric):
+ """Return the C++-layer enum for the given string value of metric.
+ """
+ if metric == 'Euclidean':
+ return _lib.Euclidean
+ elif metric == 'Rperp':
+ return metric_enum(Rperp_alias)
+ elif metric == 'FisherRperp':
+ return _lib.Rperp
+ elif metric in ['OldRperp']:
+ return _lib.OldRperp
+ elif metric == 'Rlens':
+ return _lib.Rlens
+ elif metric == 'Arc':
+ return _lib.Arc
+ elif metric == 'Periodic':
+ return _lib.Periodic
+ else:
+ raise ValueError("Invalid metric %s"%metric)
+
+def parse_xyzsep(args, kwargs, _coords):
+ """Parse the different options for passing a coordinate and separation.
+
+ The allowed parameters are:
+
+ 1. If _coords == Flat:
+
+ :param x: The x coordinate of the location for which to count nearby points.
+ :param y: The y coordinate of the location for which to count nearby points.
+ :param sep: The separation distance
+
+ 2. If _coords == ThreeD:
+
+ Either
+ :param x: The x coordinate of the location for which to count nearby points.
+ :param y: The y coordinate of the location for which to count nearby points.
+ :param z: The z coordinate of the location for which to count nearby points.
+ :param sep: The separation distance
+
+ Or
+ :param ra: The right ascension of the location for which to count nearby points.
+ :param dec: The declination of the location for which to count nearby points.
+ :param r: The distance to the location for which to count nearby points.
+ :param sep: The separation distance
+
+ 3. If _coords == Sphere:
+
+ :param ra: The right ascension of the location for which to count nearby points.
+ :param dec: The declination of the location for which to count nearby points.
+ :param sep: The separation distance as an angle
+
+ For all angle parameters (ra, dec, sep), this quantity may be a coord.Angle instance, or
+ units maybe be provided as ra_units, dec_units or sep_units respectively.
+
+ Finally, in cases where ra, dec are allowed, a coord.CelestialCoord instance may be
+ provided as the first argument.
+
+ :returns: The effective (x, y, z, sep) as a tuple.
+ """
+ radec = False
+ if _coords == _lib.Flat:
+ if len(args) == 0:
+ if 'x' not in kwargs:
+ raise TypeError("Missing required argument x")
+ if 'y' not in kwargs:
+ raise TypeError("Missing required argument y")
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ x = kwargs.pop('x')
+ y = kwargs.pop('y')
+ sep = kwargs.pop('sep')
+ elif len(args) == 1:
+ raise TypeError("x,y should be given as either args or kwargs, not mixed.")
+ elif len(args) == 2:
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ x,y = args
+ sep = kwargs.pop('sep')
+ elif len(args) == 3:
+ x,y,sep = args
+ else:
+ raise TypeError("Too many positional args")
+ z = 0
+
+ elif _coords == _lib.ThreeD:
+ if len(args) == 0:
+ if 'x' in kwargs:
+ if 'y' not in kwargs:
+ raise TypeError("Missing required argument y")
+ if 'z' not in kwargs:
+ raise TypeError("Missing required argument z")
+ x = kwargs.pop('x')
+ y = kwargs.pop('y')
+ z = kwargs.pop('z')
+ else:
+ if 'ra' not in kwargs:
+ raise TypeError("Missing required argument ra")
+ if 'dec' not in kwargs:
+ raise TypeError("Missing required argument dec")
+ ra = kwargs.pop('ra')
+ dec = kwargs.pop('dec')
+ radec = True
+ if 'r' not in kwargs:
+ raise TypeError("Missing required argument r")
+ r = kwargs.pop('r')
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ elif len(args) == 1:
+ if not isinstance(args[0], coord.CelestialCoord):
+ raise TypeError("Invalid unnamed argument %r"%args[0])
+ ra = args[0].ra
+ dec = args[0].dec
+ radec = True
+ if 'r' not in kwargs:
+ raise TypeError("Missing required argument r")
+ r = kwargs.pop('r')
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ elif len(args) == 2:
+ if isinstance(args[0], coord.CelestialCoord):
+ ra = args[0].ra
+ dec = args[0].dec
+ radec = True
+ r = args[1]
+ else:
+ ra, dec = args
+ radec = True
+ if 'r' not in kwargs:
+ raise TypeError("Missing required argument r")
+ r = kwargs.pop('r')
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ elif len(args) == 3:
+ if isinstance(args[0], coord.CelestialCoord):
+ ra = args[0].ra
+ dec = args[0].dec
+ radec = True
+ r = args[1]
+ sep = args[2]
+ elif isinstance(args[0], coord.Angle):
+ ra, dec, r = args
+ radec = True
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ elif 'ra_units' in kwargs or 'dec_units' in kwargs:
+ ra, dec, r = args
+ radec = True
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ else:
+ x, y, z = args
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ elif len(args) == 4:
+ if isinstance(args[0], coord.Angle):
+ ra, dec, r, sep = args
+ radec = True
+ elif 'ra_units' in kwargs or 'dec_units' in kwargs:
+ ra, dec, r, sep = args
+ radec = True
+ else:
+ x, y, z, sep = args
+ else:
+ raise TypeError("Too many positional args")
+
+ else: # Sphere
+ if len(args) == 0:
+ if 'ra' not in kwargs:
+ raise TypeError("Missing required argument ra")
+ if 'dec' not in kwargs:
+ raise TypeError("Missing required argument dec")
+ ra = kwargs.pop('ra')
+ dec = kwargs.pop('dec')
+ radec = True
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ elif len(args) == 1:
+ if not isinstance(args[0], coord.CelestialCoord):
+ raise TypeError("Invalid unnamed argument %r"%args[0])
+ ra = args[0].ra
+ dec = args[0].dec
+ radec = True
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ elif len(args) == 2:
+ if isinstance(args[0], coord.CelestialCoord):
+ ra = args[0].ra
+ dec = args[0].dec
+ radec = True
+ sep = args[1]
+ else:
+ ra, dec = args
+ radec = True
+ if 'sep' not in kwargs:
+ raise TypeError("Missing required argument sep")
+ sep = kwargs.pop('sep')
+ elif len(args) == 3:
+ ra, dec, sep = args
+ radec = True
+ else:
+ raise TypeError("Too many positional args")
+ if not isinstance(sep, coord.Angle):
+ if 'sep_units' not in kwargs:
+ raise TypeError("Missing required argument sep_units")
+ sep = sep * coord.AngleUnit.from_name(kwargs.pop('sep_units'))
+ # We actually want the chord distance for this angle.
+ sep = 2. * np.sin(sep/2.)
+
+ if radec:
+ if not isinstance(ra, coord.Angle):
+ if 'ra_units' not in kwargs:
+ raise TypeError("Missing required argument ra_units")
+ ra = ra * coord.AngleUnit.from_name(kwargs.pop('ra_units'))
+ if not isinstance(dec, coord.Angle):
+ if 'dec_units' not in kwargs:
+ raise TypeError("Missing required argument dec_units")
+ dec = dec * coord.AngleUnit.from_name(kwargs.pop('dec_units'))
+ x,y,z = coord.CelestialCoord(ra, dec).get_xyz()
+ if _coords == _lib.ThreeD:
+ x *= r
+ y *= r
+ z *= r
+ if len(kwargs) > 0:
+ raise TypeError("Invalid kwargs: %s"%(kwargs))
+
+ return float(x), float(y), float(z), float(sep)
+
+class lazy_property(object):
+ """
+ This decorator will act similarly to @property, but will be efficient for multiple access
+ to values that require some significant calculation.
+
+ It works by replacing the attribute with the computed value, so after the first access,
+ the property (an attribute of the class) is superseded by the new attribute of the instance.
+
+ Usage::
+
+ @lazy_property
+ def slow_function_to_be_used_as_a_property(self):
+ x = ... # Some slow calculation.
+ return x
+
+ Base on an answer from http://stackoverflow.com/a/6849299
+ This implementation taken from GalSim utilities.py
+ """
+ def __init__(self, fget):
+ self.fget = fget
+ self.func_name = fget.__name__
+
+ def __get__(self, obj, cls):
+ if obj is None:
+ return self
+ value = self.fget(obj)
+ setattr(obj, self.func_name, value)
+ return value
+
+
+
+def depr_pos_kwargs(fn):
+ """
+ This decorator will allow the old API where keywords are allowed as positional variables,
+ but it will give a deprecation warning about it.
+
+ @depr_pos_kwargs
+ def func_with_kwargs(a, *, b=3, c=4):
+ ...
+
+ # Expected usage:
+ func_with_kwargs(1, b=5, c=9)
+
+ # This works, but gives a deprecation warning
+ func_with_kwargs(1, 5, 9)
+ """
+ # Note: this is inspired by the legacy_api_wrap decorator by flying-sheep, which does something
+ # similar.
+ # https://github.com/flying-sheep/legacy-api-wrap/blob/master/legacy_api_wrap.py
+ # However, it was reimplemented from scratch my MJ.
+
+ params = inspect.signature(fn).parameters
+ nparams = len(params)
+ npos = np.sum([p.kind in [p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD] for p in params.values()])
+ assert nparams > npos # Otherwise developer probably forgot to add the * to the signature!
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ if len(args) > npos:
+ # Make sure providing too many params is still a TypeError.
+ if len(args) > nparams:
+ raise TypeError("{} takes at most {} arguments but {} were given.".format(
+ fn.__name__, nparams, len(args)))
+
+ # Which names need to turn into kwargs?
+ kw_names = list(params.keys())[npos:len(args)]
+
+ # Warn about deprecated syntax
+ warnings.warn(
+ "Use of keyword-only arguments as positional arguments is deprecated in "+
+ "the function " + fn.__name__ + ". " +
+ "The following parameters now require an explicit keyword name: "+
+ str(kw_names), FutureWarning)
+
+ # But make it work.
+ for a, n in zip(args[npos:], kw_names):
+ kwargs[n] = a
+ args = args[:npos]
+
+ return fn(*args, **kwargs)
+
+ return wrapper
+