diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..ce436f57 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,21 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +insert_final_newline = true +trim_trailing_whitespace = true +charset = utf-8 +end_of_line = lf + +[*.py] +max_line_length = 120 + +[*.md] +trim_trailing_whitespace = false + +[*.yml] +indent_size = 2 + +[Makefile] +indent_style = tab diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d418e864..a7cb3c3d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,10 +13,10 @@ jobs: fail-fast: false matrix: python-version: [ - '3.8', '3.9', '3.10', '3.11', + '3.12', ] runs-on: ubuntu-latest steps: @@ -30,19 +30,19 @@ jobs: - run: make unittests - run: coverage report - flake: + lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: '3.11' - - run: pip install flake8 flake8-isort pep8-naming - - run: make flaketest + - run: pip install ruff + - run: make linttest docs: if: github.ref == 'refs/heads/master' - needs: [tests, flake] + needs: [tests, lint] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 diff --git a/Makefile b/Makefile index ac10f96c..118ad2fd 100644 --- a/Makefile +++ b/Makefile @@ -1,16 +1,23 @@ -.PHONY: test unittests flaketest doctest update_data - +.PHONY: test test: unittests flaketest doctest +.PHONY: unittests unittests: coverage run -m unittest -flaketest: - flake8 sapphire +.PHONY: linttest +linttest: + ruff check . + +.PHONY: lintfix +lintfix: + ruff check --fix-only . +.PHONY: doctest doctest: sphinx-build -anW doc doc/_build/html +.PHONY: update_data update_data: ifeq ($(strip $(shell git status --porcelain | wc -l)), 0) @echo "Updating local data. Creating test data to match local data and committing." diff --git a/doc/conf.py b/doc/conf.py index b8e46f08..325185ff 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -9,13 +9,11 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import os -import sys # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('.')) # Often reused AUTHORS = 'David Fokkema, Arne de Laat, and Tom Kooij' @@ -83,27 +81,26 @@ # Output file base name for HTML help builder. htmlhelp_basename = 'SAPPHiREdoc' + def setup(app): app.add_css_file('hisparc_style.css') + # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'SAPPHiRE.tex', 'SAPPHiRE Documentation', - AUTHORS, 'manual'), + ('index', 'SAPPHiRE.tex', 'SAPPHiRE Documentation', AUTHORS, 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -116,8 +113,7 @@ def setup(app): # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'sapphire', 'SAPPHiRE Documentation', - [AUTHORS], 1) + ('index', 'sapphire', 'SAPPHiRE Documentation', [AUTHORS], 1), ] @@ -127,10 +123,15 @@ def setup(app): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'SAPPHiRE', 'SAPPHiRE Documentation', - AUTHORS, 'SAPPHiRE', - 'One line description of project.', - 'Miscellaneous'), + ( + 'index', + 'SAPPHiRE', + 'SAPPHiRE Documentation', + AUTHORS, + 'SAPPHiRE', + 'One line description of project.', + 'Miscellaneous', + ), ] diff --git a/doc/scripts/is_useful.py b/doc/scripts/is_useful.py index 6295c07a..9d04fb37 100644 --- a/doc/scripts/is_useful.py +++ b/doc/scripts/is_useful.py @@ -1,5 +1,6 @@ def square(x): - return x ** 2 + return x**2 + print(1, 2, 3) print(square(1), square(2), square(3)) diff --git a/doc/scripts/is_useful_and_importable.py b/doc/scripts/is_useful_and_importable.py index 7f8f1f47..f1d2b5d9 100644 --- a/doc/scripts/is_useful_and_importable.py +++ b/doc/scripts/is_useful_and_importable.py @@ -1,5 +1,5 @@ def square(x): - return x ** 2 + return x**2 if __name__ == '__main__': diff --git a/doc/scripts/plot_zenith_distribution.py b/doc/scripts/plot_zenith_distribution.py index 26d20125..37ac82be 100644 --- a/doc/scripts/plot_zenith_distribution.py +++ b/doc/scripts/plot_zenith_distribution.py @@ -13,8 +13,8 @@ def plot_zenith_distribution(data): zenith = zenith.compress(-isnan(zenith)) plt.hist(degrees(zenith), bins=linspace(0, 90, 51), histtype='step') - plt.xlabel("zenith [deg]") - plt.ylabel("count") + plt.xlabel('zenith [deg]') + plt.ylabel('count') if __name__ == '__main__': diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..42d49b14 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,189 @@ +[build-system] +requires = ['flit_core>=3.9'] +build-backend = 'flit_core.buildapi' + +[project] +name = 'hisparc-sapphire' +version = '2.0.0' +description = 'A framework for the HiSPARC experiment' +readme = 'README.rst' +requires-python = '>=3.9' +license = {file = 'LICENSE'} +authors = [ + {name = 'Arne de Laat', email = 'arne@delaat.net'}, + {name = 'David Fokkema'}, + {name = 'Tom Kooij'}, +] +maintainers = [ + {name = 'Arne de Laat', email = 'arne@delaat.net'}, +] +keywords = [ + 'cosmic rays', + 'detectors', + 'astrophysics', + 'HiSPARC', + 'Nikhef', + 'University of Utah', +] +classifiers = [ + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python', + 'Topic :: Education', + 'Topic :: Scientific/Engineering', +] +dependencies = [ + 'numpy>=1.25.2', + 'scipy>=1.13.0', + 'tables>=3.9.2', + 'progressbar2>=4.4.2', +] + +[project.optional-dependencies] +dev = [ + 'Sphinx', + 'coverage==7.4.4', + 'ruff==0.4.1', +] +astropy = [ + 'astropy>=5.0.0', +] +publish = [ + 'flit==3.9.0', +] + +[project.urls] +Homepage = 'https://data.hisparc.nl' +Documentation = 'https://docs.hisparc.nl/sapphire/' +Repository = 'https://github.com/hisparc/sapphire/' +Issues = 'https://github.com/HiSPARC/sapphire/issues' + +[project.scripts] +create_and_store_test_data = 'sapphire.tests.create_and_store_test_data:main' +update_local_data = 'sapphire.data.update_local_data:main' +extend_local_data = 'sapphire.data.extend_local_data:main' + +[tool.flit.module] +name = 'sapphire' + +[tool.ruff] +line-length = 120 +target-version = 'py39' +extend-exclude = [ + 'doc/', + 'scripts/', +] + +[tool.ruff.format] +quote-style = 'single' # Prefer single quotes, except for triple quotes strings + +[tool.ruff.lint] +select = ['ALL'] # https://docs.astral.sh/ruff/rules/ +ignore = [ + 'ANN', # No type annotations yet + 'ARG002', # Allow unused arguments to keep signature equal for similar functions + 'B028', # Allow warnings.warn without stacklevel + 'B904', # Allow skipping causes for exceptions + 'BLE001', # Allow catching Exception + 'CPY001', # Do not require copyright notices + 'D', # Ignore docstring checks + 'DTZ', # Ignore timezones for datetime + 'E731', # Allow assigning lambda to variable + 'EM', # Allow messages directly in exceptions + 'ERA001', # Allow commented code + 'F405', # Allow star imports + 'FBT001', # Allow positional for boolean arguments + 'FBT002', # Allow default value for boolean arguments + 'PD', # Not using pandas + 'PERF203', # Allow try-except in loop + 'PLR0913', # Allow functions with many arguments + 'PLR6301', # Allow not using self in methods + 'PT', # Not using pytest + 'PTH', # Allow using os.path + 'RET504', # Allow variable assignment before return + 'RET505', # Allow elif after return + 'RET506', # Allow elif after raise + 'RET507', # Allow elif after continue + 'RET508', # Allow elif after break + 'S311', # Allow using random for non-cryptographic purposes + 'SIM108', # Allow if-else block instead of ternary + 'SIM117', # Allow separate with statements to preserve context + 'SLF001', # Allow accessing private members + 'T201', # Allow using print + 'TD002', # Allow TODO without author + 'TD003', # Allow TODO without issue link + 'TID252', # Allow relative imports + 'TRY003', # Specific messages for common exception classes + + # TODO these still need to be checked (i.e. ignored or fixed) + 'ARG001', + 'ARG003', + 'ARG005', + 'B018', + 'C901', + 'FBT003', + 'FIX002', + 'G002', + 'G004', + 'NPY002', + 'PLR0911', + 'PLR0912', + 'PLR0915', + 'PLR2004', + 'PLW2901', + 'RET503', + 'RUF005', + 'RUF012', + 'S103', + 'S108', + 'S310', + 'S602', + 'S603', + 'S604', + 'S607', + 'SIM102', + 'SIM105', + 'SIM115', + 'TRY301', + 'TRY400', + 'UP031', +] + +[tool.ruff.lint.per-file-ignores] +'sapphire/corsika/units.py' = ['N816'] # Allow mixed case variables +'sapphire/esd.py' = ['A002'] # Some keyword arguments shadow a builtin +'sapphire/kascade.py' = ['N806'] # Allow non lower case variables to match KASCADE +'sapphire/storage.py' = ['N815'] # Allow mixed case variables +'sapphire/tests/simulations/test_gammas.py' = ['N806'] # Allow upper case variables +'sapphire/tests/transformations/test_celestial.py' = ['N806'] # Allow upper case variables + +[tool.ruff.lint.flake8-quotes] +inline-quotes = 'single' + +[tool.ruff.lint.isort] +lines-between-types = 1 +section-order = [ + 'future', + 'standard-library', + 'third-party', + 'extras', + 'first-party', + 'local-folder', +] + +[tool.ruff.lint.isort.sections] +extras = ['artist', 'pylab'] + +[tool.coverage.run] +branch = true +source = [ + 'sapphire', +] + +[tool.coverage.report] +show_missing = true +skip_empty = true +skip_covered = true diff --git a/sapphire/README b/sapphire/README deleted file mode 100644 index 2210604a..00000000 --- a/sapphire/README +++ /dev/null @@ -1,2 +0,0 @@ -Simulation and Analysis Program Package for HiSPARC Research (SAPPHiRE) -======================================================================= diff --git a/sapphire/README.md b/sapphire/README.md new file mode 100644 index 00000000..d5235a32 --- /dev/null +++ b/sapphire/README.md @@ -0,0 +1 @@ +# Simulation and Analysis Program Package for HiSPARC Research (SAPPHiRE) diff --git a/sapphire/__init__.py b/sapphire/__init__.py index 75529389..7c8e1fe8 100644 --- a/sapphire/__init__.py +++ b/sapphire/__init__.py @@ -54,6 +54,7 @@ commonly used functions such as a progressbar """ + from . import ( analysis, api, @@ -100,44 +101,61 @@ from .tests import run_tests from .transformations.celestial import zenithazimuth_to_equatorial from .transformations.clock import datetime_to_gps, gps_to_datetime -from .version import __version__ # noqa - -__all__ = ['analysis', - 'api', - 'clusters', - 'corsika', - 'data', - 'esd', - 'kascade', - 'publicdb', - 'qsub', - 'simulations', - 'storage', - 'time_util', - 'transformations', - 'utils', - 'determine_detector_timing_offsets', - 'DetermineStationTimingOffsets', - 'CoincidenceQuery', - 'Coincidences', 'CoincidencesESD', - 'FindMostProbableValueInSpectrum', - 'ProcessEvents', 'ProcessEventsFromSource', - 'ProcessEventsFromSourceWithTriggerOffset', - 'ProcessWeather', 'ProcessWeatherFromSource', - 'ProcessSingles', 'ProcessSinglesFromSource', - 'TraceObservables', 'MeanFilter', 'DataReduction', - 'ReconstructESDEvents', 'ReconstructESDEventsFromSource', - 'ReconstructESDCoincidences', - 'ProcessTimeDeltas', - 'Network', 'Station', - 'HiSPARCStations', 'HiSPARCNetwork', 'ScienceParkCluster', - 'CorsikaQuery', - 'quick_download', 'load_data', 'download_data', - 'download_lightning', 'download_coincidences', - 'GroundParticlesSimulation', 'MultipleGroundParticlesSimulation', - 'KascadeLdfSimulation', 'NkgLdfSimulation', - 'FlatFrontSimulation', 'ConeFrontSimulation', - 'run_tests', - 'zenithazimuth_to_equatorial', - 'gps_to_datetime', 'datetime_to_gps' - ] + +__all__ = [ + 'analysis', + 'api', + 'clusters', + 'corsika', + 'data', + 'esd', + 'kascade', + 'publicdb', + 'qsub', + 'simulations', + 'storage', + 'time_util', + 'transformations', + 'utils', + 'determine_detector_timing_offsets', + 'DetermineStationTimingOffsets', + 'CoincidenceQuery', + 'Coincidences', + 'CoincidencesESD', + 'FindMostProbableValueInSpectrum', + 'ProcessEvents', + 'ProcessEventsFromSource', + 'ProcessEventsFromSourceWithTriggerOffset', + 'ProcessWeather', + 'ProcessWeatherFromSource', + 'ProcessSingles', + 'ProcessSinglesFromSource', + 'TraceObservables', + 'MeanFilter', + 'DataReduction', + 'ReconstructESDEvents', + 'ReconstructESDEventsFromSource', + 'ReconstructESDCoincidences', + 'ProcessTimeDeltas', + 'Network', + 'Station', + 'HiSPARCStations', + 'HiSPARCNetwork', + 'ScienceParkCluster', + 'CorsikaQuery', + 'quick_download', + 'load_data', + 'download_data', + 'download_lightning', + 'download_coincidences', + 'GroundParticlesSimulation', + 'MultipleGroundParticlesSimulation', + 'KascadeLdfSimulation', + 'NkgLdfSimulation', + 'FlatFrontSimulation', + 'ConeFrontSimulation', + 'run_tests', + 'zenithazimuth_to_equatorial', + 'gps_to_datetime', + 'datetime_to_gps', +] diff --git a/sapphire/analysis/__init__.py b/sapphire/analysis/__init__.py index d6b8cb65..a9a82f71 100644 --- a/sapphire/analysis/__init__.py +++ b/sapphire/analysis/__init__.py @@ -52,6 +52,7 @@ determine time deltas for station pairs """ + from . import ( calibration, coincidence_queries, @@ -67,15 +68,17 @@ time_deltas, ) -__all__ = ['calibration', - 'coincidence_queries', - 'coincidences', - 'core_reconstruction', - 'direction_reconstruction', - 'event_utils', - 'find_mpv', - 'landau', - 'process_events', - 'process_traces', - 'reconstructions', - 'time_deltas'] +__all__ = [ + 'calibration', + 'coincidence_queries', + 'coincidences', + 'core_reconstruction', + 'direction_reconstruction', + 'event_utils', + 'find_mpv', + 'landau', + 'process_events', + 'process_traces', + 'reconstructions', + 'time_deltas', +] diff --git a/sapphire/analysis/calibration.py b/sapphire/analysis/calibration.py index a3427ef0..e55472bd 100644 --- a/sapphire/analysis/calibration.py +++ b/sapphire/analysis/calibration.py @@ -1,4 +1,4 @@ -""" Determine calibration values for data +"""Determine calibration values for data This module can be used to determine calibration values from data. @@ -6,6 +6,7 @@ Determine the PMT response curve to correct the detected number of MIPs. """ + from datetime import datetime, timedelta from itertools import chain, combinations, tee @@ -39,24 +40,24 @@ def determine_detector_timing_offsets(events, station=None): z = [d.get_coordinates()[2] for d in station.detectors] else: n_detectors = 4 - z = [0., 0., 0., 0.] + z = [0.0, 0.0, 0.0, 0.0] - for id in range(n_detectors): - t.append(events.col('t%d' % (id + 1))) - filters.append((events.col('n%d' % (id + 1)) > 0.3) & (t[id] >= 0.)) + for detector_id in range(n_detectors): + t.append(events.col(f't{detector_id + 1}')) + filters.append((events.col('n%d' % (detector_id + 1)) > 0.3) & (t[detector_id] >= 0.0)) if n_detectors == 2: ref_id = 1 else: ref_id = determine_best_reference(filters) - for id in range(n_detectors): - if id == ref_id: - offsets[id] = 0. + for detector_id in range(n_detectors): + if detector_id == ref_id: + offsets[detector_id] = 0.0 continue - dt = (t[id] - t[ref_id]).compress(filters[id] & filters[ref_id]) - dz = z[id] - z[ref_id] - offsets[id], _ = determine_detector_timing_offset(dt, dz) + dt = (t[detector_id] - t[ref_id]).compress(filters[detector_id] & filters[ref_id]) + dz = z[detector_id] - z[ref_id] + offsets[detector_id], _ = determine_detector_timing_offset(dt, dz) # If all except reference are nan, make reference nan. if sum(isnan(offsets)) == 3: @@ -101,9 +102,14 @@ class DetermineStationTimingOffsets: # Minimum number of timedeltas required to attempt a fit MIN_LEN_DT = 200 - def __init__(self, stations=None, data=None, progress=False, - force_stale=False, - time_deltas_group='/coincidences/time_deltas'): + def __init__( + self, + stations=None, + data=None, + progress=False, + force_stale=False, + time_deltas_group='/coincidences/time_deltas', + ): """Initialize the class :param stations: list of stations for which to determine offsets. @@ -118,8 +124,7 @@ def __init__(self, stations=None, data=None, progress=False, self.force_stale = force_stale self.time_deltas_group = time_deltas_group if stations is not None: - self.cluster = HiSPARCStations(stations, skip_missing=True, - force_stale=self.force_stale) + self.cluster = HiSPARCStations(stations, skip_missing=True, force_stale=self.force_stale) else: self.cluster = HiSPARCNetwork(force_stale=self.force_stale) @@ -129,22 +134,19 @@ def read_dt(self, station, ref_station, start, end): pair = (ref_station, station) table_path = self.time_deltas_group + '/station_%d/station_%d' % pair table = self.data.get_node(table_path, 'time_deltas') - ts0 = datetime_to_gps(start) # noqa - ts1 = datetime_to_gps(end) # noqa - return table.read_where('(timestamp >= ts0) & (timestamp < ts1)', - field='delta') + ts0 = datetime_to_gps(start) # noqa: F841 + ts1 = datetime_to_gps(end) # noqa: F841 + return table.read_where('(timestamp >= ts0) & (timestamp < ts1)', field='delta') @memoize def _get_gps_timestamps(self, station): """Get timestamps of station gps changes""" - return Station(station, - force_stale=self.force_stale).gps_locations['timestamp'] + return Station(station, force_stale=self.force_stale).gps_locations['timestamp'] @memoize def _get_electronics_timestamps(self, station): """Get timestamps of station electronics (hardware) changes""" - return Station(station, - force_stale=self.force_stale).electronics['timestamp'] + return Station(station, force_stale=self.force_stale).electronics['timestamp'] def _get_cuts(self, station, ref_station): """Get cuts for determination of offsets @@ -156,13 +158,17 @@ def _get_cuts(self, station, ref_station): :return: list of datetime objects """ - cuts = {self._datetime(gps_to_datetime(ts)) - for ts in chain(self._get_gps_timestamps(station), - self._get_gps_timestamps(ref_station), - self._get_electronics_timestamps(station), - self._get_electronics_timestamps(ref_station))} + cuts = { + self._datetime(gps_to_datetime(ts)) + for ts in chain( + self._get_gps_timestamps(station), + self._get_gps_timestamps(ref_station), + self._get_electronics_timestamps(station), + self._get_electronics_timestamps(ref_station), + ) + } today = self._datetime(datetime.now()) - cuts = sorted(list(cuts) + [today]) + cuts = sorted([*list(cuts), today]) return cuts @memoize @@ -177,7 +183,8 @@ def _get_r_dz(self, date, station, ref_station): self.cluster.set_timestamp(datetime_to_gps(date)) r, _, dz = self.cluster.calc_rphiz_for_stations( self.cluster.get_station(ref_station).station_id, - self.cluster.get_station(station).station_id) + self.cluster.get_station(station).station_id, + ) return r, dz def _determine_interval(self, r): @@ -187,7 +194,7 @@ def _determine_interval(self, r): :return: number of days in interval. """ - return max(int(r ** 1.2 / 10), 7) + return max(int(r**1.2 / 10), 7) def _get_left_and_right_bounds(self, cuts, date, days): """Determine left and right bounds between cuts @@ -261,8 +268,7 @@ def determine_station_timing_offset(self, date, station, ref_station): """ date = self._datetime(date) - left, right = self.determine_first_and_last_date(date, station, - ref_station) + left, right = self.determine_first_and_last_date(date, station, ref_station) r, dz = self._get_r_dz(date, station, ref_station) dt = self.read_dt(station, ref_station, left, right) if len(dt) < self.MIN_LEN_DT: @@ -272,8 +278,7 @@ def determine_station_timing_offset(self, date, station, ref_station): return s_off, error - def determine_station_timing_offsets(self, station, ref_station, - start=None, end=None): + def determine_station_timing_offsets(self, station, ref_station, start=None, end=None): """Determine the timing offsets between a station pair :param station: station number. @@ -291,11 +296,9 @@ def determine_station_timing_offsets(self, station, ref_station, offsets = [] length = (end - start).days - for date, _ in pbar(datetime_range(start, end), show=self.progress, - length=length): + for date, _ in pbar(datetime_range(start, end), show=self.progress, length=length): ts0 = datetime_to_gps(date) - s_off, error = self.determine_station_timing_offset(date, station, - ref_station) + s_off, error = self.determine_station_timing_offset(date, station, ref_station) offsets.append((ts0, s_off, error)) return offsets @@ -310,8 +313,7 @@ def determine_station_timing_offsets_for_date(self, date): station_pairs = self.get_station_pairs_within_max_distance(date) offsets = [] for station, ref_station in station_pairs: - s_off, error = self.determine_station_timing_offset(date, station, - ref_station) + s_off, error = self.determine_station_timing_offset(date, station, ref_station) offsets.append((station, ref_station, s_off, error)) return offsets @@ -364,8 +366,7 @@ def fit_timing_offset(dt, bins): x = (bins[:-1] + bins[1:]) / 2 sigma = sqrt(y + 1) try: - popt, pcov = curve_fit(gauss, x, y, p0=(len(dt), 0., std(dt)), - sigma=sigma, absolute_sigma=False) + popt, pcov = curve_fit(gauss, x, y, p0=(len(dt), 0.0, std(dt)), sigma=sigma, absolute_sigma=False) offset = popt[1] width = popt[2] offset_error = width / sqrt(sum(y)) @@ -386,10 +387,9 @@ def determine_best_reference(filters): lengths = [] ids = range(len(filters)) - for id in ids: - idx = [j for j in ids if j != id] - lengths.append(sum(filters[id] & (filters[idx[0]] | - filters[idx[1]] | filters[idx[2]]))) + for detector_id in ids: + idx = [j for j in ids if j != detector_id] + lengths.append(sum(filters[detector_id] & (filters[idx[0]] | filters[idx[1]] | filters[idx[2]]))) return lengths.index(max(lengths)) diff --git a/sapphire/analysis/coincidence_queries.py b/sapphire/analysis/coincidence_queries.py old mode 100755 new mode 100644 index fd750732..61116961 --- a/sapphire/analysis/coincidence_queries.py +++ b/sapphire/analysis/coincidence_queries.py @@ -8,7 +8,6 @@ class CoincidenceQuery: - """Perform queries on an ESD file where coincidences have been analysed. Functions in this class build and perform queries to easily filter @@ -47,8 +46,7 @@ def __init__(self, data, coincidence_group='/coincidences'): self.data = tables.open_file(data, 'r') else: self.data = data - self.coincidences = self.data.get_node(coincidence_group, - 'coincidences') + self.coincidences = self.data.get_node(coincidence_group, 'coincidences') self.c_index = self.data.get_node(coincidence_group, 'c_index') self.s_index = self.data.get_node(coincidence_group, 's_index') self.s_nodes = [] @@ -58,12 +56,10 @@ def __init__(self, data, coincidence_group='/coincidences'): except tables.NoSuchNodeError: self.s_nodes.append(None) re_number = re.compile('[0-9]+$') - self.s_numbers = [int(re_number.search(s_path.decode('utf-8')).group()) - for s_path in self.s_index] + self.s_numbers = [int(re_number.search(s_path.decode('utf-8')).group()) for s_path in self.s_index] try: - self.reconstructions = self.data.get_node(coincidence_group, - 'reconstructions') + self.reconstructions = self.data.get_node(coincidence_group, 'reconstructions') self.reconstructed = True except tables.NoSuchNodeError: self.reconstructed = False @@ -137,8 +133,7 @@ def at_least(self, stations, n, start=None, stop=None, iterator=False): if len(s_columns) < n: # No combinations possible because there are to few stations return [] - s_combinations = ['(%s)' % (' & '.join(combo)) - for combo in itertools.combinations(s_columns, n)] + s_combinations = ['(%s)' % (' & '.join(combo)) for combo in itertools.combinations(s_columns, n)] query = '(%s)' % ' | '.join(s_combinations) query = self._add_timestamp_filter(query, start, stop) filtered_coincidences = self.perform_query(query, iterator) @@ -222,8 +217,7 @@ def _get_events(self, coincidence): station_number = self.s_numbers[s_idx] s_node = self.s_nodes[s_idx] if s_node is None: - warnings.warn('Missing station group for station id %d. ' - 'Events from it are excluded.' % s_idx) + warnings.warn('Missing station group for station id %d. Events from it are excluded.' % s_idx) continue events.append((station_number, s_node.events[e_idx])) return events @@ -242,8 +236,7 @@ def _get_reconstructions(self, coincidence): station_number = self.s_numbers[s_idx] s_node = self.s_nodes[s_idx] if s_node is None: - warnings.warn(f'Missing station group for station id {s_idx}.' - 'Reconstructions from it are excluded.') + warnings.warn(f'Missing station group for station id {s_idx}. Reconstructions from it are excluded.') continue rec_table = s_node.reconstructions reconstructions.append((station_number, rec_table[e_idx])) @@ -260,9 +253,11 @@ def _get_reconstruction(self, coincidence): reconstruction = self.reconstructions[coincidence['id']] return reconstruction else: - raise Exception('Coincidences are not (properly) reconstructed.' - 'Perform reconstructions and reinitialize this ' - 'class.') + raise RuntimeError( + 'Coincidences are not (properly) reconstructed.' + 'Perform reconstructions and reinitialize this ' + 'class.', + ) def all_events(self, coincidences, n=0): """Get all events for the given coincidences. @@ -272,8 +267,7 @@ def all_events(self, coincidences, n=0): :return: list of events for each coincidence. """ - coincidence_events = (self._get_events(coincidence) - for coincidence in coincidences) + coincidence_events = (self._get_events(coincidence) for coincidence in coincidences) return self.minimum_events_for_coincidence(coincidence_events, n) def all_reconstructions(self, coincidences, n=0): @@ -284,8 +278,7 @@ def all_reconstructions(self, coincidences, n=0): :return: list of reconstructed events for each coincidence. """ - coincidence_recs = (self._get_reconstructions(coincidence) - for coincidence in coincidences) + coincidence_recs = (self._get_reconstructions(coincidence) for coincidence in coincidences) return self.minimum_events_for_coincidence(coincidence_recs, n) def minimum_events_for_coincidence(self, coincidences_events, n=2): @@ -295,9 +288,7 @@ def minimum_events_for_coincidence(self, coincidences_events, n=2): :param n: minimum number of events per coincidence. """ - filtered_coincidences = (coincidence - for coincidence in coincidences_events - if len(coincidence) >= n) + filtered_coincidences = (coincidence for coincidence in coincidences_events if len(coincidence) >= n) return filtered_coincidences def events_from_stations(self, coincidences, stations, n=2): @@ -308,10 +299,8 @@ def events_from_stations(self, coincidences, stations, n=2): :return: list of filtered events for each coincidence. """ - events_iterator = (self._get_events(coincidence) - for coincidence in coincidences) - coincidences_events = (self._events_from_stations(events, stations) - for events in events_iterator) + events_iterator = (self._get_events(coincidence) for coincidence in coincidences) + coincidences_events = (self._events_from_stations(events, stations) for events in events_iterator) return self.minimum_events_for_coincidence(coincidences_events, n) def reconstructions_from_stations(self, coincidences, stations, n=2): @@ -322,10 +311,8 @@ def reconstructions_from_stations(self, coincidences, stations, n=2): :return: list of filtered reconstructed events for each coincidence. """ - reconstructions_iterator = (self._get_reconstructions(coincidence) - for coincidence in coincidences) - coincidences_recs = (self._events_from_stations(recs, stations) - for recs in reconstructions_iterator) + reconstructions_iterator = (self._get_reconstructions(coincidence) for coincidence in coincidences) + coincidences_recs = (self._events_from_stations(recs, stations) for recs in reconstructions_iterator) return self.minimum_events_for_coincidence(coincidences_recs, n) def _events_from_stations(self, events, stations): @@ -370,10 +357,6 @@ def events_in_cluster(self, coincidences, cluster, n=2): def __repr__(self): try: - return "{}({!r}, {!r})".format( - self.__class__.__name__, - self.data.filename, - self.coincidences._v_parent._v_pathname - ) + return f'{self.__class__.__name__}({self.data.filename!r}, {self.coincidences._v_parent._v_pathname!r})' except AttributeError: - return f"" + return f'' diff --git a/sapphire/analysis/coincidences.py b/sapphire/analysis/coincidences.py index 04ecc8b3..bfc57834 100644 --- a/sapphire/analysis/coincidences.py +++ b/sapphire/analysis/coincidences.py @@ -1,35 +1,36 @@ -""" Search for coincidences between HiSPARC stations +"""Search for coincidences between HiSPARC stations - This module can be used to search for coincidences between several - HiSPARC stations. To skip this and directly download coincidences - use :func:`~sapphire.esd.download_coincidences`, this is slightly - less flexible because you can not choose the coincidence window. +This module can be used to search for coincidences between several +HiSPARC stations. To skip this and directly download coincidences +use :func:`~sapphire.esd.download_coincidences`, this is slightly +less flexible because you can not choose the coincidence window. - For regular usage, download events from the ESD and use the - :class:`CoincidencesESD` class. Example usage:: +For regular usage, download events from the ESD and use the +:class:`CoincidencesESD` class. Example usage:: - import datetime + import datetime - import tables + import tables - from sapphire import CoincidencesESD, download_data + from sapphire import CoincidencesESD, download_data - STATIONS = [501, 503, 506] - START = datetime.datetime(2013, 1, 1) - END = datetime.datetime(2013, 1, 2) + STATIONS = [501, 503, 506] + START = datetime.datetime(2013, 1, 1) + END = datetime.datetime(2013, 1, 2) - if __name__ == '__main__': - station_groups = ['/s%d' % u for u in STATIONS] + if __name__ == '__main__': + station_groups = ['/s%d' % u for u in STATIONS] - data = tables.open_file('data.h5', 'w') - for station, group in zip(STATIONS, station_groups): - download_data(data, group, station, START, END) + data = tables.open_file('data.h5', 'w') + for station, group in zip(STATIONS, station_groups): + download_data(data, group, station, START, END) - coin = CoincidencesESD(data, '/coincidences', station_groups) - coin.search_and_store_coincidences() + coin = CoincidencesESD(data, '/coincidences', station_groups) + coin.search_and_store_coincidences() """ + import os.path import numpy as np @@ -112,8 +113,7 @@ class Coincidences: """ - def __init__(self, data, coincidence_group, station_groups, - overwrite=False, progress=True): + def __init__(self, data, coincidence_group, station_groups, overwrite=False, progress=True): """Initialize the class. :param data: either a PyTables file or path to a HDF5 file. @@ -139,12 +139,11 @@ def __init__(self, data, coincidence_group, station_groups, if overwrite: self.data.remove_node(coincidence_group, recursive=True) else: - raise RuntimeError("Group %s already exists in datafile, " - "and overwrite is False" % - coincidence_group) + raise RuntimeError( + 'Group %s already exists in datafile, and overwrite is False' % coincidence_group, + ) head, tail = os.path.split(coincidence_group) - self.coincidence_group = self.data.create_group(head, tail, - createparents=True) + self.coincidence_group = self.data.create_group(head, tail, createparents=True) self.station_groups = station_groups self.trig_threshold = 0.5 @@ -163,7 +162,7 @@ def __exit__(self, exc_type, exc_value, traceback): if self.opened: self.data.close() - def search_and_store_coincidences(self, window=10000): + def search_and_store_coincidences(self, window=10_000): """Search, process and store coincidences. This is a semi-automatic method to search for coincidences, @@ -179,7 +178,7 @@ def search_and_store_coincidences(self, window=10000): self.process_events() self.store_coincidences() - def search_coincidences(self, window=10000, shifts=None, limit=None): + def search_coincidences(self, window=10_000, shifts=None, limit=None): """Search for coincidences. Search all data in the station_groups for coincidences, and store @@ -208,14 +207,10 @@ def search_coincidences(self, window=10000, shifts=None, limit=None): events. """ - c_index, timestamps = \ - self._search_coincidences(window, shifts, limit) + c_index, timestamps = self._search_coincidences(window, shifts, limit) timestamps = np.array(timestamps, dtype=np.uint64) - self.data.create_array(self.coincidence_group, '_src_timestamps', - timestamps) - src_c_index = self.data.create_vlarray(self.coincidence_group, - '_src_c_index', - tables.UInt32Atom()) + self.data.create_array(self.coincidence_group, '_src_timestamps', timestamps) + src_c_index = self.data.create_vlarray(self.coincidence_group, '_src_c_index', tables.UInt32Atom()) for coincidence in c_index: src_c_index.append(coincidence) @@ -238,29 +233,24 @@ def process_events(self, overwrite=None): if len(c_index) == 0: return - selected_timestamps = [] - for coincidence in c_index: - for event in coincidence: - selected_timestamps.append(timestamps[event]) + selected_timestamps = [timestamps[event] for coincidence in c_index for event in coincidence] full_index = np.array(selected_timestamps) for station_id, station_group in enumerate(self.station_groups): station_group = self.data.get_node(station_group) - selected = full_index.compress(full_index[:, 1] == station_id, - axis=0) + selected = full_index.compress(full_index[:, 1] == station_id, axis=0) index = selected[:, 2] if 'blobs' in station_group: if self.progress: - print("Processing coincidence events with traces") + print('Processing coincidence events with traces') processor = process_events.ProcessIndexedEventsWithLINT else: if self.progress: - print("Processing coincidence events without traces") + print('Processing coincidence events without traces') processor = process_events.ProcessIndexedEventsWithoutTraces - process = processor(self.data, station_group, index, - progress=self.progress) + process = processor(self.data, station_group, index, progress=self.progress) process.process_and_store_results(overwrite=overwrite) def store_coincidences(self): @@ -272,19 +262,13 @@ def store_coincidences(self): """ self.c_index = [] - self.coincidences = self.data.create_table(self.coincidence_group, - 'coincidences', - storage.Coincidence) - self.observables = self.data.create_table(self.coincidence_group, - 'observables', - storage.EventObservables) - - for coincidence in pbar(self.coincidence_group._src_c_index, - show=self.progress): + self.coincidences = self.data.create_table(self.coincidence_group, 'coincidences', storage.Coincidence) + self.observables = self.data.create_table(self.coincidence_group, 'observables', storage.EventObservables) + + for coincidence in pbar(self.coincidence_group._src_c_index, show=self.progress): self._store_coincidence(coincidence) - c_index = self.data.create_vlarray(self.coincidence_group, 'c_index', - tables.UInt32Col()) + c_index = self.data.create_vlarray(self.coincidence_group, 'c_index', tables.UInt32Col()) for coincidence in self.c_index: c_index.append(coincidence) c_index.flush() @@ -311,21 +295,17 @@ def _store_coincidence(self, coincidence): group = self.data.get_node(self.station_groups[station_id]) event = group.events[event_index] - idx = self._store_event_in_observables(event, coincidence_id, - station_id) + idx = self._store_event_in_observables(event, coincidence_id, station_id) observables_idx.append(idx) - timestamps.append((event['ext_timestamp'], event['timestamp'], - event['nanoseconds'])) + timestamps.append((event['ext_timestamp'], event['timestamp'], event['nanoseconds'])) first_timestamp = sorted(timestamps)[0] - row['ext_timestamp'], row['timestamp'], row['nanoseconds'] = \ - first_timestamp + row['ext_timestamp'], row['timestamp'], row['nanoseconds'] = first_timestamp row.append() self.c_index.append(observables_idx) self.coincidences.flush() - def _store_event_in_observables(self, event, coincidence_id, - station_id): + def _store_event_in_observables(self, event, coincidence_id, station_id): """Store a single event in the observables table.""" row = self.observables.row @@ -333,8 +313,7 @@ def _store_event_in_observables(self, event, coincidence_id, row['id'] = event_id row['station_id'] = station_id - for key in ('timestamp', 'nanoseconds', 'ext_timestamp', - 'n1', 'n2', 'n3', 'n4', 't1', 't2', 't3', 't4'): + for key in ('timestamp', 'nanoseconds', 'ext_timestamp', 'n1', 'n2', 'n3', 'n4', 't1', 't2', 't3', 't4'): row[key] = event[key] signals = [event[key] for key in ('n1', 'n2', 'n3', 'n4')] @@ -345,7 +324,7 @@ def _store_event_in_observables(self, event, coincidence_id, self.observables.flush() return event_id - def _search_coincidences(self, window=10000, shifts=None, limit=None): + def _search_coincidences(self, window=10_000, shifts=None, limit=None): """Search for coincidences Search for coincidences in a set of PyTables event tables, optionally @@ -376,8 +355,7 @@ def _search_coincidences(self, window=10000, shifts=None, limit=None): for station_group in self.station_groups: station_group = self.data.get_node(station_group) if 'events' in station_group: - event_tables.append(self.data.get_node(station_group, - 'events')) + event_tables.append(self.data.get_node(station_group, 'events')) timestamps = self._retrieve_timestamps(event_tables, shifts, limit) coincidences = self._do_search_coincidences(timestamps, window) @@ -407,20 +385,17 @@ def _retrieve_timestamps(self, event_tables, shifts=None, limit=None): # calculate the shifts in nanoseconds and cast them to int. # (prevent upcasting timestamps to float64 further on) if shifts is not None: - shifts = [int(shift * 1e9) if shift is not None else shift - for shift in shifts] + shifts = [int(shift * 1_000_000_000) if shift is not None else shift for shift in shifts] timestamps = [] for s_id, event_table in enumerate(event_tables): - ts = [(x, s_id, j) for j, x in - enumerate(event_table.col('ext_timestamp')[:limit])] + ts = [(x, s_id, j) for j, x in enumerate(event_table.col('ext_timestamp')[:limit])] try: # shift data. carefully avoid upcasting (we're adding two # ints, which is an int, and casting that back to uint64. if # we're not careful, an intermediate value will be a float64, # which doesn't hold the precision to store nanoseconds. - ts = [(np.uint64(int(x) + shifts[i]), i, j) - for x, i, j in ts] + ts = [(np.uint64(int(x) + shifts[i]), i, j) for x, i, j in ts] except (TypeError, IndexError): # shift is None or doesn't exist pass @@ -455,11 +430,9 @@ def _do_search_coincidences(self, timestamps, window): prev_coincidence = [] if self.progress and len(timestamps): - pbar = ProgressBar(max_value=len(timestamps), - widgets=[Percentage(), Bar(), ETA()]).start() + pbar = ProgressBar(max_value=len(timestamps), widgets=[Percentage(), Bar(), ETA()]).start() for i in range(len(timestamps)): - # build coincidence, starting with the current timestamp c = [i] t0 = timestamps[i][0] @@ -476,8 +449,7 @@ def _do_search_coincidences(self, timestamps, window): # if we have more than one event in the coincidence, save it if len(c) > 1: # is this coincidence part of the previous coincidence? - is_part_of_prev = np.array([u in prev_coincidence - for u in c]).all() + is_part_of_prev = np.array([u in prev_coincidence for u in c]).all() if not is_part_of_prev: # no, so it's a new one coincidences.append(c) @@ -493,16 +465,25 @@ def _do_search_coincidences(self, timestamps, window): def __repr__(self): if not self.data.isopen: - return "" % self.__class__.__name__ + return '' % self.__class__.__name__ try: - return ("%s(%r, %r, %r, overwrite=%r, progress=%r)" % - (self.__class__.__name__, self.data.filename, - self.coincidences._v_parent._v_pathname, - self.station_groups, self.overwrite, self.progress)) + return '%s(%r, %r, %r, overwrite=%r, progress=%r)' % ( + self.__class__.__name__, + self.data.filename, + self.coincidences._v_parent._v_pathname, + self.station_groups, + self.overwrite, + self.progress, + ) except AttributeError: - return ("%s(%r, %r, %r, overwrite=%r, progress=%r)" % - (self.__class__.__name__, self.data.filename, - None, self.station_groups, self.overwrite, self.progress)) + return '%s(%r, %r, %r, overwrite=%r, progress=%r)' % ( + self.__class__.__name__, + self.data.filename, + None, + self.station_groups, + self.overwrite, + self.progress, + ) class CoincidencesESD(Coincidences): @@ -577,8 +558,7 @@ class CoincidencesESD(Coincidences): """ - def search_and_store_coincidences(self, window=10000, - station_numbers=None): + def search_and_store_coincidences(self, window=10_000, station_numbers=None): """Search and store coincidences. This is a semi-automatic method to search for coincidences @@ -588,7 +568,7 @@ def search_and_store_coincidences(self, window=10000, self.search_coincidences(window=window) self.store_coincidences(station_numbers=station_numbers) - def search_coincidences(self, window=10000, shifts=None, limit=None): + def search_coincidences(self, window=10_000, shifts=None, limit=None): """Search for coincidences. Search all data in the station_groups for coincidences, and store @@ -637,21 +617,21 @@ def store_coincidences(self, station_numbers=None): n_coincidences = len(self._src_c_index) if station_numbers is not None: if len(station_numbers) != len(self.station_groups): - raise RuntimeError( - "Number of station numbers must equal number of groups.") + raise RuntimeError('Number of station numbers must equal number of groups.') self.station_numbers = station_numbers - s_columns = {'s%d' % number: tables.BoolCol(pos=p) - for p, number in enumerate(station_numbers, 12)} + s_columns = {'s%d' % number: tables.BoolCol(pos=p) for p, number in enumerate(station_numbers, 12)} else: self.station_numbers = None - s_columns = {'s%d' % n: tables.BoolCol(pos=(n + 12)) - for n, _ in enumerate(self.station_groups)} + s_columns = {'s%d' % n: tables.BoolCol(pos=(n + 12)) for n, _ in enumerate(self.station_groups)} description = storage.Coincidence description.columns.update(s_columns) self.coincidences = self.data.create_table( - self.coincidence_group, 'coincidences', description, - expectedrows=n_coincidences) + self.coincidence_group, + 'coincidences', + description, + expectedrows=n_coincidences, + ) self.c_index = [] @@ -659,15 +639,21 @@ def store_coincidences(self, station_numbers=None): self._store_coincidence(coincidence) c_index = self.data.create_vlarray( - self.coincidence_group, 'c_index', tables.UInt32Col(shape=2), - expectedrows=n_coincidences) + self.coincidence_group, + 'c_index', + tables.UInt32Col(shape=2), + expectedrows=n_coincidences, + ) for observables_idx in pbar(self.c_index, show=self.progress): c_index.append(observables_idx) c_index.flush() s_index = self.data.create_vlarray( - self.coincidence_group, 's_index', tables.VLStringAtom(), - expectedrows=len(self.station_groups)) + self.coincidence_group, + 's_index', + tables.VLStringAtom(), + expectedrows=len(self.station_groups), + ) for station_group in self.station_groups: s_index.append(station_group.encode('utf-8')) s_index.flush() @@ -699,12 +685,10 @@ def _store_coincidence(self, coincidence): group = self.data.get_node(self.station_groups[station_id]) event = group.events[event_index] observables_idx.append((station_id, event_index)) - timestamps.append((event['ext_timestamp'], event['timestamp'], - event['nanoseconds'])) + timestamps.append((event['ext_timestamp'], event['timestamp'], event['nanoseconds'])) first_timestamp = sorted(timestamps)[0] - row['ext_timestamp'], row['timestamp'], row['nanoseconds'] = \ - first_timestamp + row['ext_timestamp'], row['timestamp'], row['nanoseconds'] = first_timestamp row.append() self.c_index.append(observables_idx) self.coincidences.flush() @@ -736,8 +720,7 @@ def get_events(data, stations, coincidence, timestamps, get_raw_traces=False): process = process_events.ProcessEvents(data, stations[station]) event = process.source[index] if not get_raw_traces: - baseline = np.where(event['baseline'] != -999, event['baseline'], - 200)[np.where(event['traces'] >= 0)] + baseline = np.where(event['baseline'] != -999, event['baseline'], 200)[np.where(event['traces'] >= 0)] # transpose to get expected format traces = (process.get_traces_for_event(event) - baseline).T else: diff --git a/sapphire/analysis/core_reconstruction.py b/sapphire/analysis/core_reconstruction.py index 31c6bace..7350974f 100644 --- a/sapphire/analysis/core_reconstruction.py +++ b/sapphire/analysis/core_reconstruction.py @@ -1,18 +1,18 @@ -""" Core reconstruction - - This module contains two classes that can be used to reconstruct - HiSPARC events and coincidences. The classes know how to extract the - relevant information from the station and event or cluster and - coincidence. Various algorithms which do the reconstruction are also - defined here. The algorithms require positions and particle densties to - do the reconstruction. - - Each algorithm has a :meth:`~BaseCoreAlgorithm.reconstruct_common` - method which always requires particle denisties, x, and y positions - and optionally z positions and previous reconstruction results. The - data is then prepared for the algorithm and passed to - the :meth:`~CenterMassAlgorithm.reconstruct` method which returns the - reconstructed x and y coordinates. +"""Core reconstruction + +This module contains two classes that can be used to reconstruct +HiSPARC events and coincidences. The classes know how to extract the +relevant information from the station and event or cluster and +coincidence. Various algorithms which do the reconstruction are also +defined here. The algorithms require positions and particle densties to +do the reconstruction. + +Each algorithm has a :meth:`~BaseCoreAlgorithm.reconstruct_common` +method which always requires particle denisties, x, and y positions +and optionally z positions and previous reconstruction results. The +data is then prepared for the algorithm and passed to +the :meth:`~CenterMassAlgorithm.reconstruct` method which returns the +reconstructed x and y coordinates. """ @@ -28,7 +28,6 @@ class EventCoreReconstruction: - """Reconstruct core for station events This class is aware of 'events' and 'stations'. Initialize this class @@ -62,23 +61,21 @@ def reconstruct_event(self, event, detector_ids=None, initial=None): detector_ids = range(4) self.station.cluster.set_timestamp(event['timestamp']) - for id in detector_ids: - p_detector = detector_density(event, id, self.station) + for detector_id in detector_ids: + p_detector = detector_density(event, detector_id, self.station) if not isnan(p_detector): - dx, dy, dz = self.station.detectors[id].get_coordinates() + dx, dy, dz = self.station.detectors[detector_id].get_coordinates() p.append(p_detector) x.append(dx) y.append(dy) z.append(dz) if len(p) >= 3: - core_x, core_y = self.estimator.reconstruct_common(p, x, y, z, - initial) + core_x, core_y = self.estimator.reconstruct_common(p, x, y, z, initial) else: core_x, core_y = (nan, nan) return core_x, core_y - def reconstruct_events(self, events, detector_ids=None, progress=True, - initials=None): + def reconstruct_events(self, events, detector_ids=None, progress=True, initials=None): """Reconstruct events :param events: the events table for the station from an ESD data @@ -95,8 +92,7 @@ def reconstruct_events(self, events, detector_ids=None, progress=True, events = pbar(events, show=progress) events_init = zip_longest(events, initials) - cores = [self.reconstruct_event(event, detector_ids, initial) - for event, initial in events_init] + cores = [self.reconstruct_event(event, detector_ids, initial) for event, initial in events_init] if len(cores): core_x, core_y = zip(*cores) else: @@ -104,12 +100,10 @@ def reconstruct_events(self, events, detector_ids=None, progress=True, return core_x, core_y def __repr__(self): - return ("<%s, station: %r, estimator: %r>" % - (self.__class__.__name__, self.station, self.estimator)) + return '<%s, station: %r, estimator: %r>' % (self.__class__.__name__, self.station, self.estimator) class CoincidenceCoreReconstruction: - """Reconstruct core for coincidences This class is aware of 'coincidences' and 'clusters'. Initialize @@ -124,8 +118,7 @@ def __init__(self, cluster): self.estimator = CenterMassAlgorithm self.cluster = cluster - def reconstruct_coincidence(self, coincidence, station_numbers=None, - initial=None): + def reconstruct_coincidence(self, coincidence, station_numbers=None, initial=None): """Reconstruct a single coincidence :param coincidence: a coincidence list consisting of @@ -158,14 +151,12 @@ def reconstruct_coincidence(self, coincidence, station_numbers=None, z.append(sz) if len(p) >= 3: - core_x, core_y = self.estimator.reconstruct_common(p, x, y, z, - initial) + core_x, core_y = self.estimator.reconstruct_common(p, x, y, z, initial) else: core_x, core_y = (nan, nan) return core_x, core_y - def reconstruct_coincidences(self, coincidences, station_numbers=None, - progress=True, initials=None): + def reconstruct_coincidences(self, coincidences, station_numbers=None, progress=True, initials=None): """Reconstruct all coincidences :param coincidences: a list of coincidences, each consisting of @@ -183,9 +174,9 @@ def reconstruct_coincidences(self, coincidences, station_numbers=None, coincidences = pbar(coincidences, show=progress) coin_init = zip_longest(coincidences, initials) - cores = [self.reconstruct_coincidence(coincidence, station_numbers, - initial) - for coincidence, initial in coin_init] + cores = [ + self.reconstruct_coincidence(coincidence, station_numbers, initial) for coincidence, initial in coin_init + ] if len(cores): core_x, core_y = list(zip(*cores)) else: @@ -193,13 +184,10 @@ def reconstruct_coincidences(self, coincidences, station_numbers=None, return core_x, core_y def __repr__(self): - return ("<%s, cluster: %r, estimator: %r>" % - (self.__class__.__name__, self.cluster, self.estimator)) - + return '<%s, cluster: %r, estimator: %r>' % (self.__class__.__name__, self.cluster, self.estimator) -class CoincidenceCoreReconstructionDetectors( - CoincidenceCoreReconstruction): +class CoincidenceCoreReconstructionDetectors(CoincidenceCoreReconstruction): """Reconstruct core for coincidences using each detector Instead of using the average station particle density this class @@ -207,8 +195,7 @@ class CoincidenceCoreReconstructionDetectors( """ - def reconstruct_coincidence(self, coincidence, station_numbers=None, - initial=None): + def reconstruct_coincidence(self, coincidence, station_numbers=None, initial=None): """Reconstruct a single coincidence :param coincidence: a coincidence list consisting of @@ -228,29 +215,26 @@ def reconstruct_coincidence(self, coincidence, station_numbers=None, return (nan, nan) for station_number, event in coincidence: - if station_numbers is not None: - if station_number not in station_numbers: - continue + if station_numbers is not None and station_number not in station_numbers: + continue station = self.cluster.get_station(station_number) - for id in range(4): - p_detector = detector_density(event, id, station) + for detector_id in range(4): + p_detector = detector_density(event, detector_id, station) if not isnan(p_detector): - dx, dy, dz = station.detectors[id].get_coordinates() + dx, dy, dz = station.detectors[detector_id].get_coordinates() p.append(p_detector) x.append(dx) y.append(dy) z.append(dz) if len(p) >= 3: - core_x, core_y = self.estimator.reconstruct_common(p, x, y, z, - initial) + core_x, core_y = self.estimator.reconstruct_common(p, x, y, z, initial) else: core_x, core_y = (nan, nan) return core_x, core_y class BaseCoreAlgorithm: - """No actual core reconstruction algorithm Simply returns (nan, nan) as core. @@ -282,7 +266,6 @@ def reconstruct(): class CenterMassAlgorithm(BaseCoreAlgorithm): - """Simple core estimator Estimates the core by center of mass of the measurements. @@ -324,7 +307,6 @@ def reconstruct(p, x, y): class AverageIntersectionAlgorithm(BaseCoreAlgorithm): - """Core estimator To the densities in 3 stations correspond 2 possible cores. The line @@ -351,7 +333,7 @@ def reconstruct_common(cls, p, x, y, z=None, initial=None): """ if len(p) < 4 or len(x) < 4 or len(y) < 4: - raise Exception('This algorithm requires at least 4 detections.') + raise ValueError('This algorithm requires at least 4 detections.') if initial is None: initial = {} @@ -371,8 +353,8 @@ def reconstruct_common(cls, p, x, y, z=None, initial=None): linelist0 = [] linelist1 = [] for zero, one, two in subsets: - pp = (phit[zero] / phit[one]) ** (2. / m) - qq = (phit[zero] / phit[two]) ** (2. / m) + pp = (phit[zero] / phit[one]) ** (2.0 / m) + qq = (phit[zero] / phit[two]) ** (2.0 / m) if pp == 1: pp = 1.000001 if qq == 1: @@ -399,8 +381,7 @@ def reconstruct_common(cls, p, x, y, z=None, initial=None): linelist0.append(-e / f) linelist1.append((a * e + b * f + g * k) / f) - newx, newy = CenterMassAlgorithm.reconstruct_common(p, x, y, z, - initial) + newx, newy = CenterMassAlgorithm.reconstruct_common(p, x, y, z, initial) subsets = combinations(statindex, 2) xpointlist = [] @@ -419,13 +400,11 @@ def reconstruct_common(cls, p, x, y, z=None, initial=None): xpointlist.append(xint) ypointlist.append(yint) - subxplist, subyplist = cls.select_newlist( - newx, newy, xpointlist, ypointlist, 120.) + subxplist, subyplist = cls.select_newlist(newx, newy, xpointlist, ypointlist, 120.0) if len(subxplist) > 3: newx = mean(subxplist) newy = mean(subyplist) - subxplist, subyplist = cls.select_newlist( - newx, newy, xpointlist, ypointlist, 100.) + subxplist, subyplist = cls.select_newlist(newx, newy, xpointlist, ypointlist, 100.0) if len(subxplist) > 2: newx = mean(subxplist) newy = mean(subyplist) @@ -448,7 +427,6 @@ def select_newlist(newx, newy, xpointlist, ypointlist, distance): class EllipsLdfAlgorithm(BaseCoreAlgorithm): - """Simple core estimator Estimates the core by center of mass of the measurements. @@ -469,8 +447,8 @@ def reconstruct_common(cls, p, x, y, z=None, initial=None): """ if initial is None: initial = {} - theta = initial.get('theta', 0.) - phi = initial.get('phi', 0.) + theta = initial.get('theta', 0.0) + phi = initial.get('phi', 0.0) return cls.reconstruct(p, x, y, theta, phi)[:2] @classmethod @@ -484,22 +462,41 @@ def reconstruct(cls, p, x, y, theta, phi): """ xcmass, ycmass = CenterMassAlgorithm.reconstruct_common(p, x, y) - chi2best = 10 ** 99 + chi2best = 10**99 xbest = xcmass ybest = ycmass - factorbest = 1. - gridsize = 5. + factorbest = 1.0 + gridsize = 5.0 xbest1, ybest1, chi2best1, factorbest1 = cls.selectbest( - p, x, y, xbest, ybest, factorbest, chi2best, gridsize, theta, phi) - - xlines, ylines = AverageIntersectionAlgorithm.reconstruct_common(p, x, - y) - chi2best = 10 ** 99 + p, + x, + y, + xbest, + ybest, + factorbest, + chi2best, + gridsize, + theta, + phi, + ) + + xlines, ylines = AverageIntersectionAlgorithm.reconstruct_common(p, x, y) + chi2best = 10**99 xbest = xcmass ybest = ycmass - factorbest = 1. + factorbest = 1.0 xbest2, ybest2, chi2best2, factorbest2 = cls.selectbest( - p, x, y, xbest, ybest, factorbest, chi2best, gridsize, theta, phi) + p, + x, + y, + xbest, + ybest, + factorbest, + chi2best, + gridsize, + theta, + phi, + ) if chi2best1 < chi2best2: chi2best = chi2best1 @@ -512,17 +509,26 @@ def reconstruct(cls, p, x, y, theta, phi): ybest = ybest2 factorbest = factorbest2 - gridsize = 2. + gridsize = 2.0 core_x, core_y, chi2best, factorbest = cls.selectbest( - p, x, y, xbest, ybest, factorbest, chi2best, gridsize, theta, phi) + p, + x, + y, + xbest, + ybest, + factorbest, + chi2best, + gridsize, + theta, + phi, + ) size = factorbest * ldf.EllipsLdf._n_electrons return core_x, core_y, chi2best, size @staticmethod - def selectbest(p, x, y, xstart, ystart, factorbest, chi2best, gridsize, - theta, phi): + def selectbest(p, x, y, xstart, ystart, factorbest, chi2best, gridsize, theta, phi): """selects the best core position in grid around (xstart, ystart). :param p: detector particle density in m^-2. @@ -540,22 +546,21 @@ def selectbest(p, x, y, xstart, ystart, factorbest, chi2best, gridsize, ytry = ystart + (i - 20) * gridsize xstations = array(x) ystations = array(y) - r, angle = a.calculate_core_distance_and_angle( - xstations, ystations, xtry, ytry) + r, angle = a.calculate_core_distance_and_angle(xstations, ystations, xtry, ytry) rho = a.calculate_ldf_value(r, angle) - mmdivk = 0. - m = 0. - k = 0. + mmdivk = 0.0 + m = 0.0 + k = 0.0 for i, j in zip(p, rho): - mmdivk += 1. * i * i / j + mmdivk += 1.0 * i * i / j m += i k += j sizefactor = sqrt(mmdivk / k) with warnings.catch_warnings(record=True): - chi2 = 2. * (sizefactor * k - m) + chi2 = 2.0 * (sizefactor * k - m) if chi2 < chi2best: factorbest = sizefactor xbest = xtry diff --git a/sapphire/analysis/direction_reconstruction.py b/sapphire/analysis/direction_reconstruction.py index dcca4f7f..4375eef3 100644 --- a/sapphire/analysis/direction_reconstruction.py +++ b/sapphire/analysis/direction_reconstruction.py @@ -1,25 +1,26 @@ -""" Direction reconstruction - - This module contains two classes that can be used to reconstruct - HiSPARC events and coincidences. The classes know how to extract the - relevant information from the station and event or cluster and - coincidence. Various algorithms which do the reconstruction are also - defined here. The algorithms require positions and arrival times to - do the reconstruction. - - Each algorithm has a :meth:`~BaseDirectionAlgorithm.reconstruct_common` - method which always requires arrival times, x, and y positions and - optionally z positions and previous reconstruction results. The data - is then prepared for the algorithm and passed to - the :meth:`~BaseDirectionAlgorithm.reconstruct` method which returns the - reconstructed theta and phi coordinates. +"""Direction reconstruction + +This module contains two classes that can be used to reconstruct +HiSPARC events and coincidences. The classes know how to extract the +relevant information from the station and event or cluster and +coincidence. Various algorithms which do the reconstruction are also +defined here. The algorithms require positions and arrival times to +do the reconstruction. + +Each algorithm has a :meth:`~BaseDirectionAlgorithm.reconstruct_common` +method which always requires arrival times, x, and y positions and +optionally z positions and previous reconstruction results. The data +is then prepared for the algorithm and passed to +the :meth:`~BaseDirectionAlgorithm.reconstruct` method which returns the +reconstructed theta and phi coordinates. """ + import warnings from itertools import combinations, zip_longest -import numpy +import numpy as np from numpy import ( arccos, @@ -47,12 +48,11 @@ from ..utils import c, floor_in_base, make_relative, memoize, norm_angle, pbar, vector_length from . import event_utils -NO_OFFSET = [0., 0., 0., 0.] -NO_STATION_OFFSET = (0., 100.) +NO_OFFSET = [0.0, 0.0, 0.0, 0.0] +NO_STATION_OFFSET = (0.0, 100.0) class EventDirectionReconstruction: - """Reconstruct direction for station events This class is aware of 'events' and 'stations'. Initialize this class @@ -69,8 +69,7 @@ def __init__(self, station): self.fit = RegressionAlgorithm3D self.station = station - def reconstruct_event(self, event, detector_ids=None, offsets=NO_OFFSET, - initial=None): + def reconstruct_event(self, event, detector_ids=None, offsets=NO_OFFSET, initial=None): """Reconstruct a single event :param event: an event (e.g. from an events table), or any @@ -91,10 +90,10 @@ def reconstruct_event(self, event, detector_ids=None, offsets=NO_OFFSET, self.station.cluster.set_timestamp(event['timestamp']) if isinstance(offsets, Station): offsets = offsets.detector_timing_offset(event['timestamp']) - for id in detector_ids: - t_detector = event_utils.detector_arrival_time(event, id, offsets) + for detector_id in detector_ids: + t_detector = event_utils.detector_arrival_time(event, detector_id, offsets) if not isnan(t_detector): - dx, dy, dz = self.station.detectors[id].get_coordinates() + dx, dy, dz = self.station.detectors[detector_id].get_coordinates() t.append(t_detector) x.append(dx) y.append(dy) @@ -108,8 +107,7 @@ def reconstruct_event(self, event, detector_ids=None, offsets=NO_OFFSET, theta, phi = (nan, nan) return theta, phi, ids - def reconstruct_events(self, events, detector_ids=None, offsets=NO_OFFSET, - progress=True, initials=None): + def reconstruct_events(self, events, detector_ids=None, offsets=NO_OFFSET, progress=True, initials=None): """Reconstruct events :param events: the events table for the station from an ESD data file. @@ -126,8 +124,7 @@ def reconstruct_events(self, events, detector_ids=None, offsets=NO_OFFSET, initials = [] events = pbar(events, show=progress) events_init = zip_longest(events, initials) - angles = [self.reconstruct_event(event, detector_ids, offsets, initial) - for event, initial in events_init] + angles = [self.reconstruct_event(event, detector_ids, offsets, initial) for event, initial in events_init] if len(angles): theta, phi, ids = zip(*angles) else: @@ -135,12 +132,10 @@ def reconstruct_events(self, events, detector_ids=None, offsets=NO_OFFSET, return theta, phi, ids def __repr__(self): - return ("<%s, station: %r, direct: %r, fit: %r>" % - (self.__class__.__name__, self.station, self.direct, self.fit)) + return '<%s, station: %r, direct: %r, fit: %r>' % (self.__class__.__name__, self.station, self.direct, self.fit) class CoincidenceDirectionReconstruction: - """Reconstruct direction for coincidences This class is aware of 'coincidences' and 'clusters'. Initialize @@ -158,8 +153,7 @@ def __init__(self, cluster): self.curved = CurvedRegressionAlgorithm3D() self.cluster = cluster - def reconstruct_coincidence(self, coincidence_events, station_numbers=None, - offsets=None, initial=None): + def reconstruct_coincidence(self, coincidence_events, station_numbers=None, offsets=None, initial=None): """Reconstruct a single coincidence :param coincidence_events: a coincidence list consisting of three @@ -185,8 +179,7 @@ def reconstruct_coincidence(self, coincidence_events, station_numbers=None, self.cluster.set_timestamp(ts0) t, x, y, z, nums = ([], [], [], [], []) - offsets = self.get_station_offsets(coincidence_events, station_numbers, - offsets, ts0) + offsets = self.get_station_offsets(coincidence_events, station_numbers, offsets, ts0) for station_number, event in coincidence_events: if station_numbers is not None: @@ -194,8 +187,7 @@ def reconstruct_coincidence(self, coincidence_events, station_numbers=None, continue t_off = offsets.get(station_number, NO_OFFSET) station = self.cluster.get_station(station_number) - t_first = event_utils.station_arrival_time( - event, ets0, offsets=t_off, station=station) + t_first = event_utils.station_arrival_time(event, ets0, offsets=t_off, station=station) if not isnan(t_first): sx, sy, sz = station.calc_center_of_mass_coordinates() t.append(t_first) @@ -215,8 +207,7 @@ def reconstruct_coincidence(self, coincidence_events, station_numbers=None, return theta, phi, nums - def reconstruct_coincidences(self, coincidences, station_numbers=None, - offsets=None, progress=True, initials=None): + def reconstruct_coincidences(self, coincidences, station_numbers=None, offsets=None, progress=True, initials=None): """Reconstruct all coincidences :param coincidences: a list of coincidence events, each consisting @@ -238,17 +229,17 @@ def reconstruct_coincidences(self, coincidences, station_numbers=None, initials = [] coincidences = pbar(coincidences, show=progress) coin_init = zip_longest(coincidences, initials) - angles = [self.reconstruct_coincidence(coincidence, station_numbers, - offsets, initial) - for coincidence, initial in coin_init] + angles = [ + self.reconstruct_coincidence(coincidence, station_numbers, offsets, initial) + for coincidence, initial in coin_init + ] if len(angles): theta, phi, nums = zip(*angles) else: theta, phi, nums = ((), (), ()) return theta, phi, nums - def get_station_offsets(self, coincidence_events, station_numbers, - offsets, ts0): + def get_station_offsets(self, coincidence_events, station_numbers, offsets, ts0): if offsets and isinstance(next(iter(offsets.values())), Station): if station_numbers is None: # stations in the coincidence @@ -256,8 +247,7 @@ def get_station_offsets(self, coincidence_events, station_numbers, else: stations = station_numbers midnight_ts = floor_in_base(ts0, 86400) - offsets = self.determine_best_offsets(stations, midnight_ts, - offsets) + offsets = self.determine_best_offsets(stations, midnight_ts, offsets) return offsets @memoize @@ -278,8 +268,7 @@ def determine_best_offsets(self, station_numbers, midnight_ts, offsets): relative to the reference station. """ - offset_stations = station_numbers + [sn for sn in list(offsets.keys()) - if sn not in station_numbers] + offset_stations = station_numbers + [sn for sn in list(offsets.keys()) if sn not in station_numbers] offset_matrix = zeros((len(offset_stations), len(offset_stations))) error_matrix = zeros((len(offset_stations), len(offset_stations))) @@ -287,8 +276,7 @@ def determine_best_offsets(self, station_numbers, midnight_ts, offsets): for i, sn in enumerate(offset_stations): for j, ref_sn in enumerate(offset_stations): try: - o, e = offsets[sn].station_timing_offset(ref_sn, - midnight_ts) + o, e = offsets[sn].station_timing_offset(ref_sn, midnight_ts) except Exception: o, e = NO_STATION_OFFSET else: @@ -296,25 +284,19 @@ def determine_best_offsets(self, station_numbers, midnight_ts, offsets): o, e = NO_STATION_OFFSET offset_matrix[i, j] = -o offset_matrix[j, i] = o - error_matrix[i, j] = e ** 2 - error_matrix[j, i] = e ** 2 + error_matrix[i, j] = e**2 + error_matrix[j, i] = e**2 - ref_sn, predecessors = self.determine_best_reference(error_matrix, - station_numbers) + ref_sn, predecessors = self.determine_best_reference(error_matrix, station_numbers) best_offsets = {} for sn in station_numbers: - best_offset = self._reconstruct_best_offset( - predecessors, sn, ref_sn, station_numbers, offset_matrix) - best_offsets[sn] = self._calculate_offsets(offsets[sn], - midnight_ts, - best_offset) + best_offset = self._reconstruct_best_offset(predecessors, sn, ref_sn, station_numbers, offset_matrix) + best_offsets[sn] = self._calculate_offsets(offsets[sn], midnight_ts, best_offset) return best_offsets def determine_best_reference(self, error_matrix, station_numbers): - paths, predecessors = shortest_path(error_matrix, method='FW', - directed=False, - return_predecessors=True) + paths, predecessors = shortest_path(error_matrix, method='FW', directed=False, return_predecessors=True) n = len(station_numbers) # Only consider station in coincidence for reference total_errors = paths[:n, :n].sum(axis=1) @@ -322,9 +304,8 @@ def determine_best_reference(self, error_matrix, station_numbers): return best_reference, predecessors - def _reconstruct_best_offset(self, predecessors, sn, ref_sn, - station_numbers, offset_matrix): - offset = 0. + def _reconstruct_best_offset(self, predecessors, sn, ref_sn, station_numbers, offset_matrix): + offset = 0.0 if sn != ref_sn: i = station_numbers.index(sn) j = station_numbers.index(ref_sn) @@ -348,14 +329,16 @@ def _calculate_offsets(self, station, ts0, offset): return [offset + d_off for d_off in detector_offsets] def __repr__(self): - return ("<%s, cluster: %r, direct: %r, fit: %r, curved: %r>" % - (self.__class__.__name__, self.cluster, self.direct, self.fit, - self.curved)) - + return '<%s, cluster: %r, direct: %r, fit: %r, curved: %r>' % ( + self.__class__.__name__, + self.cluster, + self.direct, + self.fit, + self.curved, + ) -class CoincidenceDirectionReconstructionDetectors( - CoincidenceDirectionReconstruction): +class CoincidenceDirectionReconstructionDetectors(CoincidenceDirectionReconstruction): """Reconstruct direction for coincidences using each detector Instead of only the first arrival time per station this class @@ -363,8 +346,7 @@ class CoincidenceDirectionReconstructionDetectors( """ - def reconstruct_coincidence(self, coincidence_events, station_numbers=None, - offsets=None, initial=None): + def reconstruct_coincidence(self, coincidence_events, station_numbers=None, offsets=None, initial=None): """Reconstruct a single coincidence :param coincidence_events: a coincidence list consisting of one @@ -391,8 +373,7 @@ def reconstruct_coincidence(self, coincidence_events, station_numbers=None, self.cluster.set_timestamp(ts0) t, x, y, z, nums = ([], [], [], [], []) - offsets = self.get_station_offsets(coincidence_events, station_numbers, - offsets, ts0) + offsets = self.get_station_offsets(coincidence_events, station_numbers, offsets, ts0) for station_number, event in coincidence_events: if station_numbers is not None: @@ -400,8 +381,7 @@ def reconstruct_coincidence(self, coincidence_events, station_numbers=None, continue t_off = offsets.get(station_number, NO_OFFSET) station = self.cluster.get_station(station_number) - t_detectors = event_utils.relative_detector_arrival_times( - event, ets0, offsets=t_off, station=station) + t_detectors = event_utils.relative_detector_arrival_times(event, ets0, offsets=t_off, station=station) for t_detector, detector in zip(t_detectors, station.detectors): if not isnan(t_detector): dx, dy, dz = detector.get_coordinates() @@ -425,7 +405,6 @@ def reconstruct_coincidence(self, coincidence_events, station_numbers=None, class BaseDirectionAlgorithm: - """No actual direction reconstruction algorithm Simply returns (nan, nan) as direction. @@ -457,7 +436,6 @@ def reconstruct(): class DirectAlgorithm(BaseDirectionAlgorithm): - """Reconstruct angles using direct analytical formula. This implements the equations derived in Fokkema2012 sec 4.2. @@ -514,19 +492,18 @@ def reconstruct(cls, dt1, dt2, r1, r2, phi1, phi2): # No time difference means shower came from zenith. return 0, 0 - phi = arctan2(-(r1 * dt2 * cos(phi1) - r2 * dt1 * cos(phi2)), - (r1 * dt2 * sin(phi1) - r2 * dt1 * sin(phi2))) + phi = arctan2(-(r1 * dt2 * cos(phi1) - r2 * dt1 * cos(phi2)), (r1 * dt2 * sin(phi1) - r2 * dt1 * sin(phi2))) # The directional vector c * dt should be negative, # not apparent in Fokkema2012 fig 4.4. theta = nan if r1 == 0 or r2 == 0: pass - elif not dt1 == 0 and not phi - phi1 == pi / 2: + elif dt1 != 0 and phi - phi1 != pi / 2: sintheta = c * -dt1 / (r1 * cos(phi - phi1)) if abs(sintheta) <= 1: theta = arcsin(sintheta) - elif not dt2 == 0 and not phi - phi2 == pi / 2: + elif dt2 != 0 and phi - phi2 != pi / 2: sintheta = c * -dt2 / (r2 * cos(phi - phi2)) if abs(sintheta) <= 1: theta = arcsin(sintheta) @@ -549,16 +526,19 @@ def rel_theta1_errorsq(cls, theta, phi, phi1, phi2, r1=10, r2=10): sintheta = sin(theta) sinphiphi1 = sin(phi - phi1) - den = r1 ** 2 * (1 - sintheta ** 2) * cos(phi - phi1) ** 2 + den = r1**2 * (1 - sintheta**2) * cos(phi - phi1) ** 2 - aa = (r1 ** 2 * sinphiphi1 ** 2 * - cls.rel_phi_errorsq(theta, phi, phi1, phi2, r1, r2)) - bb = -(2 * r1 * c * sinphiphi1 * - (cls.dphi_dt0(theta, phi, phi1, phi2, r1, r2) - - cls.dphi_dt1(theta, phi, phi1, phi2, r1, r2))) - cc = 2 * c ** 2 + aa = r1**2 * sinphiphi1**2 * cls.rel_phi_errorsq(theta, phi, phi1, phi2, r1, r2) + bb = -( + 2 + * r1 + * c + * sinphiphi1 + * (cls.dphi_dt0(theta, phi, phi1, phi2, r1, r2) - cls.dphi_dt1(theta, phi, phi1, phi2, r1, r2)) + ) + cc = 2 * c**2 - errsq = (aa * sintheta ** 2 + bb * sintheta + cc) / den + errsq = (aa * sintheta**2 + bb * sintheta + cc) / den return where(isnan(errsq), inf, errsq) @@ -569,16 +549,19 @@ def rel_theta2_errorsq(cls, theta, phi, phi1, phi2, r1=10, r2=10): sintheta = sin(theta) sinphiphi2 = sin(phi - phi2) - den = r2 ** 2 * (1 - sintheta ** 2) * cos(phi - phi2) ** 2 + den = r2**2 * (1 - sintheta**2) * cos(phi - phi2) ** 2 - aa = (r2 ** 2 * sinphiphi2 ** 2 * - cls.rel_phi_errorsq(theta, phi, phi1, phi2, r1, r2)) - bb = -(2 * r2 * c * sinphiphi2 * - (cls.dphi_dt0(theta, phi, phi1, phi2, r1, r2) - - cls.dphi_dt2(theta, phi, phi1, phi2, r1, r2))) - cc = 2 * c ** 2 + aa = r2**2 * sinphiphi2**2 * cls.rel_phi_errorsq(theta, phi, phi1, phi2, r1, r2) + bb = -( + 2 + * r2 + * c + * sinphiphi2 + * (cls.dphi_dt0(theta, phi, phi1, phi2, r1, r2) - cls.dphi_dt2(theta, phi, phi1, phi2, r1, r2)) + ) + cc = 2 * c**2 - errsq = (aa * sintheta ** 2 + bb * sintheta + cc) / den + errsq = (aa * sintheta**2 + bb * sintheta + cc) / den return where(isnan(errsq), inf, errsq) @@ -592,28 +575,30 @@ def rel_phi_errorsq(theta, phi, phi1, phi2, r1=10, r2=10): sinphi2 = sin(phi2) cosphi2 = cos(phi2) - den = ((1 + tanphi ** 2) ** 2 * r1 ** 2 * r2 ** 2 * sin(theta) ** 2 * - (sinphi1 * cos(phi - phi2) - sinphi2 * cos(phi - phi1)) ** 2 / - c ** 2) + den = ( + (1 + tanphi**2) ** 2 + * r1**2 + * r2**2 + * sin(theta) ** 2 + * (sinphi1 * cos(phi - phi2) - sinphi2 * cos(phi - phi1)) ** 2 + / c**2 + ) - aa = (r1 ** 2 * sinphi1 ** 2 + - r2 ** 2 * sinphi2 ** 2 - - r1 * r2 * sinphi1 * sinphi2) - bb = (2 * r1 ** 2 * sinphi1 * cosphi1 + - 2 * r2 ** 2 * sinphi2 * cosphi2 - - r1 * r2 * (sinphi2 * cosphi1 + sinphi1 * cosphi2)) - cc = (r1 ** 2 * cosphi1 ** 2 + - r2 ** 2 * cosphi2 ** 2 - - r1 * r2 * cosphi1 * cosphi2) + aa = r1**2 * sinphi1**2 + r2**2 * sinphi2**2 - r1 * r2 * sinphi1 * sinphi2 + bb = ( + 2 * r1**2 * sinphi1 * cosphi1 + + 2 * r2**2 * sinphi2 * cosphi2 + - r1 * r2 * (sinphi2 * cosphi1 + sinphi1 * cosphi2) + ) + cc = r1**2 * cosphi1**2 + r2**2 * cosphi2**2 - r1 * r2 * cosphi1 * cosphi2 - return 2 * (aa * tanphi ** 2 + bb * tanphi + cc) / den + return 2 * (aa * tanphi**2 + bb * tanphi + cc) / den @classmethod def dphi_dt0(cls, theta, phi, phi1, phi2, r1=10, r2=10): """Fokkema2012, eq 4.19""" - return -(cls.dphi_dt1(theta, phi, phi1, phi2, r1, r2) + - cls.dphi_dt2(theta, phi, phi1, phi2, r1, r2)) + return -(cls.dphi_dt1(theta, phi, phi1, phi2, r1, r2) + cls.dphi_dt2(theta, phi, phi1, phi2, r1, r2)) @staticmethod def dphi_dt1(theta, phi, phi1, phi2, r1=10, r2=10): @@ -624,9 +609,7 @@ def dphi_dt1(theta, phi, phi1, phi2, r1=10, r2=10): sinphi2 = sin(phi2) cosphi2 = cos(phi2) - den = ((1 + tanphi ** 2) * r1 * r2 * sin(theta) * - (sinphi2 * cos(phi - phi1) - sinphi1 * cos(phi - phi2)) / - c) + den = (1 + tanphi**2) * r1 * r2 * sin(theta) * (sinphi2 * cos(phi - phi1) - sinphi1 * cos(phi - phi2)) / c num = -r2 * (sinphi2 * tanphi + cosphi2) return num / den @@ -640,16 +623,13 @@ def dphi_dt2(theta, phi, phi1, phi2, r1=10, r2=10): cosphi1 = cos(phi1) sinphi2 = sin(phi2) - den = ((1 + tanphi ** 2) * r1 * r2 * sin(theta) * - (sinphi2 * cos(phi - phi1) - sinphi1 * cos(phi - phi2)) / - c) + den = (1 + tanphi**2) * r1 * r2 * sin(theta) * (sinphi2 * cos(phi - phi1) - sinphi1 * cos(phi - phi2)) / c num = r1 * (sinphi1 * tanphi + cosphi1) return num / den class DirectAlgorithmCartesian(BaseDirectionAlgorithm): - """Reconstruct angles using direct analytical formula. This implements the equations derived in Montanus2014. @@ -704,7 +684,7 @@ def reconstruct(dt1, dt2, dx1, dx2, dy1, dy2): theta = nan phi = nan - if not vz == 0: + if vz != 0: usquared = ux * ux + uy * uy vzsquared = vz * vz uvzsqrt = sqrt(usquared / vzsquared) @@ -716,7 +696,6 @@ def reconstruct(dt1, dt2, dx1, dx2, dy1, dy2): class DirectAlgorithmCartesian3D(BaseDirectionAlgorithm): - """Reconstruct angles using direct analytical formula. This implements the equations derived in Montanus2014. @@ -752,8 +731,7 @@ def reconstruct_common(cls, t, x, y, z=None, initial=None): dy = make_relative(y) dz = make_relative(z) - return cls.reconstruct(dt[1], dt[2], dx[1], dx[2], dy[1], dy[2], dz[1], - dz[2]) + return cls.reconstruct(dt[1], dt[2], dx[1], dx[2], dy[1], dy[2], dz[1], dz[2]) @staticmethod def reconstruct(dt1, dt2, dx1, dx2, dy1, dy2, dz1=0, dz2=0): @@ -780,7 +758,7 @@ def reconstruct(dt1, dt2, dx1, dx2, dy1, dy2, dz1=0, dz2=0): theta = nan phi = nan - if underroot > 0 and not vsquared == 0: + if underroot > 0 and vsquared != 0: term = v * sqrt(underroot) nplus = (uxv + term) / vsquared nmin = (uxv - term) / vsquared @@ -798,10 +776,10 @@ def reconstruct(dt1, dt2, dx1, dx2, dy1, dy2, dz1=0, dz2=0): thetamin = pi # Allow solution only if it is the only one above horizon - if thetaplus <= pi / 2. and thetamin > pi / 2.: + if thetaplus <= pi / 2.0 and thetamin > pi / 2.0: theta = thetaplus phi = phiplus - elif thetaplus > pi / 2. and thetamin <= pi / 2.: + elif thetaplus > pi / 2.0 and thetamin <= pi / 2.0: theta = thetamin phi = phimin @@ -809,7 +787,6 @@ def reconstruct(dt1, dt2, dx1, dx2, dy1, dy2, dz1=0, dz2=0): class SphereAlgorithm: - """Reconstruct the direction in equatorial coordinates Note: currently incompatible with the other algorithms! @@ -836,9 +813,7 @@ def reconstruct_equatorial(cls, t, x, y, z, timestamp): """ t_int = array([-1000, -10000]) + t[0] x_int, y_int, z_int = cls.interaction_curve(x, y, z, t, t_int) - dec = arctan2(z_int[1] - z_int[0], - sqrt((x_int[1] - x_int[0]) ** 2. + - (y_int[1] - y_int[0]) ** 2.)) + dec = arctan2(z_int[1] - z_int[0], sqrt((x_int[1] - x_int[0]) ** 2.0 + (y_int[1] - y_int[0]) ** 2.0)) ra = arctan2(x_int[1] - x_int[0], y_int[1] - y_int[0]) return dec, ra @@ -866,30 +841,26 @@ def interaction_curve(x, y, z, t, t_int): t01 = t[0] - t[1] t02 = t[0] - t[2] - a = 2. * (x01 * y02 - x02 * y01) - b = 2. * (x02 * z01 - x01 * z02) - h = 2. * (x02 * t01 - x01 * t02) * c ** 2 - d = (x02 * (x01 ** 2 + y01 ** 2 + z01 ** 2 - (t01 * c) ** 2) - - x01 * (x02 ** 2 + y02 ** 2 + z02 ** 2 - (t02 * c) ** 2)) - e = 2. * (y01 * z02 - y02 * z01) - f = 2. * (y01 * t02 - y02 * t01) * c ** 2 - g = (y01 * (x02 ** 2 + y02 ** 2 + z02 ** 2 - (t02 * c) ** 2) - - y02 * (x01 ** 2 + y01 ** 2 + z01 ** 2 - (t01 * c) ** 2)) - - t = a ** 2 + b ** 2 + e ** 2 + a = 2.0 * (x01 * y02 - x02 * y01) + b = 2.0 * (x02 * z01 - x01 * z02) + h = 2.0 * (x02 * t01 - x01 * t02) * c**2 + d = x02 * (x01**2 + y01**2 + z01**2 - (t01 * c) ** 2) - x01 * (x02**2 + y02**2 + z02**2 - (t02 * c) ** 2) + e = 2.0 * (y01 * z02 - y02 * z01) + f = 2.0 * (y01 * t02 - y02 * t01) * c**2 + g = y01 * (x02**2 + y02**2 + z02**2 - (t02 * c) ** 2) - y02 * (x01**2 + y01**2 + z01**2 - (t01 * c) ** 2) + + t = a**2 + b**2 + e**2 v = (b * h + e * f) / t w = (b * d + e * g) / t - p = (d ** 2 + g ** 2) / t + p = (d**2 + g**2) / t q = 2 * (h * d + f * g) / t - r = (h ** 2 + f ** 2 - (a * c) ** 2) / t + r = (h**2 + f**2 - (a * c) ** 2) / t t_int0 = t_int - t[0] sign = 1 - z = -v * t_int0 - w + sign * sqrt((v ** 2 - r) * t_int0 ** 2 + - (2 * v * w - q) * t_int0 + - w ** 2 - p) + z = -v * t_int0 - w + sign * sqrt((v**2 - r) * t_int0**2 + (2 * v * w - q) * t_int0 + w**2 - p) y = (b * z + h * t_int0 + d) / a x = (e * z + f * t_int0 + g) / a @@ -904,9 +875,7 @@ def interaction_curve(x, y, z, t, t_int): # Select interaction above the earths surface. sign = -1 - z = -v * t_int0 - w + sign * sqrt((v ** 2 - r) * t_int0 ** 2 + - (2 * v * w - q) * t_int0 + - w ** 2 - p) + z = -v * t_int0 - w + sign * sqrt((v**2 - r) * t_int0**2 + (2 * v * w - q) * t_int0 + w**2 - p) y = (b * z + h * t_int0 + d) / a x = (e * z + f * t_int0 + g) / a @@ -918,7 +887,6 @@ def interaction_curve(x, y, z, t, t_int): class FitAlgorithm3D(BaseDirectionAlgorithm): - @classmethod def reconstruct_common(cls, t, x, y, z=None, initial=None): """Reconstruct angles from 3 or more detections @@ -958,11 +926,15 @@ def reconstruct(cls, t, x, y, z): cons = {'type': 'eq', 'fun': cls.constraint_normal_vector} - fit = minimize(cls.best_fit, x0=(0.1, 0.1, 0.989, 0.), - args=(dt, dx, dy, dz), method="SLSQP", - bounds=((-1, 1), (-1, 1), (-1, 1), (None, None)), - constraints=cons, - options={'ftol': 1e-9, 'eps': 1e-7, 'maxiter': 50}) + fit = minimize( + cls.best_fit, + x0=(0.1, 0.1, 0.989, 0.0), + args=(dt, dx, dy, dz), + method='SLSQP', + bounds=((-1, 1), (-1, 1), (-1, 1), (None, None)), + constraints=cons, + options={'ftol': 1e-9, 'eps': 1e-7, 'maxiter': 50}, + ) if fit.success: phi1 = arctan2(fit.x[1], fit.x[0]) theta1 = arccos(fit.x[2]) @@ -970,11 +942,15 @@ def reconstruct(cls, t, x, y, z): phi1 = nan theta1 = nan - fit = minimize(cls.best_fit, x0=(-0.1, -0.1, -0.989, 0.), - args=(dt, dx, dy, dz), method="SLSQP", - bounds=((-1, 1), (-1, 1), (-1, 1), (None, None)), - constraints=cons, - options={'ftol': 1e-9, 'eps': 1e-7, 'maxiter': 50}) + fit = minimize( + cls.best_fit, + x0=(-0.1, -0.1, -0.989, 0.0), + args=(dt, dx, dy, dz), + method='SLSQP', + bounds=((-1, 1), (-1, 1), (-1, 1), (None, None)), + constraints=cons, + options={'ftol': 1e-9, 'eps': 1e-7, 'maxiter': 50}, + ) if fit.success: phi2 = arctan2(fit.x[1], fit.x[0]) theta2 = arccos(fit.x[2]) @@ -986,10 +962,10 @@ def reconstruct(cls, t, x, y, z): # and the other is either nan or larger than pi/2 (shower from below), # the first one is considered correct. # If both come from above (or from below), both theta's are rejected. - if theta1 <= pi / 2. and (isnan(theta2) or theta2 > pi / 2.): + if theta1 <= pi / 2.0 and (isnan(theta2) or theta2 > pi / 2.0): theta = theta1 phi = phi1 - elif (isnan(theta1) or theta1 > pi / 2.) and theta2 <= pi / 2.: + elif (isnan(theta1) or theta1 > pi / 2.0) and theta2 <= pi / 2.0: theta = theta2 phi = phi2 else: @@ -1016,15 +992,11 @@ def best_fit(n_xyz, dt, dx, dy, dz): """ nx, ny, nz, m = n_xyz - slq = sum( - (nx * xi + ny * yi + zi * nz + c * ti + m) ** 2 - for ti, xi, yi, zi in zip(dt, dx, dy, dz) - ) + slq = sum((nx * xi + ny * yi + zi * nz + c * ti + m) ** 2 for ti, xi, yi, zi in zip(dt, dx, dy, dz)) return slq + m * m class RegressionAlgorithm(BaseDirectionAlgorithm): - """Reconstruct angles using an analytical regression formula. This implements the equations as for ISVHECRI (Montanus 2014). @@ -1064,42 +1036,39 @@ def reconstruct(cls, t, x, y): return nan, nan k = len(t) - xs = numpy.sum(x) - ys = numpy.sum(y) - ts = numpy.sum(t) + xs = np.sum(x) + ys = np.sum(y) + ts = np.sum(t) - xx = 0. - yy = 0. - tx = 0. - ty = 0. - xy = 0. + xx = 0.0 + yy = 0.0 + tx = 0.0 + ty = 0.0 + xy = 0.0 for ti, xi, yi in zip(t, x, y): - xx += xi ** 2 - yy += yi ** 2 + xx += xi**2 + yy += yi**2 tx += ti * xi ty += ti * yi xy += xi * yi - denom = (k * xy ** 2 + xs ** 2 * yy + ys ** 2 * xx - k * xx * yy - - 2 * xs * ys * xy) + denom = k * xy**2 + xs**2 * yy + ys**2 * xx - k * xx * yy - 2 * xs * ys * xy if denom == 0: denom = nan - numer = (tx * (k * yy - ys ** 2) + xy * (ts * ys - k * ty) + - xs * ys * ty - ts * xs * yy) + numer = tx * (k * yy - ys**2) + xy * (ts * ys - k * ty) + xs * ys * ty - ts * xs * yy nx = c * numer / denom - numer = (ty * (k * xx - xs ** 2) + xy * (ts * xs - k * tx) + - xs * ys * tx - ts * ys * xx) + numer = ty * (k * xx - xs**2) + xy * (ts * xs - k * tx) + xs * ys * tx - ts * ys * xx ny = c * numer / denom - horiz = nx ** 2 + ny ** 2 - if horiz > 1.: + horiz = nx**2 + ny**2 + if horiz > 1.0: theta = nan phi = nan else: - nz = sqrt(1 - nx ** 2 - ny ** 2) + nz = sqrt(1 - nx**2 - ny**2) phi = arctan2(ny, nx) theta = arccos(nz) @@ -1107,7 +1076,6 @@ def reconstruct(cls, t, x, y): class RegressionAlgorithm3D(BaseDirectionAlgorithm): - """Reconstruct angles by iteratively applying a regression formula. This implements the equations as recently derived (Montanus 2014). @@ -1153,7 +1121,7 @@ def reconstruct(cls, t, x, y, z): regress2d = RegressionAlgorithm() theta, phi = regress2d.reconstruct_common(t, x, y) - dtheta = 1. + dtheta = 1.0 iteration = 0 while dtheta > 0.001: iteration += 1 @@ -1173,7 +1141,6 @@ def reconstruct(cls, t, x, y, z): class CurvedMixin: - """Provide methods to estimate the time delay due to front curvature Given a core location, detector position, and shower angle the radial core @@ -1200,12 +1167,10 @@ def radial_core_distance(cls, x, y, core_x, core_y, theta, phi): dy = core_y - y nx = sin(theta) * cos(phi) ny = sin(theta) * sin(phi) - return sqrt(dx ** 2 * (1 - nx ** 2) + dy ** 2 * (1 - ny ** 2) - - 2 * dx * dy * nx * ny) + return sqrt(dx**2 * (1 - nx**2) + dy**2 * (1 - ny**2) - 2 * dx * dy * nx * ny) class CurvedRegressionAlgorithm(CurvedMixin, BaseDirectionAlgorithm): - """Reconstruct angles taking the shower front curvature into account. Take the shower front curvature into account. Assumes knowledge about the @@ -1256,14 +1221,13 @@ def reconstruct(self, t, x, y, core_x, core_y): regress2d = RegressionAlgorithm() theta, phi = regress2d.reconstruct_common(t, x, y) - dtheta = 1. + dtheta = 1.0 iteration = 0 while dtheta > 0.001: iteration += 1 if iteration > self.MAX_ITERATIONS: return nan, nan - t_proj = [ti - self.time_delay(xi, yi, core_x, core_y, theta, phi) - for ti, xi, yi in zip(t, x, y)] + t_proj = [ti - self.time_delay(xi, yi, core_x, core_y, theta, phi) for ti, xi, yi in zip(t, x, y)] theta_prev = theta theta, phi = regress2d.reconstruct_common(t_proj, x, y) dtheta = abs(theta - theta_prev) @@ -1272,7 +1236,6 @@ def reconstruct(self, t, x, y, core_x, core_y): class CurvedRegressionAlgorithm3D(CurvedMixin, BaseDirectionAlgorithm): - """Reconstruct angles accounting for front curvature and detector altitudes Take the shower front curvature and different detector heights into @@ -1326,7 +1289,7 @@ def reconstruct(self, t, x, y, z, core_x, core_y): regress2d = RegressionAlgorithm() theta, phi = regress2d.reconstruct_common(t, x, y) - dtheta = 1. + dtheta = 1.0 iteration = 0 while dtheta > 0.001: iteration += 1 @@ -1337,9 +1300,10 @@ def reconstruct(self, t, x, y, z, core_x, core_y): nz = cos(theta) x_proj = [xi - zi * nxnz for xi, zi in zip(x, z)] y_proj = [yi - zi * nynz for yi, zi in zip(y, z)] - t_proj = [ti + zi / (c * nz) - - self.time_delay(xpi, ypi, core_x, core_y, theta, phi) - for ti, xpi, ypi, zi in zip(t, x_proj, y_proj, z)] + t_proj = [ + ti + zi / (c * nz) - self.time_delay(xpi, ypi, core_x, core_y, theta, phi) + for ti, xpi, ypi, zi in zip(t, x_proj, y_proj, z) + ] theta_prev = theta theta, phi = regress2d.reconstruct_common(t_proj, x_proj, y_proj) dtheta = abs(theta - theta_prev) @@ -1369,7 +1333,7 @@ def logic_checks(t, x, y, z): # Check for identical positions if len(t) == 3: xyz = list(zip(x, y, z)) - if not len(xyz) == len(set(xyz)): + if len(xyz) != len(set(xyz)): return False txyz = list(zip(t, x, y, z)) @@ -1402,8 +1366,7 @@ def logic_checks(t, x, y, z): lenvec12 = vector_length(dx3, dy3, dz3) # area triangle is |cross product| - area = abs(dx1 * dy2 - dx2 * dy1 + dy1 * dz2 - dy2 * dz1 + - dz1 * dx2 - dz2 * dx1) + area = abs(dx1 * dy2 - dx2 * dy1 + dy1 * dz2 - dy2 * dz1 + dz1 * dx2 - dz2 * dx1) # prevent floating point errors if area < 1e-7: @@ -1418,8 +1381,7 @@ def logic_checks(t, x, y, z): smallest_angle = min(sin1, sin2, sin3) # remember largest of smallest sines - largest_of_smallest_angles = max(largest_of_smallest_angles, - smallest_angle) + largest_of_smallest_angles = max(largest_of_smallest_angles, smallest_angle) # discard reconstruction if the largest of the smallest angles of each # triangle is smaller than 0.1 rad (5.73 degrees) diff --git a/sapphire/analysis/event_utils.py b/sapphire/analysis/event_utils.py index e05f6e07..dc196455 100644 --- a/sapphire/analysis/event_utils.py +++ b/sapphire/analysis/event_utils.py @@ -1,4 +1,4 @@ -""" Get data from HiSPARC events +"""Get data from HiSPARC events This module contains functions to derive data from HiSPARC events. Common tasks for data reconstruction are getting the particle density @@ -7,11 +7,12 @@ times and trigger time) and stations. """ + from numpy import nan, nanmean, nanmin from ..utils import ERR -NO_OFFSET = [0., 0., 0., 0.] +NO_OFFSET = [0.0, 0.0, 0.0, 0.0] def station_density(event, detector_ids=None, station=None): @@ -28,8 +29,7 @@ def station_density(event, detector_ids=None, station=None): """ if detector_ids is None: detector_ids = get_detector_ids(station, event) - p = nanmean(detector_densities(event, detector_ids=detector_ids, - station=station)) + p = nanmean(detector_densities(event, detector_ids=detector_ids, station=station)) return p @@ -45,7 +45,7 @@ def detector_densities(event, detector_ids=None, station=None): """ if detector_ids is None: detector_ids = get_detector_ids(station, event) - p = [detector_density(event, id, station) for id in detector_ids] + p = [detector_density(event, detector_id, station) for detector_id in detector_ids] return p @@ -71,8 +71,7 @@ def detector_density(event, detector_id, station=None): return p -def station_arrival_time(event, reference_ext_timestamp, - detector_ids=None, offsets=NO_OFFSET, station=None): +def station_arrival_time(event, reference_ext_timestamp, detector_ids=None, offsets=NO_OFFSET, station=None): """Get station arrival time, i.e. first detector hit Arrival time of first detector hit in the station. The returned time @@ -96,16 +95,12 @@ def station_arrival_time(event, reference_ext_timestamp, if event['t_trigger'] in ERR: t = nan else: - t_first = nanmin(detector_arrival_times(event, detector_ids, offsets, - station)) - t = ((int(event['ext_timestamp']) - int(reference_ext_timestamp)) - - event['t_trigger'] + t_first) + t_first = nanmin(detector_arrival_times(event, detector_ids, offsets, station)) + t = (int(event['ext_timestamp']) - int(reference_ext_timestamp)) - event['t_trigger'] + t_first return t -def relative_detector_arrival_times(event, reference_ext_timestamp, - detector_ids=None, offsets=NO_OFFSET, - station=None): +def relative_detector_arrival_times(event, reference_ext_timestamp, detector_ids=None, offsets=NO_OFFSET, station=None): """Get relative arrival times for all detectors :param event: Processed event row. @@ -124,16 +119,15 @@ def relative_detector_arrival_times(event, reference_ext_timestamp, if event['t_trigger'] in ERR: t = [nan] * len(detector_ids) else: - arrival_times = detector_arrival_times(event, detector_ids, - offsets, station) - t = [(int(event['ext_timestamp']) - int(reference_ext_timestamp)) - - event['t_trigger'] + arrival_time - for arrival_time in arrival_times] + arrival_times = detector_arrival_times(event, detector_ids, offsets, station) + t = [ + (int(event['ext_timestamp']) - int(reference_ext_timestamp)) - event['t_trigger'] + arrival_time + for arrival_time in arrival_times + ] return t -def detector_arrival_times(event, detector_ids=None, offsets=NO_OFFSET, - station=None): +def detector_arrival_times(event, detector_ids=None, offsets=NO_OFFSET, station=None): """Get corrected arrival times for all detectors :param event: Processed event row. @@ -146,7 +140,7 @@ def detector_arrival_times(event, detector_ids=None, offsets=NO_OFFSET, """ if detector_ids is None: detector_ids = get_detector_ids(station, event) - t = [detector_arrival_time(event, id, offsets) for id in detector_ids] + t = [detector_arrival_time(event, detector_id, offsets) for detector_id in detector_ids] return t @@ -183,8 +177,7 @@ def get_detector_ids(station=None, event=None): if station is not None: detector_ids = list(range(len(station.detectors))) elif event is not None: - detector_ids = [i for i, ph in enumerate(event['pulseheights']) - if ph != -1] + detector_ids = [i for i, ph in enumerate(event['pulseheights']) if ph != -1] else: detector_ids = list(range(4)) return detector_ids diff --git a/sapphire/analysis/find_mpv.py b/sapphire/analysis/find_mpv.py index fae6ce65..4a87a729 100644 --- a/sapphire/analysis/find_mpv.py +++ b/sapphire/analysis/find_mpv.py @@ -4,6 +4,7 @@ find the most probable value in a HiSPARC spectrum """ + import warnings from scipy.optimize import curve_fit @@ -62,7 +63,7 @@ def find_mpv(self): try: mpv = self.fit_mpv(first_guess) except RuntimeError: - warnings.warn("Fit failed") + warnings.warn('Fit failed') return -999, False else: return mpv, True @@ -101,7 +102,7 @@ def find_first_guess_mpv(self): # calculate position of most probable value idx_mpv = idx_right_max + idx_greatest_decrease + left_idx - mpv = (bins[idx_mpv] + bins[idx_mpv + 1]) / 2. + mpv = (bins[idx_mpv] + bins[idx_mpv + 1]) / 2.0 return mpv @@ -123,11 +124,11 @@ def fit_mpv(self, first_guess, width_factor=MPV_FIT_WIDTH_FACTOR): """ n, bins = self.n, self.bins - bins_x = (bins[:-1] + bins[1:]) / 2. + bins_x = (bins[:-1] + bins[1:]) / 2.0 # calculate fit domain - left = (1. - width_factor) * first_guess - right = (1. + width_factor) * first_guess + left = (1.0 - width_factor) * first_guess + right = (1.0 + width_factor) * first_guess # bracket histogram data x = bins_x.compress((left <= bins_x) & (bins_x < right)) @@ -136,16 +137,15 @@ def fit_mpv(self, first_guess, width_factor=MPV_FIT_WIDTH_FACTOR): # sanity check: number of data points must be at least equal to # the number of fit parameters if len(x) < 3: - raise RuntimeError("Number of data points not sufficient") + raise RuntimeError('Number of data points not sufficient') # fit to a normal distribution - popt, pcov = curve_fit(gauss, x, y, - p0=(y.max(), first_guess, first_guess)) + popt, pcov = curve_fit(gauss, x, y, p0=(y.max(), first_guess, first_guess)) mpv = popt[1] # sanity check: if MPV is outside domain, the MIP peak was not # bracketed correctly if mpv < x[0] or mpv > x[-1]: - raise RuntimeError("Fitted MPV value outside range") + raise RuntimeError('Fitted MPV value outside range') return mpv diff --git a/sapphire/analysis/landau.py b/sapphire/analysis/landau.py index b120d929..546738a2 100644 --- a/sapphire/analysis/landau.py +++ b/sapphire/analysis/landau.py @@ -1,20 +1,21 @@ -""" Landau distribution function +"""Landau distribution function - This module computes the Landau distribution, which governs the - fluctuations in energy loss of particles travelling through a - relatively thin layer of matter. +This module computes the Landau distribution, which governs the +fluctuations in energy loss of particles travelling through a +relatively thin layer of matter. - Currently, this module only contains functions to calculate the exact - function using two integral representations of the defining complex - integral. This should be extended by approximations when the need for - doing serious work arises. +Currently, this module only contains functions to calculate the exact +function using two integral representations of the defining complex +integral. This should be extended by approximations when the need for +doing serious work arises. - References are made to Fokkema2012, DOI: 10.3990/1.9789036534383. +References are made to Fokkema2012, DOI: 10.3990/1.9789036534383. """ + import warnings -from numpy import Inf, arctan, convolve, cos, exp, interp, linspace, log, pi, sin, vectorize +from numpy import arctan, convolve, cos, exp, inf, interp, linspace, log, pi, sin, vectorize from scipy import integrate, stats @@ -26,19 +27,20 @@ def pdf(lf): """ if lf < -10: - return 0. + return 0.0 elif lf < 0: sf = exp(-lf - 1) - integrant = integrate.quad(pdf_kernel, 0, Inf, args=(sf,))[0] + integrant = integrate.quad(pdf_kernel, 0, inf, args=(sf,))[0] return 1 / pi * exp(-sf) * integrant else: - integrant = integrate.quad(pdf_kernel2, 0, Inf, args=(lf,))[0] + integrant = integrate.quad(pdf_kernel2, 0, inf, args=(lf,))[0] return 1 / pi * integrant def pdf_kernel(y, sf): - return (exp(sf / 2 * log(1 + y ** 2 / sf ** 2) - y * arctan(y / sf)) * - cos(.5 * y * log(1 + y ** 2 / sf ** 2) - y + sf * arctan(y / sf))) + return exp(sf / 2 * log(1 + y**2 / sf**2) - y * arctan(y / sf)) * cos( + 0.5 * y * log(1 + y**2 / sf**2) - y + sf * arctan(y / sf), + ) def pdf_kernel2(u, lf): @@ -47,7 +49,7 @@ def pdf_kernel2(u, lf): Fokkema2012, eq 2.13. """ - return exp(-lf * u) * u ** -u * sin(pi * u) + return exp(-lf * u) * u**-u * sin(pi * u) class Scintillator: @@ -107,8 +109,7 @@ def pdf(self, lf): self.pdf_values = pdf(self.pdf_domain) return self.pdf(lf) - def conv_landau_for_x(self, x, count_scale=1, mev_scale=None, - gauss_scale=None): + def conv_landau_for_x(self, x, count_scale=1, mev_scale=None, gauss_scale=None): """Landau convolved with Gaussian Fokkema2012, eq 5.4. @@ -135,8 +136,7 @@ def conv_landau_for_x(self, x, count_scale=1, mev_scale=None, y = interp(x, x_calc, y_calc) return y - def conv_landau(self, x, count_scale=1, mev_scale=None, - gauss_scale=None): + def conv_landau(self, x, count_scale=1, mev_scale=None, gauss_scale=None): """Bare-bones convoluted landau function This thing is fragile. Use with great care! First and foremost, @@ -163,21 +163,17 @@ def residuals(self, p, xdata, ydata, a, b): self.mev_scale = mev_scale self.gauss_scale = gauss_scale - return self._residuals(xdata, ydata, mev_scale, count_scale, - gauss_scale, a, b) + return self._residuals(xdata, ydata, mev_scale, count_scale, gauss_scale, a, b) def constrained_residuals(self, p, xdata, ydata, a, b): count_scale = p mev_scale = self.mev_scale gauss_scale = self.gauss_scale - return self._residuals(xdata, ydata, mev_scale, count_scale, - gauss_scale, a, b) + return self._residuals(xdata, ydata, mev_scale, count_scale, gauss_scale, a, b) - def _residuals(self, xdata, ydata, mev_scale, count_scale, - gauss_scale, a, b): - yfit = self.conv_landau_for_x(xdata, count_scale, mev_scale, - gauss_scale) + def _residuals(self, xdata, ydata, mev_scale, count_scale, gauss_scale, a, b): + yfit = self.conv_landau_for_x(xdata, count_scale, mev_scale, gauss_scale) yfit = yfit.compress((a <= xdata) & (xdata < b)) ydata = ydata.compress((a <= xdata) & (xdata < b)) @@ -194,7 +190,7 @@ def discrete_convolution(f, g, t): """ if abs(min(t) + max(t)) > 1e-6: - raise RuntimeError("Range needs to be symmetrical around zero.") + raise RuntimeError('Range needs to be symmetrical around zero.') dt = t[1] - t[0] return dt * convolve(f(t), g(t), mode='same') diff --git a/sapphire/analysis/process_events.py b/sapphire/analysis/process_events.py index dbc7d7e1..b6661a7e 100644 --- a/sapphire/analysis/process_events.py +++ b/sapphire/analysis/process_events.py @@ -1,32 +1,33 @@ -""" Process HiSPARC events +"""Process HiSPARC events - This module can be used analyse data to get observables like arrival - times and particle count in each detector for each event. +This module can be used analyse data to get observables like arrival +times and particle count in each detector for each event. - Example usage:: +Example usage:: - import datetime + import datetime - import tables + import tables - from sapphire.publicdb import download_data - from sapphire import ProcessEvents + from sapphire.publicdb import download_data + from sapphire import ProcessEvents - STATIONS = [501, 503, 506] - START = datetime.datetime(2013, 1, 1) - END = datetime.datetime(2013, 1, 2) + STATIONS = [501, 503, 506] + START = datetime.datetime(2013, 1, 1) + END = datetime.datetime(2013, 1, 2) - if __name__ == '__main__': - station_groups = ['/s%d' % u for u in STATIONS] + if __name__ == '__main__': + station_groups = ['/s%d' % u for u in STATIONS] - with tables.open_file('data.h5', 'w') as data: - for station, group in zip(STATIONS, station_groups): - download_data(data, group, station, START, END, True) - proc = ProcessEvents(data, group) - proc.process_and_store_results() + with tables.open_file('data.h5', 'w') as data: + for station, group in zip(STATIONS, station_groups): + download_data(data, group, station, START, END, True) + proc = ProcessEvents(data, group) + proc.process_and_store_results() """ + import operator import os import warnings @@ -41,7 +42,7 @@ from .process_traces import ADC_HIGH_THRESHOLD, ADC_LOW_THRESHOLD, ADC_TIME_PER_SAMPLE ADC_THRESHOLD = 20 #: Threshold for arrival times, relative to the baseline -ADC_LIMIT = 2 ** 12 +ADC_LIMIT = 2**12 #: Default trigger for 2-detector station #: 2 low and no high, no external @@ -52,7 +53,6 @@ class ProcessEvents: - """Process HiSPARC events to obtain several observables. This class can be used to process a set of HiSPARC events and adds a @@ -83,7 +83,8 @@ class ProcessEvents: 'n2': tables.Float32Col(pos=18, dflt=-1), 'n3': tables.Float32Col(pos=19, dflt=-1), 'n4': tables.Float32Col(pos=20, dflt=-1), - 't_trigger': tables.Float32Col(pos=21, dflt=-1)} + 't_trigger': tables.Float32Col(pos=21, dflt=-1), + } def __init__(self, data, group, source=None, progress=True): """Initialize the class. @@ -103,8 +104,7 @@ def __init__(self, data, group, source=None, progress=True): self.progress = progress self.limit = None - def process_and_store_results(self, destination=None, overwrite=False, - limit=None): + def process_and_store_results(self, destination=None, overwrite=False, limit=None): """Process events and store the results. :param destination: name of the table where the results will be @@ -132,8 +132,7 @@ def get_traces_for_event(self, event): :return: the traces: an array of pulseheight values. """ - traces = [list(self._get_trace(idx)) for idx in event['traces'] - if idx >= 0] + traces = [list(self._get_trace(idx)) for idx in event['traces'] if idx >= 0] # Make traces follow NumPy conventions traces = np.array(traces).T @@ -171,8 +170,7 @@ def _check_destination(self, destination, overwrite): """Check if the destination is valid""" if destination == '_events': - raise RuntimeError("The _events table is reserved for internal " - "use. Choose another destination.") + raise RuntimeError('The _events table is reserved for internal use. Choose another destination.') elif destination is None: destination = 'events' @@ -180,8 +178,7 @@ def _check_destination(self, destination, overwrite): # worry. Otherwise, destination may not exist or will be overwritten if self.source.name != destination: if destination in self.group and not overwrite: - raise RuntimeError("I will not overwrite previous results " - "(unless you specify overwrite=True)") + raise RuntimeError('I will not overwrite previous results (unless you specify overwrite=True)') self.destination = destination @@ -198,8 +195,7 @@ def _clean_events_table(self): unique_sorted_ids = self._find_unique_row_ids(enumerated_timestamps) - new_events = self._replace_table_with_selected_rows(events, - unique_sorted_ids) + new_events = self._replace_table_with_selected_rows(events, unique_sorted_ids) self.source = new_events self._normalize_event_ids(new_events) @@ -224,8 +220,7 @@ def _replace_table_with_selected_rows(self, table, row_ids): the destination table. """ - tmptable = self.data.create_table(self.group, 't__events', - description=table.description) + tmptable = self.data.create_table(self.group, 't__events', description=table.description) selected_rows = table.read_coordinates(row_ids) tmptable.append(selected_rows) tmptable.flush() @@ -262,9 +257,7 @@ def _create_empty_results_table(self): if '_t_events' in self.group: self.data.remove_node(self.group, '_t_events') - table = self.data.create_table(self.group, '_t_events', - self.processed_events_description, - expectedrows=length) + table = self.data.create_table(self.group, '_t_events', self.processed_events_description, expectedrows=length) for _ in range(length): table.row.append() @@ -277,8 +270,7 @@ def _copy_events_into_table(self): source = self.source for col in pbar(source.colnames, show=self.progress): - table.modify_column(stop=self.limit, colname=col, - column=getattr(source.cols, col)[:self.limit]) + table.modify_column(stop=self.limit, colname=col, column=getattr(source.cols, col)[: self.limit]) table.flush() def _store_results_from_traces(self): @@ -300,8 +292,7 @@ def process_traces(self): else: events = self.source - timings = self._process_traces_from_event_list(events, - length=self.limit) + timings = self._process_traces_from_event_list(events, length=self.limit) return timings def _process_traces_from_event_list(self, events, length=None): @@ -333,9 +324,7 @@ def _reconstruct_time_from_traces(self, event): """ timings = [] - for baseline, pulseheight, trace_idx in zip(event['baseline'], - event['pulseheights'], - event['traces']): + for baseline, pulseheight, trace_idx in zip(event['baseline'], event['pulseheights'], event['traces']): if pulseheight < 0: # retain -1, -999 status flags in timing timings.append(pulseheight) @@ -343,11 +332,8 @@ def _reconstruct_time_from_traces(self, event): timings.append(-999) else: trace = self._get_trace(trace_idx) - timings.append(self._reconstruct_time_from_trace(trace, - baseline)) - timings = [time * ADC_TIME_PER_SAMPLE - if time not in ERR else time - for time in timings] + timings.append(self._reconstruct_time_from_trace(trace, baseline)) + timings = [time * ADC_TIME_PER_SAMPLE if time not in ERR else time for time in timings] return timings def _get_trace(self, idx): @@ -364,8 +350,7 @@ def _get_trace(self, idx): try: trace = zlib.decompress(blobs[idx]).decode('utf-8').split(',') except zlib.error: - trace = (zlib.decompress(blobs[idx][1:-1]) - .decode('utf-8').split(',')) + trace = zlib.decompress(blobs[idx][1:-1]).decode('utf-8').split(',') if trace[-1] == '': del trace[-1] trace = (int(x) for x in trace) @@ -433,8 +418,7 @@ def _process_pulseintegrals(self): if (detector_integrals < 0).all(): all_mpv.append(np.nan) else: - n, bins = np.histogram(detector_integrals, - bins=np.linspace(0, 50000, 201)) + n, bins = np.histogram(detector_integrals, bins=np.linspace(0, 50_000, 201)) find_mpv = FindMostProbableValueInSpectrum(n, bins) mpv, is_fitted = find_mpv.find_mpv() if is_fitted: @@ -443,15 +427,12 @@ def _process_pulseintegrals(self): all_mpv.append(np.nan) all_mpv = np.array(all_mpv) - for event in self.source[:self.limit]: + for event in self.source[: self.limit]: pulseintegrals = event['integrals'] # retain -1, -999 status flags - pulseintegrals = np.where(pulseintegrals >= 0, - pulseintegrals / all_mpv, - pulseintegrals) + pulseintegrals = np.where(pulseintegrals >= 0, pulseintegrals / all_mpv, pulseintegrals) # if mpv fit failed, value is nan. Make it -999 - pulseintegrals = np.where(np.isnan(pulseintegrals), -999, - pulseintegrals) + pulseintegrals = np.where(np.isnan(pulseintegrals), -999, pulseintegrals) n_particles.append(pulseintegrals) return np.array(n_particles) @@ -467,16 +448,18 @@ def _move_results_table_into_destination(self): def __repr__(self): if not self.data.isopen: - return "" % self.__class__.__name__ + return '' % self.__class__.__name__ else: - return ("%s(%r, %r, source=%r, progress=%r)" % - (self.__class__.__name__, self.data.filename, - self.group._v_pathname, self.source._v_pathname, - self.progress)) + return '%s(%r, %r, source=%r, progress=%r)' % ( + self.__class__.__name__, + self.data.filename, + self.group._v_pathname, + self.source._v_pathname, + self.progress, + ) class ProcessIndexedEvents(ProcessEvents): - """Process a subset of events using an index. This is a subclass of :class:`ProcessEvents`. Using an index, this @@ -533,7 +516,6 @@ def get_traces_for_indexed_event_index(self, idx): class ProcessEventsWithLINT(ProcessEvents): - """Process events using LInear INTerpolation for arrival times. This is a subclass of :class:`ProcessEvents`. Use a linear @@ -557,10 +539,10 @@ def _reconstruct_time_from_trace(self, trace, baseline): if i == 0: value = i - elif not i == -999: + elif i != -999: x0, x1 = i - 1, i y0, y1 = trace[x0], trace[x1] - value = 1. * (threshold - y0) / (y1 - y0) + x0 + value = 1.0 * (threshold - y0) / (y1 - y0) + x0 else: value = -999 @@ -568,7 +550,6 @@ def _reconstruct_time_from_trace(self, trace, baseline): class ProcessIndexedEventsWithLINT(ProcessIndexedEvents, ProcessEventsWithLINT): - """Process a subset of events using LInear INTerpolation. This is a subclass of :class:`ProcessIndexedEvents` and @@ -576,11 +557,8 @@ class ProcessIndexedEventsWithLINT(ProcessIndexedEvents, ProcessEventsWithLINT): """ - pass - class ProcessEventsWithoutTraces(ProcessEvents): - """Process events without traces This is a subclass of :class:`ProcessEvents`. Processing events @@ -594,11 +572,8 @@ class ProcessEventsWithoutTraces(ProcessEvents): def _store_results_from_traces(self): """Fake storing results from traces.""" - pass - class ProcessIndexedEventsWithoutTraces(ProcessEventsWithoutTraces, ProcessIndexedEvents): - """Process a subset of events without traces This is a subclass of :class:`ProcessIndexedEvents` and @@ -610,11 +585,8 @@ class ProcessIndexedEventsWithoutTraces(ProcessEventsWithoutTraces, ProcessIndex """ - pass - class ProcessEventsWithTriggerOffset(ProcessEvents): - """Process events and reconstruct trigger time from traces The trigger times are stored in the columnt_trigger, they are @@ -648,7 +620,7 @@ def __init__(self, data, group, source=None, progress=True, station=None): elif n == 4: self.trigger = TRIGGER_4 else: - raise Exception('No trigger settings available') + raise ValueError('No trigger settings available') else: self.station = Station(station) @@ -693,8 +665,11 @@ def _reconstruct_time_from_traces(self, event): low_idx = [] high_idx = [] for baseline, pulseheight, trace_idx, trig_thresholds in zip( - event['baseline'], event['pulseheights'], event['traces'], - self.thresholds): + event['baseline'], + event['pulseheights'], + event['traces'], + self.thresholds, + ): if pulseheight < 0: # Retain -1 and -999 status flags in timing timings.append(pulseheight) @@ -724,17 +699,15 @@ def _reconstruct_time_from_traces(self, event): trace = self._get_trace(trace_idx) - t, l, h = self._first_above_thresholds(trace, thresholds, - max_signal) - timings.append(t) - low_idx.append(l) - high_idx.append(h) + time, low, high = self._first_above_thresholds(trace, thresholds, max_signal) + timings.append(time) + low_idx.append(low) + high_idx.append(high) t_trigger = self._reconstruct_trigger(low_idx, high_idx) timings.append(t_trigger) - timings = [time * ADC_TIME_PER_SAMPLE if time not in ERR else time - for time in timings] + timings = [time * ADC_TIME_PER_SAMPLE if time not in ERR else time for time in timings] return timings @classmethod @@ -781,8 +754,7 @@ def _first_value_above_threshold(trace, threshold, t=0): threshold, and the value. """ - return next(((i, x) for i, x in enumerate(trace, t) if x >= threshold), - (-999, 0)) + return next(((i, x) for i, x in enumerate(trace, t) if x >= threshold), (-999, 0)) def _reconstruct_trigger(self, low_idx, high_idx): """Reconstruct the moment of trigger from the threshold info @@ -798,8 +770,8 @@ def _reconstruct_trigger(self, low_idx, high_idx): if external: return -999 - low_idx = sorted(idx for idx in low_idx if not idx == -999) - high_idx = sorted(idx for idx in high_idx if not idx == -999) + low_idx = sorted(idx for idx in low_idx if idx != -999) + high_idx = sorted(idx for idx in high_idx if idx != -999) if and_or: # low or high, which ever is first @@ -809,39 +781,45 @@ def _reconstruct_trigger(self, low_idx, high_idx): return high_idx[n_high - 1] elif n_low and len(low_idx) >= n_low: return low_idx[n_low - 1] - else: - if n_low and n_high: - # low and high - if len(low_idx) >= n_low + n_high and len(high_idx) >= n_high: - return max(low_idx[n_low + n_high - 1], high_idx[n_high - 1]) - elif n_high: - # 0 low and high - if len(high_idx) >= n_high: - return high_idx[n_high - 1] - elif n_low: - # low and 0 high - if len(low_idx) >= n_low: - return low_idx[n_low - 1] + elif n_low and n_high: + # low and high + if len(low_idx) >= n_low + n_high and len(high_idx) >= n_high: + return max(low_idx[n_low + n_high - 1], high_idx[n_high - 1]) + elif n_high: + # 0 low and high + if len(high_idx) >= n_high: + return high_idx[n_high - 1] + elif n_low: + # low and 0 high + if len(low_idx) >= n_low: + return low_idx[n_low - 1] return -999 def __repr__(self): if not self.data.isopen: - return "" % self.__class__.__name__ + return f'' elif self.station is None: - return ("%s(%r, %r, source=%r, progress=%r, Station=%r)" % - (self.__class__.__name__, self.data.filename, - self.group._v_pathname, self.source._v_pathname, - self.progress, None)) + return '%s(%r, %r, source=%r, progress=%r, Station=%r)' % ( + self.__class__.__name__, + self.data.filename, + self.group._v_pathname, + self.source._v_pathname, + self.progress, + None, + ) else: - return ("%s(%r, %r, source=%r, progress=%r, station=%d)" % - (self.__class__.__name__, self.data.filename, - self.group._v_pathname, self.source._v_pathname, - self.progress, self.station.number)) + return '%s(%r, %r, source=%r, progress=%r, station=%d)' % ( + self.__class__.__name__, + self.data.filename, + self.group._v_pathname, + self.source._v_pathname, + self.progress, + self.station.number, + ) class ProcessEventsFromSource(ProcessEvents): - """Process HiSPARC events from a different source. This class is a subclass of ProcessEvents. The difference is that in @@ -851,8 +829,7 @@ class ProcessEventsFromSource(ProcessEvents): """ - def __init__(self, source_file, dest_file, source_group, dest_group, - progress=False): + def __init__(self, source_file, dest_file, source_group, dest_group, progress=False): """Initialize the class. :param source_file,dest_file: PyTables source and destination files. @@ -897,7 +874,6 @@ def _get_source(self): def _check_destination(self, destination, overwrite): """Override method, the destination is empty""" - pass def _replace_table_with_selected_rows(self, table, row_ids): """Replace events table with selected rows. @@ -907,8 +883,7 @@ def _replace_table_with_selected_rows(self, table, row_ids): the destination table. """ - new_events = self.dest_file.create_table(self.dest_group, '_events', - description=table.description) + new_events = self.dest_file.create_table(self.dest_group, '_events', description=table.description) selected_rows = table.read_coordinates(row_ids) new_events.append(selected_rows) new_events.flush() @@ -922,9 +897,12 @@ def _create_empty_results_table(self): else: length = len(self.source) - table = self.dest_file.create_table(self.dest_group, 'events', - self.processed_events_description, - expectedrows=length) + table = self.dest_file.create_table( + self.dest_group, + 'events', + self.processed_events_description, + expectedrows=length, + ) for _ in range(length): table.row.append() @@ -943,17 +921,19 @@ def _get_blobs(self): def __repr__(self): if not self.source_file.isopen or not self.dest_file.isopen: - return f"" + return f'' else: - return ("%s(%r, %r, %r, %r, progress=%r)" % - (self.__class__.__name__, self.source_file.filename, - self.dest_file.filename, self.source_group._v_pathname, - self.dest_group._v_pathname, self.progress)) + return '%s(%r, %r, %r, %r, progress=%r)' % ( + self.__class__.__name__, + self.source_file.filename, + self.dest_file.filename, + self.source_group._v_pathname, + self.dest_group._v_pathname, + self.progress, + ) -class ProcessEventsFromSourceWithTriggerOffset(ProcessEventsFromSource, - ProcessEventsWithTriggerOffset): - +class ProcessEventsFromSourceWithTriggerOffset(ProcessEventsFromSource, ProcessEventsWithTriggerOffset): """Process events from a different source and find trigger. This is a subclass of :class:`ProcessEventsFromSource` and @@ -963,8 +943,7 @@ class ProcessEventsFromSourceWithTriggerOffset(ProcessEventsFromSource, """ - def __init__(self, source_file, dest_file, source_group, dest_group, - station=None, progress=False): + def __init__(self, source_file, dest_file, source_group, dest_group, station=None, progress=False): """Initialize the class. :param source_file,dest_file: PyTables source and destination files. @@ -995,28 +974,35 @@ def __init__(self, source_file, dest_file, source_group, dest_group, elif n == 4: self.trigger = TRIGGER_4 else: - raise Exception('No trigger settings available') + raise ValueError('No trigger settings available') else: self.station = Station(station) def __repr__(self): if not self.source_file.isopen or not self.dest_file.isopen: - return "" % self.__class__.__name__ + return f'' elif self.station is None: - return ("%s(%r, %r, %r, %r, progress=%r)" % - (self.__class__.__name__, self.source_file.filename, - self.dest_file.filename, self.source_group._v_pathname, - self.dest_group._v_pathname, self.progress)) + return '%s(%r, %r, %r, %r, progress=%r)' % ( + self.__class__.__name__, + self.source_file.filename, + self.dest_file.filename, + self.source_group._v_pathname, + self.dest_group._v_pathname, + self.progress, + ) else: - return ("%s(%r, %r, %r, %r, station=%d, progress=%r)" % - (self.__class__.__name__, self.source_file.filename, - self.dest_file.filename, self.source_group._v_pathname, - self.dest_group._v_pathname, self.station.number, - self.progress)) + return '%s(%r, %r, %r, %r, station=%d, progress=%r)' % ( + self.__class__.__name__, + self.source_file.filename, + self.dest_file.filename, + self.source_group._v_pathname, + self.dest_group._v_pathname, + self.station.number, + self.progress, + ) class ProcessDataTable(ProcessEvents): - """Process HiSPARC abstract data table to clean the data. Abstract data is a PyTables table containing a timestamp for each row. @@ -1025,6 +1011,7 @@ class ProcessDataTable(ProcessEvents): sort the data by timestamp to store it in to a copy of the table. """ + table_name = 'abstract_data' # overwrite with 'weather' or 'singles' def process_and_store_results(self, destination=None, overwrite=False, limit=None): @@ -1062,16 +1049,14 @@ def _check_destination(self, destination, overwrite): """Check if the destination is valid""" if destination == f'_t_{self.table_name}': - raise RuntimeError(f"The _t_{self.table_name} table is for internal use. Choose " - "another destination.") + raise RuntimeError(f'The _t_{self.table_name} table is for internal use. Choose another destination.') elif destination is None: destination = self.table_name # If destination == source, source will be overwritten. if self.source.name != destination: if destination in self.group and not overwrite: - raise RuntimeError("I will not overwrite previous results " - "(unless you specify overwrite=True)") + raise RuntimeError('I will not overwrite previous results (unless you specify overwrite=True)') self.destination = destination @@ -1088,8 +1073,7 @@ def _clean_data_table(self): unique_sorted_ids = self._find_unique_row_ids(enumerated_timestamps) - new_data = self._replace_table_with_selected_rows(data, - unique_sorted_ids) + new_data = self._replace_table_with_selected_rows(data, unique_sorted_ids) self.source = new_data self._normalize_event_ids(new_data) @@ -1101,9 +1085,7 @@ def _replace_table_with_selected_rows(self, table, row_ids): the destination table. """ - tmptable = self.data.create_table(self.group, - f'_t_{self.table_name}', - description=table.description) + tmptable = self.data.create_table(self.group, f'_t_{self.table_name}', description=table.description) selected_rows = table.read_coordinates(row_ids) tmptable.append(selected_rows) tmptable.flush() @@ -1112,7 +1094,6 @@ def _replace_table_with_selected_rows(self, table, row_ids): class ProcessDataTableFromSource(ProcessDataTable): - """Process HiSPARC abstract data table from a different source. This class is a subclass of ProcessDataTable. The difference is that in @@ -1153,7 +1134,6 @@ def _get_source(self): def _check_destination(self, destination, overwrite): """Override method, the destination should be empty""" - pass def _replace_table_with_selected_rows(self, table, row_ids): """Replace data table with selected rows. @@ -1163,9 +1143,7 @@ def _replace_table_with_selected_rows(self, table, row_ids): the destination table. """ - new_table = self.dest_file.create_table(self.dest_group, - self.table_name, - description=table.description) + new_table = self.dest_file.create_table(self.dest_group, self.table_name, description=table.description) selected_rows = table.read_coordinates(row_ids) new_table.append(selected_rows) new_table.flush() @@ -1173,16 +1151,19 @@ def _replace_table_with_selected_rows(self, table, row_ids): def __repr__(self): if not self.source_file.isopen or not self.dest_file.isopen: - return "" % self.__class__.__name__ + return f'' else: - return ("%s(%r, %r, %r, %r, progress=%r)" % - (self.__class__.__name__, self.source_file.filename, - self.dest_file.filename, self.source_group._v_pathname, - self.dest_group._v_pathname, self.progress)) + return '%s(%r, %r, %r, %r, progress=%r)' % ( + self.__class__.__name__, + self.source_file.filename, + self.dest_file.filename, + self.source_group._v_pathname, + self.dest_group._v_pathname, + self.progress, + ) class ProcessWeather(ProcessDataTable): - """Process HiSPARC weather to clean the data. This class can be used to process a set of HiSPARC weather, to @@ -1190,11 +1171,11 @@ class ProcessWeather(ProcessDataTable): copy of the weather table. """ + table_name = 'weather' class ProcessWeatherFromSource(ProcessDataTableFromSource): - """Process HiSPARC weather from a different source. This class behaves like a subclass of ProcessWeather because of a common @@ -1205,11 +1186,11 @@ class ProcessWeatherFromSource(ProcessDataTableFromSource): assumed to be empty. """ + table_name = 'weather' class ProcessSingles(ProcessDataTable): - """Process HiSPARC singles data to clean the data. This class can be used to process a set of HiSPARC singles data, to @@ -1217,11 +1198,11 @@ class ProcessSingles(ProcessDataTable): copy of the singles data table. """ + table_name = 'singles' class ProcessSinglesFromSource(ProcessDataTableFromSource): - """Process HiSPARC singles data from a different source. This class behaves like a subclass of ProcessSingles because of a common @@ -1232,4 +1213,5 @@ class ProcessSinglesFromSource(ProcessDataTableFromSource): assumed to be empty. """ + table_name = 'singles' diff --git a/sapphire/analysis/process_traces.py b/sapphire/analysis/process_traces.py index d8ce3049..8618f00d 100644 --- a/sapphire/analysis/process_traces.py +++ b/sapphire/analysis/process_traces.py @@ -1,12 +1,13 @@ -""" Process HiSPARC traces +"""Process HiSPARC traces - This module can be used analyse (raw) traces. It implements the same - algorithms as are implemented in the HiSPARC DAQ. +This module can be used analyse (raw) traces. It implements the same +algorithms as are implemented in the HiSPARC DAQ. - The :class:`MeanFilter` is meant to mimic the filter in the HiSPARC DAQ. - It is reproduced here to make it easy to read the algorithm. +The :class:`MeanFilter` is meant to mimic the filter in the HiSPARC DAQ. +It is reproduced here to make it easy to read the algorithm. """ + from functools import cached_property from numpy import around, convolve, ones, where @@ -35,7 +36,6 @@ class TraceObservables: - """Reconstruct trace observables If one wants to reconstruct trace observables from existing data some @@ -56,8 +56,7 @@ class TraceObservables: """ - def __init__(self, traces, threshold=ADC_BASELINE_THRESHOLD, - padding=DATA_REDUCTION_PADDING): + def __init__(self, traces, threshold=ADC_BASELINE_THRESHOLD, padding=DATA_REDUCTION_PADDING): """Initialize the class :param traces: a NumPy array of traces, ordered such that the first @@ -73,7 +72,7 @@ def __init__(self, traces, threshold=ADC_BASELINE_THRESHOLD, self.n = self.traces.shape[1] self.missing = [-1] * (4 - self.n) if self.n not in [2, 4]: - raise Exception('Unsupported number of detectors') + raise ValueError('Unsupported number of detectors') @cached_property def baselines(self): @@ -89,7 +88,7 @@ def baselines(self): :return: the baseline in ADC count. """ - baselines = around(self.traces[:self.padding].mean(axis=0)) + baselines = around(self.traces[: self.padding].mean(axis=0)) return baselines.astype('int').tolist() + self.missing @cached_property @@ -99,7 +98,7 @@ def std_dev(self): :return: the standard deviation in milli ADC count. """ - std_dev = around(self.traces[:self.padding].std(axis=0) * 1000) + std_dev = around(self.traces[: self.padding].std(axis=0) * 1000) return std_dev.astype('int').tolist() + self.missing @cached_property @@ -109,7 +108,7 @@ def pulseheights(self): :return: the pulseheights in ADC count. """ - pulseheights = self.traces.max(axis=0) - self.baselines[:self.n] + pulseheights = self.traces.max(axis=0) - self.baselines[: self.n] return pulseheights.tolist() + self.missing @cached_property @@ -122,8 +121,11 @@ def integrals(self): """ threshold = self.threshold - integrals = where(self.traces - self.baselines[:self.n] > threshold, - self.traces - self.baselines[:self.n], 0).sum(axis=0) + integrals = where( + self.traces - self.baselines[: self.n] > threshold, + self.traces - self.baselines[: self.n], + 0, + ).sum(axis=0) return integrals.tolist() + self.missing @cached_property @@ -136,12 +138,12 @@ def n_peaks(self): """ # Make rough guess at the baseline/threshold to expect - if all(b < 100 for b in self.baselines[:self.n]): + if all(b < 100 for b in self.baselines[: self.n]): peak_threshold = ADC_LOW_THRESHOLD_III - 30 else: peak_threshold = ADC_LOW_THRESHOLD - 200 - traces = self.traces - self.baselines[:self.n] + traces = self.traces - self.baselines[: self.n] n_peaks = [] for trace in traces.T: @@ -157,20 +159,18 @@ def n_peaks(self): in_peak = True local_maximum = value n_peak += 1 - else: - if value > local_maximum: - local_maximum = value - elif local_maximum - value > peak_threshold: - # enough signal decrease to be out of peak - in_peak = False - local_minimum = value if value > 0 else 0 + elif value > local_maximum: + local_maximum = value + elif local_maximum - value > peak_threshold: + # enough signal decrease to be out of peak + in_peak = False + local_minimum = value if value > 0 else 0 n_peaks.append(n_peak) return n_peaks + self.missing class MeanFilter: - """Filter raw traces This class replicates the behavior of the Mean_Filter.vi in the HiSPARC @@ -225,9 +225,7 @@ def filter_trace(self, raw_trace): filtered_even = self.filter(even_trace) filtered_odd = self.filter(odd_trace) - recombined_trace = [v - for eo in zip(filtered_even, filtered_odd) - for v in eo] + recombined_trace = [v for eo in zip(filtered_even, filtered_odd) for v in eo] filtered_trace = self.filter(recombined_trace) return filtered_trace @@ -241,7 +239,7 @@ def mean_filter_with_threshold(self, trace): local_mean = moving_average[3] local_mean_rounded = rounded_average[3] - if all([abs(v - local_mean) <= self.threshold for v in trace[:4]]): + if all(abs(v - local_mean) <= self.threshold for v in trace[:4]): filtered_trace.extend([local_mean_rounded] * 4) else: filtered_trace.extend(trace[:4]) @@ -283,14 +281,12 @@ def mean_filter_without_threshold(self, trace): def __repr__(self): try: - return ("%s(use_threshold=%s, threshold=%r)" % - (self.__class__.__name__, True, self.threshold)) + return '%s(use_threshold=%s, threshold=%r)' % (self.__class__.__name__, True, self.threshold) except AttributeError: - return f"{self.__class__.__name__}(use_threshold={False})" + return f'{self.__class__.__name__}(use_threshold={False})' class DataReduction: - """Data reduce traces This class replicates the behavior also implemented in the HiSPARC DAQ. @@ -302,8 +298,7 @@ class DataReduction: """ - def __init__(self, threshold=ADC_BASELINE_THRESHOLD, - padding=DATA_REDUCTION_PADDING): + def __init__(self, threshold=ADC_BASELINE_THRESHOLD, padding=DATA_REDUCTION_PADDING): """Initialize the class :param threshold: value of the threshold to use, in ADC counts. @@ -326,7 +321,7 @@ def reduce_traces(self, traces, baselines=None, return_offset=False): """ if baselines is None: - baselines = TraceObservables(traces).baselines[:len(traces[0])] + baselines = TraceObservables(traces).baselines[: len(traces[0])] left, right = self.determine_cuts(traces, baselines) left, right = self.add_padding(left, right, len(traces)) if return_offset: @@ -346,10 +341,11 @@ def determine_cuts(self, traces, baselines): right cross the threshold. """ - left = next((i for i, t in enumerate(traces) - if max(t - baselines) > self.threshold), 0) - right = len(traces) - next((i for i, t in enumerate(reversed(traces)) - if max(t - baselines) > self.threshold), 0) + left = next((i for i, t in enumerate(traces) if max(t - baselines) > self.threshold), 0) + right = len(traces) - next( + (i for i, t in enumerate(reversed(traces)) if max(t - baselines) > self.threshold), + 0, + ) return left, right def add_padding(self, left, right, length=None): diff --git a/sapphire/analysis/reconstructions.py b/sapphire/analysis/reconstructions.py index 7be7d980..11faf84d 100644 --- a/sapphire/analysis/reconstructions.py +++ b/sapphire/analysis/reconstructions.py @@ -1,23 +1,24 @@ -""" Reconstruct HiSPARC events and coincidences +"""Reconstruct HiSPARC events and coincidences - This module contains classes that can be used to reconstruct - HiSPARC events and coincidences. These classes can be used to automate - the tasks of reconstructing directions and/or cores. +This module contains classes that can be used to reconstruct +HiSPARC events and coincidences. These classes can be used to automate +the tasks of reconstructing directions and/or cores. - The classes can reconstruct measured data from the ESD as well as - simulated data from :mod:`sapphire.simulations`. +The classes can reconstruct measured data from the ESD as well as +simulated data from :mod:`sapphire.simulations`. - The classes read data stored in HDF5 files and extract station metadata - (cluster and detector layout, station and detector offsets) from - various sources: +The classes read data stored in HDF5 files and extract station metadata +(cluster and detector layout, station and detector offsets) from +various sources: - - from the public database using :class:`sapphire.api.Station` objects - - from stored or provided :class`sappire.cluster.Station` objects, - usually cluster or station layout stored by :mod:`sapphire.simulations` +- from the public database using :class:`sapphire.api.Station` objects +- from stored or provided :class`sappire.cluster.Station` objects, + usually cluster or station layout stored by :mod:`sapphire.simulations` - Reconstructed data is stored in HDF5 files. + Reconstructed data is stored in HDF5 files. """ + import os import warnings @@ -36,7 +37,6 @@ class ReconstructESDEvents: - """Reconstruct events from single stations Example usage:: @@ -62,10 +62,18 @@ class ReconstructESDEvents: """ - def __init__(self, data, station_group, station, - overwrite=False, progress=True, verbose=False, - destination='reconstructions', - force_fresh=False, force_stale=False): + def __init__( + self, + data, + station_group, + station, + overwrite=False, + progress=True, + verbose=False, + destination='reconstructions', + force_fresh=False, + force_stale=False, + ): """Initialize the class. :param data: the PyTables datafile. @@ -92,7 +100,7 @@ def __init__(self, data, station_group, station, self.force_fresh = force_fresh self.force_stale = force_stale - self.offsets = [0., 0., 0., 0.] + self.offsets = [0.0, 0.0, 0.0, 0.0] self._get_or_create_station_object(station) @@ -121,13 +129,10 @@ def reconstruct_directions(self, detector_ids=None): """ if len(self.core_x) and len(self.core_y): - initials = ({'core_x': x, 'core_y': y} - for x, y in zip(self.core_x, self.core_y)) + initials = ({'core_x': x, 'core_y': y} for x, y in zip(self.core_x, self.core_y)) else: initials = [] - angles = self.direction.reconstruct_events(self.events, detector_ids, - self.offsets, self.progress, - initials) + angles = self.direction.reconstruct_events(self.events, detector_ids, self.offsets, self.progress, initials) self.theta, self.phi, self.detector_ids = angles def reconstruct_cores(self, detector_ids=None): @@ -137,12 +142,10 @@ def reconstruct_cores(self, detector_ids=None): """ if len(self.theta) and len(self.phi): - initials = ({'theta': theta, 'phi': phi} - for theta, phi in zip(self.theta, self.phi)) + initials = ({'theta': theta, 'phi': phi} for theta, phi in zip(self.theta, self.phi)) else: initials = [] - cores = self.core.reconstruct_events(self.events, detector_ids, - self.progress, initials) + cores = self.core.reconstruct_events(self.events, detector_ids, self.progress, initials) self.core_x, self.core_y = cores def prepare_output(self): @@ -150,15 +153,17 @@ def prepare_output(self): if self.destination in self.station_group: if self.overwrite: - self.data.remove_node(self.station_group, self.destination, - recursive=True) + self.data.remove_node(self.station_group, self.destination, recursive=True) else: - raise RuntimeError("Reconstructions table already exists for " - "%s, and overwrite is False" % - self.station_group) + raise RuntimeError( + 'Reconstructions table already exists for %s, and overwrite is False' % self.station_group, + ) self.reconstructions = self.data.create_table( - self.station_group, self.destination, ReconstructedEvent, - expectedrows=self.events.nrows) + self.station_group, + self.destination, + ReconstructedEvent, + expectedrows=self.events.nrows, + ) try: self.reconstructions._v_attrs.station = self.station except tables.HDF5ExtError: @@ -191,14 +196,15 @@ def get_detector_offsets(self): print('Read detector offsets from station object.') except AttributeError: if self.station_number is not None: - self.offsets = api.Station(self.station_number, - force_fresh=self.force_fresh, - force_stale=self.force_stale) + self.offsets = api.Station( + self.station_number, + force_fresh=self.force_fresh, + force_stale=self.force_stale, + ) if self.verbose: print('Reading detector offsets from public database.') else: - self.offsets = determine_detector_timing_offsets(self.events, - self.station) + self.offsets = determine_detector_timing_offsets(self.events, self.station) self.store_offsets() if self.verbose: print('Determined offsets from event data: ', self.offsets) @@ -208,14 +214,12 @@ def store_offsets(self): if 'detector_offsets' in self.station_group: if self.overwrite: - self.data.remove_node(self.station_group.detector_offsets, - recursive=True) + self.data.remove_node(self.station_group.detector_offsets, recursive=True) else: - raise RuntimeError("Detector offset table already exists for " - "%s, and overwrite is False" % - self.station_group) - self.detector_offsets = self.data.create_array( - self.station_group, 'detector_offsets', self.offsets) + raise RuntimeError( + 'Detector offset table already exists for %s, and overwrite is False' % self.station_group, + ) + self.detector_offsets = self.data.create_array(self.station_group, 'detector_offsets', self.offsets) self.detector_offsets.flush() def store_reconstructions(self): @@ -226,31 +230,34 @@ def store_reconstructions(self): """ for event, core_x, core_y, theta, phi, detector_ids in zip_longest( - self.events, self.core_x, self.core_y, - self.theta, self.phi, self.detector_ids): - self._store_reconstruction(event, core_x, core_y, theta, phi, - detector_ids) + self.events, + self.core_x, + self.core_y, + self.theta, + self.phi, + self.detector_ids, + ): + self._store_reconstruction(event, core_x, core_y, theta, phi, detector_ids) self.reconstructions.flush() - def _store_reconstruction(self, event, core_x, core_y, theta, phi, - detector_ids): + def _store_reconstruction(self, event, core_x, core_y, theta, phi, detector_ids): """Store single reconstruction""" row = self.reconstructions.row row['id'] = event['event_id'] row['ext_timestamp'] = event['ext_timestamp'] try: - row['min_n'] = min(event['n%d' % (id + 1)] for id in detector_ids) + row['min_n'] = min(event['n%d' % (detector_id + 1)] for detector_id in detector_ids) except ValueError: # sometimes, all arrival times are -999 or -1, and then # detector_ids = []. So min([]) gives a ValueError. - row['min_n'] = -999. + row['min_n'] = -999.0 row['x'] = core_x row['y'] = core_y row['zenith'] = theta row['azimuth'] = phi - for id in detector_ids: - row['d%d' % (id + 1)] = True + for detector_id in detector_ids: + row['d%d' % (detector_id + 1)] = True row.append() def _get_or_create_station_object(self, station): @@ -258,23 +265,30 @@ def _get_or_create_station_object(self, station): self.station = station self.station_number = None if self.verbose: - print('Using object %s for metadata.' % self.station) + print(f'Using object {self.station} for metadata.') else: self.station_number = station - cluster = HiSPARCStations([station], - force_fresh=self.force_fresh, - force_stale=self.force_stale) + cluster = HiSPARCStations([station], force_fresh=self.force_fresh, force_stale=self.force_stale) self.station = cluster.get_station(station) if self.verbose: print(f'Constructed object {self.station} from public database.') class ReconstructESDEventsFromSource(ReconstructESDEvents): - - def __init__(self, source_data, dest_data, source_group, dest_group, - station, overwrite=False, progress=True, verbose=False, - destination='reconstructions', - force_fresh=False, force_stale=False): + def __init__( + self, + source_data, + dest_data, + source_group, + dest_group, + station, + overwrite=False, + progress=True, + verbose=False, + destination='reconstructions', + force_fresh=False, + force_stale=False, + ): """Initialize the class. :param data: the PyTables datafile. @@ -291,8 +305,16 @@ def __init__(self, source_data, dest_data, source_group, dest_group, """ super().__init__( - source_data, source_group, station, overwrite, progress, verbose, - destination, force_fresh, force_stale) + source_data, + source_group, + station, + overwrite, + progress, + verbose, + destination, + force_fresh, + force_stale, + ) self.dest_data = dest_data self.dest_group = dest_group @@ -305,12 +327,16 @@ def prepare_output(self): if self.overwrite: self.dest_data.remove_node(dest_path, recursive=True) else: - raise RuntimeError("Reconstructions table already exists for " - "%s, and overwrite is False" % - self.dest_group) + raise RuntimeError( + 'Reconstructions table already exists for %s, and overwrite is False' % self.dest_group, + ) self.reconstructions = self.dest_data.create_table( - self.dest_group, self.destination, ReconstructedEvent, - expectedrows=self.events.nrows, createparents=True) + self.dest_group, + self.destination, + ReconstructedEvent, + expectedrows=self.events.nrows, + createparents=True, + ) try: self.reconstructions._v_attrs.station = self.station except tables.HDF5ExtError: @@ -318,7 +344,6 @@ def prepare_output(self): class ReconstructSimulatedEvents(ReconstructESDEvents): - """Reconstruct simulated events from single stations Simulated events use simulated meta-data (e.g. timing offsets) @@ -367,8 +392,7 @@ def _get_or_create_station_object(self, station): cluster = self.data.get_node_attr('/coincidences', 'cluster') self.station = cluster.get_station(station) if self.station is None: - raise RuntimeError('Station %d not found in cluster' - ' object.' % self.station_number) + raise RuntimeError('Station %d not found in cluster object.' % self.station_number) if self.verbose: print('Read object %s from datafile.' % self.station) except (tables.NoSuchNodeError, AttributeError): @@ -376,7 +400,6 @@ def _get_or_create_station_object(self, station): class ReconstructESDCoincidences: - """Reconstruct coincidences, e.g. event between multiple stations Example usage:: @@ -390,10 +413,18 @@ class ReconstructESDCoincidences: """ - def __init__(self, data, coincidences_group='/coincidences', - overwrite=False, progress=True, verbose=False, - destination='reconstructions', cluster=None, - force_fresh=False, force_stale=False): + def __init__( + self, + data, + coincidences_group='/coincidences', + overwrite=False, + progress=True, + verbose=False, + destination='reconstructions', + cluster=None, + force_fresh=False, + force_stale=False, + ): """Initialize the class. :param data: the PyTables datafile. @@ -444,15 +475,17 @@ def reconstruct_directions(self, station_numbers=None): """ if len(self.core_x) and len(self.core_y): - initials = ({'core_x': x, 'core_y': y} - for x, y in zip(self.core_x, self.core_y)) + initials = ({'core_x': x, 'core_y': y} for x, y in zip(self.core_x, self.core_y)) else: initials = [] - coincidences = pbar(self.cq.all_coincidences(iterator=True), - length=self.coincidences.nrows, show=self.progress) + coincidences = pbar(self.cq.all_coincidences(iterator=True), length=self.coincidences.nrows, show=self.progress) angles = self.direction.reconstruct_coincidences( - self.cq.all_events(coincidences, n=0), station_numbers, - self.offsets, progress=False, initials=initials) + self.cq.all_events(coincidences, n=0), + station_numbers, + self.offsets, + progress=False, + initials=initials, + ) self.theta, self.phi, self.station_numbers = angles def reconstruct_cores(self, station_numbers=None): @@ -462,15 +495,16 @@ def reconstruct_cores(self, station_numbers=None): """ if len(self.theta) and len(self.phi): - initials = ({'theta': theta, 'phi': phi} - for theta, phi in zip(self.theta, self.phi)) + initials = ({'theta': theta, 'phi': phi} for theta, phi in zip(self.theta, self.phi)) else: initials = [] - coincidences = pbar(self.cq.all_coincidences(iterator=True), - length=self.coincidences.nrows, show=self.progress) + coincidences = pbar(self.cq.all_coincidences(iterator=True), length=self.coincidences.nrows, show=self.progress) cores = self.core.reconstruct_coincidences( - self.cq.all_events(coincidences, n=0), station_numbers, - progress=False, initials=initials) + self.cq.all_events(coincidences, n=0), + station_numbers, + progress=False, + initials=initials, + ) self.core_x, self.core_y = cores def prepare_output(self): @@ -478,20 +512,23 @@ def prepare_output(self): if self.destination in self.coincidences_group: if self.overwrite: - self.data.remove_node(self.coincidences_group, - self.destination, recursive=True) + self.data.remove_node(self.coincidences_group, self.destination, recursive=True) else: - raise RuntimeError("Reconstructions table already exists for " - "%s, and overwrite is False" % - self.coincidences_group) + raise RuntimeError( + 'Reconstructions table already exists for %s, and overwrite is False' % self.coincidences_group, + ) - s_columns = {'s%d' % station.number: tables.BoolCol(pos=p) - for p, station in enumerate(self.cluster.stations, 26)} + s_columns = { + 's%d' % station.number: tables.BoolCol(pos=p) for p, station in enumerate(self.cluster.stations, 26) + } description = ReconstructedCoincidence description.columns.update(s_columns) self.reconstructions = self.data.create_table( - self.coincidences_group, self.destination, description, - expectedrows=self.coincidences.nrows) + self.coincidences_group, + self.destination, + description, + expectedrows=self.coincidences.nrows, + ) try: self.reconstructions._v_attrs.cluster = self.cluster except tables.HDF5ExtError: @@ -507,17 +544,17 @@ def get_station_timing_offsets(self): """ try: - self.offsets = {station.number: [station.gps_offset + d.offset - for d in station.detectors] - for station in self.cluster.stations} + self.offsets = { + station.number: [station.gps_offset + d.offset for d in station.detectors] + for station in self.cluster.stations + } if self.verbose: print('Using timing offsets from cluster object.') except AttributeError: - self.offsets = {station.number: - api.Station(station.number, - force_fresh=self.force_fresh, - force_stale=self.force_stale) - for station in self.cluster.stations} + self.offsets = { + station.number: api.Station(station.number, force_fresh=self.force_fresh, force_stale=self.force_stale) + for station in self.cluster.stations + } if self.verbose: print('Using timing offsets from public database.') @@ -529,14 +566,17 @@ def store_reconstructions(self): """ for coincidence, x, y, theta, phi, station_numbers in zip_longest( - self.coincidences, self.core_x, self.core_y, - self.theta, self.phi, self.station_numbers): - self._store_reconstruction(coincidence, x, y, theta, phi, - station_numbers) + self.coincidences, + self.core_x, + self.core_y, + self.theta, + self.phi, + self.station_numbers, + ): + self._store_reconstruction(coincidence, x, y, theta, phi, station_numbers) self.reconstructions.flush() - def _store_reconstruction(self, coincidence, core_x, core_y, theta, phi, - station_numbers): + def _store_reconstruction(self, coincidence, core_x, core_y, theta, phi, station_numbers): """Store single reconstruction""" row = self.reconstructions.row @@ -565,16 +605,11 @@ def _get_or_create_cluster_object(self, cluster): if cluster is None: s_active = self._get_active_stations() - cluster = HiSPARCStations(s_active, - force_fresh=self.force_fresh, - force_stale=self.force_stale) + cluster = HiSPARCStations(s_active, force_fresh=self.force_fresh, force_stale=self.force_stale) if self.verbose: - print('Constructed cluster %s from public database.' - % self.cluster) - else: - # TODO: check cluster object isinstance - if self.verbose: - print('Using cluster %s for metadata.' % self.cluster) + print('Constructed cluster %s from public database.' % self.cluster) + elif self.verbose: + print('Using cluster %s for metadata.' % self.cluster) return cluster def _get_active_stations(self): @@ -584,8 +619,7 @@ def _get_active_stations(self): for s_path in self.coincidences_group.s_index: try: - station_event_table = self.data.get_node(s_path.decode() + - '/events') + station_event_table = self.data.get_node(s_path.decode() + '/events') except tables.NoSuchNodeError: continue if not station_event_table.nrows: @@ -596,11 +630,20 @@ def _get_active_stations(self): class ReconstructESDCoincidencesFromSource(ReconstructESDCoincidences): - - def __init__(self, source_data, dest_data, source_group, dest_group, - overwrite=False, progress=True, verbose=False, - destination='reconstructions', cluster=None, - force_fresh=False, force_stale=False): + def __init__( + self, + source_data, + dest_data, + source_group, + dest_group, + overwrite=False, + progress=True, + verbose=False, + destination='reconstructions', + cluster=None, + force_fresh=False, + force_stale=False, + ): """Initialize the class. :param data: the PyTables datafile. @@ -617,8 +660,16 @@ def __init__(self, source_data, dest_data, source_group, dest_group, """ super().__init__( - source_data, source_group, overwrite, progress, verbose, - destination, cluster, force_fresh, force_stale) + source_data, + source_group, + overwrite, + progress, + verbose, + destination, + cluster, + force_fresh, + force_stale, + ) self.dest_data = dest_data self.dest_group = dest_group @@ -631,17 +682,22 @@ def prepare_output(self): if self.overwrite: self.dest_data.remove_node(dest_path, recursive=True) else: - raise RuntimeError("Reconstructions table already exists for " - "%s, and overwrite is False" % - self.dest_group) + raise RuntimeError( + 'Reconstructions table already exists for %s, and overwrite is False' % self.dest_group, + ) - s_columns = {'s%d' % station.number: tables.BoolCol(pos=p) - for p, station in enumerate(self.cluster.stations, 26)} + s_columns = { + 's%d' % station.number: tables.BoolCol(pos=p) for p, station in enumerate(self.cluster.stations, 26) + } description = ReconstructedCoincidence description.columns.update(s_columns) self.reconstructions = self.dest_data.create_table( - self.dest_group, self.destination, description, - expectedrows=self.coincidences.nrows, createparents=True) + self.dest_group, + self.destination, + description, + expectedrows=self.coincidences.nrows, + createparents=True, + ) try: self.reconstructions._v_attrs.cluster = self.cluster except tables.HDF5ExtError: @@ -649,7 +705,6 @@ def prepare_output(self): class ReconstructSimulatedCoincidences(ReconstructESDCoincidences): - """Reconstruct simulated coincidences. Simulated coincidences use simulated meta-data (e.g. timing offsets) @@ -687,14 +742,11 @@ def _get_or_create_cluster_object(self, cluster): """ if cluster is None: try: - cluster = self.data.get_node_attr(self.coincidences_group, - 'cluster') + cluster = self.data.get_node_attr(self.coincidences_group, 'cluster') if self.verbose: print('Read cluster %s from datafile.' % self.cluster) except (tables.NoSuchNodeError, AttributeError): raise RuntimeError('Unable to read cluster object from HDF') - else: - # TODO: check cluster object - if self.verbose: - print('Using cluster %s for metadata.' % self.cluster) + elif self.verbose: + print('Using cluster %s for metadata.' % self.cluster) return cluster diff --git a/sapphire/analysis/time_deltas.py b/sapphire/analysis/time_deltas.py index 4b692571..30dc07a3 100644 --- a/sapphire/analysis/time_deltas.py +++ b/sapphire/analysis/time_deltas.py @@ -1,26 +1,27 @@ -""" Determine time differences between coincident events +"""Determine time differences between coincident events - Determine time delta between coincidence events from station pairs. +Determine time delta between coincidence events from station pairs. - Example usage:: +Example usage:: - import datetime + import datetime - import tables + import tables - from sapphire import download_coincidences - from sapphire import ProcessTimeDeltas + from sapphire import download_coincidences + from sapphire import ProcessTimeDeltas - START = datetime.datetime(2015, 2, 1) - END = datetime.datetime(2015, 2, 5) + START = datetime.datetime(2015, 2, 1) + END = datetime.datetime(2015, 2, 5) - if __name__ == '__main__': - with tables.open_file('data.h5', 'w') as data: - download_coincidences(data, start=START, end=END) - td = ProcessTimeDeltas(data) - td.determine_and_store_time_deltas() + if __name__ == '__main__': + with tables.open_file('data.h5', 'w') as data: + download_coincidences(data, start=START, end=END) + td = ProcessTimeDeltas(data) + td.determine_and_store_time_deltas() """ + import posixpath import re @@ -38,7 +39,6 @@ class ProcessTimeDeltas: - """Process HiSPARC event coincidences to obtain time deltas. Use this to determine arrival time differences between station pairs which @@ -46,8 +46,7 @@ class ProcessTimeDeltas: """ - def __init__(self, data, coincidence_group='/coincidences', progress=True, - destination='time_deltas'): + def __init__(self, data, coincidence_group='/coincidences', progress=True, destination='time_deltas'): """Initialize the class. :param data: the PyTables datafile. @@ -86,13 +85,12 @@ def find_station_pairs(self): """ s_index = self.cq.s_index re_number = re.compile('[0-9]+$') - s_numbers = [int(re_number.search(s_path.decode('utf-8')).group()) - for s_path in s_index] + s_numbers = [int(re_number.search(s_path.decode('utf-8')).group()) for s_path in s_index] c_index = self.cq.c_index - self.pairs = {(s_numbers[s1], s_numbers[s2]) - for c_idx in c_index - for s1, s2 in combinations(sorted(c_idx[:, 0]), 2)} + self.pairs = { + (s_numbers[s1], s_numbers[s2]) for c_idx in c_index for s1, s2 in combinations(sorted(c_idx[:, 0]), 2) + } def get_detector_offsets(self): """Retrieve the API detector_timing_offset method for all pairs @@ -102,8 +100,7 @@ def get_detector_offsets(self): """ station_numbers = {station for pair in self.pairs for station in pair} - self.detector_timing_offsets = {sn: Station(sn).detector_timing_offset - for sn in station_numbers} + self.detector_timing_offsets = {sn: Station(sn).detector_timing_offset for sn in station_numbers} def determine_time_deltas_for_pair(self, ref_station, station): """Determine the arrival time differences between two stations. @@ -118,8 +115,7 @@ def determine_time_deltas_for_pair(self, ref_station, station): previous_ets = 0 coincidences = self.cq.all([ref_station, station], iterator=True) - coin_events = self.cq.events_from_stations(coincidences, - [ref_station, station]) + coin_events = self.cq.events_from_stations(coincidences, [ref_station, station]) ref_offsets = self.detector_timing_offsets[ref_station] offsets = self.detector_timing_offsets[station] @@ -136,20 +132,18 @@ def determine_time_deltas_for_pair(self, ref_station, station): continue if events[0][0] == ref_station: ref_id = 0 - id = 1 + other_id = 1 else: ref_id = 1 - id = 0 + other_id = 0 ref_event = events[ref_id][1] ref_detector_offsets = ref_offsets(ref_event['timestamp']) - event = events[id][1] + event = events[other_id][1] detector_offsets = offsets(event['timestamp']) - ref_t = station_arrival_time(ref_event, ref_ets, [0, 1, 2, 3], - ref_detector_offsets) - t = station_arrival_time(event, ref_ets, [0, 1, 2, 3], - detector_offsets) + ref_t = station_arrival_time(ref_event, ref_ets, [0, 1, 2, 3], ref_detector_offsets) + t = station_arrival_time(event, ref_ets, [0, 1, 2, 3], detector_offsets) if isnan(t) or isnan(ref_t): continue dt.append(t - ref_t) @@ -165,20 +159,28 @@ def store_time_deltas(self, ext_timestamps, time_deltas, pair): dt_table.remove() except tables.NoSuchNodeError: pass - delta_data = [(ets, int(ets) / 1_000_000_000, int(ets) % 1_000_000_000, - time_delta) - for ets, time_delta in zip(ext_timestamps, time_deltas)] - table = self.data.create_table(table_path, 'time_deltas', TimeDelta, - createparents=True, - expectedrows=len(delta_data)) + delta_data = [ + (ets, int(ets) / 1_000_000_000, int(ets) % 1_000_000_000, time_delta) + for ets, time_delta in zip(ext_timestamps, time_deltas) + ] + table = self.data.create_table( + table_path, + 'time_deltas', + TimeDelta, + createparents=True, + expectedrows=len(delta_data), + ) table.append(delta_data) table.flush() def __repr__(self): if not self.data.isopen: - return "" % self.__class__.__name__ + return '' % self.__class__.__name__ coincidence_group = self.cq.coincidences._v_parent._v_pathname - return ("<%s, data: %r, coincidence_group: %r, progress: %r, " - "destination: %r>" % - (self.__class__.__name__, self.data.filename, - coincidence_group, self.progress, self.destination)) + return '<%s, data: %r, coincidence_group: %r, progress: %r, destination: %r>' % ( + self.__class__.__name__, + self.data.filename, + coincidence_group, + self.progress, + self.destination, + ) diff --git a/sapphire/api.py b/sapphire/api.py index 22cb4ccc..799ae9b4 100644 --- a/sapphire/api.py +++ b/sapphire/api.py @@ -1,26 +1,27 @@ -""" Access the HiSPARC public database API. +"""Access the HiSPARC public database API. - This provides easy classes and functions to access the HiSPARC - publicdb API. This takes care of the url retrieval and conversion - from JSON to Python dictionaries. +This provides easy classes and functions to access the HiSPARC +publicdb API. This takes care of the url retrieval and conversion +from JSON to Python dictionaries. - Example usage: +Example usage: - .. code-block:: python +.. code-block:: python - >>> from sapphire import Station - >>> stations = [5, 3102, 504, 7101, 8008, 13005] - >>> clusters = [Station(station).cluster for station in stations] - >>> for station, cluster in zip(stations, clusters): - ... print('Station %d is in cluster %s.' % (station, cluster)) - Station 5 is in cluster Amsterdam. - Station 3102 is in cluster Leiden. - Station 504 is in cluster Amsterdam. - Station 7101 is in cluster Enschede. - Station 8008 is in cluster Eindhoven. - Station 13005 is in cluster Bristol. + >>> from sapphire import Station + >>> stations = [5, 3102, 504, 7101, 8008, 13005] + >>> clusters = [Station(station).cluster for station in stations] + >>> for station, cluster in zip(stations, clusters): + ... print('Station %d is in cluster %s.' % (station, cluster)) + Station 5 is in cluster Amsterdam. + Station 3102 is in cluster Leiden. + Station 504 is in cluster Amsterdam. + Station 7101 is in cluster Enschede. + Station 8008 is in cluster Eindhoven. + Station 13005 is in cluster Bristol. """ + import datetime import json import logging @@ -28,7 +29,7 @@ from functools import cached_property from io import BytesIO -from os import extsep, path +from pathlib import Path from urllib.error import HTTPError, URLError from urllib.parse import urljoin from urllib.request import urlopen @@ -40,7 +41,7 @@ logger = logging.getLogger(__name__) -LOCAL_BASE = path.join(path.dirname(__file__), 'data') +LOCAL_BASE = Path(__file__).parent / 'data' def get_api_base(): @@ -52,7 +53,6 @@ def get_src_base(): class API: - """Base API class This provided the methods to retrieve data from the API. The results @@ -63,21 +63,21 @@ class API: """ urls = { - "stations": 'stations/', - "stations_in_subcluster": 'subclusters/{subcluster_number}/', - "subclusters": 'subclusters/', - "subclusters_in_cluster": 'clusters/{cluster_number}/', - "clusters": 'clusters/', - "clusters_in_country": 'countries/{country_number}/', - "countries": 'countries/', - "stations_with_data": 'stations/data/{year}/{month}/{day}/', - "stations_with_weather": 'stations/weather/{year}/{month}/{day}/', - "station_info": 'station/{station_number}/', - "has_data": 'station/{station_number}/data/{year}/{month}/{day}/', - "has_weather": 'station/{station_number}/weather/{year}/{month}/{day}/', - "configuration": 'station/{station_number}/config/{year}/{month}/{day}/', - "number_of_events": 'station/{station_number}/num_events/{year}/{month}/{day}/{hour}/', - "event_trace": 'station/{station_number}/trace/{ext_timestamp}/' + 'stations': 'stations/', + 'stations_in_subcluster': 'subclusters/{subcluster_number}/', + 'subclusters': 'subclusters/', + 'subclusters_in_cluster': 'clusters/{cluster_number}/', + 'clusters': 'clusters/', + 'clusters_in_country': 'countries/{country_number}/', + 'countries': 'countries/', + 'stations_with_data': 'stations/data/{year}/{month}/{day}/', + 'stations_with_weather': 'stations/weather/{year}/{month}/{day}/', + 'station_info': 'station/{station_number}/', + 'has_data': 'station/{station_number}/data/{year}/{month}/{day}/', + 'has_weather': 'station/{station_number}/weather/{year}/{month}/{day}/', + 'configuration': 'station/{station_number}/config/{year}/{month}/{day}/', + 'number_of_events': 'station/{station_number}/num_events/{year}/{month}/{day}/{hour}/', + 'event_trace': 'station/{station_number}/trace/{ext_timestamp}/', } src_urls = { @@ -97,7 +97,7 @@ class API: 'trigger': 'trigger/{station_number}/', 'layout': 'layout/{station_number}/', 'detector_timing_offsets': 'detector_timing_offsets/{station_number}/', - 'station_timing_offsets': 'station_timing_offsets/{station_1}/{station_2}/' + 'station_timing_offsets': 'station_timing_offsets/{station_1}/{station_2}/', } def __init__(self, force_fresh=False, force_stale=False): @@ -120,23 +120,23 @@ def _get_json(self, urlpath): """ urlpath = urlpath.rstrip('/') if self.force_fresh and self.force_stale: - raise Exception('Can not force fresh and stale simultaneously.') + raise ValueError('Can not force fresh and stale simultaneously.') try: if self.force_stale: - raise Exception + raise ValueError('Should not get data from server') json_data = self._retrieve_url(urlpath, base=get_api_base()) data = json.loads(json_data) - except Exception: + except Exception as remote_error: if self.force_fresh: - raise Exception('Couldn\'t get requested data from server.') - localpath = path.join(LOCAL_BASE, urlpath + extsep + 'json') + raise RuntimeError("Couldn't get requested data from server.") from remote_error + localpath = LOCAL_BASE / f'{urlpath}.json' try: - with open(localpath) as localdata: + with localpath.open() as localdata: data = json.load(localdata) - except Exception: + except Exception as local_error: if self.force_stale: - raise Exception('Couldn\'t find requested data locally.') - raise Exception('Couldn\'t get requested data from server nor find it locally.') + raise RuntimeError("Couldn't find requested data locally.") from local_error + raise RuntimeError("Couldn't get requested data from server nor find it locally.") from remote_error if not self.force_stale: warnings.warn('Using local data. Possibly outdated.') @@ -152,30 +152,29 @@ def _get_tsv(self, urlpath, names=None): """ urlpath = urlpath.rstrip('/') if self.force_fresh and self.force_stale: - raise Exception('Can not force fresh and stale simultaneously.') + raise ValueError('Can not force fresh and stale simultaneously.') try: if self.force_stale: - raise Exception + raise ValueError('Should not get data from server') tsv_data = self._retrieve_url(urlpath, base=get_src_base()) - except Exception: + except Exception as remote_error: if self.force_fresh: - raise Exception('Couldn\'t get requested data from server.') - localpath = path.join(LOCAL_BASE, urlpath + extsep + 'tsv') + raise RuntimeError("Couldn't get requested data from server.") from remote_error + localpath = LOCAL_BASE / f'{urlpath}.tsv' try: with warnings.catch_warnings(): warnings.filterwarnings('ignore') data = genfromtxt(localpath, delimiter='\t', dtype=None, names=names) - except Exception: + except Exception as local_error: if self.force_stale: - raise Exception('Couldn\'t find requested data locally.') - raise Exception('Couldn\'t get requested data from server nor find it locally.') + raise RuntimeError("Couldn't find requested data locally.") from local_error + raise RuntimeError("Couldn't get requested data from server nor find it locally.") from remote_error if not self.force_stale: warnings.warn('Using local data. Possibly outdated.') else: with warnings.catch_warnings(): warnings.filterwarnings('ignore') - data = genfromtxt(BytesIO(tsv_data.encode('utf-8')), - delimiter='\t', dtype=None, names=names) + data = genfromtxt(BytesIO(tsv_data.encode('utf-8')), delimiter='\t', dtype=None, names=names) return atleast_1d(data) @@ -192,13 +191,13 @@ def _retrieve_url(urlpath, base=None): base = get_api_base() url = urljoin(base, urlpath + '/' if urlpath else '') - logging.debug('Getting: ' + url) + logging.debug(f'Getting: {url}') try: result = urlopen(url).read().decode('utf-8') - except HTTPError as e: - raise Exception('A HTTP %d error occured for the url: %s' % (e.code, url)) + except HTTPError as error: + raise RuntimeError(f'A HTTP {error.code} error occured for the url: {url}') except URLError: - raise Exception('An URL error occured.') + raise RuntimeError('An URL error occured.') return result @@ -218,18 +217,17 @@ def check_connection(): @staticmethod def validate_partial_date(year='', month='', day='', hour=''): if year == '' and (month != '' or day != '' or hour != ''): - raise Exception('You must also specify the year') + raise ValueError('You must also specify the year') elif month == '' and (day != '' or hour != ''): - raise Exception('You must also specify the month') + raise ValueError('You must also specify the month') elif day == '' and hour != '': - raise Exception('You must also specify the day') + raise ValueError('You must also specify the day') def __repr__(self): - return f"{self.__class__.__name__}(force_fresh={self.force_fresh}, force_stale={self.force_stale})" + return f'{self.__class__.__name__}(force_fresh={self.force_fresh}, force_stale={self.force_stale})' class Network(API): - """Get info about the network (countries/clusters/subclusters/stations)""" @cached_property @@ -432,12 +430,12 @@ def coincidence_number(self, year, month, day): @staticmethod def validate_numbers(country=None, cluster=None, subcluster=None): - if country is not None and country % 10000: - raise Exception('Invalid country number, must be multiple of 10000.') + if country is not None and country % 10_000: + raise ValueError('Invalid country number, must be multiple of 10000.') if cluster is not None and cluster % 1000: - raise Exception('Invalid cluster number, must be multiple of 1000.') + raise ValueError('Invalid cluster number, must be multiple of 1000.') if subcluster is not None and subcluster % 100: - raise Exception('Invalid subcluster number, must be multiple of 100.') + raise ValueError('Invalid subcluster number, must be multiple of 100.') def uptime(self, stations, start=None, end=None): """Get number of hours for which the given stations have been simultaneously active @@ -457,8 +455,7 @@ def uptime(self, stations, start=None, end=None): stations = [stations] for station in stations: - data[station] = Station(station, force_fresh=self.force_fresh, - force_stale=self.force_stale).event_time() + data[station] = Station(station, force_fresh=self.force_fresh, force_stale=self.force_stale).event_time() first = min(values['timestamp'][0] for values in data.values()) last = max(values['timestamp'][-1] for values in data.values()) @@ -466,11 +463,16 @@ def uptime(self, stations, start=None, end=None): len_array = (last - first) // 3600 + 1 all_active = ones(len_array) - for station in data.keys(): + minimum_events_per_hour = 500 + maximum_events_per_hour = 5_000 + + for station in data: is_active = zeros(len_array) start_i = (data[station]['timestamp'][0] - first) // 3600 end_i = start_i + len(data[station]) - is_active[start_i:end_i] = (data[station]['counts'] > 500) & (data[station]['counts'] < 5000) + is_active[start_i:end_i] = (data[station]['counts'] > minimum_events_per_hour) & ( + data[station]['counts'] < maximum_events_per_hour + ) all_active = logical_and(all_active, is_active) # filter start, end @@ -488,7 +490,6 @@ def uptime(self, stations, start=None, end=None): class Station(API): - """Access data about a single station""" def __init__(self, station, force_fresh=False, force_stale=False): @@ -501,7 +502,7 @@ def __init__(self, station, force_fresh=False, force_stale=False): """ if force_fresh and force_stale: - raise Exception('Can not force fresh and stale simultaneously.') + raise ValueError('Can not force fresh and stale simultaneously.') if station not in Network(force_fresh=force_fresh, force_stale=force_stale).station_numbers(): warnings.warn('Possibly invalid station, or without config.') self.force_fresh = force_fresh @@ -539,8 +540,12 @@ def config(self, date=None): """ if date is None: date = datetime.date.today() - path = (self.urls['configuration'] - .format(station_number=self.station, year=date.year, month=date.month, day=date.day)) + path = self.urls['configuration'].format( + station_number=self.station, + year=date.year, + month=date.month, + day=date.day, + ) return self._get_json(path) @@ -558,8 +563,13 @@ def n_events(self, year='', month='', day='', hour=''): """ self.validate_partial_date(year, month, day, hour) - path = (self.urls['number_of_events'] - .format(station_number=self.station, year=year, month=month, day=day, hour=hour)) + path = self.urls['number_of_events'].format( + station_number=self.station, + year=year, + month=month, + day=day, + hour=hour, + ) return self._get_json(path) def has_data(self, year='', month='', day=''): @@ -634,8 +644,7 @@ def pulse_height(self, year, month, day): """ columns = ('pulseheight', 'ph1', 'ph2', 'ph3', 'ph4') - path = self.src_urls['pulseheight'].format(station_number=self.station, - year=year, month=month, day=day) + path = self.src_urls['pulseheight'].format(station_number=self.station, year=year, month=month, day=day) return self._get_tsv(path, names=columns) def pulse_integral(self, year, month, day): @@ -679,8 +688,7 @@ def barometer(self, year, month, day): """ columns = ('timestamp', 'air_pressure') - path = self.src_urls['barometer'].format(station_number=self.station, - year=year, month=month, day=day) + path = self.src_urls['barometer'].format(station_number=self.station, year=year, month=month, day=day) return self._get_tsv(path, names=columns) def temperature(self, year, month, day): @@ -691,8 +699,7 @@ def temperature(self, year, month, day): """ columns = ('timestamp', 'temperature') - path = self.src_urls['temperature'].format(station_number=self.station, - year=year, month=month, day=day) + path = self.src_urls['temperature'].format(station_number=self.station, year=year, month=month, day=day) return self._get_tsv(path, names=columns) @cached_property @@ -718,8 +725,7 @@ def electronic(self, timestamp=None): idx = -1 else: idx = get_active_index(electronics['timestamp'], timestamp) - electronic = [electronics[idx][field] for field in - ('primary', 'secondary', 'primary_fpga', 'secondary_fpga')] + electronic = [electronics[idx][field] for field in ('primary', 'secondary', 'primary_fpga', 'secondary_fpga')] return electronic @cached_property @@ -745,7 +751,7 @@ def voltage(self, timestamp=None): idx = -1 else: idx = get_active_index(voltages['timestamp'], timestamp) - voltage = [voltages[idx]['voltage%d' % i] for i in range(1, 5)] + voltage = [voltages[idx][f'voltage{detector_id}'] for detector_id in range(1, 5)] return voltage @cached_property @@ -771,7 +777,7 @@ def current(self, timestamp=None): idx = -1 else: idx = get_active_index(currents['timestamp'], timestamp) - current = [currents[idx]['current%d' % i] for i in range(1, 5)] + current = [currents[idx][f'current{detector_id}'] for detector_id in range(1, 5)] return current @cached_property @@ -799,9 +805,11 @@ def gps_location(self, timestamp=None): timestamp = process_time(timestamp) locations = self.gps_locations idx = get_active_index(locations['timestamp'], timestamp) - location = {'latitude': locations[idx]['latitude'], - 'longitude': locations[idx]['longitude'], - 'altitude': locations[idx]['altitude']} + location = { + 'latitude': locations[idx]['latitude'], + 'longitude': locations[idx]['longitude'], + 'altitude': locations[idx]['altitude'], + } return location @cached_property @@ -811,10 +819,21 @@ def triggers(self): :return: array of timestamps and values. """ - columns = ('timestamp', - 'low1', 'low2', 'low3', 'low4', - 'high1', 'high2', 'high3', 'high4', - 'n_low', 'n_high', 'and_or', 'external') + columns = ( + 'timestamp', + 'low1', + 'low2', + 'low3', + 'low4', + 'high1', + 'high2', + 'high3', + 'high4', + 'n_low', + 'n_high', + 'and_or', + 'external', + ) path = self.src_urls['trigger'].format(station_number=self.station) return self._get_tsv(path, names=columns) @@ -830,10 +849,10 @@ def trigger(self, timestamp=None): idx = -1 else: idx = get_active_index(triggers['timestamp'], timestamp) - thresholds = [[triggers[idx]['%s%d' % (t, i)] - for t in ('low', 'high')] - for i in range(1, 5)] - trigger = [triggers[idx][t] for t in ('n_low', 'n_high', 'and_or', 'external')] + thresholds = [ + [triggers[idx][f'{threshold}{detector_id}'] for threshold in ('low', 'high')] for detector_id in range(1, 5) + ] + trigger = [triggers[idx][trigger_option] for trigger_option in ('n_low', 'n_high', 'and_or', 'external')] return thresholds, trigger @cached_property @@ -843,11 +862,25 @@ def station_layouts(self): :return: array of timestamps and values. """ - columns = ('timestamp', - 'radius1', 'alpha1', 'height1', 'beta1', - 'radius2', 'alpha2', 'height2', 'beta2', - 'radius3', 'alpha3', 'height3', 'beta3', - 'radius4', 'alpha4', 'height4', 'beta4') + columns = ( + 'timestamp', + 'radius1', + 'alpha1', + 'height1', + 'beta1', + 'radius2', + 'alpha2', + 'height2', + 'beta2', + 'radius3', + 'alpha3', + 'height3', + 'beta3', + 'radius4', + 'alpha4', + 'height4', + 'beta4', + ) base = self.src_urls['layout'] path = base.format(station_number=self.station) return self._get_tsv(path, names=columns) @@ -864,9 +897,10 @@ def station_layout(self, timestamp=None): idx = -1 else: idx = get_active_index(station_layouts['timestamp'], timestamp) - station_layout = [[station_layouts[idx]['%s%d' % (c, i)] - for c in ('radius', 'alpha', 'height', 'beta')] - for i in range(1, 5)] + station_layout = [ + [station_layouts[idx][f'{coordinate}{detector_id}'] for coordinate in ('radius', 'alpha', 'height', 'beta')] + for detector_id in range(1, 5) + ] return station_layout @cached_property @@ -893,7 +927,7 @@ def detector_timing_offset(self, timestamp=None): idx = -1 else: idx = get_active_index(detector_timing_offsets['timestamp'], timestamp) - detector_timing_offset = [detector_timing_offsets[idx]['offset%d' % i] for i in range(1, 5)] + detector_timing_offset = [detector_timing_offsets[idx][f'offset{detector_id}'] for detector_id in range(1, 5)] return detector_timing_offset @@ -906,7 +940,7 @@ def station_timing_offsets(self, reference_station): """ if reference_station == self.station: - raise Exception('Reference station cannot be the same station') + raise ValueError('Reference station cannot be the same station') if reference_station > self.station: station_1, station_2 = self.station, reference_station toggle_sign = True @@ -931,7 +965,7 @@ def station_timing_offset(self, reference_station, timestamp=None): """ if self.station == reference_station: - return (0., 0.) + return (0.0, 0.0) station_timing_offsets = self.station_timing_offsets(reference_station) if timestamp is None: @@ -943,5 +977,9 @@ def station_timing_offset(self, reference_station, timestamp=None): return station_timing_offset def __repr__(self): - return ("%s(%d, force_fresh=%s, force_stale=%s)" % - (self.__class__.__name__, self.station, self.force_fresh, self.force_stale)) + return '%s(%d, force_fresh=%s, force_stale=%s)' % ( + self.__class__.__name__, + self.station, + self.force_fresh, + self.force_stale, + ) diff --git a/sapphire/clusters.py b/sapphire/clusters.py index a778436b..4ac62f05 100644 --- a/sapphire/clusters.py +++ b/sapphire/clusters.py @@ -1,19 +1,19 @@ -""" Define HiSPARC detectors, stations and clusters. +"""Define HiSPARC detectors, stations and clusters. - The :class:`BaseCluster` defines a HiSPARC cluster consisting of one or - more stations. The :class:`Station` defines a HiSPARC station, - consisting of one or more :class:`Detector` objects. +The :class:`BaseCluster` defines a HiSPARC cluster consisting of one or +more stations. The :class:`Station` defines a HiSPARC station, +consisting of one or more :class:`Detector` objects. - To easily create a cluster object for a specific set of real HiSPARC - stations the :class:`HiSPARCStations` can be used, for example:: +To easily create a cluster object for a specific set of real HiSPARC +stations the :class:`HiSPARCStations` can be used, for example:: - >>> from sapphire import HiSPARCStations - >>> cluster = HiSPARCStations([102, 104, 105], force_stale=True) +>>> from sapphire import HiSPARCStations +>>> cluster = HiSPARCStations([102, 104, 105], force_stale=True) - The use of ``force_stale`` forces the use of local data, which - is much faster to load than data from the server. +The use of ``force_stale`` forces the use of local data, which +is much faster to load than data from the server. - These cluster objects are mainly used by simulations and reconstructions. +These cluster objects are mainly used by simulations and reconstructions. """ @@ -31,7 +31,7 @@ class Detector: """A HiSPARC detector""" - _detector_size = (.5, 1.) + _detector_size = (0.5, 1.0) def __init__(self, station, position, orientation='UD', detector_timestamps=None): """Initialize detector @@ -51,27 +51,26 @@ def __init__(self, station, position, orientation='UD', detector_timestamps=None if detector_timestamps is None: detector_timestamps = [0] self.station = station - if hasattr(position[0], "__len__"): + if hasattr(position[0], '__len__'): self.x = position[0] self.y = position[1] - self.z = position[2] if len(position) == 3 else [0.] * len(self.x) + self.z = position[2] if len(position) == 3 else [0.0] * len(self.x) else: self.x = [position[0]] self.y = [position[1]] - self.z = [position[2]] if len(position) == 3 else [0.] + self.z = [position[2]] if len(position) == 3 else [0.0] if isinstance(orientation, str) and orientation == 'UD': self.orientation = [0] * len(self.x) elif isinstance(orientation, str) and orientation == 'LR': self.orientation = [pi / 2] * len(self.x) + elif hasattr(orientation, '__len__'): + self.orientation = orientation else: - if hasattr(orientation, "__len__"): - self.orientation = orientation - else: - self.orientation = [orientation] + self.orientation = [orientation] if len(detector_timestamps) == len(self.x): self.timestamps = detector_timestamps else: - raise Exception('Number of timestamps must equal number of postions') + raise ValueError('Number of timestamps must equal number of postions') self.index = -1 def _update_timestamp(self, timestamp): @@ -161,14 +160,13 @@ def get_corners(self): # cluster frame sina = sin(alpha_station) cosa = cos(alpha_station) - corners = [(x_station + xc * cosa - yc * sina, y_station + xc * sina + yc * cosa) - for xc, yc in corners] + corners = [(x_station + xc * cosa - yc * sina, y_station + xc * sina + yc * cosa) for xc, yc in corners] return corners def __repr__(self): - id = next(i for i, d in enumerate(self.station.detectors) if self is d) - return "<%s, id: %d, station: %r>" % (self.__class__.__name__, id, self.station) + detector_id = next(i for i, d in enumerate(self.station.detectors) if self is d) + return '<%s, id: %d, station: %r>' % (self.__class__.__name__, detector_id, self.station) class Station: @@ -176,9 +174,17 @@ class Station: _detectors = None - def __init__(self, cluster, station_id, position, angle=None, - detectors=None, station_timestamps=None, - detector_timestamps=None, number=None): + def __init__( + self, + cluster, + station_id, + position, + angle=None, + detectors=None, + station_timestamps=None, + detector_timestamps=None, + number=None, + ): """Initialize station :param cluster: cluster this station is a part of @@ -208,17 +214,17 @@ def __init__(self, cluster, station_id, position, angle=None, detector_timestamps = [0] self.cluster = cluster self.station_id = station_id - if hasattr(position[0], "__len__"): + if hasattr(position[0], '__len__'): self.x = position[0] self.y = position[1] - self.z = position[2] if len(position) == 3 else [0.] * len(self.x) + self.z = position[2] if len(position) == 3 else [0.0] * len(self.x) else: self.x = [position[0]] self.y = [position[1]] - self.z = [position[2]] if len(position) == 3 else [0.] + self.z = [position[2]] if len(position) == 3 else [0.0] if angle is None: - self.angle = [0.] * len(self.x) - elif hasattr(angle, "__len__"): + self.angle = [0.0] * len(self.x) + elif hasattr(angle, '__len__'): self.angle = angle else: self.angle = [angle] @@ -227,15 +233,14 @@ def __init__(self, cluster, station_id, position, angle=None, if len(station_timestamps) == len(self.x): self.timestamps = station_timestamps else: - raise Exception('Number of timestamps must equal number of postions') + raise ValueError('Number of timestamps must equal number of postions') if detectors is None: # detector positions for a standard station station_size = 10 a = station_size / 2 b = a * sqrt(3) - detectors = [((0, b, 0), 'UD'), ((0, b / 3, 0), 'UD'), - ((-a, 0, 0), 'LR'), ((a, 0, 0), 'LR')] + detectors = [((0, b, 0), 'UD'), ((0, b / 3, 0), 'UD'), ((-a, 0, 0), 'LR'), ((a, 0, 0), 'LR')] for position, orientation in detectors: self._add_detector(position, orientation, detector_timestamps) @@ -277,7 +282,7 @@ def get_area(self, detector_ids=None): """ if detector_ids is not None: - return sum(self._detectors[id].get_area() for id in detector_ids) + return sum(self._detectors[detector_id].get_area() for detector_id in detector_ids) else: return sum(d.get_area() for d in self._detectors) @@ -334,9 +339,9 @@ def get_lla_coordinates(self): transform = geographic.FromWGS84ToENUTransformation(lla) latitude, longitude, altitude = transform.enu_to_lla(enu) - latitude = latitude if abs(latitude) > 1e-7 else 0. - longitude = longitude if abs(longitude) > 1e-7 else 0. - altitude = altitude if abs(altitude) > 1e-7 else 0. + latitude = latitude if abs(latitude) > 1e-7 else 0.0 + longitude = longitude if abs(longitude) > 1e-7 else 0.0 + altitude = altitude if abs(altitude) > 1e-7 else 0.0 return latitude, longitude, altitude @@ -380,8 +385,12 @@ def calc_center_of_mass_coordinates(self): return x0, y0, z0 def __repr__(self): - return ("<%s, id: %d, number: %d, cluster: %r>" % - (self.__class__.__name__, self.station_id, self.number, self.cluster)) + return '<%s, id: %d, number: %d, cluster: %r>' % ( + self.__class__.__name__, + self.station_id, + self.number, + self.cluster, + ) class BaseCluster: @@ -400,7 +409,7 @@ def __init__(self, position=(0, 0, 0), angle=0, lla=(52.35592417, 4.95114402, 56 """ self.x = position[0] self.y = position[1] - self.z = position[2] if len(position) == 3 else 0. + self.z = position[2] if len(position) == 3 else 0.0 self.alpha = angle self.lla = lla # Set initial timestamp in the future to use latest positions @@ -417,9 +426,15 @@ def set_timestamp(self, timestamp): for station in self.stations: station._update_timestamp(self._timestamp) - def _add_station(self, position, angle=None, detectors=None, - station_timestamps=None, detector_timestamps=None, - number=None): + def _add_station( + self, + position, + angle=None, + detectors=None, + station_timestamps=None, + detector_timestamps=None, + number=None, + ): """Add a station to the cluster :param position: x,y,z position of the station relative to @@ -454,9 +469,9 @@ def _add_station(self, position, angle=None, detectors=None, self._stations = [] station_id = len(self._stations) - self._stations.append(Station(self, station_id, position, angle, - detectors, station_timestamps, - detector_timestamps, number)) + self._stations.append( + Station(self, station_id, position, angle, detectors, station_timestamps, detector_timestamps, number), + ) def set_center_off_mass_at_origin(self): """Set the cluster center of mass to (0, 0, 0)""" @@ -584,9 +599,7 @@ def calc_center_of_mass_coordinates(self): absolute coordinate system """ - x, y, z = zip(*[detector.get_coordinates() - for station in self.stations - for detector in station.detectors]) + x, y, z = zip(*[detector.get_coordinates() for station in self.stations for detector in station.detectors]) x0 = np.nanmean(x) y0 = np.nanmean(y) @@ -640,11 +653,10 @@ def calc_horizontal_distance_between_stations(self, s1, s2): return self._distance(*xy) def __repr__(self): - return "<%s>" % self.__class__.__name__ + return '<%s>' % self.__class__.__name__ class CompassStations(BaseCluster): - """Add detectors to stations using compass coordinates Compass coordinates consist of r, alpha, z, beta. These define @@ -656,8 +668,7 @@ class CompassStations(BaseCluster): """ - def _add_station(self, position, detectors, station_timestamps=None, - detector_timestamps=None, number=None): + def _add_station(self, position, detectors, station_timestamps=None, detector_timestamps=None, number=None): """Add a station to the cluster :param position: x,y,z coordinates of the station relative @@ -691,15 +702,12 @@ def _add_station(self, position, detectors, station_timestamps=None, ... number=104) """ - detectors = [(axes.compass_to_cartesian(r, alpha, z), np.radians(beta)) - for r, alpha, z, beta in detectors] + detectors = [(axes.compass_to_cartesian(r, alpha, z), np.radians(beta)) for r, alpha, z, beta in detectors] - super()._add_station( - position, None, detectors, station_timestamps, detector_timestamps, number) + super()._add_station(position, None, detectors, station_timestamps, detector_timestamps, number) class SimpleCluster(BaseCluster): - """Define a simple cluster containing four stations :param size: This value is the distance between the three outer stations. @@ -751,7 +759,6 @@ def __init__(self): class SingleDiamondStation(BaseCluster): - """Define a cluster containing a single diamond shaped station Detectors 1, 3 and 4 are in the usual position for a 4 detector @@ -766,14 +773,12 @@ def __init__(self): station_size = 10 a = station_size / 2 b = a * sqrt(3) - detectors = [((0., b, 0), 'UD'), ((a * 2, b, 0), 'UD'), - ((-a, 0., 0), 'LR'), ((a, 0., 0), 'LR')] + detectors = [((0.0, b, 0), 'UD'), ((a * 2, b, 0), 'UD'), ((-a, 0.0, 0), 'LR'), ((a, 0.0, 0), 'LR')] self._add_station((0, 0, 0), 0, detectors) class HiSPARCStations(CompassStations): - """A cluster containing any real station from the HiSPARC network The gps position and number of detectors are taken from the API. @@ -830,8 +835,7 @@ def __init__(self, stations, skip_missing=False, force_fresh=False, force_stale= try: detectors = station_info.station_layouts fields = ('radius', 'alpha', 'height', 'beta') - razbs = [[detectors['%s%d' % (field, i)] for field in fields] - for i in range(1, n_detectors + 1)] + razbs = [[detectors['%s%d' % (field, i)] for field in fields] for i in range(1, n_detectors + 1)] detector_ts = detectors['timestamp'] except Exception: missing_detectors.append(station) @@ -842,7 +846,7 @@ def __init__(self, stations, skip_missing=False, force_fresh=False, force_stale= d = 10 / sqrt(3) razbs = [(d, 0, 0, 0), (0, 0, 0, 0), (d, -120, 0, 90), (d, 120, 0, 90)] else: - raise RuntimeError("Detector count unknown for station %d." % station) + raise RuntimeError('Detector count unknown for station %d.' % station) detector_ts = [0] self._add_station(enu, razbs, station_ts, detector_ts, station) @@ -850,18 +854,19 @@ def __init__(self, stations, skip_missing=False, force_fresh=False, force_stale= self.set_center_off_mass_at_origin() if len(missing_gps): - warnings.warn('Could not get GPS location for stations: %s. ' - 'Those stations are excluded.' % str(missing_gps)) + warnings.warn( + 'Could not get GPS location for stations: %s. Those stations are excluded.' % str(missing_gps), + ) if len(missing_detectors): - warnings.warn('Could not get detector layout for stations %s, ' - 'defaults will be used!' % str(missing_detectors)) + warnings.warn( + 'Could not get detector layout for stations %s, defaults will be used!' % str(missing_detectors), + ) def __repr__(self): - return f"{self.__class__.__name__}({[s.number for s in self.stations]!r})" + return f'{self.__class__.__name__}({[s.number for s in self.stations]!r})' class ScienceParkCluster(HiSPARCStations): - """A cluster containing stations from the Science Park subcluster :param stations: A list of station numbers to include. Only stations @@ -870,8 +875,7 @@ class ScienceParkCluster(HiSPARCStations): """ - def __init__(self, stations=None, skip_missing=False, force_fresh=False, - force_stale=False): + def __init__(self, stations=None, skip_missing=False, force_fresh=False, force_stale=False): if stations is None: network = api.Network(force_fresh, force_stale) stations = [sn for sn in network.station_numbers(subcluster=500) if sn != 507] @@ -881,7 +885,6 @@ def __init__(self, stations=None, skip_missing=False, force_fresh=False, class HiSPARCNetwork(HiSPARCStations): - """A cluster containing all station from the HiSPARC network""" def __init__(self, force_fresh=False, force_stale=False): @@ -891,7 +894,7 @@ def __init__(self, force_fresh=False, force_stale=False): super().__init__(stations, skip_missing, force_fresh, force_stale) def __repr__(self): - return "<%s>" % self.__class__.__name__ + return '<%s>' % self.__class__.__name__ def flatten_cluster(cluster): @@ -904,6 +907,6 @@ def flatten_cluster(cluster): """ for station in cluster.stations: - station.z = [0.] * len(station.z) + station.z = [0.0] * len(station.z) for detector in station.detectors: - detector.z = [0.] * len(detector.z) + detector.z = [0.0] * len(detector.z) diff --git a/sapphire/corsika/__init__.py b/sapphire/corsika/__init__.py index 77441261..26ccda5f 100644 --- a/sapphire/corsika/__init__.py +++ b/sapphire/corsika/__init__.py @@ -30,9 +30,7 @@ convert values in units used by CORSIKA to HiSPARC units """ + from . import corsika_queries, particles, reader, units -__all__ = ['corsika_queries', - 'particles', - 'reader', - 'units'] +__all__ = ['corsika_queries', 'particles', 'reader', 'units'] diff --git a/sapphire/corsika/blocks.py b/sapphire/corsika/blocks.py index 70df164e..ad48abc4 100644 --- a/sapphire/corsika/blocks.py +++ b/sapphire/corsika/blocks.py @@ -10,21 +10,22 @@ import math import struct -import numpy +import numpy as np from . import particles, units try: from numba import jit except ImportError: + def jit(func): return func # All sizes are in bytes -class Format: +class Format: """The binary format information of the file. As specified in the CORSIKA user manual, Section 10.2.1. @@ -57,8 +58,7 @@ def __init__(self): self.particles_per_subblock = 39 # Full particle sub block - self.particles_format = (self.particle_format * - self.particles_per_subblock) + self.particles_format = self.particle_format * self.particles_per_subblock self.particles_size = self.particle_size * self.particles_per_subblock def __repr__(self): @@ -67,8 +67,8 @@ def __repr__(self): # From here on, things should not depend on the field size as everything is -class RunHeader: +class RunHeader: """The run header sub-block As specified in the CORSIKA user manual, Table 7. @@ -82,7 +82,7 @@ def __init__(self, subblock): self.version = subblock[3] self.observation_levels = subblock[4] - self.observation_heights = numpy.array(subblock[5:15]) * units.cm + self.observation_heights = np.array(subblock[5:15]) * units.cm self.spectral_slope = subblock[15] self.min_energy = subblock[16] * units.GeV @@ -96,7 +96,7 @@ def __init__(self, subblock): self.cutoff_electrons = subblock[22] * units.GeV self.cutoff_photons = subblock[23] * units.GeV - self.C = numpy.array(subblock[24:74]) + self.C = np.array(subblock[24:74]) self.x_inclined = subblock[74] self.y_inclined = subblock[75] @@ -105,25 +105,25 @@ def __init__(self, subblock): self.phi_inclined = subblock[78] self.n_showers = subblock[92] - self.CKA = numpy.array(subblock[94:134]) - self.CETA = numpy.array(subblock[134:139]) - self.CSTRBA = numpy.array(subblock[139:150]) + self.CKA = np.array(subblock[94:134]) + self.CETA = np.array(subblock[134:139]) + self.CSTRBA = np.array(subblock[139:150]) self.x_scatter_Cherenkov = subblock[247] self.y_scatter_Cherenkov = subblock[248] - self.atmospheric_layer_boundaries = numpy.array(subblock[249:254]) - self.a_atmospheric = numpy.array(subblock[254:259]) - self.b_atmospheric = numpy.array(subblock[259:264]) - self.c_atmospheric = numpy.array(subblock[264:269]) + self.atmospheric_layer_boundaries = np.array(subblock[249:254]) + self.a_atmospheric = np.array(subblock[254:259]) + self.b_atmospheric = np.array(subblock[259:264]) + self.c_atmospheric = np.array(subblock[264:269]) self.NFLAIN = subblock[269] self.NFLDIF = subblock[270] - self.NFLPIF = numpy.floor(subblock[271] / 100) + self.NFLPIF = np.floor(subblock[271] / 100) self.NFLPI0 = subblock[271] % 100 - self.NFRAGM = numpy.floor(subblock[272] / 100) + self.NFRAGM = np.floor(subblock[272] / 100) self.NFLCHE = subblock[272] % 100 def thickness_to_height(self, thickness): - """"Calculate height (m) for given thickness (gramms/cm**2) + """ "Calculate height (m) for given thickness (gramms/cm**2) As specified in the CORSIKA user manual, Appendix D. @@ -174,7 +174,6 @@ def height_to_thickness(self, height): class EventHeader: - """The event header sub-block As specified in the CORSIKA user manual, Table 8. @@ -192,7 +191,7 @@ def __init__(self, subblock): self.first_interaction_altitude = subblock[6] * units.cm self.p_x = subblock[7] * units.GeV self.p_y = subblock[8] * units.GeV - self.p_z = - subblock[9] * units.GeV # Same direction as axis + self.p_z = -subblock[9] * units.GeV # Same direction as axis self.zenith = subblock[10] * units.rad # CORSIKA coordinate conventions are shown in Figure 1 of the manual. @@ -203,7 +202,7 @@ def __init__(self, subblock): # So finally we need to subtract pi/2 rad from the azimuth and # normalize its range. azimuth_corsika = subblock[11] * units.rad - azimuth = azimuth_corsika - (math.pi / 2.) + azimuth = azimuth_corsika - (math.pi / 2.0) if azimuth >= math.pi: self.azimuth = azimuth - (2 * math.pi) elif azimuth < -math.pi: @@ -212,16 +211,14 @@ def __init__(self, subblock): self.azimuth = azimuth self.n_seeds = subblock[12] - self.seeds = numpy.array(list(zip(subblock[13:41:3], - subblock[14:42:3], - subblock[15:43:3]))) + self.seeds = np.array(list(zip(subblock[13:41:3], subblock[14:42:3], subblock[15:43:3]))) self.run_number = subblock[43] self.date_start = subblock[44] self.version = subblock[45] self.n_observation_levels = subblock[46] - self.observation_heights = numpy.array(subblock[47:57]) * units.cm + self.observation_heights = np.array(subblock[47:57]) * units.cm self.spectral_slope = subblock[57] self.min_energy = subblock[58] * units.GeV @@ -271,8 +268,8 @@ def __init__(self, subblock): self.Cherenkov_wavelength_min = subblock[95] * units.nanometer self.Cherenkov_wavelength_max = subblock[96] * units.nanometer self.uses_of_Cherenkov_event = subblock[97] - self.core_x = numpy.array(subblock[98:118]) * units.cm - self.core_y = numpy.array(subblock[118:138]) * units.cm + self.core_x = np.array(subblock[98:118]) * units.cm + self.core_y = np.array(subblock[118:138]) * units.cm self.flag_SIBYLL = subblock[138] self.flag_SIBYLL_cross = subblock[139] @@ -320,8 +317,7 @@ def hadron_model_low(self): @property def hadron_model_high(self): - hadron_models_high = {0: 'HDPM', 1: 'VENUS', 2: 'SIBYLL', 3: 'QGSJET', - 4: 'DPMJET', 5: 'NEXUS', 6: 'EPOS'} + hadron_models_high = {0: 'HDPM', 1: 'VENUS', 2: 'SIBYLL', 3: 'QGSJET', 4: 'DPMJET', 5: 'NEXUS', 6: 'EPOS'} return hadron_models_high.get(self.flag_hadron_model_high, 'unknown') @property @@ -330,15 +326,16 @@ def computer(self): return computers.get(self.flag_computer, 'unknown') def __repr__(self): - return ('<%s, particle: %r, energy: 10**%.1f eV, zenith: %r deg,' - ' azimuth: %r deg>' % - (self.__class__.__name__, self.particle, - math.log10(self.energy), math.degrees(self.zenith), - math.degrees(self.azimuth))) + return '<%s, particle: %r, energy: 10**%.1f eV, zenith: %r deg, azimuth: %r deg>' % ( + self.__class__.__name__, + self.particle, + math.log10(self.energy), + math.degrees(self.zenith), + math.degrees(self.azimuth), + ) class RunEnd: - """The run end sub-block As specified in the CORSIKA user manual, Table 14. @@ -351,16 +348,10 @@ def __init__(self, subblock): self.n_events_processed = subblock[2] def __repr__(self): - return '{}(({!r}, {!r}, {!r}))'.format( - self.__class__.__name__, - self.id, - self.run_number, - self.n_events_processed - ) + return f'{self.__class__.__name__}(({self.id!r}, {self.run_number!r}, {self.n_events_processed!r}))' class EventEnd: - """The event end sub-block As specified in the CORSIKA user manual, Table 13. @@ -378,29 +369,28 @@ def __init__(self, subblock): self.n_particles_levels = subblock[6] # NKG output - self.NKG_lateral_1_x = numpy.array(subblock[7:28]) / units.cm2 - self.NKG_lateral_1_y = numpy.array(subblock[28:49]) / units.cm2 - self.NKG_lateral_1_xy = numpy.array(subblock[49:70]) / units.cm2 - self.NKG_lateral_1_yx = numpy.array(subblock[70:91]) / units.cm2 - - self.NKG_lateral_2_x = numpy.array(subblock[91:112]) / units.cm2 - self.NKG_lateral_2_y = numpy.array(subblock[112:133]) / units.cm2 - self.NKG_lateral_2_xy = numpy.array(subblock[133:154]) / units.cm2 - self.NKG_lateral_2_yx = numpy.array(subblock[154:175]) / units.cm2 - - self.NKG_electron_number = numpy.array(subblock[175:185]) - self.NKG_pseudo_age = numpy.array(subblock[185:195]) - self.NKG_electron_distances = numpy.array(subblock[195:205]) * units.cm - self.NKG_local_pseudo_age_1 = numpy.array(subblock[205:215]) - - self.NKG_level_height_mass = numpy.array(subblock[215:225]) - self.NKG_level_height_distance = numpy.array(subblock[225:235]) - self.NKG_distance_bins_local_pseudo_age = \ - numpy.array(subblock[235:245]) * units.cm - self.NKG_local_pseudo_age_2 = numpy.array(subblock[245:255]) + self.NKG_lateral_1_x = np.array(subblock[7:28]) / units.cm2 + self.NKG_lateral_1_y = np.array(subblock[28:49]) / units.cm2 + self.NKG_lateral_1_xy = np.array(subblock[49:70]) / units.cm2 + self.NKG_lateral_1_yx = np.array(subblock[70:91]) / units.cm2 + + self.NKG_lateral_2_x = np.array(subblock[91:112]) / units.cm2 + self.NKG_lateral_2_y = np.array(subblock[112:133]) / units.cm2 + self.NKG_lateral_2_xy = np.array(subblock[133:154]) / units.cm2 + self.NKG_lateral_2_yx = np.array(subblock[154:175]) / units.cm2 + + self.NKG_electron_number = np.array(subblock[175:185]) + self.NKG_pseudo_age = np.array(subblock[185:195]) + self.NKG_electron_distances = np.array(subblock[195:205]) * units.cm + self.NKG_local_pseudo_age_1 = np.array(subblock[205:215]) + + self.NKG_level_height_mass = np.array(subblock[215:225]) + self.NKG_level_height_distance = np.array(subblock[225:235]) + self.NKG_distance_bins_local_pseudo_age = np.array(subblock[235:245]) * units.cm + self.NKG_local_pseudo_age_2 = np.array(subblock[245:255]) # Longitudinal distribution - self.longitudinal_parameters = numpy.array(subblock[255:261]) + self.longitudinal_parameters = np.array(subblock[255:261]) self.chi2_longitudinal_fit = subblock[261] self.n_photons_output = subblock[262] @@ -438,15 +428,14 @@ def particle_data(subblock): y = x_corsika t = subblock[6] * units.ns # or z for additional muon info - id = description // 1000 + particle_id = description // 1000 hadron_generation = description // 10 % 100 observation_level = description % 10 - r = math.sqrt(x ** 2 + y ** 2) + r = math.sqrt(x**2 + y**2) phi = math.atan2(y, x) - return (p_x, p_y, p_z, x, y, t, id, r, hadron_generation, - observation_level, phi) + return (p_x, p_y, p_z, x, y, t, particle_id, r, hadron_generation, observation_level, phi) @jit @@ -463,7 +452,6 @@ def particle_data_thin(subblock): class ParticleData: - """The particle data sub-block As specified in the CORSIKA user manual, Table 10. @@ -471,9 +459,19 @@ class ParticleData: """ def __init__(self, subblock): - self.p_x, self.p_y, self.p_z, self.x, self.y, self.t, self.id, \ - self.r, self.hadron_generation, self.observation_level, \ - self.phi = particle_data(subblock) + ( + self.p_x, + self.p_y, + self.p_z, + self.x, + self.y, + self.t, + self.id, + self.r, + self.hadron_generation, + self.observation_level, + self.phi, + ) = particle_data(subblock) @property def is_detectable(self): @@ -498,7 +496,7 @@ def is_nucleus(self): @property def is_cherenkov(self): - return 9900 <= self.id + return self.id >= 9900 @property def atomic_number(self): @@ -518,13 +516,16 @@ def atom(self): return None def __repr__(self): - return ('<%s, particle: %r, x: %r m, y: %r m, t: %r ns>' % - (self.__class__.__name__, self.particle, self.x, self.y, - self.t)) + return '<%s, particle: %r, x: %r m, y: %r m, t: %r ns>' % ( + self.__class__.__name__, + self.particle, + self.x, + self.y, + self.t, + ) class CherenkovData: - """The cherenkov photon sub-block As specified in CORSIKA user manual, Table 11. @@ -546,8 +547,8 @@ def __init__(self, subblock): # THIN versions -class FormatThin(Format): +class FormatThin(Format): """The format information of the thinned file As specified in CORSIKA user manual, Section 10.2.2. @@ -576,13 +577,11 @@ def __init__(self): self.particle_size = struct.calcsize(self.particle_format) # Full particle sub block - self.particles_format = (self.particle_format * - self.particles_per_subblock) + self.particles_format = self.particle_format * self.particles_per_subblock self.particles_size = self.particle_size * self.particles_per_subblock class ParticleDataThin(ParticleData): - """The thinned particle data sub-block As specified in the CORSIKA user manual, Table 10. @@ -590,13 +589,23 @@ class ParticleDataThin(ParticleData): """ def __init__(self, subblock): - self.p_x, self.p_y, self.p_z, self.x, self.y, self.t, self.id, \ - self.r, self.hadron_generation, self.observation_level, \ - self.phi, self.weight = particle_data_thin(subblock) + ( + self.p_x, + self.p_y, + self.p_z, + self.x, + self.y, + self.t, + self.id, + self.r, + self.hadron_generation, + self.observation_level, + self.phi, + self.weight, + ) = particle_data_thin(subblock) class CherenkovDataThin(CherenkovData): - """The thinned cherenkov photon sub-block As specified in CORSIKA user manual, Table 11. diff --git a/sapphire/corsika/corsika_queries.py b/sapphire/corsika/corsika_queries.py old mode 100755 new mode 100644 index f96d4d6c..58368990 --- a/sapphire/corsika/corsika_queries.py +++ b/sapphire/corsika/corsika_queries.py @@ -8,7 +8,6 @@ class CorsikaQuery: - def __init__(self, data, simulations_group='/simulations'): """Setup variables to point to the tables @@ -106,8 +105,7 @@ def all_zeniths(self): """ return {degrees(zenith) for zenith in set(self.sims.col('zenith'))} - def simulations(self, particle='proton', energy=None, zenith=None, - azimuth=None, iterator=False): + def simulations(self, particle='proton', energy=None, zenith=None, azimuth=None, iterator=False): """Set of available energies given the requirements :param particle: primary particle must be this kind, name of particle. @@ -184,19 +182,19 @@ def float_filter(self, key, value): return query - def range_filter(self, key, min=None, max=None): + def range_filter(self, key, min_value=None, max_value=None): """Filter to be in a range :param key: variable to filter. - :param min,max: limits on the value. + :param min_value,max_value: limits on the value. :return: query. """ queries = [] - if min is not None: - queries.append(f'({key} >= {min})') - if max is not None: - queries.append(f'({key} <= {max})') + if min_value is not None: + queries.append(f'({key} >= {min_value})') + if max_value is not None: + queries.append(f'({key} <= {max_value})') query = ' & '.join(queries) return query @@ -220,7 +218,5 @@ def perform_query(self, query, iterator=False): def __repr__(self): if not self.data.isopen: - return "" % self.__class__.__name__ - return ("%s(%r, simulations_group=%r)" % - (self.__class__.__name__, self.data.filename, - self.sims._v_pathname)) + return '' % self.__class__.__name__ + return '%s(%r, simulations_group=%r)' % (self.__class__.__name__, self.data.filename, self.sims._v_pathname) diff --git a/sapphire/corsika/generate_corsika_overview.py b/sapphire/corsika/generate_corsika_overview.py index 3588ca41..88a3ec7b 100644 --- a/sapphire/corsika/generate_corsika_overview.py +++ b/sapphire/corsika/generate_corsika_overview.py @@ -1,16 +1,17 @@ -""" Generate an overview table of the CORSIKA simulations +"""Generate an overview table of the CORSIKA simulations - This script will look for all completed and converted CORSIKA - simulations in the given data path. Information about each - simulation is collected and then summarized in a new h5 file as an - overview. +This script will look for all completed and converted CORSIKA +simulations in the given data path. Information about each +simulation is collected and then summarized in a new h5 file as an +overview. - The given source path should contain subdirectories named after the - seeds used for the simulation in the format ``{seed1}_{seed2}``, - e.g. ``821280921_182096636``. These in turn should contain converted - CORSIKA simulation results called ``corsika.h5``. +The given source path should contain subdirectories named after the +seeds used for the simulation in the format ``{seed1}_{seed2}``, +e.g. ``821280921_182096636``. These in turn should contain converted +CORSIKA simulation results called ``corsika.h5``. """ + import argparse import glob import logging @@ -121,8 +122,7 @@ def prepare_output(n): os.umask(0o02) tmp_path = create_tempfile_path() overview = tables.open_file(tmp_path, 'w') - overview.create_table('/', 'simulations', Simulations, - 'Simulations overview', expectedrows=n) + overview.create_table('/', 'simulations', Simulations, 'Simulations overview', expectedrows=n) return tmp_path, overview @@ -139,8 +139,8 @@ def move_tempfile_to_destination(tmp_path, destination): def all_seeds(source): """Get set of all seeds in the corsika data directory""" - dirs = glob.glob(os.path.join(source, '*_*')) - seeds = [os.path.basename(dir) for dir in dirs] + directories = glob.glob(os.path.join(source, '*_*')) + seeds = [os.path.basename(directory) for directory in directories] return sorted(set(seeds)) @@ -163,24 +163,21 @@ def generate_corsika_overview(source, destination, progress=False): def main(): - parser = argparse.ArgumentParser(description='Generate an overview of ' - 'CORSIKA simulations.') - parser.add_argument('source', nargs='?', default=DATA_PATH, - help="directory path containing CORSIKA simulations") - parser.add_argument('destination', nargs='?', default=OUTPUT_PATH, - help="path of the HDF5 output file") - parser.add_argument('--progress', action='store_true', - help='show progressbar during generation') - parser.add_argument('--log', action='store_true', - help='write logs to file, only for use on server') - parser.add_argument('--lazy', action='store_true', - help='only run if the overview is outdated') + parser = argparse.ArgumentParser(description='Generate an overview of CORSIKA simulations.') + parser.add_argument('source', nargs='?', default=DATA_PATH, help='directory path containing CORSIKA simulations') + parser.add_argument('destination', nargs='?', default=OUTPUT_PATH, help='path of the HDF5 output file') + parser.add_argument('--progress', action='store_true', help='show progressbar during generation') + parser.add_argument('--log', action='store_true', help='write logs to file, only for use on server') + parser.add_argument('--lazy', action='store_true', help='only run if the overview is outdated') args = parser.parse_args() if args.log: - logging.basicConfig(filename=LOGFILE, filemode='a', - format='%(asctime)s %(name)s %(levelname)s: ' - '%(message)s', - datefmt='%y%m%d_%H%M%S', level=logging.INFO) + logging.basicConfig( + filename=LOGFILE, + filemode='a', + format='%(asctime)s %(name)s %(levelname)s: %(message)s', + datefmt='%y%m%d_%H%M%S', + level=logging.INFO, + ) if args.lazy: last_store = os.path.getmtime(args.source) last_overview = os.path.getmtime(args.destination) @@ -188,9 +185,7 @@ def main(): logger.info('Overview up to date.') return - generate_corsika_overview(source=args.source, - destination=args.destination, - progress=args.progress) + generate_corsika_overview(source=args.source, destination=args.destination, progress=args.progress) if __name__ == '__main__': diff --git a/sapphire/corsika/mergesort.py b/sapphire/corsika/mergesort.py index 4dfc25fd..7388db37 100644 --- a/sapphire/corsika/mergesort.py +++ b/sapphire/corsika/mergesort.py @@ -9,18 +9,24 @@ class TableMergeSort: - - """ Sort a PyTables HDF5 table either in memory or on-disk """ + """Sort a PyTables HDF5 table either in memory or on-disk""" _iterators = [] - _BUFSIZE = 100000 + _BUFSIZE = 100_000 hdf5_temp = None - def __init__(self, key, inputfile, outputfile=None, tempfile=None, - tablename='groundparticles', destination=None, - overwrite=False, progress=True): - - """ Initialize the class + def __init__( + self, + key, + inputfile, + outputfile=None, + tempfile=None, + tablename='groundparticles', + destination=None, + overwrite=False, + progress=True, + ): + """Initialize the class :param key: the name of the column which is to be sorted. :param inputfile: PyTables HDF5 input file. @@ -48,8 +54,7 @@ def __init__(self, key, inputfile, outputfile=None, tempfile=None, self.hdf5_out = inputfile self.destination = destination else: - raise RuntimeError("Must specify either an outputfile or a " - "destination table") + raise RuntimeError('Must specify either an outputfile or a destination table') else: self.hdf5_out = outputfile if destination is not None: @@ -62,12 +67,9 @@ def __init__(self, key, inputfile, outputfile=None, tempfile=None, if self.overwrite: self.hd5_out.remove_nove('/', self.destination, recursive=True) else: - raise RuntimeError("Destination table exists and overwrite " - "is False") + raise RuntimeError('Destination table exists and overwrite is False') except tables.NoSuchNodeError: - self.outtable = self.hdf5_out.create_table('/', self.destination, - self.description, - expectedrows=self.nrows) + self.outtable = self.hdf5_out.create_table('/', self.destination, self.description, expectedrows=self.nrows) self._calc_nrows_in_chunk() @@ -79,10 +81,9 @@ def __init__(self, key, inputfile, outputfile=None, tempfile=None, self.hdf5_temp = tempfile if self.progress: parts = int(len(self.table) / self.nrows_in_chunk) + 1 - print("On disk mergesort in %d parts." % parts) - else: - if self.progress: - print("Table can be sorted in memory.") + print('On disk mergesort in %d parts.' % parts) + elif self.progress: + print('Table can be sorted in memory.') def __enter__(self): return self @@ -104,18 +105,15 @@ def sort(self): parts = int(nrows / chunk) + 1 if parts == 1: if self.progress: - print("Sorting table in memory and writing to disk.") + print('Sorting table in memory and writing to disk.') self._sort_chunk(self.outtable, 0, nrows) else: if self.progress: - print("Sorting in %d chunks of %d rows:" % (parts, chunk)) + print('Sorting in %d chunks of %d rows:' % (parts, chunk)) - for idx, start in pbar(enumerate(range(0, nrows, chunk)), - length=parts, show=self.progress): + for idx, start in pbar(enumerate(range(0, nrows, chunk)), length=parts, show=self.progress): table_name = 'temp_table_%d' % idx - table = self.hdf5_temp.create_table('/', table_name, - self.description, - expectedrows=chunk) + table = self.hdf5_temp.create_table('/', table_name, self.description, expectedrows=chunk) iterator = self._sort_chunk(table, start, start + chunk) self._iterators.append(iterator) @@ -123,10 +121,9 @@ def sort(self): idx = 0 if self.progress: - print("Merging:") + print('Merging:') - for keyedrow in pbar(merge(*self._iterators), length=nrows, - show=self.progress): + for keyedrow in pbar(merge(*self._iterators), length=nrows, show=self.progress): x, row = keyedrow if idx == self._BUFSIZE: diff --git a/sapphire/corsika/particles.py b/sapphire/corsika/particles.py index fe8a6117..69d653ee 100644 --- a/sapphire/corsika/particles.py +++ b/sapphire/corsika/particles.py @@ -23,6 +23,7 @@ """ + import re @@ -37,8 +38,7 @@ def name(particle_id): try: return ID[particle_id] except KeyError: - return (ATOMIC_NUMBER[int(particle_id) % 100] + - str(int(particle_id // 100))) + return ATOMIC_NUMBER[int(particle_id) % 100] + str(int(particle_id // 100)) def particle_id(name): @@ -58,268 +58,260 @@ def particle_id(name): # weight append the weight to the name, e.g. helium4 or carbon14 if name == atom_name: return z * 100 + z - atom = re.match(r"^([a-z]+)(\d+)$", name) + atom = re.match(r'^([a-z]+)(\d+)$', name) if atom is not None: return int(atom.group(2)) * 100 + (particle_id(atom.group(1)) % 100) -ID = {1: 'gamma', - 2: 'positron', - 3: 'electron', - 4: 'neutrino', # No longer used? - 5: 'muon_p', - 6: 'muon_m', - 7: 'pion_0', - 8: 'pion_p', - 9: 'pion_m', - 10: 'Kaon_0_long', - 11: 'Kaon_p', - 12: 'Kaon_m', - 13: 'neutron', - 14: 'proton', - 15: 'anti_proton', - 16: 'Kaon_0_short', - 17: 'eta', - 18: 'Lambda', - 19: 'Sigma_p', - 20: 'Sigma_0', - 21: 'Sigma_m', - 22: 'Xi_0', - 23: 'Xi_m', - 24: 'Omega_m', - 25: 'anti_neutron', - 26: 'anti_Lambda', - 27: 'anti_Sigma_m', - 28: 'anti_Sigma_0', - 29: 'anti_Sigma_p', - 30: 'anti_Xi_0', - 31: 'anti_Xi_p', - 32: 'anti_Omega_p', - 50: 'omega', - 51: 'rho_0', - 52: 'rho_p', - 53: 'rho_m', - 54: 'Delta_pp', - 55: 'Delta_p', - 56: 'Delta_0', - 57: 'Delta_m', - 58: 'anti_Delta_mm', - 59: 'anti_Delta_m', - 60: 'anti_Delta_0', - 61: 'anti_Delta_p', - 62: 'Kaon_star_0', - 63: 'Kaon_star_p', - 64: 'Kaon_star_m', - 65: 'anti_Kaon_star_0', - 66: 'electron_neutrino', - 67: 'anti_electron_neutrino', - 68: 'muon_neutrino', - 69: 'anti_muon_neutrino', - - 71: 'eta__2_gamma', - 72: 'eta__3_pion_0', - 73: 'eta__pion_p_pion_m_pion_0', - 74: 'eta__pion_p_pion_m_gamma', - 75: 'additional_muon_p', - 76: 'additional_muon_m', - - 85: 'decay_start_muon_p', - 86: 'decay_start_muon_m', - - 95: 'decay_end_muon_p', - 96: 'decay_end_muon_m', - - 116: 'D_0', - 117: 'D_p', - 118: 'anti_D_m', - 119: 'anti_D_0', - 120: 'D_p_short', - 121: 'anti_D_m_short', - 122: 'eta_c', - 123: 'D_star_0', - 124: 'D_star_p', - 125: 'anti_D_star_m', - 126: 'anti_D_star_0', - 127: 'D_star_p_short', - 128: 'anti_D_star_m_short', - - 130: 'j_psi', - 131: 'tau_p', - 132: 'tau_m', - 133: 'tau_neutrino', - 134: 'anti_tau_neutrino', - - 137: 'Lambda_c_p', - 138: 'Xi_c_p', - 139: 'Xi_c_0', - 140: 'Sigma_c_pp', - 141: 'Sigma_c_', - 142: 'Sigma_c_0', - 143: 'Xi_c_prime_p', - 144: 'Xi_c_prime_0', - 145: 'Omega_c_0', - - 149: 'anti_Lambda_c_m', - 150: 'anti_Xi_c_m', - 151: 'anti_Xi_c_0', - 152: 'anti_Sigma_c_mm', - 153: 'anti_Sigma_c_m', - 154: 'anti_Sigma_c_0', - 155: 'anti_Xi_c_prime_m', - 156: 'anti_Xi_c_prime_0', - 157: 'anti_Omega_c_0', - - 161: 'Sigma_c_star_pp', - 162: 'Sigma_c_star_p', - 163: 'Sigma_c_star_0', - - 171: 'anti_Sigma_c_star_mm', - 172: 'anti_Sigma_c_star_m', - 173: 'anti_Sigma_c_star_0', - - 176: 'B_0', - 177: 'B_p', - 178: 'anti_B_m', - 179: 'anti_B_0', - 180: 'B_s_0', - 181: 'anti_B_s_0', - 182: 'B_c_p', - 183: 'anti_B_c_m', - 184: 'Lambda_b_0', - 185: 'Sigma_b_m', - 186: 'Sigma_b_p', - 187: 'Xi_b_0', - 188: 'Xi_b_m', - 189: 'Omega_b_m', - 190: 'anti_Lambda_b_0', - 191: 'anti_Sigma_b_p', - 192: 'anti_Sigma_b_m', - 193: 'anti_Xi_b_0', - 194: 'anti_Xi_b_p', - 195: 'anti_Omega_b_p', - - # A x 100 + Z - 101: 'hydrogen', - 201: 'deuteron', - 301: 'tritium', - 302: 'helium3', - 402: 'alpha', - 703: 'lithium', - 904: 'beryllium', - 1105: 'boron', - 1206: 'carbon', - 1407: 'nitrogen', - 1608: 'oxygen', - 2713: 'aluminium', - 2814: 'silicon', - 3216: 'sulfur', - 4020: 'calcium', - 5626: 'iron', - 5828: 'nickel', - - 9900: 'cherenkov_photons'} +ID = { + 1: 'gamma', + 2: 'positron', + 3: 'electron', + 4: 'neutrino', # No longer used? + 5: 'muon_p', + 6: 'muon_m', + 7: 'pion_0', + 8: 'pion_p', + 9: 'pion_m', + 10: 'Kaon_0_long', + 11: 'Kaon_p', + 12: 'Kaon_m', + 13: 'neutron', + 14: 'proton', + 15: 'anti_proton', + 16: 'Kaon_0_short', + 17: 'eta', + 18: 'Lambda', + 19: 'Sigma_p', + 20: 'Sigma_0', + 21: 'Sigma_m', + 22: 'Xi_0', + 23: 'Xi_m', + 24: 'Omega_m', + 25: 'anti_neutron', + 26: 'anti_Lambda', + 27: 'anti_Sigma_m', + 28: 'anti_Sigma_0', + 29: 'anti_Sigma_p', + 30: 'anti_Xi_0', + 31: 'anti_Xi_p', + 32: 'anti_Omega_p', + 50: 'omega', + 51: 'rho_0', + 52: 'rho_p', + 53: 'rho_m', + 54: 'Delta_pp', + 55: 'Delta_p', + 56: 'Delta_0', + 57: 'Delta_m', + 58: 'anti_Delta_mm', + 59: 'anti_Delta_m', + 60: 'anti_Delta_0', + 61: 'anti_Delta_p', + 62: 'Kaon_star_0', + 63: 'Kaon_star_p', + 64: 'Kaon_star_m', + 65: 'anti_Kaon_star_0', + 66: 'electron_neutrino', + 67: 'anti_electron_neutrino', + 68: 'muon_neutrino', + 69: 'anti_muon_neutrino', + 71: 'eta__2_gamma', + 72: 'eta__3_pion_0', + 73: 'eta__pion_p_pion_m_pion_0', + 74: 'eta__pion_p_pion_m_gamma', + 75: 'additional_muon_p', + 76: 'additional_muon_m', + 85: 'decay_start_muon_p', + 86: 'decay_start_muon_m', + 95: 'decay_end_muon_p', + 96: 'decay_end_muon_m', + 116: 'D_0', + 117: 'D_p', + 118: 'anti_D_m', + 119: 'anti_D_0', + 120: 'D_p_short', + 121: 'anti_D_m_short', + 122: 'eta_c', + 123: 'D_star_0', + 124: 'D_star_p', + 125: 'anti_D_star_m', + 126: 'anti_D_star_0', + 127: 'D_star_p_short', + 128: 'anti_D_star_m_short', + 130: 'j_psi', + 131: 'tau_p', + 132: 'tau_m', + 133: 'tau_neutrino', + 134: 'anti_tau_neutrino', + 137: 'Lambda_c_p', + 138: 'Xi_c_p', + 139: 'Xi_c_0', + 140: 'Sigma_c_pp', + 141: 'Sigma_c_', + 142: 'Sigma_c_0', + 143: 'Xi_c_prime_p', + 144: 'Xi_c_prime_0', + 145: 'Omega_c_0', + 149: 'anti_Lambda_c_m', + 150: 'anti_Xi_c_m', + 151: 'anti_Xi_c_0', + 152: 'anti_Sigma_c_mm', + 153: 'anti_Sigma_c_m', + 154: 'anti_Sigma_c_0', + 155: 'anti_Xi_c_prime_m', + 156: 'anti_Xi_c_prime_0', + 157: 'anti_Omega_c_0', + 161: 'Sigma_c_star_pp', + 162: 'Sigma_c_star_p', + 163: 'Sigma_c_star_0', + 171: 'anti_Sigma_c_star_mm', + 172: 'anti_Sigma_c_star_m', + 173: 'anti_Sigma_c_star_0', + 176: 'B_0', + 177: 'B_p', + 178: 'anti_B_m', + 179: 'anti_B_0', + 180: 'B_s_0', + 181: 'anti_B_s_0', + 182: 'B_c_p', + 183: 'anti_B_c_m', + 184: 'Lambda_b_0', + 185: 'Sigma_b_m', + 186: 'Sigma_b_p', + 187: 'Xi_b_0', + 188: 'Xi_b_m', + 189: 'Omega_b_m', + 190: 'anti_Lambda_b_0', + 191: 'anti_Sigma_b_p', + 192: 'anti_Sigma_b_m', + 193: 'anti_Xi_b_0', + 194: 'anti_Xi_b_p', + 195: 'anti_Omega_b_p', + # A x 100 + Z + 101: 'hydrogen', + 201: 'deuteron', + 301: 'tritium', + 302: 'helium3', + 402: 'alpha', + 703: 'lithium', + 904: 'beryllium', + 1105: 'boron', + 1206: 'carbon', + 1407: 'nitrogen', + 1608: 'oxygen', + 2713: 'aluminium', + 2814: 'silicon', + 3216: 'sulfur', + 4020: 'calcium', + 5626: 'iron', + 5828: 'nickel', + 9900: 'cherenkov_photons', +} # Z numbers -ATOMIC_NUMBER = {1: 'hydrogen', - 2: 'helium', - 3: 'lithium', - 4: 'beryllium', - 5: 'boron', - 6: 'carbon', - 7: 'nitrogen', - 8: 'oxygen', - 9: 'fluorine', - 10: 'neon', - 11: 'sodium', - 12: 'magnesium', - 13: 'aluminium', - 14: 'silicon', - 15: 'phosphorus', - 16: 'sulfur', - 17: 'chlorine', - 18: 'argon', - 19: 'potassium', - 20: 'calcium', - 21: 'scandium', - 22: 'titanium', - 23: 'vanadium', - 24: 'chromium', - 25: 'manganese', - 26: 'iron', - 27: 'cobalt', - 28: 'nickel', - 29: 'copper', - 30: 'zinc', - 31: 'gallium', - 32: 'germanium', - 33: 'arsenic', - 34: 'selenium', - 35: 'bromine', - 36: 'krypton', - 37: 'rubidium', - 38: 'strontium', - 39: 'yttrium', - 40: 'zirconium', - 41: 'niobium', - 42: 'molybdenum', - 43: 'technetium', - 44: 'ruthenium', - 45: 'rhodium', - 46: 'palladium', - 47: 'silver', - 48: 'cadmium', - 49: 'indium', - 50: 'tin', - 51: 'antimony', - 52: 'tellurium', - 53: 'iodine', - 54: 'xenon', - 55: 'caesium', - 56: 'barium', - 57: 'lanthanum', - 58: 'cerium', - 59: 'praseodym.', - 60: 'neodymium', - 61: 'promethium', - 62: 'samarium', - 63: 'europium', - 64: 'gadolinium', - 65: 'terbium', - 66: 'dysprosium', - 67: 'holmium', - 68: 'erbium', - 69: 'thulium', - 70: 'ytterbium', - 71: 'lutetium', - 72: 'hafnium', - 73: 'tantalum', - 74: 'tungsten', - 75: 'rhenium', - 76: 'osmium', - 77: 'iridium', - 78: 'platinum', - 79: 'gold', - 80: 'mercury', - 81: 'thallium', - 82: 'lead', - 83: 'bismuth', - 84: 'polonium', - 85: 'astatine', - 86: 'radon', - 87: 'francium', - 88: 'radium', - 89: 'actinium', - 90: 'thorium', - 91: 'protactin.', - 92: 'uranium', - 93: 'neptunium', - 94: 'plutonium', - 95: 'americium', - 96: 'curium', - 97: 'berkelium', - 98: 'californium', - 99: 'einsteinium'} +ATOMIC_NUMBER = { + 1: 'hydrogen', + 2: 'helium', + 3: 'lithium', + 4: 'beryllium', + 5: 'boron', + 6: 'carbon', + 7: 'nitrogen', + 8: 'oxygen', + 9: 'fluorine', + 10: 'neon', + 11: 'sodium', + 12: 'magnesium', + 13: 'aluminium', + 14: 'silicon', + 15: 'phosphorus', + 16: 'sulfur', + 17: 'chlorine', + 18: 'argon', + 19: 'potassium', + 20: 'calcium', + 21: 'scandium', + 22: 'titanium', + 23: 'vanadium', + 24: 'chromium', + 25: 'manganese', + 26: 'iron', + 27: 'cobalt', + 28: 'nickel', + 29: 'copper', + 30: 'zinc', + 31: 'gallium', + 32: 'germanium', + 33: 'arsenic', + 34: 'selenium', + 35: 'bromine', + 36: 'krypton', + 37: 'rubidium', + 38: 'strontium', + 39: 'yttrium', + 40: 'zirconium', + 41: 'niobium', + 42: 'molybdenum', + 43: 'technetium', + 44: 'ruthenium', + 45: 'rhodium', + 46: 'palladium', + 47: 'silver', + 48: 'cadmium', + 49: 'indium', + 50: 'tin', + 51: 'antimony', + 52: 'tellurium', + 53: 'iodine', + 54: 'xenon', + 55: 'caesium', + 56: 'barium', + 57: 'lanthanum', + 58: 'cerium', + 59: 'praseodym.', + 60: 'neodymium', + 61: 'promethium', + 62: 'samarium', + 63: 'europium', + 64: 'gadolinium', + 65: 'terbium', + 66: 'dysprosium', + 67: 'holmium', + 68: 'erbium', + 69: 'thulium', + 70: 'ytterbium', + 71: 'lutetium', + 72: 'hafnium', + 73: 'tantalum', + 74: 'tungsten', + 75: 'rhenium', + 76: 'osmium', + 77: 'iridium', + 78: 'platinum', + 79: 'gold', + 80: 'mercury', + 81: 'thallium', + 82: 'lead', + 83: 'bismuth', + 84: 'polonium', + 85: 'astatine', + 86: 'radon', + 87: 'francium', + 88: 'radium', + 89: 'actinium', + 90: 'thorium', + 91: 'protactin.', + 92: 'uranium', + 93: 'neptunium', + 94: 'plutonium', + 95: 'americium', + 96: 'curium', + 97: 'berkelium', + 98: 'californium', + 99: 'einsteinium', +} # From the CORSIKA `corsikaread.cpp` program diff --git a/sapphire/corsika/qsub_corsika.py b/sapphire/corsika/qsub_corsika.py index d5cfc6cf..2165ab0e 100644 --- a/sapphire/corsika/qsub_corsika.py +++ b/sapphire/corsika/qsub_corsika.py @@ -1,26 +1,26 @@ -""" Run CORSIKA simulations on Stoomboot +"""Run CORSIKA simulations on Stoomboot - In order to quickly get a good sample of simulated showers we use the - Nikhef computer cluster Stoomboot to run multiple jobs simultaneously. - For this purpose a script has been written that will make this easy. - The :mod:`~sapphire.corsika.qsub_corsika` script can submit as many - jobs as you want with the parameters that you desire. It automatically - ensures that a unique combination of seeds for the random number - sequences are used for each simulation. +In order to quickly get a good sample of simulated showers we use the +Nikhef computer cluster Stoomboot to run multiple jobs simultaneously. +For this purpose a script has been written that will make this easy. +The :mod:`~sapphire.corsika.qsub_corsika` script can submit as many +jobs as you want with the parameters that you desire. It automatically +ensures that a unique combination of seeds for the random number +sequences are used for each simulation. - To run this file correctly do it in the correct env:: +To run this file correctly do it in the correct env:: - $ source activate corsika + $ source activate corsika - The syntax for calling the script can be seen by calling its help:: +The syntax for calling the script can be seen by calling its help:: - $ qsub_corsika --help + $ qsub_corsika --help - For example, running 100 showers with proton primaries of 1e16 eV - coming in at 22.5 degrees zenith and 90 degrees azimuth on the - standard Stoomboot queue with the default CORSIKA configuration:: +For example, running 100 showers with proton primaries of 1e16 eV +coming in at 22.5 degrees zenith and 90 degrees azimuth on the +standard Stoomboot queue with the default CORSIKA configuration:: - $ qsub_corsika 100 16 proton 22.5 -q generic -a 90 + $ qsub_corsika 100 16 proton 22.5 -q generic -a 90 """ @@ -74,7 +74,7 @@ DIRECT ./ output directory USER hisparc user DEBUG F 6 F 1000000 debug flag and log.unit for out - EXIT terminates input""") # noqa: E501 + EXIT terminates input""") SCRIPT_TEMPLATE = textwrap.dedent("""\ #!/usr/bin/env bash @@ -95,7 +95,6 @@ class CorsikaBatch: - """Run many simultaneous CORSIKA simulations using Stoomboot Stoomboot is the Nikhef computer cluster. @@ -121,8 +120,15 @@ class CorsikaBatch: """ - def __init__(self, energy=16, particle='proton', zenith=22.5, azimuth=180, - queue='generic', corsika='corsika74000Linux_QGSII_gheisha'): + def __init__( + self, + energy=16, + particle='proton', + zenith=22.5, + azimuth=180, + queue='generic', + corsika='corsika74000Linux_QGSII_gheisha', + ): self.energy_pre, self.energy_pow = self.corsika_energy(energy) self.particle = particles.particle_id(particle) # Store as particle id self.theta = zenith @@ -143,8 +149,8 @@ def corsika_energy(self, energy): :return: separate multiplier and power """ - if modf(energy)[0] == 0.: - return (1., int(energy - 9)) + if modf(energy)[0] == 0.0: + return (1.0, int(energy - 9)) elif modf(energy)[0] == 0.5: return (3.16228, int(modf(energy)[1] - 9)) else: @@ -173,10 +179,10 @@ def prepare_env(self): def submit_job(self): """Submit job to Stoomboot""" - name = f"cor_{self.seed1}_{self.seed2}" - extra = f"-d {self.get_rundir()}" + name = f'cor_{self.seed1}_{self.seed2}' + extra = f'-d {self.get_rundir()}' if self.queue == 'long': - extra += " -l walltime=96:00:00" + extra += ' -l walltime=96:00:00' script = self.create_script() qsub.submit_job(script, name, self.queue, extra) @@ -195,9 +201,9 @@ def generate_random_seeds(self, taken): each is formatted like this: 'seed1_seed2' """ - seed1 = random.randint(1, 900000000) - seed2 = random.randint(1, 900000000) - seed = f"{seed1}_{seed2}" + seed1 = random.randint(1, 900_000_000) + seed2 = random.randint(1, 900_000_000) + seed = f'{seed1}_{seed2}' if seed not in taken: self.seed1 = seed1 self.seed2 = seed2 @@ -224,13 +230,18 @@ def create_input(self): """Make CORSIKA steering file""" input_path = os.path.join(self.get_rundir(), 'input-hisparc') - input = INPUT_TEMPLATE.format(seed1=self.seed1, seed2=self.seed2, - particle=self.particle, phi=self.phi, - energy_pre=self.energy_pre, - energy_pow=self.energy_pow, - theta=self.theta, tablesdir=CORSIKADIR) + input_content = INPUT_TEMPLATE.format( + seed1=self.seed1, + seed2=self.seed2, + particle=self.particle, + phi=self.phi, + energy_pre=self.energy_pre, + energy_pow=self.energy_pow, + theta=self.theta, + tablesdir=CORSIKADIR, + ) with open(input_path, 'w') as input_file: - input_file.write(input) + input_file.write(input_content) def create_script(self): """Make Stoomboot script file""" @@ -238,8 +249,7 @@ def create_script(self): exec_path = os.path.join(CORSIKADIR, self.corsika) run_path = self.get_rundir() - script = SCRIPT_TEMPLATE.format(corsika=exec_path, rundir=run_path, - datadir=DATADIR, faildir=FAILDIR) + script = SCRIPT_TEMPLATE.format(corsika=exec_path, rundir=run_path, datadir=DATADIR, faildir=FAILDIR) return script def copy_config(self): @@ -254,17 +264,21 @@ def copy_config(self): subprocess.check_output(['cp', source, destination]) def __repr__(self): - energy = round(log10(self.energy_pre * 10 ** self.energy_pow) + 9, 1) + energy = round(log10(self.energy_pre * 10**self.energy_pow) + 9, 1) particle = particles.name(self.particle) azimuth = self.phi - 90 - return ('%s(energy=%r, particle=%r, zenith=%r, azimuth=%r, queue=%r, ' - 'corsika=%r)' % - (self.__class__.__name__, energy, particle, self.theta, - azimuth, self.queue, self.corsika)) - - -def multiple_jobs(n, energy, particle, zenith, azimuth, queue, corsika, - progress=True): + return '%s(energy=%r, particle=%r, zenith=%r, azimuth=%r, queue=%r, corsika=%r)' % ( + self.__class__.__name__, + energy, + particle, + self.theta, + azimuth, + self.queue, + self.corsika, + ) + + +def multiple_jobs(n, energy, particle, zenith, azimuth, queue, corsika, progress=True): """Use this to sumbit multiple jobs to Stoomboot :param n: Number of jobs to submit @@ -280,66 +294,83 @@ def multiple_jobs(n, energy, particle, zenith, azimuth, queue, corsika, """ if progress: - print(textwrap.dedent("""\ + print( + textwrap.dedent(f"""\ Batch submitting jobs to Stoomboot: Number of jobs {n} - Particle energy 10^{e} eV - Primary particle {p} - Zenith angle {z} degrees - Azimuth angle {a} degrees - Stoomboot queue {q} - CORSIKA executable {c} - """.format(n=n, e=energy, p=particle, z=zenith, a=azimuth, q=queue, - c=corsika))) + Particle energy 10^{energy} eV + Primary particle {particle} + Zenith angle {zenith} degrees + Azimuth angle {azimuth} degrees + Stoomboot queue {queue} + CORSIKA executable {corsika} + """), + ) available_slots = qsub.check_queue(queue) if available_slots <= 0: - raise Exception('Submitting no jobs because selected queue is full.') + raise RuntimeError('Submitting no jobs because selected queue is full.') elif available_slots < n: n = available_slots - warnings.warn('Submitting {n} jobs because queue almost full.' - .format(n=n)) + warnings.warn(f'Submitting {n} jobs because queue almost full.') for _ in pbar(range(n), show=progress): - batch = CorsikaBatch(energy=energy, particle=particle, zenith=zenith, - azimuth=azimuth, queue=queue, corsika=corsika) + batch = CorsikaBatch( + energy=energy, + particle=particle, + zenith=zenith, + azimuth=azimuth, + queue=queue, + corsika=corsika, + ) batch.run() def main(): - parser = argparse.ArgumentParser(description='Submit CORSIKA jobs to ' - 'Stoomboot, only at Nikhef.') - parser.add_argument('n', type=int, help="number of jobs to submit") - parser.add_argument('energy', metavar='energy', type=float, - help="energy of the primary particle in range 11..17, " - "in steps of .5 (log10(E[eV]))", - choices=[11, 11.5, 12, 12.5, 13, 13.5, 14, 14.5, - 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5]) - parser.add_argument('particle', help="primary particle kind (e.g. proton " - "or iron)") - parser.add_argument('zenith', metavar='zenith', - help="zenith angle of primary particle in range 0..60," - " in steps of 7.5 [degrees]", - type=float, - choices=[0, 7.5, 15, 22.5, 30, 37.5, 45, 52.5, 60]) - parser.add_argument('-a', '--azimuth', metavar='angle', - help="azimuth angle of primary particle in range " - "0..315, in steps of 45 [degrees]", - type=int, - default=0, - choices=[0, 45, 90, 135, 180, 225, 270, 315]) - parser.add_argument('-q', '--queue', metavar='name', - help="name of the Stoomboot queue to use, choose from " - "express, short, generic (default), and long", - default='generic', - choices=['express', 'short', 'generic', 'long']) - parser.add_argument('-c', '--corsika', metavar='exec', - help="name of the CORSIKA executable to use", - default="corsika74000Linux_QGSII_gheisha") + parser = argparse.ArgumentParser(description='Submit CORSIKA jobs to Stoomboot, only at Nikhef.') + parser.add_argument('n', type=int, help='number of jobs to submit') + parser.add_argument( + 'energy', + metavar='energy', + type=float, + help='energy of the primary particle in range 11..17, in steps of .5 (log10(E[eV]))', + choices=[11, 11.5, 12, 12.5, 13, 13.5, 14, 14.5, 15, 15.5, 16, 16.5, 17, 17.5, 18, 18.5], + ) + parser.add_argument('particle', help='primary particle kind (e.g. proton or iron)') + parser.add_argument( + 'zenith', + metavar='zenith', + help='zenith angle of primary particle in range 0..60, in steps of 7.5 [degrees]', + type=float, + choices=[0, 7.5, 15, 22.5, 30, 37.5, 45, 52.5, 60], + ) + parser.add_argument( + '-a', + '--azimuth', + metavar='angle', + help='azimuth angle of primary particle in range 0..315, in steps of 45 [degrees]', + type=int, + default=0, + choices=[0, 45, 90, 135, 180, 225, 270, 315], + ) + parser.add_argument( + '-q', + '--queue', + metavar='name', + help='name of the Stoomboot queue to use, choose from express, short, generic (default), and long', + default='generic', + choices=['express', 'short', 'generic', 'long'], + ) + parser.add_argument( + '-c', + '--corsika', + metavar='exec', + help='name of the CORSIKA executable to use', + default='corsika74000Linux_QGSII_gheisha', + ) args = parser.parse_args() - multiple_jobs(args.n, args.energy, args.particle, args.zenith, - args.azimuth, args.queue, args.corsika) + multiple_jobs(args.n, args.energy, args.particle, args.zenith, args.azimuth, args.queue, args.corsika) if __name__ == '__main__': diff --git a/sapphire/corsika/qsub_store_corsika_data.py b/sapphire/corsika/qsub_store_corsika_data.py index f3e782db..b356c4ed 100644 --- a/sapphire/corsika/qsub_store_corsika_data.py +++ b/sapphire/corsika/qsub_store_corsika_data.py @@ -1,15 +1,16 @@ -""" Convert CORSIKA stored showers to HDF5 on Stoomboot +"""Convert CORSIKA stored showers to HDF5 on Stoomboot - Automatically submits Stoomboot jobs to convert corsika data. The - script ``store_corsika_data`` can be used to convert a DAT000000 - CORSIKA file to a HDF5 file. This script checks our data folder for - new or unconverted simulations and creates Stoomboot jobs to perform - the conversion. +Automatically submits Stoomboot jobs to convert corsika data. The +script ``store_corsika_data`` can be used to convert a DAT000000 +CORSIKA file to a HDF5 file. This script checks our data folder for +new or unconverted simulations and creates Stoomboot jobs to perform +the conversion. - This job is run as a cron job to ensure the simulations remain up to - date. +This job is run as a cron job to ensure the simulations remain up to +date. """ + import argparse import glob import logging @@ -38,8 +39,8 @@ def all_seeds(): """Get set of all seeds in the corsika data directory""" - dirs = glob.glob(os.path.join(DATADIR, '*_*')) - seeds = [os.path.basename(dir) for dir in dirs] + directories = glob.glob(os.path.join(DATADIR, '*_*')) + seeds = [os.path.basename(directory) for directory in directories] return set(seeds) @@ -94,8 +95,7 @@ def filter_large_seeds(seeds_todo): """Exclude seeds for data files that are to large""" limit = 70e9 # larger than 70 GB has not been tested yet - return {s for s in seeds_todo - if os.path.getsize(os.path.join(DATADIR, s, SOURCE_FILE)) < limit} + return {s for s in seeds_todo if os.path.getsize(os.path.join(DATADIR, s, SOURCE_FILE)) < limit} def store_command(seed): @@ -103,9 +103,7 @@ def store_command(seed): source = os.path.join(DATADIR, seed, SOURCE_FILE) destination = os.path.join(DATADIR, seed, DESTINATION_FILE) - command = ('{bin_path}python {bin_path}store_corsika_data {source} ' - '{destination}'.format(bin_path=BIN_PATH, source=source, - destination=destination)) + command = f'{BIN_PATH}python {BIN_PATH}store_corsika_data {source} {destination}' return command @@ -120,7 +118,7 @@ def run(queue): n_jobs_to_submit = min(len(seeds), qsub.check_queue(queue), 50) extra = '' if queue == 'long': - extra += " -l walltime=96:00:00" + extra += ' -l walltime=96:00:00' logger.info('Submitting jobs for %d simulations.' % n_jobs_to_submit) try: @@ -136,13 +134,15 @@ def run(queue): def main(): - parser = argparse.ArgumentParser(description='Submit jobs to Stoomboot to ' - 'store CORSIKA data as HDF5.') - parser.add_argument('-q', '--queue', metavar='name', - help="name of the Stoomboot queue to use, choose from " - "express, short, generic, and long (default)", - default='long', - choices=['express', 'short', 'generic', 'long']) + parser = argparse.ArgumentParser(description='Submit jobs to Stoomboot to store CORSIKA data as HDF5.') + parser.add_argument( + '-q', + '--queue', + metavar='name', + help='name of the Stoomboot queue to use, choose from express, short, generic, and long (default)', + default='long', + choices=['express', 'short', 'generic', 'long'], + ) args = parser.parse_args() logger.debug('Starting to submit new jobs.') run(args.queue) @@ -151,7 +151,10 @@ def main(): if __name__ == '__main__': logging.basicConfig( - filename=LOGFILE, filemode='a', + filename=LOGFILE, + filemode='a', format='%(asctime)s %(name)s %(levelname)s: %(message)s', - datefmt='%y%m%d_%H%M%S', level=logging.INFO) + datefmt='%y%m%d_%H%M%S', + level=logging.INFO, + ) main() diff --git a/sapphire/corsika/reader.py b/sapphire/corsika/reader.py index 0394a268..0e2d7d2c 100644 --- a/sapphire/corsika/reader.py +++ b/sapphire/corsika/reader.py @@ -1,67 +1,67 @@ -""" Read CORSIKA data files. +"""Read CORSIKA data files. - This provides functionality to read CORSIKA output - files with `Python `_. It provides the following main - classes: +This provides functionality to read CORSIKA output +files with `Python `_. It provides the following main +classes: - * :class:`~sapphire.corsika.reader.CorsikaFile`: The file class provides a - generator over all events in the file. - * :class:`~sapphire.corsika.reader.CorsikaEvent`: The event class that - provides a generator over all particles at ground. +* :class:`~sapphire.corsika.reader.CorsikaFile`: The file class provides a + generator over all events in the file. +* :class:`~sapphire.corsika.reader.CorsikaEvent`: The event class that + provides a generator over all particles at ground. - and the following classes that correspond to the sub-blocks defined in - the CORSIKA manual: +and the following classes that correspond to the sub-blocks defined in +the CORSIKA manual: - * :class:`~sapphire.corsika.blocks.RunHeader` - * :class:`~sapphire.corsika.blocks.RunEnd` - * :class:`~sapphire.corsika.blocks.EventHeader` - * :class:`~sapphire.corsika.blocks.EventEnd` - * :class:`~sapphire.corsika.blocks.ParticleData` - * :class:`~sapphire.corsika.blocks.CherenkovData` +* :class:`~sapphire.corsika.blocks.RunHeader` +* :class:`~sapphire.corsika.blocks.RunEnd` +* :class:`~sapphire.corsika.blocks.EventHeader` +* :class:`~sapphire.corsika.blocks.EventEnd` +* :class:`~sapphire.corsika.blocks.ParticleData` +* :class:`~sapphire.corsika.blocks.CherenkovData` - Additionally version for thinned showers are available: +Additionally version for thinned showers are available: - * :class:`CorsikaFileThin` - * :class:`~sapphire.corsika.blocks.ParticleDataThin` - * :class:`~sapphire.corsika.blocks.CherenkovDataThin` +* :class:`CorsikaFileThin` +* :class:`~sapphire.corsika.blocks.ParticleDataThin` +* :class:`~sapphire.corsika.blocks.CherenkovDataThin` - Issues - ====== +Issues +====== - This module does not handle platform dependent issues such as byte - ordering (endianness) and field size. This was the result of an - afternoon hack and has only been tested with files generated using - 32 bit CORSIKA files on a linux system compiled with gfortran. +This module does not handle platform dependent issues such as byte +ordering (endianness) and field size. This was the result of an +afternoon hack and has only been tested with files generated using +32 bit CORSIKA files on a linux system compiled with gfortran. - * **Field Size**: According to the CORSIKA user manual section 10.2 - all quantities are written as single precision real numbers - independently of 32-bit or 64-bit, so each field in the file - should be 4 bytes long. - * **Endianness**: There is no check for byte ordering. It can be added - using Python's `struct module - `_. - * **Special Particles**: This module currently ignores all special - (book-keeping) particles like for muon additional information and - history. +* **Field Size**: According to the CORSIKA user manual section 10.2 + all quantities are written as single precision real numbers + independently of 32-bit or 64-bit, so each field in the file + should be 4 bytes long. +* **Endianness**: There is no check for byte ordering. It can be added + using Python's `struct module + `_. +* **Special Particles**: This module currently ignores all special + (book-keeping) particles like for muon additional information and + history. - More Info - ========= +More Info +========= - For short information on fortran unformatted binary files, take a look - at http://paulbourke.net/dataformats/reading/ +For short information on fortran unformatted binary files, take a look +at http://paulbourke.net/dataformats/reading/ - For detailed information on the CORSIKA format, check the 'Outputs' - chapter in the CORSIKA user manual. In particular, check the 'Normal - Particle Output' section. +For detailed information on the CORSIKA format, check the 'Outputs' +chapter in the CORSIKA user manual. In particular, check the 'Normal +Particle Output' section. - Authors - ======= +Authors +======= - - Javier Gonzalez - - Arne de Laat +- Javier Gonzalez +- Arne de Laat """ @@ -139,8 +139,7 @@ def get_particles(self): :yield: each particle in the event """ - for sub_block_index in self._raw_file._subblocks_indices( - self._header_index, self._end_index): + for sub_block_index in self._raw_file._subblocks_indices(self._header_index, self._end_index): for particle in self._raw_file._get_particles(sub_block_index): particle_type = particle[6] observation_level = particle[9] @@ -160,16 +159,10 @@ def get_particles(self): yield particle def __repr__(self): - return "{}({!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._raw_file, - self._header_index, - self._end_index - ) + return f'{self.__class__.__name__}({self._raw_file!r}, {self._header_index!r}, {self._end_index!r})' class CorsikaFile: - """CORSIKA output file handler This class will probide an interface for CORSIKA output files. @@ -201,7 +194,7 @@ def finish(self): def __enter__(self): return self - def __exit__(self, type, value, traceback): + def __exit__(self, exc_type, exc_value, traceback): self.finish() def check(self): @@ -221,8 +214,7 @@ def check(self): """ if self._size % self.format.block_size != 0: - raise Exception('File "{name}" does not have an integer number ' - 'of blocks!'.format(name=self._filename)) + raise ValueError(f'File "{self._filename}" does not have an integer number of blocks!') block_size = self.format.block_size padding = self.format.block_padding_size n_blocks = self._size // block_size @@ -232,8 +224,7 @@ def check(self): self._file.seek((block + 1) * block_size - padding) b = unpack('i', self._file.read(padding))[0] if a != b: - raise Exception('Block #{block} is not right: ({head}, {tail})' - .format(block=block, head=a, tail=b)) + raise ValueError(f'Block #{block} is not right: ({a}, {b})') return True def get_sub_blocks(self): @@ -247,11 +238,10 @@ def get_sub_blocks(self): subblock_size = self.format.subblock_size n_blocks = self._size / block_size for b in range(0, n_blocks * block_size, block_size): - for s in range(0, self.format.subblocks_per_block): + for s in range(self.format.subblocks_per_block): pos = b + s * subblock_size + self.format.block_padding_size self._file.seek(pos) - yield unpack(self.format.subblock_format, - self._file.read(subblock_size)) + yield unpack(self.format.subblock_format, self._file.read(subblock_size)) def get_header(self): """Get the Run header @@ -310,10 +300,11 @@ def _subblocks_indices(self, min_sub_block=None, max_sub_block=None): subblock_size = self.format.subblock_size n_blocks = self._size // block_size for b in range(0, n_blocks * block_size, block_size): - for s in range(0, self.format.subblocks_per_block): + for s in range(self.format.subblocks_per_block): pos = b + s * subblock_size + self.format.block_padding_size - if ((min_sub_block is not None and pos <= min_sub_block) or - (max_sub_block is not None and pos >= max_sub_block)): + if (min_sub_block is not None and pos <= min_sub_block) or ( + max_sub_block is not None and pos >= max_sub_block + ): continue yield pos @@ -376,8 +367,7 @@ def _get_particles(self, word): """Get subblock of particles from the contents as tuples""" unpacked_particles = self._unpack_particles(word) - particles = zip(*[iter(unpacked_particles)] * - self.format.fields_per_particle) + particles = zip(*[iter(unpacked_particles)] * self.format.fields_per_particle) return (particle_data(particle) for particle in particles) def _unpack_subblock(self, word): @@ -387,8 +377,7 @@ def _unpack_subblock(self, word): """ self._file.seek(word) - return unpack(self.format.subblock_format, - self._file.read(self.format.subblock_size)) + return unpack(self.format.subblock_format, self._file.read(self.format.subblock_size)) def _unpack_particle(self, word): """Unpack a particle block @@ -397,8 +386,7 @@ def _unpack_particle(self, word): """ self._file.seek(word) - return unpack(self.format.particle_format, - self._file.read(self.format.particle_size)) + return unpack(self.format.particle_format, self._file.read(self.format.particle_size)) def _unpack_particles(self, word): """Unpack a particles subblock @@ -407,15 +395,13 @@ def _unpack_particles(self, word): """ self._file.seek(word) - return unpack(self.format.particles_format, - self._file.read(self.format.particles_size)) + return unpack(self.format.particles_format, self._file.read(self.format.particles_size)) def __repr__(self): - return f"{self.__class__.__name__}({self._filename!r})" + return f'{self.__class__.__name__}({self._filename!r})' class CorsikaFileThin(CorsikaFile): - """CORSIKA thinned output file handler Same as the unthinned output handler, but with support for @@ -447,6 +433,5 @@ def _get_particles(self, word): """Get subblock of thinned particles from the contents as tuples""" unpacked_particles = self._unpack_particles(word) - particles = zip(*[iter(unpacked_particles)] * - self.format.fields_per_particle) + particles = zip(*[iter(unpacked_particles)] * self.format.fields_per_particle) return (particle_data_thin(particle) for particle in particles) diff --git a/sapphire/corsika/store_corsika_data.py b/sapphire/corsika/store_corsika_data.py index 79eeb147..3717716d 100644 --- a/sapphire/corsika/store_corsika_data.py +++ b/sapphire/corsika/store_corsika_data.py @@ -1,17 +1,17 @@ -""" Store CORSIKA simulation data in HDF5 file +"""Store CORSIKA simulation data in HDF5 file - This module reads the CORSIKA binary ground particles file and stores - each particle individually in a HDF5 file, using PyTables. This file - can then be used as input for the detector simulation. +This module reads the CORSIKA binary ground particles file and stores +each particle individually in a HDF5 file, using PyTables. This file +can then be used as input for the detector simulation. - The syntax and options for calling this script can be seen with:: +The syntax and options for calling this script can be seen with:: - $ store_corsika_data --help + $ store_corsika_data --help - For example to convert a CORSIKA file in the current directory called - DAT000000 to a HDF5 called corsika.h5 with a progress bar run:: +For example to convert a CORSIKA file in the current directory called +DAT000000 to a HDF5 called corsika.h5 with a progress bar run:: - $ store_corsika_data --progress DAT000000 corsika.h5 + $ store_corsika_data --progress DAT000000 corsika.h5 """ @@ -44,7 +44,6 @@ class GroundParticles(tables.IsDescription): class ThinnedGroundParticles(GroundParticles): - """Store information about thinned shower particles .. attribute:: weight @@ -57,13 +56,12 @@ class ThinnedGroundParticles(GroundParticles): weight = tables.Float32Col(pos=11) -def save_particle(row, p): +def save_particle(row, particle): """Write the information of a particle into a row""" - (p_x, p_y, p_z, x, y, t, id, r, hadron_generation, observation_level, - phi) = p + (p_x, p_y, p_z, x, y, t, particle_id, r, hadron_generation, observation_level, phi) = particle - row['particle_id'] = id + row['particle_id'] = particle_id row['r'] = r row['phi'] = phi row['x'] = x @@ -85,14 +83,13 @@ def save_thinned_particle(row, p): save_particle(row, p[:-1]) -def store_and_sort_corsika_data(source, destination, overwrite=False, - progress=False, thin=False): +def store_and_sort_corsika_data(source, destination, overwrite=False, progress=False, thin=False): """First convert the data to HDF5 and create a sorted version""" if os.path.exists(destination): if not overwrite: if progress: - raise Exception("Destination already exists, doing nothing") + raise RuntimeError('Destination already exists, doing nothing') return else: os.remove(destination) @@ -106,17 +103,15 @@ def store_and_sort_corsika_data(source, destination, overwrite=False, unsorted = create_tempfile_path(temp_dir) temp_path = create_tempfile_path(temp_dir) - with corsika_reader(source) as corsika_data, \ - tables.open_file(unsorted, 'a') as hdf_temp: - store_corsika_data(corsika_data, hdf_temp, progress=progress, - thin=thin) - - with tables.open_file(unsorted, 'r') as hdf_unsorted, \ - tables.open_file(destination, 'w') as hdf_data, \ - tables.open_file(temp_path, 'w') as hdf_temp: + with corsika_reader(source) as corsika_data, tables.open_file(unsorted, 'a') as hdf_temp: + store_corsika_data(corsika_data, hdf_temp, progress=progress, thin=thin) - with TableMergeSort('x', hdf_unsorted, hdf_data, hdf_temp, - progress=progress) as mergesort: + with ( + tables.open_file(unsorted, 'r') as hdf_unsorted, + tables.open_file(destination, 'w') as hdf_data, + tables.open_file(temp_path, 'w') as hdf_temp, + ): + with TableMergeSort('x', hdf_unsorted, hdf_data, hdf_temp, progress=progress) as mergesort: mergesort.sort() event_header = hdf_unsorted.get_node_attr('/', 'event_header') @@ -135,8 +130,7 @@ def store_and_sort_corsika_data(source, destination, overwrite=False, create_index(hdf_data, progress=progress) -def store_corsika_data(source, destination, table_name='groundparticles', - progress=False, thin=False): +def store_corsika_data(source, destination, table_name='groundparticles', progress=False, thin=False): """Store particles from a CORSIKA simulation in a HDF5 file :param source: CorsikaFile instance of the source DAT file. @@ -147,7 +141,7 @@ def store_corsika_data(source, destination, table_name='groundparticles', """ if progress: - print("Converting CORSIKA data (%s) to HDF5 format" % source._filename) + print('Converting CORSIKA data (%s) to HDF5 format' % source._filename) source.check() if not thin: @@ -161,9 +155,13 @@ def store_corsika_data(source, destination, table_name='groundparticles', n_particles = event.get_end().n_particles_levels progress = progress and n_particles > 1 try: - table = destination.create_table('/', table_name, description, - 'All groundparticles', - expectedrows=n_particles) + table = destination.create_table( + '/', + table_name, + description, + 'All groundparticles', + expectedrows=n_particles, + ) except tables.NodeError: if progress: print('%s already exists, doing nothing' % table_name) @@ -171,15 +169,14 @@ def store_corsika_data(source, destination, table_name='groundparticles', else: raise if progress: - pbar = ProgressBar(max_value=n_particles - 1, - widgets=[Percentage(), Bar(), ETA()]).start() + pbar = ProgressBar(max_value=n_particles - 1, widgets=[Percentage(), Bar(), ETA()]).start() particle_row = table.row for row, particle in enumerate(event.get_particles()): save_particle_to_row(particle_row, particle) - if progress and not row % 5000: + if progress and not row % 5_000: pbar.update(row) - if not row % 1000000: + if not row % 1_000_000: table.flush() if progress: @@ -214,8 +211,7 @@ def create_index(hdf_data, table_name='groundparticles', progress=False): table.reindex_dirty() -def copy_and_sort_node(hdf_temp, hdf_data, table_name='groundparticles', - progress=False): +def copy_and_sort_node(hdf_temp, hdf_data, table_name='groundparticles', progress=False): """Sort the data in the tables by the x column This speeds up queries to select data based on the x column. @@ -239,19 +235,14 @@ def create_tempfile_path(temp_dir=None): def main(): parser = argparse.ArgumentParser(description='Store CORSIKA data as HDF5.') - parser.add_argument('source', help="path of the CORSIKA source file") - parser.add_argument('destination', - help="path of the HDF5 destination file") - parser.add_argument('--overwrite', action='store_true', - help='overwrite destination file it is already exists') - parser.add_argument('--progress', action='store_true', - help='show progressbar during conversion') - parser.add_argument('--thin', action='store_true', - help='indicate if thinning was active in CORSIKA') + parser.add_argument('source', help='path of the CORSIKA source file') + parser.add_argument('destination', help='path of the HDF5 destination file') + parser.add_argument('--overwrite', action='store_true', help='overwrite destination file it is already exists') + parser.add_argument('--progress', action='store_true', help='show progressbar during conversion') + parser.add_argument('--thin', action='store_true', help='indicate if thinning was active in CORSIKA') args = parser.parse_args() - store_and_sort_corsika_data(args.source, args.destination, args.overwrite, - args.progress, args.thin) + store_and_sort_corsika_data(args.source, args.destination, args.overwrite, args.progress, args.thin) if __name__ == '__main__': diff --git a/sapphire/corsika/units.py b/sapphire/corsika/units.py index ba94a356..9df854ff 100644 --- a/sapphire/corsika/units.py +++ b/sapphire/corsika/units.py @@ -1,4 +1,4 @@ -""" Defines units in terms of HiSPARC standard units +"""Defines units in terms of HiSPARC standard units You should use the units defined in this file whenever you have a dimensional quantity in your code. For example: @@ -86,7 +86,7 @@ yotta = 1e24 # Length [L] -meter = 1. +meter = 1.0 meter2 = meter * meter meter3 = meter * meter * meter @@ -135,11 +135,11 @@ km3 = kilometer3 # Angle -radian = 1. +radian = 1.0 milliradian = milli * radian -degree = (3.14159265358979323846 / 180.) * radian +degree = (3.14159265358979323846 / 180.0) * radian -steradian = 1. +steradian = 1.0 # symbols rad = radian @@ -148,7 +148,7 @@ deg = degree # Time [T] -nanosecond = 1. +nanosecond = 1.0 nanosecond2 = nanosecond * nanosecond second = giga * nanosecond millisecond = milli * second @@ -158,7 +158,7 @@ hour = 60 * minute day = 24 * hour -hertz = 1. / second +hertz = 1.0 / second kilohertz = kilo * hertz megahertz = mega * hertz @@ -168,12 +168,12 @@ ms = millisecond # Electric charge [Q] -eplus = 1. # positron charge +eplus = 1.0 # positron charge eSI = 1.602176462e-19 # positron charge in coulomb coulomb = eplus / eSI # coulomb = 6.24150e18 * eplus # Energy [E] -electronvolt = 1. +electronvolt = 1.0 megaelectronvolt = mega * electronvolt kiloelectronvolt = kilo * electronvolt gigaelectronvolt = giga * electronvolt @@ -252,20 +252,20 @@ henry = weber / ampere # henry = 1.60217e-7 * MeV * (ns / eplus) ** 2 # Temperature -kelvin = 1. +kelvin = 1.0 # Amount of substance -mole = 1. +mole = 1.0 # Activity [T^-1] -becquerel = 1. / second +becquerel = 1.0 / second curie = 3.7e10 * becquerel # Absorbed dose [L^2][T^-2] gray = joule / kilogram # Luminous intensity [I] -candela = 1. +candela = 1.0 # Luminous flux [I] lumen = candela * steradian diff --git a/sapphire/data/__init__.py b/sapphire/data/__init__.py index dc47df0d..031ff2fb 100644 --- a/sapphire/data/__init__.py +++ b/sapphire/data/__init__.py @@ -9,7 +9,7 @@ bring already included data up to date """ + from . import extend_local_data, update_local_data -__all__ = ['extend_local_data', - 'update_local_data'] +__all__ = ['extend_local_data', 'update_local_data'] diff --git a/sapphire/data/extend_local_data.py b/sapphire/data/extend_local_data.py index 68fa7dab..6f6943b2 100644 --- a/sapphire/data/extend_local_data.py +++ b/sapphire/data/extend_local_data.py @@ -1,4 +1,4 @@ -""" Add more local JSON and TSV data +"""Add more local JSON and TSV data Add additional local data, to be used by :mod:`~sapphire.api` if internet is unavailable. The use of local data can also be forced to skip calls to the @@ -17,6 +17,7 @@ $ extend_local_data --help """ + import argparse from ..api import Network diff --git a/sapphire/data/update_local_data.py b/sapphire/data/update_local_data.py index 9b88c83b..c10ba386 100644 --- a/sapphire/data/update_local_data.py +++ b/sapphire/data/update_local_data.py @@ -1,4 +1,4 @@ -""" Update local JSON and TSV data +"""Update local JSON and TSV data This script updates the local copies of the JSON and TSV data from the Public Database API. If internet is unavailable the :mod:`~sapphire.api` uses these @@ -39,12 +39,14 @@ def update_local_json(progress=True): for data_type in pbar(toplevel_types, show=progress): update_toplevel_json(data_type) - for arg_type, data_type in [('stations', 'station_info'), - ('subclusters', 'stations_in_subcluster'), - ('clusters', 'subclusters_in_cluster'), - ('countries', 'clusters_in_country')]: + for arg_type, data_type in [ + ('stations', 'station_info'), + ('subclusters', 'stations_in_subcluster'), + ('clusters', 'subclusters_in_cluster'), + ('countries', 'clusters_in_country'), + ]: if progress: - print('Downloading JSONs: %s' % data_type) + print(f'Downloading JSONs: {data_type}') update_sublevel_json(arg_type, data_type, progress) @@ -53,11 +55,10 @@ def update_local_tsv(progress=True): station_numbers = Network().station_numbers() - for data_type in ['gps', 'trigger', 'layout', 'voltage', 'current', - 'electronics', 'detector_timing_offsets']: + for data_type in ['gps', 'trigger', 'layout', 'voltage', 'current', 'electronics', 'detector_timing_offsets']: if progress: - print('Downloading TSVs: %s' % data_type) - update_sublevel_tsv(data_type, station_numbers) + print(f'Downloading TSVs: {data_type}') + update_sublevel_tsv(data_type, station_numbers, progress=progress) # GPS and layout data should now be up to date, local data can be used with warnings.catch_warnings(record=True): @@ -65,8 +66,8 @@ def update_local_tsv(progress=True): for data_type in ['station_timing_offsets']: if progress: - print('Downloading TSVs: %s' % data_type) - update_subsublevel_tsv(data_type, station_numbers, network) + print(f'Downloading TSVs: {data_type}') + update_subsublevel_tsv(data_type, station_numbers, network, progress=progress) def update_toplevel_json(data_type): @@ -74,7 +75,7 @@ def update_toplevel_json(data_type): try: get_and_store_json(url) except Exception: - print('Failed to get %s data' % data_type) + print(f'Failed to get {data_type} data') def update_sublevel_json(arg_type, data_type, progress=True): @@ -89,19 +90,17 @@ def update_sublevel_json(arg_type, data_type, progress=True): numbers = [x['number'] for x in loads(API._retrieve_url(url))] except Exception: if progress: - print('Failed to get %s data' % data_type) + print(f'Failed to get {data_type} data') return kwarg = API.urls[data_type].split('/')[1].strip('{}') for number in pbar(numbers, show=progress): - url = API.urls[data_type].format(**{kwarg: number, 'year': '', - 'month': '', 'day': ''}) + url = API.urls[data_type].format(**{kwarg: number, 'year': '', 'month': '', 'day': ''}) try: get_and_store_json(url.strip('/')) except Exception: if progress: - print('Failed to get %s data for %s %d' % - (data_type, arg_type, number)) + print('Failed to get %s data for %s %d' % (data_type, arg_type, number)) return @@ -113,8 +112,7 @@ def update_sublevel_tsv(data_type, station_numbers, progress=True): pass for number in pbar(station_numbers, show=progress): - url = API.src_urls[data_type].format(station_number=number, - year='', month='', day='') + url = API.src_urls[data_type].format(station_number=number, year='', month='', day='') url = url.strip('/') + '/' try: get_and_store_tsv(url) @@ -126,8 +124,7 @@ def update_sublevel_tsv(data_type, station_numbers, progress=True): def update_subsublevel_tsv(data_type, station_numbers, network, progress=True): subdir = API.src_urls[data_type].split('/')[0] - for number1, number2 in pbar(list(combinations(station_numbers, 2)), - show=progress): + for number1, number2 in pbar(list(combinations(station_numbers, 2)), show=progress): distance = network.calc_distance_between_stations(number1, number2) if distance is None or distance > 1e3: continue @@ -135,14 +132,12 @@ def update_subsublevel_tsv(data_type, station_numbers, network, progress=True): makedirs(path.join(LOCAL_BASE, subdir, str(number1))) except OSError: pass - url = API.src_urls[data_type].format(station_1=number1, - station_2=number2) + url = API.src_urls[data_type].format(station_1=number1, station_2=number2) try: get_and_store_tsv(url) except Exception: if progress: - print('Failed to get %s data for station pair %d-%d' % - (data_type, number1, number2)) + print('Failed to get %s data for station pair %d-%d' % (data_type, number1, number2)) def get_and_store_json(url): diff --git a/sapphire/esd.py b/sapphire/esd.py index 91d0f955..eb385a36 100644 --- a/sapphire/esd.py +++ b/sapphire/esd.py @@ -1,17 +1,18 @@ -""" Fetch events and other data from the event summary data (ESD). +"""Fetch events and other data from the event summary data (ESD). - This module enables you to access the event summary data. +This module enables you to access the event summary data. - If you are in a real hurry and know what you're doing (and took the - time to read this far), you can call the :func:`quick_download` - function like this:: +If you are in a real hurry and know what you're doing (and took the +time to read this far), you can call the :func:`quick_download` +function like this:: - >>> from sapphire import quick_download - >>> data = quick_download(501) + >>> from sapphire import quick_download + >>> data = quick_download(501) - For regular use, look up :func:`download_data`. +For regular use, look up :func:`download_data`. """ + import calendar import collections import csv @@ -122,18 +123,18 @@ def load_data(file, group, tsv_file, type='events'): """ if type == 'events': table = _get_or_create_events_table(file, group) - read_and_store_class = _read_line_and_store_event_class + read_and_store_class = ReadLineAndStoreEventClass elif type == 'weather': table = _get_or_create_weather_table(file, group) - read_and_store_class = _read_line_and_store_weather_class + read_and_store_class = ReadLineAndStoreWeatherClass elif type == 'singles': table = _get_or_create_singles_table(file, group) - read_and_store_class = _read_line_and_store_singles_class + read_and_store_class = ReadLineAndStoreSinglesClass elif type == 'lightning': table = _get_or_create_lightning_table(file, group) - read_and_store_class = _read_line_and_store_lightning_class + read_and_store_class = ReadLineAndStoreLightningClass else: - raise ValueError("Data type not recognized.") + raise ValueError('Data type not recognized.') with open(tsv_file, 'rb') as data: reader = csv.reader(iterdecode(data, 'utf-8'), delimiter='\t') @@ -189,21 +190,21 @@ def download_data(file, group, station_number, start=None, end=None, type='event if type == 'events': url = get_events_url().format(station_number=station_number, query=query) table = _get_or_create_events_table(file, group) - read_and_store = _read_line_and_store_event_class + read_and_store = ReadLineAndStoreEventClass elif type == 'weather': url = get_weather_url().format(station_number=station_number, query=query) table = _get_or_create_weather_table(file, group) - read_and_store = _read_line_and_store_weather_class + read_and_store = ReadLineAndStoreWeatherClass elif type == 'singles': url = get_singles_url().format(station_number=station_number, query=query) table = _get_or_create_singles_table(file, group) - read_and_store = _read_line_and_store_singles_class + read_and_store = ReadLineAndStoreSinglesClass elif type == 'lightning': url = get_lightning_url().format(lightning_type=station_number, query=query) table = _get_or_create_lightning_table(file, group) - read_and_store = _read_line_and_store_lightning_class + read_and_store = ReadLineAndStoreLightningClass else: - raise ValueError("Data type not recognized.") + raise ValueError('Data type not recognized.') try: data = urlopen(url) @@ -217,7 +218,7 @@ def download_data(file, group, station_number, start=None, end=None, type='event t_end = calendar.timegm(end.utctimetuple()) t_delta = t_end - t_start if progress: - pbar = ProgressBar(max_value=1., widgets=[Percentage(), Bar(), ETA()]).start() + pbar = ProgressBar(max_value=1.0, widgets=[Percentage(), Bar(), ETA()]).start() # loop over lines in tsv as they come streaming in prev_update = time.time() @@ -226,8 +227,8 @@ def download_data(file, group, station_number, start=None, end=None, type='event for line in reader: timestamp = writer.store_line(line) # update progressbar every 0.5 seconds - if progress and time.time() - prev_update > 0.5 and not timestamp == 0.: - pbar.update((1. * timestamp - t_start) / t_delta) + if progress and time.time() - prev_update > 0.5 and timestamp != 0.0: + pbar.update((1.0 * timestamp - t_start) / t_delta) prev_update = time.time() if progress: pbar.finish() @@ -235,13 +236,13 @@ def download_data(file, group, station_number, start=None, end=None, type='event if line[0][0] == '#': if len(line[0]) == 1: # No events recieved, and no success line - raise Exception('Failed to download data, no data recieved.') + raise ValueError('Failed to download data, no data recieved.') else: # Successful download because last line is a non-empty comment return else: # Last line is data, report failed download and date/time of last line - raise Exception('Failed to complete download, last received data from: %s %s.' % tuple(line[:2])) + raise ValueError('Failed to complete download, last received data from: %s %s.' % tuple(line[:2])) def download_lightning(file, group, lightning_type=4, start=None, end=None, progress=True): @@ -269,7 +270,7 @@ def download_lightning(file, group, lightning_type=4, start=None, end=None, prog group = '/l%d' % lightning_type if lightning_type not in range(6): - raise ValueError("Invalid lightning type.") + raise ValueError('Invalid lightning type.') download_data(file, group, lightning_type, start=start, end=end, type='lightning', progress=progress) @@ -323,20 +324,18 @@ def load_coincidences(file, tsv_file, group=''): if line[0][0] == '#': if len(line[0]) == 1: # No events to load, and no success line - raise Exception('No data to load, source contains no data.') + raise ValueError('No data to load, source contains no data.') else: # Successful download because last line is a non-empty comment pass else: # Last line is data, report possible fail and last date/time - raise Exception('Source file seems incomplete, last received data ' - 'from: %s %s.' % tuple(line[2:4])) + raise ValueError('Source file seems incomplete, last received data from: %s %s.' % tuple(line[2:4])) file.flush() -def download_coincidences(file, group='', cluster=None, stations=None, - start=None, end=None, n=2, progress=True): +def download_coincidences(file, group='', cluster=None, stations=None, start=None, end=None, n=2, progress=True): """Download event summary data coincidences :param file: PyTables datafile handler. @@ -377,7 +376,7 @@ def download_coincidences(file, group='', cluster=None, stations=None, end = start + datetime.timedelta(days=1) if stations is not None and len(stations) < n: - raise Exception('To few stations in query, give at least n.') + raise ValueError('To few stations in query, give at least n.') # build and open url, create tables and set read function query = urlencode({'cluster': cluster, 'stations': stations, 'start': start, 'end': end, 'n': n}) @@ -397,7 +396,7 @@ def download_coincidences(file, group='', cluster=None, stations=None, t_end = calendar.timegm(end.utctimetuple()) t_delta = t_end - t_start if progress: - pbar = ProgressBar(max_value=1., widgets=[Percentage(), Bar(), ETA()]).start() + pbar = ProgressBar(max_value=1.0, widgets=[Percentage(), Bar(), ETA()]).start() # loop over lines in tsv as they come streaming in, keep temporary # lists until a full coincidence is in. @@ -414,8 +413,8 @@ def download_coincidences(file, group='', cluster=None, stations=None, # Full coincidence has been received, store it. timestamp = _read_lines_and_store_coincidence(file, c_group, coincidence, station_groups) # update progressbar every 0.5 seconds - if progress and time.time() - prev_update > 0.5 and not timestamp == 0.: - pbar.update((1. * timestamp - t_start) / t_delta) + if progress and time.time() - prev_update > 0.5 and timestamp != 0.0: + pbar.update((1.0 * timestamp - t_start) / t_delta) prev_update = time.time() coincidence = [line] current_coincidence = int(line[0]) @@ -430,14 +429,13 @@ def download_coincidences(file, group='', cluster=None, stations=None, if line[0][0] == '#': if len(line[0]) == 1: # No events recieved, and no success line - raise Exception('Failed to download data, no data recieved.') + raise ValueError('Failed to download data, no data recieved.') else: # Successful download because last line is a non-empty comment pass else: # Last line is data, report failed download and date/time of last line - raise Exception('Failed to complete download, last received data ' - 'from: %s %s.' % tuple(line[2:4])) + raise ValueError('Failed to complete download, last received data from: %s %s.' % tuple(line[2:4])) file.flush() @@ -458,8 +456,8 @@ def _read_or_get_station_groups(file, group): else: re_number = re.compile('[0-9]+$') groups = collections.OrderedDict() - for sid, station_group in enumerate(s_index): - station_group = station_group.decode() + for sid, encoded_station_group in enumerate(s_index): + station_group = encoded_station_group.decode() station = int(re_number.search(station_group).group()) groups[station] = {'group': station_group, 's_index': sid} return groups @@ -481,9 +479,10 @@ def _get_station_groups(group): for cluster in clusters: stations = network.station_numbers(cluster=cluster['number']) for station in stations: - groups[station] = {'group': ('%s/hisparc/cluster_%s/station_%d' % - (group, cluster['name'].lower(), station)), - 's_index': s_index} + groups[station] = { + 'group': ('%s/hisparc/cluster_%s/station_%d' % (group, cluster['name'].lower(), station)), + 's_index': s_index, + } s_index += 1 return groups @@ -510,8 +509,10 @@ def _create_coincidences_tables(file, group, station_groups): # Create coincidences table description = storage.Coincidence - s_columns = {'s%d' % station: tables.BoolCol(pos=p) - for p, station in enumerate(station_groups, 12)} + start_position = len(storage.Coincidence.columns) + 1 + s_columns = { + f's{station}': tables.BoolCol(pos=position) for position, station in enumerate(station_groups, start_position) + } description.columns.update(s_columns) coincidences = file.create_table(coin_group, 'coincidences', description, createparents=True) @@ -546,21 +547,23 @@ def _create_events_table(file, group): exist. """ - description = {'event_id': tables.UInt32Col(pos=0), - 'timestamp': tables.Time32Col(pos=1), - 'nanoseconds': tables.UInt32Col(pos=2), - 'ext_timestamp': tables.UInt64Col(pos=3), - 'pulseheights': tables.Int16Col(pos=4, shape=4), - 'integrals': tables.Int32Col(pos=5, shape=4), - 'n1': tables.Float32Col(pos=6), - 'n2': tables.Float32Col(pos=7), - 'n3': tables.Float32Col(pos=8), - 'n4': tables.Float32Col(pos=9), - 't1': tables.Float32Col(pos=10), - 't2': tables.Float32Col(pos=11), - 't3': tables.Float32Col(pos=12), - 't4': tables.Float32Col(pos=13), - 't_trigger': tables.Float32Col(pos=14)} + description = { + 'event_id': tables.UInt32Col(pos=0), + 'timestamp': tables.Time32Col(pos=1), + 'nanoseconds': tables.UInt32Col(pos=2), + 'ext_timestamp': tables.UInt64Col(pos=3), + 'pulseheights': tables.Int16Col(pos=4, shape=4), + 'integrals': tables.Int32Col(pos=5, shape=4), + 'n1': tables.Float32Col(pos=6), + 'n2': tables.Float32Col(pos=7), + 'n3': tables.Float32Col(pos=8), + 'n4': tables.Float32Col(pos=9), + 't1': tables.Float32Col(pos=10), + 't2': tables.Float32Col(pos=11), + 't3': tables.Float32Col(pos=12), + 't4': tables.Float32Col(pos=13), + 't_trigger': tables.Float32Col(pos=14), + } return file.create_table(group, 'events', description, createparents=True) @@ -585,22 +588,24 @@ def _create_weather_table(file, group): exist. """ - description = {'event_id': tables.UInt32Col(pos=0), - 'timestamp': tables.Time32Col(pos=1), - 'temp_inside': tables.Float32Col(pos=2), - 'temp_outside': tables.Float32Col(pos=3), - 'humidity_inside': tables.Int16Col(pos=4), - 'humidity_outside': tables.Int16Col(pos=5), - 'barometer': tables.Float32Col(pos=6), - 'wind_dir': tables.Int16Col(pos=7), - 'wind_speed': tables.Int16Col(pos=8), - 'solar_rad': tables.Int16Col(pos=9), - 'uv': tables.Int16Col(pos=10), - 'evapotranspiration': tables.Float32Col(pos=11), - 'rain_rate': tables.Float32Col(pos=12), - 'heat_index': tables.Int16Col(pos=13), - 'dew_point': tables.Float32Col(pos=14), - 'wind_chill': tables.Float32Col(pos=15)} + description = { + 'event_id': tables.UInt32Col(pos=0), + 'timestamp': tables.Time32Col(pos=1), + 'temp_inside': tables.Float32Col(pos=2), + 'temp_outside': tables.Float32Col(pos=3), + 'humidity_inside': tables.Int16Col(pos=4), + 'humidity_outside': tables.Int16Col(pos=5), + 'barometer': tables.Float32Col(pos=6), + 'wind_dir': tables.Int16Col(pos=7), + 'wind_speed': tables.Int16Col(pos=8), + 'solar_rad': tables.Int16Col(pos=9), + 'uv': tables.Int16Col(pos=10), + 'evapotranspiration': tables.Float32Col(pos=11), + 'rain_rate': tables.Float32Col(pos=12), + 'heat_index': tables.Int16Col(pos=13), + 'dew_point': tables.Float32Col(pos=14), + 'wind_chill': tables.Float32Col(pos=15), + } return file.create_table(group, 'weather', description, createparents=True) @@ -625,16 +630,18 @@ def _create_singles_table(file, group): exist. """ - description = {'event_id': tables.UInt32Col(pos=0), - 'timestamp': tables.Time32Col(pos=1), - 'mas_ch1_low': tables.Int32Col(pos=2), - 'mas_ch1_high': tables.Int32Col(pos=3), - 'mas_ch2_low': tables.Int32Col(pos=4), - 'mas_ch2_high': tables.Int32Col(pos=5), - 'slv_ch1_low': tables.Int32Col(pos=6), - 'slv_ch1_high': tables.Int32Col(pos=7), - 'slv_ch2_low': tables.Int32Col(pos=8), - 'slv_ch2_high': tables.Int32Col(pos=9)} + description = { + 'event_id': tables.UInt32Col(pos=0), + 'timestamp': tables.Time32Col(pos=1), + 'mas_ch1_low': tables.Int32Col(pos=2), + 'mas_ch1_high': tables.Int32Col(pos=3), + 'mas_ch2_low': tables.Int32Col(pos=4), + 'mas_ch2_high': tables.Int32Col(pos=5), + 'slv_ch1_low': tables.Int32Col(pos=6), + 'slv_ch1_high': tables.Int32Col(pos=7), + 'slv_ch2_low': tables.Int32Col(pos=8), + 'slv_ch2_high': tables.Int32Col(pos=9), + } return file.create_table(group, 'singles', description, createparents=True) @@ -659,13 +666,15 @@ def _create_lightning_table(file, group): exist. """ - description = {'event_id': tables.UInt32Col(pos=0), - 'timestamp': tables.Time32Col(pos=1), - 'nanoseconds': tables.UInt32Col(pos=2), - 'ext_timestamp': tables.UInt64Col(pos=3), - 'latitude': tables.Float32Col(pos=4), - 'longitude': tables.Float32Col(pos=5), - 'current': tables.Float32Col(pos=6)} + description = { + 'event_id': tables.UInt32Col(pos=0), + 'timestamp': tables.Time32Col(pos=1), + 'nanoseconds': tables.UInt32Col(pos=2), + 'ext_timestamp': tables.UInt64Col(pos=3), + 'latitude': tables.Float32Col(pos=4), + 'longitude': tables.Float32Col(pos=5), + 'current': tables.Float32Col(pos=6), + } return file.create_table(group, 'lightning', description, createparents=True) @@ -695,14 +704,15 @@ def _read_lines_and_store_coincidence(file, c_group, coincidence, station_groups for event in coincidence: station_number = int(event[1]) try: - row['s%d' % station_number] = True + row[f's{station_number}'] = True group_path = station_groups[station_number]['group'] except KeyError: # Can not add new column, so user should make a new data file. - raise Exception('Unexpected station number: %d, no column and/or ' - 'station group path available.' % station_number) + raise KeyError( + f'Unexpected station number: {station_number}, no column and/or station group path available.', + ) event_group = _get_or_create_events_table(file, group_path) - with _read_line_and_store_event_class(event_group) as writer: + with ReadLineAndStoreEventClass(event_group) as writer: s_idx = station_groups[station_number]['s_index'] e_idx = len(event_group) c_idx.append((s_idx, e_idx)) @@ -715,8 +725,7 @@ def _read_lines_and_store_coincidence(file, c_group, coincidence, station_groups return int(coincidence[0][4]) -class _read_line_and_store_event_class: - +class ReadLineAndStoreEventClass: """Store lines of event data from the ESD Use this contextmanager to store events from a TSV file into a PyTables @@ -743,12 +752,34 @@ def store_line(self, line): """ # ignore comment lines if line[0][0] == '#': - return 0. + return 0.0 # break up TSV line - (date, time_str, timestamp, nanoseconds, ph1, ph2, ph3, ph4, int1, - int2, int3, int4, n1, n2, n3, n4, t1, t2, t3, t4, t_trigger, zenith, - azimuth) = line[:23] + ( + date, + time_str, + timestamp, + nanoseconds, + ph1, + ph2, + ph3, + ph4, + int1, + int2, + int3, + int4, + n1, + n2, + n3, + n4, + t1, + t2, + t3, + t4, + t_trigger, + zenith, + azimuth, + ) = line[:23] row = self.table.row @@ -774,7 +805,7 @@ def store_line(self, line): self.event_counter += 1 # force flush every 1e6 rows to free buffers - if not self.event_counter % 1000000: + if not self.event_counter % 1_000_000: self.table.flush() return int(timestamp) @@ -783,21 +814,34 @@ def __exit__(self, type, value, traceback): self.table.flush() -class _read_line_and_store_weather_class(_read_line_and_store_event_class): - +class ReadLineAndStoreWeatherClass(ReadLineAndStoreEventClass): """Store lines of weather data from the ESD""" def store_line(self, line): # ignore comment lines if line[0][0] == '#': - return 0. + return 0.0 # break up TSV line - (date, time, timestamp, temperature_inside, temperature_outside, - humidity_inside, humidity_outside, atmospheric_pressure, - wind_direction, wind_speed, solar_radiation, uv_index, - evapotranspiration, rain_rate, heat_index, dew_point, - wind_chill) = line + ( + date, + time, + timestamp, + temperature_inside, + temperature_outside, + humidity_inside, + humidity_outside, + atmospheric_pressure, + wind_direction, + wind_speed, + solar_radiation, + uv_index, + evapotranspiration, + rain_rate, + heat_index, + dew_point, + wind_chill, + ) = line row = self.table.row @@ -824,25 +868,34 @@ def store_line(self, line): self.event_counter += 1 # force flush every 1e6 rows to free buffers - if not self.event_counter % 1000000: + if not self.event_counter % 1_000_000: self.table.flush() return int(timestamp) -class _read_line_and_store_singles_class(_read_line_and_store_event_class): - +class ReadLineAndStoreSinglesClass(ReadLineAndStoreEventClass): """Store lines of singles data from the ESD""" def store_line(self, line): # ignore comment lines if line[0][0] == '#': - return 0. + return 0.0 # break up TSV line - (date, time, timestamp, - mas_ch1_low, mas_ch1_high, mas_ch2_low, mas_ch2_high, - slv_ch1_low, slv_ch1_high, slv_ch2_low, slv_ch2_high) = line + ( + date, + time, + timestamp, + mas_ch1_low, + mas_ch1_high, + mas_ch2_low, + mas_ch2_high, + slv_ch1_low, + slv_ch1_high, + slv_ch2_low, + slv_ch2_high, + ) = line row = self.table.row @@ -863,20 +916,19 @@ def store_line(self, line): self.event_counter += 1 # force flush every 1e6 rows to free buffers - if not self.event_counter % 1000000: + if not self.event_counter % 1_000_000: self.table.flush() return int(timestamp) -class _read_line_and_store_lightning_class(_read_line_and_store_event_class): - +class ReadLineAndStoreLightningClass(ReadLineAndStoreEventClass): """Store lines of lightning data from the ESD""" def store_line(self, line): # ignore comment lines if line[0][0] == '#': - return 0. + return 0.0 # break up TSV line (date, time_str, timestamp, nanoseconds, latitude, longitude, current) = line[:7] @@ -897,7 +949,7 @@ def store_line(self, line): self.event_counter += 1 # force flush every 1e6 rows to free buffers - if not self.event_counter % 1000000: + if not self.event_counter % 1_000_000: self.table.flush() return int(timestamp) diff --git a/sapphire/kascade.py b/sapphire/kascade.py index ba37cfa6..e47e0570 100644 --- a/sapphire/kascade.py +++ b/sapphire/kascade.py @@ -1,15 +1,15 @@ -""" Read and store KASCADE data. +"""Read and store KASCADE data. - Read data files provided by the KASCADE collaboration and store them - in a format compatible with HiSPARC data. +Read data files provided by the KASCADE collaboration and store them +in a format compatible with HiSPARC data. - This module contains the following class: +This module contains the following class: - :class:`StoreKascadeData` - Read and store KASCADE data files. +:class:`StoreKascadeData` + Read and store KASCADE data files. - :class:`KascadeCoincidences` - Find HiSPARC and KASCADE events that belong together. +:class:`KascadeCoincidences` + Find HiSPARC and KASCADE events that belong together. """ @@ -25,8 +25,7 @@ class StoreKascadeData: - def __init__(self, data, kascade_filename, kascade_path='/kascade', - hisparc_path=None, force=False, progress=True): + def __init__(self, data, kascade_filename, kascade_path='/kascade', hisparc_path=None, force=False, progress=True): """Initialize the class. :param data: the PyTables datafile @@ -47,12 +46,11 @@ def __init__(self, data, kascade_filename, kascade_path='/kascade', if kascade_path in data: if not force: - raise RuntimeError(f"Cancelling data storage; {kascade_path} already exists") + raise RuntimeError(f'Cancelling data storage; {kascade_path} already exists') else: data.remove_node(kascade_path, recursive=True) - self.kascade = data.create_table(kascade_path, 'events', KascadeEvent, - "KASCADE events", createparents=True) + self.kascade = data.create_table(kascade_path, 'events', KascadeEvent, 'KASCADE events', createparents=True) self.kascade_filename = kascade_filename def read_and_store_data(self): @@ -70,15 +68,15 @@ def read_and_store_data(self): start = clock.gps_to_utc(min(timestamps)) - 5 stop = clock.gps_to_utc(max(timestamps)) + 5 except IndexError: - raise RuntimeError("HiSPARC event table is empty") + raise RuntimeError('HiSPARC event table is empty') if self.progress: - print(f"Processing data from {time.ctime(start)} to {time.ctime(stop)}") + print(f'Processing data from {time.ctime(start)} to {time.ctime(stop)}') else: start = None stop = None if self.progress: - print("Processing all data") + print('Processing all data') self._process_events_in_range(start, stop) @@ -138,8 +136,29 @@ def _store_kascade_event(self, data): tablerow = self.kascade.row # read all columns into KASCADE-named variables - (Irun, Ieve, Gt, Mmn, EnergyArray, Xc, Yc, Ze, Az, Size, Nmu, He0, # noqa: N806 - Hmu0, He1, Hmu1, He2, Hmu2, He3, Hmu3, T200, P200) = data + ( + Irun, + Ieve, + Gt, + Mmn, + EnergyArray, + Xc, + Yc, + Ze, + Az, + Size, + Nmu, + He0, + Hmu0, + He1, + Hmu1, + He2, + Hmu2, + He3, + Hmu3, + T200, + P200, + ) = data tablerow['run_id'] = Irun tablerow['event_id'] = Ieve @@ -168,7 +187,7 @@ def __init__(self, data, hisparc_group, kascade_group, overwrite=False, ignore_e if 'c_index' in self.kascade_group: if not overwrite and not ignore_existing: - raise RuntimeError("I found existing coincidences stored in the KASCADE group") + raise RuntimeError('I found existing coincidences stored in the KASCADE group') elif overwrite: data.remove_node(kascade_group, 'c_index') @@ -196,11 +215,11 @@ def search_coincidences(self, timeshift=0, dtlimit=None, limit=None): # Shift the kascade data instead of the hisparc data. There is less of # it, so this is much faster. - k['ext_timestamp'] += int(-1e9) * timeshift + k['ext_timestamp'] += -1_000_000_000 * timeshift if dtlimit: # dtlimit in ns - dtlimit *= 1e9 + dtlimit *= 1_000_000_000 coinc_dt, coinc_h_idx, coinc_k_idx = [], [], [] @@ -215,7 +234,7 @@ def search_coincidences(self, timeshift=0, dtlimit=None, limit=None): if limit: max_k_idx = k_idx + limit - 1 else: - max_k_idx = np.Inf + max_k_idx = np.inf while k_idx <= max_k_idx: # Try to get the timestamps of the kascade event and the @@ -257,8 +276,7 @@ def search_coincidences(self, timeshift=0, dtlimit=None, limit=None): # one. k_idx += 1 - self.coincidences = np.rec.fromarrays( - [coinc_dt, coinc_h_idx, coinc_k_idx], names='dt, h_idx, k_idx') + self.coincidences = np.rec.fromarrays([coinc_dt, coinc_h_idx, coinc_k_idx], names='dt, h_idx, k_idx') def store_coincidences(self): self.data.create_table(self.kascade_group, 'c_index', self.coincidences) @@ -273,7 +291,6 @@ def _get_cached_sorted_id_and_timestamp_arrays(self): def _get_sorted_id_and_timestamp_array(self, group): timestamps = group.events.col('ext_timestamp') ids = group.events.col('event_id') - data = np.rec.fromarrays([ids, timestamps], - names='event_id, ext_timestamp') + data = np.rec.fromarrays([ids, timestamps], names='event_id, ext_timestamp') data.sort(order='ext_timestamp') return data diff --git a/sapphire/publicdb.py b/sapphire/publicdb.py index b2463061..6a35d93e 100644 --- a/sapphire/publicdb.py +++ b/sapphire/publicdb.py @@ -1,11 +1,12 @@ -""" Fetch raw events and other data from the public database +"""Fetch raw events and other data from the public database - This module enables you to access the public database and even the raw - event data. This is intended for specialized use only. For most uses, it is - faster and more convenient to access the event summary data (ESD) using - :mod:`~sapphire.esd`. +This module enables you to access the public database and even the raw +event data. This is intended for specialized use only. For most uses, it is +faster and more convenient to access the event summary data (ESD) using +:mod:`~sapphire.esd`. """ + import datetime import logging import os @@ -62,21 +63,21 @@ def download_data(file, group, station_id, start, end, get_blobs=False): server = ServerProxy(get_publicdb_xmlrpc_url()) for t0, t1 in datetimerange(start, end): - logger.info(f"{t0} {t1}") - logger.info(f"Getting server data URL {t0}") + logger.info(f'{t0} {t1}') + logger.info(f'Getting server data URL {t0}') try: url = server.hisparc.get_data_url(station_id, t0, get_blobs) - except Exception as exc: - if re.search("No data", str(exc)): - logger.warning(f"No data for {t0}") + except Exception as error: + if re.search('No data', str(error)): + logger.warning(f'No data for {t0}') continue else: raise - logger.info("Downloading data...") + logger.info('Downloading data...') tmp_datafile, headers = urlretrieve(url) - logger.info("Storing data...") + logger.info('Storing data...') _store_data(file, group, tmp_datafile, t0, t1) - logger.info("Done.") + logger.info('Done.') def _store_data(dst_file, dst_group, src_filename, t0, t1): @@ -102,14 +103,21 @@ def _store_data(dst_file, dst_group, src_filename, t0, t1): for row in node: dst_node.append(row) - elif node.name in ['events', 'errors', 'config', 'comparator', - 'singles', 'satellites', 'weather', - 'weather_error', 'weather_config']: + elif node.name in [ + 'events', + 'errors', + 'config', + 'comparator', + 'singles', + 'satellites', + 'weather', + 'weather_error', + 'weather_config', + ]: if t1 is None: - cond = 'timestamp >= %d' % datetime_to_gps(t0) + cond = f'timestamp >= {datetime_to_gps(t0)}' else: - cond = ('(%d <= timestamp) & (timestamp <= %d)' % - (datetime_to_gps(t0), datetime_to_gps(t1))) + cond = f'({datetime_to_gps(t0)} <= timestamp) & (timestamp <= {datetime_to_gps(t1)})' rows = node.read_where(cond) @@ -176,7 +184,7 @@ def datetimerange(start, stop): """ if start > stop: - raise Exception('Start can not be after stop.') + raise ValueError('Start can not be after stop.') elif start.date() == stop.date(): yield start, stop return @@ -214,6 +222,6 @@ def _get_or_create_node(file, group, src_node): elif isinstance(src_node, tables.VLArray): node = file.create_vlarray(group, src_node.name, src_node.atom, src_node.title) else: - raise Exception("Unknown node class: %s" % type(src_node)) + raise TypeError(f'Unknown node class: {type(src_node)}') return node diff --git a/sapphire/qsub.py b/sapphire/qsub.py index cacb6a0e..6a553329 100644 --- a/sapphire/qsub.py +++ b/sapphire/qsub.py @@ -1,23 +1,25 @@ -""" Access the Nikhef Stoomboot cluster. +"""Access the Nikhef Stoomboot cluster. - .. note:: - This module is only for use at Nikhef. The Stoomboot cluster is only - accessible for Nikhef users. +.. note:: + This module is only for use at Nikhef. The Stoomboot cluster is only + accessible for Nikhef users. - Easy to use functions to make use of the Nikhef Stoomboot facilities. - This checks the available slots on the requested queue, creates the - scripts to submit, submits the jobs, and cleans up afterwards. +Easy to use functions to make use of the Nikhef Stoomboot facilities. +This checks the available slots on the requested queue, creates the +scripts to submit, submits the jobs, and cleans up afterwards. - Example usage:: +Example usage:: - >>> from sapphire import qsub - >>> qsub.check_queue('long') - 340 - >>> qsub.submit_job('touch /data/hisparc/test', 'job_1', 'express') + >>> from sapphire import qsub + >>> qsub.check_queue('long') + 340 + >>> qsub.submit_job('touch /data/hisparc/test', 'job_1', 'express') """ + import os import subprocess +import tempfile from . import utils @@ -73,7 +75,7 @@ def submit_job(script, name, queue, extra=''): result = subprocess.check_output(qsub, stderr=subprocess.STDOUT, shell=True) if not result == b'': - raise Exception(f'{name} - Error occured: {result}') + raise RuntimeError(f'{name} - Error occured: {result}') delete_script(script_path) @@ -82,7 +84,7 @@ def create_script(script, name): """Create script as temp file to be run on Stoomboot""" script_name = f'his_{name}.sh' - script_path = os.path.join('/tmp', script_name) + script_path = os.path.join(tempfile.gettempdir(), script_name) with open(script_path, 'w') as script_file: script_file.write(script) diff --git a/sapphire/simulations/__init__.py b/sapphire/simulations/__init__.py index 692268bc..beb12d98 100644 --- a/sapphire/simulations/__init__.py +++ b/sapphire/simulations/__init__.py @@ -22,11 +22,7 @@ simulation of detector response due to gammas """ + from . import base, detector, gammas, groundparticles, ldf, showerfront -__all__ = ['base', - 'detector', - 'groundparticles', - 'ldf', - 'showerfront', - 'gammas'] +__all__ = ['base', 'detector', 'groundparticles', 'ldf', 'showerfront', 'gammas'] diff --git a/sapphire/simulations/base.py b/sapphire/simulations/base.py index d8d94641..ba89fd5c 100644 --- a/sapphire/simulations/base.py +++ b/sapphire/simulations/base.py @@ -20,6 +20,7 @@ >>> sim.run() """ + import random import warnings @@ -32,7 +33,6 @@ class BaseSimulation: - """Base class for simulations. :param cluster: :class:`~sapphire.clusters.BaseCluster` instance. @@ -45,8 +45,7 @@ class BaseSimulation: """ - def __init__(self, cluster, data, output_path='/', n=1, seed=None, - progress=True): + def __init__(self, cluster, data, output_path='/', n=1, seed=None, progress=True): self.cluster = cluster self.data = data self.output_path = output_path @@ -76,22 +75,21 @@ def _prepare_output_tables(self): def run(self): """Run the simulations.""" - for (shower_id, shower_parameters) in enumerate( - self.generate_shower_parameters()): - + for shower_id, shower_parameters in enumerate(self.generate_shower_parameters()): station_events = self.simulate_events_for_shower(shower_parameters) - self.store_coincidence(shower_id, shower_parameters, - station_events) + self.store_coincidence(shower_id, shower_parameters, station_events) def generate_shower_parameters(self): """Generate shower parameters like core position, energy, etc.""" - shower_parameters = {'core_pos': (None, None), - 'zenith': None, - 'azimuth': None, - 'size': None, - 'energy': None, - 'ext_timestamp': None} + shower_parameters = { + 'core_pos': (None, None), + 'zenith': None, + 'azimuth': None, + 'size': None, + 'energy': None, + 'ext_timestamp': None, + } for _ in pbar(range(self.n), show=self.progress): yield shower_parameters @@ -101,26 +99,19 @@ def simulate_events_for_shower(self, shower_parameters): station_events = [] for station_id, station in enumerate(self.cluster.stations): - has_triggered, station_observables = \ - self.simulate_station_response(station, - shower_parameters) + has_triggered, station_observables = self.simulate_station_response(station, shower_parameters) if has_triggered: - event_index = \ - self.store_station_observables(station_id, - station_observables) + event_index = self.store_station_observables(station_id, station_observables) station_events.append((station_id, event_index)) return station_events def simulate_station_response(self, station, shower_parameters): """Simulate station response to a shower.""" - detector_observables = self.simulate_all_detectors( - station.detectors, shower_parameters) + detector_observables = self.simulate_all_detectors(station.detectors, shower_parameters) has_triggered = self.simulate_trigger(detector_observables) - station_observables = \ - self.process_detector_observables(detector_observables) - station_observables = self.simulate_gps(station_observables, - shower_parameters, station) + station_observables = self.process_detector_observables(detector_observables) + station_observables = self.simulate_gps(station_observables, shower_parameters, station) return has_triggered, station_observables @@ -133,8 +124,7 @@ def simulate_all_detectors(self, detectors, shower_parameters): """ detector_observables = [] for detector in detectors: - observables = self.simulate_detector_response(detector, - shower_parameters) + observables = self.simulate_detector_response(detector, shower_parameters) detector_observables.append(observables) return detector_observables @@ -149,7 +139,7 @@ def simulate_detector_response(self, detector, shower_parameters): """ # implement this! - observables = {'n': 0., 't': -999} + observables = {'n': 0.0, 't': -999} return observables @@ -179,8 +169,7 @@ def process_detector_observables(self, detector_observables): like n1, n2, n3, etc. """ - station_observables = {'pulseheights': 4 * [-1.], - 'integrals': 4 * [-1.]} + station_observables = {'pulseheights': 4 * [-1.0], 'integrals': 4 * [-1.0]} for detector_id, observables in enumerate(detector_observables, 1): for key, value in observables.items(): @@ -215,8 +204,7 @@ def store_station_observables(self, station_id, station_observables): return events_table.nrows - 1 - def store_coincidence(self, shower_id, shower_parameters, - station_events): + def store_coincidence(self, shower_id, shower_parameters, station_events): """Store coincidence. Store the information to find events of different stations @@ -246,16 +234,14 @@ def store_coincidence(self, shower_id, shower_parameters, row['s%d' % station.number] = True station_group = self.station_groups[station_id] event = station_group.events[event_index] - timestamps.append((event['ext_timestamp'], event['timestamp'], - event['nanoseconds'])) + timestamps.append((event['ext_timestamp'], event['timestamp'], event['nanoseconds'])) try: first_timestamp = sorted(timestamps)[0] except IndexError: first_timestamp = (0, 0, 0) - row['ext_timestamp'], row['timestamp'], row['nanoseconds'] = \ - first_timestamp + row['ext_timestamp'], row['timestamp'], row['nanoseconds'] = first_timestamp row.append() self.coincidences.flush() @@ -270,27 +256,23 @@ def _prepare_coincidence_tables(self): This makes it easy to link events detected by multiple stations. """ - self.coincidence_group = self.data.create_group(self.output_path, - 'coincidences', - createparents=True) + self.coincidence_group = self.data.create_group(self.output_path, 'coincidences', createparents=True) try: self.coincidence_group._v_attrs.cluster = self.cluster except tables.HDF5ExtError: warnings.warn('Unable to store cluster object, to large for HDF.') description = storage.Coincidence - s_columns = {'s%d' % station.number: tables.BoolCol(pos=p) - for p, station in enumerate(self.cluster.stations, 12)} + s_columns = { + 's%d' % station.number: tables.BoolCol(pos=p) for p, station in enumerate(self.cluster.stations, 12) + } description.columns.update(s_columns) - self.coincidences = self.data.create_table( - self.coincidence_group, 'coincidences', description) + self.coincidences = self.data.create_table(self.coincidence_group, 'coincidences', description) - self.c_index = self.data.create_vlarray( - self.coincidence_group, 'c_index', tables.UInt32Col(shape=2)) + self.c_index = self.data.create_vlarray(self.coincidence_group, 'c_index', tables.UInt32Col(shape=2)) - self.s_index = self.data.create_vlarray( - self.coincidence_group, 's_index', tables.VLStringAtom()) + self.s_index = self.data.create_vlarray(self.coincidence_group, 's_index', tables.VLStringAtom()) def _prepare_station_tables(self): """Create the groups and events table to store the observables @@ -299,17 +281,12 @@ def _prepare_station_tables(self): :param station: a :class:`sapphire.clusters.Station` object """ - self.cluster_group = self.data.create_group(self.output_path, - 'cluster_simulations', - createparents=True) + self.cluster_group = self.data.create_group(self.output_path, 'cluster_simulations', createparents=True) self.station_groups = [] for station in self.cluster.stations: - station_group = self.data.create_group(self.cluster_group, - 'station_%d' % - station.number) + station_group = self.data.create_group(self.cluster_group, 'station_%d' % station.number) description = ProcessEvents.processed_events_description - self.data.create_table(station_group, 'events', description, - expectedrows=self.n) + self.data.create_table(station_group, 'events', description, expectedrows=self.n) self.station_groups.append(station_group) def _store_station_index(self): @@ -321,7 +298,10 @@ def _store_station_index(self): def __repr__(self): if not self.data.isopen: - return "" % self.__class__.__name__ - return ('<%s, cluster: %r, data: %r, output_path: %r>' % - (self.__class__.__name__, self.cluster, self.data.filename, - self.output_path)) + return '' % self.__class__.__name__ + return '<%s, cluster: %r, data: %r, output_path: %r>' % ( + self.__class__.__name__, + self.cluster, + self.data.filename, + self.output_path, + ) diff --git a/sapphire/simulations/detector.py b/sapphire/simulations/detector.py index b485d90e..850882c3 100644 --- a/sapphire/simulations/detector.py +++ b/sapphire/simulations/detector.py @@ -3,6 +3,7 @@ These are some common simulations for HiSPARC detectors. """ + import warnings from math import acos, cos, pi, sin, sqrt @@ -15,7 +16,6 @@ class HiSPARCSimulation(BaseSimulation): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -103,12 +103,11 @@ def simulate_signal_transport_time(cls, n=1): """ numbers = np.random.random(n) if n < 20: - dt = np.array([2.5507 + 2.39885 * number if number < 0.39377 else - 1.56764 + 4.89536 * number for number in numbers]) + dt = np.array( + [2.5507 + 2.39885 * number if number < 0.39377 else 1.56764 + 4.89536 * number for number in numbers], + ) else: - dt = np.where(numbers < 0.39377, - 2.5507 + 2.39885 * numbers, - 1.56764 + 4.89536 * numbers) + dt = np.where(numbers < 0.39377, 2.5507 + 2.39885 * numbers, 1.56764 + 4.89536 * numbers) return dt @classmethod @@ -144,7 +143,7 @@ def simulate_detector_mips(cls, n, theta): """ # Limit cos theta to maximum length though the detector. - min_costheta = 2. / 112. + min_costheta = 2.0 / 112.0 costheta = np.cos(theta) if isinstance(costheta, float): costheta = max(costheta, min_costheta) @@ -167,13 +166,9 @@ def simulate_detector_mips(cls, n, theta): if not isinstance(costheta, float): mips = sum(mips) else: - mips = np.where(y < 0.3394, - (0.48 + 0.8583 * np.sqrt(y)) / costheta, - (0.73 + 0.7366 * y) / costheta) - mips = np.where(y < 0.4344, mips, - (1.7752 - 1.0336 * np.sqrt(0.9267 - y)) / costheta) - mips = np.where(y < 0.9041, mips, - (2.28 - 2.1316 * np.sqrt(1 - y)) / costheta) + mips = np.where(y < 0.3394, (0.48 + 0.8583 * np.sqrt(y)) / costheta, (0.73 + 0.7366 * y) / costheta) + mips = np.where(y < 0.4344, mips, (1.7752 - 1.0336 * np.sqrt(0.9267 - y)) / costheta) + mips = np.where(y < 0.9041, mips, (2.28 - 2.1316 * np.sqrt(1 - y)) / costheta) mips = sum(mips) warnings.resetwarnings() return mips @@ -193,14 +188,14 @@ def generate_core_position(cls, r_max): :return: Random x, y position in the disc with radius r_max. """ - r = sqrt(np.random.uniform(0, r_max ** 2)) + r = sqrt(np.random.uniform(0, r_max**2)) phi = np.random.uniform(-pi, pi) x = r * cos(phi) y = r * sin(phi) return x, y @classmethod - def generate_zenith(cls, min=0, max=pi / 3.): + def generate_zenith(cls, min_value=0, max_value=pi / 3.0): """Generate a random zenith Generate a random zenith for a uniform distribution on a sphere. @@ -218,7 +213,7 @@ def generate_zenith(cls, min=0, max=pi / 3.): :return: random zenith position on a sphere, in radians. """ - p = np.random.uniform(cos(max), cos(min)) + p = np.random.uniform(cos(max_value), cos(min_value)) return acos(p) @classmethod @@ -249,7 +244,7 @@ def inverse_zenith_probability(cls, p): :return: zenith with corresponding cumulative probability, in radians. """ - return acos((1 - p) ** (1 / 8.)) + return acos((1 - p) ** (1 / 8.0)) @classmethod def generate_azimuth(cls): @@ -277,44 +272,36 @@ def generate_energy(cls, e_min=1e14, e_max=1e21, alpha=-2.75): """ x = np.random.random() - a1 = alpha + 1. - energy = (e_min ** a1 + x * (e_max ** a1 - e_min ** a1)) ** (1 / a1) + a1 = alpha + 1.0 + energy = (e_min**a1 + x * (e_max**a1 - e_min**a1)) ** (1 / a1) return energy class ErrorlessSimulation(HiSPARCSimulation): - @classmethod def simulate_detector_offsets(cls, n_detectors): - - return [0.] * n_detectors + return [0.0] * n_detectors @classmethod def simulate_detector_offset(cls): - - return 0. + return 0.0 @classmethod def simulate_station_offset(cls): - - return 0. + return 0.0 @classmethod def simulate_gps_uncertainty(cls): - - return 0. + return 0.0 @classmethod def simulate_adc_sampling(cls, t): - return t @classmethod def simulate_signal_transport_time(cls, n=1): - - return np.array([0.] * n) + return np.array([0.0] * n) @classmethod def simulate_detector_mips(cls, n, theta): - return n diff --git a/sapphire/simulations/gammas.py b/sapphire/simulations/gammas.py index 368315c5..8b108d82 100644 --- a/sapphire/simulations/gammas.py +++ b/sapphire/simulations/gammas.py @@ -10,7 +10,7 @@ import numpy as np SCINTILLATOR_THICKNESS = 2.0 # cm -MAX_DEPTH = 112. # longest straight path in scintillator in cm +MAX_DEPTH = 112.0 # longest straight path in scintillator in cm ENERGY_LOSS = 2.0 # 2 MeV per cm MAX_E = ENERGY_LOSS * SCINTILLATOR_THICKNESS MIP = 3.38 # MeV @@ -47,14 +47,11 @@ def compton_energy_transfer(gamma_energy): recoil_energies = np.linspace(0, edge, 1000) # electron energy distribution - electron_energy = [energy_transfer_cross_section(gamma_energy, - recoil_energy) - for recoil_energy in recoil_energies] + electron_energy = [energy_transfer_cross_section(gamma_energy, recoil_energy) for recoil_energy in recoil_energies] cumulative_energy = np.cumsum(electron_energy) - normalised_energy_distribution = (cumulative_energy / - cumulative_energy[-1]) + normalised_energy_distribution = cumulative_energy / cumulative_energy[-1] r = np.random.random() conversion_factor = normalised_energy_distribution.searchsorted(r) / 1000 @@ -79,9 +76,12 @@ def energy_transfer_cross_section(gamma_energy, recoil_energy): s = recoil_energy / gamma_energy - return (np.pi * (r_e ** 2) / (ELECTRON_REST_MASS_MeV * gamma ** 2) * - (2 + (s ** 2 / ((gamma ** 2) * ((1 - s) ** 2))) + - (s / (1 - s)) * (s - 2 / gamma))) + return ( + np.pi + * (r_e**2) + / (ELECTRON_REST_MASS_MeV * gamma**2) + * (2 + (s**2 / ((gamma**2) * ((1 - s) ** 2))) + (s / (1 - s)) * (s - 2 / gamma)) + ) def max_energy_deposit_in_mips(depth, scintillator_depth): @@ -114,16 +114,14 @@ def simulate_detector_mips_gammas(p, theta): mips = 0 for energy, angle in zip(energies, theta): # project depth onto direction of incident particle - scintillator_depth = min(SCINTILLATOR_THICKNESS / np.cos(angle), - MAX_DEPTH) + scintillator_depth = min(SCINTILLATOR_THICKNESS / np.cos(angle), MAX_DEPTH) # Calculate interaction point in units of scinitlator depth. # If depth > 1 there is no interaction. depth_compton = expovariate(1 / compton_mean_free_path(energy)) depth_pair = expovariate(1 / pair_mean_free_path(energy)) - if ((depth_pair > scintillator_depth) & - (depth_compton > scintillator_depth)): + if (depth_pair > scintillator_depth) & (depth_compton > scintillator_depth): # no interaction continue @@ -133,8 +131,7 @@ def simulate_detector_mips_gammas(p, theta): # kinetic energy transfered to electron by compton scattering energy_deposit = compton_energy_transfer(energy) / MIP - max_deposit = max_energy_deposit_in_mips(depth_compton, - scintillator_depth) + max_deposit = max_energy_deposit_in_mips(depth_compton, scintillator_depth) mips += min(max_deposit, energy_deposit) elif energy > 1.022: @@ -143,8 +140,7 @@ def simulate_detector_mips_gammas(p, theta): # 1.022 MeV used for creation of two particles # all the rest is electron kinetic energy energy_deposit = (energy - 1.022) / MIP - max_deposit = max_energy_deposit_in_mips(depth_pair, - scintillator_depth) + max_deposit = max_energy_deposit_in_mips(depth_pair, scintillator_depth) mips += min(max_deposit, energy_deposit) return mips @@ -163,24 +159,59 @@ def pair_mean_free_path(gamma_energy): :return: mean free path [cm]. """ - energy_path_pair_production = np.array([ - (4, 689.31), (5, 504.52), (6, 404.96), - (7, 343.56), (8, 302.00), (9, 271.84), - (10, 249.03), (11, 231.28), (12, 217.04), - (13, 205.23), (14, 195.32), (15, 186.88), - (16, 179.47), (18, 167.40), (20, 157.85), - (22, 149.97), (24, 143.51), (26, 138.00), - (28, 133.30), (30, 129.20), (40, 114.65), - (50, 105.64), (60, 99.37), (80, 91.17), - (100, 85.90), (150, 78.25), (200, 74.07), - (300, 69.44), (400, 66.93), (500, 65.34), - (600, 64.21), (800, 62.73), (1000, 61.82), - (1500, 60.47), (2000, 59.72), (3000, 58.97), - (4000, 58.53), (5000, 58.28), (6000, 58.09), - (8000, 57.85), (10000, 57.70), (15000, 57.51), - (20000, 57.41), (30000, 57.27), (40000, 57.21), - (50000, 57.17), (60000, 57.13), (80000, 57.12), - (100000, 57.08)]) + energy_path_pair_production = np.array( + [ + (4, 689.31), + (5, 504.52), + (6, 404.96), + (7, 343.56), + (8, 302.00), + (9, 271.84), + (10, 249.03), + (11, 231.28), + (12, 217.04), + (13, 205.23), + (14, 195.32), + (15, 186.88), + (16, 179.47), + (18, 167.40), + (20, 157.85), + (22, 149.97), + (24, 143.51), + (26, 138.00), + (28, 133.30), + (30, 129.20), + (40, 114.65), + (50, 105.64), + (60, 99.37), + (80, 91.17), + (100, 85.90), + (150, 78.25), + (200, 74.07), + (300, 69.44), + (400, 66.93), + (500, 65.34), + (600, 64.21), + (800, 62.73), + (1000, 61.82), + (1500, 60.47), + (2000, 59.72), + (3000, 58.97), + (4000, 58.53), + (5000, 58.28), + (6000, 58.09), + (8000, 57.85), + (10000, 57.70), + (15000, 57.51), + (20000, 57.41), + (30000, 57.27), + (40000, 57.21), + (50000, 57.17), + (60000, 57.13), + (80000, 57.12), + (100000, 57.08), + ], + ) gamma_energies = energy_path_pair_production[:, 0] mean_free_paths = energy_path_pair_production[:, 1] @@ -202,24 +233,59 @@ def compton_mean_free_path(gamma_energy): :return: mean free path [cm]. """ - energy_path_compton_scattering = np.array([ - (4, 31.88), (5, 36.90), (6, 41.75), - (7, 46.47), (8, 51.05), (9, 55.52), - (10, 59.95), (11, 64.27), (12, 68.54), - (13, 72.73), (14, 76.86), (15, 80.97), - (16, 85.03), (18, 93.02), (20, 100.92), - (22, 108.60), (24, 116.23), (26, 123.81), - (28, 131.23), (30, 138.64), (40, 174.40), - (50, 208.94), (60, 242.54), (80, 307.50), - (100, 370.51), (150, 520.29), (200, 663.57), - (300, 936.33), (400, 1195.46), (500, 1444.04), - (600, 1686.34), (800, 2159.36), (1000, 2624.67), - (1500, 3757.99), (2000, 4856.73), (3000, 6983.24), - (4000, 9049.77), (5000, 11063.17), (6000, 13048.02), - (8000, 16940.54), (10000, 20746.89), (15000, 30021.01), - (20000, 39047.25), (30000, 56625.14), (40000, 73746.31), - (50000, 90579.71), (60000, 107146.68), (80000, 139684.31), - (100000, 171791.79)]) + energy_path_compton_scattering = np.array( + [ + (4, 31.88), + (5, 36.90), + (6, 41.75), + (7, 46.47), + (8, 51.05), + (9, 55.52), + (10, 59.95), + (11, 64.27), + (12, 68.54), + (13, 72.73), + (14, 76.86), + (15, 80.97), + (16, 85.03), + (18, 93.02), + (20, 100.92), + (22, 108.60), + (24, 116.23), + (26, 123.81), + (28, 131.23), + (30, 138.64), + (40, 174.40), + (50, 208.94), + (60, 242.54), + (80, 307.50), + (100, 370.51), + (150, 520.29), + (200, 663.57), + (300, 936.33), + (400, 1195.46), + (500, 1444.04), + (600, 1686.34), + (800, 2159.36), + (1000, 2624.67), + (1500, 3757.99), + (2000, 4856.73), + (3000, 6983.24), + (4000, 9049.77), + (5000, 11063.17), + (6000, 13048.02), + (8000, 16940.54), + (10000, 20746.89), + (15000, 30021.01), + (20000, 39047.25), + (30000, 56625.14), + (40000, 73746.31), + (50000, 90579.71), + (60000, 107146.68), + (80000, 139684.31), + (100000, 171791.79), + ], + ) gamma_energies = energy_path_compton_scattering[:, 0] mean_free_paths = energy_path_compton_scattering[:, 1] diff --git a/sapphire/simulations/groundparticles.py b/sapphire/simulations/groundparticles.py index 7418f1a6..4c80760e 100644 --- a/sapphire/simulations/groundparticles.py +++ b/sapphire/simulations/groundparticles.py @@ -29,7 +29,6 @@ class GroundParticlesSimulation(HiSPARCSimulation): - def __init__(self, corsikafile_path, max_core_distance, *args, **kwargs): """Simulation initialization @@ -70,10 +69,12 @@ def generate_shower_parameters(self): event_header = self.corsikafile.get_node_attr('/', 'event_header') event_end = self.corsikafile.get_node_attr('/', 'event_end') - corsika_parameters = {'zenith': event_header.zenith, - 'size': event_end.n_electrons_levels, - 'energy': event_header.energy, - 'particle': event_header.particle} + corsika_parameters = { + 'zenith': event_header.zenith, + 'size': event_end.n_electrons_levels, + 'energy': event_header.energy, + 'particle': event_header.particle, + } self.corsika_azimuth = event_header.azimuth for i in pbar(range(self.n), show=self.progress): @@ -81,9 +82,7 @@ def generate_shower_parameters(self): x, y = self.generate_core_position(r_max) shower_azimuth = self.generate_azimuth() - shower_parameters = {'ext_timestamp': ext_timestamp, - 'core_pos': (x, y), - 'azimuth': shower_azimuth} + shower_parameters = {'ext_timestamp': ext_timestamp, 'core_pos': (x, y), 'azimuth': shower_azimuth} # Subtract CORSIKA shower azimuth from desired shower azimuth # make it fit in (-pi, pi] to get rotation angle of the cluster. @@ -132,10 +131,9 @@ def simulate_detector_response(self, detector, shower_parameters): nz = cos(shower_parameters['zenith']) tproj = detector.get_coordinates()[-1] / (c * nz) first_signal = particles['t'].min() + detector.offset - tproj - observables = {'n': round(mips, 3), - 't': self.simulate_adc_sampling(first_signal)} + observables = {'n': round(mips, 3), 't': self.simulate_adc_sampling(first_signal)} else: - observables = {'n': 0., 't': -999} + observables = {'n': 0.0, 't': -999} return observables @@ -147,9 +145,7 @@ def simulate_detector_mips_for_particles(self, particles): """ # determination of lepton angle of incidence - theta = np.arccos(abs(particles['p_z']) / - vector_length(particles['p_x'], particles['p_y'], - particles['p_z'])) + theta = np.arccos(abs(particles['p_z']) / vector_length(particles['p_x'], particles['p_y'], particles['p_z'])) n = len(particles) mips = self.simulate_detector_mips(n, theta) @@ -168,21 +164,12 @@ def simulate_trigger(self, detector_observables): """ n_detectors = len(detector_observables) - detectors_low = sum( - True for observables in detector_observables - if observables['n'] > 0.3 - ) - detectors_high = sum( - True for observables in detector_observables - if observables['n'] > 0.5 - ) + detectors_low = sum(True for observables in detector_observables if observables['n'] > 0.3) + detectors_high = sum(True for observables in detector_observables if observables['n'] > 0.5) - if n_detectors == 4 and (detectors_high >= 2 or detectors_low >= 3): - return True - elif n_detectors == 2 and detectors_low >= 2: - return True - else: - return False + return ( + n_detectors == 4 and (detectors_high >= 2 or detectors_low >= 3) or n_detectors == 2 and detectors_low >= 2 + ) def simulate_gps(self, station_observables, shower_parameters, station): """Simulate gps timestamp. @@ -196,23 +183,26 @@ def simulate_gps(self, station_observables, shower_parameters, station): trigger time. """ - arrival_times = [station_observables['t%d' % id] - for id in range(1, 5) - if station_observables.get('n%d' % id, -1) > 0] + arrival_times = [ + station_observables[f't{detector_id}'] + for detector_id in range(1, 5) + if station_observables.get(f'n{detector_id}', -1) > 0 + ] if len(arrival_times) > 1: trigger_time = sorted(arrival_times)[1] ext_timestamp = shower_parameters['ext_timestamp'] - ext_timestamp += int(trigger_time + station.gps_offset + - self.simulate_gps_uncertainty()) + ext_timestamp += int(trigger_time + station.gps_offset + self.simulate_gps_uncertainty()) timestamp = int(ext_timestamp / 1_000_000_000) nanoseconds = int(ext_timestamp % 1_000_000_000) - gps_timestamp = {'ext_timestamp': ext_timestamp, - 'timestamp': timestamp, - 'nanoseconds': nanoseconds, - 't_trigger': trigger_time} + gps_timestamp = { + 'ext_timestamp': ext_timestamp, + 'timestamp': timestamp, + 'nanoseconds': nanoseconds, + 't_trigger': trigger_time, + } station_observables.update(gps_timestamp) return station_observables @@ -238,7 +228,7 @@ def get_particles_in_detector(self, detector, shower_parameters): :param shower_parameters: dictionary with the shower parameters. """ - detector_boundary = sqrt(0.5) / 2. + detector_boundary = sqrt(0.5) / 2.0 x, y, z = detector.get_coordinates() zenith = shower_parameters['zenith'] @@ -249,10 +239,12 @@ def get_particles_in_detector(self, detector, shower_parameters): xproj = x - z * nxnz yproj = y - z * nynz - query = ('(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f)' - ' & (particle_id >= 2) & (particle_id <= 6)' % - (xproj - detector_boundary, xproj + detector_boundary, - yproj - detector_boundary, yproj + detector_boundary)) + query = '(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f) & (particle_id >= 2) & (particle_id <= 6)' % ( + xproj - detector_boundary, + xproj + detector_boundary, + yproj - detector_boundary, + yproj + detector_boundary, + ) return self.groundparticles.read_where(query) @@ -271,8 +263,7 @@ def simulate_detector_response(self, detector, shower_parameters): :param shower_parameters: dictionary with the shower parameters. """ - leptons, gammas = self.get_particles_in_detector(detector, - shower_parameters) + leptons, gammas = self.get_particles_in_detector(detector, shower_parameters) n_leptons = len(leptons) n_gammas = len(gammas) @@ -300,8 +291,7 @@ def simulate_detector_response(self, detector, shower_parameters): elif n_gammas: first_signal = first_gamma + detector.offset - return {'n': mips_lepton + mips_gamma, - 't': self.simulate_adc_sampling(first_signal)} + return {'n': mips_lepton + mips_gamma, 't': self.simulate_adc_sampling(first_signal)} def get_particles_in_detector(self, detector, shower_parameters): """Get particles that hit a detector. @@ -321,7 +311,7 @@ def get_particles_in_detector(self, detector, shower_parameters): :param shower_parameters: dictionary with the shower parameters. """ - detector_boundary = sqrt(.5) / 2. + detector_boundary = sqrt(0.5) / 2.0 x, y, z = detector.get_coordinates() zenith = shower_parameters['zenith'] @@ -332,20 +322,21 @@ def get_particles_in_detector(self, detector, shower_parameters): xproj = x - z * nxnz yproj = y - z * nynz - query_leptons = \ - ('(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f)' - ' & (particle_id >= 2) & (particle_id <= 6)' % - (xproj - detector_boundary, xproj + detector_boundary, - yproj - detector_boundary, yproj + detector_boundary)) + query_leptons = '(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f) & (particle_id >= 2) & (particle_id <= 6)' % ( + xproj - detector_boundary, + xproj + detector_boundary, + yproj - detector_boundary, + yproj + detector_boundary, + ) - query_gammas = \ - ('(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f)' - ' & (particle_id == 1)' % - (xproj - detector_boundary, xproj + detector_boundary, - yproj - detector_boundary, yproj + detector_boundary)) + query_gammas = '(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f) & (particle_id == 1)' % ( + xproj - detector_boundary, + xproj + detector_boundary, + yproj - detector_boundary, + yproj + detector_boundary, + ) - return (self.groundparticles.read_where(query_leptons), - self.groundparticles.read_where(query_gammas)) + return (self.groundparticles.read_where(query_leptons), self.groundparticles.read_where(query_gammas)) def simulate_detector_mips_for_gammas(self, particles): """Simulate the detector signal for gammas @@ -354,12 +345,10 @@ def simulate_detector_mips_for_gammas(self, particles): components of the particle momenta. """ - p_gamma = np.sqrt(particles['p_x'] ** 2 + particles['p_y'] ** 2 + - particles['p_z'] ** 2) + p_gamma = np.sqrt(particles['p_x'] ** 2 + particles['p_y'] ** 2 + particles['p_z'] ** 2) # determination of lepton angle of incidence - theta = np.arccos(abs(particles['p_z']) / - p_gamma) + theta = np.arccos(abs(particles['p_z']) / p_gamma) mips = simulate_detector_mips_gammas(p_gamma, theta) @@ -367,7 +356,6 @@ def simulate_detector_mips_for_gammas(self, particles): class DetectorBoundarySimulation(GroundParticlesSimulation): - """More accuratly simulate the detection area of the detectors. Take the orientation of the detectors into account and use the @@ -407,12 +395,21 @@ def get_particles_in_detector(self, detector, shower_parameters): b11, line1, b12 = self.get_line_boundary_eqs(*cproj[0:3]) b21, line2, b22 = self.get_line_boundary_eqs(*cproj[1:4]) - query = ("(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f) & " - "(b11 < %s) & (%s < b12) & (b21 < %s) & (%s < b22) & " - "(particle_id >= 2) & (particle_id <= 6)" % - (xproj - detector_boundary, xproj + detector_boundary, - yproj - detector_boundary, yproj + detector_boundary, - line1, line1, line2, line2)) + query = ( + '(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f) & ' + '(b11 < %s) & (%s < b12) & (b21 < %s) & (%s < b22) & ' + '(particle_id >= 2) & (particle_id <= 6)' + % ( + xproj - detector_boundary, + xproj + detector_boundary, + yproj - detector_boundary, + yproj + detector_boundary, + line1, + line1, + line2, + line2, + ) + ) return self.groundparticles.read_where(query) @@ -442,7 +439,7 @@ def get_line_boundary_eqs(self, p0, p1, p2): # Compute the general equation for the lines if x0 == x1: # line is exactly vertical - line = "x" + line = 'x' b1, b2 = x0, x2 else: # First, compute the slope @@ -452,7 +449,7 @@ def get_line_boundary_eqs(self, p0, p1, p2): b1 = y0 - a * x0 b2 = y2 - a * x2 - line = "y - %f * x" % a + line = 'y - %f * x' % a # And order the y-intercepts if b1 > b2: @@ -462,7 +459,6 @@ def get_line_boundary_eqs(self, p0, p1, p2): class ParticleCounterSimulation(GroundParticlesSimulation): - """Do not simulate mips, just count the number of particles.""" def simulate_detector_mips(self, n, theta): @@ -472,7 +468,6 @@ def simulate_detector_mips(self, n, theta): class FixedCoreDistanceSimulation(GroundParticlesSimulation): - """Shower core at a fixed core distance (from cluster origin). :param core_distance: distance of shower core to center of cluster. @@ -493,9 +488,7 @@ def generate_core_position(cls, r_max): return x, y -class GroundParticlesSimulationWithoutErrors(ErrorlessSimulation, - GroundParticlesSimulation): - +class GroundParticlesSimulationWithoutErrors(ErrorlessSimulation, GroundParticlesSimulation): """This simulation does not simulate errors/uncertainties This results in perfect timing (first particle through detector) @@ -503,11 +496,8 @@ class GroundParticlesSimulationWithoutErrors(ErrorlessSimulation, """ - pass - class MultipleGroundParticlesSimulation(GroundParticlesSimulation): - """Use multiple CORSIKA simulated air showers in one run. Simulations will be selected from the set of available showers. @@ -526,8 +516,7 @@ class MultipleGroundParticlesSimulation(GroundParticlesSimulation): # CORSIKA data location at Nikhef DATA = '/data/hisparc/corsika/data/{seeds}/corsika.h5' - def __init__(self, corsikaoverview_path, max_core_distance, min_energy, - max_energy, *args, **kwargs): + def __init__(self, corsikaoverview_path, max_core_distance, min_energy, max_energy, *args, **kwargs): """Simulation initialization :param corsikaoverview_path: path to the corsika_overview.h5 file @@ -545,11 +534,8 @@ def __init__(self, corsikaoverview_path, max_core_distance, min_energy, self.max_core_distance = max_core_distance self.min_energy = min_energy self.max_energy = max_energy - self.available_energies = {e for e in self.cq.all_energies - if min_energy <= 10 ** e <= max_energy} - self.available_zeniths = {e: self.cq.available_parameters('zenith', - energy=e) - for e in self.available_energies} + self.available_energies = {e for e in self.cq.all_energies if min_energy <= 10**e <= max_energy} + self.available_zeniths = {e: self.cq.available_parameters('zenith', energy=e) for e in self.available_energies} def finish(self): """Clean-up after simulation""" @@ -577,10 +563,12 @@ def generate_shower_parameters(self): if sim is None: continue - corsika_parameters = {'zenith': sim['zenith'], - 'size': sim['n_electron'], - 'energy': sim['energy'], - 'particle': sim['particle_id']} + corsika_parameters = { + 'zenith': sim['zenith'], + 'size': sim['n_electron'], + 'energy': sim['energy'], + 'particle': sim['particle_id'], + } self.corsika_azimuth = sim['azimuth'] seeds = self.cq.seeds([sim])[0] @@ -596,9 +584,7 @@ def generate_shower_parameters(self): x, y = self.generate_core_position(r) shower_azimuth = self.generate_azimuth() - shower_parameters = {'ext_timestamp': ext_timestamp, - 'core_pos': (x, y), - 'azimuth': shower_azimuth} + shower_parameters = {'ext_timestamp': ext_timestamp, 'core_pos': (x, y), 'azimuth': shower_azimuth} # Subtract CORSIKA shower azimuth from desired shower # azimuth to get rotation angle of the cluster. @@ -619,8 +605,7 @@ def select_simulation(self): shower_energy = closest_in_list(log10(energy), self.available_energies) zenith = self.generate_zenith() - shower_zenith = closest_in_list(np.degrees(zenith), - self.available_zeniths[shower_energy]) + shower_zenith = closest_in_list(np.degrees(zenith), self.available_zeniths[shower_energy]) sims = self.cq.simulations(energy=shower_energy, zenith=shower_zenith) if not len(sims): diff --git a/sapphire/simulations/ldf.py b/sapphire/simulations/ldf.py index f2fdf807..54c70f99 100644 --- a/sapphire/simulations/ldf.py +++ b/sapphire/simulations/ldf.py @@ -17,6 +17,7 @@ >>> sim.run() """ + import warnings from numpy import arctan2, cos, log10, pi, random, sin, sqrt @@ -27,9 +28,7 @@ class BaseLdfSimulation(HiSPARCSimulation): - - def __init__(self, max_core_distance, min_energy, max_energy, *args, - **kwargs): + def __init__(self, max_core_distance, min_energy, max_energy, *args, **kwargs): """Simulation initialization :param max_core_distance: maximum distance of shower core to @@ -66,12 +65,14 @@ def generate_shower_parameters(self): for i in pbar(range(self.n), show=self.progress): energy = self.generate_energy(self.min_energy, self.max_energy) size = 10 ** (log10(energy) - 15 + 4.8) - shower_parameters = {'ext_timestamp': (giga + i) * giga, - 'azimuth': self.generate_azimuth(), - 'zenith': 0., - 'core_pos': self.generate_core_position(r), - 'size': size, - 'energy': energy} + shower_parameters = { + 'ext_timestamp': (giga + i) * giga, + 'azimuth': self.generate_azimuth(), + 'zenith': 0.0, + 'core_pos': self.generate_core_position(r), + 'size': size, + 'energy': energy, + } yield shower_parameters @@ -85,15 +86,14 @@ def simulate_detector_response(self, detector, shower_parameters): :param shower_parameters: dictionary with the shower parameters. """ - n_detected = self.get_num_particles_in_detector(detector, - shower_parameters) + n_detected = self.get_num_particles_in_detector(detector, shower_parameters) theta = shower_parameters['zenith'] if n_detected: mips = self.simulate_detector_mips(n_detected, theta) observables = {'n': mips} else: - observables = {'n': 0.} + observables = {'n': 0.0} return observables def get_num_particles_in_detector(self, detector, shower_parameters): @@ -110,13 +110,11 @@ def get_num_particles_in_detector(self, detector, shower_parameters): azimuth = shower_parameters['azimuth'] size = shower_parameters['size'] - r = self.ldf.calculate_core_distance(x, y, core_x, core_y, zenith, - azimuth) + r = self.ldf.calculate_core_distance(x, y, core_x, core_y, zenith, azimuth) p_shower = self.ldf.calculate_ldf_value(r, n_electrons=size) p_ground = p_shower * cos(zenith) - num_particles = self.simulate_particles_for_density( - p_ground * detector.get_area()) + num_particles = self.simulate_particles_for_density(p_ground * detector.get_area()) return num_particles @@ -132,7 +130,6 @@ def simulate_particles_for_density(p): class BaseLdfSimulationWithoutErrors(ErrorlessSimulation, BaseLdfSimulation): - """This simulation does not simulate errors/uncertainties This should result in perfect particle counting for the detectors. @@ -147,7 +144,6 @@ def simulate_particles_for_density(p): class NkgLdfSimulation(BaseLdfSimulation): - """Same as the BaseLdfSimulation but uses the NkgLdf as LDF""" def __init__(self, *args, **kwargs): @@ -156,16 +152,11 @@ def __init__(self, *args, **kwargs): self.ldf = NkgLdf() -class NkgLdfSimulationWithoutErrors(NkgLdfSimulation, - BaseLdfSimulationWithoutErrors): - +class NkgLdfSimulationWithoutErrors(NkgLdfSimulation, BaseLdfSimulationWithoutErrors): """Same as the NkgLdfSimulation but without error simulation""" - pass - class KascadeLdfSimulation(BaseLdfSimulation): - """Same as the BaseLdfSimulation but uses the KascadeLdf as LDF""" def __init__(self, *args, **kwargs): @@ -174,16 +165,11 @@ def __init__(self, *args, **kwargs): self.ldf = KascadeLdf() -class KascadeLdfSimulationWithoutErrors(KascadeLdfSimulation, - BaseLdfSimulationWithoutErrors): - +class KascadeLdfSimulationWithoutErrors(KascadeLdfSimulation, BaseLdfSimulationWithoutErrors): """Same as the KascadeLdfSimulation but without error simulation""" - pass - class EllipsLdfSimulation(BaseLdfSimulation): - """Same as BaseLdfSimulation but uses the EllipsLdF as LDF""" def __init__(self, *args, **kwargs): @@ -207,12 +193,14 @@ def generate_shower_parameters(self): for i in pbar(range(self.n), show=self.progress): energy = self.generate_energy(self.min_energy, self.max_energy) size = 10 ** (log10(energy) - 15 + 4.8) - shower_parameters = {'ext_timestamp': (giga + i) * giga, - 'azimuth': self.generate_azimuth(), - 'zenith': self.generate_zenith(), - 'core_pos': self.generate_core_position(r), - 'size': size, - 'energy': energy} + shower_parameters = { + 'ext_timestamp': (giga + i) * giga, + 'azimuth': self.generate_azimuth(), + 'zenith': self.generate_zenith(), + 'core_pos': self.generate_core_position(r), + 'size': size, + 'energy': energy, + } yield shower_parameters @@ -230,18 +218,15 @@ def get_num_particles_in_detector(self, detector, shower_parameters): azimuth = shower_parameters['azimuth'] size = shower_parameters['size'] - r, phi = self.ldf.calculate_core_distance_and_angle(x, y, core_x, - core_y) + r, phi = self.ldf.calculate_core_distance_and_angle(x, y, core_x, core_y) p_ground = self.ldf.calculate_ldf_value(r, phi, size, zenith, azimuth) - num_particles = self.simulate_particles_for_density( - p_ground * detector.get_area()) + num_particles = self.simulate_particles_for_density(p_ground * detector.get_area()) return num_particles class BaseLdf: - """Base LDF class No particles! Always returns a particle density of 0. @@ -249,7 +234,7 @@ class BaseLdf: """ def calculate_ldf_value(self, r, n_electrons=None, s=None): - return 0. + return 0.0 def calculate_core_distance(self, x, y, x0, y0, theta, phi): """Calculate core distance @@ -267,19 +252,17 @@ def calculate_core_distance(self, x, y, x0, y0, theta, phi): x = x - x0 y = y - y0 - return sqrt(x ** 2 + y ** 2 - - (x * cos(phi) + y * sin(phi)) ** 2 * sin(theta) ** 2) + return sqrt(x**2 + y**2 - (x * cos(phi) + y * sin(phi)) ** 2 * sin(theta) ** 2) class NkgLdf(BaseLdf): - """The Nishimura-Kamata-Greisen function""" # shower parameters # Age parameter and Moliere radius from Thoudam2012 sec 5.6. - _n_electrons = 10 ** 4.8 + _n_electrons = 10**4.8 _s = 1.7 - _r0 = 30. + _r0 = 30.0 def __init__(self, n_electrons=None, s=None): """NKG LDF setup @@ -336,8 +319,7 @@ def ldf_value(self, r, n_electrons, s): c_s = self._c(s) r0 = self._r0 - return (n_electrons * c_s * (r / r0) ** (s - 2) * - (1 + r / r0) ** (s - 4.5)) + return n_electrons * c_s * (r / r0) ** (s - 2) * (1 + r / r0) ** (s - 4.5) def _c(self, s): """Part of the LDF @@ -349,19 +331,17 @@ def _c(self, s): """ r0 = self._r0 - return (gamma(4.5 - s) / - (2 * pi * r0 ** 2 * gamma(s) * gamma(4.5 - 2 * s))) + return gamma(4.5 - s) / (2 * pi * r0**2 * gamma(s) * gamma(4.5 - 2 * s)) class KascadeLdf(NkgLdf): - """The KASCADE modified NKG function""" # shower parameters # Values from Fokkema2012 sec 7.1. - _n_electrons = 10 ** 4.8 + _n_electrons = 10**4.8 _s = 0.94 # Shape parameter - _r0 = 40. + _r0 = 40.0 _alpha = 1.5 _beta = 3.6 @@ -385,8 +365,7 @@ def ldf_value(self, r, n_electrons, s): alpha = self._alpha beta = self._beta - return (n_electrons * c_s * (r / r0) ** (s - alpha) * - (1 + r / r0) ** (s - beta)) + return n_electrons * c_s * (r / r0) ** (s - alpha) * (1 + r / r0) ** (s - beta) def _c(self, s): """Part of the LDF @@ -400,26 +379,22 @@ def _c(self, s): r0 = self._r0 beta = self._beta alpha = self._alpha - return (gamma(beta - s) / - (2 * pi * r0 ** 2 * gamma(s - alpha + 2) * - gamma(alpha + beta - 2 * s - 2))) + return gamma(beta - s) / (2 * pi * r0**2 * gamma(s - alpha + 2) * gamma(alpha + beta - 2 * s - 2)) class EllipsLdf(KascadeLdf): - """The NKG function modified for leptons and azimuthal asymmetry""" # shower parameters # Values from Montanus, paper to follow. - _n_electrons = 10 ** 4.8 - _s1 = -.5 # Shape parameter + _n_electrons = 10**4.8 + _s1 = -0.5 # Shape parameter _s2 = -2.6 # Shape parameter - _r0 = 30. - _zenith = 0. - _azimuth = 0. + _r0 = 30.0 + _zenith = 0.0 + _azimuth = 0.0 - def __init__(self, n_electrons=None, zenith=None, azimuth=None, s1=None, - s2=None): + def __init__(self, n_electrons=None, zenith=None, azimuth=None, s1=None, s2=None): if n_electrons is not None: self._n_electrons = n_electrons if zenith is not None: @@ -441,8 +416,7 @@ def _cache_c_s_value(self): """ self._c_s = self._c(self._s1, self._s2) - def calculate_ldf_value(self, r, phi, n_electrons=None, zenith=None, - azimuth=None): + def calculate_ldf_value(self, r, phi, n_electrons=None, zenith=None, azimuth=None): """Calculate the LDF value for a given core distance and polar angle :param r: core distance in m. @@ -457,8 +431,7 @@ def calculate_ldf_value(self, r, phi, n_electrons=None, zenith=None, zenith = self._zenith if azimuth is None: azimuth = self._azimuth - return self.ldf_value(r, phi, n_electrons, zenith, azimuth, self._s1, - self._s2) + return self.ldf_value(r, phi, n_electrons, zenith, azimuth, self._s1, self._s2) def ldf_value(self, r, phi, n_electrons, zenith, azimuth, s1, s2): """Calculate the LDF value @@ -496,8 +469,7 @@ def ldf_value(self, r, phi, n_electrons, zenith, azimuth, s1, s2): term2 = 1 + k / r0 muoncorr = 1 + k / (11.24 * r0) # See warning in docstring. with warnings.catch_warnings(record=True): - p = (n_electrons * c_s * cos(zenith) * term1 ** s1 * term2 ** s2 * - muoncorr) + p = n_electrons * c_s * cos(zenith) * term1**s1 * term2**s2 * muoncorr return p def _c(self, s1, s2): @@ -511,8 +483,7 @@ def _c(self, s1, s2): """ r0 = self._r0 - return (gamma(-s2) / - (2 * pi * r0 ** 2 * gamma(s1 + 2) * gamma(-s1 - s2 - 2))) + return gamma(-s2) / (2 * pi * r0**2 * gamma(s1 + 2) * gamma(-s1 - s2 - 2)) def calculate_core_distance_and_angle(self, x, y, x0, y0): """Calculate core distance diff --git a/sapphire/simulations/showerfront.py b/sapphire/simulations/showerfront.py index 793659c8..1abdac8d 100644 --- a/sapphire/simulations/showerfront.py +++ b/sapphire/simulations/showerfront.py @@ -26,15 +26,13 @@ class FlatFrontSimulation(HiSPARCSimulation): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Since the cluster is not rotated detector positions can be cached. for station in self.cluster.stations: for detector in station.detectors: - detector.cylindrical_coordinates = \ - detector.get_cylindrical_coordinates() + detector.cylindrical_coordinates = detector.get_cylindrical_coordinates() def generate_shower_parameters(self): """Generate shower parameters, i.e. azimuth and zenith angles. @@ -49,12 +47,14 @@ def generate_shower_parameters(self): """ for i in pbar(range(self.n), show=self.progress): - shower_parameters = {'ext_timestamp': (1_000_000_000 + i) * 1_000_000_000, - 'azimuth': self.generate_azimuth(), - 'zenith': self.generate_attenuated_zenith(), - 'core_pos': (None, None), - 'size': None, - 'energy': None} + shower_parameters = { + 'ext_timestamp': (1_000_000_000 + i) * 1_000_000_000, + 'azimuth': self.generate_azimuth(), + 'zenith': self.generate_attenuated_zenith(), + 'core_pos': (None, None), + 'size': None, + 'energy': None, + } yield shower_parameters @@ -66,9 +66,10 @@ def simulate_detector_response(self, detector, shower_parameters): """ arrival_time = self.simulate_adc_sampling( - self.get_arrival_time(detector, shower_parameters) + - self.simulate_signal_transport_time(1)[0] + - detector.offset) + self.get_arrival_time(detector, shower_parameters) + + self.simulate_signal_transport_time(1)[0] + + detector.offset, + ) observables = {'t': arrival_time} return observables @@ -92,7 +93,7 @@ def get_arrival_time(self, detector, shower_parameters): phi = shower_parameters['azimuth'] theta = shower_parameters['zenith'] r = r1 * cos(phi - phi1) - z1 * tan(theta) - cdt = - (r * sin(theta) + z1 / cos(theta)) + cdt = -(r * sin(theta) + z1 / cos(theta)) dt = cdt / c return dt @@ -103,45 +104,41 @@ def simulate_gps(self, station_observables, shower_parameters, station): """ n_detectors = len(station.detectors) - ids = list(range(1, n_detectors + 1)) - arrival_times = [station_observables['t%d' % id] for id in ids] + detector_ids = list(range(1, n_detectors + 1)) + arrival_times = [station_observables[f't{detector_id}'] for detector_id in detector_ids] ext_timestamp = shower_parameters['ext_timestamp'] first_time = sorted(arrival_times)[0] - for id in ids: - station_observables['t%d' % id] -= first_time + for detector_id in detector_ids: + station_observables[f't{detector_id}'] -= first_time - arrival_times = [station_observables['t%d' % id] for id in ids] + arrival_times = [station_observables[f't{detector_id}'] for detector_id in detector_ids] trigger_time = sorted(arrival_times)[1] - ext_timestamp += int(first_time + trigger_time + station.gps_offset + - self.simulate_gps_uncertainty()) + ext_timestamp += int(first_time + trigger_time + station.gps_offset + self.simulate_gps_uncertainty()) timestamp = int(ext_timestamp / 1_000_000_000) nanoseconds = int(ext_timestamp % 1_000_000_000) - gps_timestamp = {'ext_timestamp': ext_timestamp, - 'timestamp': timestamp, - 'nanoseconds': nanoseconds, - 't_trigger': trigger_time} + gps_timestamp = { + 'ext_timestamp': ext_timestamp, + 'timestamp': timestamp, + 'nanoseconds': nanoseconds, + 't_trigger': trigger_time, + } station_observables.update(gps_timestamp) return station_observables -class FlatFrontSimulationWithoutErrors(ErrorlessSimulation, - FlatFrontSimulation): - +class FlatFrontSimulationWithoutErrors(ErrorlessSimulation, FlatFrontSimulation): """This simulation does not simulate errors/uncertainties This should result in perfect timing for the detectors. """ - pass - class FlatFrontSimulation2D(FlatFrontSimulation): - """This simulation ignores detector altitudes.""" def get_arrival_time(self, detector, shower_parameters): @@ -162,16 +159,11 @@ def get_arrival_time(self, detector, shower_parameters): return dt -class FlatFrontSimulation2DWithoutErrors(FlatFrontSimulation2D, - FlatFrontSimulationWithoutErrors): - +class FlatFrontSimulation2DWithoutErrors(FlatFrontSimulation2D, FlatFrontSimulationWithoutErrors): """Ignore altitude of detectors and do not simulate errors.""" - pass - class ConeFrontSimulation(FlatFrontSimulation): - """This simulation uses a cone shaped shower front. The opening angle of the cone is given in the init @@ -216,12 +208,14 @@ def generate_shower_parameters(self): x, y = self.generate_core_position(r_max) azimuth = self.generate_azimuth() - shower_parameters = {'ext_timestamp': (1_000_000_000 + i) * 1_000_000_000, - 'azimuth': azimuth, - 'zenith': self.generate_attenuated_zenith(), - 'core_pos': (x, y), - 'size': None, - 'energy': self.generate_energy(1e15, 1e17)} + shower_parameters = { + 'ext_timestamp': (1_000_000_000 + i) * 1_000_000_000, + 'azimuth': azimuth, + 'zenith': self.generate_attenuated_zenith(), + 'core_pos': (x, y), + 'size': None, + 'energy': self.generate_energy(1e15, 1e17), + } self._prepare_cluster_for_shower(x, y, azimuth) @@ -251,14 +245,13 @@ def get_arrival_time(self, detector, shower_parameters): phi = shower_parameters['azimuth'] theta = shower_parameters['zenith'] r = r1 * cos(phi - phi1) - z * tan(theta) - cdt = - (r * sin(theta) + z / cos(theta)) + cdt = -(r * sin(theta) + z / cos(theta)) nx = sin(theta) * cos(phi) ny = sin(theta) * sin(phi) nz = cos(theta) - r_core = sqrt(x ** 2 + y ** 2 + z ** 2 - - (x * nx + y * ny + z * nz) ** 2) + r_core = sqrt(x**2 + y**2 + z**2 - (x * nx + y * ny + z * nz) ** 2) t_shape = self.delay_at_r(r_core) dt = t_shape + (cdt / c) @@ -266,18 +259,16 @@ def get_arrival_time(self, detector, shower_parameters): class FlatFront: - """Simple flat shower front""" def delay_at_r(self, r): - return 0. + return 0.0 def front_shape(self, r): - return 0. + return 0.0 class ConeFront: - """Simple cone shaped shower front""" def delay_at_r(self, r): @@ -294,7 +285,6 @@ def front_shape(self, r): class CorsikaStationFront: - """Shower front shape derrived from CORSIKA simulations on a station. A set of CORSIKA generated showers were used to determine the median @@ -331,4 +321,4 @@ def front_shape(self, r, energy, particle='proton'): return self._front_shape(r, a, b) def _front_shape(self, r, a, b): - return a * r ** b + return a * r**b diff --git a/sapphire/storage.py b/sapphire/storage.py index 6290a4a2..6dcdb964 100644 --- a/sapphire/storage.py +++ b/sapphire/storage.py @@ -1,14 +1,14 @@ -""" PyTables table descriptions for data storage +"""PyTables table descriptions for data storage - This module contains the table descriptions used by the detector - simulation to store intermediate and final data in a HDF5 file. +This module contains the table descriptions used by the detector +simulation to store intermediate and final data in a HDF5 file. """ + import tables class EventObservables(tables.IsDescription): - """Store information about the observables of an event. The observables are described for each station independently. So, for each @@ -40,6 +40,7 @@ class EventObservables(tables.IsDescription): number of detectors with at least one particle """ + id = tables.UInt32Col() station_id = tables.UInt8Col() timestamp = tables.Time32Col() @@ -63,7 +64,6 @@ class EventObservables(tables.IsDescription): class Coincidence(tables.IsDescription): - """Store information about a coincidence of stations within a cluster. An extensive air shower can trigger multiple stations, resulting in a set @@ -123,6 +123,7 @@ class Coincidence(tables.IsDescription): The primary particle energy of the (simulated) shower. """ + id = tables.UInt32Col(pos=0) timestamp = tables.Time32Col(pos=1) nanoseconds = tables.UInt32Col(pos=2) @@ -138,7 +139,6 @@ class Coincidence(tables.IsDescription): class TimeDelta(tables.IsDescription): - """Store time differences""" ext_timestamp = tables.UInt64Col(pos=0) @@ -148,7 +148,6 @@ class TimeDelta(tables.IsDescription): class ReconstructedCoincidence(tables.IsDescription): - """Store information about reconstructed coincidences""" id = tables.UInt32Col(pos=1) @@ -177,7 +176,6 @@ class ReconstructedCoincidence(tables.IsDescription): class ReconstructedEvent(ReconstructedCoincidence): - """Store information about reconstructed events .. attribute:: id @@ -198,7 +196,6 @@ class ReconstructedEvent(ReconstructedCoincidence): class KascadeEvent(tables.IsDescription): - """Store events from KASCADE""" run_id = tables.IntCol(pos=0) @@ -220,7 +217,6 @@ class KascadeEvent(tables.IsDescription): class ReconstructedKascadeEvent(tables.IsDescription): - """Store information about reconstructed events""" # r, phi is core position diff --git a/sapphire/tests/__init__.py b/sapphire/tests/__init__.py index d909ddcf..bccd209b 100644 --- a/sapphire/tests/__init__.py +++ b/sapphire/tests/__init__.py @@ -5,6 +5,7 @@ installed correctly. Simply call the :func:`run_tests` function. """ + import os from unittest import TestSuite, TextTestRunner, defaultTestLoader diff --git a/sapphire/tests/analysis/test_calibration.py b/sapphire/tests/analysis/test_calibration.py index d7b82f6b..35e23417 100644 --- a/sapphire/tests/analysis/test_calibration.py +++ b/sapphire/tests/analysis/test_calibration.py @@ -1,14 +1,13 @@ -import os import unittest import warnings from datetime import date, datetime +from pathlib import Path from unittest.mock import MagicMock, Mock, call, patch, sentinel import tables -from numpy import all, array, isnan, nan, std -from numpy.random import normal, uniform +from numpy import all, array, isnan, nan, random, std from sapphire import HiSPARCNetwork, HiSPARCStations from sapphire.analysis import calibration @@ -20,10 +19,8 @@ class DetectorTimingTests(unittest.TestCase): - def get_testdata_path(self): - dir_path = os.path.dirname(__file__) - return os.path.join(dir_path, TEST_DATA_ESD) + return Path(__file__).parent / TEST_DATA_ESD def test_determine_detector_timing_offsets(self): with tables.open_file(self.get_testdata_path(), 'r') as data: @@ -44,22 +41,22 @@ def test_determine_detector_timing_offset(self, mock_fit): dzc = dz / c # Good result - mock_fit.return_value = (1., 2.) + mock_fit.return_value = (1.0, 2.0) offset, _ = calibration.determine_detector_timing_offset(dt) - self.assertEqual(offset, 1.) + self.assertEqual(offset, 1.0) offset, _ = calibration.determine_detector_timing_offset(dt, dz=dz) - self.assertEqual(offset, 1. + dzc) + self.assertEqual(offset, 1.0 + dzc) - mock_fit.return_value = (-1.5, 5.) + mock_fit.return_value = (-1.5, 5.0) offset, _ = calibration.determine_detector_timing_offset(dt) self.assertEqual(offset, -1.5) offset, _ = calibration.determine_detector_timing_offset(dt, dz=dz) self.assertEqual(offset, -1.5 + dzc) - mock_fit.return_value = (250., 100.) + mock_fit.return_value = (250.0, 100.0) offset, _ = calibration.determine_detector_timing_offset(dt, dz=dz) self.assertTrue(isnan(offset)) - mock_fit.return_value = (-150., 100.) + mock_fit.return_value = (-150.0, 100.0) offset, _ = calibration.determine_detector_timing_offset(dt, dz=dz) self.assertTrue(isnan(offset)) @@ -69,11 +66,10 @@ def test_determine_detector_timing_offset(self, mock_fit): class StationTimingTests(unittest.TestCase): - @patch.object(calibration, 'percentile') @patch.object(calibration, 'fit_timing_offset') def test_determine_station_timing_offset(self, mock_fit, mock_percentile): - mock_percentile.return_value = (-50., 50.) + mock_percentile.return_value = (-50.0, 50.0) dz = 0.6 dzc = dz / c @@ -82,23 +78,23 @@ def test_determine_station_timing_offset(self, mock_fit, mock_percentile): self.assertTrue(all(isnan(offset))) # Good result - mock_fit.return_value = (1., 5.) + mock_fit.return_value = (1.0, 5.0) offset, _ = calibration.determine_station_timing_offset([sentinel.dt]) - self.assertEqual(offset, 1.) + self.assertEqual(offset, 1.0) mock_percentile.assert_called_once_with([sentinel.dt], [0.5, 99.5]) offset, _ = calibration.determine_station_timing_offset([sentinel.dt], dz=dz) - self.assertEqual(offset, 1. + dzc) + self.assertEqual(offset, 1.0 + dzc) - mock_fit.return_value = (-1.5, 5.) + mock_fit.return_value = (-1.5, 5.0) offset, _ = calibration.determine_station_timing_offset([sentinel.dt]) self.assertEqual(offset, -1.5) offset, _ = calibration.determine_station_timing_offset([sentinel.dt], dz=dz) self.assertEqual(offset, -1.5 + dzc) - mock_fit.return_value = (2500., 100.) + mock_fit.return_value = (2500.0, 100.0) offset, _ = calibration.determine_station_timing_offset([sentinel.dt]) self.assertTrue(isnan(offset)) - mock_fit.return_value = (-1500., 100.) + mock_fit.return_value = (-1500.0, 100.0) offset, _ = calibration.determine_station_timing_offset([sentinel.dt]) self.assertTrue(isnan(offset)) @@ -108,21 +104,17 @@ def test_determine_station_timing_offset(self, mock_fit, mock_percentile): class BestReferenceTests(unittest.TestCase): - def test_determine_best_reference(self): # Tie - filters = array([[True, True, False], [True, False, True], - [False, True, True], [True, True, False]]) + filters = array([[True, True, False], [True, False, True], [False, True, True], [True, True, False]]) self.assertEqual(calibration.determine_best_reference(filters), 0) # 1 has most matches - filters = array([[True, False, False], [True, True, True], - [False, False, False], [True, True, False]]) + filters = array([[True, False, False], [True, True, True], [False, False, False], [True, True, False]]) self.assertEqual(calibration.determine_best_reference(filters), 1) # Another winner - filters = array([[True, True, False], [True, False, True], - [False, True, True], [True, True, True]]) + filters = array([[True, True, False], [True, False, True], [False, True, True], [True, True, True]]) self.assertEqual(calibration.determine_best_reference(filters), 3) # Not yet support number of detectors @@ -131,7 +123,6 @@ def test_determine_best_reference(self): class SplitDatetimeRangeTests(unittest.TestCase): - def test_split_range(self): # 101 days start = date(2016, 1, 1) @@ -188,17 +179,16 @@ def test_pairwise(self): class FitTimingOffsetTests(unittest.TestCase): - def test_fit_timing_offset(self): deviations = [] for _ in range(50): - center = uniform(-40, 40) - sigma = uniform(10, 30) + center = random.uniform(-40, 40) + sigma = random.uniform(10, 30) n = int(4e4) lower = center - 3 * sigma upper = center + 3 * sigma bins = list(range(int(lower), int(upper), 1)) - dt = normal(center, sigma, n) + dt = random.normal(center, sigma, n) offset, error = calibration.fit_timing_offset(dt, bins) deviations.append((center - offset) / error) # Test if determined offset close to the actual center. @@ -208,15 +198,16 @@ def test_fit_timing_offset(self): class DetermineStationTimingOffsetsTests(unittest.TestCase): - def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) stations = [501, 102, 105, 8001] - self.off = calibration.DetermineStationTimingOffsets(stations=stations, data=sentinel.data, - progress=sentinel.progress, force_stale=True) - - def tearDown(self): - warnings.resetwarnings() + self.off = calibration.DetermineStationTimingOffsets( + stations=stations, + data=sentinel.data, + progress=sentinel.progress, + force_stale=True, + ) def test_init(self): self.assertEqual(self.off.progress, sentinel.progress) @@ -236,7 +227,7 @@ def test_read_dt(self): start = datetime(2014, 1, 1) end = datetime(2016, 12, 31) self.off.read_dt(station, ref_station, start, end) - table_path = ('/coincidences/time_deltas/station_%d/station_%d' % (ref_station, station)) + table_path = '/coincidences/time_deltas/station_%d/station_%d' % (ref_station, station) table_name = 'time_deltas' self.off.data.get_node.assert_called_once_with(table_path, table_name) self.assertTrue(table_mock.read_where.called) @@ -248,8 +239,12 @@ def test_station_pairs_within_max_distance(self): def test_station_pairs_wrong_order(self): stations = [105, 102, 8001, 501] - self.off = calibration.DetermineStationTimingOffsets(stations=stations, data=sentinel.data, - progress=sentinel.progress, force_stale=True) + self.off = calibration.DetermineStationTimingOffsets( + stations=stations, + data=sentinel.data, + progress=sentinel.progress, + force_stale=True, + ) results = list(self.off.get_station_pairs_within_max_distance()) self.assertEqual([(102, 105)], results) @@ -271,22 +266,24 @@ def test_get_r_dz(self): self.assertAlmostEqual(dz, dz_102_105, places=5) def test_determine_interval(self): - combinations = ((0., 7), - (50., 10), - (200., 57), - (1000., 398)) + combinations = ((0.0, 7), (50.0, 10), (200.0, 57), (1000.0, 398)) for r, ref_int in combinations: self.assertEqual(self.off._determine_interval(r), ref_int) def test_get_cuts(self): - gps_station = (datetime_to_gps(datetime(2014, 1, 1, 10, 3)), - datetime_to_gps(datetime(2014, 3, 1, 11, 32))) - gps_ref_station = (datetime_to_gps(datetime(2014, 1, 5, 0, 1, 1)), - datetime_to_gps(datetime(2014, 3, 5, 3, 34, 4))) - elec_station = (datetime_to_gps(datetime(2014, 1, 3, 3, 34, 3)), - datetime_to_gps(datetime(2014, 3, 5, 23, 59, 59))) - elec_ref_station = (datetime_to_gps(datetime(2014, 1, 9, 0, 0, 0)), - datetime_to_gps(datetime(2014, 3, 15, 1, 2, 3))) + gps_station = (datetime_to_gps(datetime(2014, 1, 1, 10, 3)), datetime_to_gps(datetime(2014, 3, 1, 11, 32))) + gps_ref_station = ( + datetime_to_gps(datetime(2014, 1, 5, 0, 1, 1)), + datetime_to_gps(datetime(2014, 3, 5, 3, 34, 4)), + ) + elec_station = ( + datetime_to_gps(datetime(2014, 1, 3, 3, 34, 3)), + datetime_to_gps(datetime(2014, 3, 5, 23, 59, 59)), + ) + elec_ref_station = ( + datetime_to_gps(datetime(2014, 1, 9, 0, 0, 0)), + datetime_to_gps(datetime(2014, 3, 15, 1, 2, 3)), + ) gps_mock = Mock() elec_mock = Mock() @@ -308,15 +305,14 @@ def test_get_cuts(self): self.assertEqual(cuts[-1], datetime(today.year, today.month, today.day)) def test_get_left_and_right_bounds(self): - cuts = (datetime(2014, 1, 1), - datetime(2015, 1, 1), - datetime(2015, 1, 5), - datetime(2015, 1, 10)) - combinations = [(datetime(2015, 1, 1), 7, datetime(2015, 1, 1), datetime(2015, 1, 4)), - (datetime(2015, 1, 3), 7, datetime(2015, 1, 1), datetime(2015, 1, 4)), - (datetime(2015, 1, 3).date(), 7, datetime(2015, 1, 1), datetime(2015, 1, 4)), - (datetime(2015, 1, 5), 7, datetime(2015, 1, 5), datetime(2015, 1, 9)), - (datetime(2015, 1, 10), 7, datetime(2015, 1, 5), datetime(2015, 1, 10))] + cuts = (datetime(2014, 1, 1), datetime(2015, 1, 1), datetime(2015, 1, 5), datetime(2015, 1, 10)) + combinations = [ + (datetime(2015, 1, 1), 7, datetime(2015, 1, 1), datetime(2015, 1, 4)), + (datetime(2015, 1, 3), 7, datetime(2015, 1, 1), datetime(2015, 1, 4)), + (datetime(2015, 1, 3).date(), 7, datetime(2015, 1, 1), datetime(2015, 1, 4)), + (datetime(2015, 1, 5), 7, datetime(2015, 1, 5), datetime(2015, 1, 9)), + (datetime(2015, 1, 10), 7, datetime(2015, 1, 5), datetime(2015, 1, 10)), + ] for d, days, ref_left, ref_right in combinations: left, right = self.off._get_left_and_right_bounds(cuts, d, days) self.assertEqual(left, ref_left) @@ -366,16 +362,16 @@ def test_determine_station_timing_offset(self, mock_det_offset): self.off._get_r_dz.return_value = sentinel.r, sentinel.dz self.off.determine_first_and_last_date.return_value = (0, 0) - self.off.read_dt.return_value = 1000 * [0.] - mock_det_offset.return_value = (10., 1.) + self.off.read_dt.return_value = 1000 * [0.0] + mock_det_offset.return_value = (10.0, 1.0) offsets = self.off.determine_station_timing_offset(date, sentinel.station, sentinel.ref_station) self.off._get_r_dz.assert_called_once_with(date, sentinel.station, sentinel.ref_station) self.off.determine_first_and_last_date.assert_called_once_with(date, sentinel.station, sentinel.ref_station) - self.assertEqual(offsets, (10., 1.)) + self.assertEqual(offsets, (10.0, 1.0)) - self.off.read_dt.return_value = 90 * [0.] + self.off.read_dt.return_value = 90 * [0.0] offsets = self.off.determine_station_timing_offset(date, sentinel.station, sentinel.ref_station) self.assertEqual(offsets, (nan, nan)) diff --git a/sapphire/tests/analysis/test_coincidence_queries.py b/sapphire/tests/analysis/test_coincidence_queries.py index d8cb144e..43be913b 100644 --- a/sapphire/tests/analysis/test_coincidence_queries.py +++ b/sapphire/tests/analysis/test_coincidence_queries.py @@ -6,22 +6,22 @@ class BaseCoincidenceQueryTest(unittest.TestCase): - @patch.object(coincidence_queries.tables, 'open_file') def setUp(self, mock_open_file): self.mock_open_file = mock_open_file self.data_path = sentinel.data_path self.coincidences_group = sentinel.coincidences_group - self.cq = coincidence_queries.CoincidenceQuery( - self.data_path, self.coincidences_group) + self.cq = coincidence_queries.CoincidenceQuery(self.data_path, self.coincidences_group) def test_init_opens_file_and_gets_nodes(self): self.mock_open_file.assert_called_once_with(self.data_path, 'r') - expected = [call(self.coincidences_group, 'coincidences'), - call(self.coincidences_group, 'c_index'), - call(self.coincidences_group, 's_index'), - call(self.coincidences_group, 'reconstructions')] + expected = [ + call(self.coincidences_group, 'coincidences'), + call(self.coincidences_group, 'c_index'), + call(self.coincidences_group, 's_index'), + call(self.coincidences_group, 'reconstructions'), + ] call_list = self.mock_open_file.return_value.get_node.call_args_list self.assertEqual(call_list, expected) @@ -70,8 +70,7 @@ def test_at_least(self, mock_columns, mock_query): n = 2 self.cq.at_least(sentinel.stations, n) mock_columns.assert_called_once_with(sentinel.stations) - mock_query.assert_called_once_with('((s501 & s502) | (s501 & s503) | ' - '(s502 & s503))', False) + mock_query.assert_called_once_with('((s501 & s502) | (s501 & s503) | (s502 & s503))', False) @patch.object(coincidence_queries.CoincidenceQuery, 'perform_query') def test_timerange(self, mock_query): @@ -153,8 +152,7 @@ def test_reconstructions_from_stations(self, mock_get_reconstructions, mock_even # mock_minimum.assert_called_once_with([sentinel.coincidence_events]) def test__events_from_stations(self): - events = ([sentinel.station1, sentinel.event1], - [sentinel.station2, sentinel.event2]) + events = ([sentinel.station1, sentinel.event1], [sentinel.station2, sentinel.event2]) stations = [sentinel.station2] result = self.cq._events_from_stations(events, stations) self.assertEqual(result, [[sentinel.station2, sentinel.event2]]) diff --git a/sapphire/tests/analysis/test_coincidences.py b/sapphire/tests/analysis/test_coincidences.py index ac9aea45..e2ba8a4d 100644 --- a/sapphire/tests/analysis/test_coincidences.py +++ b/sapphire/tests/analysis/test_coincidences.py @@ -17,7 +17,6 @@ class CoincidencesTests(unittest.TestCase): - @patch.object(coincidences.tables, 'open_file') def setUp(self, mock_open_file): self.mock_open_file = mock_open_file @@ -32,7 +31,7 @@ def test_init(self): @patch.object(coincidences.Coincidences, 'store_coincidences') def test_search_and_store_coincidences(self, mock_store, mock_process, mock_search): self.c.search_and_store_coincidences() - mock_search.assert_called_with(window=10000) + mock_search.assert_called_with(window=10_000) mock_process.assert_called_with() mock_store.assert_called_with() self.c.search_and_store_coincidences(sentinel.window) @@ -44,53 +43,88 @@ def test__retrieve_timestamps(self): station1 = Mock() station2 = Mock() # Station 2 timestamps are not already correctly sorted. - station1.col.return_value = [uint64(1400000002000000050), uint64(1400000018000000500)] - station2.col.return_value = [uint64(1400000002000000510), uint64(1400000030000000000)][::-1] + station1.col.return_value = [uint64(1400000002_000000050), uint64(1400000018_000000500)] + station2.col.return_value = [uint64(1400000002_000000510), uint64(1400000030_000000000)][::-1] stations = [station1, station2] timestamps = self.c._retrieve_timestamps(stations) - self.assertEqual(timestamps, - [(uint64(1400000002000000050), 0, 0), (uint64(1400000002000000510), 1, 1), - (uint64(1400000018000000500), 0, 1), (uint64(1400000030000000000), 1, 0)]) + self.assertEqual( + timestamps, + [ + (uint64(1400000002_000000050), 0, 0), + (uint64(1400000002_000000510), 1, 1), + (uint64(1400000018_000000500), 0, 1), + (uint64(1400000030_000000000), 1, 0), + ], + ) # Shift both timestamps = self.c._retrieve_timestamps(stations, shifts=[1, 17]) - self.assertEqual(timestamps, - [(uint64(1400000003000000050), 0, 0), (uint64(1400000019000000500), 0, 1), - (uint64(1400000019000000510), 1, 1), (uint64(1400000047000000000), 1, 0)]) + self.assertEqual( + timestamps, + [ + (uint64(1400000003_000000050), 0, 0), + (uint64(1400000019_000000500), 0, 1), + (uint64(1400000019_000000510), 1, 1), + (uint64(1400000047_000000000), 1, 0), + ], + ) # Wrong value type shifts - self.assertRaises(TypeError, self.c._retrieve_timestamps, stations, shifts=['', '']) - self.assertRaises(TypeError, self.c._retrieve_timestamps, stations, shifts=['', 90]) + self.assertRaises(ValueError, self.c._retrieve_timestamps, stations, shifts=['', '']) + self.assertRaises(ValueError, self.c._retrieve_timestamps, stations, shifts=['', 90]) # Different length shifts timestamps = self.c._retrieve_timestamps(stations, shifts=[110]) - self.assertEqual(timestamps, - [(uint64(1400000002000000510), 1, 1), (uint64(1400000030000000000), 1, 0), - (uint64(1400000112000000050), 0, 0), (uint64(1400000128000000500), 0, 1)]) + self.assertEqual( + timestamps, + [ + (uint64(1400000002_000000510), 1, 1), + (uint64(1400000030_000000000), 1, 0), + (uint64(1400000112_000000050), 0, 0), + (uint64(1400000128_000000500), 0, 1), + ], + ) timestamps = self.c._retrieve_timestamps(stations, shifts=[None, 60]) - self.assertEqual(timestamps, - [(uint64(1400000002000000050), 0, 0), (uint64(1400000018000000500), 0, 1), - (uint64(1400000062000000510), 1, 1), (uint64(1400000090000000000), 1, 0)]) + self.assertEqual( + timestamps, + [ + (uint64(1400000002_000000050), 0, 0), + (uint64(1400000018_000000500), 0, 1), + (uint64(1400000062_000000510), 1, 1), + (uint64(1400000090_000000000), 1, 0), + ], + ) # Subsecond shifts timestamps = self.c._retrieve_timestamps(stations, shifts=[3e-9, 5e-9]) - self.assertEqual(timestamps, - [(uint64(1400000002000000053), 0, 0), (uint64(1400000002000000515), 1, 1), - (uint64(1400000018000000503), 0, 1), (uint64(1400000030000000005), 1, 0)]) + self.assertEqual( + timestamps, + [ + (uint64(1400000002_000000053), 0, 0), + (uint64(1400000002_000000515), 1, 1), + (uint64(1400000018_000000503), 0, 1), + (uint64(1400000030_000000005), 1, 0), + ], + ) # Using limits timestamps = self.c._retrieve_timestamps(stations, limit=1) - self.assertEqual(timestamps, - [(uint64(1400000002000000050), 0, 0), (uint64(1400000030000000000), 1, 0)]) - # This should fail but does not - self.assertEqual(timestamps, - [(1400000002000000049, 0, 0), (1400000030000000000, 1, 0)]) - # Using uint64 does work correctly - self.assertNotEqual(timestamps, - [(uint64(1400000002000000049), 0, 0), (uint64(1400000030000000000), 1, 0)]) - self.assertNotEqual(timestamps, - [(uint64(1400000002000000051), 0, 0), (uint64(1400000030000000001), 1, 0)]) + + # Check accuracy of comparisons between different types + self.assertEqual(timestamps, [(1400000002_000000050, 0, 0), (1400000030_000000000, 1, 0)]) + self.assertEqual(timestamps, [(uint64(1400000002_000000050), 0, 0), (uint64(1400000030_000000000), 1, 0)]) + self.assertNotEqual(timestamps, [(1400000002_000000049, 0, 0), (1400000030_000000000, 1, 0)]) + self.assertNotEqual(timestamps, [(1400000002_000000051, 0, 0), (1400000030_000000000, 1, 0)]) + self.assertNotEqual(timestamps, [(uint64(1400000002_000000049), 0, 0), (uint64(1400000030_000000000), 1, 0)]) + self.assertNotEqual(timestamps, [(uint64(1400000002_000000051), 0, 0), (uint64(1400000030_000000001), 1, 0)]) def test__do_search_coincidences(self): # [(timestamp, station_idx, event_idx), ..] - timestamps = [(uint64(0), 0, 0), (uint64(0), 1, 0), (uint64(10), 1, 1), - (uint64(15), 2, 0), (uint64(100), 1, 2), (uint64(200), 2, 1), - (uint64(250), 0, 1), (uint64(251), 0, 2)] + timestamps = [ + (uint64(0), 0, 0), + (uint64(0), 1, 0), + (uint64(10), 1, 1), + (uint64(15), 2, 0), + (uint64(100), 1, 2), + (uint64(200), 2, 1), + (uint64(250), 0, 1), + (uint64(251), 0, 2), + ] c = self.c._do_search_coincidences(timestamps, window=6) expected_coincidences = [[0, 1], [2, 3], [6, 7]] @@ -106,7 +140,6 @@ def test__do_search_coincidences(self): class CoincidencesESDTests(CoincidencesTests): - @patch.object(coincidences.tables, 'open_file') def setUp(self, mock_open_file): self.mock_open_file = mock_open_file @@ -135,19 +168,14 @@ def test_search_coincidences(self, mock__search): class CoincidencesDataTests(unittest.TestCase): - def setUp(self): self.data_path = self.create_tempfile_from_testdata() - - def tearDown(self): - os.remove(self.data_path) + self.addCleanup(os.remove, self.data_path) def test_coincidencesesd_output(self): with tables.open_file(self.data_path, 'a') as data: with patch('sapphire.analysis.process_events.ProcessIndexedEventsWithoutTraces'): - c = coincidences.Coincidences(data, '/coincidences', - ['/station_501', '/station_502'], - progress=False) + c = coincidences.Coincidences(data, '/coincidences', ['/station_501', '/station_502'], progress=False) c.search_and_store_coincidences() validate_results(self, self.get_testdata_path(), self.data_path) @@ -174,12 +202,9 @@ def remove_existing_coincidences(self, path): class CoincidencesESDDataTests(CoincidencesDataTests): - def test_coincidencesesd_output(self): with tables.open_file(self.data_path, 'a') as data: - c = coincidences.CoincidencesESD(data, '/coincidences', - ['/station_501', '/station_502'], - progress=False) + c = coincidences.CoincidencesESD(data, '/coincidences', ['/station_501', '/station_502'], progress=False) self.assertRaises(RuntimeError, c.search_and_store_coincidences, station_numbers=[501]) c.search_and_store_coincidences(station_numbers=[501, 502]) validate_results(self, self.get_testdata_path(), self.data_path) diff --git a/sapphire/tests/analysis/test_core_reconstruction.py b/sapphire/tests/analysis/test_core_reconstruction.py index d6b24b65..6cccf4e3 100644 --- a/sapphire/tests/analysis/test_core_reconstruction.py +++ b/sapphire/tests/analysis/test_core_reconstruction.py @@ -4,7 +4,6 @@ class BaseAlgorithm: - """Use this class to check the different algorithms They should give similar results and errors in some cases. @@ -18,28 +17,25 @@ def test_stations_square(self): """Four detection points in a square shape.""" # Same density - p = (1., 1., 1., 1.) - x = (0., 0., 10., 10.) - y = (0., 10., 10., 0.) - z = (0., 0., 0., 0.) + p = (1.0, 1.0, 1.0, 1.0) + x = (0.0, 0.0, 10.0, 10.0) + y = (0.0, 10.0, 10.0, 0.0) + z = (0.0, 0.0, 0.0, 0.0) result = self.call_reconstruct(p, x, y, z) - self.assertAlmostEqual(result[0], 5.) - self.assertAlmostEqual(result[1], 5.) + self.assertAlmostEqual(result[0], 5.0) + self.assertAlmostEqual(result[1], 5.0) class CenterMassAlgorithmTest(unittest.TestCase, BaseAlgorithm): - def setUp(self): self.algorithm = core_reconstruction.CenterMassAlgorithm() class AverageIntersectionAlgorithmTest(unittest.TestCase, BaseAlgorithm): - def setUp(self): self.algorithm = core_reconstruction.AverageIntersectionAlgorithm() class EllipsLdfAlgorithmTest(unittest.TestCase, BaseAlgorithm): - def setUp(self): self.algorithm = core_reconstruction.EllipsLdfAlgorithm() diff --git a/sapphire/tests/analysis/test_direction_reconstruction.py b/sapphire/tests/analysis/test_direction_reconstruction.py index 7bc0a391..1a0b9689 100644 --- a/sapphire/tests/analysis/test_direction_reconstruction.py +++ b/sapphire/tests/analysis/test_direction_reconstruction.py @@ -10,7 +10,6 @@ class EventDirectionReconstructionTest(unittest.TestCase): - def test_init(self): dirrec = direction_reconstruction.EventDirectionReconstruction(sentinel.station) self.assertEqual(dirrec.direct, direction_reconstruction.DirectAlgorithmCartesian3D) @@ -36,7 +35,7 @@ def test_bad_times(self, mock_detector_arrival_time): @patch.object(direction_reconstruction.event_utils, 'detector_arrival_time') def test_reconstruct_event(self, mock_detector_arrival_time): - mock_detector_arrival_time.return_value = 0. + mock_detector_arrival_time.return_value = 0.0 station = MagicMock() detector = Mock() detector.get_coordinates.return_value = [sentinel.x, sentinel.y, sentinel.z] @@ -58,7 +57,12 @@ def test_reconstruct_event(self, mock_detector_arrival_time): # Three detections, direct reconstruction theta, phi, ids = dirrec.reconstruct_event(event, detector_ids=[0, 1, 2]) dirrec.direct.reconstruct_common.assert_called_once_with( - [0.] * 3, [sentinel.x] * 3, [sentinel.y] * 3, [sentinel.z] * 3, None) + [0.0] * 3, + [sentinel.x] * 3, + [sentinel.y] * 3, + [sentinel.z] * 3, + None, + ) dirrec.fit.reconstruct_common.assert_not_called() self.assertEqual(theta, sentinel.theta) self.assertEqual(phi, sentinel.phi) @@ -68,13 +72,23 @@ def test_reconstruct_event(self, mock_detector_arrival_time): theta, phi, ids = dirrec.reconstruct_event(event, detector_ids=[0, 1, 2, 3]) self.assertEqual(dirrec.direct.reconstruct_common.call_count, 1) dirrec.fit.reconstruct_common.assert_called_once_with( - [0.] * 4, [sentinel.x] * 4, [sentinel.y] * 4, [sentinel.z] * 4, None) + [0.0] * 4, + [sentinel.x] * 4, + [sentinel.y] * 4, + [sentinel.z] * 4, + None, + ) self.assertEqual(theta, sentinel.theta) self.assertEqual(phi, sentinel.phi) self.assertEqual(len(ids), 4) theta, phi, ids = dirrec.reconstruct_event(event, detector_ids=None) dirrec.fit.reconstruct_common.assert_called_with( - [0.] * 4, [sentinel.x] * 4, [sentinel.y] * 4, [sentinel.z] * 4, None) + [0.0] * 4, + [sentinel.x] * 4, + [sentinel.y] * 4, + [sentinel.z] * 4, + None, + ) self.assertEqual(dirrec.fit.reconstruct_common.call_count, 2) # Four detections, fit reconstruction with offsets @@ -88,18 +102,25 @@ def test_reconstruct_event(self, mock_detector_arrival_time): def test_reconstruct_events(self, mock_reconstruct_event): mock_reconstruct_event.return_value = [sentinel.theta, sentinel.phi, sentinel.ids] dirrec = direction_reconstruction.EventDirectionReconstruction(sentinel.station) - self.assertEqual(dirrec.reconstruct_events([sentinel.event, sentinel.event], - sentinel.detector_ids, sentinel.offsets, progress=False), - ((sentinel.theta, sentinel.theta), (sentinel.phi, sentinel.phi), (sentinel.ids, sentinel.ids))) + self.assertEqual( + dirrec.reconstruct_events( + [sentinel.event, sentinel.event], + sentinel.detector_ids, + sentinel.offsets, + progress=False, + ), + ((sentinel.theta, sentinel.theta), (sentinel.phi, sentinel.phi), (sentinel.ids, sentinel.ids)), + ) self.assertEqual(mock_reconstruct_event.call_count, 2) mock_reconstruct_event.assert_called_with(sentinel.event, sentinel.detector_ids, sentinel.offsets, None) - self.assertEqual(dirrec.reconstruct_events([], sentinel.detector_ids, sentinel.offsets, progress=False), - ((), (), ())) + self.assertEqual( + dirrec.reconstruct_events([], sentinel.detector_ids, sentinel.offsets, progress=False), + ((), (), ()), + ) self.assertEqual(mock_reconstruct_event.call_count, 2) class CoincidenceDirectionReconstructionTest(unittest.TestCase): - def setUp(self): self.dirrec = direction_reconstruction.CoincidenceDirectionReconstruction(sentinel.cluster) @@ -122,7 +143,7 @@ def test_set_cluster_timestamp(self): @patch.object(direction_reconstruction.event_utils, 'station_arrival_time') def test_reconstruct_coincidence(self, mock_station_arrival_time): dirrec = self.dirrec - mock_station_arrival_time.return_value = 0. + mock_station_arrival_time.return_value = 0.0 cluster = MagicMock() station = MagicMock() cluster.get_station.return_value = station @@ -135,10 +156,13 @@ def test_reconstruct_coincidence(self, mock_station_arrival_time): dirrec.fit.reconstruct_common.return_value = (sentinel.theta, sentinel.phi) dirrec.curved.reconstruct_common.return_value = (sentinel.theta, sentinel.phi) coincidence_2 = [[sentinel.station_number, {'timestamp': 1}], [1, sentinel.event]] - coincidence_3 = [[sentinel.station_number, {'timestamp': 1}], [1, sentinel.event], - [2, sentinel.event]] - coincidence_4 = [[sentinel.station_number, {'timestamp': 1}], [1, sentinel.event], - [2, sentinel.event], [3, sentinel.event]] + coincidence_3 = [[sentinel.station_number, {'timestamp': 1}], [1, sentinel.event], [2, sentinel.event]] + coincidence_4 = [ + [sentinel.station_number, {'timestamp': 1}], + [1, sentinel.event], + [2, sentinel.event], + [3, sentinel.event], + ] # To few events theta, phi, nums = dirrec.reconstruct_coincidence(coincidence_2) @@ -157,7 +181,12 @@ def test_reconstruct_coincidence(self, mock_station_arrival_time): theta, phi, nums = dirrec.reconstruct_coincidence(coincidence_3) cluster.set_timestamp.assert_called_with(1) dirrec.direct.reconstruct_common.assert_called_once_with( - [0.] * 3, [sentinel.x] * 3, [sentinel.y] * 3, [sentinel.z] * 3, {}) + [0.0] * 3, + [sentinel.x] * 3, + [sentinel.y] * 3, + [sentinel.z] * 3, + {}, + ) dirrec.fit.reconstruct_common.assert_not_called() self.assertEqual(theta, sentinel.theta) self.assertEqual(phi, sentinel.phi) @@ -167,7 +196,12 @@ def test_reconstruct_coincidence(self, mock_station_arrival_time): theta, phi, nums = dirrec.reconstruct_coincidence(coincidence_4) cluster.set_timestamp.assert_called_with(1) dirrec.fit.reconstruct_common.assert_called_once_with( - [0.] * 4, [sentinel.x] * 4, [sentinel.y] * 4, [sentinel.z] * 4, {}) + [0.0] * 4, + [sentinel.x] * 4, + [sentinel.y] * 4, + [sentinel.z] * 4, + {}, + ) self.assertEqual(dirrec.direct.reconstruct_common.call_count, 1) dirrec.curved.reconstruct_common.assert_not_called() self.assertEqual(theta, sentinel.theta) @@ -197,13 +231,26 @@ def test_reconstruct_coincidence(self, mock_station_arrival_time): @patch.object(direction_reconstruction.CoincidenceDirectionReconstruction, 'reconstruct_coincidence') def test_reconstruct_coincidences(self, mock_reconstruct_coincidence): mock_reconstruct_coincidence.return_value = [sentinel.theta, sentinel.phi, sentinel.nums] - self.assertEqual(self.dirrec.reconstruct_coincidences([sentinel.coincidence, sentinel.coincidence], - sentinel.station_numbers, sentinel.offsets, progress=False), - ((sentinel.theta, sentinel.theta), (sentinel.phi, sentinel.phi), (sentinel.nums, sentinel.nums))) + self.assertEqual( + self.dirrec.reconstruct_coincidences( + [sentinel.coincidence, sentinel.coincidence], + sentinel.station_numbers, + sentinel.offsets, + progress=False, + ), + ((sentinel.theta, sentinel.theta), (sentinel.phi, sentinel.phi), (sentinel.nums, sentinel.nums)), + ) self.assertEqual(mock_reconstruct_coincidence.call_count, 2) - mock_reconstruct_coincidence.assert_called_with(sentinel.coincidence, sentinel.station_numbers, sentinel.offsets, None) - self.assertEqual(self.dirrec.reconstruct_coincidences([], sentinel.station_numbers, sentinel.offsets, progress=False), - ((), (), ())) + mock_reconstruct_coincidence.assert_called_with( + sentinel.coincidence, + sentinel.station_numbers, + sentinel.offsets, + None, + ) + self.assertEqual( + self.dirrec.reconstruct_coincidences([], sentinel.station_numbers, sentinel.offsets, progress=False), + ((), (), ()), + ) self.assertEqual(mock_reconstruct_coincidence.call_count, 2) def test_get_station_offsets(self): @@ -215,25 +262,21 @@ def test_get_station_offsets(self): station_numbers = None offsets = {} ts0 = 86400 - result = dirrec.get_station_offsets(coincidence_events, station_numbers, - offsets, ts0) + result = dirrec.get_station_offsets(coincidence_events, station_numbers, offsets, ts0) self.assertEqual(result, offsets) offsets = {1: MagicMock(spec=direction_reconstruction.Station)} - result = dirrec.get_station_offsets(coincidence_events, station_numbers, - offsets, ts0) + result = dirrec.get_station_offsets(coincidence_events, station_numbers, offsets, ts0) self.assertEqual(result, sentinel.best_offset) mock_offsets.assert_called_once_with([sentinel.sn1], ts0, offsets) ts0 = 864000 + 12345 - result = dirrec.get_station_offsets(coincidence_events, station_numbers, - offsets, ts0) + result = dirrec.get_station_offsets(coincidence_events, station_numbers, offsets, ts0) self.assertEqual(result, sentinel.best_offset) mock_offsets.assert_called_with([sentinel.sn1], 864000, offsets) station_numbers = sentinel.station_numbers - result = dirrec.get_station_offsets(coincidence_events, station_numbers, - offsets, ts0) + result = dirrec.get_station_offsets(coincidence_events, station_numbers, offsets, ts0) self.assertEqual(result, sentinel.best_offset) mock_offsets.assert_called_with(sentinel.station_numbers, 864000, offsets) @@ -247,41 +290,27 @@ def test_determine_best_offsets(self): midnight_ts = sentinel.midnight_ts best_offsets = dirrec.determine_best_offsets(station_numbers, midnight_ts, offsets) self.assertEqual(list(best_offsets.keys()), station_numbers) - self.assertEqual(list(best_offsets.values()), [[1.0, 0.0, 2.0, 3.0], - [2.0, 1.0, 3.0, 4.0]]) + self.assertEqual(list(best_offsets.values()), [[1.0, 0.0, 2.0, 3.0], [2.0, 1.0, 3.0, 4.0]]) def test_determine_best_reference(self): # last station would be best reference, but not in station_numbers # second and third station are tied, so second is best reference - error_matrix = array([[0, 5, 2, 1], - [5, 0, 1, 1], - [2, 1, 0, 1], - [1, 1, 1, 0]]) + error_matrix = array([[0, 5, 2, 1], [5, 0, 1, 1], [2, 1, 0, 1], [1, 1, 1, 0]]) station_numbers = [1, 2, 3] ref, pred = self.dirrec.determine_best_reference(error_matrix, station_numbers) self.assertEqual(ref, 2) - predecessors = array([[-9999, 3, 0, 0], - [3, -9999, 1, 1], - [2, 2, -9999, 2], - [3, 3, 3, -9999]]) + predecessors = array([[-9999, 3, 0, 0], [3, -9999, 1, 1], [2, 2, -9999, 2], [3, 3, 3, -9999]]) self.assertEqual(pred.tolist(), predecessors.tolist()) def test__reconstruct_best_offset(self): offset = self.dirrec._reconstruct_best_offset([], 1, 1, [], []) self.assertEqual(offset, 0) - predecessors = array([[-9999, 0, 1], - [1, -9999, 1], - [1, 2, -9999]]) - offset_matrix = array([[0, -1, -1], - [1, 0, -1], - [1, 1, 0]]) + predecessors = array([[-9999, 0, 1], [1, -9999, 1], [1, 2, -9999]]) + offset_matrix = array([[0, -1, -1], [1, 0, -1], [1, 1, 0]]) station_numbers = [1, 2, 3] - combinations = [(1, 1, 0), - (1, 2, 1), - (1, 3, 2), - (2, 3, 1)] + combinations = [(1, 1, 0), (1, 2, 1), (1, 3, 2), (2, 3, 1)] for sn1, sn2, offset in combinations: o12 = self.dirrec._reconstruct_best_offset(predecessors, sn1, sn2, station_numbers, offset_matrix) o21 = self.dirrec._reconstruct_best_offset(predecessors, sn2, sn1, station_numbers, offset_matrix) @@ -290,23 +319,22 @@ def test__reconstruct_best_offset(self): def test__calculate_offsets(self): mock_station = Mock() - mock_station.detector_timing_offset.return_value = [0., 1., 2., 3.] - offset = 1. + mock_station.detector_timing_offset.return_value = [0.0, 1.0, 2.0, 3.0] + offset = 1.0 ts0 = sentinel.timestamp offsets = self.dirrec._calculate_offsets(mock_station, ts0, offset) mock_station.detector_timing_offset.assert_called_once_with(ts0) - self.assertEqual(offsets, [1., 2., 3., 4.]) + self.assertEqual(offsets, [1.0, 2.0, 3.0, 4.0]) class CoincidenceDirectionReconstructionDetectorsTest(CoincidenceDirectionReconstructionTest): - def setUp(self): self.dirrec = direction_reconstruction.CoincidenceDirectionReconstructionDetectors(sentinel.cluster) @patch.object(direction_reconstruction.event_utils, 'relative_detector_arrival_times') def test_reconstruct_coincidence(self, mock_arrival_times): dirrec = self.dirrec - mock_arrival_times.return_value = [0., 0., nan, nan] + mock_arrival_times.return_value = [0.0, 0.0, nan, nan] cluster = MagicMock() station = MagicMock() cluster.get_station.return_value = station @@ -323,10 +351,13 @@ def test_reconstruct_coincidence(self, mock_arrival_times): coincidence_0 = [] coincidence_1 = [[sentinel.station_number, {'timestamp': 1}]] coincidence_2 = [[sentinel.station_number, {'timestamp': 1}], [1, sentinel.event]] - coincidence_3 = [[sentinel.station_number, {'timestamp': 1}], [1, sentinel.event], - [2, sentinel.event]] - coincidence_4 = [[sentinel.station_number, {'timestamp': 1}], [1, sentinel.event], - [2, sentinel.event], [3, sentinel.event]] + coincidence_3 = [[sentinel.station_number, {'timestamp': 1}], [1, sentinel.event], [2, sentinel.event]] + coincidence_4 = [ + [sentinel.station_number, {'timestamp': 1}], + [1, sentinel.event], + [2, sentinel.event], + [3, sentinel.event], + ] # To few detection points, no reconstruction theta, phi, nums = dirrec.reconstruct_coincidence(coincidence_0) @@ -351,7 +382,12 @@ def test_reconstruct_coincidence(self, mock_arrival_times): theta, phi, nums = dirrec.reconstruct_coincidence(coincidence_2) cluster.set_timestamp.assert_called_with(1) dirrec.fit.reconstruct_common.assert_called_once_with( - [0.] * 4, [sentinel.x] * 4, [sentinel.y] * 4, [sentinel.z] * 4, {}) + [0.0] * 4, + [sentinel.x] * 4, + [sentinel.y] * 4, + [sentinel.z] * 4, + {}, + ) self.assertEqual(dirrec.fit.reconstruct_common.call_count, 1) self.assertEqual(theta, sentinel.theta) self.assertEqual(phi, sentinel.phi) @@ -367,13 +403,18 @@ def test_reconstruct_coincidence(self, mock_arrival_times): self.assertEqual(phi, sentinel.phi) self.assertEqual(len(nums), 4) - mock_arrival_times.return_value = [0., nan, nan, nan] + mock_arrival_times.return_value = [0.0, nan, nan, nan] # Three stations with three detection points, direct reconstruction theta, phi, nums = dirrec.reconstruct_coincidence(coincidence_3) cluster.set_timestamp.assert_called_with(1) dirrec.direct.reconstruct_common.assert_called_once_with( - [0.] * 3, [sentinel.x] * 3, [sentinel.y] * 3, [sentinel.z] * 3, {}) + [0.0] * 3, + [sentinel.x] * 3, + [sentinel.y] * 3, + [sentinel.z] * 3, + {}, + ) self.assertEqual(dirrec.direct.reconstruct_common.call_count, 1) self.assertEqual(theta, sentinel.theta) self.assertEqual(phi, sentinel.phi) @@ -394,18 +435,30 @@ def test_reconstruct_coincidence(self, mock_arrival_times): @patch.object(direction_reconstruction.CoincidenceDirectionReconstructionDetectors, 'reconstruct_coincidence') def test_reconstruct_coincidences(self, mock_reconstruct_coincidence): mock_reconstruct_coincidence.return_value = [sentinel.theta, sentinel.phi, sentinel.nums] - self.assertEqual(self.dirrec.reconstruct_coincidences([sentinel.coincidence, sentinel.coincidence], - sentinel.station_numbers, sentinel.offsets, progress=False), - ((sentinel.theta, sentinel.theta), (sentinel.phi, sentinel.phi), (sentinel.nums, sentinel.nums))) + self.assertEqual( + self.dirrec.reconstruct_coincidences( + [sentinel.coincidence, sentinel.coincidence], + sentinel.station_numbers, + sentinel.offsets, + progress=False, + ), + ((sentinel.theta, sentinel.theta), (sentinel.phi, sentinel.phi), (sentinel.nums, sentinel.nums)), + ) self.assertEqual(mock_reconstruct_coincidence.call_count, 2) - mock_reconstruct_coincidence.assert_called_with(sentinel.coincidence, sentinel.station_numbers, sentinel.offsets, None) - self.assertEqual(self.dirrec.reconstruct_coincidences([], sentinel.station_numbers, sentinel.offsets, progress=False), - ((), (), ())) + mock_reconstruct_coincidence.assert_called_with( + sentinel.coincidence, + sentinel.station_numbers, + sentinel.offsets, + None, + ) + self.assertEqual( + self.dirrec.reconstruct_coincidences([], sentinel.station_numbers, sentinel.offsets, progress=False), + ((), (), ()), + ) self.assertEqual(mock_reconstruct_coincidence.call_count, 2) class BaseAlgorithm: - """Use this class to check the different algorithms This provides a shortcut to call the reconstruct_common method. @@ -417,7 +470,6 @@ def call_reconstruct(self, t, x, y, z, initial=None): class FlatAlgorithm(BaseAlgorithm): - """Use this class to test algorithms for flat shower fronts. They should give similar results and errors in some cases. @@ -429,18 +481,18 @@ def test_stations_in_line(self): """Three detection points on a line do not provide a solution.""" # On a line in x - t = (0., 2., 3.) - x = (0., 0., 0.) # same x - y = (0., 5., 10.) - z = (0., 0., 0.) # same z + t = (0.0, 2.0, 3.0) + x = (0.0, 0.0, 0.0) # same x + y = (0.0, 5.0, 10.0) + z = (0.0, 0.0, 0.0) # same z result = self.call_reconstruct(t, x, y, z) self.assertTrue(isnan(result).all()) # Diagonal line - t = (0., 2., 3.) - x = (0., 5., 10.) - y = (0., 5., 10.) - z = (0., 0., 0.) # same z + t = (0.0, 2.0, 3.0) + x = (0.0, 5.0, 10.0) + y = (0.0, 5.0, 10.0) + z = (0.0, 0.0, 0.0) # same z result = self.call_reconstruct(t, x, y, z) self.assertTrue(isnan(result).all()) @@ -451,44 +503,44 @@ def test_same_stations(self): """ # Two at same location - t = (0., 2., 3.) - x = (0., 0., 1.) - y = (5., 5., 6.) - z = (0., 0., 1.) + t = (0.0, 2.0, 3.0) + x = (0.0, 0.0, 1.0) + y = (5.0, 5.0, 6.0) + z = (0.0, 0.0, 1.0) result = self.call_reconstruct(t, x, y, z) self.assertTrue(isnan(result).all()) - t = (0., 2., 3.) - x = (0., 1., 0.) - y = (5., 6., 5.) - z = (0., 1., 0.) + t = (0.0, 2.0, 3.0) + x = (0.0, 1.0, 0.0) + y = (5.0, 6.0, 5.0) + z = (0.0, 1.0, 0.0) result = self.call_reconstruct(t, x, y, z) self.assertTrue(isnan(result).all()) - t = (0., 2., 3.) - x = (1., 0., 0.) - y = (6., 5., 5.) - z = (1., 0., 0.) + t = (0.0, 2.0, 3.0) + x = (1.0, 0.0, 0.0) + y = (6.0, 5.0, 5.0) + z = (1.0, 0.0, 0.0) result = self.call_reconstruct(t, x, y, z) self.assertTrue(isnan(result).all()) # Three at same location - t = (0., 2., 3.) - x = (0., 0., 0.) # same x - y = (5., 5., 5.) # same y - z = (0., 0., 0.) # same z + t = (0.0, 2.0, 3.0) + x = (0.0, 0.0, 0.0) # same x + y = (5.0, 5.0, 5.0) # same y + z = (0.0, 0.0, 0.0) # same z result = self.call_reconstruct(t, x, y, z) self.assertTrue(isnan(result).all()) def test_shower_from_above(self): """Simple shower from zenith, azimuth can be any allowed value.""" - t = (0., 0., 0.) # same t - x = (0., 10., 0.) - y = (0., 0., 10.) - z = (0., 0., 0.) # same z + t = (0.0, 0.0, 0.0) # same t + x = (0.0, 10.0, 0.0) + y = (0.0, 0.0, 10.0) + z = (0.0, 0.0, 0.0) # same z theta, phi = self.call_reconstruct(t, x, y, z) - self.assertAlmostEqual(theta, 0., 4) + self.assertAlmostEqual(theta, 0.0, 4) # azimuth can be any value between -pi and pi self.assertTrue(-pi <= phi < pi) @@ -497,11 +549,11 @@ def test_to_large_dt(self): # TODO: Add better test with smaller tolerance - x = (0., -5., 5.) - y = (sqrt(100 - 25), 0., 0.) - z = (0., 0., 0.) + x = (0.0, -5.0, 5.0) + y = (sqrt(100 - 25), 0.0, 0.0) + z = (0.0, 0.0, 0.0) - t = (35., 0., 0.) + t = (35.0, 0.0, 0.0) theta, phi = self.call_reconstruct(t, x, y, z) self.assertTrue(isnan(theta)) @@ -510,20 +562,20 @@ def test_showers_at_various_angles(self): c = 0.299792458 - x = (0., -5., 5.) - y = (sqrt(100 - 25), 0., 0.) - z = (0., 0., 0.) + x = (0.0, -5.0, 5.0) + y = (sqrt(100 - 25), 0.0, 0.0) + z = (0.0, 0.0, 0.0) # triangle height h = sqrt(100 - 25) - times = (2.5, 5., 7.5, 10., 12.5, 15., 17.5, 20., 22.5, 25., 27.5) + times = (2.5, 5.0, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 22.5, 25.0, 27.5) for time in times: for i in range(3): zenith = arcsin((time * c) / h) - t = [0., 0., 0.] + t = [0.0, 0.0, 0.0] t[i] = time azimuths = [-pi / 2, pi / 6, pi * 5 / 6] theta, phi = self.call_reconstruct(t, x, y, z) @@ -534,7 +586,7 @@ def test_showers_at_various_angles(self): self.assertEqual((theta, phi), (theta_no_z, phi_no_z)) t = [time] * 3 - t[i] = 0. + t[i] = 0.0 azimuths = [pi / 2, -pi * 5 / 6, -pi / 6] theta, phi = self.call_reconstruct(t, x, y, z) self.assertAlmostEqual(phi, azimuths[i], 4) @@ -545,7 +597,6 @@ def test_showers_at_various_angles(self): class DirectAlgorithm(FlatAlgorithm): - """Use this class to check algorithms that only support three detections They should give similar warnings in some cases. @@ -559,21 +610,20 @@ def test_to_many_stations(self): """ # Shower from above (for first three detectors) - x = (0., 10., 0., 10.) - y = (0., 0., 10., 10.) - z = (0., 0., 0., 0.) - t = (0., 0., 0., 10.) + x = (0.0, 10.0, 0.0, 10.0) + y = (0.0, 0.0, 10.0, 10.0) + z = (0.0, 0.0, 0.0, 0.0) + t = (0.0, 0.0, 0.0, 10.0) with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + warnings.simplefilter('always') theta, phi = self.call_reconstruct(t, x, y, z) self.assertTrue(issubclass(w[0].category, UserWarning)) - self.assertAlmostEqual(theta, 0., 4) + self.assertAlmostEqual(theta, 0.0, 4) self.assertTrue(-pi <= phi < pi) class AltitudeAlgorithm(FlatAlgorithm): - """Use this class to check the altitude support They should give similar results and errors in some cases. @@ -583,14 +633,14 @@ class AltitudeAlgorithm(FlatAlgorithm): def test_stations_altitude(self): """Simple shower on a non horizontal square.""" - x = (0., 10., 10.) - y = (0, 0., 10.) - z = (2., 0., -2.) + x = (0.0, 10.0, 10.0) + y = (0, 0.0, 10.0) + z = (2.0, 0.0, -2.0) - zenith = arctan(4. / 10. / sqrt(2)) + zenith = arctan(4.0 / 10.0 / sqrt(2)) - t = [0., 0., 0.] - azimuth = pi / 4. + t = [0.0, 0.0, 0.0] + azimuth = pi / 4.0 theta, phi = self.call_reconstruct(t, x, y, z) self.assertAlmostEqual(phi, azimuth, 5) @@ -598,14 +648,10 @@ def test_stations_altitude(self): class DirectAltitudeAlgorithm(DirectAlgorithm, AltitudeAlgorithm): - """Test algorithm that uses only 3 detectors and has altitude support.""" - pass - class MultiAlgorithm(FlatAlgorithm): - """Use this class to check the different algorithms for more stations They should give similar results and errors in some cases. @@ -617,20 +663,20 @@ def test_diamond_stations(self): c = 0.299792458 - x = (0., -5., 5., 10.) - y = (sqrt(100 - 25), 0., 0., sqrt(100 - 25)) - z = (0., 0., 0., 0.) + x = (0.0, -5.0, 5.0, 10.0) + y = (sqrt(100 - 25), 0.0, 0.0, sqrt(100 - 25)) + z = (0.0, 0.0, 0.0, 0.0) # triangle height h = sqrt(100 - 25) - times = (2.5, 5., 7.5, 10., 12.5, 15., 17.5, 20., 22.5, 25., 27.5) + times = (2.5, 5.0, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 22.5, 25.0, 27.5) for time in times: zenith = arcsin((time * c) / h) azimuth = pi / 6 - t = [0., 0., 0., 0.] + t = [0.0, 0.0, 0.0, 0.0] t[1] = time t[3] = -time theta, phi = self.call_reconstruct(t, x, y, z) @@ -642,20 +688,20 @@ def test_square_stations(self): c = 0.299792458 - x = (0., 5., 5., 0.) - y = (0, 0., 5., 5.) - z = (0., 0., 0., 0.) + x = (0.0, 5.0, 5.0, 0.0) + y = (0, 0.0, 5.0, 5.0) + z = (0.0, 0.0, 0.0, 0.0) # triangle height - h = sqrt(50. / 4.) + h = sqrt(50.0 / 4.0) - times = (2.5, 5., 7.5, 10.) + times = (2.5, 5.0, 7.5, 10.0) for time in times: zenith = arcsin((time * c) / h) - azimuth = - 3 * pi / 4 + azimuth = -3 * pi / 4 - t = [0., 0., 0., 0.] + t = [0.0, 0.0, 0.0, 0.0] t[0] = -time t[2] = time theta, phi = self.call_reconstruct(t, x, y, z) @@ -664,7 +710,6 @@ def test_square_stations(self): class MultiAltitudeAlgorithm(MultiAlgorithm, AltitudeAlgorithm): - """Check some algorithms for multiple stations at different altitudes. They should give similar results and errors in some cases. @@ -674,14 +719,14 @@ class MultiAltitudeAlgorithm(MultiAlgorithm, AltitudeAlgorithm): def test_hexagon_altitude(self): """Simple shower on a non horizontal square.""" - x = (-5., 5., 10., 5., -5., -10.) - y = (-5. * sqrt(3), -5. * sqrt(3), 0., 5. * sqrt(3), 5. * sqrt(3), 0.) - z = (0., -3., -5., -3., 0., 4.) + x = (-5.0, 5.0, 10.0, 5.0, -5.0, -10.0) + y = (-5.0 * sqrt(3), -5.0 * sqrt(3), 0.0, 5.0 * sqrt(3), 5.0 * sqrt(3), 0.0) + z = (0.0, -3.0, -5.0, -3.0, 0.0, 4.0) zenith = 0.38333 azimuth = 0.00000 - t = [0., 0., 0., 0., 0., 0.] + t = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] theta, phi = self.call_reconstruct(t, x, y, z) self.assertAlmostEqual(phi, azimuth, 4) @@ -689,7 +734,6 @@ def test_hexagon_altitude(self): class CurvedAlgorithm(BaseAlgorithm): - """Check some algorithms supporting a curved shower front. They should give similar results and errors in some cases. @@ -699,20 +743,19 @@ class CurvedAlgorithm(BaseAlgorithm): def test_curved_shower(self): """Simple curved shower on three detectors.""" - t = (0., 0., 10.) - x = (0., 100., 50.) - y = (0., 0., 100.) - z = (0., 0., 0.) + t = (0.0, 0.0, 10.0) + x = (0.0, 100.0, 50.0) + y = (0.0, 0.0, 100.0) + z = (0.0, 0.0, 0.0) init = {'core_x': 50, 'core_y': 0} theta, phi = self.call_reconstruct(t, x, y, z, initial=init) - self.assertAlmostEqual(theta, 0., 4) + self.assertAlmostEqual(theta, 0.0, 4) self.assertTrue(-pi <= phi < pi) class CurvedAltitudeAlgorithm(CurvedAlgorithm): - """Check algorithms for curved fronts and stations at different altitudes. They should give similar results and errors in some cases. @@ -724,67 +767,59 @@ def test_curved_shower_on_stations_with_altitude(self): c = 0.299792458 - z = (10, 0., 40.) - t = (-z[0] / c, 0., 10. - z[2] / c) - x = (0., 100., 50.) - y = (0., 0., 100.) + z = (10, 0.0, 40.0) + t = (-z[0] / c, 0.0, 10.0 - z[2] / c) + x = (0.0, 100.0, 50.0) + y = (0.0, 0.0, 100.0) init = {'core_x': 50, 'core_y': 0} theta, phi = self.call_reconstruct(t, x, y, z, initial=init) - self.assertAlmostEqual(theta, 0., 4) + self.assertAlmostEqual(theta, 0.0, 4) self.assertTrue(-pi <= phi < pi) class DirectAlgorithmTest(unittest.TestCase, DirectAlgorithm): - def setUp(self): self.algorithm = direction_reconstruction.DirectAlgorithm() class DirectAlgorithmCartesianTest(unittest.TestCase, DirectAlgorithm): - def setUp(self): self.algorithm = direction_reconstruction.DirectAlgorithmCartesian() class DirectAlgorithmCartesian3DTest(unittest.TestCase, DirectAltitudeAlgorithm): - def setUp(self): self.algorithm = direction_reconstruction.DirectAlgorithmCartesian3D() class FitAlgorithm3DTest(unittest.TestCase, MultiAltitudeAlgorithm): - def setUp(self): self.algorithm = direction_reconstruction.FitAlgorithm3D() - @unittest.expectedFailure + @unittest.skip('Fails on CI') def test_square_stations(self): super().test_square_stations() class RegressionAlgorithmTest(unittest.TestCase, MultiAlgorithm): - def setUp(self): self.algorithm = direction_reconstruction.RegressionAlgorithm() class RegressionAlgorithm3DTest(unittest.TestCase, MultiAltitudeAlgorithm): - def setUp(self): self.algorithm = direction_reconstruction.RegressionAlgorithm3D() class CurvedRegressionAlgorithmTest(unittest.TestCase, CurvedAlgorithm): - def setUp(self): self.algorithm = direction_reconstruction.CurvedRegressionAlgorithm() self.algorithm.front = ConeFront() class CurvedRegressionAlgorithm3DTest(unittest.TestCase, CurvedAltitudeAlgorithm): - def setUp(self): self.algorithm = direction_reconstruction.CurvedRegressionAlgorithm3D() self.algorithm.front = ConeFront() diff --git a/sapphire/tests/analysis/test_event_utils.py b/sapphire/tests/analysis/test_event_utils.py index 5ac632cc..27d616ab 100644 --- a/sapphire/tests/analysis/test_event_utils.py +++ b/sapphire/tests/analysis/test_event_utils.py @@ -1,6 +1,7 @@ import unittest import warnings +from itertools import repeat from unittest.mock import MagicMock, patch, sentinel from numpy import isnan, nan @@ -9,7 +10,6 @@ class StationDensityTests(unittest.TestCase): - @patch.object(event_utils, 'detector_densities') @patch.object(event_utils, 'get_detector_ids') def test_station_density(self, mock_detector_ids, mock_detector_densities): @@ -26,117 +26,132 @@ def test_station_density(self, mock_detector_ids, mock_detector_densities): class DetectorDensitiesTests(unittest.TestCase): - @patch.object(event_utils, 'detector_density') @patch.object(event_utils, 'get_detector_ids') def test_detector_densities(self, mock_detector_ids, mock_detector_density): mock_detector_ids.return_value = list(range(4)) mock_detector_density.return_value = sentinel.density - self.assertEqual(event_utils.detector_densities(sentinel.event, list(range(4))), - [sentinel.density] * 4) - self.assertEqual(event_utils.detector_densities(sentinel.event, list(range(2))), - [sentinel.density] * 2) + self.assertEqual(event_utils.detector_densities(sentinel.event, list(range(4))), [sentinel.density] * 4) + self.assertEqual(event_utils.detector_densities(sentinel.event, list(range(2))), [sentinel.density] * 2) mock_detector_ids.assert_not_called() - self.assertEqual(event_utils.detector_densities(sentinel.event), - [sentinel.density] * 4) + self.assertEqual(event_utils.detector_densities(sentinel.event), [sentinel.density] * 4) mock_detector_ids.assert_called_once_with(None, sentinel.event) - self.assertEqual(event_utils.detector_densities(sentinel.event, station=sentinel.station), - [sentinel.density] * 4) + self.assertEqual( + event_utils.detector_densities(sentinel.event, station=sentinel.station), + [sentinel.density] * 4, + ) mock_detector_ids.assert_called_with(sentinel.station, sentinel.event) class DetectorDensityTests(unittest.TestCase): - def setUp(self): self.event = MagicMock() self.station = MagicMock() def test_detector_density(self): - self.event.__getitem__.side_effect = lambda name: 2 + self.event.__getitem__.side_effect = repeat(2) self.assertEqual(event_utils.detector_density(self.event, 0), 4) self.event.__getitem__.assert_called_with('n1') def test_no_good_detector_density(self): - self.event.__getitem__.side_effect = lambda name: -999 + self.event.__getitem__.side_effect = repeat(-999) self.assertTrue(isnan(event_utils.detector_density(self.event, 0))) self.event.__getitem__.assert_called_with('n1') class StationArrivalTimeTests(unittest.TestCase): - @patch.object(event_utils, 'detector_arrival_times') @patch.object(event_utils, 'get_detector_ids') def test_station_arrival_time(self, mock_detector_ids, mock_detector_arrival_times): mock_detector_ids.return_value = list(range(4)) - mock_detector_arrival_times.return_value = [7.5, 5., 2.5, 5.] + mock_detector_arrival_times.return_value = [7.5, 5.0, 2.5, 5.0] event_dict = {'t_trigger': 10, 'ext_timestamp': 1000} event = MagicMock() event.__getitem__.side_effect = lambda name: event_dict[name] ref_ets = 500 rel_arrival_time = event_dict['ext_timestamp'] - ref_ets - event_dict['t_trigger'] - self.assertEqual(event_utils.station_arrival_time(event, ref_ets, list(range(4)), sentinel.offsets, sentinel.station), - rel_arrival_time + 2.5) + self.assertEqual( + event_utils.station_arrival_time(event, ref_ets, list(range(4)), sentinel.offsets, sentinel.station), + rel_arrival_time + 2.5, + ) mock_detector_ids.assert_not_called() - self.assertEqual(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets), - rel_arrival_time + 2.5) + self.assertEqual( + event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets), + rel_arrival_time + 2.5, + ) mock_detector_ids.assert_called_once_with(None, event) - self.assertEqual(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station), - rel_arrival_time + 2.5) + self.assertEqual( + event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station), + rel_arrival_time + 2.5, + ) mock_detector_ids.assert_called_with(sentinel.station, event) @patch.object(event_utils, 'detector_arrival_times') @patch.object(event_utils, 'get_detector_ids') def test_nan_station_arrival_time(self, mock_detector_ids, mock_detector_arrival_times): mock_detector_ids.return_value = list(range(4)) - mock_detector_arrival_times.return_value = [7.5, 5., nan, 5.] + mock_detector_arrival_times.return_value = [7.5, 5.0, nan, 5.0] event_dict = {'t_trigger': 10, 'ext_timestamp': 1000} event = MagicMock() event.__getitem__.side_effect = lambda name: event_dict[name] ref_ets = 500 rel_arrival_time = event_dict['ext_timestamp'] - ref_ets - event_dict['t_trigger'] - self.assertEqual(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station), - rel_arrival_time + 5) + self.assertEqual( + event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station), + rel_arrival_time + 5, + ) event_dict['t_trigger'] = -999 - self.assertTrue(isnan(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station))) + self.assertTrue( + isnan(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station)), + ) event_dict['t_trigger'] = nan - self.assertTrue(isnan(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station))) + self.assertTrue( + isnan(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station)), + ) event_dict['t_trigger'] = 10 mock_detector_arrival_times.return_value = [nan, nan, nan, nan] with warnings.catch_warnings(record=True) as warned: - self.assertTrue(isnan(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station))) + self.assertTrue( + isnan(event_utils.station_arrival_time(event, ref_ets, None, sentinel.offsets, sentinel.station)), + ) self.assertEqual(len(warned), 1) class RelativeDetectorArrivalTimesTests(unittest.TestCase): - @patch.object(event_utils, 'detector_arrival_times') @patch.object(event_utils, 'get_detector_ids') def test_relative_detector_arrival_times(self, mock_detector_ids, mock_detector_arrival_times): mock_detector_ids.return_value = list(range(4)) - mock_detector_arrival_times.return_value = [7.5, 5., 2.5, 5.] + mock_detector_arrival_times.return_value = [7.5, 5.0, 2.5, 5.0] event_dict = {'t_trigger': 10, 'ext_timestamp': 1000} event = MagicMock() event.__getitem__.side_effect = lambda name: event_dict[name] ref_ets = 500 rel_arrival_time = event_dict['ext_timestamp'] - ref_ets - event_dict['t_trigger'] - self.assertEqual(event_utils.relative_detector_arrival_times(event, 500, list(range(4)), sentinel.offsets, sentinel.station), - [rel_arrival_time + t for t in mock_detector_arrival_times()]) + self.assertEqual( + event_utils.relative_detector_arrival_times(event, 500, list(range(4)), sentinel.offsets, sentinel.station), + [rel_arrival_time + t for t in mock_detector_arrival_times()], + ) mock_detector_ids.assert_not_called() - self.assertEqual(event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets), - [rel_arrival_time + t for t in mock_detector_arrival_times()]) + self.assertEqual( + event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets), + [rel_arrival_time + t for t in mock_detector_arrival_times()], + ) mock_detector_ids.assert_called_once_with(None, event) - self.assertEqual(event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets, sentinel.station), - [rel_arrival_time + t for t in mock_detector_arrival_times()]) + self.assertEqual( + event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets, sentinel.station), + [rel_arrival_time + t for t in mock_detector_arrival_times()], + ) mock_detector_ids.assert_called_with(sentinel.station, event) @patch.object(event_utils, 'detector_arrival_times') @patch.object(event_utils, 'get_detector_ids') def test_nan_relative_detector_arrival_times(self, mock_detector_ids, mock_detector_arrival_times): mock_detector_ids.return_value = list(range(4)) - mock_detector_arrival_times.return_value = [7.5, 5., 5., nan] + mock_detector_arrival_times.return_value = [7.5, 5.0, 5.0, nan] event_dict = {'t_trigger': 10, 'ext_timestamp': 1000} event = MagicMock() event.__getitem__.side_effect = lambda name: event_dict[name] @@ -147,43 +162,52 @@ def test_nan_relative_detector_arrival_times(self, mock_detector_ids, mock_detec self.assertEqual(result[:-1], [rel_arrival_time + t for t in mock_detector_arrival_times()[:-1]]) self.assertTrue(isnan(result[-1])) event_dict['t_trigger'] = -999 - self.assertTrue(isnan(event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets, sentinel.station)).all()) + self.assertTrue( + isnan( + event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets, sentinel.station), + ).all(), + ) event_dict['t_trigger'] = nan - self.assertTrue(isnan(event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets, sentinel.station)).all()) + self.assertTrue( + isnan( + event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets, sentinel.station), + ).all(), + ) event_dict['t_trigger'] = 10 mock_detector_arrival_times.return_value = [nan, nan, nan, nan] - self.assertTrue(isnan(event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets, sentinel.station)).all()) + self.assertTrue( + isnan( + event_utils.relative_detector_arrival_times(event, 500, None, sentinel.offsets, sentinel.station), + ).all(), + ) class DetectorArrivalTimesTests(unittest.TestCase): - @patch.object(event_utils, 'detector_arrival_time') @patch.object(event_utils, 'get_detector_ids') def test_detector_arrival_times(self, mock_detector_ids, mock_detector_arrival_time): mock_detector_ids.return_value = list(range(4)) mock_detector_arrival_time.return_value = sentinel.time - self.assertEqual(event_utils.detector_arrival_times(sentinel.event, list(range(4))), - [sentinel.time] * 4) - self.assertEqual(event_utils.detector_arrival_times(sentinel.event, list(range(2))), - [sentinel.time] * 2) + self.assertEqual(event_utils.detector_arrival_times(sentinel.event, list(range(4))), [sentinel.time] * 4) + self.assertEqual(event_utils.detector_arrival_times(sentinel.event, list(range(2))), [sentinel.time] * 2) mock_detector_ids.assert_not_called() - self.assertEqual(event_utils.detector_arrival_times(sentinel.event), - [sentinel.time] * 4) + self.assertEqual(event_utils.detector_arrival_times(sentinel.event), [sentinel.time] * 4) mock_detector_ids.assert_called_once_with(None, sentinel.event) - self.assertEqual(event_utils.detector_arrival_times(sentinel.event, station=sentinel.station), - [sentinel.time] * 4) + self.assertEqual( + event_utils.detector_arrival_times(sentinel.event, station=sentinel.station), + [sentinel.time] * 4, + ) mock_detector_ids.assert_called_with(sentinel.station, sentinel.event) class DetectorArrivalTimeTests(unittest.TestCase): - def setUp(self): self.event = MagicMock() self.offsets = [1, 2, 3, 4] def test_detector_arrival_time(self): - self.event.__getitem__.side_effect = lambda name: 2.5 + self.event.__getitem__.side_effect = repeat(2.5) self.assertEqual(event_utils.detector_arrival_time(self.event, 0), 2.5) self.event.__getitem__.assert_called_with('t1') self.assertEqual(event_utils.detector_arrival_time(self.event, 0, self.offsets), 1.5) @@ -192,7 +216,7 @@ def test_detector_arrival_time(self): self.event.__getitem__.assert_called_with('t2') def test_no_good_detector_arrival_time(self): - self.event.__getitem__.side_effect = lambda name: -999 + self.event.__getitem__.side_effect = repeat(-999) self.assertTrue(isnan(event_utils.detector_arrival_time(self.event, 0))) self.event.__getitem__.assert_called_with('t1') self.assertTrue(isnan(event_utils.detector_arrival_time(self.event, 0, self.offsets))) @@ -200,13 +224,12 @@ def test_no_good_detector_arrival_time(self): class GetDetectorIdsTests(unittest.TestCase): - def test_get_detector_ids(self): self.assertEqual(event_utils.get_detector_ids(), list(range(4))) station = MagicMock() station.detectors.__len__.return_value = 2 self.assertEqual(event_utils.get_detector_ids(station=station), list(range(2))) event = MagicMock() - event.__getitem__.side_effect = lambda name: [10, 100, 40, -1] + event.__getitem__.side_effect = repeat([10, 100, 40, -1]) self.assertEqual(event_utils.get_detector_ids(event=event), list(range(3))) self.assertEqual(event_utils.get_detector_ids(station=station, event=event), list(range(2))) diff --git a/sapphire/tests/analysis/test_find_mpv.py b/sapphire/tests/analysis/test_find_mpv.py index 13cec795..4cc468c2 100644 --- a/sapphire/tests/analysis/test_find_mpv.py +++ b/sapphire/tests/analysis/test_find_mpv.py @@ -7,7 +7,6 @@ class FindMostProbableValueInSpectrumTest(unittest.TestCase): - def test_failing_fit(self): """Check for correct warnings/errors for failing fit""" @@ -19,7 +18,7 @@ def test_failing_fit(self): first_guess = fmpv.find_first_guess_mpv() with self.assertRaises(RuntimeError) as cm: fmpv.fit_mpv(first_guess) - self.assertEqual(str(cm.exception), "Number of data points not sufficient") + self.assertEqual(str(cm.exception), 'Number of data points not sufficient') # Warning from the find mpv function with warnings.catch_warnings(record=True) as w: @@ -27,7 +26,7 @@ def test_failing_fit(self): # https://bugs.python.org/issue4180 if hasattr(find_mpv, '__warningregistry__'): find_mpv.__warningregistry__ = {} - warnings.simplefilter("always") + warnings.simplefilter('always') mpv, is_fitted = fmpv.find_mpv() self.assertTrue(issubclass(w[0].category, UserWarning)) self.assertEqual(mpv, -999) @@ -42,4 +41,4 @@ def test_bad_fit(self): fmpv = find_mpv.FindMostProbableValueInSpectrum(n, bins) with self.assertRaises(RuntimeError) as cm: fmpv.fit_mpv(111) - self.assertEqual(str(cm.exception), "Fitted MPV value outside range") + self.assertEqual(str(cm.exception), 'Fitted MPV value outside range') diff --git a/sapphire/tests/analysis/test_landau.py b/sapphire/tests/analysis/test_landau.py index 355c96ce..e8dc9ee6 100644 --- a/sapphire/tests/analysis/test_landau.py +++ b/sapphire/tests/analysis/test_landau.py @@ -6,19 +6,17 @@ class LandauTest(unittest.TestCase): - def test_pdf_mpv(self): """Check if peak of Landau pdf is at correct place The peak of the Landau should be around -0.22 """ - x = linspace(-.4, 0, 100) + x = linspace(-0.4, 0, 100) self.assertAlmostEqual(x[landau.pdf(x).argmax()] + 0.222, 0, 2) class ScintillatorTest(unittest.TestCase): - def setUp(self): self.scin = landau.Scintillator() @@ -26,5 +24,5 @@ def test_pdf(self): """Check if the integral of the Landau pdf is almost 1""" self.scin.pdf(0) - step_size = (self.scin.full_domain[-1] - self.scin.full_domain[-2]) + step_size = self.scin.full_domain[-1] - self.scin.full_domain[-2] self.assertAlmostEqual(self.scin.pdf_values.sum() * step_size, 1, 1) diff --git a/sapphire/tests/analysis/test_process_events.py b/sapphire/tests/analysis/test_process_events.py index 6bfe7257..52b2db45 100644 --- a/sapphire/tests/analysis/test_process_events.py +++ b/sapphire/tests/analysis/test_process_events.py @@ -21,15 +21,13 @@ class ProcessEventsTests(unittest.TestCase): def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) self.data_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.data_path) self.data = tables.open_file(self.data_path, 'a') + self.addCleanup(self.data.close) self.proc = process_events.ProcessEvents(self.data, DATA_GROUP, progress=False) - def tearDown(self): - warnings.resetwarnings() - self.data.close() - os.remove(self.data_path) - def test_get_traces_for_event(self): event = self.proc.source[0] self.assertEqual(self.proc.get_traces_for_event(event)[12][3], 1334) @@ -38,7 +36,7 @@ def test__find_unique_row_ids(self): ext_timestamps = self.proc.source.col('ext_timestamp') enumerated_timestamps = list(enumerate(ext_timestamps)) enumerated_timestamps.sort(key=operator.itemgetter(1)) - ids_in = [id for id, _ in enumerated_timestamps] + ids_in = [row_id for row_id, _ in enumerated_timestamps] ids = self.proc._find_unique_row_ids(enumerated_timestamps) self.assertEqual(ids, ids_in) @@ -74,12 +72,12 @@ def test_first_above_threshold(self): self.assertEqual(self.proc.first_above_threshold(trace, 4), 2) self.assertEqual(self.proc.first_above_threshold(trace, 5), -999) -# @patch.object(process_events.FindMostProbableValueInSpectrum, 'find_mpv') + # @patch.object(process_events.FindMostProbableValueInSpectrum, 'find_mpv') def test__process_pulseintegrals(self): self.proc.limit = 1 -# mock_find_mpv.return_value = (-999, False) + # mock_find_mpv.return_value = (-999, False) # Because of small data sample fit fails for detector 1 - self.assertEqual(self.proc._process_pulseintegrals()[0][1], -999.) + self.assertEqual(self.proc._process_pulseintegrals()[0][1], -999.0) self.assertAlmostEqual(self.proc._process_pulseintegrals()[0][3], 3.98951741969) self.proc.limit = None @@ -103,7 +101,9 @@ class ProcessIndexedEventsTests(ProcessEventsTests): def setUp(self): warnings.filterwarnings('ignore') self.data_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.data_path) self.data = tables.open_file(self.data_path, 'a') + self.addCleanup(self.data.close) self.proc = process_events.ProcessIndexedEvents(self.data, DATA_GROUP, [0, 10], progress=False) def test_process_traces(self): @@ -118,8 +118,11 @@ def test_get_traces_for_indexed_event_index(self): class ProcessEventsWithLINTTests(ProcessEventsTests): def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) self.data_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.data_path) self.data = tables.open_file(self.data_path, 'a') + self.addCleanup(self.data.close) self.proc = process_events.ProcessEventsWithLINT(self.data, DATA_GROUP, progress=False) def test__reconstruct_time_from_traces(self): @@ -139,8 +142,11 @@ def test__reconstruct_time_from_trace(self): class ProcessEventsWithTriggerOffsetTests(ProcessEventsTests): def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) self.data_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.data_path) self.data = tables.open_file(self.data_path, 'a') + self.addCleanup(self.data.close) self.proc = process_events.ProcessEventsWithTriggerOffset(self.data, DATA_GROUP, progress=False) def test__reconstruct_time_from_traces(self): @@ -162,13 +168,28 @@ def test__first_above_thresholds(self): # 2 detectors self.assertEqual(self.proc._first_above_thresholds((x for x in [200, 200, 900]), [300, 400], 900), [2, 2, -999]) self.assertEqual(self.proc._first_above_thresholds((x for x in [200, 200, 400]), [300, 400], 400), [2, 2, -999]) - self.assertEqual(self.proc._first_above_thresholds((x for x in [200, 350, 450, 550]), [300, 400], 550), [1, 2, -999]) + self.assertEqual( + self.proc._first_above_thresholds((x for x in [200, 350, 450, 550]), [300, 400], 550), + [1, 2, -999], + ) # 4 detectors - self.assertEqual(self.proc._first_above_thresholds((x for x in [200, 200, 900]), [300, 400, 500], 900), [2, 2, 2]) - self.assertEqual(self.proc._first_above_thresholds((x for x in [200, 200, 400]), [300, 400, 500], 400), [2, 2, -999]) - self.assertEqual(self.proc._first_above_thresholds((x for x in [200, 350, 450, 550]), [300, 400, 500], 550), [1, 2, 3]) + self.assertEqual( + self.proc._first_above_thresholds((x for x in [200, 200, 900]), [300, 400, 500], 900), + [2, 2, 2], + ) + self.assertEqual( + self.proc._first_above_thresholds((x for x in [200, 200, 400]), [300, 400, 500], 400), + [2, 2, -999], + ) + self.assertEqual( + self.proc._first_above_thresholds((x for x in [200, 350, 450, 550]), [300, 400, 500], 550), + [1, 2, 3], + ) # No signal - self.assertEqual(self.proc._first_above_thresholds((x for x in [200, 250, 200, 2000]), [300, 400, 500], 250), [-999, -999, -999]) + self.assertEqual( + self.proc._first_above_thresholds((x for x in [200, 250, 200, 2000]), [300, 400, 500], 250), + [-999, -999, -999], + ) def test__first_value_above_threshold(self): trace = [200, 200, 300, 200] @@ -247,53 +268,70 @@ def test__reconstruct_trigger(self): class ProcessEventsFromSourceTests(ProcessEventsTests): def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) self.source_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.source_path) self.source_data = tables.open_file(self.source_path, 'r') + self.addCleanup(self.source_data.close) self.dest_path = self.create_tempfile_path() + self.addCleanup(os.remove, self.dest_path) self.dest_data = tables.open_file(self.dest_path, 'a') - self.proc = process_events.ProcessEventsFromSource( - self.source_data, self.dest_data, DATA_GROUP, DATA_GROUP) - - def tearDown(self): - warnings.resetwarnings() - self.source_data.close() - os.remove(self.source_path) - self.dest_data.close() - os.remove(self.dest_path) + self.addCleanup(self.dest_data.close) + self.proc = process_events.ProcessEventsFromSource(self.source_data, self.dest_data, DATA_GROUP, DATA_GROUP) def test_process_and_store_results(self): self.proc.process_and_store_results() -class ProcessEventsFromSourceWithTriggerOffsetTests(ProcessEventsFromSourceTests, - ProcessEventsWithTriggerOffsetTests): +class ProcessEventsFromSourceWithTriggerOffsetTests(ProcessEventsFromSourceTests, ProcessEventsWithTriggerOffsetTests): def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) self.source_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.source_path) self.source_data = tables.open_file(self.source_path, 'r') + self.addCleanup(self.source_data.close) self.dest_path = self.create_tempfile_path() + self.addCleanup(os.remove, self.dest_path) self.dest_data = tables.open_file(self.dest_path, 'a') + self.addCleanup(self.dest_data.close) self.proc = process_events.ProcessEventsFromSourceWithTriggerOffset( - self.source_data, self.dest_data, DATA_GROUP, DATA_GROUP) + self.source_data, + self.dest_data, + DATA_GROUP, + DATA_GROUP, + ) -class ProcessEventsFromSourceWithTriggerOffsetStationTests(ProcessEventsFromSourceTests, - ProcessEventsWithTriggerOffsetTests): +class ProcessEventsFromSourceWithTriggerOffsetStationTests( + ProcessEventsFromSourceTests, + ProcessEventsWithTriggerOffsetTests, +): def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) self.source_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.source_path) self.source_data = tables.open_file(self.source_path, 'r') + self.addCleanup(self.source_data.close) self.dest_path = self.create_tempfile_path() + self.addCleanup(os.remove, self.dest_path) self.dest_data = tables.open_file(self.dest_path, 'a') + self.addCleanup(self.dest_data.close) self.proc = process_events.ProcessEventsFromSourceWithTriggerOffset( - self.source_data, self.dest_data, DATA_GROUP, DATA_GROUP, - station=501) + self.source_data, + self.dest_data, + DATA_GROUP, + DATA_GROUP, + station=501, + ) def test__reconstruct_time_from_traces_with_external(self): mock_trigger = Mock() - mock_trigger.return_value = ([(process_events.ADC_LOW_THRESHOLD, - process_events.ADC_HIGH_THRESHOLD)] * 4, - [0, 0, 0, 1]) + mock_trigger.return_value = ( + [(process_events.ADC_LOW_THRESHOLD, process_events.ADC_HIGH_THRESHOLD)] * 4, + [0, 0, 0, 1], + ) self.proc.station.trigger = mock_trigger event = self.proc.source[10] @@ -306,15 +344,12 @@ def test__reconstruct_time_from_traces_with_external(self): class ProcessSinglesTests(unittest.TestCase): def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) self.data_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.data_path) self.data = tables.open_file(self.data_path, 'a') - self.proc = process_events.ProcessSingles(self.data, DATA_GROUP, - progress=False) - - def tearDown(self): - warnings.resetwarnings() - self.data.close() - os.remove(self.data_path) + self.addCleanup(self.data.close) + self.proc = process_events.ProcessSingles(self.data, DATA_GROUP, progress=False) def test_process_and_store_results(self): self.proc.process_and_store_results() @@ -343,19 +378,16 @@ def get_testdata_path(self): class ProcessSinglesFromSourceTests(ProcessSinglesTests): def setUp(self): warnings.filterwarnings('ignore') + self.addCleanup(warnings.resetwarnings) self.source_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.source_path) self.source_data = tables.open_file(self.source_path, 'r') + self.addCleanup(self.source_data.close) self.dest_path = self.create_tempfile_path() + self.addCleanup(os.remove, self.dest_path) self.dest_data = tables.open_file(self.dest_path, 'a') - self.proc = process_events.ProcessSinglesFromSource( - self.source_data, self.dest_data, DATA_GROUP, '/') - - def tearDown(self): - warnings.resetwarnings() - self.source_data.close() - os.remove(self.source_path) - self.dest_data.close() - os.remove(self.dest_path) + self.addCleanup(self.dest_data.close) + self.proc = process_events.ProcessSinglesFromSource(self.source_data, self.dest_data, DATA_GROUP, '/') def test_process_and_store_results(self): self.proc.process_and_store_results() diff --git a/sapphire/tests/analysis/test_process_traces.py b/sapphire/tests/analysis/test_process_traces.py index eefcf5f3..698c4d52 100644 --- a/sapphire/tests/analysis/test_process_traces.py +++ b/sapphire/tests/analysis/test_process_traces.py @@ -9,7 +9,6 @@ class TraceObservablesTests(unittest.TestCase): - def setUp(self): trace = [200] * 400 + [500] + [510] + [400] * 10 + [200] * 600 + [400] * 10 + [200] trace2 = [203, 199] * 200 + [500] + [510] + [398, 402] * 5 + [203, 199] * 300 + [400] * 10 + [200] @@ -33,7 +32,6 @@ def test_n_peaks(self): class MeanFilterTests(unittest.TestCase): - def setUp(self): self.trace = [[200] * 400 + [500] + [400] * 20 + [200] * 600] self.traces = self.trace * 2 @@ -51,15 +49,14 @@ def test_init(self): @patch.object(process_traces.MeanFilter, 'filter_trace') def test_filter_traces(self, mock_filter_trace): mock_filter_trace.return_value = sentinel.filtered_trace - self.assertEqual(self.mf.filter_traces(self.traces), - [sentinel.filtered_trace, sentinel.filtered_trace]) + self.assertEqual(self.mf.filter_traces(self.traces), [sentinel.filtered_trace, sentinel.filtered_trace]) def test_filter_trace(self): mock_filter = MagicMock() self.mf.filter = mock_filter - mock_filter.side_effect = cycle([[sentinel.filtered_even] * 2, - [sentinel.filtered_odd] * 2, - [sentinel.filtered_recombined]]) + mock_filter.side_effect = cycle( + [[sentinel.filtered_even] * 2, [sentinel.filtered_odd] * 2, [sentinel.filtered_recombined]], + ) trace_segment = [sentinel.trace_even, sentinel.trace_odd] filtered_trace = self.mf.filter_trace(trace_segment * 4) @@ -148,7 +145,6 @@ def test_mean_filter_without_threshold(self): class DataReductionTests(unittest.TestCase): - def setUp(self): self.dr = process_traces.DataReduction() @@ -161,8 +157,14 @@ def test_reduce_traces(self): pre = 400 post = 300 baseline = 200 - trace = ([baseline] * pre + [baseline + 50] + [baseline + 60] * 4 + - [baseline] * 600 + [baseline + 90] * 5 + [baseline] * post) + trace = ( + [baseline] * pre + + [baseline + 50] + + [baseline + 60] * 4 + + [baseline] * 600 + + [baseline + 90] * 5 + + [baseline] * post + ) traces = array([trace, [baseline] * len(trace)]).T reduced_traces = self.dr.reduce_traces(traces, [baseline] * 2) r_traces, left = self.dr.reduce_traces(traces, [baseline] * 2, True) @@ -175,8 +177,14 @@ def test_reduce_traces(self): pre = 10 post = 10 baseline = 200 - trace = ([baseline] * pre + [baseline + 50] + [baseline + 60] * 4 + - [baseline] * 600 + [baseline + 90] * 5 + [baseline] * post) + trace = ( + [baseline] * pre + + [baseline + 50] + + [baseline + 60] * 4 + + [baseline] * 600 + + [baseline + 90] * 5 + + [baseline] * post + ) traces = array([trace, [baseline] * len(trace)]).T reduced_traces = self.dr.reduce_traces(traces, [baseline] * 2) r_traces, left = self.dr.reduce_traces(traces, [baseline] * 2, True) @@ -189,8 +197,14 @@ def test_determine_cuts(self): pre = 400 post = 300 baseline = 200 - trace = ([baseline] * pre + [baseline + 50] + [baseline + 60] * 4 + - [baseline] * 600 + [baseline + 90] * 5 + [baseline] * post) + trace = ( + [baseline] * pre + + [baseline + 50] + + [baseline + 60] * 4 + + [baseline] * 600 + + [baseline + 90] * 5 + + [baseline] * post + ) traces = array([trace, [baseline] * len(trace)]).T left, right = self.dr.determine_cuts(traces, [baseline] * 2) self.assertEqual(left, pre) @@ -206,11 +220,13 @@ def test_determine_cuts(self): self.assertEqual(right, length) def test_add_padding(self): - combinations = (((0, 20), (0, 46)), # left at limit - ((4, 20), (0, 46)), # left close to limit - ((50, 2400), (24, 2426)), # left far from limit - ((50, 2400, 2400), (24, 2400)), # right at limit - ((50, 2400, 2410), (24, 2410)), # right close to limit - ((0, 200, 2400), (0, 226)),) # right far from limit - for input, expected in combinations: - self.assertEqual(self.dr.add_padding(*input), expected) + combinations = ( + ((0, 20), (0, 46)), # left at limit + ((4, 20), (0, 46)), # left close to limit + ((50, 2400), (24, 2426)), # left far from limit + ((50, 2400, 2400), (24, 2400)), # right at limit + ((50, 2400, 2410), (24, 2410)), # right close to limit + ((0, 200, 2400), (0, 226)), + ) # right far from limit + for args, expected in combinations: + self.assertEqual(self.dr.add_padding(*args), expected) diff --git a/sapphire/tests/analysis/test_reconstructions.py b/sapphire/tests/analysis/test_reconstructions.py index b8bba1da..0109703e 100644 --- a/sapphire/tests/analysis/test_reconstructions.py +++ b/sapphire/tests/analysis/test_reconstructions.py @@ -1,6 +1,7 @@ import os import unittest +from itertools import repeat from unittest.mock import MagicMock, patch, sentinel import tables @@ -11,14 +12,18 @@ class ReconstructESDEventsTest(unittest.TestCase): - def setUp(self): self.data = MagicMock() self.station = MagicMock(spec=reconstructions.Station) self.rec = reconstructions.ReconstructESDEvents( - self.data, sentinel.station_group, self.station, - overwrite=sentinel.overwrite, progress=sentinel.progress, - verbose=False, destination=sentinel.destination) + self.data, + sentinel.station_group, + self.station, + overwrite=sentinel.overwrite, + progress=False, + verbose=False, + destination=sentinel.destination, + ) def test_init(self): rec = self.rec @@ -29,10 +34,10 @@ def test_init(self): self.assertEqual(rec.events, self.data.get_node.return_value.events) self.assertEqual(rec.overwrite, sentinel.overwrite) - self.assertEqual(rec.progress, sentinel.progress) + self.assertFalse(rec.progress) self.assertFalse(rec.verbose) self.assertEqual(rec.destination, sentinel.destination) - self.assertEqual(rec.offsets, [0.] * 4) + self.assertEqual(rec.offsets, [0.0] * 4) self.assertEqual(rec.station, self.station) @@ -50,35 +55,51 @@ def test_reconstruct_directions(self): self.rec.direction.reconstruct_events.return_value = (sentinel.theta, sentinel.phi, sentinel.ids) self.rec.reconstruct_directions() self.rec.direction.reconstruct_events.assert_called_once_with( - self.rec.events, None, self.rec.offsets, self.rec.progress, []) + self.rec.events, + None, + self.rec.offsets, + self.rec.progress, + [], + ) self.assertEqual(self.rec.theta, sentinel.theta) self.assertEqual(self.rec.phi, sentinel.phi) self.assertEqual(self.rec.detector_ids, sentinel.ids) self.rec.reconstruct_directions(sentinel.detector_ids) self.rec.direction.reconstruct_events.assert_called_with( - self.rec.events, sentinel.detector_ids, self.rec.offsets, self.rec.progress, []) + self.rec.events, + sentinel.detector_ids, + self.rec.offsets, + self.rec.progress, + [], + ) def test_reconstruct_cores(self): self.rec.core = MagicMock() self.rec.core.reconstruct_events.return_value = (sentinel.core_x, sentinel.core_y) self.rec.reconstruct_cores() - self.rec.core.reconstruct_events.assert_called_once_with( - self.rec.events, None, self.rec.progress, []) + self.rec.core.reconstruct_events.assert_called_once_with(self.rec.events, None, self.rec.progress, []) self.assertEqual(self.rec.core_x, sentinel.core_x) self.assertEqual(self.rec.core_y, sentinel.core_y) self.rec.reconstruct_cores(sentinel.detector_ids) self.rec.core.reconstruct_events.assert_called_with( - self.rec.events, sentinel.detector_ids, self.rec.progress, []) + self.rec.events, + sentinel.detector_ids, + self.rec.progress, + [], + ) def test_prepare_output(self): self.rec.events = MagicMock() self.rec.events.nrows = sentinel.nrows self.rec.prepare_output() self.data.create_table.assert_called_once_with( - self.rec.station_group, sentinel.destination, - reconstructions.ReconstructedEvent, expectedrows=sentinel.nrows) + self.rec.station_group, + sentinel.destination, + reconstructions.ReconstructedEvent, + expectedrows=sentinel.nrows, + ) self.assertEqual(self.rec.reconstructions, self.data.create_table.return_value) self.assertEqual(self.rec.reconstructions._v_attrs.station, self.station) @@ -90,11 +111,13 @@ def test_prepare_output_existing(self): # Overwrite existing self.rec.overwrite = True self.rec.prepare_output() - self.data.remove_node.assert_called_once_with( - self.rec.station_group, sentinel.destination, recursive=True) + self.data.remove_node.assert_called_once_with(self.rec.station_group, sentinel.destination, recursive=True) self.data.create_table.assert_called_with( - self.rec.station_group, sentinel.destination, - reconstructions.ReconstructedEvent, expectedrows=sentinel.nrows) + self.rec.station_group, + sentinel.destination, + reconstructions.ReconstructedEvent, + expectedrows=sentinel.nrows, + ) self.assertEqual(self.rec.reconstructions, self.data.create_table.return_value) self.assertEqual(self.rec.reconstructions._v_attrs.station, self.station) @@ -112,8 +135,7 @@ def test_get_detector_offsets(self, mock_determine_detctor_timing_offets, mock_s # no offsets in station object no station_number -> # determine offsets from events self.rec.get_detector_offsets() - mock_determine_detctor_timing_offets.assert_called_with( - sentinel.events, self.station) + mock_determine_detctor_timing_offets.assert_called_with(sentinel.events, self.station) # no offsets in station object and station number -> api.Station self.rec.station_number = sentinel.station @@ -131,24 +153,28 @@ def test__store_reconstruction(self): # _store_reconstruction calls min(event['n1'], ...). # but MagicMock is unordered in python 3! # Mock a dict that always returns 42. - event.__getitem__.side_effect = lambda x: 42. + event.__getitem__.side_effect = repeat(42.0) self.rec.reconstructions = MagicMock() - self.rec._store_reconstruction(event, sentinel.core_x, sentinel.core_y, - sentinel.theta, sentinel.phi, [1, 3, 4]) + self.rec._store_reconstruction(event, sentinel.core_x, sentinel.core_y, sentinel.theta, sentinel.phi, [1, 3, 4]) self.rec.reconstructions.row.append.assert_called_once_with() class ReconstructESDEventsFromSourceTest(ReconstructESDEventsTest): - def setUp(self): self.data = MagicMock() self.dest_data = MagicMock() self.station = MagicMock(spec=reconstructions.Station) self.rec = reconstructions.ReconstructESDEventsFromSource( - self.data, self.dest_data, sentinel.station_group, - sentinel.dest_group, self.station, overwrite=sentinel.overwrite, - progress=sentinel.progress, verbose=False, - destination=sentinel.destination) + self.data, + self.dest_data, + sentinel.station_group, + sentinel.dest_group, + self.station, + overwrite=sentinel.overwrite, + progress=False, + verbose=False, + destination=sentinel.destination, + ) @unittest.skip('WIP') def test_prepare_output(self): @@ -160,7 +186,6 @@ def test_prepare_output_existing(self): class ReconstructSimulatedEventsTest(unittest.TestCase): - def setUp(self): self.data = None @@ -172,24 +197,27 @@ def test_station_is_object(self): self.data = MagicMock() station = MagicMock(spec=reconstructions.Station) rec = reconstructions.ReconstructSimulatedEvents( - self.data, sentinel.station_group, station, - overwrite=sentinel.overwrite, progress=sentinel.progress, - verbose=False, destination=sentinel.destination) + self.data, + sentinel.station_group, + station, + overwrite=sentinel.overwrite, + progress=False, + verbose=False, + destination=sentinel.destination, + ) self.assertEqual(rec.station, station) def test_read_object_from_hdf5(self): fn = self.get_testdata_path(TEST_DATA_FILE) self.data = tables.open_file(fn, 'r') station_group = '/cluster_simulations/station_0' - rec = reconstructions.ReconstructSimulatedEvents( - self.data, station_group, 0) + rec = reconstructions.ReconstructSimulatedEvents(self.data, station_group, 0) # isinstance does not work on classes that are read from pickles. self.assertEqual(rec.station.station_id, 0) with self.assertRaises(RuntimeError): - rec = reconstructions.ReconstructSimulatedEvents( - self.data, station_group, -999) + rec = reconstructions.ReconstructSimulatedEvents(self.data, station_group, -999) def get_testdata_path(self, fn): dir_path = os.path.dirname(__file__) @@ -197,16 +225,20 @@ def get_testdata_path(self, fn): class ReconstructESDCoincidencesTest(unittest.TestCase): - @patch.object(reconstructions, 'CoincidenceQuery') def setUp(self, mock_cq): self.data = MagicMock() self.cluster = MagicMock() self.cq = mock_cq self.rec = reconstructions.ReconstructESDCoincidences( - self.data, sentinel.coin_group, overwrite=sentinel.overwrite, - progress=sentinel.progress, verbose=False, - destination=sentinel.destination, cluster=self.cluster) + self.data, + sentinel.coin_group, + overwrite=sentinel.overwrite, + progress=False, + verbose=False, + destination=sentinel.destination, + cluster=self.cluster, + ) def test_init(self): rec = self.rec @@ -217,7 +249,7 @@ def test_init(self): self.assertEqual(rec.coincidences, self.data.get_node.return_value.coincidences) self.assertEqual(rec.overwrite, sentinel.overwrite) - self.assertEqual(rec.progress, sentinel.progress) + self.assertFalse(rec.progress) self.assertFalse(rec.verbose) self.assertEqual(rec.destination, sentinel.destination) self.assertEqual(rec.offsets, {}) @@ -245,13 +277,12 @@ def test_get_station_timing_offsets(self, mock_station): self.assertEqual(list(self.rec.offsets.keys()), [sentinel.number]) self.assertEqual(list(self.rec.offsets.values()), [sentinel.station]) - detector = MagicMock(offset=1.) - station = MagicMock(number=sentinel.number, gps_offset=2., detectors=[detector], - spec=['gps_offset', 'number']) + detector = MagicMock(offset=1.0) + station = MagicMock(number=sentinel.number, gps_offset=2.0, detectors=[detector], spec=['gps_offset', 'number']) self.rec.cluster.stations = [station] self.rec.get_station_timing_offsets() self.assertEqual(list(self.rec.offsets.keys()), [sentinel.number]) - self.assertEqual(list(self.rec.offsets.values()), [[3.]]) + self.assertEqual(list(self.rec.offsets.values()), [[3.0]]) def test_reconstruct_directions(self): self.rec.coincidences = MagicMock() @@ -260,14 +291,24 @@ def test_reconstruct_directions(self): self.rec.direction.reconstruct_coincidences.return_value = (sentinel.theta, sentinel.phi, sentinel.nums) self.rec.reconstruct_directions() self.rec.direction.reconstruct_coincidences.assert_called_once_with( - self.rec.cq.all_events.return_value, None, self.rec.offsets, progress=False, initials=[]) + self.rec.cq.all_events.return_value, + None, + self.rec.offsets, + progress=False, + initials=[], + ) self.assertEqual(self.rec.theta, sentinel.theta) self.assertEqual(self.rec.phi, sentinel.phi) self.assertEqual(self.rec.station_numbers, sentinel.nums) self.rec.reconstruct_directions(sentinel.nums) self.rec.direction.reconstruct_coincidences.assert_called_with( - self.rec.cq.all_events.return_value, sentinel.nums, self.rec.offsets, progress=False, initials=[]) + self.rec.cq.all_events.return_value, + sentinel.nums, + self.rec.offsets, + progress=False, + initials=[], + ) def test_reconstruct_cores(self): self.rec.coincidences = MagicMock() @@ -276,13 +317,21 @@ def test_reconstruct_cores(self): self.rec.core.reconstruct_coincidences.return_value = (sentinel.core_x, sentinel.core_y) self.rec.reconstruct_cores() self.rec.core.reconstruct_coincidences.assert_called_once_with( - self.rec.cq.all_events.return_value, None, progress=False, initials=[]) + self.rec.cq.all_events.return_value, + None, + progress=False, + initials=[], + ) self.assertEqual(self.rec.core_x, sentinel.core_x) self.assertEqual(self.rec.core_y, sentinel.core_y) self.rec.reconstruct_cores(sentinel.nums) self.rec.core.reconstruct_coincidences.assert_called_with( - self.rec.cq.all_events.return_value, sentinel.nums, progress=False, initials=[]) + self.rec.cq.all_events.return_value, + sentinel.nums, + progress=False, + initials=[], + ) def test_prepare_output(self): self.rec.coincidences = MagicMock() @@ -290,8 +339,11 @@ def test_prepare_output(self): self.cluster.stations.return_value = [] self.rec.prepare_output() self.data.create_table.assert_called_once_with( - self.rec.coincidences_group, sentinel.destination, - reconstructions.ReconstructedCoincidence, expectedrows=sentinel.nrows) + self.rec.coincidences_group, + sentinel.destination, + reconstructions.ReconstructedCoincidence, + expectedrows=sentinel.nrows, + ) self.assertEqual(self.rec.reconstructions, self.data.create_table.return_value) self.assertEqual(self.rec.reconstructions._v_attrs.cluster, self.cluster) @@ -304,11 +356,13 @@ def test_prepare_output_existing(self): # Overwrite existing self.rec.overwrite = True self.rec.prepare_output() - self.data.remove_node.assert_called_once_with( - self.rec.coincidences_group, sentinel.destination, recursive=True) + self.data.remove_node.assert_called_once_with(self.rec.coincidences_group, sentinel.destination, recursive=True) self.data.create_table.assert_called_with( - self.rec.coincidences_group, sentinel.destination, - reconstructions.ReconstructedCoincidence, expectedrows=sentinel.nrows) + self.rec.coincidences_group, + sentinel.destination, + reconstructions.ReconstructedCoincidence, + expectedrows=sentinel.nrows, + ) self.assertEqual(self.rec.reconstructions, self.data.create_table.return_value) self.assertEqual(self.rec.reconstructions._v_attrs.cluster, self.cluster) @@ -325,22 +379,22 @@ def test_prepare_output_columns(self, mock_description): self.rec.cluster.stations = [station] self.rec.prepare_output() - mock_description.columns.update.assert_called_once_with( - {'s1': reconstructions.tables.BoolCol(pos=26)}) + mock_description.columns.update.assert_called_once_with({'s1': reconstructions.tables.BoolCol(pos=26)}) self.data.create_table.assert_called_with( - self.rec.coincidences_group, sentinel.destination, mock_description, - expectedrows=sentinel.nrows) + self.rec.coincidences_group, + sentinel.destination, + mock_description, + expectedrows=sentinel.nrows, + ) def test__store_reconstruction(self): coin = MagicMock() self.rec.reconstructions = MagicMock() - self.rec._store_reconstruction(coin, sentinel.core_x, sentinel.core_y, - sentinel.theta, sentinel.phi, [2, 3, 4]) + self.rec._store_reconstruction(coin, sentinel.core_x, sentinel.core_y, sentinel.theta, sentinel.phi, [2, 3, 4]) self.rec.reconstructions.row.append.assert_called_once_with() class ReconstructESDCoincidencesFromSourceTest(ReconstructESDCoincidencesTest): - @patch.object(reconstructions, 'CoincidenceQuery') def setUp(self, mock_cq): self.data = MagicMock() @@ -348,10 +402,16 @@ def setUp(self, mock_cq): self.cluster = MagicMock() self.cq = mock_cq self.rec = reconstructions.ReconstructESDCoincidencesFromSource( - self.data, self.dest_data, sentinel.coin_group, - sentinel.dest_group, overwrite=sentinel.overwrite, - progress=sentinel.progress, verbose=False, - destination=sentinel.destination, cluster=self.cluster) + self.data, + self.dest_data, + sentinel.coin_group, + sentinel.dest_group, + overwrite=sentinel.overwrite, + progress=False, + verbose=False, + destination=sentinel.destination, + cluster=self.cluster, + ) @unittest.skip('WIP') def test_prepare_output(self): @@ -367,7 +427,6 @@ def test_prepare_output_columns(self): class ReconstructSimulatedCoincidencesTest(unittest.TestCase): - def setUp(self): self.data = MagicMock() @@ -379,9 +438,14 @@ def tearDown(self): def test_cluster_is_object(self, mock_cq): cluster = MagicMock() rec = reconstructions.ReconstructSimulatedCoincidences( - self.data, sentinel.coin_group, overwrite=sentinel.overwrite, - progress=sentinel.progress, verbose=False, - destination=sentinel.destination, cluster=cluster) + self.data, + sentinel.coin_group, + overwrite=sentinel.overwrite, + progress=False, + verbose=False, + destination=sentinel.destination, + cluster=cluster, + ) self.assertEqual(rec.cluster, cluster) def test_read_object_from_hdf5(self): diff --git a/sapphire/tests/analysis/test_time_deltas.py b/sapphire/tests/analysis/test_time_deltas.py index 536e2d7c..9385a774 100644 --- a/sapphire/tests/analysis/test_time_deltas.py +++ b/sapphire/tests/analysis/test_time_deltas.py @@ -15,13 +15,11 @@ class ProcessTimeDeltasTests(unittest.TestCase): def setUp(self): self.data_path = self.create_tempfile_from_testdata() + self.addCleanup(os.remove, self.data_path) self.data = tables.open_file(self.data_path, 'a') + self.addCleanup(self.data.close) self.td = time_deltas.ProcessTimeDeltas(self.data, progress=False) - def tearDown(self): - self.data.close() - os.remove(self.data_path) - def test_init(self): self.assertEqual(self.td.progress, False) self.assertEqual(self.td.data, self.data) @@ -38,18 +36,22 @@ def test_get_detector_offsets(self, mock_station): self.td.pairs = {(sentinel.station1, sentinel.station2), (sentinel.station1, sentinel.station3)} self.td.get_detector_offsets() - self.assertEqual(self.td.detector_timing_offsets, - {sentinel.station1: mock_offsets.detector_timing_offset, - sentinel.station2: mock_offsets.detector_timing_offset, - sentinel.station3: mock_offsets.detector_timing_offset}) + self.assertEqual( + self.td.detector_timing_offsets, + { + sentinel.station1: mock_offsets.detector_timing_offset, + sentinel.station2: mock_offsets.detector_timing_offset, + sentinel.station3: mock_offsets.detector_timing_offset, + }, + ) def test_store_time_deltas(self): pair = (501, 502) node_path = '/coincidences/time_deltas/station_%d/station_%d' % pair self.assertRaises(Exception, self.data.get_node, node_path, 'time_deltas') - self.td.store_time_deltas([12345678987654321], [2.5], pair) + self.td.store_time_deltas([12345678_987654321], [2.5], pair) stored_data = self.data.get_node(node_path, 'time_deltas') - self.assertEqual(list(stored_data[0]), [12345678987654321, 12345678, 987654321, 2.5]) + self.assertEqual(list(stored_data[0]), [12345678_987654321, 12345678, 987654321, 2.5]) def create_tempfile_from_testdata(self): tmp_path = self.create_tempfile_path() diff --git a/sapphire/tests/corsika/test_blocks.py b/sapphire/tests/corsika/test_blocks.py index 19328b21..a3ae44da 100644 --- a/sapphire/tests/corsika/test_blocks.py +++ b/sapphire/tests/corsika/test_blocks.py @@ -6,6 +6,7 @@ try: import numba + numba.__version__ # stop flake8 from complaining about unused module except ImportError: numba_available = False @@ -17,47 +18,44 @@ class CorsikaBlocksTests(unittest.TestCase): def setUp(self): self.format = blocks.Format() - def tearDown(self): - pass - def test_validate_block_format(self): """Verify that the block format is logical""" - self.assertEqual((self.format.block_size - 2 * self.format.block_padding_size) / self.format.subblock_size, - self.format.subblocks_per_block, - msg=('The block format ({block}) and sub-block format ' - '({sub_block}) do not agree! block size is {block_size} ' - 'and sub-block size is {sub_block_size}. Block size should' - ' be {subblocks_per_block} times the sub-block size plus ' - 'padding (usually 8 bytes).' - .format(block=self.format.block_format, - sub_block=self.format.subblock_format, - block_size=self.format.block_size, - sub_block_size=self.format.subblock_size, - subblocks_per_block=self.format.subblocks_per_block))) + self.assertEqual( + (self.format.block_size - 2 * self.format.block_padding_size) / self.format.subblock_size, + self.format.subblocks_per_block, + msg=( + f'The block format ({self.format.block_format}) and sub-block format ' + f'({self.format.subblock_format}) do not agree! block size is {self.format.block_size} ' + f'and sub-block size is {self.format.subblock_size}. Block size should' + f' be {self.format.subblocks_per_block} times the sub-block size plus ' + 'padding (usually 8 bytes).' + ), + ) def test_validate_subblock_format(self): """Verify that the subblock format is logical""" - self.assertEqual(self.format.subblock_size / self.format.particle_size, - self.format.particles_per_subblock, - msg=('The sub_block format ({sub_block}) and particle format ' - '({particle}) do not agree! sub-block size is ' - '{sub_block_size} and particle record size is ' - '{particle_size}. Sub-block size should be ' - '{particles_per_subblock} times the particle record size.' - .format(sub_block=self.format.subblock_format, - particle=self.format.particle_format, - sub_block_size=self.format.subblock_size, - particle_size=self.format.particle_size, - particles_per_subblock=self.format.particles_per_subblock))) + self.assertEqual( + self.format.subblock_size / self.format.particle_size, + self.format.particles_per_subblock, + msg=( + f'The sub_block format ({self.format.subblock_format}) and particle format ' + f'({self.format.particle_format}) do not agree! sub-block size is ' + f'{self.format.subblock_size} and particle record size is ' + f'{self.format.particle_size}. Sub-block size should be ' + f'{self.format.particles_per_subblock} times the particle record size.' + ), + ) def test_validate_particle_format(self): """Verify that the particle format is correct""" - self.assertEqual(self.format.particle_format, '7f', - msg=('The particle format ({particle}) is incorrect.' - .format(particle=self.format.particle_format))) + self.assertEqual( + self.format.particle_format, + '7f', + msg=(f'The particle format ({self.format.particle_format}) is incorrect.'), + ) class CorsikaBlocksThinTests(CorsikaBlocksTests): @@ -67,24 +65,25 @@ def setUp(self): def test_validate_particle_format(self): """Verify that the particle format is correct""" - self.assertEqual(self.format.particle_format, '8f', - msg=('The thinned particle format ({particle}) is incorrect.' - .format(particle=self.format.particle_format))) + self.assertEqual( + self.format.particle_format, + '8f', + msg=(f'The thinned particle format ({self.format.particle_format}) is incorrect.'), + ) class ParticleDataTests(unittest.TestCase): - def setUp(self): # Input - id = 1000 - p_x = 2. # GeV - p_y = 1. # GeV - p_z = 10. # GeV - x = 300. # cm - y = 400. # cm - t = 12345678. # ns + particle_id = 1000 + p_x = 2.0 # GeV + p_y = 1.0 # GeV + p_z = 10.0 # GeV + x = 300.0 # cm + y = 400.0 # cm + t = 12345678.0 # ns - self.subblock = (id, p_x, p_y, p_z, x, y, t) + self.subblock = (particle_id, p_x, p_y, p_z, x, y, t) # Output p_x *= 1e9 # eV @@ -92,35 +91,33 @@ def setUp(self): p_z *= 1e9 x *= 1e-2 # m y *= 1e-2 - r = sqrt(x ** 2 + y ** 2) + r = sqrt(x**2 + y**2) phi = atan2(x, -y) - self.result = (p_x, p_y, -p_z, -y, x, t, id / 1000, r, id / 10 % 100, - id % 10, phi) + self.result = (p_x, p_y, -p_z, -y, x, t, particle_id / 1000, r, particle_id / 10 % 100, particle_id % 10, phi) def test_particle_data(self): """Verify conversion of particle information by particle_data()""" self.assertAlmostEqual(blocks.particle_data(self.subblock), self.result) - @unittest.skipUnless(numba_available, "Numba required") + @unittest.skipUnless(numba_available, 'Numba required') def test_numba_jit(self): """Verify particle_data() with numba JIT disabled""" self.assertTrue(hasattr(blocks.particle_data, '__numba__')) - old_value = getattr(numba.config, 'DISABLE_JIT') - setattr(numba.config, 'DISABLE_JIT', 1) + old_value = numba.config.DISABLE_JIT + numba.config.DISABLE_JIT = 1 self.assertAlmostEqual(blocks.particle_data(self.subblock), self.result) - setattr(numba.config, 'DISABLE_JIT', old_value) + numba.config.DISABLE_JIT = old_value class ParticleDataThinTests(ParticleDataTests): - def setUp(self): super().setUp() # Input - weight = 9. + weight = 9.0 self.subblock = self.subblock + (weight,) # Output @@ -131,12 +128,12 @@ def test_particle_data(self): self.assertAlmostEqual(blocks.particle_data_thin(self.subblock), self.result) - @unittest.skipUnless(numba_available, "Numba required") + @unittest.skipUnless(numba_available, 'Numba required') def test_numba_jit(self): """Verify particle_data() with numba JIT disabled""" self.assertTrue(hasattr(blocks.particle_data_thin, '__numba__')) - old_value = getattr(numba.config, 'DISABLE_JIT') - setattr(numba.config, 'DISABLE_JIT', 1) + old_value = numba.config.DISABLE_JIT + numba.config.DISABLE_JIT = 1 self.assertAlmostEqual(blocks.particle_data_thin(self.subblock), self.result) - setattr(numba.config, 'DISABLE_JIT', old_value) + numba.config.DISABLE_JIT = old_value diff --git a/sapphire/tests/corsika/test_corsika.py b/sapphire/tests/corsika/test_corsika.py index c6c8761f..d2316899 100644 --- a/sapphire/tests/corsika/test_corsika.py +++ b/sapphire/tests/corsika/test_corsika.py @@ -12,9 +12,7 @@ class CorsikaFileTests(unittest.TestCase): def setUp(self): self.file = corsika.reader.CorsikaFile(DATA_FILE) - - def tearDown(self): - self.file.finish() + self.addCleanup(self.file.finish) def test_validate_file(self): """Verify that the data file is valid""" @@ -28,7 +26,7 @@ def test_run_header(self): self.assertIsInstance(header, corsika.blocks.RunHeader) self.assertEqual(header.id, b'RUNH') self.assertAlmostEqual(header.version, 7.4, 4) - for h in [10., 5000., 30000., 50000., 110000.]: + for h in [10.0, 5000.0, 30000.0, 50000.0, 110000.0]: t = header.height_to_thickness(h) self.assertAlmostEqual(header.thickness_to_height(t), h, 8) @@ -58,7 +56,7 @@ def test_event_header(self): self.assertEqual(header.id, b'EVTH') self.assertEqual(corsika.particles.name(header.particle_id), 'proton') self.assertEqual(header.energy, 1e14) - self.assertEqual(header.azimuth, -pi / 2.) + self.assertEqual(header.azimuth, -pi / 2.0) self.assertEqual(header.zenith, 0.0) self.assertEqual(header.hadron_model_high, 'QGSJET') diff --git a/sapphire/tests/corsika/test_corsika_queries.py b/sapphire/tests/corsika/test_corsika_queries.py index 49671e6c..22e804be 100644 --- a/sapphire/tests/corsika/test_corsika_queries.py +++ b/sapphire/tests/corsika/test_corsika_queries.py @@ -11,12 +11,9 @@ class CorsikaQueryTest(unittest.TestCase): - def setUp(self): self.cq = corsika_queries.CorsikaQuery(self.get_overview_path()) - - def tearDown(self): - self.cq.finish() + self.addCleanup(self.cq.finish) def test_seeds(self): result = self.cq.seeds(self.cq.all_simulations()) @@ -34,7 +31,7 @@ def test_get_info(self): def test_all_energies(self): energies = list(self.cq.all_energies) - assert_allclose(energies, [14.]) + assert_allclose(energies, [14.0]) def test_all_particles(self): particles = self.cq.all_particles @@ -42,19 +39,19 @@ def test_all_particles(self): def test_all_azimuths(self): azimuths = self.cq.all_azimuths - self.assertEqual(azimuths, {-90.}) + self.assertEqual(azimuths, {-90.0}) def test_all_zeniths(self): zeniths = self.cq.all_zeniths - self.assertEqual(zeniths, {0.}) + self.assertEqual(zeniths, {0.0}) def test_available_parameters(self): result = list(self.cq.available_parameters('energy', particle='proton')) assert_allclose(result, [14.0]) - result = self.cq.available_parameters('particle_id', zenith=0.) + result = self.cq.available_parameters('particle_id', zenith=0.0) self.assertEqual(result, {'proton'}) - result = self.cq.available_parameters('zenith', azimuth=-90.) - self.assertEqual(result, {0.}) + result = self.cq.available_parameters('zenith', azimuth=-90.0) + self.assertEqual(result, {0.0}) self.assertRaises(RuntimeError, self.cq.available_parameters, 'zenith', energy=19) self.assertRaises(RuntimeError, self.cq.available_parameters, 'zenith', particle='iron') @@ -64,7 +61,6 @@ def get_overview_path(self): class MockCorsikaQueryTest(unittest.TestCase): - @patch.object(corsika_queries.tables, 'open_file') def setUp(self, mock_open): self.mock_open = mock_open @@ -75,8 +71,7 @@ def setUp(self, mock_open): def test_init(self): self.mock_open.assert_called_once_with(self.data_path, 'r') - self.mock_open.return_value.get_node.assert_called_once_with( - sentinel.simulations_group) + self.mock_open.return_value.get_node.assert_called_once_with(sentinel.simulations_group) @patch.object(corsika_queries.tables, 'open_file') def test_init_file(self, mock_open): @@ -98,32 +93,33 @@ def test_simulations(self, mock_perform): self.cq.all_particles = ['electron'] self.cq.all_energies = [15.5] - result = self.cq.simulations(particle='electron', energy=15.5, - zenith=0., azimuth=0.) + result = self.cq.simulations(particle='electron', energy=15.5, zenith=0.0, azimuth=0.0) self.assertEqual(result, sentinel.simulations) mock_perform.assert_called_with( '(particle_id == 3) & ' '(abs(log10(energy) - 15.5) < 1e-4) & ' '(abs(zenith - 0.0) < 1e-4) & ' - '(abs(azimuth - 0.0) < 1e-4)', False) + '(abs(azimuth - 0.0) < 1e-4)', + False, + ) def test_filter(self): - filter = self.cq.filter('foo', 123) - self.assertEqual(filter, '(foo == 123)') + tables_filter = self.cq.filter('foo', 123) + self.assertEqual(tables_filter, '(foo == 123)') def test_float_filter(self): - filter = self.cq.float_filter('foo', 12.3) - self.assertEqual(filter, '(abs(foo - 12.3) < 1e-4)') + tables_filter = self.cq.float_filter('foo', 12.3) + self.assertEqual(tables_filter, '(abs(foo - 12.3) < 1e-4)') def test_range_filter(self): - filter = self.cq.range_filter('foo', 12.3, 14.5) - self.assertEqual(filter, '(foo >= 12.3) & (foo <= 14.5)') - filter = self.cq.range_filter('foo', 12.3) - self.assertEqual(filter, '(foo >= 12.3)') - filter = self.cq.range_filter('foo', max=14.5) - self.assertEqual(filter, '(foo <= 14.5)') - filter = self.cq.range_filter('foo') - self.assertEqual(filter, '') + tables_filter = self.cq.range_filter('foo', 12.3, 14.5) + self.assertEqual(tables_filter, '(foo >= 12.3) & (foo <= 14.5)') + tables_filter = self.cq.range_filter('foo', 12.3) + self.assertEqual(tables_filter, '(foo >= 12.3)') + tables_filter = self.cq.range_filter('foo', max_value=14.5) + self.assertEqual(tables_filter, '(foo <= 14.5)') + tables_filter = self.cq.range_filter('foo') + self.assertEqual(tables_filter, '') def test_all_simulations(self): result = self.cq.all_simulations() diff --git a/sapphire/tests/corsika/test_generate_corsika_overview.py b/sapphire/tests/corsika/test_generate_corsika_overview.py index a9a5a4ce..cacdb220 100644 --- a/sapphire/tests/corsika/test_generate_corsika_overview.py +++ b/sapphire/tests/corsika/test_generate_corsika_overview.py @@ -11,18 +11,14 @@ class GenerateCorsikaOverviewTests(unittest.TestCase): - def setUp(self): self.source_path = self.get_testdata_path() self.expected_path = self.get_expected_path() self.destination_path = self.create_tempfile_path() - - def tearDown(self): - os.remove(self.destination_path) + self.addCleanup(os.remove, self.destination_path) def test_store_data(self): - generate_corsika_overview(source=self.source_path, - destination=self.destination_path) + generate_corsika_overview(source=self.source_path, destination=self.destination_path) validate_results(self, self.expected_path, self.destination_path) def create_tempfile_path(self): diff --git a/sapphire/tests/corsika/test_particles.py b/sapphire/tests/corsika/test_particles.py index 02d1b9ba..29812f32 100644 --- a/sapphire/tests/corsika/test_particles.py +++ b/sapphire/tests/corsika/test_particles.py @@ -4,48 +4,46 @@ class CorsikaParticlesTests(unittest.TestCase): - def setUp(self): - self.pid_name = [(1, 'gamma'), - (2, 'positron'), - (3, 'electron'), - (5, 'muon_p'), - (6, 'muon_m'), - (13, 'neutron'), - (14, 'proton'), - (201, 'deuteron'), - (301, 'tritium'), - (302, 'helium3'), - (402, 'alpha'), - (1206, 'carbon'), - (1407, 'nitrogen'), - (1608, 'oxygen'), - (2713, 'aluminium'), - (2814, 'silicon'), - (3216, 'sulfur'), - (5626, 'iron')] - self.massless_atoms = [(909, 'fluorine'), - (3232, 'germanium'), - (9999, 'einsteinium')] - self.atoms = [(1406, 'carbon14'), - (9999, 'einsteinium99')] + self.pid_name = [ + (1, 'gamma'), + (2, 'positron'), + (3, 'electron'), + (5, 'muon_p'), + (6, 'muon_m'), + (13, 'neutron'), + (14, 'proton'), + (201, 'deuteron'), + (301, 'tritium'), + (302, 'helium3'), + (402, 'alpha'), + (1206, 'carbon'), + (1407, 'nitrogen'), + (1608, 'oxygen'), + (2713, 'aluminium'), + (2814, 'silicon'), + (3216, 'sulfur'), + (5626, 'iron'), + ] + self.massless_atoms = [(909, 'fluorine'), (3232, 'germanium'), (9999, 'einsteinium')] + self.atoms = [(1406, 'carbon14'), (9999, 'einsteinium99')] def test_particle_ids(self): """Verify that the correct names belong to each ID""" - for id, name in self.pid_name: - self.assertEqual(particles.ID[id], name) + for particle_id, name in self.pid_name: + self.assertEqual(particles.ID[particle_id], name) def test_conversion_functions(self): """Verify that the functions correctly convert back and forth""" - for id, name in self.pid_name: - self.assertEqual(particles.name(id), name) - self.assertEqual(particles.particle_id(name), id) + for particle_id, name in self.pid_name: + self.assertEqual(particles.name(particle_id), name) + self.assertEqual(particles.particle_id(name), particle_id) - for id, name in self.massless_atoms: - self.assertEqual(particles.particle_id(name), id) + for particle_id, name in self.massless_atoms: + self.assertEqual(particles.particle_id(name), particle_id) - for id, name in self.atoms: - self.assertEqual(particles.name(id), name) - self.assertEqual(particles.particle_id(name), id) + for particle_id, name in self.atoms: + self.assertEqual(particles.name(particle_id), name) + self.assertEqual(particles.particle_id(name), particle_id) diff --git a/sapphire/tests/corsika/test_qsub_corsika.py b/sapphire/tests/corsika/test_qsub_corsika.py index 8adbe60f..a5a43058 100644 --- a/sapphire/tests/corsika/test_qsub_corsika.py +++ b/sapphire/tests/corsika/test_qsub_corsika.py @@ -8,16 +8,14 @@ class CorsikaBatchTest(unittest.TestCase): - def setUp(self): self.cb = qsub_corsika.CorsikaBatch() @patch.object(qsub_corsika.particles, 'particle_id') def test_init(self, mock_particles): mock_particles.return_value = sentinel.particle_id - cb = qsub_corsika.CorsikaBatch(16, sentinel.particle, sentinel.zenith, - 30, sentinel.queue, sentinel.corsika) - self.assertEqual(cb.energy_pre, 1.) + cb = qsub_corsika.CorsikaBatch(16, sentinel.particle, sentinel.zenith, 30, sentinel.queue, sentinel.corsika) + self.assertEqual(cb.energy_pre, 1.0) self.assertEqual(cb.energy_pow, 7) mock_particles.assert_called_once_with(sentinel.particle) self.assertEqual(cb.particle, sentinel.particle_id) @@ -53,14 +51,16 @@ def test_submit_job(self, mock_create_script, mock_rundir, mock_submit_job): mock_rundir.return_value = '/data/123_456/' mock_create_script.return_value = sentinel.script self.cb.submit_job() - mock_submit_job.assert_called_once_with(sentinel.script, 'cor_123_456', - 'generic', '-d /data/123_456/') + mock_submit_job.assert_called_once_with(sentinel.script, 'cor_123_456', 'generic', '-d /data/123_456/') # Check addition of walltime argument for long queue self.cb.queue = 'long' self.cb.submit_job() - mock_submit_job.assert_called_with(sentinel.script, 'cor_123_456', - 'long', - '-d /data/123_456/ -l walltime=96:00:00') + mock_submit_job.assert_called_with( + sentinel.script, + 'cor_123_456', + 'long', + '-d /data/123_456/ -l walltime=96:00:00', + ) @patch.object(qsub_corsika.os, 'listdir') def test_taken_seeds(self, mock_listdir): @@ -116,16 +116,23 @@ def test_create_input(self, mock_rundir): class MultipleJobsTest(unittest.TestCase): - @patch.object(qsub_corsika.qsub, 'check_queue') def test_no_available_slots(self, mock_check_queue): """No slots available on queue""" mock_check_queue.return_value = 0 - self.assertRaises(Exception, qsub_corsika.multiple_jobs, sentinel.n, - sentinel.energy, sentinel.particle, sentinel.zenith, - sentinel.azimuth, sentinel.queue, sentinel.corsika, - progress=False) + self.assertRaises( + Exception, + qsub_corsika.multiple_jobs, + sentinel.n, + sentinel.energy, + sentinel.particle, + sentinel.zenith, + sentinel.azimuth, + sentinel.queue, + sentinel.corsika, + progress=False, + ) mock_check_queue.assert_called_once_with(sentinel.queue) @patch.object(qsub_corsika, 'CorsikaBatch') @@ -135,15 +142,25 @@ def test_one_available_wanted_more(self, mock_check_queue, mock_corsika_batch): mock_check_queue.return_value = 1 with warnings.catch_warnings(record=True) as warned: - qsub_corsika.multiple_jobs(2, sentinel.energy, sentinel.particle, - sentinel.zenith, sentinel.azimuth, - sentinel.queue, sentinel.corsika, - progress=False) + qsub_corsika.multiple_jobs( + 2, + sentinel.energy, + sentinel.particle, + sentinel.zenith, + sentinel.azimuth, + sentinel.queue, + sentinel.corsika, + progress=False, + ) mock_check_queue.assert_called_once_with(sentinel.queue) mock_corsika_batch.assert_called_once_with( - energy=sentinel.energy, particle=sentinel.particle, - zenith=sentinel.zenith, azimuth=sentinel.azimuth, - queue=sentinel.queue, corsika=sentinel.corsika) + energy=sentinel.energy, + particle=sentinel.particle, + zenith=sentinel.zenith, + azimuth=sentinel.azimuth, + queue=sentinel.queue, + corsika=sentinel.corsika, + ) mock_corsika_batch.return_value.run.assert_called_once_with() self.assertEqual(len(warned), 1) @@ -154,15 +171,25 @@ def test_two_available_wanted_more(self, mock_check_queue, mock_corsika_batch): mock_check_queue.return_value = 2 with warnings.catch_warnings(record=True) as warned: - qsub_corsika.multiple_jobs(3, sentinel.energy, sentinel.particle, - sentinel.zenith, sentinel.azimuth, - sentinel.queue, sentinel.corsika, - progress=False) + qsub_corsika.multiple_jobs( + 3, + sentinel.energy, + sentinel.particle, + sentinel.zenith, + sentinel.azimuth, + sentinel.queue, + sentinel.corsika, + progress=False, + ) mock_check_queue.assert_called_once_with(sentinel.queue) mock_corsika_batch.assert_called_with( - energy=sentinel.energy, particle=sentinel.particle, - zenith=sentinel.zenith, azimuth=sentinel.azimuth, - queue=sentinel.queue, corsika=sentinel.corsika) + energy=sentinel.energy, + particle=sentinel.particle, + zenith=sentinel.zenith, + azimuth=sentinel.azimuth, + queue=sentinel.queue, + corsika=sentinel.corsika, + ) mock_corsika_batch.return_value.run.assert_called_with() # This is twice as often because it includes the calls to run() self.assertEqual(len(mock_corsika_batch.mock_calls), 4) @@ -177,15 +204,25 @@ def test_plenty_available(self, mock_check_queue, mock_corsika_batch): mock_check_queue.return_value = 50 n = 10 with warnings.catch_warnings(record=True) as warned: - qsub_corsika.multiple_jobs(n, sentinel.energy, sentinel.particle, - sentinel.zenith, sentinel.azimuth, - sentinel.queue, sentinel.corsika, - progress=False) + qsub_corsika.multiple_jobs( + n, + sentinel.energy, + sentinel.particle, + sentinel.zenith, + sentinel.azimuth, + sentinel.queue, + sentinel.corsika, + progress=False, + ) mock_check_queue.assert_called_once_with(sentinel.queue) mock_corsika_batch.assert_called_with( - energy=sentinel.energy, particle=sentinel.particle, - zenith=sentinel.zenith, azimuth=sentinel.azimuth, - queue=sentinel.queue, corsika=sentinel.corsika) + energy=sentinel.energy, + particle=sentinel.particle, + zenith=sentinel.zenith, + azimuth=sentinel.azimuth, + queue=sentinel.queue, + corsika=sentinel.corsika, + ) mock_corsika_batch.return_value.run.assert_called_with() # This is twice as often because it includes the calls to run() self.assertEqual(len(mock_corsika_batch.mock_calls), n * 2) diff --git a/sapphire/tests/corsika/test_qsub_store_corsika_data.py b/sapphire/tests/corsika/test_qsub_store_corsika_data.py index 43fade0f..28834705 100644 --- a/sapphire/tests/corsika/test_qsub_store_corsika_data.py +++ b/sapphire/tests/corsika/test_qsub_store_corsika_data.py @@ -7,7 +7,6 @@ class SeedsTest(unittest.TestCase): - @patch.object(qsub_store_corsika_data.glob, 'glob') def test_all_seeds(self, mock_glob): mock_glob.return_value = ['/data/123_456', '/data/234_567'] @@ -31,7 +30,7 @@ def test_seeds_in_queue(self): # Empty set if log not available with patch.object(builtins, 'open', mock_open()) as mock_file: - mock_file.side_effect = IOError('no log!') + mock_file.side_effect = OSError('no log!') seeds = qsub_store_corsika_data.seeds_in_queue() mock_file.assert_called_with(qsub_store_corsika_data.QUEUED_SEEDS) self.assertEqual(seeds, set()) @@ -67,9 +66,12 @@ def test_store_command(self): tmp = qsub_store_corsika_data.DATADIR qsub_store_corsika_data.DATADIR = '/data' command = qsub_store_corsika_data.store_command('123_456') - self.assertEqual(command, '/data/hisparc/env/miniconda/envs/corsika/bin/python ' - '/data/hisparc/env/miniconda/envs/corsika/bin/store_corsika_data ' - '/data/123_456/DAT000000 /data/123_456/corsika.h5') + self.assertEqual( + command, + '/data/hisparc/env/miniconda/envs/corsika/bin/python ' + '/data/hisparc/env/miniconda/envs/corsika/bin/store_corsika_data ' + '/data/123_456/DAT000000 /data/123_456/corsika.h5', + ) qsub_store_corsika_data.DATADIR = tmp @patch.object(qsub_store_corsika_data.os.path, 'getsize') @@ -80,8 +82,17 @@ def test_store_command(self): @patch.object(qsub_store_corsika_data.qsub, 'submit_job') @patch.object(qsub_store_corsika_data, 'append_queued_seeds') @patch.object(qsub_store_corsika_data, 'SCRIPT_TEMPLATE') - def test_run(self, mock_template, mock_append, mock_submit, mock_store, - mock_check, mock_get_seeds, mock_umask, mock_size): + def test_run( + self, + mock_template, + mock_append, + mock_submit, + mock_store, + mock_check, + mock_get_seeds, + mock_umask, + mock_size, + ): seeds = {'123_456', '234_567'} mock_size.return_value = 12355 mock_get_seeds.return_value = seeds.copy() @@ -92,6 +103,5 @@ def test_run(self, mock_template, mock_append, mock_submit, mock_store, for seed in seeds: mock_submit.assert_any_call(sentinel.script, seed, sentinel.queue, '') mock_append.assert_any_call([seed]) - mock_template.format.assert_called_with(command=sentinel.command, - datadir=qsub_store_corsika_data.DATADIR) + mock_template.format.assert_called_with(command=sentinel.command, datadir=qsub_store_corsika_data.DATADIR) mock_umask.assert_called_once_with(0o02) diff --git a/sapphire/tests/corsika/test_store_corsika_data.py b/sapphire/tests/corsika/test_store_corsika_data.py index 2151f7dd..7770929b 100644 --- a/sapphire/tests/corsika/test_store_corsika_data.py +++ b/sapphire/tests/corsika/test_store_corsika_data.py @@ -14,26 +14,27 @@ class StoreCorsikaDataTests(unittest.TestCase): - """Store CORSIKA test using the function directly""" def setUp(self): self.source_path = self.get_testdata_path() self.expected_path = self.get_expected_path() self.destination_path = self.create_tempfile_path() + self.addCleanup(os.remove, self.destination_path) self.thin = False - def tearDown(self): - os.remove(self.destination_path) - def test_store_data(self): # First with overwrite false - self.assertRaises(Exception, store_and_sort_corsika_data, - self.source_path, self.destination_path, - progress=True, thin=self.thin) + self.assertRaises( + Exception, + store_and_sort_corsika_data, + self.source_path, + self.destination_path, + progress=True, + thin=self.thin, + ) # Now with overwrite true - store_and_sort_corsika_data(self.source_path, self.destination_path, - overwrite=True, thin=self.thin) + store_and_sort_corsika_data(self.source_path, self.destination_path, overwrite=True, thin=self.thin) validate_results(self, self.expected_path, self.destination_path) def create_tempfile_path(self): @@ -51,7 +52,6 @@ def get_expected_path(self): class StoreThinCorsikaDataTests(StoreCorsikaDataTests): - """Store thinned CORSIKA test using the function directly""" def setUp(self): diff --git a/sapphire/tests/corsika/test_units.py b/sapphire/tests/corsika/test_units.py index ac98d205..154f477c 100644 --- a/sapphire/tests/corsika/test_units.py +++ b/sapphire/tests/corsika/test_units.py @@ -5,20 +5,19 @@ class CorsikaUnitsTests(unittest.TestCase): - def test_base_units(self): """Verify that the correct units are one""" - self.assertEqual(units.meter, 1.) + self.assertEqual(units.meter, 1.0) self.assertEqual(units.m, units.meter) - self.assertEqual(units.nanosecond, 1.) + self.assertEqual(units.nanosecond, 1.0) self.assertEqual(units.ns, units.nanosecond) - self.assertEqual(units.electronvolt, 1.) + self.assertEqual(units.electronvolt, 1.0) self.assertEqual(units.eV, units.electronvolt) - self.assertEqual(units.radian, 1.) + self.assertEqual(units.radian, 1.0) self.assertEqual(units.rad, units.radian) - self.assertEqual(units.eplus, 1.) - self.assertEqual(units.volt, 1.) + self.assertEqual(units.eplus, 1.0) + self.assertEqual(units.volt, 1.0) self.assertEqual(units.volt, units.electronvolt / units.eplus) def test_corsika_units(self): @@ -33,13 +32,13 @@ def test_corsika_units(self): self.assertEqual(units.second, units.giga * units.ns) self.assertEqual(units.s, units.second) self.assertEqual(units.EeV, units.exa * units.eV) - self.assertEqual(units.degree, (math.pi / 180.) * units.rad) + self.assertEqual(units.degree, (math.pi / 180.0) * units.rad) self.assertEqual(units.joule, units.eV / units.eSI) self.assertEqual(units.joule, units.eV / units.eSI) - self.assertEqual(units.gram, units.peta * units.joule * units.ns ** 2 / units.m ** 2) + self.assertEqual(units.gram, units.peta * units.joule * units.ns**2 / units.m**2) self.assertEqual(units.g, units.gram) - self.assertEqual(units.tesla, units.giga * units.volt * units.ns / units.m ** 2) + self.assertEqual(units.tesla, units.giga * units.volt * units.ns / units.m**2) def test_prefixes(self): """Verify the values of the prefixes""" diff --git a/sapphire/tests/create_and_store_test_data.py b/sapphire/tests/create_and_store_test_data.py old mode 100755 new mode 100644 index e4700441..45b35d20 --- a/sapphire/tests/create_and_store_test_data.py +++ b/sapphire/tests/create_and_store_test_data.py @@ -4,7 +4,7 @@ def main(): - descr = "Download and update test data for usage in tests." - parser = argparse.ArgumentParser(description=descr) + description = 'Download and update test data for usage in tests.' + parser = argparse.ArgumentParser(description=description) parser.parse_args() create_and_store_test_data() diff --git a/sapphire/tests/data/__init__.py b/sapphire/tests/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sapphire/tests/data/test_extend_local_data.py b/sapphire/tests/data/test_extend_local_data.py index 801c29a1..469270bf 100644 --- a/sapphire/tests/data/test_extend_local_data.py +++ b/sapphire/tests/data/test_extend_local_data.py @@ -6,7 +6,6 @@ class UpdateLocalDataTests(unittest.TestCase): - @patch.object(extend_local_data, 'Network') @patch.object(extend_local_data, 'update_sublevel_tsv') def test_update_local_json(self, mock_sub, mock_net): diff --git a/sapphire/tests/data/test_update_local_data.py b/sapphire/tests/data/test_update_local_data.py index ddccaa70..4ab2dba5 100644 --- a/sapphire/tests/data/test_update_local_data.py +++ b/sapphire/tests/data/test_update_local_data.py @@ -11,7 +11,6 @@ def fake_pbar(*args, **kwargs): class UpdateLocalDataTests(unittest.TestCase): - @patch.object(update_local_data, 'pbar', side_effect=fake_pbar) @patch.object(builtins, 'print') @patch.object(update_local_data, 'update_sublevel_json') @@ -25,17 +24,15 @@ def test_update_local_json(self, mock_top, mock_sub, mock_print, mock_pbar): self.assertTrue(mock_print.called) self.assertTrue(mock_pbar.called) - @patch.object(update_local_data, 'pbar', side_effect=fake_pbar) @patch.object(builtins, 'print') @patch.object(update_local_data, 'Network') @patch.object(update_local_data, 'HiSPARCNetwork') @patch.object(update_local_data, 'update_subsublevel_tsv') @patch.object(update_local_data, 'update_sublevel_tsv') - def test_update_local_tsv(self, mock_sub, mock_ssub, mock_hnet, mock_net, mock_print, mock_pbar): + def test_update_local_tsv(self, mock_sub, mock_ssub, mock_hnet, mock_net, mock_print): update_local_data.update_local_tsv(progress=False) self.assertTrue(mock_sub.called) self.assertTrue(mock_ssub.called) self.assertFalse(mock_print.called) update_local_data.update_local_tsv(progress=True) self.assertTrue(mock_print.called) - self.assertTrue(mock_pbar.called) diff --git a/sapphire/tests/esd_load_data.py b/sapphire/tests/esd_load_data.py index f51e99ef..f2e8cd58 100644 --- a/sapphire/tests/esd_load_data.py +++ b/sapphire/tests/esd_load_data.py @@ -2,22 +2,23 @@ import os import tempfile +from pathlib import Path from urllib.request import urlretrieve import tables from sapphire import esd -self_path = os.path.dirname(__file__) +self_path = Path(__file__).parent -test_data_path = os.path.join(self_path, 'test_data/esd_load_data.h5') -test_data_coincidences_path = os.path.join(self_path, 'test_data/esd_coincidence_data.h5') +test_data_path = self_path / 'test_data/esd_load_data.h5' +test_data_coincidences_path = self_path / 'test_data/esd_coincidence_data.h5' -events_source = os.path.join(self_path, 'test_data/events-s501-20120101.tsv') -weather_source = os.path.join(self_path, 'test_data/weather-s501-20120101.tsv') -singles_source = os.path.join(self_path, 'test_data/singles-s501-20170101.tsv') -lightning_source = os.path.join(self_path, 'test_data/lightning-knmi-20150717.tsv') -coincidences_source = os.path.join(self_path, 'test_data/coincidences-20160310.tsv') +events_source = self_path / 'test_data/events-s501-20120101.tsv' +weather_source = self_path / 'test_data/weather-s501-20120101.tsv' +singles_source = self_path / 'test_data/singles-s501-20170101.tsv' +lightning_source = self_path / 'test_data/lightning-knmi-20150717.tsv' +coincidences_source = self_path / 'test_data/coincidences-20160310.tsv' def create_tempfile_path(): @@ -81,16 +82,26 @@ def create_and_store_test_data(): perform_esd_download_data(test_data_path) perform_download_coincidences(test_data_coincidences_path) - urlretrieve(esd.get_weather_url().format(station_number=501, query='start=2012-01-01&end=2012-01-01+00:01:00'), - weather_source) - urlretrieve(esd.get_events_url().format(station_number=501, query='start=2012-01-01&end=2012-01-01+00:01:00'), - events_source) - urlretrieve(esd.get_singles_url().format(station_number=501, query='start=2017-01-01&end=2017-01-01+00:10:00'), - singles_source) - urlretrieve(esd.get_lightning_url().format(lightning_type=4, query='start=2015-07-17&end=2015-07-17+00:10:00'), - lightning_source) - urlretrieve(esd.get_coincidences_url().format(query='start=2016-03-10&end=2016-03-10+00:01:00&stations=501,+510&n=2'), - coincidences_source) + urlretrieve( + esd.get_weather_url().format(station_number=501, query='start=2012-01-01&end=2012-01-01+00:01:00'), + weather_source, + ) + urlretrieve( + esd.get_events_url().format(station_number=501, query='start=2012-01-01&end=2012-01-01+00:01:00'), + events_source, + ) + urlretrieve( + esd.get_singles_url().format(station_number=501, query='start=2017-01-01&end=2017-01-01+00:10:00'), + singles_source, + ) + urlretrieve( + esd.get_lightning_url().format(lightning_type=4, query='start=2015-07-17&end=2015-07-17+00:10:00'), + lightning_source, + ) + urlretrieve( + esd.get_coincidences_url().format(query='start=2016-03-10&end=2016-03-10+00:01:00&stations=501,+510&n=2'), + coincidences_source, + ) if __name__ == '__main__': diff --git a/sapphire/tests/simulations/perform_simulation.py b/sapphire/tests/simulations/perform_simulation.py index 3676eb8c..d904bff5 100644 --- a/sapphire/tests/simulations/perform_simulation.py +++ b/sapphire/tests/simulations/perform_simulation.py @@ -30,8 +30,7 @@ def perform_groundparticlessimulation(filename, mock_time): cluster = sapphire.clusters.SimpleCluster(size=40) filters = tables.Filters(complevel=1) with tables.open_file(filename, 'w', filters=filters) as data: - sim = GroundParticlesSimulation(corsika_data_path, 70, cluster, - data, n=10, seed=1, progress=False) + sim = GroundParticlesSimulation(corsika_data_path, 70, cluster, data, n=10, seed=1, progress=False) sim.run() @@ -45,8 +44,7 @@ def perform_groundparticlesgammasimulation(filename, mock_time): cluster = sapphire.clusters.SimpleCluster(size=40) filters = tables.Filters(complevel=1) with tables.open_file(filename, 'w', filters=filters) as data: - sim = GroundParticlesGammaSimulation(corsika_data_path, 70, cluster, - data, n=10, seed=42, progress=False) + sim = GroundParticlesGammaSimulation(corsika_data_path, 70, cluster, data, n=10, seed=42, progress=False) sim.run() @@ -56,8 +54,7 @@ def perform_flatfrontsimulation(filename): cluster = sapphire.clusters.SimpleCluster(size=40) filters = tables.Filters(complevel=1) with tables.open_file(filename, 'w', filters=filters) as data: - sim = FlatFrontSimulation(cluster, data, '/', 10, seed=1, - progress=False) + sim = FlatFrontSimulation(cluster, data, '/', 10, seed=1, progress=False) sim.run() @@ -67,8 +64,7 @@ def perform_nkgldfsimulation(filename): cluster = sapphire.clusters.SimpleCluster(size=40) filters = tables.Filters(complevel=1) with tables.open_file(filename, 'w', filters=filters) as data: - sim = NkgLdfSimulation(400, 1e15, 1e19, cluster, data, '/', 10, - seed=1, progress=False) + sim = NkgLdfSimulation(400, 1e15, 1e19, cluster, data, '/', 10, seed=1, progress=False) sim.run() diff --git a/sapphire/tests/simulations/test_base_simulation.py b/sapphire/tests/simulations/test_base_simulation.py index 1b203e73..9c47c947 100644 --- a/sapphire/tests/simulations/test_base_simulation.py +++ b/sapphire/tests/simulations/test_base_simulation.py @@ -11,7 +11,6 @@ class BaseSimulationTest(unittest.TestCase): - @patch.object(BaseSimulation, '_prepare_output_tables') def setUp(self, mock_method): self.mock_prepare_output_tables = mock_method @@ -20,9 +19,7 @@ def setUp(self, mock_method): self.output_path = sentinel.output_path self.n = sentinel.n - self.simulation = BaseSimulation(self.cluster, self.data, - self.output_path, self.n, - progress=False) + self.simulation = BaseSimulation(self.cluster, self.data, self.output_path, self.n, progress=False) def test_init_sets_attributes(self): self.assertIs(self.simulation.cluster, self.cluster) @@ -36,8 +33,7 @@ def test_init_calls_prepare_output_tables(self): @patch.object(BaseSimulation, '_prepare_coincidence_tables') @patch.object(BaseSimulation, '_prepare_station_tables') @patch.object(BaseSimulation, '_store_station_index') - def test_prepare_output_tables_calls(self, mock_method3, mock_method2, - mock_method1): + def test_prepare_output_tables_calls(self, mock_method3, mock_method2, mock_method1): self.simulation._prepare_output_tables() mock_method1.assert_called_once_with() mock_method2.assert_called_once_with() @@ -58,8 +54,7 @@ def test_run(self, mock_store, mock_simulate, mock_generate): # test store_coincidence called 2nd time with shower_id 1, # parameters and events - mock_store.assert_called_with(1, sentinel.params2, - sentinel.events) + mock_store.assert_called_with(1, sentinel.params2, sentinel.events) def test_generate_shower_parameters(self): self.simulation.n = 10 @@ -69,29 +64,33 @@ def test_generate_shower_parameters(self): output = list(output) self.assertEqual(len(output), 10) - expected = {'core_pos': (None, None), 'zenith': None, 'azimuth': None, - 'size': None, 'energy': None, 'ext_timestamp': None} + expected = { + 'core_pos': (None, None), + 'zenith': None, + 'azimuth': None, + 'size': None, + 'energy': None, + 'ext_timestamp': None, + } self.assertEqual(output[0], expected) @patch.object(BaseSimulation, 'simulate_station_response') @patch.object(BaseSimulation, 'store_station_observables') def test_simulate_events_for_shower(self, mock_store, mock_simulate): self.simulation.cluster = Mock() - self.simulation.cluster.stations = [sentinel.station1, - sentinel.station2, - sentinel.station3] + self.simulation.cluster.stations = [sentinel.station1, sentinel.station2, sentinel.station3] - mock_simulate.side_effect = [(True, sentinel.obs1), (False, None), - (True, sentinel.obs3)] + mock_simulate.side_effect = [(True, sentinel.obs1), (False, None), (True, sentinel.obs3)] mock_store.side_effect = [sentinel.index1, sentinel.index2] - events = self.simulation.simulate_events_for_shower( - sentinel.params) + events = self.simulation.simulate_events_for_shower(sentinel.params) # test simulate_station_response called for each station, with # shower parameters - expected = [call(sentinel.station1, sentinel.params), - call(sentinel.station2, sentinel.params), - call(sentinel.station3, sentinel.params)] + expected = [ + call(sentinel.station1, sentinel.params), + call(sentinel.station2, sentinel.params), + call(sentinel.station3, sentinel.params), + ] self.assertEqual(mock_simulate.call_args_list, expected) # test store_station_observables called only for triggered @@ -101,15 +100,13 @@ def test_simulate_events_for_shower(self, mock_store, mock_simulate): # test returned events consists of list of station indexes and # stored event indexes - self.assertEqual(events, [(0, sentinel.index1), - (2, sentinel.index2)]) + self.assertEqual(events, [(0, sentinel.index1), (2, sentinel.index2)]) @patch.object(BaseSimulation, 'simulate_all_detectors') @patch.object(BaseSimulation, 'simulate_trigger') @patch.object(BaseSimulation, 'process_detector_observables') @patch.object(BaseSimulation, 'simulate_gps') - def test_simulate_station_response(self, mock_gps, mock_process, - mock_trigger, mock_detectors): + def test_simulate_station_response(self, mock_gps, mock_process, mock_trigger, mock_detectors): mock_detectors.return_value = sentinel.detector_observables mock_trigger.return_value = sentinel.has_triggered mock_process.return_value = sentinel.station_observables @@ -118,40 +115,33 @@ def test_simulate_station_response(self, mock_gps, mock_process, mock_station = Mock() mock_station.detectors = sentinel.detectors - has_triggered, station_observables = \ - self.simulation.simulate_station_response(mock_station, - sentinel.parameters) + has_triggered, station_observables = self.simulation.simulate_station_response( + mock_station, + sentinel.parameters, + ) # Tests - mock_detectors.assert_called_once_with(sentinel.detectors, - sentinel.parameters) + mock_detectors.assert_called_once_with(sentinel.detectors, sentinel.parameters) mock_trigger.assert_called_once_with(sentinel.detector_observables) mock_process.assert_called_once_with(sentinel.detector_observables) - mock_gps.assert_called_once_with(sentinel.station_observables, - sentinel.parameters, - mock_station) + mock_gps.assert_called_once_with(sentinel.station_observables, sentinel.parameters, mock_station) self.assertIs(has_triggered, sentinel.has_triggered) self.assertIs(station_observables, sentinel.gps_observables) @patch.object(BaseSimulation, 'simulate_detector_response') def test_simulate_all_detectors(self, mock_response): detectors = [sentinel.detector1, sentinel.detector2] - mock_response.side_effect = [sentinel.observables1, - sentinel.observables2] + mock_response.side_effect = [sentinel.observables1, sentinel.observables2] - observables = self.simulation.simulate_all_detectors( - detectors, sentinel.parameters) + observables = self.simulation.simulate_all_detectors(detectors, sentinel.parameters) - expected = [call(sentinel.detector1, sentinel.parameters), - call(sentinel.detector2, sentinel.parameters)] + expected = [call(sentinel.detector1, sentinel.parameters), call(sentinel.detector2, sentinel.parameters)] self.assertEqual(mock_response.call_args_list, expected) - self.assertEqual(observables, [sentinel.observables1, - sentinel.observables2]) + self.assertEqual(observables, [sentinel.observables1, sentinel.observables2]) def test_simulate_detector_response(self): - observables = self.simulation.simulate_detector_response(Mock(), - Mock()) + observables = self.simulation.simulate_detector_response(Mock(), Mock()) self.assertIsInstance(observables, dict) self.assertIn('n', observables) self.assertIn('t', observables) @@ -172,17 +162,21 @@ def test_simulate_gps(self): self.assertIn('nanoseconds', gps_dict) def test_process_detector_observables(self): - detector_observables = [{'n': 1., 't': 2., 'pulseheights': 3., - 'integrals': 4.}, - {'n': 5., 't': 6., 'pulseheights': 7., - 'integrals': 8.}, - {'foo': -999.}] - - expected = {'n1': 1., 'n2': 5., 't1': 2., 't2': 6., - 'pulseheights': [3., 7., -1., -1.], - 'integrals': [4., 8., -1, -1]} - actual = self.simulation.process_detector_observables( - detector_observables) + detector_observables = [ + {'n': 1.0, 't': 2.0, 'pulseheights': 3.0, 'integrals': 4.0}, + {'n': 5.0, 't': 6.0, 'pulseheights': 7.0, 'integrals': 8.0}, + {'foo': -999.0}, + ] + + expected = { + 'n1': 1.0, + 'n2': 5.0, + 't1': 2.0, + 't2': 6.0, + 'pulseheights': [3.0, 7.0, -1.0, -1.0], + 'integrals': [4.0, 8.0, -1, -1], + } + actual = self.simulation.process_detector_observables(detector_observables) self.assertEqual(expected, actual) @@ -192,16 +186,14 @@ def test_store_station_observables(self): table = station_groups.__getitem__.return_value.events table.nrows = 123 - observables = {'key1': 1., 'key2': 2.} + observables = {'key1': 1.0, 'key2': 2.0} table.colnames = ['key1', 'key2'] - idx = self.simulation.store_station_observables( - sentinel.station_id, observables) + idx = self.simulation.store_station_observables(sentinel.station_id, observables) # tests station_groups.__getitem__.assert_called_once_with(sentinel.station_id) - calls = [call('event_id', table.nrows), call('key2', 2.), - call('key1', 1.)] + calls = [call('event_id', table.nrows), call('key2', 2.0), call('key1', 1.0)] station_groups.asser_has_calls(calls, any_order=True) table.row.append.assert_called_once_with() table.flush.assert_called_once_with() @@ -211,50 +203,53 @@ def test_store_station_observables_raises_warning(self): station_groups = MagicMock() self.simulation.station_groups = station_groups table = station_groups.__getitem__.return_value.events - observables = {'key1': 1., 'key2': 2.} + observables = {'key1': 1.0, 'key2': 2.0} table.colnames = ['key1'] with warnings.catch_warnings(record=True) as warned: warnings.simplefilter('always') - self.simulation.store_station_observables(sentinel.station_id, - observables) + self.simulation.store_station_observables(sentinel.station_id, observables) self.assertEqual(len(warned), 1) - @unittest.skip("WIP") + @unittest.skip('WIP') def test_store_coincidence(self, shower_id, shower_parameters, station_events): pass - @unittest.skip("WIP") + @unittest.skip('WIP') def test_prepare_coincidence_tables(self): pass - @unittest.skip("WIP") + @unittest.skip('WIP') def test_prepare_station_tables(self): pass - @unittest.skip("WIP") + @unittest.skip('WIP') def test_store_station_index(self): pass - @unittest.skip("Does not test this unit") + @unittest.skip('Does not test this unit') def test_init_creates_coincidences_output_group(self): - self.data.create_group.assert_any_call( - self.output_path, 'coincidences', createparents=True) + self.data.create_group.assert_any_call(self.output_path, 'coincidences', createparents=True) self.data.create_table.assert_called_with( - self.simulation.coincidence_group, 'coincidences', storage.Coincidence) + self.simulation.coincidence_group, + 'coincidences', + storage.Coincidence, + ) self.assertEqual(self.data.create_vlarray.call_count, 2) self.data.create_vlarray.assert_any_call( - self.simulation.coincidence_group, 'c_index', tables.UInt32Col(shape=2)) + self.simulation.coincidence_group, + 'c_index', + tables.UInt32Col(shape=2), + ) - @unittest.skip("Does not test this unit") + @unittest.skip('Does not test this unit') def test_init_creates_cluster_output_group(self): - self.data.create_group.assert_any_call( - self.output_path, 'cluster_simulations', createparents=True) + self.data.create_group.assert_any_call(self.output_path, 'cluster_simulations', createparents=True) # The following tests need a better mock of cluster in order to work. # self.data.create_group.assert_any_call(self.simulation.cluster_group, 'station_0') # self.data.create_table.assert_any_call( # station_group, 'events', storage.ProcessedHisparcEvent, expectedrows=self.n) - @unittest.skip("Does not test this unit") + @unittest.skip('Does not test this unit') def test_init_stores_cluster_in_attrs(self): self.assertIs(self.simulation.coincidence_group._v_attrs.cluster, self.cluster) diff --git a/sapphire/tests/simulations/test_detectors.py b/sapphire/tests/simulations/test_detectors.py index 317e3c04..ddb4fbe3 100644 --- a/sapphire/tests/simulations/test_detectors.py +++ b/sapphire/tests/simulations/test_detectors.py @@ -12,69 +12,68 @@ class HiSPARCSimulationTest(unittest.TestCase): - def setUp(self): self.simulation = HiSPARCSimulation random.seed(1) np.random.seed(1) def test_simulate_detector_offsets(self): - self.assertEqual(self.simulation.simulate_detector_offsets(1), - [4.49943665734718]) + self.assertEqual(self.simulation.simulate_detector_offsets(1), [4.49943665734718]) offsets = self.simulation.simulate_detector_offsets(10000) - assert_almost_equal(np.mean(offsets), 0., 1) + assert_almost_equal(np.mean(offsets), 0.0, 1) assert_almost_equal(np.std(offsets), 2.77, 2) def test_simulate_detector_offset(self): - self.assertEqual(self.simulation.simulate_detector_offset(), - 4.49943665734718) + self.assertEqual(self.simulation.simulate_detector_offset(), 4.49943665734718) def test_simulate_station_offset(self): - self.assertEqual(self.simulation.simulate_station_offset(), - 25.989525818611867) + self.assertEqual(self.simulation.simulate_station_offset(), 25.989525818611867) def test_simulate_gps_uncertainty(self): - self.assertEqual(self.simulation.simulate_gps_uncertainty(), - 7.3095541364845875) + self.assertEqual(self.simulation.simulate_gps_uncertainty(), 7.3095541364845875) def test_simulate_adc_sampling(self): self.assertEqual(self.simulation.simulate_adc_sampling(0), 0) self.assertEqual(self.simulation.simulate_adc_sampling(0.1), 2.5) self.assertEqual(self.simulation.simulate_adc_sampling(1.25), 2.5) self.assertEqual(self.simulation.simulate_adc_sampling(2.5), 2.5) - self.assertEqual(self.simulation.simulate_adc_sampling(4), 5.) + self.assertEqual(self.simulation.simulate_adc_sampling(4), 5.0) def test_simulate_signal_transport_time(self): - self.assertEqual(list(self.simulation.simulate_signal_transport_time()), - [3.6091128409407927]) - self.assertEqual(list(self.simulation.simulate_signal_transport_time(1)), - [5.0938877122170032]) - self.assertEqual(list(self.simulation.simulate_signal_transport_time(11)), - [2.5509743680305879, 3.2759504918578886, - 2.9027453686866318, 2.7722064380611307, - 2.9975103080633256, 3.3796483500672148, - 3.5099596226498524, 4.2053418869706736, - 3.6197480580293133, 4.9220361334622806, - 3.0411502792684506]) + self.assertEqual(list(self.simulation.simulate_signal_transport_time()), [3.6091128409407927]) + self.assertEqual(list(self.simulation.simulate_signal_transport_time(1)), [5.0938877122170032]) + self.assertEqual( + list(self.simulation.simulate_signal_transport_time(11)), + [ + 2.5509743680305879, + 3.2759504918578886, + 2.9027453686866318, + 2.7722064380611307, + 2.9975103080633256, + 3.3796483500672148, + 3.5099596226498524, + 4.2053418869706736, + 3.6197480580293133, + 4.9220361334622806, + 3.0411502792684506, + ], + ) def test_simulate_detector_mips(self): # Test with single angle - assert_almost_equal(self.simulation.simulate_detector_mips(1, 0.5), - 1.1818585) - assert_almost_equal(self.simulation.simulate_detector_mips(2, 0.2), - 1.8313342374) + assert_almost_equal(self.simulation.simulate_detector_mips(1, 0.5), 1.1818585) + assert_almost_equal(self.simulation.simulate_detector_mips(2, 0.2), 1.8313342374) # Test with multiple angles - assert_almost_equal(self.simulation.simulate_detector_mips(2, np.array([0.5, 1])), - 2.58167027) - assert_almost_equal(self.simulation.simulate_detector_mips(1, np.array([0.5, 1])), - 2.21526297) + assert_almost_equal(self.simulation.simulate_detector_mips(2, np.array([0.5, 1])), 2.58167027) + assert_almost_equal(self.simulation.simulate_detector_mips(1, np.array([0.5, 1])), 2.21526297) # Test limiting detector length - assert_almost_equal(self.simulation.simulate_detector_mips(1, np.radians(90)), - 47.6237460) - assert_almost_equal(self.simulation.simulate_detector_mips(2, np.array([np.radians(90), np.radians(87)])), - 74.6668728) + assert_almost_equal(self.simulation.simulate_detector_mips(1, np.radians(90)), 47.6237460) + assert_almost_equal( + self.simulation.simulate_detector_mips(2, np.array([np.radians(90), np.radians(87)])), + 74.6668728, + ) def test_generate_core_position(self): x, y = self.simulation.generate_core_position(500) @@ -82,18 +81,16 @@ def test_generate_core_position(self): assert_almost_equal(y, 317.2896993591305) def test_generate_azimuth(self): - self.assertEqual(self.simulation.generate_azimuth(), - -0.521366120872004) + self.assertAlmostEqual(self.simulation.generate_azimuth(), -0.521366120872004) def test_generate_energy(self): self.assertEqual(self.simulation.generate_energy(), 136117213526167.64) io = 1e17 - assert_almost_equal(self.simulation.generate_energy(io, io) / io, 1.) + assert_almost_equal(self.simulation.generate_energy(io, io) / io, 1.0) self.assertEqual(self.simulation.generate_energy(alpha=-3), 100005719231473.97) class ErrorlessSimulationTest(HiSPARCSimulationTest): - def setUp(self): self.simulation = ErrorlessSimulation random.seed(1) @@ -117,7 +114,7 @@ def test_simulate_adc_sampling(self): self.assertEqual(self.simulation.simulate_adc_sampling(0.1), 0.1) self.assertEqual(self.simulation.simulate_adc_sampling(1.25), 1.25) self.assertEqual(self.simulation.simulate_adc_sampling(2.5), 2.5) - self.assertEqual(self.simulation.simulate_adc_sampling(4), 4.) + self.assertEqual(self.simulation.simulate_adc_sampling(4), 4.0) def test_simulate_signal_transport_time(self): self.assertEqual(list(self.simulation.simulate_signal_transport_time()), [0]) diff --git a/sapphire/tests/simulations/test_gammas.py b/sapphire/tests/simulations/test_gammas.py index a8c430eb..d3e1ccd0 100644 --- a/sapphire/tests/simulations/test_gammas.py +++ b/sapphire/tests/simulations/test_gammas.py @@ -8,7 +8,6 @@ class GammasTest(unittest.TestCase): - def test_compton_edge(self): # Compton edges of well known gammas sources # http://web.mit.edu/lululiu/Public/8.13/xray/TKA%20files/annihilation-Na.pdf @@ -21,9 +20,7 @@ def test_compton_edge(self): # Co-60 # 1.17 MeV: 0.96 MeV # 1.33 MeV: 1.12 MeV - combinations = ((0.511, 0.340), (1.27, 1.06), - (0.662, 0.482), - (1.17, 0.96), (1.33, 1.12)) + combinations = ((0.511, 0.340), (1.27, 1.06), (0.662, 0.482), (1.17, 0.96), (1.33, 1.12)) for E, edge in combinations: self.assertAlmostEqual(gammas.compton_edge(E), edge, places=2) @@ -31,7 +28,7 @@ def test_compton_edge(self): def test_compton_mean_free_path(self): # Relevant mean-free-paths in vinyltoluene scintillator # Values checked with: Jos Steijer, Nikhef internal note, 16 juni 2010, figure 3 - combinations = ((1., 32.), (10., 60.)) + combinations = ((1.0, 32.0), (10.0, 60.0)) for E, edge in combinations: self.assertAlmostEqual(gammas.compton_mean_free_path(E), edge, places=0) @@ -39,7 +36,7 @@ def test_compton_mean_free_path(self): def test_pair_mean_free_path(self): # Relevant mean-free-paths in vinyltoluene scintillator # Values checked with: Jos Steijer, Nikhef internal note, 16 juni 2010, figure 5 - combinations = ((10, 249.), (1000., 62.)) + combinations = ((10, 249.0), (1000.0, 62.0)) for E, edge in combinations: self.assertAlmostEqual(gammas.pair_mean_free_path(E), edge, places=0) @@ -48,14 +45,14 @@ def test_pair_mean_free_path(self): def test_compton_energy_transfer(self, mock_random): # if random() return 1, energy should approach the kinematic maximum (compton edge) mock_random.return_value = 1.0 - for gamma_energy in [3., 10., 100.]: + for gamma_energy in [3.0, 10.0, 100.0]: expected = gammas.compton_edge(gamma_energy) self.assertAlmostEqual(gammas.compton_energy_transfer(gamma_energy), expected, places=0) # if random() returns 0, energy should approach 0 mock_random.return_value = 0.0 - for gamma_energy in [3., 10., 100.]: - self.assertAlmostEqual(gammas.compton_energy_transfer(gamma_energy), 0.) + for gamma_energy in [3.0, 10.0, 100.0]: + self.assertAlmostEqual(gammas.compton_energy_transfer(gamma_energy), 0.0) def test_energy_transfer_cross_section(self): # The plot from github.com/tomkooij/lio-project/photons/check_sapphire_gammas.py @@ -69,9 +66,9 @@ def test_energy_transfer_cross_section(self): self.assertAlmostEqual(gammas.energy_transfer_cross_section(E, edge) / barn, cross_section, places=1) def test_max_energy_transfer(self): - self.assertAlmostEqual(gammas.max_energy_deposit_in_mips(0., 1.), gammas.MAX_E / gammas.MIP) - self.assertAlmostEqual(gammas.max_energy_deposit_in_mips(0.5, 1.), 0.5 * gammas.MAX_E / gammas.MIP) - self.assertAlmostEqual(gammas.max_energy_deposit_in_mips(1., 1.), 0.) + self.assertAlmostEqual(gammas.max_energy_deposit_in_mips(0.0, 1.0), gammas.MAX_E / gammas.MIP) + self.assertAlmostEqual(gammas.max_energy_deposit_in_mips(0.5, 1.0), 0.5 * gammas.MAX_E / gammas.MIP) + self.assertAlmostEqual(gammas.max_energy_deposit_in_mips(1.0, 1.0), 0.0) @patch.object(gammas, 'compton_energy_transfer') @patch.object(gammas, 'pair_mean_free_path') @@ -81,12 +78,12 @@ def test_simulate_detector_mips_gammas_compton(self, mock_l_compton, mock_l_pair mock_l_compton.return_value = 1e-3 mock_l_pair.return_value = 1e50 - mock_compton.return_value = 1. + mock_compton.return_value = 1.0 p = np.array([10]) - theta = np.array([0.]) + theta = np.array([0.0]) mips = gammas.simulate_detector_mips_gammas(p, theta) - mock_compton.assert_called_once_with(10. / 1e6) + mock_compton.assert_called_once_with(10.0 / 1e6) self.assertLessEqual(mips, gammas.MAX_E) @patch.object(gammas, 'compton_energy_transfer') @@ -97,10 +94,10 @@ def test_simulate_detector_mips_gammas_pair(self, mock_l_compton, mock_l_pair, m mock_l_compton.return_value = 1e50 mock_l_pair.return_value = 1e-3 - mock_compton.return_value = 42. - energies = np.array([10., 7.]) # MeV + mock_compton.return_value = 42.0 + energies = np.array([10.0, 7.0]) # MeV p = energies * 1e6 # eV - theta = np.array([0.]) + theta = np.array([0.0]) for _ in range(100): mips = gammas.simulate_detector_mips_gammas(p, theta) @@ -110,14 +107,14 @@ def test_simulate_detector_mips_gammas_pair(self, mock_l_compton, mock_l_pair, m # not enough energy for pair production energies = np.array([0.5, 0.7]) # MeV p = energies * 1e6 # eV - theta = np.array([0., 0.]) + theta = np.array([0.0, 0.0]) for _ in range(100): self.assertEqual(gammas.simulate_detector_mips_gammas(p, theta), 0) @patch('sapphire.simulations.gammas.expovariate') def test_simulate_detector_mips_no_interaction(self, mock_expovariate): p = np.array([10e6]) - theta = np.array([0.]) + theta = np.array([0.0]) # force no interaction mock_expovariate.side_effect = [1e6, 1e3] @@ -139,5 +136,5 @@ def test_simulate_detector_mips_no_interaction(self, mock_expovariate): n = 30 mock_expovariate.side_effect = [4, 5] * n p = np.array([10e6] * n) - theta = np.array([1.] * n) # projected depth would be 126 cm + theta = np.array([1.0] * n) # projected depth would be 126 cm self.assertEqual(gammas.simulate_detector_mips_gammas(p, theta), 0) diff --git a/sapphire/tests/simulations/test_groundparticles.py b/sapphire/tests/simulations/test_groundparticles.py index a66e6d00..bbbb2e52 100644 --- a/sapphire/tests/simulations/test_groundparticles.py +++ b/sapphire/tests/simulations/test_groundparticles.py @@ -14,31 +14,27 @@ class GroundParticlesSimulationTest(unittest.TestCase): - def setUp(self): - - self.simulation = groundparticles.GroundParticlesSimulation.__new__( - groundparticles.GroundParticlesSimulation) + self.simulation = groundparticles.GroundParticlesSimulation.__new__(groundparticles.GroundParticlesSimulation) corsika_data_path = os.path.join(self_path, 'test_data/corsika.h5') self.corsika_data = tables.open_file(corsika_data_path, 'r') self.simulation.corsikafile = self.corsika_data + self.addCleanup(self.corsika_data.close) self.simulation.cluster = SingleDiamondStation() self.detectors = self.simulation.cluster.stations[0].detectors - def tearDown(self): - self.corsika_data.close() - def test__prepare_cluster_for_shower(self): - # Combinations of shower parameters and detector after transformations - combinations = (((0, 0, 0), (-0, -0, -0)), - ((10, -60, 0), (-10, 60, -0)), - ((10, -60, pi / 2), (60, 10, -pi / 2))) - - for input, expected in combinations: - self.simulation._prepare_cluster_for_shower(*input) + combinations = ( + ((0, 0, 0), (-0, -0, -0)), + ((10, -60, 0), (-10, 60, -0)), + ((10, -60, pi / 2), (60, 10, -pi / 2)), + ) + + for args, expected in combinations: + self.simulation._prepare_cluster_for_shower(*args) self.assertAlmostEqual(self.simulation.cluster.x, expected[0]) self.assertAlmostEqual(self.simulation.cluster.y, expected[1]) self.assertAlmostEqual(self.simulation.cluster.alpha, expected[2]) @@ -49,19 +45,17 @@ def test_get_particles_query_string(self): # Combinations of shower parameters and detector after transformations shower_parameters = {'zenith': 0} self.simulation.corsika_azimuth = 0 - combinations = ((0, 0, 0), - (10, -60, 0), - (10, -60, pi / 2)) + combinations = ((0, 0, 0), (10, -60, 0), (10, -60, pi / 2)) - for input in combinations: - self.simulation._prepare_cluster_for_shower(*input) + for args in combinations: + self.simulation._prepare_cluster_for_shower(*args) self.simulation.get_particles_in_detector(self.detectors[0], shower_parameters) x, y = self.detectors[0].get_xy_coordinates() - size = sqrt(.5) / 2. + size = sqrt(0.5) / 2.0 self.simulation.groundparticles.read_where.assert_called_with( '(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f) & ' - '(particle_id >= 2) & (particle_id <= 6)' % - (x - size, x + size, y - size, y + size)) + '(particle_id >= 2) & (particle_id <= 6)' % (x - size, x + size, y - size, y + size), + ) def test_get_particles(self): self.groundparticles = self.corsika_data.root.groundparticles @@ -69,44 +63,46 @@ def test_get_particles(self): shower_parameters = {'zenith': 0} self.simulation.corsika_azimuth = 0 - combinations = (((0, 0, 0), (1, 0, 0, 0)), - ((1, -1, 0), (0, 1, 0, 3)), - ((1, -1, pi / 2), (1, 1, 0, 1))) - - for input, expected in combinations: - self.simulation._prepare_cluster_for_shower(*input) + combinations = ( + ((0, 0, 0), (1, 0, 0, 0)), + ((1, -1, 0), (0, 1, 0, 3)), + ((1, -1, pi / 2), (1, 1, 0, 1)), + ) + + for args, expected in combinations: + self.simulation._prepare_cluster_for_shower(*args) for d, e in zip(self.detectors, expected): self.assertEqual(len(self.simulation.get_particles_in_detector(d, shower_parameters)), e) class GroundParticlesGammaSimulationTest(unittest.TestCase): - def setUp(self): self.simulation = groundparticles.GroundParticlesGammaSimulation.__new__( - groundparticles.GroundParticlesGammaSimulation) + groundparticles.GroundParticlesGammaSimulation, + ) corsika_data_path = os.path.join(self_path, 'test_data/corsika.h5') self.corsika_data = tables.open_file(corsika_data_path, 'r') + self.addCleanup(self.corsika_data.close) self.simulation.corsikafile = self.corsika_data self.simulation.cluster = SingleDiamondStation() self.detectors = self.simulation.cluster.stations[0].detectors - def tearDown(self): - self.corsika_data.close() - def test_get_particles(self): self.groundparticles = self.corsika_data.root.groundparticles self.simulation.groundparticles = self.groundparticles shower_parameters = {'zenith': 0} self.simulation.corsika_azimuth = 0 - combinations = (((0, 0, 0), (1, 0, 0, 0), (5, 0, 2, 4)), - ((1, -1, 0), (0, 1, 0, 3), (1, 1, 4, 8)), - ((1, -1, pi / 2), (1, 1, 0, 1), (1, 3, 6, 1))) - - for input, n1, n2 in combinations: - self.simulation._prepare_cluster_for_shower(*input) + combinations = ( + ((0, 0, 0), (1, 0, 0, 0), (5, 0, 2, 4)), + ((1, -1, 0), (0, 1, 0, 3), (1, 1, 4, 8)), + ((1, -1, pi / 2), (1, 1, 0, 1), (1, 3, 6, 1)), + ) + + for args, n1, n2 in combinations: + self.simulation._prepare_cluster_for_shower(*args) for d, n_lep, n_gam in zip(self.detectors, n1, n2): lep, gamma = self.simulation.get_particles_in_detector(d, shower_parameters) self.assertEqual(len(lep), n_lep) @@ -114,10 +110,8 @@ def test_get_particles(self): class DetectorBoundarySimulationTest(GroundParticlesSimulationTest): - def setUp(self): - self.simulation = groundparticles.DetectorBoundarySimulation.__new__( - groundparticles.DetectorBoundarySimulation) + self.simulation = groundparticles.DetectorBoundarySimulation.__new__(groundparticles.DetectorBoundarySimulation) corsika_data_path = os.path.join(self_path, 'test_data/corsika.h5') self.corsika_data = tables.open_file(corsika_data_path, 'r') @@ -132,11 +126,10 @@ def test_get_particles_query_string(self): # Combinations of shower parameters and detector after transformations shower_parameters = {'zenith': 0} self.simulation.corsika_azimuth = 0 - combinations = ((0, 0, 0), - (10, -60, 0)) + combinations = ((0, 0, 0), (10, -60, 0)) - for input in combinations: - self.simulation._prepare_cluster_for_shower(*input) + for args in combinations: + self.simulation._prepare_cluster_for_shower(*args) self.simulation.get_particles_in_detector(self.detectors[0], shower_parameters) x, y = self.detectors[0].get_xy_coordinates() size = 0.6 @@ -144,8 +137,8 @@ def test_get_particles_query_string(self): '(x >= %f) & (x <= %f) & (y >= %f) & (y <= %f) & ' '(b11 < y - 0.000000 * x) & (y - 0.000000 * x < b12) & ' '(b21 < x) & (x < b22) & ' - '(particle_id >= 2) & (particle_id <= 6)' % - (x - size, x + size, y - size, y + size)) + '(particle_id >= 2) & (particle_id <= 6)' % (x - size, x + size, y - size, y + size), + ) def test_get_particles(self): self.groundparticles = self.corsika_data.root.groundparticles @@ -153,37 +146,40 @@ def test_get_particles(self): shower_parameters = {'zenith': 0} self.simulation.corsika_azimuth = 0 - combinations = (((0, 0, 0), (1, 0, 1, 0)), - ((1, -1, 0), (0, 1, 1, 3)), - ((1, -1, pi / 2), (1, 1, 0, 1))) - - for input, expected in combinations: - self.simulation._prepare_cluster_for_shower(*input) + combinations = ( + ((0, 0, 0), (1, 0, 1, 0)), + ((1, -1, 0), (0, 1, 1, 3)), + ((1, -1, pi / 2), (1, 1, 0, 1)), + ) + + for args, expected in combinations: + self.simulation._prepare_cluster_for_shower(*args) for d, e in zip(self.detectors, expected): self.assertEqual(len(self.simulation.get_particles_in_detector(d, shower_parameters)), e) def test_get_line_boundary_eqs(self): - combos = ((((0, 0), (1, 1), (0, 2)), (0.0, 'y - 1.000000 * x', 2.0)), - (((0, 0), (0, 1), (1, 2)), (0.0, 'x', 1))) + combinations = ( + (((0, 0), (1, 1), (0, 2)), (0.0, 'y - 1.000000 * x', 2.0)), + (((0, 0), (0, 1), (1, 2)), (0.0, 'x', 1)), + ) - for input, expected in combos: - result = self.simulation.get_line_boundary_eqs(*input) + for args, expected in combinations: + result = self.simulation.get_line_boundary_eqs(*args) self.assertEqual(result, expected) class FixedCoreDistanceSimulationTest(unittest.TestCase): - def test_fixed_core_distance(self): r = random.uniform(1e-15, 4000, size=300) x, y = groundparticles.FixedCoreDistanceSimulation.generate_core_position(r) - testing.assert_allclose(sqrt(x ** 2 + y ** 2), r, 1e-11) + testing.assert_allclose(sqrt(x**2 + y**2), r, 1e-11) class MultipleGroundParticlesSimulationTest(unittest.TestCase): - def setUp(self): self.simulation = groundparticles.MultipleGroundParticlesSimulation.__new__( - groundparticles.MultipleGroundParticlesSimulation) + groundparticles.MultipleGroundParticlesSimulation, + ) self.simulation.cq = Mock() self.simulation.max_core_distance = sentinel.max_core_distance @@ -201,18 +197,16 @@ def test_generate_shower_parameters(self): self.simulation.select_simulation.return_value = None shower_parameters = self.simulation.generate_shower_parameters() self.assertRaises(StopIteration, shower_parameters.__next__) - self.assertEqual(self.simulation.select_simulation.call_count, - self.simulation.n) + self.assertEqual(self.simulation.select_simulation.call_count, self.simulation.n) def test_select_simulation(self): self.simulation.generate_zenith = lambda: 0.27 # 15.5 deg - self.simulation.generate_energy = lambda e_min, e_max: 10 ** 16.4 + self.simulation.generate_energy = lambda e_min, e_max: 10**16.4 self.simulation.available_energies = set(arange(12, 18, 0.5)) - self.simulation.available_zeniths = {e: set(arange(0, 60, 7.5)) - for e in self.simulation.available_energies} + self.simulation.available_zeniths = {e: set(arange(0, 60, 7.5)) for e in self.simulation.available_energies} self.simulation.cq.simulations.return_value = [sentinel.sim] result = self.simulation.select_simulation() - self.simulation.cq.simulations.assert_called_once_with(energy=16.5, zenith=15.) + self.simulation.cq.simulations.assert_called_once_with(energy=16.5, zenith=15.0) self.assertEqual(result, sentinel.sim) self.simulation.cq.simulations.return_value = [] diff --git a/sapphire/tests/simulations/test_ldf.py b/sapphire/tests/simulations/test_ldf.py index 04df5dcb..62aa3262 100644 --- a/sapphire/tests/simulations/test_ldf.py +++ b/sapphire/tests/simulations/test_ldf.py @@ -7,7 +7,6 @@ class BaseLdfSimulationTest(unittest.TestCase): - def setUp(self): self.simulation = ldf.BaseLdfSimulation random.seed(1) @@ -21,7 +20,6 @@ def test_simulate_particles_for_density(self): class BaseLdfSimulationWithoutErrorsTest(BaseLdfSimulationTest): - def setUp(self): super().setUp() self.simulation = ldf.BaseLdfSimulationWithoutErrors @@ -34,7 +32,6 @@ def test_simulate_particles_for_density(self): class BaseLdfTest(unittest.TestCase): - def setUp(self): self.ldf = ldf.BaseLdf() @@ -48,7 +45,7 @@ def test_calculate_ldf_value(self): def test_calculate_core_distance(self): # TODO: Add core distances for inclined showers - self.assertEqual(self.ldf.calculate_core_distance(0., 0., 0., 0., 0., 0.), 0) - self.assertEqual(self.ldf.calculate_core_distance(10., 0., 0., 0., 0., 0.), 10.) - self.assertEqual(self.ldf.calculate_core_distance(10., 0., 10., 0., 0., 0.), 0.) - self.assertEqual(self.ldf.calculate_core_distance(10., 3., 10., 3., 0., 0.), 0.) + self.assertEqual(self.ldf.calculate_core_distance(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 0) + self.assertEqual(self.ldf.calculate_core_distance(10.0, 0.0, 0.0, 0.0, 0.0, 0.0), 10.0) + self.assertEqual(self.ldf.calculate_core_distance(10.0, 0.0, 10.0, 0.0, 0.0, 0.0), 0.0) + self.assertEqual(self.ldf.calculate_core_distance(10.0, 3.0, 10.0, 3.0, 0.0, 0.0), 0.0) diff --git a/sapphire/tests/simulations/test_simulation_acceptance.py b/sapphire/tests/simulations/test_simulation_acceptance.py index 3e711362..82cb148f 100644 --- a/sapphire/tests/simulations/test_simulation_acceptance.py +++ b/sapphire/tests/simulations/test_simulation_acceptance.py @@ -17,44 +17,40 @@ class GroundparticlesSimulationAcceptanceTest(unittest.TestCase): - def test_simulation_output(self): """Perform a simulation and verify the output""" output_path = create_tempfile_path() + self.addCleanup(os.remove, output_path) perform_groundparticlessimulation(output_path) validate_results(self, test_data_path, output_path) - os.remove(output_path) class GroundparticlesGammaSimulationAcceptanceTest(unittest.TestCase): - def test_simulation_output(self): """Perform a simulation and verify the output""" output_path = create_tempfile_path() + self.addCleanup(os.remove, output_path) perform_groundparticlesgammasimulation(output_path) validate_results(self, test_data_gamma, output_path) - os.remove(output_path) class FlatFrontSimulationAcceptanceTest(unittest.TestCase): - def test_simulation_output(self): """Perform a simulation and verify the output""" output_path = create_tempfile_path() + self.addCleanup(os.remove, output_path) perform_flatfrontsimulation(output_path) validate_results(self, test_data_flat, output_path) - os.remove(output_path) class NkgLdfSimulationAcceptanceTest(unittest.TestCase): - def test_simulation_output(self): """Perform a simulation and verify the output""" output_path = create_tempfile_path() + self.addCleanup(os.remove, output_path) perform_nkgldfsimulation(output_path) validate_results(self, test_data_nkg, output_path) - os.remove(output_path) diff --git a/sapphire/tests/test_api.py b/sapphire/tests/test_api.py index a76c3996..29c07d11 100644 --- a/sapphire/tests/test_api.py +++ b/sapphire/tests/test_api.py @@ -54,20 +54,18 @@ def test__retrieve_url(self, mock_urlopen): def test__get_tsv(self, mock_urlopen): mock_urlopen.return_value.read.return_value = b'1297956608\t52.3414237\t4.8807081\t43.32' self.api.force_fresh = True - self.assertEqual(self.api._get_tsv('gps/2/').tolist(), - [(1297956608, 52.3414237, 4.8807081, 43.32)]) + self.assertEqual(self.api._get_tsv('gps/2/').tolist(), [(1297956608, 52.3414237, 4.8807081, 43.32)]) mock_urlopen.return_value.read.side_effect = URLError('no interwebs!') self.assertRaises(Exception, self.api._get_tsv, 'gps/2/') self.api.force_fresh = False self.assertRaises(Exception, self.api._get_tsv, 'gps/0/') with warnings.catch_warnings(record=True) as warned: - self.assertEqual(self.api._get_tsv('gps/2/').tolist()[0], - (1297953008, 52.3414237, 4.8807081, 43.32)) + self.assertEqual(self.api._get_tsv('gps/2/').tolist()[0], (1297953008, 52.3414237, 4.8807081, 43.32)) self.assertEqual(len(warned), 1) -@unittest.skipUnless(api.API.check_connection(), "Internet connection required") +@unittest.skipUnless(api.API.check_connection(), 'Internet connection required') class APITestsLive(unittest.TestCase): def setUp(self): self.api = api.API() @@ -81,7 +79,7 @@ def test__get_json(self): self.assertIsInstance(json, dict) -@unittest.skipUnless(api.API.check_connection(), "Internet connection required") +@unittest.skipUnless(api.API.check_connection(), 'Internet connection required') class NetworkTests(unittest.TestCase): def setUp(self): self.network = api.Network(force_fresh=True, force_stale=False) @@ -91,29 +89,34 @@ def setUp(self): @patch.object(api.Network, 'clusters') @patch.object(api.Network, 'subclusters') @patch.object(api.Network, 'stations') - def test_nested_network(self, mock_stations, mock_subcluster, - mock_clusters, mock_countries): - mock_countries.return_value = [{'name': sentinel.country_name, - 'number': sentinel.country_number}] - mock_clusters.return_value = [{'name': sentinel.cluster_name, - 'number': sentinel.cluster_number}] - mock_subcluster.return_value = [{'name': sentinel.subcluster_name, - 'number': sentinel.subcluster_number}] - mock_stations.return_value = [{'name': sentinel.station_name, - 'number': sentinel.station_number}] + def test_nested_network(self, mock_stations, mock_subcluster, mock_clusters, mock_countries): + mock_countries.return_value = [{'name': sentinel.country_name, 'number': sentinel.country_number}] + mock_clusters.return_value = [{'name': sentinel.cluster_name, 'number': sentinel.cluster_number}] + mock_subcluster.return_value = [{'name': sentinel.subcluster_name, 'number': sentinel.subcluster_number}] + mock_stations.return_value = [{'name': sentinel.station_name, 'number': sentinel.station_number}] nested_network = self.network.nested_network() - self.assertEqual(nested_network, - [{'clusters': [ - {'subclusters': [ - {'stations': [ - {'name': sentinel.station_name, - 'number': sentinel.station_number}], - 'name': sentinel.subcluster_name, - 'number': sentinel.subcluster_number}], + self.assertEqual( + nested_network, + [ + { + 'clusters': [ + { + 'subclusters': [ + { + 'stations': [{'name': sentinel.station_name, 'number': sentinel.station_number}], + 'name': sentinel.subcluster_name, + 'number': sentinel.subcluster_number, + }, + ], 'name': sentinel.cluster_name, - 'number': sentinel.cluster_number}], - 'name': sentinel.country_name, - 'number': sentinel.country_number}]) + 'number': sentinel.cluster_number, + }, + ], + 'name': sentinel.country_name, + 'number': sentinel.country_number, + }, + ], + ) def test_lazy_countries(self): self.laziness_of_method('countries') @@ -152,43 +155,37 @@ def test_bad_subcluster(self): self.assertRaises(Exception, self.network.subclusters, cluster=bad_number) def test_country_numbers(self): - self.network._all_countries = [{'number': sentinel.number1}, - {'number': sentinel.number2}] - self.assertEqual(self.network.country_numbers(), - [sentinel.number1, sentinel.number2]) + self.network._all_countries = [{'number': sentinel.number1}, {'number': sentinel.number2}] + self.assertEqual(self.network.country_numbers(), [sentinel.number1, sentinel.number2]) @patch.object(api.Network, 'clusters') @patch.object(api.Network, 'validate_numbers') def test_cluster_numbers(self, mock_validate, mock_clusters): - mock_clusters.return_value = [{'number': sentinel.number1}, - {'number': sentinel.number2}] - self.assertEqual(self.network.cluster_numbers(sentinel.country), - [sentinel.number1, sentinel.number2]) + mock_clusters.return_value = [{'number': sentinel.number1}, {'number': sentinel.number2}] + self.assertEqual(self.network.cluster_numbers(sentinel.country), [sentinel.number1, sentinel.number2]) mock_clusters.assert_called_once_with(country=sentinel.country) @patch.object(api.Network, 'subclusters') @patch.object(api.Network, 'validate_numbers') def test_subcluster_numbers(self, mock_validate, mock_subclusters): - mock_subclusters.return_value = [{'number': sentinel.number1}, - {'number': sentinel.number2}] - self.assertEqual(self.network.subcluster_numbers(sentinel.country, - sentinel.cluster), - [sentinel.number1, sentinel.number2]) - mock_subclusters.assert_called_once_with(country=sentinel.country, - cluster=sentinel.cluster) + mock_subclusters.return_value = [{'number': sentinel.number1}, {'number': sentinel.number2}] + self.assertEqual( + self.network.subcluster_numbers(sentinel.country, sentinel.cluster), + [sentinel.number1, sentinel.number2], + ) + mock_subclusters.assert_called_once_with(country=sentinel.country, cluster=sentinel.cluster) @patch.object(api.Network, 'stations') @patch.object(api.Network, 'validate_numbers') def test_station_numbers(self, mock_validate, mock_stations): - mock_stations.return_value = [{'number': sentinel.number1}, - {'number': sentinel.number2}] - station_numbers = self.network.station_numbers(sentinel.country, - sentinel.cluster, - sentinel.subcluster) + mock_stations.return_value = [{'number': sentinel.number1}, {'number': sentinel.number2}] + station_numbers = self.network.station_numbers(sentinel.country, sentinel.cluster, sentinel.subcluster) self.assertEqual(station_numbers, [sentinel.number1, sentinel.number2]) - mock_stations.assert_called_once_with(country=sentinel.country, - cluster=sentinel.cluster, - subcluster=sentinel.subcluster) + mock_stations.assert_called_once_with( + country=sentinel.country, + cluster=sentinel.cluster, + subcluster=sentinel.subcluster, + ) @patch.object(api.Network, '_retrieve_url') def test_station_numbers_disconnected(self, mock_retrieve_url): @@ -259,36 +256,26 @@ def test_coincidence_number(self): def test_uptime(self, mock_urlopen): # datetime(2014,1,1) 2 days on, 2 days off, 1 day on sn = b'[{"name": "foo", "number": 501}, {"name": "bar", "number": 502}]' - event_time_1 = str.encode('1388534400\t2000.\n' - '1388538000\t2000.\n' - '1388541600\t12.\n' - '1388545200\t125.\n' - '1388548800\t3000.\n') + event_time_1 = str.encode( + '1388534400\t2000.\n1388538000\t2000.\n1388541600\t12.\n1388545200\t125.\n1388548800\t3000.\n', + ) # datetime(2014,1,1) 2 days off, 3 days on - event_time_2 = str.encode('1388534400\t50.\n' - '1388538000\t20.\n' - '1388541600\t2000.\n' - '1388545200\t2000.\n' - '1388548800\t3000.\n') + event_time_2 = str.encode( + '1388534400\t50.\n1388538000\t20.\n1388541600\t2000.\n1388545200\t2000.\n1388548800\t3000.\n', + ) # station 1 mock_urlopen.return_value.read.side_effect = [sn, event_time_1] * 4 self.assertEqual(self.network.uptime([501]), 3) - self.assertEqual(self.network.uptime([501], start=datetime(2014, 1, 1), - end=datetime(2014, 1, 1, 2)), 2) - self.assertEqual(self.network.uptime([501], start=datetime(2014, 1, 1), - end=datetime(2014, 1, 2)), 3) - self.assertEqual(self.network.uptime([501], start=datetime(2013, 1, 1), - end=datetime(2013, 1, 2)), 0) + self.assertEqual(self.network.uptime([501], start=datetime(2014, 1, 1), end=datetime(2014, 1, 1, 2)), 2) + self.assertEqual(self.network.uptime([501], start=datetime(2014, 1, 1), end=datetime(2014, 1, 2)), 3) + self.assertEqual(self.network.uptime([501], start=datetime(2013, 1, 1), end=datetime(2013, 1, 2)), 0) # station 2 mock_urlopen.return_value.read.side_effect = [sn, event_time_2] * 3 self.assertEqual(self.network.uptime([501]), 3) - self.assertEqual(self.network.uptime([501], start=datetime(2014, 1, 1), - end=datetime(2014, 1, 1, 2)), 0) - self.assertEqual(self.network.uptime([501], start=datetime(2014, 1, 1), - end=datetime(2014, 1, 2)), 3) + self.assertEqual(self.network.uptime([501], start=datetime(2014, 1, 1), end=datetime(2014, 1, 1, 2)), 0) + self.assertEqual(self.network.uptime([501], start=datetime(2014, 1, 1), end=datetime(2014, 1, 2)), 3) # two stations together - mock_urlopen.return_value.read.side_effect = [sn, event_time_1, - sn, event_time_2] + mock_urlopen.return_value.read.side_effect = [sn, event_time_1, sn, event_time_2] self.assertEqual(self.network.uptime([501, 502]), 1) def laziness_of_method(self, method): @@ -303,7 +290,6 @@ def laziness_of_method(self, method): class StaleNetworkTests(NetworkTests): - """Tests using local data Overwrite tests using data not available locally. @@ -331,16 +317,19 @@ def test_coincidence_time(self): def test_coincidence_number(self): self.assertRaises(Exception, self.network.coincidence_number, 2013, 1, 1) - @unittest.skipIf(has_extended_local_data('eventtime/%d/' % STATION), - "Local data is extended") + @unittest.skipIf(has_extended_local_data('eventtime/%d/' % STATION), 'Local data is extended') def test_uptime(self): self.assertRaises(Exception, self.network.uptime, [501]) - self.assertRaises(Exception, self.network.uptime, [501], - start=datetime(2014, 1, 1), - end=datetime(2014, 1, 1, 2)) + self.assertRaises( + Exception, + self.network.uptime, + [501], + start=datetime(2014, 1, 1), + end=datetime(2014, 1, 1, 2), + ) -@unittest.skipUnless(api.API.check_connection(), "Internet connection required") +@unittest.skipUnless(api.API.check_connection(), 'Internet connection required') class StationTests(unittest.TestCase): def setUp(self): self.station = api.Station(STATION, force_fresh=True, force_stale=False) @@ -355,7 +344,7 @@ def test_no_stale_station(self, mock_retrieve_url): def test_bad_station_number(self, mock_station_numbers): mock_station_numbers.return_value = [501, 502, 503] with warnings.catch_warnings(record=True) as warned: - warnings.simplefilter("always") + warnings.simplefilter('always') api.Station(1) self.assertEqual(len(warned), 1) @@ -370,8 +359,7 @@ def test_properties(self): def test_config(self): self.assertEqual(self.station.config()['detnum'], 501) - self.assertAlmostEqual(self.station.config( - date(2011, 1, 1))['mas_ch1_current'], 7.54901960784279) + self.assertAlmostEqual(self.station.config(date(2011, 1, 1))['mas_ch1_current'], 7.54901960784279) def test_num_events(self): self.assertIsInstance(self.station.n_events(2004), int) @@ -421,10 +409,11 @@ def test_has_weather_bad_args(self): @patch.object(api, 'urlopen') def test_event_trace(self, mock_urlopen): def make_trace(start, end): - """ return a trace (type bytes) to mock urlopen """ + """return a trace (type bytes) to mock urlopen""" trace = '[%s]' % ', '.join(str(v) for v in range(start, end)) return_value = '[%s]' % ', '.join(4 * [trace]) return return_value.encode() + mock_urlopen.return_value.read.return_value = make_trace(0, 11) self.assertEqual(self.station.event_trace(1378771205, 571920029)[3][9], 9) mock_urlopen.return_value.read.return_value = make_trace(200, 211) @@ -472,14 +461,12 @@ def test_voltage(self): data2 = self.station.voltages[0] data = self.station.voltage(0) # 1970-1-1 - self.assertEqual(data, [data2['voltage1'], data2['voltage2'], - data2['voltage3'], data2['voltage4']]) + self.assertEqual(data, [data2['voltage1'], data2['voltage2'], data2['voltage3'], data2['voltage4']]) data2 = self.station.voltages[-1] data1 = self.station.voltage(FUTURE) data = self.station.voltage() - self.assertEqual(data1, [data2['voltage1'], data2['voltage2'], - data2['voltage3'], data2['voltage4']]) + self.assertEqual(data1, [data2['voltage1'], data2['voltage2'], data2['voltage3'], data2['voltage4']]) self.assertEqual(data, data1) def test_laziness_currents(self): @@ -518,8 +505,21 @@ def test_laziness_station_layouts(self): self.laziness_of_attribute('station_layouts') def test_triggers(self): - names = ('timestamp', 'low1', 'low2', 'low3', 'low4', 'high1', 'high2', - 'high3', 'high4', 'n_low', 'n_high', 'and_or', 'external') + names = ( + 'timestamp', + 'low1', + 'low2', + 'low3', + 'low4', + 'high1', + 'high2', + 'high3', + 'high4', + 'n_low', + 'n_high', + 'and_or', + 'external', + ) data = self.station.triggers self.assertEqual(data.dtype.names, names) @@ -535,11 +535,25 @@ def test_laziness_triggers(self): self.laziness_of_attribute('triggers') def test_station_layouts(self): - names = ('timestamp', - 'radius1', 'alpha1', 'height1', 'beta1', - 'radius2', 'alpha2', 'height2', 'beta2', - 'radius3', 'alpha3', 'height3', 'beta3', - 'radius4', 'alpha4', 'height4', 'beta4') + names = ( + 'timestamp', + 'radius1', + 'alpha1', + 'height1', + 'beta1', + 'radius2', + 'alpha2', + 'height2', + 'beta2', + 'radius3', + 'alpha3', + 'height3', + 'beta3', + 'radius4', + 'alpha4', + 'height4', + 'beta4', + ) data = self.station.station_layouts self.assertEqual(data.dtype.names, names) @@ -602,7 +616,7 @@ def test_station_timing_offset(self): # Zero offset to self data = self.station.station_timing_offset(STATION) - self.assertEqual(data, (0., 0.)) + self.assertEqual(data, (0.0, 0.0)) def laziness_of_attribute(self, attribute): with patch.object(api.API, '_get_tsv') as mock_get_tsv: @@ -626,7 +640,6 @@ def laziness_of_method(self, method, args=None): class StaleStationTests(StationTests): - """Tests using local data Overwrite tests using data not available locally. @@ -658,7 +671,7 @@ def test_has_weather(self): @patch.object(api, 'urlopen') def test_event_trace(self, mock_urlopen): - trace = '[%s]' % ', '.join(str(v) for v in range(0, 11)) + trace = '[%s]' % ', '.join(str(v) for v in range(11)) mock_urlopen.return_value.read.return_value = '[%s]' % ', '.join(4 * [trace]) self.assertRaises(Exception, self.station.event_trace, 1378771205, 571920029) self.assertRaises(Exception, self.station.event_trace, 1378771205, 571920029, raw=True) diff --git a/sapphire/tests/test_clusters.py b/sapphire/tests/test_clusters.py index 7c55fe95..745ad9a0 100644 --- a/sapphire/tests/test_clusters.py +++ b/sapphire/tests/test_clusters.py @@ -15,16 +15,19 @@ def setUp(self): self.mock_station = Mock() self.detector_1 = clusters.Detector(self.mock_station, (1, 0, 0), 'LR') self.detector_2 = clusters.Detector(self.mock_station, (-1, 2, 1), 'UD') - self.detector_s = clusters.Detector(self.mock_station, - (sentinel.x, sentinel.y, sentinel.z), - sentinel.orientation) - self.detector_4d = clusters.Detector(station=self.mock_station, - position=([0, 5], [0, 5], [0, 5]), - detector_timestamps=[0, 5]) + self.detector_s = clusters.Detector( + self.mock_station, + (sentinel.x, sentinel.y, sentinel.z), + sentinel.orientation, + ) + self.detector_4d = clusters.Detector( + station=self.mock_station, + position=([0, 5], [0, 5], [0, 5]), + detector_timestamps=[0, 5], + ) def test_bad_arguments(self): - self.assertRaises(Exception, clusters.Detector, self.mock_station, - (1, 0, 0), 'LR', [0, 1]) + self.assertRaises(Exception, clusters.Detector, self.mock_station, (1, 0, 0), 'LR', [0, 1]) def test__update_timestamp(self): self.assertEqual(self.detector_4d.index, -1) @@ -39,9 +42,9 @@ def test_4d_positions(self): self.assertEqual(self.detector_4d.get_coordinates(), (0, 0, 0)) def test_detector_size(self): - self.assertEqual(self.detector_1.detector_size, (.5, 1.)) - self.assertEqual(self.detector_2.detector_size, (.5, 1.)) - self.assertEqual(self.detector_s.detector_size, (.5, 1.)) + self.assertEqual(self.detector_1.detector_size, (0.5, 1.0)) + self.assertEqual(self.detector_2.detector_size, (0.5, 1.0)) + self.assertEqual(self.detector_s.detector_size, (0.5, 1.0)) def test_get_area(self): self.assertEqual(self.detector_1.get_area(), 0.5) @@ -83,9 +86,9 @@ def test_get_cylindrical_coordinates(self): self.assertEqual(coordinates, (sqrt(5), atan2(2, -1), 1)) def test_left_right_get_corners(self): - self.mock_station.get_coordinates.return_value = (.25, 3, 0, 0) + self.mock_station.get_coordinates.return_value = (0.25, 3, 0, 0) corners = self.detector_1.get_corners() - self.assertEqual(corners, [(.75, 3.25), (.75, 2.75), (1.75, 2.75), (1.75, 3.25)]) + self.assertEqual(corners, [(0.75, 3.25), (0.75, 2.75), (1.75, 2.75), (1.75, 3.25)]) def test_left_right_get_corners_rotated(self): self.mock_station.get_coordinates.return_value = (0, 0, 0, pi / 2) @@ -110,22 +113,34 @@ class StationTests(unittest.TestCase): def setUp(self): with patch('sapphire.clusters.Detector') as mock_detector: self.cluster = Mock() - self.station_1 = clusters.Station(self.cluster, 1, (0, 1, 2), pi / 4, - [((3, 4), 'LR')]) - self.station_s = clusters.Station(self.cluster, sentinel.id, - (sentinel.x, sentinel.y, sentinel.z), - sentinel.angle, [], - number=sentinel.number) - self.station_4d = clusters.Station(self.cluster, 4, - ([0, 5], [0, 5], [0, 5]), (0, pi), - station_timestamps=[0, 5]) + self.station_1 = clusters.Station(self.cluster, 1, (0, 1, 2), pi / 4, [((3, 4), 'LR')]) + self.station_s = clusters.Station( + self.cluster, + sentinel.id, + (sentinel.x, sentinel.y, sentinel.z), + sentinel.angle, + [], + number=sentinel.number, + ) + self.station_4d = clusters.Station( + self.cluster, + 4, + ([0, 5], [0, 5], [0, 5]), + (0, pi), + station_timestamps=[0, 5], + ) self.mock_detector_instance = mock_detector.return_value def test_bad_arguments(self): with patch('sapphire.clusters.Detector'): - self.assertRaises(Exception, clusters.Station, - cluster=self.cluster, station_id=1, - position=(0, 1, 2), station_timestamps=[1, 2]) + self.assertRaises( + Exception, + clusters.Station, + cluster=self.cluster, + station_id=1, + position=(0, 1, 2), + station_timestamps=[1, 2], + ) def test__update_timestamp(self): self.assertEqual(self.station_4d.index, -1) @@ -184,8 +199,7 @@ def test_get_coordinates(self): # Trivial cluster.get_coordinates.return_value = (0, 0, 0, 0) - station = clusters.Station(cluster, 1, position=(0, 0), angle=0, - detectors=[((0, 0), 'LR')]) + station = clusters.Station(cluster, 1, position=(0, 0), angle=0, detectors=[((0, 0), 'LR')]) coordinates = station.get_coordinates() self.assertEqual(coordinates, (0, 0, 0, 0)) coordinates = station.get_xyalpha_coordinates() @@ -243,8 +257,7 @@ def test_get_polar_alpha_coordinates(self): # Trivial cluster.get_coordinates.return_value = (0, 0, 0, 0) - station = clusters.Station(cluster, 1, position=(0, 0), angle=0, - detectors=[((0, 0), 'LR')]) + station = clusters.Station(cluster, 1, position=(0, 0), angle=0, detectors=[((0, 0), 'LR')]) coordinates = station.get_polar_alpha_coordinates() self.assertEqual(coordinates, (0, 0, 0)) @@ -269,19 +282,28 @@ def test_get_polar_alpha_coordinates(self): def test_calc_r_and_phi_for_detectors(self): cluster = Mock() cluster.get_coordinates.return_value = (0, 0, 0, 0) - station = clusters.Station(cluster, 1, position=(0, 0), angle=0, - detectors=[((0, 0), 'LR'), ((10., 10.), 'LR')]) + station = clusters.Station( + cluster, + 1, + position=(0, 0), + angle=0, + detectors=[((0, 0), 'LR'), ((10.0, 10.0), 'LR')], + ) r, phi = station.calc_r_and_phi_for_detectors(0, 1) - self.assertAlmostEqual(r ** 2, 10 ** 2 + 10 ** 2) + self.assertAlmostEqual(r**2, 10**2 + 10**2) self.assertAlmostEqual(phi, pi / 4) def test_calc_center_of_mass_coordinates(self): cluster = Mock() cluster.get_coordinates.return_value = (0, 0, 0, 0) - station = clusters.Station(cluster, 1, position=(0, 0), angle=0, - detectors=[((0, 0, 0), 'LR'), ((10, 9, 1), 'LR'), - ((nan, nan, nan), 'LR'), ((nan, nan, nan), 'LR')]) + station = clusters.Station( + cluster, + 1, + position=(0, 0), + angle=0, + detectors=[((0, 0, 0), 'LR'), ((10, 9, 1), 'LR'), ((nan, nan, nan), 'LR'), ((nan, nan, nan), 'LR')], + ) center = station.calc_xy_center_of_mass_coordinates() self.assert_tuple_almost_equal(center, (5, 4.5)) center = station.calc_center_of_mass_coordinates() @@ -291,7 +313,7 @@ def assert_tuple_almost_equal(self, actual, expected): self.assertIsInstance(actual, tuple) self.assertIsInstance(expected, tuple) - msg = f"Tuples differ: {str(actual)} != {str(expected)}" + msg = f'Tuples differ: {actual!s} != {expected!s}' for actual_value, expected_value in zip(actual, expected): self.assertAlmostEqual(actual_value, expected_value, msg=msg) @@ -309,8 +331,7 @@ def test_add_station(self): detector_list = Mock(name='detector_list') number = Mock(name='number') cluster._add_station((x, y, z), angle, detector_list, number=number) - mock_station.assert_called_with(cluster, 0, (x, y, z), angle, - detector_list, None, None, number) + mock_station.assert_called_with(cluster, 0, (x, y, z), angle, detector_list, None, None, number) def test_set_timestamp(self): with patch('sapphire.clusters.Station'): @@ -353,39 +374,41 @@ def test_get_station_by_number(self): self.assertEqual(cluster.get_station(501), cluster.stations[0]) def test_init_sets_position(self): - cluster = clusters.BaseCluster((10., 20.), pi / 2) - self.assertEqual(cluster.x, 10.) - self.assertEqual(cluster.y, 20.) - self.assertEqual(cluster.z, 0.) + cluster = clusters.BaseCluster((10.0, 20.0), pi / 2) + self.assertEqual(cluster.x, 10.0) + self.assertEqual(cluster.y, 20.0) + self.assertEqual(cluster.z, 0.0) self.assertEqual(cluster.alpha, pi / 2) def test_get_coordinates(self): - cluster = clusters.BaseCluster((10., 20., 0), pi / 2) + cluster = clusters.BaseCluster((10.0, 20.0, 0), pi / 2) coordinates = cluster.get_coordinates() - self.assertEqual(coordinates, (10., 20., 0, pi / 2)) + self.assertEqual(coordinates, (10.0, 20.0, 0, pi / 2)) coordinates = cluster.get_xyalpha_coordinates() - self.assertEqual(coordinates, (10., 20., pi / 2)) + self.assertEqual(coordinates, (10.0, 20.0, pi / 2)) coordinates = cluster.get_xy_coordinates() - self.assertEqual(coordinates, (10., 20.)) + self.assertEqual(coordinates, (10.0, 20.0)) def test_get_polar_alpha_coordinates(self): cluster = clusters.BaseCluster((-sqrt(2) / 2, sqrt(2) / 2), pi / 2) r, phi, alpha = cluster.get_polar_alpha_coordinates() - self.assertAlmostEqual(r, 1.) + self.assertAlmostEqual(r, 1.0) self.assertAlmostEqual(phi, 3 * pi / 4) self.assertEqual(alpha, pi / 2) def test_set_coordinates(self): cluster = clusters.BaseCluster() cluster.set_coordinates(sentinel.x, sentinel.y, sentinel.z, sentinel.alpha) - self.assertEqual((cluster.x, cluster.y, cluster.z, cluster.alpha), - (sentinel.x, sentinel.y, sentinel.z, sentinel.alpha)) + self.assertEqual( + (cluster.x, cluster.y, cluster.z, cluster.alpha), + (sentinel.x, sentinel.y, sentinel.z, sentinel.alpha), + ) def test_set_rphialpha_coordinates(self): cluster = clusters.BaseCluster() - cluster.set_cylindrical_coordinates(10., pi / 2, sentinel.z, sentinel.alpha) - self.assertAlmostEqual(cluster.x, 0.) - self.assertAlmostEqual(cluster.y, 10.) + cluster.set_cylindrical_coordinates(10.0, pi / 2, sentinel.z, sentinel.alpha) + self.assertAlmostEqual(cluster.x, 0.0) + self.assertAlmostEqual(cluster.y, 10.0) self.assertAlmostEqual(cluster.z, sentinel.z) self.assertAlmostEqual(cluster.alpha, sentinel.alpha) @@ -395,15 +418,16 @@ def test_calc_r_and_phi_for_stations(self): cluster._add_station((1, sqrt(3)), 0) r, phi, z = cluster.calc_rphiz_for_stations(0, 1) self.assertAlmostEqual(r, 2) - self.assertAlmostEqual(phi, pi / 3.) + self.assertAlmostEqual(phi, pi / 3.0) self.assertAlmostEqual(z, 0) def test_calc_xy_center_of_mass_coordinates(self): cluster = clusters.BaseCluster() - cluster._add_station((0, 0), 0, [((0, 5 * sqrt(3)), 'UD'), - ((0, 5 * sqrt(3) / 3), 'UD'), - ((-10, 0), 'LR'), - ((10, 0), 'LR')]) + cluster._add_station( + (0, 0), + 0, + [((0, 5 * sqrt(3)), 'UD'), ((0, 5 * sqrt(3) / 3), 'UD'), ((-10, 0), 'LR'), ((10, 0), 'LR')], + ) x, y = cluster.calc_xy_center_of_mass_coordinates() self.assertAlmostEqual(x, 0) self.assertAlmostEqual(y, 5 * sqrt(3) / 3) @@ -411,43 +435,42 @@ def test_calc_xy_center_of_mass_coordinates(self): def test_calc_xy_center_of_mass_coordinates_nan_detectors(self): # detector locations can be nan, esp two detector stations cluster = clusters.BaseCluster() - cluster._add_station((0, 0), 0, [((-10, 0), 'LR'), - ((10, 0), 'LR'), - ((nan, nan), 'LR'), - ((nan, nan), 'LR')]) + cluster._add_station((0, 0), 0, [((-10, 0), 'LR'), ((10, 0), 'LR'), ((nan, nan), 'LR'), ((nan, nan), 'LR')]) x, y = cluster.calc_xy_center_of_mass_coordinates() self.assertAlmostEqual(x, 0) self.assertAlmostEqual(y, 0) def test_set_center_off_mass_at_origin(self): cluster = clusters.BaseCluster() - cluster._add_station((0, 0), 0, [((0, 5 * sqrt(3)), 'UD'), - ((0, 5 * sqrt(3) / 3), 'UD'), - ((-10, 0), 'LR'), - ((10, 0), 'LR')]) + cluster._add_station( + (0, 0), + 0, + [((0, 5 * sqrt(3)), 'UD'), ((0, 5 * sqrt(3) / 3), 'UD'), ((-10, 0), 'LR'), ((10, 0), 'LR')], + ) cluster.set_center_off_mass_at_origin() center = cluster.calc_center_of_mass_coordinates() - assert_array_almost_equal(center, [0., 0., 0.]) + assert_array_almost_equal(center, [0.0, 0.0, 0.0]) def test_set_center_off_mass_at_origin_rotated_cluster(self): cluster = clusters.BaseCluster() - cluster._add_station((0, 0), 0, [((0, 5 * sqrt(3)), 'UD'), - ((0, 5 * sqrt(3) / 3), 'UD'), - ((-10, 0), 'LR'), - ((10, 0), 'LR')]) - cluster.set_coordinates(10., -10., 1., 1.) + cluster._add_station( + (0, 0), + 0, + [((0, 5 * sqrt(3)), 'UD'), ((0, 5 * sqrt(3) / 3), 'UD'), ((-10, 0), 'LR'), ((10, 0), 'LR')], + ) + cluster.set_coordinates(10.0, -10.0, 1.0, 1.0) cluster.set_center_off_mass_at_origin() center = cluster.calc_center_of_mass_coordinates() - assert_array_almost_equal(center, [0., 0., 0.]) - self.assertAlmostEqual(cluster.alpha, 1.) + assert_array_almost_equal(center, [0.0, 0.0, 0.0]) + self.assertAlmostEqual(cluster.alpha, 1.0) def test__distance(self): - x = array([-5., 4., 3.]) - y = array([2., -1., 0.]) + x = array([-5.0, 4.0, 3.0]) + y = array([2.0, -1.0, 0.0]) dist = clusters.BaseCluster()._distance(x, y) self.assertAlmostEqual(dist, sqrt(49 + 25 + 9)) - x = array([-5., 4.]) - y = array([2., -1.]) + x = array([-5.0, 4.0]) + y = array([2.0, -1.0]) dist = clusters.BaseCluster()._distance(x, y) self.assertAlmostEqual(dist, sqrt(49 + 25)) @@ -485,22 +508,21 @@ def assert_tuple_almost_equal(self, actual, expected): self.assertIsInstance(actual, tuple) self.assertIsInstance(expected, tuple) - msg = f"Tuples differ: {str(actual)} != {str(expected)}" + msg = f'Tuples differ: {actual!s} != {expected!s}' for actual_value, expected_value in zip(actual, expected): self.assertAlmostEqual(actual_value, expected_value, msg=msg) class SimpleClusterTests(unittest.TestCase): def test_init_calls_super_init(self): - with patch.object(clusters.BaseCluster, '__init__', - mocksignature=True) as mock_base_init: + with patch.object(clusters.BaseCluster, '__init__', mocksignature=True) as mock_base_init: clusters.SimpleCluster() self.assertTrue(mock_base_init.called) def test_get_coordinates_after_init(self): cluster = clusters.SimpleCluster() coordinates = cluster.get_xyalpha_coordinates() - self.assertEqual(coordinates, (0., 0., 0.)) + self.assertEqual(coordinates, (0.0, 0.0, 0.0)) def test_cluster_stations(self): cluster = clusters.SimpleCluster() @@ -510,15 +532,14 @@ def test_cluster_stations(self): class SingleStationTests(unittest.TestCase): def test_init_calls_super_init(self): - with patch.object(clusters.BaseCluster, '__init__', - mocksignature=True) as mock_base_init: + with patch.object(clusters.BaseCluster, '__init__', mocksignature=True) as mock_base_init: clusters.SingleStation() self.assertTrue(mock_base_init.called) def test_get_coordinates_after_init(self): cluster = clusters.SingleStation() coordinates = cluster.get_xyalpha_coordinates() - self.assertEqual(coordinates, (0., 0., 0.)) + self.assertEqual(coordinates, (0.0, 0.0, 0.0)) def test_single_station(self): cluster = clusters.SingleStation() @@ -531,14 +552,13 @@ def setUp(self): self.cluster = clusters.SingleDetectorStation() def test_init_calls_super_init(self): - with patch.object(clusters.BaseCluster, '__init__', - mocksignature=True) as mock_base_init: + with patch.object(clusters.BaseCluster, '__init__', mocksignature=True) as mock_base_init: clusters.SingleDetectorStation() self.assertTrue(mock_base_init.called) def test_get_coordinates_after_init(self): coordinates = self.cluster.get_xyalpha_coordinates() - self.assertEqual(coordinates, (0., 0., 0.)) + self.assertEqual(coordinates, (0.0, 0.0, 0.0)) def test_single_station_single_detector(self): self.assertEqual(len(self.cluster.stations), 1) @@ -547,15 +567,14 @@ def test_single_station_single_detector(self): class SingleTwoDetectorStationTests(unittest.TestCase): def test_init_calls_super_init(self): - with patch.object(clusters.BaseCluster, '__init__', - mocksignature=True) as mock_base_init: + with patch.object(clusters.BaseCluster, '__init__', mocksignature=True) as mock_base_init: clusters.SingleTwoDetectorStation() self.assertTrue(mock_base_init.called) def test_get_coordinates_after_init(self): cluster = clusters.SingleTwoDetectorStation() coordinates = cluster.get_xyalpha_coordinates() - self.assertEqual(coordinates, (0., 0., 0.)) + self.assertEqual(coordinates, (0.0, 0.0, 0.0)) def test_single_station(self): cluster = clusters.SingleTwoDetectorStation() @@ -567,15 +586,14 @@ def test_single_station(self): class SingleDiamondStationTests(unittest.TestCase): def test_init_calls_super_init(self): - with patch.object(clusters.BaseCluster, '__init__', - mocksignature=True) as mock_base_init: + with patch.object(clusters.BaseCluster, '__init__', mocksignature=True) as mock_base_init: clusters.SingleDiamondStation() self.assertTrue(mock_base_init.called) def test_get_coordinates_after_init(self): cluster = clusters.SingleDiamondStation() coordinates = cluster.get_xyalpha_coordinates() - self.assertEqual(coordinates, (0., 0., 0.)) + self.assertEqual(coordinates, (0.0, 0.0, 0.0)) def test_single_station(self): cluster = clusters.SingleDiamondStation() @@ -587,17 +605,16 @@ def test_single_station(self): class HiSPARCStationTests(unittest.TestCase): def setUp(self): - self.cluster = clusters.HiSPARCStations([501, 508, 510], - force_stale=True) + self.cluster = clusters.HiSPARCStations([501, 508, 510], force_stale=True) def test_first_station_was_reference(self): """First station was origin before shift to center mass""" - self.assertNotEqual(self.cluster.get_station(501).get_coordinates(), (0., 0., 0., 0.)) - self.assertNotEqual(self.cluster.get_station(508).get_coordinates(), (0., 0., 0., 0.)) + self.assertNotEqual(self.cluster.get_station(501).get_coordinates(), (0.0, 0.0, 0.0, 0.0)) + self.assertNotEqual(self.cluster.get_station(508).get_coordinates(), (0.0, 0.0, 0.0, 0.0)) # Undo cluster center at center mass self.cluster.set_coordinates(0, 0, 0, 0) - self.assertEqual(self.cluster.get_station(501).get_coordinates(), (0., 0., 0., 0.)) - self.assertNotEqual(self.cluster.get_station(508).get_coordinates(), (0., 0., 0., 0.)) + self.assertEqual(self.cluster.get_station(501).get_coordinates(), (0.0, 0.0, 0.0, 0.0)) + self.assertNotEqual(self.cluster.get_station(508).get_coordinates(), (0.0, 0.0, 0.0, 0.0)) def test_allow_missing_gps(self): """Allow making cluster with station without GPS coords @@ -606,39 +623,36 @@ def test_allow_missing_gps(self): """ with warnings.catch_warnings(record=True) as warned: - cluster = clusters.HiSPARCStations( - [0, 508, 510], skip_missing=True, force_stale=True) + cluster = clusters.HiSPARCStations([0, 508, 510], skip_missing=True, force_stale=True) self.assertEqual(len(warned), 2) # Undo cluster center at center mass cluster.set_coordinates(0, 0, 0, 0) - self.assertEqual(cluster.get_station(508).get_coordinates(), (0., 0., 0., 0.)) + self.assertEqual(cluster.get_station(508).get_coordinates(), (0.0, 0.0, 0.0, 0.0)) def test_missing_gps_not_allowed(self): """Making cluster with station without GPS coords raises exception""" with self.assertRaises(KeyError): - clusters.HiSPARCStations( - [0, 508, 510], skip_missing=False, force_stale=True) + clusters.HiSPARCStations([0, 508, 510], skip_missing=False, force_stale=True) def test_zero_center_off_mass(self): center = self.cluster.calc_center_of_mass_coordinates() - assert_array_almost_equal(center, [0., 0., 0.]) + assert_array_almost_equal(center, [0.0, 0.0, 0.0]) class FlattenClusterTests(unittest.TestCase): - def test_flatten_cluster_mock(self): cluster = Mock() station = Mock() - station.z = [1.] + station.z = [1.0] detector = Mock() - detector.z = [1.] + detector.z = [1.0] station.detectors = [detector] cluster.stations = [station] - self.assertEqual(cluster.stations[0].z[0], 1.) - self.assertEqual(cluster.stations[0].detectors[0].z[0], 1.) + self.assertEqual(cluster.stations[0].z[0], 1.0) + self.assertEqual(cluster.stations[0].detectors[0].z[0], 1.0) clusters.flatten_cluster(cluster) - self.assertEqual(cluster.stations[0].z[0], 0.) - self.assertEqual(cluster.stations[0].detectors[0].z[0], 0.) + self.assertEqual(cluster.stations[0].z[0], 0.0) + self.assertEqual(cluster.stations[0].detectors[0].z[0], 0.0) def test_flatten_cluster(self): cluster = clusters.CompassStations() diff --git a/sapphire/tests/test_clusters_acceptance.py b/sapphire/tests/test_clusters_acceptance.py index 92066c67..52d58c55 100644 --- a/sapphire/tests/test_clusters_acceptance.py +++ b/sapphire/tests/test_clusters_acceptance.py @@ -10,11 +10,9 @@ def setUp(self): self.cluster = clusters.SimpleCluster(size=100) def test_station_positions_and_angles(self): - a = sqrt(100 ** 2 - 50 ** 2) - expected = [(0, 2 * a / 3, 0, 0), (0, 0, 0, 0), - (-50, -a / 3, 0, 2 * pi / 3), (50, -a / 3, 0, -2 * pi / 3)] - actual = [(station.x[0], station.y[0], station.z[0], station.angle[0]) - for station in self.cluster.stations] + a = sqrt(100**2 - 50**2) + expected = [(0, 2 * a / 3, 0, 0), (0, 0, 0, 0), (-50, -a / 3, 0, 2 * pi / 3), (50, -a / 3, 0, -2 * pi / 3)] + actual = [(station.x[0], station.y[0], station.z[0], station.angle[0]) for station in self.cluster.stations] for actual_value, expected_value in zip(actual, expected): self.assert_tuple_almost_equal(actual_value, expected_value) @@ -28,6 +26,6 @@ def assert_tuple_almost_equal(self, actual, expected): self.assertIsInstance(actual, tuple) self.assertIsInstance(expected, tuple) - msg = f"Tuples differ: {str(actual)} != {str(expected)}" + msg = f'Tuples differ: {actual!s} != {expected!s}' for actual_value, expected_value in zip(actual, expected): self.assertAlmostEqual(actual_value, expected_value, msg=msg) diff --git a/sapphire/tests/test_esd.py b/sapphire/tests/test_esd.py index 6ed0e53b..f6e1b5c5 100644 --- a/sapphire/tests/test_esd.py +++ b/sapphire/tests/test_esd.py @@ -20,76 +20,76 @@ class StaleNetwork(api.Network): """api.Network with `force_stale=True` always true""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.force_stale = True class ESDTest(unittest.TestCase): - def test_create_table(self): - description = {'event_id': tables.UInt32Col(pos=0), - 'timestamp': tables.Time32Col(pos=1), - 'nanoseconds': tables.UInt32Col(pos=2), - 'ext_timestamp': tables.UInt64Col(pos=3), - 'pulseheights': tables.Int16Col(pos=4, shape=4), - 'integrals': tables.Int32Col(pos=5, shape=4), - 'n1': tables.Float32Col(pos=6), - 'n2': tables.Float32Col(pos=7), - 'n3': tables.Float32Col(pos=8), - 'n4': tables.Float32Col(pos=9), - 't1': tables.Float32Col(pos=10), - 't2': tables.Float32Col(pos=11), - 't3': tables.Float32Col(pos=12), - 't4': tables.Float32Col(pos=13), - 't_trigger': tables.Float32Col(pos=14)} + description = { + 'event_id': tables.UInt32Col(pos=0), + 'timestamp': tables.Time32Col(pos=1), + 'nanoseconds': tables.UInt32Col(pos=2), + 'ext_timestamp': tables.UInt64Col(pos=3), + 'pulseheights': tables.Int16Col(pos=4, shape=4), + 'integrals': tables.Int32Col(pos=5, shape=4), + 'n1': tables.Float32Col(pos=6), + 'n2': tables.Float32Col(pos=7), + 'n3': tables.Float32Col(pos=8), + 'n4': tables.Float32Col(pos=9), + 't1': tables.Float32Col(pos=10), + 't2': tables.Float32Col(pos=11), + 't3': tables.Float32Col(pos=12), + 't4': tables.Float32Col(pos=13), + 't_trigger': tables.Float32Col(pos=14), + } file = MagicMock() result = esd._create_events_table(file, sentinel.group) - file.create_table.assert_called_once_with(sentinel.group, 'events', - description, - createparents=True) + file.create_table.assert_called_once_with(sentinel.group, 'events', description, createparents=True) self.assertEqual(result, file.create_table.return_value) def test_create_weather_table(self): - description = {'event_id': tables.UInt32Col(pos=0), - 'timestamp': tables.Time32Col(pos=1), - 'temp_inside': tables.Float32Col(pos=2), - 'temp_outside': tables.Float32Col(pos=3), - 'humidity_inside': tables.Int16Col(pos=4), - 'humidity_outside': tables.Int16Col(pos=5), - 'barometer': tables.Float32Col(pos=6), - 'wind_dir': tables.Int16Col(pos=7), - 'wind_speed': tables.Int16Col(pos=8), - 'solar_rad': tables.Int16Col(pos=9), - 'uv': tables.Int16Col(pos=10), - 'evapotranspiration': tables.Float32Col(pos=11), - 'rain_rate': tables.Float32Col(pos=12), - 'heat_index': tables.Int16Col(pos=13), - 'dew_point': tables.Float32Col(pos=14), - 'wind_chill': tables.Float32Col(pos=15)} + description = { + 'event_id': tables.UInt32Col(pos=0), + 'timestamp': tables.Time32Col(pos=1), + 'temp_inside': tables.Float32Col(pos=2), + 'temp_outside': tables.Float32Col(pos=3), + 'humidity_inside': tables.Int16Col(pos=4), + 'humidity_outside': tables.Int16Col(pos=5), + 'barometer': tables.Float32Col(pos=6), + 'wind_dir': tables.Int16Col(pos=7), + 'wind_speed': tables.Int16Col(pos=8), + 'solar_rad': tables.Int16Col(pos=9), + 'uv': tables.Int16Col(pos=10), + 'evapotranspiration': tables.Float32Col(pos=11), + 'rain_rate': tables.Float32Col(pos=12), + 'heat_index': tables.Int16Col(pos=13), + 'dew_point': tables.Float32Col(pos=14), + 'wind_chill': tables.Float32Col(pos=15), + } file = MagicMock() result = esd._create_weather_table(file, sentinel.group) - file.create_table.assert_called_once_with(sentinel.group, 'weather', - description, - createparents=True) + file.create_table.assert_called_once_with(sentinel.group, 'weather', description, createparents=True) self.assertEqual(result, file.create_table.return_value) def test_create_singles_table(self): - description = {'event_id': tables.UInt32Col(pos=0), - 'timestamp': tables.Time32Col(pos=1), - 'mas_ch1_low': tables.Int32Col(pos=2), - 'mas_ch1_high': tables.Int32Col(pos=3), - 'mas_ch2_low': tables.Int32Col(pos=4), - 'mas_ch2_high': tables.Int32Col(pos=5), - 'slv_ch1_low': tables.Int32Col(pos=6), - 'slv_ch1_high': tables.Int32Col(pos=7), - 'slv_ch2_low': tables.Int32Col(pos=8), - 'slv_ch2_high': tables.Int32Col(pos=9)} + description = { + 'event_id': tables.UInt32Col(pos=0), + 'timestamp': tables.Time32Col(pos=1), + 'mas_ch1_low': tables.Int32Col(pos=2), + 'mas_ch1_high': tables.Int32Col(pos=3), + 'mas_ch2_low': tables.Int32Col(pos=4), + 'mas_ch2_high': tables.Int32Col(pos=5), + 'slv_ch1_low': tables.Int32Col(pos=6), + 'slv_ch1_high': tables.Int32Col(pos=7), + 'slv_ch2_low': tables.Int32Col(pos=8), + 'slv_ch2_high': tables.Int32Col(pos=9), + } file = MagicMock() result = esd._create_singles_table(file, sentinel.group) - file.create_table.assert_called_once_with(sentinel.group, 'singles', - description, - createparents=True) + file.create_table.assert_called_once_with(sentinel.group, 'singles', description, createparents=True) self.assertEqual(result, file.create_table.return_value) def test__first_available_numbered_path(self): @@ -107,16 +107,13 @@ def test_unsupported_type(self): """Check for Exception for unsupported data types""" self.assertRaises(ValueError, esd.load_data, None, None, None, 'bad') - self.assertRaises(ValueError, esd.download_data, None, None, 501, - type='bad') + self.assertRaises(ValueError, esd.download_data, None, None, 501, type='bad') def test_start_end_values(self): """Check for RuntimeError for impossible end=value with start=None""" - self.assertRaises(RuntimeError, esd.download_data, None, None, 501, - start=None, end='a_value') - self.assertRaises(RuntimeError, esd.download_coincidences, None, - start=None, end="a_value") + self.assertRaises(RuntimeError, esd.download_data, None, None, 501, start=None, end='a_value') + self.assertRaises(RuntimeError, esd.download_coincidences, None, start=None, end='a_value') def test_load_data_output(self): """Load data tsv into hdf5 and verify the output""" @@ -144,8 +141,7 @@ def test_quick_download(self, mock_open_file, mock_download_data): mock_open_file.assert_called_once_with('data1.h5', 'w') mock_download_data.assert_called_once_with(ANY, None, 501, None) - @unittest.skipUnless(api.API.check_connection(), - "Internet connection required") + @unittest.skipUnless(api.API.check_connection(), 'Internet connection required') def test_download_data(self): """Download data and validate results""" @@ -154,8 +150,7 @@ def test_download_data(self): validate_results(self, test_data_path, output_path) os.remove(output_path) - @unittest.skipUnless(api.API.check_connection(), - "Internet connection required") + @unittest.skipUnless(api.API.check_connection(), 'Internet connection required') @patch.object(esd.api, 'Network', side_effect=StaleNetwork) def test_download_coincidences(self, mock_esd_api_network): """Download coincidence data from esd and validate results""" diff --git a/sapphire/tests/test_kascade.py b/sapphire/tests/test_kascade.py index cdf10313..9f4e5f75 100644 --- a/sapphire/tests/test_kascade.py +++ b/sapphire/tests/test_kascade.py @@ -13,22 +13,17 @@ class StoreKascadeDataTests(unittest.TestCase): - def setUp(self): self.destination_path = self.create_tempfile_path() + self.addCleanup(os.remove, self.destination_path) def test_read_and_store_data(self): path = self.destination_path with tables.open_file(path, 'a') as self.destination_data: - self.kascade = kascade.StoreKascadeData(self.destination_data, - TEST_DATA_FILE, '/kascade', - progress=False) + self.kascade = kascade.StoreKascadeData(self.destination_data, TEST_DATA_FILE, '/kascade', progress=False) self.kascade.read_and_store_data() validate_results(self, TEST_DATA_REF, self.destination_path) - def tearDown(self): - os.remove(self.destination_path) - def create_tempfile_path(self): fd, path = tempfile.mkstemp('.h5') os.close(fd) diff --git a/sapphire/tests/test_publicdb.py b/sapphire/tests/test_publicdb.py index 32fdd062..537d5593 100644 --- a/sapphire/tests/test_publicdb.py +++ b/sapphire/tests/test_publicdb.py @@ -18,12 +18,9 @@ class DownloadDataTest(unittest.TestCase): - def setUp(self): logging.disable(logging.CRITICAL) - - def tearDown(self): - logging.disable(logging.NOTSET) + self.addCleanup(logging.disable, logging.NOTSET) @patch.object(publicdb, '_store_data') @patch.object(publicdb, 'urlretrieve') @@ -35,19 +32,23 @@ def test_download_data(self, mock_server, mock_retrieve, mock_store): mock_get_data_url = mock_server.return_value.hisparc.get_data_url mock_get_data_url.return_value = sentinel.url mock_retrieve.return_value = (sentinel.tmpdata, sentinel.headers) - publicdb.download_data(file, sentinel.group, sentinel.station_id, - start, end, get_blobs=sentinel.blobs) - mock_get_data_url.assert_called_once_with(sentinel.station_id, start, - sentinel.blobs) - - mock_get_data_url.side_effect = Exception("No data") - publicdb.download_data(file, sentinel.group, sentinel.station_id, - start, end, get_blobs=sentinel.blobs) - - mock_get_data_url.side_effect = Exception("Unknown error") - self.assertRaises(Exception, publicdb.download_data, file, - sentinel.group, sentinel.station_id, start, - end, get_blobs=sentinel.blobs) + publicdb.download_data(file, sentinel.group, sentinel.station_id, start, end, get_blobs=sentinel.blobs) + mock_get_data_url.assert_called_once_with(sentinel.station_id, start, sentinel.blobs) + + mock_get_data_url.side_effect = Exception('No data') + publicdb.download_data(file, sentinel.group, sentinel.station_id, start, end, get_blobs=sentinel.blobs) + + mock_get_data_url.side_effect = Exception('Unknown error') + self.assertRaises( + Exception, + publicdb.download_data, + file, + sentinel.group, + sentinel.station_id, + start, + end, + get_blobs=sentinel.blobs, + ) def test__store_data(self): # store data removes the source data when completed, so use a temp @@ -72,34 +73,38 @@ def test__store_data_no_end(self): start = datetime(2016, 4, 21) filters = tables.Filters(complevel=1) with tables.open_file(output_path, 'w', filters=filters) as datafile: - publicdb._store_data(datafile, '/station_501', tmp_src_path, start, - None) + publicdb._store_data(datafile, '/station_501', tmp_src_path, start, None) validate_results(self, test_data_src_path, output_path) os.remove(output_path) def test_datetimerange(self): combinations = [ - (datetime(2010, 1, 1, 11), - datetime(2010, 1, 1, 13), - [(datetime(2010, 1, 1, 11), datetime(2010, 1, 1, 13))]), - (datetime(2010, 1, 1, 11), - datetime(2010, 1, 2), - [(datetime(2010, 1, 1, 11), None)]), - (datetime(2010, 1, 1, 11), - datetime(2010, 1, 2, 13), - [(datetime(2010, 1, 1, 11), None), - (datetime(2010, 1, 2), datetime(2010, 1, 2, 13))]), - (datetime(2010, 1, 1, 11), - datetime(2010, 1, 5, 13), - [(datetime(2010, 1, 1, 11), None), - (datetime(2010, 1, 2), None), - (datetime(2010, 1, 3), None), - (datetime(2010, 1, 4), None), - (datetime(2010, 1, 5), datetime(2010, 1, 5, 13))])] + ( + datetime(2010, 1, 1, 11), + datetime(2010, 1, 1, 13), + [(datetime(2010, 1, 1, 11), datetime(2010, 1, 1, 13))], + ), + (datetime(2010, 1, 1, 11), datetime(2010, 1, 2), [(datetime(2010, 1, 1, 11), None)]), + ( + datetime(2010, 1, 1, 11), + datetime(2010, 1, 2, 13), + [(datetime(2010, 1, 1, 11), None), (datetime(2010, 1, 2), datetime(2010, 1, 2, 13))], + ), + ( + datetime(2010, 1, 1, 11), + datetime(2010, 1, 5, 13), + [ + (datetime(2010, 1, 1, 11), None), + (datetime(2010, 1, 2), None), + (datetime(2010, 1, 3), None), + (datetime(2010, 1, 4), None), + (datetime(2010, 1, 5), datetime(2010, 1, 5, 13)), + ], + ), + ] for start, stop, result in combinations: self.assertEqual(list(publicdb.datetimerange(start, stop)), result) - self.assertRaises(Exception, next, - publicdb.datetimerange(stop, start)) + self.assertRaises(Exception, next, publicdb.datetimerange(stop, start)) def test__get_or_create_group(self): file = Mock() @@ -111,9 +116,7 @@ def test__get_or_create_group(self): file.get_node.side_effect = tables.NoSuchNodeError('no such node!') in_group = '/hisparc/station_501' out_group = publicdb._get_or_create_group(file, in_group) - file.create_group.assert_called_once_with('/hisparc', 'station_501', - 'Data group', - createparents=True) + file.create_group.assert_called_once_with('/hisparc', 'station_501', 'Data group', createparents=True) self.assertEqual(file.create_group.return_value, out_group) def test__get_or_create_node(self): @@ -127,18 +130,14 @@ def test__get_or_create_node(self): file.get_node.side_effect = tables.NoSuchNodeError('no such node!') # Raise exception because type of Mock src_node is not Table or VLArray - self.assertRaises(Exception, publicdb._get_or_create_node, file, - sentinel.group, src_node) + self.assertRaises(Exception, publicdb._get_or_create_node, file, sentinel.group, src_node) src_node = Mock(spec=tables.Table) src_node.description = sentinel.description node = publicdb._get_or_create_node(file, sentinel.group, src_node) - file.create_table.assert_called_once_with( - sentinel.group, src_node.name, src_node.description, - src_node.title) + file.create_table.assert_called_once_with(sentinel.group, src_node.name, src_node.description, src_node.title) src_node = Mock(spec=tables.VLArray) src_node.atom = sentinel.atom node = publicdb._get_or_create_node(file, sentinel.group, src_node) - file.create_vlarray.assert_called_once_with( - sentinel.group, src_node.name, src_node.atom, src_node.title) + file.create_vlarray.assert_called_once_with(sentinel.group, src_node.name, src_node.atom, src_node.title) diff --git a/sapphire/tests/test_qsub.py b/sapphire/tests/test_qsub.py index ccf34a21..72692110 100644 --- a/sapphire/tests/test_qsub.py +++ b/sapphire/tests/test_qsub.py @@ -9,7 +9,6 @@ @patch.object(qsub.utils, 'which') class CheckQueueTest(unittest.TestCase): - @patch.object(qsub.subprocess, 'check_output') def test_queues(self, mock_check_output, mock_which): for queue in ['express', 'short', 'generic', 'long']: @@ -24,14 +23,16 @@ def test_bad_queue(self, mock_check_output, mock_which): @patch.object(qsub.subprocess, 'check_output') def test_check_queue(self, mock_check_output, mock_which): - combinations = ([[' 0\n'], 2, 'express'], - [[' 2\n'], 0, 'express'], - [[' 100\n'], 900, 'short'], - [['1100\n'], -100, 'short'], - [['2000\n', '1000\n'], 1000, 'generic'], - [['3600\n', '1000\n'], 400, 'generic'], - [[' 200\n', ' 100\n'], 400, 'long'], - [[' 620\n', ' 100\n'], 380, 'long']) + combinations = ( + [[' 0\n'], 2, 'express'], + [[' 2\n'], 0, 'express'], + [[' 100\n'], 900, 'short'], + [['1100\n'], -100, 'short'], + [['2000\n', '1000\n'], 1000, 'generic'], + [['3600\n', '1000\n'], 400, 'generic'], + [[' 200\n', ' 100\n'], 400, 'long'], + [[' 620\n', ' 100\n'], 380, 'long'], + ) for taken, available, queue in combinations: mock_check_output.side_effect = cycle(taken) self.assertEqual(qsub.check_queue(queue), available) @@ -39,60 +40,50 @@ def test_check_queue(self, mock_check_output, mock_which): @patch.object(qsub.utils, 'which') class SubmitJobTest(unittest.TestCase): - @patch.object(qsub, 'create_script') @patch.object(qsub.subprocess, 'check_output') @patch.object(qsub, 'delete_script') - def test_submit_job(self, mock_delete, mock_check_output, mock_create, - mock_which): + def test_submit_job(self, mock_delete, mock_check_output, mock_create, mock_which): mock_create.return_value = (sentinel.script_path, sentinel.script_name) mock_check_output.return_value = b'' qsub.submit_job(sentinel.script, sentinel.name, sentinel.queue, sentinel.extra) mock_create.assert_called_once_with(sentinel.script, sentinel.name) - command = ('qsub -q {queue} -V -z -j oe -N {name} {extra} {script}' - .format(queue=sentinel.queue, name=sentinel.script_name, - script=sentinel.script_path, extra=sentinel.extra)) - mock_check_output.assert_called_once_with(command, - stderr=qsub.subprocess.STDOUT, - shell=True) + command = ( + f'qsub -q {sentinel.queue} -V -z -j oe -N {sentinel.script_name} {sentinel.extra} {sentinel.script_path}' + ) + mock_check_output.assert_called_once_with(command, stderr=qsub.subprocess.STDOUT, shell=True) mock_delete.assert_called_once_with(sentinel.script_path) @patch.object(qsub, 'create_script') @patch.object(qsub.subprocess, 'check_output') @patch.object(qsub, 'delete_script') - def test_failed_submit_job(self, mock_delete, mock_check_output, - mock_create, mock_which): + def test_failed_submit_job(self, mock_delete, mock_check_output, mock_create, mock_which): mock_create.return_value = (sentinel.script_path, sentinel.script_name) mock_check_output.return_value = 'Failed!' - self.assertRaises(Exception, qsub.submit_job, sentinel.script, - sentinel.name, sentinel.queue, sentinel.extra) + self.assertRaises(Exception, qsub.submit_job, sentinel.script, sentinel.name, sentinel.queue, sentinel.extra) mock_create.assert_called_once_with(sentinel.script, sentinel.name) - command = ('qsub -q {queue} -V -z -j oe -N {name} {extra} {script}' - .format(queue=sentinel.queue, name=sentinel.script_name, - script=sentinel.script_path, extra=sentinel.extra)) - mock_check_output.assert_called_once_with(command, - stderr=qsub.subprocess.STDOUT, - shell=True) + command = ( + f'qsub -q {sentinel.queue} -V -z -j oe -N {sentinel.script_name} {sentinel.extra} {sentinel.script_path}' + ) + mock_check_output.assert_called_once_with(command, stderr=qsub.subprocess.STDOUT, shell=True) self.assertFalse(mock_delete.called) class CreateScriptTest(unittest.TestCase): - @patch.object(qsub.os, 'chmod') def test_create_script(self, mock_chmod): with patch.object(builtins, 'open', mock_open()) as mock_file: res_path, res_name = qsub.create_script(sentinel.script, 'hoi') - self.assertEqual(res_path, '/tmp/his_hoi.sh') + self.assertTrue(res_path.endswith('/his_hoi.sh')) self.assertEqual(res_name, 'his_hoi.sh') mock_file.assert_called_once_with(res_path, 'w') - mock_file().write.called_once_with(sentinel.script) + mock_file().write.assert_called_once_with(sentinel.script) mock_chmod.assert_called_once_with(res_path, 0o774) class DeleteScriptTest(unittest.TestCase): - @patch.object(qsub.os, 'remove') def test_delete_script(self, mock_remove): self.assertFalse(mock_remove.called) diff --git a/sapphire/tests/test_time_util.py b/sapphire/tests/test_time_util.py index 3426c292..4127fbc2 100644 --- a/sapphire/tests/test_time_util.py +++ b/sapphire/tests/test_time_util.py @@ -23,7 +23,7 @@ def test_incorrect_arguments(self): @patch.object(time_util.GPSTime, 'description') def test_str_returns_description(self, mock_description): - expected = "Foobar" + expected = 'Foobar' mock_description.return_value = expected t = time_util.GPSTime(2014, 10, 27) actual = str(t) diff --git a/sapphire/tests/test_utils.py b/sapphire/tests/test_utils.py index cf3880cc..d1f6c646 100644 --- a/sapphire/tests/test_utils.py +++ b/sapphire/tests/test_utils.py @@ -11,7 +11,6 @@ class PbarTests(unittest.TestCase): - def setUp(self): self.iterable = list(range(10)) self.output = StringIO() @@ -25,7 +24,7 @@ def test_pbar_generator(self): """Return original generator, not a progressbar""" generator = (x for x in self.iterable) - pb = utils.pbar(generator) + pb = utils.pbar(generator, fd=self.output) self.assertIsInstance(pb, types.GeneratorType) self.assertEqual(list(pb), self.iterable) @@ -50,7 +49,6 @@ def test_pbar_hide_output(self): class InBaseTests(unittest.TestCase): - def test_ceil(self): self.assertEqual(utils.ceil_in_base(2.4, 2.5), 2.5) self.assertEqual(utils.ceil_in_base(0.1, 2.5), 2.5) @@ -75,7 +73,6 @@ def test_integers(self): class ActiveIndexTests(unittest.TestCase): - def test_get_active_index(self): """Test if the bisection returns the correct index @@ -86,29 +83,27 @@ def test_get_active_index(self): equal or less than the timestamp """ - timestamps = [1., 2., 3., 4.] + timestamps = [1.0, 2.0, 3.0, 4.0] - for idx, ts in [(0, 0.), (0, 1.), (0, 1.5), (1, 2.), (1, 2.1), (3, 4.), - (3, 5.)]: + for idx, ts in [(0, 0.0), (0, 1.0), (0, 1.5), (1, 2.0), (1, 2.1), (3, 4.0), (3, 5.0)]: self.assertEqual(utils.get_active_index(timestamps, ts), idx) class GaussTests(unittest.TestCase): - """Test against explicit Gaussian""" def gaussian(self, x, n, mu, sigma): - return n * exp(-(x - mu) ** 2. / (2. * sigma ** 2)) / (sigma * sqrt(2 * pi)) + return n * exp(-((x - mu) ** 2.0) / (2.0 * sigma**2)) / (sigma * sqrt(2 * pi)) def test_gauss(self): - x, n, mu, sigma = (1., 1., 0., 1.) + x, n, mu, sigma = (1.0, 1.0, 0.0, 1.0) self.assertEqual(utils.gauss(x, n, mu, sigma), self.gaussian(x, n, mu, sigma)) - n = 2. + n = 2.0 self.assertEqual(utils.gauss(x, n, mu, sigma), self.gaussian(x, n, mu, sigma)) - sigma = 2. + sigma = 2.0 self.assertEqual(utils.gauss(x, n, mu, sigma), self.gaussian(x, n, mu, sigma)) x = 1e5 - self.assertEqual(utils.gauss(x, n, mu, sigma), 0.) + self.assertEqual(utils.gauss(x, n, mu, sigma), 0.0) def test_gauss_array(self): """Test for arrays of random values""" @@ -123,7 +118,6 @@ def test_gauss_array(self): class AngleBetweenTests(unittest.TestCase): - """Check opening angle between two directions""" def test_zeniths(self): @@ -166,17 +160,18 @@ def test_single_values(self): class DistanceBetweenTests(unittest.TestCase): - """Check distance between two (x, y) cartesian coordinates""" def test_distances(self): """Check if distances are correctly calculated""" - combinations = [((0, 0, 1.6, 0), 1.6), - ((-1, 0, 1, 0), 2), - ((-1, 0, -1, 0), 0), - ((random.uniform(1e-15, 100),) * 4, 0), - ((-10, -10, 5, 5), sqrt(450))] + combinations = [ + ((0, 0, 1.6, 0), 1.6), + ((-1, 0, 1, 0), 2), + ((-1, 0, -1, 0), 0), + ((random.uniform(1e-15, 100),) * 4, 0), + ((-10, -10, 5, 5), sqrt(450)), + ] for coordinates, distance in combinations: self.assertEqual(utils.distance_between(*coordinates), distance) # same result if the coordinates and x, y are swapped @@ -184,7 +179,6 @@ def test_distances(self): class WhichTests(unittest.TestCase): - """Check if which works""" def test_which(self): @@ -195,5 +189,4 @@ def test_which(self): def test_non_existent_program(self): """Check for error for non-existent program""" - self.assertRaises(Exception, utils.which, - 'a_very_unlikely_program_name_to_exist_cosmic_ray') + self.assertRaises(Exception, utils.which, 'a_very_unlikely_program_name_to_exist_cosmic_ray') diff --git a/sapphire/tests/transformations/test_angles.py b/sapphire/tests/transformations/test_angles.py index d9d23e59..43bf2644 100644 --- a/sapphire/tests/transformations/test_angles.py +++ b/sapphire/tests/transformations/test_angles.py @@ -7,15 +7,16 @@ class DegreeRadianHourTests(unittest.TestCase): - def setUp(self): x = random.random() # (degrees, radians, hours) - self.combinations = ((0., 0., 0.), - (15., pi / 12., 1.), - (90., pi / 2., 6.), - (180., pi, 12.), - (360. * x, 2 * pi * x, 24. * x)) + self.combinations = ( + (0.0, 0.0, 0.0), + (15.0, pi / 12.0, 1.0), + (90.0, pi / 2.0, 6.0), + (180.0, pi, 12.0), + (360.0 * x, 2 * pi * x, 24.0 * x), + ) def test_hours_to_degrees(self): for degree, _, hour in self.combinations: diff --git a/sapphire/tests/transformations/test_axes.py b/sapphire/tests/transformations/test_axes.py index 91cd33de..35d9622c 100644 --- a/sapphire/tests/transformations/test_axes.py +++ b/sapphire/tests/transformations/test_axes.py @@ -6,7 +6,6 @@ class CoordinateSystemTests(unittest.TestCase): - def setUp(self): """Test combinations of coordinates @@ -17,12 +16,18 @@ def setUp(self): ((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0)), ((1, 0, 0), (1, pi / 2, 0), (1, 0, 0), (1, 90, 0)), ((-1, 0, 0), (1, pi / 2, pi), (1, pi, 0), (1, -90, 0)), - ((0, 1, 0), (1, pi / 2, pi / 2.), (1, pi / 2., 0), (1, 0, 0)), + ((0, 1, 0), (1, pi / 2, pi / 2.0), (1, pi / 2.0, 0), (1, 0, 0)), ((0, -1, 0), (1, pi / 2, -pi / 2), (1, -pi / 2, 0), (1, 180, 0)), ((0, 0, 1), (1, 0, 0), (0, 0, 1), (0, 0, 1)), ((0, 0, -1), (1, pi, 0), (0, 0, -1), (0, 0, -1)), ((1, 1, 1), (sqrt(3), arccos(1 / sqrt(3)), pi / 4), (sqrt(2), pi / 4, 1), (sqrt(2), 45, 1)), - ((-1, -1, -1), (sqrt(3), arccos(-1 / sqrt(3)), -pi * 3 / 4), (sqrt(2), -pi * 3 / 4, -1), (sqrt(2), -135, -1))) + ( + (-1, -1, -1), + (sqrt(3), arccos(-1 / sqrt(3)), -pi * 3 / 4), + (sqrt(2), -pi * 3 / 4, -1), + (sqrt(2), -135, -1), + ), + ) def test_cartesian_to_spherical(self): for cartesian, spherical, _, _ in self.combinations: @@ -58,21 +63,19 @@ def test_compass_to_cartesian(self): class RotateCartesianTests(unittest.TestCase): - def test_rotate_cartesian(self): - input = (3., 4., 5.) - x, y, z = input - self.assertEqual(input, axes.rotate_cartesian(x, y, z, 0, 'x')) - self.assertEqual(input, axes.rotate_cartesian(x, y, z, 0, 'y')) - self.assertEqual(input, axes.rotate_cartesian(x, y, z, 0, 'z')) + initial = (3.0, 4.0, 5.0) + x, y, z = initial + self.assertEqual(initial, axes.rotate_cartesian(x, y, z, 0, 'x')) + self.assertEqual(initial, axes.rotate_cartesian(x, y, z, 0, 'y')) + self.assertEqual(initial, axes.rotate_cartesian(x, y, z, 0, 'z')) - testing.assert_almost_equal((3., -5., 4.), axes.rotate_cartesian(x, y, z, pi / 2, 'x')) - testing.assert_almost_equal((5., 4., -3.), axes.rotate_cartesian(x, y, z, pi / 2, 'y')) - testing.assert_almost_equal((-4., 3., 5.), axes.rotate_cartesian(x, y, z, pi / 2, 'z')) + testing.assert_almost_equal((3.0, -5.0, 4.0), axes.rotate_cartesian(x, y, z, pi / 2, 'x')) + testing.assert_almost_equal((5.0, 4.0, -3.0), axes.rotate_cartesian(x, y, z, pi / 2, 'y')) + testing.assert_almost_equal((-4.0, 3.0, 5.0), axes.rotate_cartesian(x, y, z, pi / 2, 'z')) class RotationMatrixTests(unittest.TestCase): - def test_no_rotation_matrix(self): """Check if no rotation is correctly returned""" @@ -84,9 +87,6 @@ def test_no_rotation_matrix(self): def test_rotation_matrix(self): """Rotate by 90 degrees to swap the other two axes""" - testing.assert_almost_equal(axes.rotation_matrix(pi / 2., 'x'), - array(((1, 0, 0), (0, 0, 1), (0, -1, 0)))) - testing.assert_almost_equal(axes.rotation_matrix(pi / 2., 'y'), - array(((0, 0, -1), (0, 1, 0), (1, 0, 0)))) - testing.assert_almost_equal(axes.rotation_matrix(pi / 2, 'z'), - array(((0, 1, 0), (-1, 0, 0), (0, 0, 1)))) + testing.assert_almost_equal(axes.rotation_matrix(pi / 2.0, 'x'), array(((1, 0, 0), (0, 0, 1), (0, -1, 0)))) + testing.assert_almost_equal(axes.rotation_matrix(pi / 2.0, 'y'), array(((0, 0, -1), (0, 1, 0), (1, 0, 0)))) + testing.assert_almost_equal(axes.rotation_matrix(pi / 2, 'z'), array(((0, 1, 0), (-1, 0, 0), (0, 0, 1)))) diff --git a/sapphire/tests/transformations/test_base.py b/sapphire/tests/transformations/test_base.py index fd735304..5349ae46 100644 --- a/sapphire/tests/transformations/test_base.py +++ b/sapphire/tests/transformations/test_base.py @@ -4,18 +4,19 @@ class DecimalSexagesimalTests(unittest.TestCase): - def setUp(self): # (decimal, sexagesimal) - self.combinations = ((0, (0, 0, 0)), - (1, (1, 0, 0)), - (30, (30, 0, 0)), - (1 / 60., (0, 1, 0)), - (-1 + (30 / 60.), (0, -30, 0)), - (-1 - (30 / 60.) - (30 / 3600.), (-1, -30, -30)), - (.5, (0, 30, 0)), - (1 / 3600., (0, 0, 1)), - (30 / 3600., (0, 0, 30))) + self.combinations = ( + (0, (0, 0, 0)), + (1, (1, 0, 0)), + (30, (30, 0, 0)), + (1 / 60.0, (0, 1, 0)), + (-1 + (30 / 60.0), (0, -30, 0)), + (-1 - (30 / 60.0) - (30 / 3600.0), (-1, -30, -30)), + (0.5, (0, 30, 0)), + (1 / 3600.0, (0, 0, 1)), + (30 / 3600.0, (0, 0, 30)), + ) def test_decimal_to_sexagesimal(self): for dec, sexa in self.combinations: diff --git a/sapphire/tests/transformations/test_celestial.py b/sapphire/tests/transformations/test_celestial.py index c193593d..95fb3805 100644 --- a/sapphire/tests/transformations/test_celestial.py +++ b/sapphire/tests/transformations/test_celestial.py @@ -8,40 +8,34 @@ from sapphire.transformations import base, celestial, clock -# This is to switch off tests in case astropy is not present -# Noqa used to silence flake8 try: - import astropy # noqa : F401 - has_astropy = True + import astropy # noqa: F401 + except ImportError: - has_astropy = False + astropy_available = False +else: + astropy_available = True class ZenithAzimuthHorizontalTests(unittest.TestCase): - def setUp(self): - self.zenith = (0., pi / 4., pi / 2.) - self.altitude = (pi / 2., pi / 4., 0.) + self.zenith = (0.0, pi / 4.0, pi / 2.0) + self.altitude = (pi / 2.0, pi / 4.0, 0.0) - self.azimuth = (-pi / 2., 0., pi / 2.) # -pi - self.Azimuth = (-pi, pi / 2., 0.) # -pi / 2. + self.azimuth = (-pi / 2.0, 0.0, pi / 2.0) # -pi + self.Azimuth = (-pi, pi / 2.0, 0.0) # -pi / 2. def test_zenithazimuth_to_horizontal(self): for zenith, altitude in zip(self.zenith, self.altitude): - self.assertEqual(celestial.zenithazimuth_to_horizontal( - zenith, 0)[0], altitude) - self.assertEqual(celestial.horizontal_to_zenithazimuth( - altitude, 0)[0], zenith) + self.assertEqual(celestial.zenithazimuth_to_horizontal(zenith, 0)[0], altitude) + self.assertEqual(celestial.horizontal_to_zenithazimuth(altitude, 0)[0], zenith) for azimuth, Azimuth in zip(self.azimuth, self.Azimuth): - self.assertEqual(celestial.zenithazimuth_to_horizontal( - 0, azimuth)[1], Azimuth) - self.assertEqual(celestial.horizontal_to_zenithazimuth( - 0, Azimuth)[1], azimuth) + self.assertEqual(celestial.zenithazimuth_to_horizontal(0, azimuth)[1], Azimuth) + self.assertEqual(celestial.horizontal_to_zenithazimuth(0, Azimuth)[1], azimuth) class EquatorialTests(unittest.TestCase): - """Accuracy tests for Celestial coordinate transformations. Use references as tests also used by astropy. @@ -92,13 +86,12 @@ def test_against_hor2eq(self): gps = clock.utc_to_gps(utc) zenith, azimuth = celestial.horizontal_to_zenithazimuth( np.radians(base.sexagesimal_to_decimal(*altitude)), - np.radians(base.sexagesimal_to_decimal(*azi))) - ra, dec = celestial.zenithazimuth_to_equatorial(latitude, longitude, - gps, zenith, azimuth) + np.radians(base.sexagesimal_to_decimal(*azi)), + ) + ra, dec = celestial.zenithazimuth_to_equatorial(latitude, longitude, gps, zenith, azimuth) # Test eq_to_zenaz merely against IDL - zencalc, azcalc = celestial.equatorial_to_zenithazimuth( - latitude, longitude, gps, ra_expected, dec_expected) + zencalc, azcalc = celestial.equatorial_to_zenithazimuth(latitude, longitude, gps, ra_expected, dec_expected) self.assertAlmostEqual(ra, ra_expected, 1) self.assertAlmostEqual(ra, ra_astropy, 1) @@ -136,11 +129,9 @@ def test_against_pyephem(self): # SAPPHiRE gps = clock.utc_to_gps(calendar.timegm(utc.utctimetuple())) zenith, azimuth = celestial.horizontal_to_zenithazimuth(altitude, azi) - ra, dec = celestial.zenithazimuth_to_equatorial( - latitude, longitude, gps, zenith, azimuth) + ra, dec = celestial.zenithazimuth_to_equatorial(latitude, longitude, gps, zenith, azimuth) - zencalc, azcalc = celestial.equatorial_to_zenithazimuth( - latitude, longitude, gps, ra_expected, dec_expected) + zencalc, azcalc = celestial.equatorial_to_zenithazimuth(latitude, longitude, gps, ra_expected, dec_expected) self.assertAlmostEqual(ra, ra_expected, 2) self.assertAlmostEqual(ra, ra_astropy, 2) @@ -177,13 +168,9 @@ def test_against_jpl_horizons(self): # SAPPHiRE gps = clock.utc_to_gps(calendar.timegm(utc.utctimetuple())) zenith, azimuth = celestial.horizontal_to_zenithazimuth(altitude, azi) - ra, dec = celestial.zenithazimuth_to_equatorial(latitude, longitude, - gps, zenith, azimuth) + ra, dec = celestial.zenithazimuth_to_equatorial(latitude, longitude, gps, zenith, azimuth) - zencalc, azcalc = celestial.equatorial_to_zenithazimuth(latitude, - longitude, gps, - ra_expected, - dec_expected) + zencalc, azcalc = celestial.equatorial_to_zenithazimuth(latitude, longitude, gps, ra_expected, dec_expected) self.assertAlmostEqual(ra, ra_expected, 3) self.assertAlmostEqual(ra, ra_astropy, 3) @@ -194,7 +181,7 @@ def test_against_jpl_horizons(self): self.assertAlmostEqual(azcalc, azimuth, 2) -@unittest.skipUnless(has_astropy, "astropy required.") +@unittest.skipUnless(astropy_available, 'astropy required.') class AstropyEquatorialTests(unittest.TestCase): """ This tests the 4 new astropy functions. They should be very close to @@ -209,30 +196,31 @@ def setUp(self): and mute output """ from astropy.utils import iers + iers.conf.auto_download = False def test_pyephem_htoea(self): - """ Check celestial.horizontal_to_equatorial_astropy """ + """Check celestial.horizontal_to_equatorial_astropy""" # This is the transform inputs - eq = [(-39.34633914878846, -112.2277168069694, 1295503840, - 3.8662384455822716, -0.31222454326513827), - (53.13143508448587, -49.24074935964933, 985619982, - 3.901575896592809, -0.3926720112815971), - (48.02031016860923, -157.4023812557098, 1126251396, - 3.366278312183976, -1.3610394240813288)] + eq = [ + (-39.34633914878846, -112.2277168069694, 1295503840, 3.8662384455822716, -0.31222454326513827), + (53.13143508448587, -49.24074935964933, 985619982, 3.901575896592809, -0.3926720112815971), + (48.02031016860923, -157.4023812557098, 1126251396, 3.366278312183976, -1.3610394240813288), + ] # result of pyephem hor->eq/zenaz-> eq - efemeq = [(5.620508199785029, -0.3651173667585858), - (5.244630787139936, -0.7866376569183651), - (2.276751381056623, -1.0406498066785745)] + efemeq = [ + (5.620508199785029, -0.3651173667585858), + (5.244630787139936, -0.7866376569183651), + (2.276751381056623, -1.0406498066785745), + ] htoea_test = [] # Produce horizontal_to_equatorial_astropy results for latitude, longitude, gps, az, alt in eq: - result = celestial.horizontal_to_equatorial_astropy( - latitude, longitude, gps, [(az, alt)]) + result = celestial.horizontal_to_equatorial_astropy(latitude, longitude, gps, [(az, alt)]) htoea_test.extend(result) # Check if all inputs are correct @@ -243,23 +231,23 @@ def test_pyephem_etoha(self): """Check celestial.equatorial_to_horizontal_astropy""" # This is the transform inputs - eq = [(-39.34633914878846, -112.2277168069694, 1295503840, - 3.8662384455822716, -0.31222454326513827), - (53.13143508448587, -49.24074935964933, 985619982, - 3.901575896592809, -0.3926720112815971), - (48.02031016860923, -157.4023812557098, 1126251396, - 3.366278312183976, -1.3610394240813288)] + eq = [ + (-39.34633914878846, -112.2277168069694, 1295503840, 3.8662384455822716, -0.31222454326513827), + (53.13143508448587, -49.24074935964933, 985619982, 3.901575896592809, -0.3926720112815971), + (48.02031016860923, -157.4023812557098, 1126251396, 3.366278312183976, -1.3610394240813288), + ] # result of pyephem eq->hor - altaz = [(2.175107479095459, -0.19537943601608276), - (5.25273323059082, -0.8308737874031067), - (3.4536221027374268, -0.894329845905304)] + altaz = [ + (2.175107479095459, -0.19537943601608276), + (5.25273323059082, -0.8308737874031067), + (3.4536221027374268, -0.894329845905304), + ] etoha_test = [] # Produce equatorial_to_horizontal_astropy results for latitude, longitude, gps, ra, dec in eq: - result = celestial.equatorial_to_horizontal_astropy( - latitude, longitude, gps, [(ra, dec)]) + result = celestial.equatorial_to_horizontal_astropy(latitude, longitude, gps, [(ra, dec)]) etoha_test.extend(result) @@ -272,23 +260,23 @@ def test_pyephem_eqtozenaz(self): """ # This is the transform inputs - eq = [(-39.34633914878846, -112.2277168069694, 1295503840, - 3.8662384455822716, -0.31222454326513827), - (53.13143508448587, -49.24074935964933, 985619982, - 3.901575896592809, -0.3926720112815971), - (48.02031016860923, -157.4023812557098, 1126251396, - 3.366278312183976, -1.3610394240813288)] + eq = [ + (-39.34633914878846, -112.2277168069694, 1295503840, 3.8662384455822716, -0.31222454326513827), + (53.13143508448587, -49.24074935964933, 985619982, 3.901575896592809, -0.3926720112815971), + (48.02031016860923, -157.4023812557098, 1126251396, 3.366278312183976, -1.3610394240813288), + ] # result converted for eq->zenaz - zenaz = [(1.7661757628109793, -0.6043111523005624), - (2.4016701141980032, 2.6012484033836625), - (2.4651261727002005, -1.8828257759425302)] + zenaz = [ + (1.7661757628109793, -0.6043111523005624), + (2.4016701141980032, 2.6012484033836625), + (2.4651261727002005, -1.8828257759425302), + ] eqtozenaz_test = [] # Produce equatorial_to_zenithazimuth_astropy results for latitude, longitude, gps, ra, dec in eq: - result = celestial.equatorial_to_zenithazimuth_astropy( - latitude, longitude, gps, [(ra, dec)]) + result = celestial.equatorial_to_zenithazimuth_astropy(latitude, longitude, gps, [(ra, dec)]) eqtozenaz_test.extend(result) # Check if all inputs are correct, cast to numpy array for certainty @@ -301,24 +289,24 @@ def test_pyephem_zenaztoeq(self): # equatorial inputs from test_pyephem_eqtozenaz transformed into # the shape of zenithazimuth coordinates so that they may be used for # reverse benchmarking purposes. - zeneq = [(-39.34633914878846, -112.2277168069694, - 1295503840, 1.8830208700600348, -2.295442118787375), - (53.13143508448587, -49.24074935964933, 985619982, - 1.9634683380764937, -2.3307795697979126), - (48.02031016860923, -157.4023812557098, 1126251396, - 2.9318357508762256, -1.7954819853890793)] + zeneq = [ + (-39.34633914878846, -112.2277168069694, 1295503840, 1.8830208700600348, -2.295442118787375), + (53.13143508448587, -49.24074935964933, 985619982, 1.9634683380764937, -2.3307795697979126), + (48.02031016860923, -157.4023812557098, 1126251396, 2.9318357508762256, -1.7954819853890793), + ] # result of pyephem hor->eq/zenaz-> eq - efemeq = [(5.620508199785029, -0.3651173667585858), - (5.244630787139936, -0.7866376569183651), - (2.276751381056623, -1.0406498066785745)] + efemeq = [ + (5.620508199785029, -0.3651173667585858), + (5.244630787139936, -0.7866376569183651), + (2.276751381056623, -1.0406498066785745), + ] zenaztoeq_test = [] # Produce zenithazimuth_to_equatorial_astropy results for latitude, longitude, gps, zen, az in zeneq: - result = celestial.zenithazimuth_to_equatorial_astropy( - latitude, longitude, gps, [(zen, az)]) + result = celestial.zenithazimuth_to_equatorial_astropy(latitude, longitude, gps, [(zen, az)]) zenaztoeq_test.extend(result) # Check if all inputs are correct, cast to numpy array for certainty diff --git a/sapphire/tests/transformations/test_clock.py b/sapphire/tests/transformations/test_clock.py index 64d7c7db..9facb6cb 100644 --- a/sapphire/tests/transformations/test_clock.py +++ b/sapphire/tests/transformations/test_clock.py @@ -6,7 +6,6 @@ class DecimalTimeTests(unittest.TestCase): - def test_time_to_decimal(self): """Check time to decimal hours conversion @@ -23,17 +22,14 @@ def test_time_to_decimal(self): class DateTests(unittest.TestCase): - def test_date_to_juliandate(self): self.assertEqual(clock.date_to_juliandate(2010, 12, 25), 2455555.5) def test_datetime_to_juliandate(self): - self.assertEqual(clock.datetime_to_juliandate(datetime.datetime(2010, 12, 25, 12)), - 2455556.0) + self.assertEqual(clock.datetime_to_juliandate(datetime.datetime(2010, 12, 25, 12)), 2455556.0) class ModifiedJulianDateTests(unittest.TestCase): - def test_datetime_to_modifiedjd(self): mjd = clock.datetime_to_modifiedjd(datetime.datetime(2010, 12, 25, 12)) self.assertEqual(mjd, 55555.5) @@ -41,76 +37,78 @@ def test_datetime_to_modifiedjd(self): def test_juliandate_to_modifiedjd(self): """Difference between Julian Date and Modified JD is 2400000.5""" - self.assertEqual(clock.juliandate_to_modifiedjd(2400000.5), 0.) - self.assertEqual(clock.modifiedjd_to_juliandate(0.), 2400000.5) + self.assertEqual(clock.juliandate_to_modifiedjd(2400000.5), 0.0) + self.assertEqual(clock.modifiedjd_to_juliandate(0.0), 2400000.5) for _ in range(5): - modifiedjd = random.uniform(0, 5000000) + modifiedjd = random.uniform(0, 5_000_000) self.assertAlmostEqual( - clock.juliandate_to_modifiedjd( - clock.modifiedjd_to_juliandate(modifiedjd)), - modifiedjd) - juliandate = random.uniform(0, 5000000) + clock.juliandate_to_modifiedjd(clock.modifiedjd_to_juliandate(modifiedjd)), + modifiedjd, + ) + juliandate = random.uniform(0, 5_000_000) self.assertAlmostEqual( - clock.modifiedjd_to_juliandate( - clock.juliandate_to_modifiedjd(juliandate)), - juliandate) + clock.modifiedjd_to_juliandate(clock.juliandate_to_modifiedjd(juliandate)), + juliandate, + ) class JulianDateToDateTimeTests(unittest.TestCase): - def test_juliandate_to_utc(self): - self.assertEqual(clock.juliandate_to_utc(2400000.5), - datetime.datetime(1858, 11, 17)) - self.assertEqual(clock.juliandate_to_utc(2455581.40429), - datetime.datetime(2011, 1, 19, 21, 42, 10, 655997)) + self.assertEqual(clock.juliandate_to_utc(2400000.5), datetime.datetime(1858, 11, 17)) + self.assertEqual(clock.juliandate_to_utc(2455581.40429), datetime.datetime(2011, 1, 19, 21, 42, 10, 655997)) def test_juliandate_to_utc_gap(self): - self.assertEqual(clock.juliandate_to_utc(2299159.5), - datetime.datetime(1582, 10, 4)) - self.assertEqual(clock.juliandate_to_utc(2299160.5), - datetime.datetime(1582, 10, 15)) + self.assertEqual(clock.juliandate_to_utc(2299159.5), datetime.datetime(1582, 10, 4)) + self.assertEqual(clock.juliandate_to_utc(2299160.5), datetime.datetime(1582, 10, 15)) def test_modifiedjd_to_utc(self): - self.assertEqual(clock.modifiedjd_to_utc(55580.90429), - datetime.datetime(2011, 1, 19, 21, 42, 10, 655997)) + self.assertEqual(clock.modifiedjd_to_utc(55580.90429), datetime.datetime(2011, 1, 19, 21, 42, 10, 655997)) class GMSTTests(unittest.TestCase): - def test_utc_to_gmst(self): # Perhaps not perfect test, a few seconds of uncertainty exist.. - self.assertAlmostEqual(clock.utc_to_gmst(datetime.datetime(2010, 12, 25)), - clock.time_to_decimal(datetime.time(6, 13, 35, 852535))) + self.assertAlmostEqual( + clock.utc_to_gmst(datetime.datetime(2010, 12, 25)), + clock.time_to_decimal(datetime.time(6, 13, 35, 852535)), + ) class LSTTests(unittest.TestCase): - def test_gmst_to_lst(self): for _ in range(5): hours = random.uniform(0, 23.934) longitude = random.uniform(-180, 180) - self.assertAlmostEqual(clock.lst_to_gmst(clock.gmst_to_lst(hours, longitude), - longitude), hours) + self.assertAlmostEqual(clock.lst_to_gmst(clock.gmst_to_lst(hours, longitude), longitude), hours) def test_utc_to_lst_gmst(self): - self.assertEqual(clock.utc_to_lst(datetime.datetime(2010, 12, 25), 0), - clock.utc_to_gmst(datetime.datetime(2010, 12, 25))) + self.assertEqual( + clock.utc_to_lst(datetime.datetime(2010, 12, 25), 0), + clock.utc_to_gmst(datetime.datetime(2010, 12, 25)), + ) # Perhaps not perfect test, a few seconds of uncertainty exist.. - self.assertAlmostEqual(clock.utc_to_lst(datetime.datetime(2010, 12, 25), 0), - clock.time_to_decimal(datetime.time(6, 13, 35, 852535))) + self.assertAlmostEqual( + clock.utc_to_lst(datetime.datetime(2010, 12, 25), 0), + clock.time_to_decimal(datetime.time(6, 13, 35, 852535)), + ) def test_utc_to_lst_at_longitudes(self): - self.assertAlmostEqual(clock.utc_to_lst(datetime.datetime(2010, 12, 25), 90), - clock.time_to_decimal(datetime.time(12, 13, 35, 852535))) - self.assertAlmostEqual(clock.utc_to_lst(datetime.datetime(2010, 12, 25), 180), - clock.time_to_decimal(datetime.time(18, 13, 35, 852535))) - self.assertAlmostEqual(clock.utc_to_lst(datetime.datetime(2010, 12, 25), 5), - clock.time_to_decimal(datetime.time(6, 33, 35, 852535))) + self.assertAlmostEqual( + clock.utc_to_lst(datetime.datetime(2010, 12, 25), 90), + clock.time_to_decimal(datetime.time(12, 13, 35, 852535)), + ) + self.assertAlmostEqual( + clock.utc_to_lst(datetime.datetime(2010, 12, 25), 180), + clock.time_to_decimal(datetime.time(18, 13, 35, 852535)), + ) + self.assertAlmostEqual( + clock.utc_to_lst(datetime.datetime(2010, 12, 25), 5), + clock.time_to_decimal(datetime.time(6, 33, 35, 852535)), + ) class GPSTimeTests(unittest.TestCase): - def setUp(self): """Setup combinations of calendar dates, timestamps and leap seconds @@ -121,38 +119,38 @@ def setUp(self): t = calendar.timegm(time.strptime(date, '%B %d, %Y')) """ - self.combinations = (('July 1, 2015', 1435708800, 17), - ('January 1, 2014', 1388534400, 16), - ('July 1, 2012', 1341100800, 16), - ('June 30, 2012', 1341014400, 15), - ('January 1, 2009', 1230768000, 15), - ('December 31, 2008', 1230681600, 14), - ('January 1, 2006', 1136073600, 14), - ('December 31, 2005', 1135987200, 13), - ('January 1, 2004', 1072915200, 13), - ('January 1, 1999', 915148800, 13), - ('July 1, 1997', 867715200, 12), - ('January 1, 1996', 820454400, 11), - ('July 1, 1994', 773020800, 10), - ('July 1, 1993', 741484800, 9), - ('July 1, 1992', 709948800, 8), - ('January 1, 1991', 662688000, 7), - ('January 1, 1990', 631152000, 6), - ('January 1, 1988', 567993600, 5), - ('July 1, 1985', 489024000, 4), - ('July 1, 1983', 425865600, 3), - ('July 1, 1982', 394329600, 2), - ('July 1, 1981', 362793600, 1)) + self.combinations = ( + ('July 1, 2015', 1435708800, 17), + ('January 1, 2014', 1388534400, 16), + ('July 1, 2012', 1341100800, 16), + ('June 30, 2012', 1341014400, 15), + ('January 1, 2009', 1230768000, 15), + ('December 31, 2008', 1230681600, 14), + ('January 1, 2006', 1136073600, 14), + ('December 31, 2005', 1135987200, 13), + ('January 1, 2004', 1072915200, 13), + ('January 1, 1999', 915148800, 13), + ('July 1, 1997', 867715200, 12), + ('January 1, 1996', 820454400, 11), + ('July 1, 1994', 773020800, 10), + ('July 1, 1993', 741484800, 9), + ('July 1, 1992', 709948800, 8), + ('January 1, 1991', 662688000, 7), + ('January 1, 1990', 631152000, 6), + ('January 1, 1988', 567993600, 5), + ('July 1, 1985', 489024000, 4), + ('July 1, 1983', 425865600, 3), + ('July 1, 1982', 394329600, 2), + ('July 1, 1981', 362793600, 1), + ) def test_gps_to_utc(self): for date, _, _ in self.combinations: - self.assertEqual(clock.gps_to_utc(clock.gps_from_string(date)), - clock.utc_from_string(date)) + self.assertEqual(clock.gps_to_utc(clock.gps_from_string(date)), clock.utc_from_string(date)) def test_utc_to_gps(self): for date, _, _ in self.combinations: - self.assertEqual(clock.utc_to_gps(clock.utc_from_string(date)), - clock.gps_from_string(date)) + self.assertEqual(clock.utc_to_gps(clock.utc_from_string(date)), clock.gps_from_string(date)) def test_utc_from_string(self): for date, timestamp, _ in self.combinations: diff --git a/sapphire/tests/transformations/test_geographic.py b/sapphire/tests/transformations/test_geographic.py index 175c09c6..3b39e634 100644 --- a/sapphire/tests/transformations/test_geographic.py +++ b/sapphire/tests/transformations/test_geographic.py @@ -4,9 +4,8 @@ class GeographicTransformationTests(unittest.TestCase): - def setUp(self): - self.ref_enu = (0., 0., 0.) + self.ref_enu = (0.0, 0.0, 0.0) self.ref_lla = (52.35592417, 4.95114402, 56.10234594) self.transform = geographic.FromWGS84ToENUTransformation(self.ref_lla) @@ -38,6 +37,6 @@ def assert_tuple_almost_equal(self, actual, expected, places=7): self.assertIsInstance(actual, tuple) self.assertIsInstance(expected, tuple) - msg = f"Tuples differ: {actual} != {expected}" + msg = f'Tuples differ: {actual} != {expected}' for actual_value, expected_value in zip(actual, expected): self.assertAlmostEqual(actual_value, expected_value, places=places, msg=msg) diff --git a/sapphire/tests/validate_results.py b/sapphire/tests/validate_results.py index f645254c..8153763c 100644 --- a/sapphire/tests/validate_results.py +++ b/sapphire/tests/validate_results.py @@ -14,8 +14,7 @@ def validate_results(test, expected_path, actual_path): :param actual_path: path to the output from the test. """ - with tables.open_file(expected_path, 'r') as expected_file, \ - tables.open_file(actual_path, 'r') as actual_file: + with tables.open_file(expected_path, 'r') as expected_file, tables.open_file(actual_path, 'r') as actual_file: for expected_node in expected_file.walk_nodes('/', 'Leaf'): try: actual_node = actual_file.get_node(expected_node._v_pathname) @@ -34,8 +33,7 @@ def validate_results(test, expected_path, actual_path): validate_attributes(test, expected_file.root, actual_file.root) -def validate_results_node(test, expected_path, actual_path, expected_node, - actual_node): +def validate_results_node(test, expected_path, actual_path, expected_node, actual_node): """Validate results by comparing two specific nodes :param test: instance of the TestCase. @@ -45,8 +43,7 @@ def validate_results_node(test, expected_path, actual_path, expected_node, :param actual_node: path to the output node from the test. """ - with tables.open_file(expected_path, 'r') as expected_file, \ - tables.open_file(actual_path, 'r') as actual_file: + with tables.open_file(expected_path, 'r') as expected_file, tables.open_file(actual_path, 'r') as actual_file: expected = expected_file.get_node(expected_node) try: actual = actual_file.get_node(actual_node) @@ -66,41 +63,53 @@ def validate_results_node(test, expected_path, actual_path, expected_node, def validate_tables(test, expected_node, actual_node): """Verify that two Tables are identical""" - test.assertEqual(expected_node.nrows, actual_node.nrows, - f"Tables '{expected_node._v_pathname}' do not have the same length.") + test.assertEqual( + expected_node.nrows, + actual_node.nrows, + f"Tables '{expected_node._v_pathname}' do not have the same length.", + ) for colname in expected_node.colnames: - test.assertIn(colname, actual_node.colnames, - f"Tables '{expected_node._v_pathname}' do not have the same columns.") + test.assertIn( + colname, + actual_node.colnames, + f"Tables '{expected_node._v_pathname}' do not have the same columns.", + ) expected_col = expected_node.col(colname) actual_col = actual_node.col(colname) assert_array_almost_equal( expected_col, actual_col, - err_msg=f"Tables '{expected_node._v_pathname}' column '{colname}' do not match." + err_msg=f"Tables '{expected_node._v_pathname}' column '{colname}' do not match.", ) def validate_vlarrays(test, expected_node, actual_node): """Verify that two VLArrays are identical""" - test.assertEqual(expected_node.shape, actual_node.shape, - f"VLArrays '{expected_node._v_pathname}' do not have the same shape.") + test.assertEqual( + expected_node.shape, + actual_node.shape, + f"VLArrays '{expected_node._v_pathname}' do not have the same shape.", + ) for expected_array, actual_array in zip(expected_node, actual_node): - test.assertTrue(all(expected_array == actual_array), - f"VLArrays '{expected_node._v_pathname}' do not match.") + test.assertTrue(all(expected_array == actual_array), f"VLArrays '{expected_node._v_pathname}' do not match.") def validate_arrays(test, expected_node, actual_node): """Verify that two Arrays are identical""" - test.assertEqual(expected_node.shape, actual_node.shape, - f"Arrays '{expected_node._v_pathname}' do not have the same shape.") - test.assertTrue(all(array(expected_node.read()) == array(actual_node.read())), - f"Arrays '{expected_node._v_pathname}' do not match.") + test.assertEqual( + expected_node.shape, + actual_node.shape, + f"Arrays '{expected_node._v_pathname}' do not have the same shape.", + ) + test.assertTrue( + all(array(expected_node.read()) == array(actual_node.read())), + f"Arrays '{expected_node._v_pathname}' do not match.", + ) def validate_attributes(test, expected_node, actual_node): """Verify that two nodes have the same user attributes""" - test.assertEqual(expected_node._v_attrs._v_attrnamesuser, - actual_node._v_attrs._v_attrnamesuser) + test.assertEqual(expected_node._v_attrs._v_attrnamesuser, actual_node._v_attrs._v_attrnamesuser) diff --git a/sapphire/time_util.py b/sapphire/time_util.py index b35aad61..3e2b3aa4 100644 --- a/sapphire/time_util.py +++ b/sapphire/time_util.py @@ -6,6 +6,7 @@ UTC/GPS/local time. No more! """ + import calendar import datetime import time @@ -38,7 +39,7 @@ def __init__(self, *args): timetuple = datetime_.utctimetuple() self._gpstimestamp = calendar.timegm(timetuple) else: - raise TypeError("Incorrect arguments") + raise TypeError('Incorrect arguments') def gpstimestamp(self): """Return the GPS date/time as a timestamp. @@ -83,4 +84,4 @@ def __str__(self): return self.description() def __repr__(self): - return "%s(%d)" % (self.__class__.__name__, self._gpstimestamp) + return '%s(%d)' % (self.__class__.__name__, self._gpstimestamp) diff --git a/sapphire/transformations/__init__.py b/sapphire/transformations/__init__.py index c202d444..d6a69822 100644 --- a/sapphire/transformations/__init__.py +++ b/sapphire/transformations/__init__.py @@ -23,11 +23,7 @@ """ + from . import angles, axes, base, celestial, clock, geographic -__all__ = ['angles', - 'axes', - 'base', - 'celestial', - 'clock', - 'geographic'] +__all__ = ['angles', 'axes', 'base', 'celestial', 'clock', 'geographic'] diff --git a/sapphire/transformations/angles.py b/sapphire/transformations/angles.py index 97e9ace3..670e808c 100644 --- a/sapphire/transformations/angles.py +++ b/sapphire/transformations/angles.py @@ -1,9 +1,10 @@ -""" Perform various angle related transformations +"""Perform various angle related transformations - Transform between different notations for angles: - Degrees, radians and hours. +Transform between different notations for angles: +Degrees, radians and hours. """ + from numpy import degrees, radians @@ -14,7 +15,7 @@ def hours_to_degrees(angle): :return: angle in degrees """ - return angle * 15. + return angle * 15.0 def hours_to_radians(angle): @@ -34,7 +35,7 @@ def degrees_to_hours(angle): :return: angle in decimal hours """ - return angle / 15. + return angle / 15.0 def radians_to_hours(angle): diff --git a/sapphire/transformations/axes.py b/sapphire/transformations/axes.py index ee979e8c..0be2d23d 100644 --- a/sapphire/transformations/axes.py +++ b/sapphire/transformations/axes.py @@ -1,4 +1,4 @@ -""" Perform various axes related transformations +"""Perform various axes related transformations - Transformation between Cartesian, polar, cylindrical, spherical and compass coordinate systems. @@ -22,6 +22,7 @@ - z: height above x,y-plane. """ + from numpy import arccos, arctan2, array, cos, degrees, radians, sin, sqrt diff --git a/sapphire/transformations/base.py b/sapphire/transformations/base.py index 8ab770f7..6fc0a082 100644 --- a/sapphire/transformations/base.py +++ b/sapphire/transformations/base.py @@ -4,6 +4,7 @@ base 60 (sexagesimal). """ + from numpy import modf @@ -17,7 +18,7 @@ def decimal_to_sexagesimal(decimal): """ fractional, integral = modf(decimal) min_fractional, minutes = modf(fractional * 60) - seconds = min_fractional * 60. + seconds = min_fractional * 60.0 return integral.astype(int), minutes.astype(int), seconds @@ -33,4 +34,4 @@ def sexagesimal_to_decimal(hd, minutes, seconds): :return: decimal hours or degrees. """ - return hd + minutes / 60. + seconds / 3600. + return hd + minutes / 60.0 + seconds / 3600.0 diff --git a/sapphire/transformations/celestial.py b/sapphire/transformations/celestial.py index 7fcd67ff..46d74ce5 100644 --- a/sapphire/transformations/celestial.py +++ b/sapphire/transformations/celestial.py @@ -1,18 +1,19 @@ -""" Perform various Celestial coordinate transformations +"""Perform various Celestial coordinate transformations - This module performs transformations between different - Celestial coordinate systems. +This module performs transformations between different +Celestial coordinate systems. - Legacy transformations (all those not marked astropy): - Formulae from: Duffett-Smith1990 - 'Astronomy with your personal computer' - ISBN 0-521-38995-X +Legacy transformations (all those not marked astropy): +Formulae from: Duffett-Smith1990 +'Astronomy with your personal computer' +ISBN 0-521-38995-X - New transformations have been added with _astropy added to function name - They are very exact, in the order of arcsec. - Ethan van Woerkom is the author of the new transformations; contact him - for further information. +New transformations have been added with _astropy added to function name +They are very exact, in the order of arcsec. +Ethan van Woerkom is the author of the new transformations; contact him +for further information. """ + import datetime import warnings @@ -24,8 +25,7 @@ from . import angles, clock -def zenithazimuth_to_equatorial(latitude, longitude, timestamp, zenith, - azimuth): +def zenithazimuth_to_equatorial(latitude, longitude, timestamp, zenith, azimuth): """Convert Zenith Azimuth to Equatorial coordinates (J2000.0) :param latitude,longitude: Position of the observer on Earth in degrees. @@ -62,8 +62,8 @@ def zenithazimuth_to_horizontal(zenith, azimuth): Azimuth the angle in the horizontal plane, from North to East (NESW). """ - altitude = norm_angle(pi / 2. - zenith) - alt_azimuth = norm_angle(pi / 2. - azimuth) + altitude = norm_angle(pi / 2.0 - zenith) + alt_azimuth = norm_angle(pi / 2.0 - azimuth) return altitude, alt_azimuth @@ -142,14 +142,13 @@ def ha_to_ra(ha, lst): :return: Right ascension (ra) in radians. """ - ra = (angles.hours_to_radians(lst) - ha) + ra = angles.hours_to_radians(lst) - ha ra %= 2 * pi return ra -def equatorial_to_zenithazimuth(latitude, longitude, timestamp, - right_ascension, declination): +def equatorial_to_zenithazimuth(latitude, longitude, timestamp, right_ascension, declination): """Convert Equatorial (J2000.0) to Zenith Azimuth coordinates :param latitude,longitude: Position of the observer on Earth in degrees. @@ -167,7 +166,7 @@ def equatorial_to_zenithazimuth(latitude, longitude, timestamp, """ lst = clock.gps_to_lst(timestamp, longitude) - ha = (angles.hours_to_radians(lst) - right_ascension) + ha = angles.hours_to_radians(lst) - right_ascension ha %= 2 * pi slat = sin(radians(latitude)) @@ -178,8 +177,7 @@ def equatorial_to_zenithazimuth(latitude, longitude, timestamp, cdec = cos(declination) altitude = arcsin((sdec * slat) + (cdec * clat * cha)) - alt_azimuth = arccos((sdec - (slat * sin(altitude))) / - (clat * cos(altitude))) + alt_azimuth = arccos((sdec - (slat * sin(altitude))) / (clat * cos(altitude))) if sha > 0: alt_azimuth = 2 * pi - alt_azimuth @@ -197,9 +195,8 @@ def equatorial_to_zenithazimuth(latitude, longitude, timestamp, from astropy.coordinates import EarthLocation, SkyCoord from astropy.time import Time - def zenithazimuth_to_equatorial_astropy(latitude, longitude, utc_timestamp, - zenaz_coordinates): - """ Converts iterables of tuples of zenithazimuth + def zenithazimuth_to_equatorial_astropy(latitude, longitude, utc_timestamp, zenaz_coordinates): + """Converts iterables of tuples of zenithazimuth to equatorial coordinates :param latitude: Latitude in decimal degrees @@ -219,14 +216,10 @@ def zenithazimuth_to_equatorial_astropy(latitude, longitude, utc_timestamp, # Normalise angle horizontal_coordinates = norm_angle(horizontal_coordinates) - return horizontal_to_equatorial_astropy(latitude, longitude, - utc_timestamp, - horizontal_coordinates) + return horizontal_to_equatorial_astropy(latitude, longitude, utc_timestamp, horizontal_coordinates) - def equatorial_to_zenithazimuth_astropy(latitude, longitude, - utc_timestamp, - equatorial_coordinates): - """ Converts iterables of tuples of equatorial + def equatorial_to_zenithazimuth_astropy(latitude, longitude, utc_timestamp, equatorial_coordinates): + """Converts iterables of tuples of equatorial to zenithazimuth coordinates :param latitude: Latitude in decimal degrees @@ -240,7 +233,11 @@ def equatorial_to_zenithazimuth_astropy(latitude, longitude, equatorial_coordinates = np.array(equatorial_coordinates) horizontal_coordinates = equatorial_to_horizontal_astropy( - latitude, longitude, utc_timestamp, equatorial_coordinates) + latitude, + longitude, + utc_timestamp, + equatorial_coordinates, + ) # Convert and flip order of zenaz coordinates, done in numpy for speed horizontal_coordinates = np.array(horizontal_coordinates) @@ -252,10 +249,8 @@ def equatorial_to_zenithazimuth_astropy(latitude, longitude, return zenaz_coordinates - def equatorial_to_horizontal_astropy(latitude, longitude, - utc_timestamp, - equatorial_coordinates): - """ Converts iterables of tuples of equatorial coordinates + def equatorial_to_horizontal_astropy(latitude, longitude, utc_timestamp, equatorial_coordinates): + """Converts iterables of tuples of equatorial coordinates to horizontal coordinates :param latitude: Latitude in decimal degrees @@ -271,16 +266,13 @@ def equatorial_to_horizontal_astropy(latitude, longitude, location = EarthLocation(longitude, latitude) t = Time(datetime.datetime.utcfromtimestamp(utc_timestamp)) - equatorial_frame = SkyCoord(equatorial_coordinates, location=location, - obstime=t, unit=u.rad, frame='icrs') + equatorial_frame = SkyCoord(equatorial_coordinates, location=location, obstime=t, unit=u.rad, frame='icrs') horizontal_frame = equatorial_frame.transform_to('altaz') return np.array((horizontal_frame.az.rad, horizontal_frame.alt.rad)).T - def horizontal_to_equatorial_astropy(latitude, longitude, - utc_timestamp, - horizontal_coordinates): - """ Converts iterables of tuples of + def horizontal_to_equatorial_astropy(latitude, longitude, utc_timestamp, horizontal_coordinates): + """Converts iterables of tuples of horizontal coordinates to equatorial coordinates :param latitude: Latitude in decimal degrees @@ -294,12 +286,11 @@ def horizontal_to_equatorial_astropy(latitude, longitude, location = EarthLocation(longitude, latitude) t = Time(datetime.datetime.utcfromtimestamp(utc_timestamp)) - horizontal_frame = SkyCoord(horizontal_coordinates, location=location, - obstime=t, unit=u.rad, frame='altaz') + horizontal_frame = SkyCoord(horizontal_coordinates, location=location, obstime=t, unit=u.rad, frame='altaz') equatorial_frame = horizontal_frame.transform_to('icrs') return np.array((equatorial_frame.ra.rad, equatorial_frame.dec.rad)).T except ImportError as e: - warnings.warn(str(e) + "\nImport of astropy failed", ImportWarning) + warnings.warn(str(e) + '\nImport of astropy failed', ImportWarning) diff --git a/sapphire/transformations/clock.py b/sapphire/transformations/clock.py index dba9323c..4fc9baec 100644 --- a/sapphire/transformations/clock.py +++ b/sapphire/transformations/clock.py @@ -1,4 +1,4 @@ -""" Time transformations +"""Time transformations This handles all the wibbly wobbly timey wimey stuff. Such as easy conversions between different time systems. @@ -23,6 +23,7 @@ https://github.com/adrn/apwlib """ + import calendar import datetime import math @@ -32,24 +33,26 @@ from . import angles, base #: Dates of leap second introductions. -LEAP_SECONDS = (('January 1, 2017', 18), - ('July 1, 2015', 17), - ('July 1, 2012', 16), - ('January 1, 2009', 15), - ('January 1, 2006', 14), - ('January 1, 1999', 13), - ('July 1, 1997', 12), - ('January 1, 1996', 11), - ('July 1, 1994', 10), - ('July 1, 1993', 9), - ('July 1, 1992', 8), - ('January 1, 1991', 7), - ('January 1, 1990', 6), - ('January 1, 1988', 5), - ('July 1, 1985', 4), - ('July 1, 1983', 3), - ('July 1, 1982', 2), - ('July 1, 1981', 1)) +LEAP_SECONDS = ( + ('January 1, 2017', 18), + ('July 1, 2015', 17), + ('July 1, 2012', 16), + ('January 1, 2009', 15), + ('January 1, 2006', 14), + ('January 1, 1999', 13), + ('July 1, 1997', 12), + ('January 1, 1996', 11), + ('July 1, 1994', 10), + ('July 1, 1993', 9), + ('July 1, 1992', 8), + ('January 1, 1991', 7), + ('January 1, 1990', 6), + ('January 1, 1988', 5), + ('July 1, 1985', 4), + ('July 1, 1983', 3), + ('July 1, 1982', 2), + ('July 1, 1981', 1), +) def time_to_decimal(time): @@ -59,8 +62,15 @@ def time_to_decimal(time): :return: decimal number representing the input time. """ - return (time.hour + time.minute / 60. + time.second / 3600. + - time.microsecond / 3600000000.) + minutes_per_hour = 60 + seconds_per_hour = 3600 + microseconds_per_hour = 3600_000_000 + return ( + time.hour + + time.minute / minutes_per_hour + + time.second / seconds_per_hour + + time.microsecond / microseconds_per_hour + ) def decimal_to_time(hours): @@ -99,7 +109,13 @@ def date_to_juliandate(year, month, day): year1 -= 1 month1 = month + 12 - if year1 > 1582 or (year1 == 1582 and month >= 10 and day >= 15): + # Correction for leap years + gregorian_year = 1582 + gregorian_month = 10 + gregorian_day = 15 + if year1 > gregorian_year or ( + year1 == gregorian_year and (month > gregorian_month or (month == gregorian_month and day >= gregorian_day)) + ): a = int(year1 / 100) b = 2 - a + int(a / 4) else: @@ -123,7 +139,7 @@ def datetime_to_juliandate(dt): """ juliandate = date_to_juliandate(dt.year, dt.month, dt.day) - decimal_time = time_to_decimal(dt.time()) / 24. + decimal_time = time_to_decimal(dt.time()) / 24.0 return juliandate + decimal_time @@ -166,16 +182,15 @@ def juliandate_to_gmst(juliandate): """ jd0 = int(juliandate - 0.5) + 0.5 # Julian Date of previous midnight - h = (juliandate - jd0) * 24. # Hours since mightnight + h = (juliandate - jd0) * 24.0 # Hours since mightnight # Days since J2000 (Julian Date 2451545.) - d0 = jd0 - 2451545. - d = juliandate - 2451545. - t = d / 36525. # Centuries since J2000 + d0 = jd0 - 2451545.0 + d = juliandate - 2451545.0 + t = d / 36525.0 # Centuries since J2000 - gmst = (6.697374558 + 0.06570982441908 * d0 + 1.00273790935 * h + - 0.000026 * t * t) + gmst = 6.697374558 + 0.06570982441908 * d0 + 1.00273790935 * h + 0.000026 * t * t - return gmst % 24. + return gmst % 24.0 def utc_to_gmst(dt): @@ -201,8 +216,8 @@ def gmst_to_utc(dt): """ jd = date_to_juliandate(dt.year, dt.month, dt.day) - d = jd - 2451545. - t = d / 36525. + d = jd - 2451545.0 + t = d / 36525.0 t0 = 6.697374558 + (2400.051336 * t) + (0.000025862 * t * t) t0 %= 24 @@ -211,8 +226,7 @@ def gmst_to_utc(dt): time = decimal_to_time(ut) - return dt.replace(hour=time.hour, minute=time.minute, second=time.seconds, - microsecond=time.microsecond) + return dt.replace(hour=time.hour, minute=time.minute, second=time.seconds, microsecond=time.microsecond) def juliandate_to_utc(juliandate): @@ -225,7 +239,9 @@ def juliandate_to_utc(juliandate): juliandate += 0.5 jd_frac, jd_int = math.modf(juliandate) - if jd_int > 2299160: + julian_gregorian_transition_date = 2299160 + + if jd_int > julian_gregorian_transition_date: a = int((jd_int - 1867216.25) / 36524.25) b = jd_int + 1 + a - int(a / 4) else: @@ -322,8 +338,7 @@ def gps_to_utc(timestamp): :return: UTC timestamp in seconds. """ - offset = next((seconds for date, seconds in LEAP_SECONDS - if timestamp >= utc_from_string(date)), 0) + offset = next((seconds for date, seconds in LEAP_SECONDS if timestamp >= utc_from_string(date)), 0) return timestamp - offset @@ -334,8 +349,7 @@ def utc_to_gps(timestamp): :return: GPS timestamp in seconds. """ - offset = next((seconds for date, seconds in LEAP_SECONDS - if timestamp >= utc_from_string(date)), 0) + offset = next((seconds for date, seconds in LEAP_SECONDS if timestamp >= utc_from_string(date)), 0) return timestamp + offset @@ -409,4 +423,4 @@ def process_time(time): try: return datetime_to_gps(time) except Exception: - raise RuntimeError('Unable to parse time: ', time) + raise RuntimeError('Unable to parse time: {time}') diff --git a/sapphire/transformations/geographic.py b/sapphire/transformations/geographic.py index 904e1162..4e31786a 100644 --- a/sapphire/transformations/geographic.py +++ b/sapphire/transformations/geographic.py @@ -1,9 +1,10 @@ -""" Perform various coordinate transformations +"""Perform various coordinate transformations - This module performs various coordinate transformations, based on some - well-known formulas. +This module performs various coordinate transformations, based on some +well-known formulas. """ + from math import atan2, cos, degrees, radians, sin, sqrt from numpy import array @@ -17,18 +18,18 @@ class WGS84Datum: editors have gone over them to make sure they are correct. """ + # Defining constants - a = 6378137. + a = 6378137.0 f = 1 / 298.257223563 # Derived constants b = a * (1 - f) - e = sqrt(2 * f - f ** 2) + e = sqrt(2 * f - f**2) eprime = sqrt(f * (2 - f) / (1 - f) ** 2) class FromWGS84ToENUTransformation: - """Convert between various geographic coordinate systems This class converts coordinates between LLA, ENU, and ECEF. @@ -89,11 +90,11 @@ def lla_to_ecef(self, coordinates): b = self.geode.b e = self.geode.e - n = a / sqrt(1 - e ** 2 * sin(latitude) ** 2) + n = a / sqrt(1 - e**2 * sin(latitude) ** 2) x = (n + altitude) * cos(latitude) * cos(longitude) y = (n + altitude) * cos(latitude) * sin(longitude) - z = (b ** 2 / a ** 2 * n + altitude) * sin(latitude) + z = (b**2 / a**2 * n + altitude) * sin(latitude) return x, y, z @@ -117,13 +118,12 @@ def ecef_to_lla(self, coordinates): e = self.geode.e eprime = self.geode.eprime - p = sqrt(x ** 2 + y ** 2) + p = sqrt(x**2 + y**2) th = atan2(a * z, b * p) longitude = atan2(y, x) - latitude = atan2((z + eprime ** 2 * b * sin(th) ** 3), - (p - e ** 2 * a * cos(th) ** 3)) - n = a / sqrt(1 - e ** 2 * sin(latitude) ** 2) + latitude = atan2((z + eprime**2 * b * sin(th) ** 3), (p - e**2 * a * cos(th) ** 3)) + n = a / sqrt(1 - e**2 * sin(latitude) ** 2) altitude = p / cos(latitude) - n return degrees(latitude), degrees(longitude), altitude @@ -149,10 +149,13 @@ def ecef_to_enu(self, coordinates): lat = radians(latitude) lon = radians(longitude) - transformation = array([ - [ -sin(lon), cos(lon), 0.], # noqa - [-sin(lat) * cos(lon), -sin(lat) * sin(lon), cos(lat)], - [ cos(lat) * cos(lon), cos(lat) * sin(lon), sin(lat)]]) # noqa + transformation = array( + [ + [-sin(lon), cos(lon), 0.0], + [-sin(lat) * cos(lon), -sin(lat) * sin(lon), cos(lat)], + [cos(lat) * cos(lon), cos(lat) * sin(lon), sin(lat)], + ], + ) coordinates = array([x - xr, y - yr, z - zr]) @@ -174,14 +177,17 @@ def enu_to_ecef(self, coordinates): lat = radians(latitude) lon = radians(longitude) - transformation = array([ - [-sin(lon), -sin(lat) * cos(lon), cos(lat) * cos(lon)], - [ cos(lon), -sin(lat) * sin(lon), cos(lat) * sin(lon)], # noqa - [ 0., cos(lat), sin(lat)]]) # noqa + transformation = array( + [ + [-sin(lon), -sin(lat) * cos(lon), cos(lat) * cos(lon)], + [cos(lon), -sin(lat) * sin(lon), cos(lat) * sin(lon)], + [0.0, cos(lat), sin(lat)], + ], + ) x, y, z = transformation.dot(array(coordinates)) return x + xr, y + yr, z + zr def __repr__(self): - return f"{self.__class__.__name__}({self.ref_lla!r})" + return f'{self.__class__.__name__}({self.ref_lla!r})' diff --git a/sapphire/utils.py b/sapphire/utils.py index ff188af2..ccd019d6 100644 --- a/sapphire/utils.py +++ b/sapphire/utils.py @@ -5,6 +5,7 @@ """ from bisect import bisect_right +from contextlib import suppress from distutils.spawn import find_executable from functools import wraps from os import environ @@ -47,14 +48,11 @@ def pbar(iterable, length=None, show=True, **kwargs): return iterable if length is None: - try: + with suppress(TypeError): length = len(iterable) - except TypeError: - pass if length: - pb = ProgressBar(max_value=length, - widgets=[Percentage(), Bar(), ETA()], **kwargs) + pb = ProgressBar(max_value=length, widgets=[Percentage(), Bar(), ETA()], **kwargs) return pb(iterable) else: return iterable @@ -130,7 +128,7 @@ def angle_between(zenith1, azimuth1, zenith2, azimuth2): """ dlat = zenith1 - zenith2 dlon = azimuth2 - azimuth1 - a = (sin(dlat / 2) ** 2 + sin(zenith1) * sin(zenith2) * sin(dlon / 2) ** 2) + a = sin(dlat / 2) ** 2 + sin(zenith1) * sin(zenith2) * sin(dlon / 2) ** 2 angle = 2 * arcsin(sqrt(a)) return angle @@ -143,7 +141,7 @@ def vector_length(x, y, z=0): :return: length of vector. """ - return sqrt(x ** 2 + y ** 2 + z ** 2) + return sqrt(x**2 + y**2 + z**2) def distance_between(x1, y1, x2, y2): @@ -173,7 +171,7 @@ def which(program): """ path = find_executable(program) if not path: - raise Exception('The program %s is not available.' % program) + raise RuntimeError(f'The program {program} is not available.') def memoize(method): @@ -182,11 +180,11 @@ def memoize(method): Source: https://stackoverflow.com/a/29954160/1033535 """ + @wraps(method) def memoizer(self, *args, **kwargs): - # Prepare and get reference to cache - attr = f"_memo_{method.__name__}" + attr = f'_memo_{method.__name__}' if not hasattr(self, attr): setattr(self, attr, {}) cache = getattr(self, attr) diff --git a/scripts/corsika/ground_particles.py b/scripts/corsika/ground_particles.py index 9c1d6fec..b1705128 100755 --- a/scripts/corsika/ground_particles.py +++ b/scripts/corsika/ground_particles.py @@ -6,17 +6,16 @@ def plot_ground(x, y, eventheader, title='Ground particles'): - size = 200. + size = 200.0 plt.figure(figsize=(9, 9)) - plt.scatter(x, y, c='r', s=2., edgecolors='none') + plt.scatter(x, y, c='r', s=2.0, edgecolors='none') plt.axis('equal') plt.axis([-size, size, -size, size]) plt.xlabel('x [m]') plt.ylabel('y [m]') plt.title(title) -# plt.savefig('/Users/arne/Dropbox/hisparc/Plots/Corsika/scatter_em.pdf') + # plt.savefig('/Users/arne/Dropbox/hisparc/Plots/Corsika/scatter_em.pdf') plt.show() - return def main(): @@ -31,8 +30,7 @@ def main(): x.append(particle.x) y.append(particle.y) event_header = event.get_header() - title = ('Ground particles, Primary: %s, E = %1.0e eV' % - (event_header.particle, event_header.energy)) + title = 'Ground particles, Primary: %s, E = %1.0e eV' % (event_header.particle, event_header.energy) plot_ground(x, y, event_header, title=title) diff --git a/scripts/corsika/qsub_sort_stored_corsika_data.py b/scripts/corsika/qsub_sort_stored_corsika_data.py index c961a049..879a28a5 100644 --- a/scripts/corsika/qsub_sort_stored_corsika_data.py +++ b/scripts/corsika/qsub_sort_stored_corsika_data.py @@ -1,9 +1,10 @@ -""" Convert unsorted CORSIKA HDF5 to sorted HDF5 using Stoomboot""" +"""Convert unsorted CORSIKA HDF5 to sorted HDF5 using Stoomboot""" import glob import logging import os import subprocess +import tempfile import textwrap QUEUE = 'generic' @@ -21,9 +22,13 @@ # To alleviate Stoomboot, make sure the job is not to short. sleep $[ ( $RANDOM % 60 ) + 60 ]""") -logging.basicConfig(filename=LOGFILE, filemode='a', - format='%(asctime)s %(name)s %(levelname)s: %(message)s', - datefmt='%y%m%d_%H%M%S', level=logging.INFO) +logging.basicConfig( + filename=LOGFILE, + filemode='a', + format='%(asctime)s %(name)s %(levelname)s: %(message)s', + datefmt='%y%m%d_%H%M%S', + level=logging.INFO, +) logger = logging.getLogger('qsub_store_corsika_data') @@ -85,12 +90,11 @@ def get_seeds_todo(): return seeds.difference(processed).difference(queued) - def get_script_path(seed): """Create path for script""" script_name = f'sort_{seed}.sh' - script_path = os.path.join('/tmp', script_name) + script_path = os.path.join(tempfile.gettempdir(), script_name) return script_path @@ -98,10 +102,10 @@ def create_script(seed): """Create script as temp file to run on Stoomboot""" script_path = get_script_path(seed) - input = SCRIPT_TEMPLATE.format(seed=os.path.join(DATADIR, seed)) + input_content = SCRIPT_TEMPLATE.format(seed=os.path.join(DATADIR, seed)) with open(script_path, 'w') as script: - script.write(input) + script.write(input_content) os.chmod(script_path, 0o774) return script_path @@ -119,13 +123,10 @@ def submit_job(seed): script_path = create_script(seed) - qsub = ('qsub -q {queue} -V -z -j oe -N {name} {script}' - .format(queue=QUEUE, name=os.path.basename(script_path), - script=script_path)) + qsub = f'qsub -q {QUEUE} -V -z -j oe -N {os.path.basename(script_path)} {script_path}' - result = subprocess.check_output(qsub, stderr=subprocess.STDOUT, - shell=True) - if not result == '': + result = subprocess.check_output(qsub, stderr=subprocess.STDOUT, shell=True) + if result != '': msg = f'{seed} - Error occured: {result}' logger.error(msg) raise Exception(msg) @@ -142,16 +143,14 @@ def check_queue(): """ queued = f'qstat {QUEUE} | grep [RQ] | wc -l' - user_queued = ('qstat -u $USER {queue} | grep [RQ] | wc -l' - .format(queue=QUEUE)) + user_queued = f'qstat -u $USER {QUEUE} | grep [RQ] | wc -l' n_queued = int(subprocess.check_output(queued, shell=True)) n_queued_user = int(subprocess.check_output(user_queued, shell=True)) max_queue = 4000 max_queue_user = 2000 keep_free = 50 - return min(max_queue - n_queued, - max_queue_user - n_queued_user) - keep_free + return min(max_queue - n_queued, max_queue_user - n_queued_user) - keep_free def run(): diff --git a/scripts/corsika/summary.py b/scripts/corsika/summary.py index cd4d1796..078a444d 100755 --- a/scripts/corsika/summary.py +++ b/scripts/corsika/summary.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -import sys from sapphire import corsika @@ -15,4 +14,4 @@ for particle in event.get_particles(): count += 1 - print(count, " particles") + print(count, ' particles') diff --git a/scripts/data/demo_mpv_fit.py b/scripts/data/demo_mpv_fit.py index 4e1b5a77..a2b523d9 100644 --- a/scripts/data/demo_mpv_fit.py +++ b/scripts/data/demo_mpv_fit.py @@ -4,6 +4,7 @@ Fit the MPV and show the results """ + import datetime import warnings @@ -26,17 +27,15 @@ def main(): for station in station_ids: plt.figure() for did in range(Station(station).n_detectors()): - n, bins = get_histogram_for_station_on_date(station, yesterday, - did) + n, bins = get_histogram_for_station_on_date(station, yesterday, did) find_mpv = FindMostProbableValueInSpectrum(n, bins) mpv, is_fitted = find_mpv.find_mpv() - plt.plot((bins[:-1] + bins[1:]) / 2., n, c=COLORS[did]) + plt.plot((bins[:-1] + bins[1:]) / 2.0, n, c=COLORS[did]) lines = ['dotted', 'solid'] - plt.axvline(mpv + did * (bins[1] - bins[0]) / 20., - c=COLORS[did], ls=lines[is_fitted]) + plt.axvline(mpv + did * (bins[1] - bins[0]) / 20.0, c=COLORS[did], ls=lines[is_fitted]) plt.title(station) - plt.xlim(0, bins[len(bins)/2]) + plt.xlim(0, bins[len(bins) / 2]) plt.yscale('log') diff --git a/scripts/data/hdf5_coincidences_to_csv.py b/scripts/data/hdf5_coincidences_to_csv.py index 8b555a05..b787708c 100755 --- a/scripts/data/hdf5_coincidences_to_csv.py +++ b/scripts/data/hdf5_coincidences_to_csv.py @@ -13,6 +13,7 @@ stored by :class:`sapphire.analysis.coincidences.CoincidencesESD` """ + import csv import tables @@ -39,8 +40,7 @@ def data_to_csv(destination, source, coincidences_group): s_numbers = [s_group.split('_')[-1] for s_group in s_index[:]] csvwriter.writerow(coincidences.colnames) - csvwriter.writerow(['station_number'] + - source.getNode(s_index[0]).events.colnames) + csvwriter.writerow(['station_number'] + source.getNode(s_index[0]).events.colnames) for coincidence in coincidences[:]: csvwriter.writerow(coincidence) @@ -50,7 +50,7 @@ def data_to_csv(destination, source, coincidences_group): csvwriter.writerow([s_numbers[s_idx]] + list(event)) -if __name__ == "__main__": +if __name__ == '__main__': # Path to the csv file to be created. destination = 'data.csv' diff --git a/scripts/kascade/direction_reconstruction.py b/scripts/kascade/direction_reconstruction.py index d740b254..a152c402 100644 --- a/scripts/kascade/direction_reconstruction.py +++ b/scripts/kascade/direction_reconstruction.py @@ -1,9 +1,9 @@ import os.path -from itertools import izip - import tables +import utils +from myshowerfront import * from scipy import integrate from scipy.interpolate import spline from scipy.special import erf @@ -14,10 +14,7 @@ from artist import GraphArtist, MultiPlot from pylab import * -import utils - -from myshowerfront import * -from sapphire.analysis.direction_reconstruction import BinnedDirectionReconstruction, DirectionReconstruction +from sapphire.analysis.direction_reconstruction import DirectionReconstruction DATADIR = '../simulations/plots' @@ -30,7 +27,7 @@ rcParams['font.serif'] = 'Computer Modern' rcParams['font.sans-serif'] = 'Computer Modern' rcParams['font.family'] = 'sans-serif' - rcParams['figure.figsize'] = [4 * x for x in (1, 2. / 3)] + rcParams['figure.figsize'] = [4 * x for x in (1, 2.0 / 3)] rcParams['figure.subplot.left'] = 0.175 rcParams['figure.subplot.bottom'] = 0.175 rcParams['font.size'] = 10 @@ -73,30 +70,32 @@ def plot_uncertainty_mip(table): r2, phi2 = station.calc_r_and_phi_for_detectors(1, 4) THETA = deg2rad(22.5) - DTHETA = deg2rad(5.) - DN = .1 + DTHETA = deg2rad(5.0) + DN = 0.1 LOGENERGY = 15 - DLOGENERGY = .5 + DLOGENERGY = 0.5 figure() x, y, y2 = [], [], [] for N in range(1, 6): x.append(N) - events = table.read_where('(abs(min_n134 - N) <= DN) & (abs(reference_theta - THETA) <= DTHETA) & (abs(log10(k_energy) - LOGENERGY) <= DLOGENERGY)') - print(len(events),) + events = table.read_where( + '(abs(min_n134 - N) <= DN) & (abs(reference_theta - THETA) <= DTHETA) & (abs(log10(k_energy) - LOGENERGY) <= DLOGENERGY)', + ) + print(len(events)) errors = events['reference_theta'] - events['reconstructed_theta'] # Make sure -pi < errors < pi errors = (errors + pi) % (2 * pi) - pi errors2 = events['reference_phi'] - events['reconstructed_phi'] # Make sure -pi < errors2 < pi errors2 = (errors2 + pi) % (2 * pi) - pi - #y.append(std(errors)) - #y2.append(std(errors2)) + # y.append(std(errors)) + # y2.append(std(errors2)) y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) print() - print("mip: min_n134, theta_std, phi_std") + print('mip: min_n134, theta_std, phi_std') for u, v, w in zip(x, y, y2): print(u, v, w) print() @@ -109,13 +108,13 @@ def plot_uncertainty_mip(table): phis = linspace(-pi, pi, 50) phi_errsq = mean(rec.rel_phi_errorsq(pi / 8, phis, phi1, phi2, r1, r2)) theta_errsq = mean(rec.rel_theta1_errorsq(pi / 8, phis, phi1, phi2, r1, r2)) - #ey = TIMING_ERROR * std_t(ex) * sqrt(phi_errsq) - #ey2 = TIMING_ERROR * std_t(ex) * sqrt(theta_errsq) + # ey = TIMING_ERROR * std_t(ex) * sqrt(phi_errsq) + # ey2 = TIMING_ERROR * std_t(ex) * sqrt(theta_errsq) R_list = [30, 20, 16, 14, 12] with tables.open_file('master-ch4v2.h5') as data2: mc = my_std_t_for_R(data2, x, R_list) - mc = sqrt(mc ** 2 + 1.2 ** 2 + 2.5 ** 2) + mc = sqrt(mc**2 + 1.2**2 + 2.5**2) print(mc) ey = mc * sqrt(phi_errsq) ey2 = mc * sqrt(theta_errsq) @@ -125,21 +124,24 @@ def plot_uncertainty_mip(table): ey2 = spline(x, ey2, nx) # Plots - plot(x, rad2deg(y), '^', label="Theta") - plot(sx, rad2deg(sy), '^', label="Theta (sim)") - plot(nx, rad2deg(ey2))#, label="Estimate Theta") - plot(x, rad2deg(y2), 'v', label="Phi") - plot(sx, rad2deg(sy2), 'v', label="Phi (sim)") - plot(nx, rad2deg(ey))#, label="Estimate Phi") + plot(x, rad2deg(y), '^', label='Theta') + plot(sx, rad2deg(sy), '^', label='Theta (sim)') + plot(nx, rad2deg(ey2)) # , label="Estimate Theta") + plot(x, rad2deg(y2), 'v', label='Phi') + plot(sx, rad2deg(sy2), 'v', label='Phi (sim)') + plot(nx, rad2deg(ey)) # , label="Estimate Phi") # Labels etc. - xlabel(r"$N_{MIP} \pm %.1f$" % DN) - ylabel("Angle reconstruction uncertainty [deg]") - title(r"$\theta = 22.5^\circ \pm %d^\circ \quad %.1f \leq \log(E) \leq %.1f$" % (rad2deg(DTHETA), LOGENERGY - DLOGENERGY, LOGENERGY + DLOGENERGY)) + xlabel(r'$N_{MIP} \pm %.1f$' % DN) + ylabel('Angle reconstruction uncertainty [deg]') + title( + r'$\theta = 22.5^\circ \pm %d^\circ \quad %.1f \leq \log(E) \leq %.1f$' + % (rad2deg(DTHETA), LOGENERGY - DLOGENERGY, LOGENERGY + DLOGENERGY), + ) legend(numpoints=1) xlim(0.5, 4.5) utils.saveplot() - print + print() graph = GraphArtist() graph.plot(x, rad2deg(y), mark='o', linestyle=None) @@ -148,8 +150,8 @@ def plot_uncertainty_mip(table): graph.plot(x, rad2deg(y2), mark='*', linestyle=None) graph.plot(sx, rad2deg(sy2), mark='square*', linestyle=None) graph.plot(nx, rad2deg(ey), mark=None) - graph.set_xlabel(r"$N_\mathrm{MIP} \pm %.1f$" % DN) - graph.set_ylabel(r"Angle reconstruction uncertainty [\si{\degree}]") + graph.set_xlabel(r'$N_\mathrm{MIP} \pm %.1f$' % DN) + graph.set_ylabel(r'Angle reconstruction uncertainty [\si{\degree}]') graph.set_xlimits(max=4.5) graph.set_ylimits(0, 40) graph.set_xticks(range(5)) @@ -165,10 +167,10 @@ def plot_uncertainty_zenith(table): r2, phi2 = station.calc_r_and_phi_for_detectors(1, 4) N = 2 - DTHETA = deg2rad(1.) - DN = .1 + DTHETA = deg2rad(1.0) + DN = 0.1 LOGENERGY = 15 - DLOGENERGY = .5 + DLOGENERGY = 0.5 figure() rcParams['text.usetex'] = False @@ -176,20 +178,22 @@ def plot_uncertainty_zenith(table): for theta in 5, 10, 15, 22.5, 30, 35: x.append(theta) THETA = deg2rad(theta) - events = table.read_where('(min_n134 >= N) & (abs(reference_theta - THETA) <= DTHETA) & (abs(log10(k_energy) - LOGENERGY) <= DLOGENERGY)') - print(theta, len(events),) + events = table.read_where( + '(min_n134 >= N) & (abs(reference_theta - THETA) <= DTHETA) & (abs(log10(k_energy) - LOGENERGY) <= DLOGENERGY)', + ) + print(theta, len(events)) errors = events['reference_theta'] - events['reconstructed_theta'] # Make sure -pi < errors < pi errors = (errors + pi) % (2 * pi) - pi errors2 = events['reference_phi'] - events['reconstructed_phi'] # Make sure -pi < errors2 < pi errors2 = (errors2 + pi) % (2 * pi) - pi - #y.append(std(errors)) - #y2.append(std(errors2)) + # y.append(std(errors)) + # y2.append(std(errors2)) y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) print() - print("zenith: theta, theta_std, phi_std") + print('zenith: theta, theta_std, phi_std') for u, v, w in zip(x, y, y2): print(u, v, w) print() @@ -212,108 +216,109 @@ def plot_uncertainty_zenith(table): graph = GraphArtist() # Plots - plot(x, rad2deg(y), '^', label="Theta") + plot(x, rad2deg(y), '^', label='Theta') graph.plot(x, rad2deg(y), mark='o', linestyle=None) - #plot(sx, rad2deg(sy), '^', label="Theta (sim)") - plot(rad2deg(ex), rad2deg(ey2))#, label="Estimate Theta") + # plot(sx, rad2deg(sy), '^', label="Theta (sim)") + plot(rad2deg(ex), rad2deg(ey2)) # , label="Estimate Theta") graph.plot(rad2deg(ex), rad2deg(ey2), mark=None) # Azimuthal angle undefined for zenith = 0 - plot(x[1:], rad2deg(y2[1:]), 'v', label="Phi") + plot(x[1:], rad2deg(y2[1:]), 'v', label='Phi') graph.plot(x[1:], rad2deg(y2[1:]), mark='*', linestyle=None) - #plot(sx[1:], rad2deg(sy2[1:]), 'v', label="Phi (sim)") - plot(rad2deg(ex), rad2deg(ey))#, label="Estimate Phi") + # plot(sx[1:], rad2deg(sy2[1:]), 'v', label="Phi (sim)") + plot(rad2deg(ex), rad2deg(ey)) # , label="Estimate Phi") graph.plot(rad2deg(ex), rad2deg(ey), mark=None) - #plot(rad2deg(ex), rad2deg(ey3), label="Estimate Phi * sin(Theta)") + # plot(rad2deg(ex), rad2deg(ey3), label="Estimate Phi * sin(Theta)") # Labels etc. - xlabel(r"Shower zenith angle [deg $\pm %d^\circ$]" % rad2deg(DTHETA)) - graph.set_xlabel(r"Shower zenith angle [\si{\degree}] $\pm \SI{%d}{\degree}$" % rad2deg(DTHETA)) - ylabel("Angle reconstruction uncertainty [deg]") - graph.set_ylabel(r"Angle reconstruction uncertainty [\si{\degree}]") - title(r"$N_{MIP} \geq %d, \quad %.1f \leq \log(E) \leq %.1f$" % (N, LOGENERGY - DLOGENERGY, LOGENERGY + DLOGENERGY)) + xlabel(r'Shower zenith angle [deg $\pm %d^\circ$]' % rad2deg(DTHETA)) + graph.set_xlabel(r'Shower zenith angle [\si{\degree}] $\pm \SI{%d}{\degree}$' % rad2deg(DTHETA)) + ylabel('Angle reconstruction uncertainty [deg]') + graph.set_ylabel(r'Angle reconstruction uncertainty [\si{\degree}]') + title(r'$N_{MIP} \geq %d, \quad %.1f \leq \log(E) \leq %.1f$' % (N, LOGENERGY - DLOGENERGY, LOGENERGY + DLOGENERGY)) ylim(0, 60) graph.set_ylimits(0, 60) - xlim(-.5, 37) + xlim(-0.5, 37) legend(numpoints=1) if USE_TEX: rcParams['text.usetex'] = True utils.saveplot() artist.utils.save_graph(graph, dirname='plots') - print + print() def plot_uncertainty_core_distance(table): N = 2 THETA = deg2rad(22.5) - DTHETA = deg2rad(5.) - DN = .5 + DTHETA = deg2rad(5.0) + DN = 0.5 DR = 10 LOGENERGY = 15 - DLOGENERGY = .5 + DLOGENERGY = 0.5 figure() x, y, y2 = [], [], [] for R in range(0, 81, 20): x.append(R) - events = table.read_where('(abs(min_n134 - N) <= DN) & (abs(reference_theta - THETA) <= DTHETA) & (abs(r - R) <= DR) & (abs(log10(k_energy) - LOGENERGY) <= DLOGENERGY)') - print(len(events),) + events = table.read_where( + '(abs(min_n134 - N) <= DN) & (abs(reference_theta - THETA) <= DTHETA) & (abs(r - R) <= DR) & (abs(log10(k_energy) - LOGENERGY) <= DLOGENERGY)', + ) + print(len(events)) errors = events['reference_theta'] - events['reconstructed_theta'] # Make sure -pi < errors < pi errors = (errors + pi) % (2 * pi) - pi errors2 = events['reference_phi'] - events['reconstructed_phi'] # Make sure -pi < errors2 < pi errors2 = (errors2 + pi) % (2 * pi) - pi - #y.append(std(errors)) - #y2.append(std(errors2)) + # y.append(std(errors)) + # y2.append(std(errors2)) y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) print() - print("R: theta_std, phi_std") + print('R: theta_std, phi_std') for u, v, w in zip(x, y, y2): print(u, v, w) print() -# # Simulation data + # # Simulation data sx, sy, sy2 = loadtxt(os.path.join(DATADIR, 'DIR-plot_uncertainty_core_distance.txt')) graph = GraphArtist() # Plots - plot(x, rad2deg(y), '^-', label="Theta") + plot(x, rad2deg(y), '^-', label='Theta') graph.plot(x[:-1], rad2deg(y[:-1]), mark='o') - plot(sx, rad2deg(sy), '^-', label="Theta (sim)") + plot(sx, rad2deg(sy), '^-', label='Theta (sim)') graph.plot(sx[:-1], rad2deg(sy[:-1]), mark='square') - plot(x, rad2deg(y2), 'v-', label="Phi") + plot(x, rad2deg(y2), 'v-', label='Phi') graph.plot(x[:-1], rad2deg(y2[:-1]), mark='*') - plot(sx, rad2deg(sy2), 'v-', label="Phi (sim)") + plot(sx, rad2deg(sy2), 'v-', label='Phi (sim)') graph.plot(sx[:-1], rad2deg(sy2[:-1]), mark='square*') # Labels etc. - xlabel(r"Core distance [m] $\pm %d$" % DR) - graph.set_xlabel(r"Core distance [\si{\meter}] $\pm \SI{%d}{\meter}$" % DR) - ylabel("Angle reconstruction uncertainty [deg]") - graph.set_ylabel(r"Angle reconstruction uncertainty [\si{\degree}]") - title(r"$N_{MIP} = %d \pm %.1f, \theta = 22.5^\circ \pm %d^\circ, %.1f \leq \log(E) \leq %.1f$" % (N, DN, rad2deg(DTHETA), LOGENERGY - DLOGENERGY, LOGENERGY + DLOGENERGY)) + xlabel(r'Core distance [m] $\pm %d$' % DR) + graph.set_xlabel(r'Core distance [\si{\meter}] $\pm \SI{%d}{\meter}$' % DR) + ylabel('Angle reconstruction uncertainty [deg]') + graph.set_ylabel(r'Angle reconstruction uncertainty [\si{\degree}]') + title( + r'$N_{MIP} = %d \pm %.1f, \theta = 22.5^\circ \pm %d^\circ, %.1f \leq \log(E) \leq %.1f$' + % (N, DN, rad2deg(DTHETA), LOGENERGY - DLOGENERGY, LOGENERGY + DLOGENERGY), + ) ylim(ymin=0) graph.set_ylimits(min=0) xlim(-2, 62) legend(numpoints=1, loc='best') utils.saveplot() artist.utils.save_graph(graph, dirname='plots') - print + print() + # Time of first hit pamflet functions -Q = lambda t, n: ((.5 * (1 - erf(t / sqrt(2)))) ** (n - 1) - * exp(-.5 * t ** 2) / sqrt(2 * pi)) +Q = lambda t, n: ((0.5 * (1 - erf(t / sqrt(2)))) ** (n - 1) * exp(-0.5 * t**2) / sqrt(2 * pi)) -expv_t = vectorize(lambda n: integrate.quad(lambda t: t * Q(t, n) - / n ** -1, - - inf, +inf)) +expv_t = vectorize(lambda n: integrate.quad(lambda t: t * Q(t, n) / n**-1, -inf, +inf)) expv_tv = lambda n: expv_t(n)[0] -expv_tsq = vectorize(lambda n: integrate.quad(lambda t: t ** 2 * Q(t, n) - / n ** -1, - - inf, +inf)) +expv_tsq = vectorize(lambda n: integrate.quad(lambda t: t**2 * Q(t, n) / n**-1, -inf, +inf)) expv_tsqv = lambda n: expv_tsq(n)[0] std_t = lambda n: sqrt(expv_tsqv(n) - expv_tv(n) ** 2) @@ -321,7 +326,7 @@ def plot_uncertainty_core_distance(table): def plot_phi_reconstruction_results_for_MIP(table, N): THETA = deg2rad(22.5) - DTHETA = deg2rad(5.) + DTHETA = deg2rad(5.0) events = table.read_where('(min_n134 >= N) & (abs(reference_theta - THETA) <= DTHETA)') sim_phi = events['reference_phi'] @@ -329,16 +334,15 @@ def plot_phi_reconstruction_results_for_MIP(table, N): figure() plot_2d_histogram(rad2deg(sim_phi), rad2deg(r_phi), 180) - xlabel(r"$\phi_K$ [deg]") - ylabel(r"$\phi_H$ [deg]") - title(r"$N_{MIP} \geq %d, \quad \theta = 22.5^\circ \pm %d^\circ$" % (N, rad2deg(DTHETA))) + xlabel(r'$\phi_K$ [deg]') + ylabel(r'$\phi_H$ [deg]') + title(r'$N_{MIP} \geq %d, \quad \theta = 22.5^\circ \pm %d^\circ$' % (N, rad2deg(DTHETA))) utils.saveplot(N) graph = artist.GraphArtist() bins = linspace(-180, 180, 73) - H, x_edges, y_edges = histogram2d(rad2deg(sim_phi), rad2deg(r_phi), - bins=bins) + H, x_edges, y_edges = histogram2d(rad2deg(sim_phi), rad2deg(r_phi), bins=bins) graph.histogram2d(H, x_edges, y_edges, type='reverse_bw') graph.set_xlabel(r'$\phi_K$ [\si{\degree}]') graph.set_ylabel(r'$\phi_H$ [\si{\degree}]') @@ -356,16 +360,15 @@ def plot_theta_reconstruction_results_for_MIP(table, N): x_edges = linspace(0, 40, 81) y_edges = linspace(0, 40, 81) plot_2d_histogram(rad2deg(sim_theta), rad2deg(r_theta), (x_edges, y_edges)) - xlabel(r"$\theta_K$ [deg]") - ylabel(r"$\theta_H$ [deg]") - title(r"$N_{MIP} \geq %d$" % N) + xlabel(r'$\theta_K$ [deg]') + ylabel(r'$\theta_H$ [deg]') + title(r'$N_{MIP} \geq %d$' % N) utils.saveplot(N) graph = artist.GraphArtist() bins = linspace(0, 40, 41) - H, x_edges, y_edges = histogram2d(rad2deg(sim_theta), rad2deg(r_theta), - bins=bins) + H, x_edges, y_edges = histogram2d(rad2deg(sim_theta), rad2deg(r_theta), bins=bins) graph.histogram2d(H, x_edges, y_edges, type='reverse_bw') graph.set_xlabel(r'$\theta_K$ [\si{\degree}]') graph.set_ylabel(r'$\theta_H$ [\si{\degree}]') @@ -375,7 +378,7 @@ def plot_theta_reconstruction_results_for_MIP(table, N): def boxplot_theta_reconstruction_results_for_MIP(table, N): figure() - DTHETA = deg2rad(1.) + DTHETA = deg2rad(1.0) angles = [0, 5, 10, 15, 22.5, 35] r_dtheta = [] @@ -392,13 +395,13 @@ def boxplot_theta_reconstruction_results_for_MIP(table, N): d75.append(scoreatpercentile(dtheta, 75)) x.append(angle) - #boxplot(r_dtheta, sym='', positions=angles, widths=2.) + # boxplot(r_dtheta, sym='', positions=angles, widths=2.) fill_between(x, d25, d75, color='0.75') plot(x, d50, 'o-', color='black') - xlabel(r"$\theta_K$ [deg]") - ylabel(r"$\theta_H - \theta_K$ [deg]") - title(r"$N_{MIP} \geq %d$" % N) + xlabel(r'$\theta_K$ [deg]') + ylabel(r'$\theta_H - \theta_K$ [deg]') + title(r'$N_{MIP} \geq %d$' % N) axhline(0, color='black') ylim(-20, 25) @@ -410,8 +413,8 @@ def boxplot_theta_reconstruction_results_for_MIP(table, N): graph.draw_horizontal_line(0, linestyle='gray') graph.shade_region(angles, d25, d75) graph.plot(angles, d50, linestyle=None) - graph.set_xlabel(r"$\theta_K$ [\si{\degree}]") - graph.set_ylabel(r"$\theta_H - \theta_K$ [\si{\degree}]") + graph.set_xlabel(r'$\theta_K$ [\si{\degree}]') + graph.set_ylabel(r'$\theta_H - \theta_K$ [\si{\degree}]') graph.set_ylimits(-5, 15) artist.utils.save_graph(graph, suffix=N, dirname='plots') @@ -420,7 +423,7 @@ def boxplot_phi_reconstruction_results_for_MIP(table, N): figure() THETA = deg2rad(22.5) - DTHETA = deg2rad(5.) + DTHETA = deg2rad(5.0) bin_edges = linspace(-180, 180, 18) x, r_dphi = [], [] @@ -439,13 +442,13 @@ def boxplot_phi_reconstruction_results_for_MIP(table, N): d75.append(scoreatpercentile(rad2deg(dphi), 75)) x.append((low + high) / 2) - #boxplot(r_dphi, positions=x, widths=1 * (high - low), sym='') + # boxplot(r_dphi, positions=x, widths=1 * (high - low), sym='') fill_between(x, d25, d75, color='0.75') plot(x, d50, 'o-', color='black') - xlabel(r"$\phi_K$ [deg]") - ylabel(r"$\phi_H - \phi_K$ [deg]") - title(r"$N_{MIP} \geq %d, \quad \theta = 22.5^\circ \pm %d^\circ$" % (N, rad2deg(DTHETA))) + xlabel(r'$\phi_K$ [deg]') + ylabel(r'$\phi_H - \phi_K$ [deg]') + title(r'$N_{MIP} \geq %d, \quad \theta = 22.5^\circ \pm %d^\circ$' % (N, rad2deg(DTHETA))) xticks(linspace(-180, 180, 9)) axhline(0, color='black') @@ -456,8 +459,8 @@ def boxplot_phi_reconstruction_results_for_MIP(table, N): graph.draw_horizontal_line(0, linestyle='gray') graph.shade_region(x, d25, d75) graph.plot(x, d50, linestyle=None) - graph.set_xlabel(r"$\phi_K$ [\si{\degree}]") - graph.set_ylabel(r"$\phi_H - \phi_K$ [\si{\degree}]") + graph.set_xlabel(r'$\phi_K$ [\si{\degree}]') + graph.set_ylabel(r'$\phi_H - \phi_K$ [\si{\degree}]') graph.set_xticks([-180, -90, '...', 180]) graph.set_xlimits(-180, 180) graph.set_ylimits(-23, 23) @@ -466,10 +469,10 @@ def boxplot_phi_reconstruction_results_for_MIP(table, N): def boxplot_arrival_times(table, N): THETA = deg2rad(0) - DTHETA = deg2rad(10.) + DTHETA = deg2rad(10.0) LOGENERGY = 15 - DLOGENERGY = .5 + DLOGENERGY = 0.5 bin_edges = linspace(0, 80, 5) x = [] @@ -503,23 +506,26 @@ def boxplot_arrival_times(table, N): plot(sx, st50, 'o-', color='black', markerfacecolor='none') plot(x, t50, 'o-', color='black') - ax2.xaxis.set_label_text("Core distance [m]") - ax1.yaxis.set_label_text("Arrival time difference $|t_2 - t_1|$ [ns]") - fig.suptitle(r"$N_{MIP} \geq %d, \quad \theta = 0^\circ \pm %d^\circ, \quad %.1f \leq \log(E) \leq %.1f$" % (N, rad2deg(DTHETA), LOGENERGY - DLOGENERGY, LOGENERGY + DLOGENERGY)) + ax2.xaxis.set_label_text('Core distance [m]') + ax1.yaxis.set_label_text('Arrival time difference $|t_2 - t_1|$ [ns]') + fig.suptitle( + r'$N_{MIP} \geq %d, \quad \theta = 0^\circ \pm %d^\circ, \quad %.1f \leq \log(E) \leq %.1f$' + % (N, rad2deg(DTHETA), LOGENERGY - DLOGENERGY, LOGENERGY + DLOGENERGY), + ) ylim(ymax=15) xlim(xmax=80) locator_params(tight=True, nbins=4) - fig.subplots_adjust(left=.1, right=.95) + fig.subplots_adjust(left=0.1, right=0.95) fig.set_size_inches(5, 2.67) utils.saveplot(N) sx = sx.compress(sx < 80) - st25 = st25[:len(sx)] - st50 = st50[:len(sx)] - st75 = st75[:len(sx)] + st25 = st25[: len(sx)] + st50 = st50[: len(sx)] + st75 = st75[: len(sx)] graph = MultiPlot(1, 3, width=r'.3\linewidth', height=r'.4\linewidth') @@ -537,8 +543,8 @@ def boxplot_arrival_times(table, N): graph.set_ylimits(0, 15) graph.set_xlimits(0, 80) - graph.set_xlabel(r"Core distance [\si{\meter}]") - graph.set_ylabel(r"Arrival time difference $|t_2 - t_1|$ [\si{\nano\second}]") + graph.set_xlabel(r'Core distance [\si{\meter}]') + graph.set_ylabel(r'Arrival time difference $|t_2 - t_1|$ [\si{\nano\second}]') graph.show_xticklabels_for_all([(0, 0), (0, 1), (0, 2)]) graph.set_xticklabels_position(0, 1, 'right') graph.show_yticklabels(0, 0) @@ -548,8 +554,8 @@ def boxplot_arrival_times(table, N): def boxplot_core_distances_for_mips(table): THETA = deg2rad(22.5) - DTHETA = deg2rad(1.) - DN = .5 + DTHETA = deg2rad(1.0) + DN = 0.5 ENERGY = 1e15 DENERGY = 2e14 @@ -561,7 +567,9 @@ def boxplot_core_distances_for_mips(table): r75_list = [] x = [] for N in range(1, 5): - sel = table.read_where('(abs(min_n134 - N) <= DN) & (abs(reference_theta - THETA) <= DTHETA) & (abs(k_energy - ENERGY) <= DENERGY) & (r <= MAX_R)') + sel = table.read_where( + '(abs(min_n134 - N) <= DN) & (abs(reference_theta - THETA) <= DTHETA) & (abs(k_energy - ENERGY) <= DENERGY) & (r <= MAX_R)', + ) r = sel[:]['r'] r25_list.append(scoreatpercentile(r, 25)) r50_list.append(scoreatpercentile(r, 50)) @@ -585,15 +593,18 @@ def boxplot_core_distances_for_mips(table): plot(x, sr50, 'o-', color='black', markerfacecolor='none') plot(x, r50_list, 'o-', color='black') - ax2.xaxis.set_label_text(r"Minimum number of particles $\pm %.1f$" % DN) - ax1.yaxis.set_label_text("Core distance [m]") - fig.suptitle(r"$\theta = 22.5^\circ \pm %d^\circ, \quad %.1f \leq \log(E) \leq %.1f$" % (rad2deg(DTHETA), log10(ENERGY - DENERGY), log10(ENERGY + DENERGY))) + ax2.xaxis.set_label_text(r'Minimum number of particles $\pm %.1f$' % DN) + ax1.yaxis.set_label_text('Core distance [m]') + fig.suptitle( + r'$\theta = 22.5^\circ \pm %d^\circ, \quad %.1f \leq \log(E) \leq %.1f$' + % (rad2deg(DTHETA), log10(ENERGY - DENERGY), log10(ENERGY + DENERGY)), + ) ax1.xaxis.set_ticks([1, 2, 3, 4]) - fig.subplots_adjust(left=.1, right=.95) + fig.subplots_adjust(left=0.1, right=0.95) ylim(ymin=0) - xlim(.8, 4.2) + xlim(0.8, 4.2) fig.set_size_inches(5, 2.67) utils.saveplot() @@ -613,8 +624,8 @@ def boxplot_core_distances_for_mips(table): graph.set_label(0, 2, 'sim + exp') graph.set_ylimits(0, 50) - graph.set_xlabel(r"Minimum number of particles $\pm %.1f$" % DN) - graph.set_ylabel(r"Core distance [\si{\meter}]") + graph.set_xlabel(r'Minimum number of particles $\pm %.1f$' % DN) + graph.set_ylabel(r'Core distance [\si{\meter}]') graph.show_xticklabels_for_all([(0, 0), (0, 1), (0, 2)]) graph.show_yticklabels(0, 0) @@ -623,9 +634,14 @@ def boxplot_core_distances_for_mips(table): def plot_2d_histogram(x, y, bins): H, xedges, yedges = histogram2d(x, y, bins) - imshow(H.T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], - origin='lower left', interpolation='lanczos', aspect='auto', - cmap=cm.Greys) + imshow( + H.T, + extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], + origin='lower left', + interpolation='lanczos', + aspect='auto', + cmap=cm.Greys, + ) colorbar() @@ -645,15 +661,15 @@ def plot_fsot_vs_lint_for_zenith(fsot, lint): errors = f_sel['reconstructed_phi'] - f_sel['reference_phi'] errors2 = f_sel['reconstructed_theta'] - f_sel['reference_theta'] - #f_y.append(std(errors)) - #f_y2.append(std(errors2)) + # f_y.append(std(errors)) + # f_y2.append(std(errors2)) f_y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) f_y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) errors = l_sel['reconstructed_phi'] - l_sel['reference_phi'] errors2 = l_sel['reconstructed_theta'] - l_sel['reference_theta'] - #l_y.append(std(errors)) - #l_y2.append(std(errors2)) + # l_y.append(std(errors)) + # l_y2.append(std(errors2)) l_y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) l_y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) @@ -662,14 +678,14 @@ def plot_fsot_vs_lint_for_zenith(fsot, lint): print(x[-1], len(f_sel), len(l_sel)) clf() - plot(x, rad2deg(f_y), label="FSOT phi") - plot(x, rad2deg(f_y2), label="FSOT theta") - plot(x, rad2deg(l_y), label="LINT phi") - plot(x, rad2deg(l_y2), label="LINT theta") + plot(x, rad2deg(f_y), label='FSOT phi') + plot(x, rad2deg(f_y2), label='FSOT theta') + plot(x, rad2deg(l_y), label='LINT phi') + plot(x, rad2deg(l_y2), label='LINT theta') legend() - xlabel("Shower zenith angle [deg]") - ylabel("Angle reconstruction uncertainty [deg]") - title(r"$N_{MIP} \geq %d$" % min_N) + xlabel('Shower zenith angle [deg]') + ylabel('Angle reconstruction uncertainty [deg]') + title(r'$N_{MIP} \geq %d$' % min_N) utils.saveplot() graph = GraphArtist() @@ -677,8 +693,8 @@ def plot_fsot_vs_lint_for_zenith(fsot, lint): graph.plot(x, rad2deg(l_y), mark=None, linestyle='dashed') graph.plot(x, rad2deg(f_y2), mark=None) graph.plot(x, rad2deg(l_y2), mark=None, linestyle='dashed') - graph.set_xlabel(r"Shower zenith angle [\si{\degree}]") - graph.set_ylabel(r"Angle reconstruction uncertainty [\si{\degree}]") + graph.set_xlabel(r'Shower zenith angle [\si{\degree}]') + graph.set_ylabel(r'Angle reconstruction uncertainty [\si{\degree}]') artist.utils.save_graph(graph, dirname='plots') @@ -692,16 +708,16 @@ def plot_fsot_vs_lint_for_zenith(fsot, lint): except NameError: data = tables.open_file('kascade.h5', 'r') - artist.utils.set_prefix("KAS-") - utils.set_prefix("KAS-") + artist.utils.set_prefix('KAS-') + utils.set_prefix('KAS-') do_reconstruction_plots(data, data.root.reconstructions) do_lint_comparison(data) - artist.utils.set_prefix("KAS-LINT-") - utils.set_prefix("KAS-LINT-") + artist.utils.set_prefix('KAS-LINT-') + utils.set_prefix('KAS-LINT-') do_reconstruction_plots(data, data.root.lint_reconstructions) - artist.utils.set_prefix("KAS-OFFSETS-") - utils.set_prefix("KAS-OFFSETS-") + artist.utils.set_prefix('KAS-OFFSETS-') + utils.set_prefix('KAS-OFFSETS-') do_reconstruction_plots(data, data.root.reconstructions_offsets) - artist.utils.set_prefix("KAS-LINT-OFFSETS-") - utils.set_prefix("KAS-LINT-OFFSETS-") + artist.utils.set_prefix('KAS-LINT-OFFSETS-') + utils.set_prefix('KAS-LINT-OFFSETS-') do_reconstruction_plots(data, data.root.lint_reconstructions_offsets) diff --git a/scripts/kascade/event_generator.py b/scripts/kascade/event_generator.py index d0f13088..dcb2c861 100644 --- a/scripts/kascade/event_generator.py +++ b/scripts/kascade/event_generator.py @@ -2,16 +2,16 @@ import datetime import struct -from numpy.random import randint, random - import MySQLdb +from numpy.random import randint, random + from sapphire.transformations import clock -T0 = 1234567890 # 14 Feb 2009 00:31:30 -H_SHIFT = 13.18 # HiSPARC timeshift +T0 = 1234567890 # 14 Feb 2009 00:31:30 +H_SHIFT = 13.18 # HiSPARC timeshift -K_FILE = "generator-kascade.dat" +K_FILE = 'generator-kascade.dat' def generate_events(timespan, rate, reconstructed_fraction): @@ -25,10 +25,9 @@ def generate_events(timespan, rate, reconstructed_fraction): with open(K_FILE, 'w') as f: writer = csv.writer(f, delimiter=' ') - db = MySQLdb.connect(user='buffer', passwd='Buffer4hisp!', - db='buffer') + db = MySQLdb.connect(user='buffer', passwd='Buffer4hisp!', db='buffer') cursor = db.cursor() - cursor.execute("DELETE FROM message") + cursor.execute('DELETE FROM message') db.commit() for ts, ns in zip(timestamps, nanoseconds): @@ -44,48 +43,49 @@ def store_hisparc_event(cursor, ts, ns): trace = 'xxx' l = len(trace) - msg = struct.pack(">BBHBBBBBHIiHhhhhhhii%ds%dshhhhhhii%ds%ds" % - (l, l, l, l), - 2, # central database - 2, # number of devices - l * 2, # length of two traces - t.second, - t.minute, - t.hour, - t.day, - t.month, - t.year, - int(ns), - 0, # SLVtime - 0, # Trigger pattern - 0, # baseline1 - 0, # baseline2 - 0, # npeaks1 - 0, # npeaks2 - 0, # pulseheight1 - 0, # pulseheight2 - 0, # integral1 - 0, # integral2 - trace, trace, - 0, # baseline3 - 0, # baseline4 - 0, # npeaks3 - 0, # npeaks4 - 0, # pulseheight3 - 0, # pulseheight4 - 0, # integral3 - 0, # integral4 - trace, trace) - - cursor.execute("INSERT INTO message (device_id, message) VALUES " - "(601, %s)", (msg,)) + msg = struct.pack( + '>BBHBBBBBHIiHhhhhhhii%ds%dshhhhhhii%ds%ds' % (l, l, l, l), + 2, # central database + 2, # number of devices + l * 2, # length of two traces + t.second, + t.minute, + t.hour, + t.day, + t.month, + t.year, + int(ns), + 0, # SLVtime + 0, # Trigger pattern + 0, # baseline1 + 0, # baseline2 + 0, # npeaks1 + 0, # npeaks2 + 0, # pulseheight1 + 0, # pulseheight2 + 0, # integral1 + 0, # integral2 + trace, + trace, + 0, # baseline3 + 0, # baseline4 + 0, # npeaks3 + 0, # npeaks4 + 0, # pulseheight3 + 0, # pulseheight4 + 0, # integral3 + 0, # integral4 + trace, + trace, + ) + + cursor.execute('INSERT INTO message (device_id, message) VALUES (601, %s)', (msg,)) def store_kascade_event(writer, ts, ns): ns = ns - (ns % 200) - writer.writerow((0, 0, ts, ns, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0)) + writer.writerow((0, 0, ts, ns, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)) if __name__ == '__main__': - generate_events(86400, 4., .1) + generate_events(86400, 4.0, 0.1) diff --git a/scripts/kascade/fit.py b/scripts/kascade/fit.py index 22d65661..6386f219 100644 --- a/scripts/kascade/fit.py +++ b/scripts/kascade/fit.py @@ -8,8 +8,8 @@ def frac_bins(low, high, binsize, nbins=1): binsize = binsize * nbins - low = low - .5 * binsize - high = ceil((high - low) / binsize) * binsize + low + .5 * binsize + low = low - 0.5 * binsize + high = ceil((high - low) / binsize) * binsize + low + 0.5 * binsize return arange(low, high, binsize) @@ -22,20 +22,22 @@ def fit_gauss_to_timings(events, timing_data): dt = [] for t, c in zip(timing_data, events): ph = c['pulseheights'] - if min([ph[i], ph[j]]) / 350. >= 2.: + if min([ph[i], ph[j]]) / 350.0 >= 2.0: dt.append(t[i] - t[j]) - n, bins, patches = hist(dt, bins=frac_bins(-40, 40, 2.5), - histtype='step', label="detector %d - %d" - % (i + 1, j + 1)) + n, bins, patches = hist( + dt, + bins=frac_bins(-40, 40, 2.5), + histtype='step', + label='detector %d - %d' % (i + 1, j + 1), + ) b = [(u + v) / 2 for u, v in zip(bins[:-1], bins[1:])] popt, pcov = curve_fit(gauss, b, n) x = linspace(-40, 40, 100) - plot(x, gauss(x, *popt), label="mu: {:.2f}, sigma: {:.2f}".format(popt[1], - popt[2])) + plot(x, gauss(x, *popt), label=f'mu: {popt[1]:.2f}, sigma: {popt[2]:.2f}') legend(prop={'size': 'small'}) title("Delta t's for D >= 2.") - xlabel("Time (ns)") - ylabel("Count") + xlabel('Time (ns)') + ylabel('Count') if __name__ == '__main__': diff --git a/scripts/kascade/master.py b/scripts/kascade/master.py index f5157821..29350315 100644 --- a/scripts/kascade/master.py +++ b/scripts/kascade/master.py @@ -11,7 +11,7 @@ from sapphire.kascade import KascadeCoincidences, StoreKascadeData -class Master(object): +class Master: hisparc_group = '/hisparc/cluster_kascade/station_601' kascade_group = '/kascade' @@ -35,19 +35,18 @@ def store_cluster_instance(self): if 'cluster' not in group._v_attrs: cluster = clusters.SingleStation() - cluster.set_xyalpha_coordinates(65., 20.82, pi) + cluster.set_xyalpha_coordinates(65.0, 20.82, pi) group._v_attrs.cluster = cluster def read_and_store_kascade_data(self): """Read KASCADE data into analysis file""" - print("Reading KASCADE data") + print('Reading KASCADE data') try: - kascade = StoreKascadeData(self.data, self.kascade_filename, - self.kascade_group, self.hisparc_group) - except RuntimeError, msg: + kascade = StoreKascadeData(self.data, self.kascade_filename, self.kascade_group, self.hisparc_group) + except RuntimeError as msg: print(msg) return else: @@ -59,18 +58,18 @@ def search_for_coincidences(self): try: coincidences = KascadeCoincidences(self.data, hisparc, kascade) - except RuntimeError, msg: + except RuntimeError as msg: print(msg) return else: - print("Searching for coincidences") + print('Searching for coincidences') coincidences.search_coincidences(timeshift=-13.180220188, dtlimit=1e-3) - print("Storing coincidences") + print('Storing coincidences') coincidences.store_coincidences() - print("Done.") + print('Done.') def process_events(self, process_cls, destination=None): - print("Processing HiSPARC events") + print('Processing HiSPARC events') c_index = self.data.get_node(self.kascade_group, 'c_index') index = c_index.col('h_idx') @@ -78,12 +77,12 @@ def process_events(self, process_cls, destination=None): process = process_cls(self.data, self.hisparc_group, index) try: process.process_and_store_results(destination) - except RuntimeError, msg: + except RuntimeError as msg: print(msg) return def reconstruct_direction(self, source, destination, correct_offsets=False): - print("Reconstructing shower directions") + print('Reconstructing shower directions') offsets = None if correct_offsets: @@ -91,8 +90,8 @@ def reconstruct_direction(self, source, destination, correct_offsets=False): offsets = process.determine_detector_timing_offsets() try: - reconstruction = KascadeDirectionReconstruction(self.data, destination, min_n134=0.) - except RuntimeError, msg: + reconstruction = KascadeDirectionReconstruction(self.data, destination, min_n134=0.0) + except RuntimeError as msg: print(msg) return else: diff --git a/scripts/kascade/plot_matching_events.py b/scripts/kascade/plot_matching_events.py index eb9c92f2..6ee9d03d 100644 --- a/scripts/kascade/plot_matching_events.py +++ b/scripts/kascade/plot_matching_events.py @@ -2,14 +2,13 @@ import datetime import tables +import utils from scipy.optimize import curve_fit from artist import GraphArtist, MultiPlot from pylab import * -import utils - from sapphire.kascade import KascadeCoincidences USE_TEX = True @@ -19,7 +18,7 @@ rcParams['font.serif'] = 'Computer Modern' rcParams['font.sans-serif'] = 'Computer Modern' rcParams['font.family'] = 'sans-serif' - rcParams['figure.figsize'] = [4 * x for x in (1, 2. / 3)] + rcParams['figure.figsize'] = [4 * x for x in (1, 2.0 / 3)] rcParams['figure.subplot.left'] = 0.175 rcParams['figure.subplot.bottom'] = 0.175 rcParams['font.size'] = 10 @@ -39,19 +38,19 @@ def plot_nearest_neighbors(data, limit=None): coincidences = KascadeCoincidences(data, hisparc_group, kascade_group, ignore_existing=True) - #dt_opt = find_optimum_dt(coincidences, p0=-13, limit=1000) - #print(dt_opt) + # dt_opt = find_optimum_dt(coincidences, p0=-13, limit=1000) + # print(dt_opt) graph = GraphArtist(axis='semilogy') styles = iter(['solid', 'dashed', 'dashdotted']) uncorrelated = None figure() - #for shift in -12, -13, dt_opt, -14: + # for shift in -12, -13, dt_opt, -14: for shift in -12, -13, -14: - print("Shifting", shift) + print('Shifting', shift) coincidences.search_coincidences(shift, dtlimit=1, limit=limit) - print(".") + print('.') dts = coincidences.coincidences['dt'] n, bins, p = hist(abs(dts) / 1e9, bins=linspace(0, 1, 101), histtype='step', label='%.3f s' % shift) n = [u if u else 1e-99 for u in n] @@ -63,14 +62,14 @@ def plot_nearest_neighbors(data, limit=None): x = (bins[:-1] + bins[1:]) / 2 f = lambda x, N, a: N * exp(-a * x) popt, pcov = curve_fit(f, x, y) - plot(x, f(x, *popt), label=r"$\lambda = %.2f$ Hz" % popt[1]) + plot(x, f(x, *popt), label=r'$\lambda = %.2f$ Hz' % popt[1]) graph.plot(x, f(x, *popt), mark=None) yscale('log') - xlabel("Time difference [s]") - graph.set_xlabel(r"Time difference [\si{\second}]") - ylabel("Counts") - graph.set_ylabel("Counts") + xlabel('Time difference [s]') + graph.set_xlabel(r'Time difference [\si{\second}]') + ylabel('Counts') + graph.set_ylabel('Counts') legend() graph.set_ylimits(min=10) utils.saveplot() @@ -79,10 +78,10 @@ def plot_nearest_neighbors(data, limit=None): def find_optimum_dt(coincidences, p0, delta=1e-3, limit=None): dt_new = p0 - dt_old = Inf + dt_old = inf while abs(dt_new - dt_old) > delta: - print("Trying:", dt_new) + print('Trying:', dt_new) coincidences.search_coincidences(dt_new, dtlimit=1, limit=limit) dts = coincidences.coincidences['dt'] / 1e9 @@ -112,28 +111,28 @@ def plot_residual_time_differences(data): figure() subplot(121) - hist(all_dts / 1e3, bins=arange(-10, 2, .01), histtype='step') - title("July 1 - Aug 6, 2008") - xlabel("Time difference [us]") - ylabel("Counts") + hist(all_dts / 1e3, bins=arange(-10, 2, 0.01), histtype='step') + title('July 1 - Aug 6, 2008') + xlabel('Time difference [us]') + ylabel('Counts') subplot(122) - hist(dts / 1e3, bins=arange(-8, -6, .01), histtype='step') - title("July 2, 2008") - xlabel("Time difference [us]") + hist(dts / 1e3, bins=arange(-8, -6, 0.01), histtype='step') + title('July 2, 2008') + xlabel('Time difference [us]') utils.saveplot() graph = MultiPlot(1, 2, width=r'.45\linewidth') - n, bins = histogram(all_dts / 1e3, bins=arange(-10, 2, .01)) + n, bins = histogram(all_dts / 1e3, bins=arange(-10, 2, 0.01)) graph.histogram(0, 1, n, bins) - graph.set_title(0, 1, "Jul 1 - Aug 6, 2008") + graph.set_title(0, 1, 'Jul 1 - Aug 6, 2008') - n, bins = histogram(dts / 1e3, bins=arange(-8, -6, .01)) + n, bins = histogram(dts / 1e3, bins=arange(-8, -6, 0.01)) graph.histogram(0, 0, n, bins) - graph.set_title(0, 0, "Jul 2, 2008") + graph.set_title(0, 0, 'Jul 2, 2008') - graph.set_xlabel(r"Time difference [\si{\micro\second}]") - graph.set_ylabel("Counts") + graph.set_xlabel(r'Time difference [\si{\micro\second}]') + graph.set_ylabel('Counts') graph.set_ylimits(min=0) graph.show_xticklabels_for_all([(0, 0), (0, 1)]) graph.show_yticklabels_for_all([(0, 0), (0, 1)]) diff --git a/scripts/kascade/plot_pulseheight_histogram.py b/scripts/kascade/plot_pulseheight_histogram.py index 0792972f..921666ed 100644 --- a/scripts/kascade/plot_pulseheight_histogram.py +++ b/scripts/kascade/plot_pulseheight_histogram.py @@ -12,14 +12,12 @@ def plot_pulseheight_histogram(data): s = landau.Scintillator() mev_scale = 3.38 / 340 - count_scale = 6e3 / .32 + count_scale = 6e3 / 0.32 clf() - n, bins, patches = hist(ph[:, 0], bins=arange(0, 1501, 10), - histtype='step') + n, bins, patches = hist(ph[:, 0], bins=arange(0, 1501, 10), histtype='step') x = linspace(0, 1500, 1500) - plot(x, s.conv_landau_for_x(x, mev_scale=mev_scale, - count_scale=count_scale)) + plot(x, s.conv_landau_for_x(x, mev_scale=mev_scale, count_scale=count_scale)) plot(x, count_scale * s.landau_pdf(x * mev_scale)) ylim(ymax=25000) xlim(xmax=1500) @@ -33,22 +31,17 @@ def plot_pulseheight_histogram(data): n_trunc = where(n <= 100000, n, 100000) graph.histogram(n_trunc, bins, linestyle='gray') graph.add_pin('data', x=800, location='above right', use_arrow=True) - graph.add_pin(r'$\gamma$', x=90, location='above right', - use_arrow=True) - graph.plot(x, s.conv_landau_for_x(x, mev_scale=mev_scale, - count_scale=count_scale), - mark=None) - graph.add_pin('convolved Landau', x=450, location='above right', - use_arrow=True) - graph.plot(x, count_scale * s.landau_pdf(x * mev_scale), mark=None, - linestyle='black') + graph.add_pin(r'$\gamma$', x=90, location='above right', use_arrow=True) + graph.plot(x, s.conv_landau_for_x(x, mev_scale=mev_scale, count_scale=count_scale), mark=None) + graph.add_pin('convolved Landau', x=450, location='above right', use_arrow=True) + graph.plot(x, count_scale * s.landau_pdf(x * mev_scale), mark=None, linestyle='black') graph.add_pin('Landau', x=380, location='above right', use_arrow=True) - graph.set_xlabel(r"Pulseheight [\adc{}]") - graph.set_ylabel(r"Number of events") + graph.set_xlabel(r'Pulseheight [\adc{}]') + graph.set_ylabel(r'Number of events') graph.set_xlimits(0, 1400) graph.set_ylimits(0, 21000) - graph.save("plots/plot_pulseheight_histogram") + graph.save('plots/plot_pulseheight_histogram') if __name__ == '__main__': diff --git a/scripts/kascade/read_sqldump/CIC.py b/scripts/kascade/read_sqldump/CIC.py index 2e4af7a2..0e371212 100644 --- a/scripts/kascade/read_sqldump/CIC.py +++ b/scripts/kascade/read_sqldump/CIC.py @@ -1,10 +1,10 @@ """ - Process HiSPARC messages from a buffer. - This module processes the CIC event message +Process HiSPARC messages from a buffer. +This module processes the CIC event message """ -__author__ = "thevinh" -__date__ = "$17-sep-2009" +__author__ = 'thevinh' +__date__ = '$17-sep-2009' import datetime import struct @@ -19,8 +19,8 @@ class CIC(HiSparc2Event): def __init__(self, message): - """ Initialization - Proceed to unpack the message. + """Initialization + Proceed to unpack the message. """ # invoke constructor of parent class HiSparc2Event.__init__(self, message) @@ -28,13 +28,13 @@ def __init__(self, message): # init the trigger rate attribute self.eventrate = 0 - self.uploadCode = 'CIC' + self.uploadCode = 'CIC' - #--------------------------End of __init__--------------------------# + # --------------------------End of __init__--------------------------# def parseMessage(self): # get database flags - tmp = struct.unpack("B", self.message[0:1])[0] + tmp = struct.unpack('B', self.message[0:1])[0] if tmp <= 3: unpack_legacy_message(self) else: @@ -45,93 +45,122 @@ def parseMessage(self): return self.getEventData() - #--------------------------End of parseMessage--------------------------# + # --------------------------End of parseMessage--------------------------# def unpackMessage(self): - """ Unpack a buffer message - This routine unpacks a buffer message written by the LabVIEW DAQ - software version 3.0 and above. Version 2.1.1 doesn't use a version - identifier in the message. By including one, we can account for - different message formats. - - Hopefully, this code is cleaner and thus easier to understand than - the legacy code. However, you'll always have to be careful with the - format strings. + """Unpack a buffer message + This routine unpacks a buffer message written by the LabVIEW DAQ + software version 3.0 and above. Version 2.1.1 doesn't use a version + identifier in the message. By including one, we can account for + different message formats. + + Hopefully, this code is cleaner and thus easier to understand than + the legacy code. However, you'll always have to be careful with the + format strings. """ # Initialize sequential reading mode self.unpackSeqMessage() - self.version, self.database_id, self.data_reduction, \ - self.eventrate, self.num_devices, self.length, \ - gps_second, gps_minute, gps_hour, gps_day, gps_month, gps_year, \ - self.nanoseconds, self.time_delta, self.trigger_pattern = \ - self.unpackSeqMessage('>2BBfBH5BH3L') + ( + self.version, + self.database_id, + self.data_reduction, + self.eventrate, + self.num_devices, + self.length, + gps_second, + gps_minute, + gps_hour, + gps_day, + gps_month, + gps_year, + self.nanoseconds, + self.time_delta, + self.trigger_pattern, + ) = self.unpackSeqMessage('>2BBfBH5BH3L') # Try to handle NaNs for eventrate. These are handled differently from platform to platform (i.e. MSVC libraries are screwed). This platform-dependent fix is not needed in later versions of python. So, drop this in the near future! if str(self.eventrate) in ['-1.#IND', '1.#INF']: self.eventrate = 0 # Only bits 0-19 are defined, zero the rest to make sure - self.trigger_pattern &= 2 ** 20 - 1 + self.trigger_pattern &= 2**20 - 1 - self.datetime = datetime.datetime(gps_year, gps_month, gps_day, - gps_hour, gps_minute, gps_second) + self.datetime = datetime.datetime(gps_year, gps_month, gps_day, gps_hour, gps_minute, gps_second) # Length of a single trace l = self.length / 2 # Read out and save traces and calculated trace parameters - self.mas_stdev1, self.mas_stdev2, self.mas_baseline1, \ - self.mas_baseline2, self.mas_npeaks1, self.mas_npeaks2, \ - self.mas_pulseheight1, self.mas_pulseheight2, self.mas_int1, \ - self.mas_int2, mas_tr1, mas_tr2 = \ - self.unpackSeqMessage('>8H2L%ds%ds' % (l, l)) + ( + self.mas_stdev1, + self.mas_stdev2, + self.mas_baseline1, + self.mas_baseline2, + self.mas_npeaks1, + self.mas_npeaks2, + self.mas_pulseheight1, + self.mas_pulseheight2, + self.mas_int1, + self.mas_int2, + mas_tr1, + mas_tr2, + ) = self.unpackSeqMessage('>8H2L%ds%ds' % (l, l)) self.mas_tr1 = compress(self.unpack_trace(mas_tr1)) self.mas_tr2 = compress(self.unpack_trace(mas_tr2)) # Read out and save secondary data as well, if available if self.num_devices > 1: - self.slv_stdev1, self.slv_stdev2, self.slv_baseline1, \ - self.slv_baseline2, self.slv_npeaks1, self.slv_npeaks2, \ - self.slv_pulseheight1, self.slv_pulseheight2, self.slv_int1, \ - self.slv_int2, slv_tr1, slv_tr2 = \ - self.unpackSeqMessage('>8H2L%ds%ds' % (l, l)) + ( + self.slv_stdev1, + self.slv_stdev2, + self.slv_baseline1, + self.slv_baseline2, + self.slv_npeaks1, + self.slv_npeaks2, + self.slv_pulseheight1, + self.slv_pulseheight2, + self.slv_int1, + self.slv_int2, + slv_tr1, + slv_tr2, + ) = self.unpackSeqMessage('>8H2L%ds%ds' % (l, l)) self.slv_tr1 = compress(self.unpack_trace(slv_tr1)) self.slv_tr2 = compress(self.unpack_trace(slv_tr2)) - #--------------------------End of unpackMessage--------------------------# + # --------------------------End of unpackMessage--------------------------# def unpack_trace(self, raw_trace): - """ Unpack a trace - Traces are stored in a funny way. We have a 12-bit ADC, so two - datapoints can (and are) stored in 3 bytes. This function unravels - traces again. - - DF: I'm wondering: does the LabVIEW program work hard to accomplish - this? If so, why do we do this in the first place? The factor 1.5 - in storage space is hardly worth it, especially considering the - fact that this is only used in the temporary buffer. - - DF: This is legacy code. I've never tried to understand it and will - certainly not touch it until I do. + """Unpack a trace + Traces are stored in a funny way. We have a 12-bit ADC, so two + datapoints can (and are) stored in 3 bytes. This function unravels + traces again. + + DF: I'm wondering: does the LabVIEW program work hard to accomplish + this? If so, why do we do this in the first place? The factor 1.5 + in storage space is hardly worth it, especially considering the + fact that this is only used in the temporary buffer. + + DF: This is legacy code. I've never tried to understand it and will + certainly not touch it until I do. """ n = len(raw_trace) if n % 3 != 0: - #return None - raise Exception("Blob length is not divisible by 3!") - a = struct.unpack("%dB" % (n), raw_trace) + # return None + raise Exception('Blob length is not divisible by 3!') + a = struct.unpack('%dB' % (n), raw_trace) trace = [] for i in xrange(0, n, 3): trace.append((a[i] << 4) + ((a[i + 1] & 240) >> 4)) trace.append(((a[i + 1] & 15) << 8) + a[i + 2]) - trace_str = "" + trace_str = '' for i in trace: - trace_str += str(i) + "," + trace_str += str(i) + ',' return trace_str - #--------------------------End of unpack_trace--------------------------# + # --------------------------End of unpack_trace--------------------------# diff --git a/scripts/kascade/read_sqldump/Event.py b/scripts/kascade/read_sqldump/Event.py index a4b10e40..eec0822d 100644 --- a/scripts/kascade/read_sqldump/Event.py +++ b/scripts/kascade/read_sqldump/Event.py @@ -1,12 +1,12 @@ """ - This module creates different types of Events that are specified by the subclasses +This module creates different types of Events that are specified by the subclasses """ -__author__ = "thevinh" -__date__ = "$16-sep-2009" +__author__ = 'thevinh' +__date__ = '$16-sep-2009' -class Event(): +class Event: # the instantiation operation def __init__(self): # init variables here if needed @@ -16,7 +16,7 @@ def __init__(self): self.nanoseconds = 0 self.export_values = 0 - #--------------------------End of __init__--------------------------# + # --------------------------End of __init__--------------------------# def getEventData(self): pass diff --git a/scripts/kascade/read_sqldump/EventExportValues.py b/scripts/kascade/read_sqldump/EventExportValues.py index 4578df49..0b65eeed 100644 --- a/scripts/kascade/read_sqldump/EventExportValues.py +++ b/scripts/kascade/read_sqldump/EventExportValues.py @@ -26,7 +26,8 @@ ('N', 'TR1', 'mas_tr1'), ('N', 'TR2', 'mas_tr2'), ('N', 'TR3', 'slv_tr1'), - ('N', 'TR4', 'slv_tr2')], + ('N', 'TR4', 'slv_tr2'), + ], 'ERR': [('N', 'ERRMSG', 'error_message')], 'CFG': [ ('N', 'CFG_GPS_LAT', 'cfg_gps_latitude'), @@ -115,11 +116,13 @@ ('N', 'CFG_SLV_CH1COMPGAIN', 'cfg_slv_ch1_comp_gain'), ('N', 'CFG_SLV_CH1COMPOFF', 'cfg_slv_ch1_comp_offset'), ('N', 'CFG_SLV_CH2COMPGAIN', 'cfg_slv_ch2_comp_gain'), - ('N', 'CFG_SLV_CH2COMPOFF', 'cfg_slv_ch2_comp_offset')], + ('N', 'CFG_SLV_CH2COMPOFF', 'cfg_slv_ch2_comp_offset'), + ], 'CMP': [ ('N', 'CMP_DEVICE', 'cmp_device'), ('N', 'CMP_COMPARATOR', 'cmp_comparator'), - ('N', 'CMP_COUNT', 'cmp_count_over_threshold')], + ('N', 'CMP_COUNT', 'cmp_count_over_threshold'), + ], 'WTR': [ ('N', 'WTR_TEMP_INSIDE', 'tempInside'), ('N', 'WTR_TEMP_OUTSIDE', 'tempOutside'), @@ -134,5 +137,6 @@ ('N', 'WTR_RAIN_RATE', 'rainRate'), ('Y', 'WTR_HEAT_INDEX', 'heatIndex'), ('Y', 'WTR_DEW_POINT', 'dewPoint'), - ('Y', 'WTR_WIND_CHILL', 'windChill')], + ('Y', 'WTR_WIND_CHILL', 'windChill'), + ], } diff --git a/scripts/kascade/read_sqldump/HiSparc2Event.py b/scripts/kascade/read_sqldump/HiSparc2Event.py index ecd667df..a2a92c31 100644 --- a/scripts/kascade/read_sqldump/HiSparc2Event.py +++ b/scripts/kascade/read_sqldump/HiSparc2Event.py @@ -1,7 +1,7 @@ -""" Process HiSPARC messages from a buffer - This module processes messages from buffer database and - gets out all available data. This data is stored in a data which - can then be uploaded to the eventwarehouse. +"""Process HiSPARC messages from a buffer +This module processes messages from buffer database and +gets out all available data. This data is stored in a data which +can then be uploaded to the eventwarehouse. """ import base64 @@ -14,9 +14,9 @@ class HiSparc2Event(Event): def __init__(self, message): - """ Initialization - First, determine message type from the argument. Then, check if - this might be a legacy message. Proceed to unpack the message. + """Initialization + First, determine message type from the argument. Then, check if + this might be a legacy message. Proceed to unpack the message. """ # invoke constructor of parent class @@ -25,12 +25,12 @@ def __init__(self, message): # get the message field in the message table self.message = message[1] - #--------------------------End of __init__--------------------------# + # --------------------------End of __init__--------------------------# def unpackMessage(self): pass - #--------------------------End of unpackMessage--------------------------# + # --------------------------End of unpackMessage--------------------------# def parseMessage(self): self.unpackMessage() @@ -40,14 +40,14 @@ def parseMessage(self): return self.getEventData() - #--------------------------End of parseMessage--------------------------# + # --------------------------End of parseMessage--------------------------# def getEventData(self): - """ Get all event data necessary for an upload. - This function parses the export_values variable declared in the EventExportValues - and figures out what data to collect for an - upload to the eventwarehouse. It returns a list of - dictionaries, one for each data element. + """Get all event data necessary for an upload. + This function parses the export_values variable declared in the EventExportValues + and figures out what data to collect for an + upload to the eventwarehouse. It returns a list of + dictionaries, one for each data element. """ eventdata = [] @@ -58,11 +58,11 @@ def getEventData(self): try: data = self.__getattribute__(value[2]) except AttributeError: - #if not self.version == 21: - # This is not a legacy message. Therefore, it should - # contain all exported variables, but alas, it - # apparently doesn't. - #print 'I missed this variable: ', value[2] + # if not self.version == 21: + # This is not a legacy message. Therefore, it should + # contain all exported variables, but alas, it + # apparently doesn't. + # print('I missed this variable: ', value[2]) continue if data_uploadcode in ['TR1', 'TR2', 'TR3', 'TR4']: @@ -72,41 +72,42 @@ def getEventData(self): # blobvalues are base64-decoded. data = base64.b64encode(data) - eventdata.append({ - "calculated": is_calculated, - "data_uploadcode": data_uploadcode, - "data": data, - }) + eventdata.append( + { + 'calculated': is_calculated, + 'data_uploadcode': data_uploadcode, + 'data': data, + }, + ) return eventdata - #--------------------------End of getEventData--------------------------# + # --------------------------End of getEventData--------------------------# def unpackSeqMessage(self, fmt=None): - """ Sequentially unpack message with a format - This method is used to read from the same buffer multiple times, - sequentially. A private variable will keep track of the current - offset. This is more convenient than keeping track of it yourself - multiple times, or hardcoding offsets. + """Sequentially unpack message with a format + This method is used to read from the same buffer multiple times, + sequentially. A private variable will keep track of the current + offset. This is more convenient than keeping track of it yourself + multiple times, or hardcoding offsets. """ if not fmt: # This is an initialization call self._struct_offset = 0 - return + return None if fmt == 'LVstring': # Request for a labview string. That is, first a long for the # length, then the string itself. - length, = self.unpackSeqMessage('>L') - fmt = ">%ds" % length + (length,) = self.unpackSeqMessage('>L') + fmt = '>%ds' % length # For debugging, keeping track of trailing bytes - #print len(self.message[self._struct_offset:]), struct.calcsize(fmt) + # print(len(self.message[self._struct_offset:]), struct.calcsize(fmt)) - data = struct.unpack_from(fmt, self.message, - offset=self._struct_offset) + data = struct.unpack_from(fmt, self.message, offset=self._struct_offset) self._struct_offset += struct.calcsize(fmt) return data - #--------------------------End of unpackSeqMessage--------------------------# + # --------------------------End of unpackSeqMessage--------------------------# diff --git a/scripts/kascade/read_sqldump/legacy.py b/scripts/kascade/read_sqldump/legacy.py index e0a2fff3..524fb9f3 100644 --- a/scripts/kascade/read_sqldump/legacy.py +++ b/scripts/kascade/read_sqldump/legacy.py @@ -1,5 +1,4 @@ import datetime -import random import struct from zlib import compress @@ -19,87 +18,87 @@ def unpack_legacy_message(self): # set the version of this legacy message to 21 (DAQ version 2.1.1) self.version = 21 - tmp = struct.unpack("B", self.blob[0:1])[0] + tmp = struct.unpack('B', self.blob[0:1])[0] if tmp == 1: - self.database = {"local": True, "central": False} + self.database = {'local': True, 'central': False} elif tmp == 2: - self.database = {"local": False, "central": True} + self.database = {'local': False, 'central': True} elif tmp == 3: - self.database = {"local": True, "central": True} - else: # Should not happen - self.database = {"local": False, "central": False} + self.database = {'local': True, 'central': True} + else: # Should not happen + self.database = {'local': False, 'central': False} # Number of devices - self.Ndev = struct.unpack("B", self.blob[1:2])[0] + self.Ndev = struct.unpack('B', self.blob[1:2])[0] # Number of bytes per trace - self.N = struct.unpack(">H", self.blob[2:4])[0] + self.N = struct.unpack('>H', self.blob[2:4])[0] # Seconds - self.second = struct.unpack("B", self.blob[4:5])[0] + self.second = struct.unpack('B', self.blob[4:5])[0] # Minutes - self.minute = struct.unpack("B", self.blob[5:6])[0] + self.minute = struct.unpack('B', self.blob[5:6])[0] # Hour - self.hour = struct.unpack("B", self.blob[6:7])[0] + self.hour = struct.unpack('B', self.blob[6:7])[0] # Day - self.day = struct.unpack("B", self.blob[7:8])[0] + self.day = struct.unpack('B', self.blob[7:8])[0] # Month - self.month = struct.unpack("B", self.blob[8:9])[0] + self.month = struct.unpack('B', self.blob[8:9])[0] # Year - self.year = struct.unpack(">H", self.blob[9:11])[0] + self.year = struct.unpack('>H', self.blob[9:11])[0] # date-time object self.datetime = datetime.datetime( - self.year, - self.month, - self.day, - self.hour, - self.minute, - self.second - ) + self.year, + self.month, + self.day, + self.hour, + self.minute, + self.second, + ) # Get the nanoseconds - self.nanoseconds = struct.unpack(">I", self.blob[11:15])[0] + self.nanoseconds = struct.unpack('>I', self.blob[11:15])[0] # Trigger time of Secondary relative to Primary in ns - self.SLVtime = struct.unpack(">i", self.blob[15:19])[0] + self.SLVtime = struct.unpack('>i', self.blob[15:19])[0] # Trigger pattern # TODO: Unwrap trigger pattern - self.trigger = struct.unpack(">H", self.blob[19:21])[0] + self.trigger = struct.unpack('>H', self.blob[19:21])[0] # Baseline from primary detector 1 - self.mas_baseline1 = struct.unpack(">h", self.blob[21:23])[0] + self.mas_baseline1 = struct.unpack('>h', self.blob[21:23])[0] # Baseline from primary detector 2 - self.mas_baseline2 = struct.unpack(">h", self.blob[23:25])[0] + self.mas_baseline2 = struct.unpack('>h', self.blob[23:25])[0] # Number of peaks from primary detector 1 - self.mas_npeaks1 = struct.unpack(">h", self.blob[25:27])[0] + self.mas_npeaks1 = struct.unpack('>h', self.blob[25:27])[0] # Number of peaks from primary detector 2 - self.mas_npeaks2 = struct.unpack(">h", self.blob[27:29])[0] + self.mas_npeaks2 = struct.unpack('>h', self.blob[27:29])[0] # Pulse height from primary detector 1 - self.mas_pulseheight1 = struct.unpack(">h", self.blob[29:31])[0] + self.mas_pulseheight1 = struct.unpack('>h', self.blob[29:31])[0] # Pulse height from primary detector 2 - self.mas_pulseheight2 = struct.unpack(">h", self.blob[31:33])[0] + self.mas_pulseheight2 = struct.unpack('>h', self.blob[31:33])[0] # Integral from primary detector 1 - self.mas_int1 = struct.unpack(">i", self.blob[33:37])[0] + self.mas_int1 = struct.unpack('>i', self.blob[33:37])[0] # Integral from primary detector 2 - self.mas_int2 = struct.unpack(">i", self.blob[37:41])[0] + self.mas_int2 = struct.unpack('>i', self.blob[37:41])[0] # Trace from primary detector 1 - self.mas_tr1 = compress(self.unpack_trace(self.blob[41:41 + self.N / 2])) + self.mas_tr1 = compress(self.unpack_trace(self.blob[41 : 41 + self.N / 2])) # Trace from primary detector 2 - self.mas_tr2 = compress(self.unpack_trace(self.blob[41 + self.N / 2:41 + self.N])) + self.mas_tr2 = compress(self.unpack_trace(self.blob[41 + self.N / 2 : 41 + self.N])) # If secondary is attached: if self.Ndev == 2: - o = 41 + self.N # Offset + o = 41 + self.N # Offset # Baseline from secondary detector 1 - self.slv_baseline1 = struct.unpack(">h", self.blob[o:o + 2])[0] + self.slv_baseline1 = struct.unpack('>h', self.blob[o : o + 2])[0] # Baseline from secondary detector 2 - self.slv_baseline2 = struct.unpack(">h", self.blob[o + 2:o + 4])[0] + self.slv_baseline2 = struct.unpack('>h', self.blob[o + 2 : o + 4])[0] # Number of peaks from secondary detector 1 - self.slv_npeaks1 = struct.unpack(">h", self.blob[o + 4:o + 6])[0] + self.slv_npeaks1 = struct.unpack('>h', self.blob[o + 4 : o + 6])[0] # Number of peaks from secondary detector 2 - self.slv_npeaks2 = struct.unpack(">h", self.blob[o + 6:o + 8])[0] + self.slv_npeaks2 = struct.unpack('>h', self.blob[o + 6 : o + 8])[0] # Pulse height from secondary detector 1 - self.slv_pulseheight1 = struct.unpack(">h", self.blob[o + 8:o + 10])[0] + self.slv_pulseheight1 = struct.unpack('>h', self.blob[o + 8 : o + 10])[0] # Pulse height from secondary detector 2 - self.slv_pulseheight2 = struct.unpack(">h", self.blob[o + 10:o + 12])[0] + self.slv_pulseheight2 = struct.unpack('>h', self.blob[o + 10 : o + 12])[0] # Integral from secondary detector 1 - self.slv_int1 = struct.unpack(">i", self.blob[o + 12:o + 16])[0] + self.slv_int1 = struct.unpack('>i', self.blob[o + 12 : o + 16])[0] # Integral from secondary detector 2 - self.slv_int2 = struct.unpack(">i", self.blob[o + 16:o + 20])[0] + self.slv_int2 = struct.unpack('>i', self.blob[o + 16 : o + 20])[0] # Trace from secondary detector 1 - self.slv_tr1 = compress(self.unpack_trace(self.blob[o + 20:o + 20 + self.N / 2])) + self.slv_tr1 = compress(self.unpack_trace(self.blob[o + 20 : o + 20 + self.N / 2])) # Trace from secondary detector 2 - self.slv_tr2 = compress(self.unpack_trace(self.blob[o + 20 + self.N / 2:o + 20 + self.N])) + self.slv_tr2 = compress(self.unpack_trace(self.blob[o + 20 + self.N / 2 : o + 20 + self.N])) diff --git a/scripts/kascade/read_sqldump/read_sqldump.py b/scripts/kascade/read_sqldump/read_sqldump.py index ee688603..d3fee739 100644 --- a/scripts/kascade/read_sqldump/read_sqldump.py +++ b/scripts/kascade/read_sqldump/read_sqldump.py @@ -8,18 +8,19 @@ DATAFILE = '../generator.h5' -mysql_escape_sequences = {r'\0': '\x00', - r"\'": r"'", - r'\"': r'"', - r'\b': '\b', - r'\n': '\n', - r'\r': '\r', - r'\t': '\t', - r'\Z': '\x1a', - r'\\': '\\', - r'\%': '%', - r'\_': '_', - } +mysql_escape_sequences = { + r'\0': '\x00', + r'\'': r"'", + r'\"': r'"', + r'\b': '\b', + r'\n': '\n', + r'\r': '\r', + r'\t': '\t', + r'\Z': '\x1a', + r'\\': '\\', + r'\%': '%', + r'\_': '_', +} def process_dump(path): @@ -28,7 +29,7 @@ def process_dump(path): id = 0 buffer = gzip.open(path) for line in buffer: - match = re.match("INSERT INTO `message` VALUES (.*)", line) + match = re.match('INSERT INTO `message` VALUES (.*)', line) if match: id = process_insert(datafile, match.group(1), id) buffer.close() @@ -41,8 +42,7 @@ def process_insert(datafile, s, id): value_pattern = re.compile(r"""([0-9]+|'(?:\\?.)*?')""") for insert in insert_pattern.finditer(s): values = value_pattern.findall(insert.group(1)) - event_id, type, msg = int(values[0]), int(values[2]),\ - values[4][1:-1] + event_id, type, msg = int(values[0]), int(values[2]), values[4][1:-1] if id != 0: id += 1 @@ -51,15 +51,17 @@ def process_insert(datafile, s, id): msg = re.sub(r'\\.', unescape_mysql_string, msg) if id != event_id: - raise Exception("Regexp error: %d != %d" % (id, event_id)) + raise Exception('Regexp error: %d != %d' % (id, event_id)) event = CIC((type, msg)) event.parseMessage() - event = {'header': {'eventtype_uploadcode': event.uploadCode, - 'datetime': event.datetime, - 'nanoseconds': event.nanoseconds, - }, - 'datalist': event.getEventData(), - } + event = { + 'header': { + 'eventtype_uploadcode': event.uploadCode, + 'datetime': event.datetime, + 'nanoseconds': event.nanoseconds, + }, + 'datalist': event.getEventData(), + } store_event(datafile, 'kascade', 601, event) return id diff --git a/scripts/kascade/read_sqldump/storage.py b/scripts/kascade/read_sqldump/storage.py index d2a278e7..ead8421c 100644 --- a/scripts/kascade/read_sqldump/storage.py +++ b/scripts/kascade/read_sqldump/storage.py @@ -1,4 +1,3 @@ -import csv import os import tables @@ -166,8 +165,7 @@ def open_or_create_file(data_dir, date): """ dir = os.path.join(data_dir, '%d/%d' % (date.year, date.month)) - file = os.path.join(dir, '%d_%d_%d.h5' % (date.year, date.month, - date.day)) + file = os.path.join(dir, '%d_%d_%d.h5' % (date.year, date.month, date.day)) if not os.path.exists(dir): # create dir and parent dirs with mode rwxr-xr-x @@ -189,8 +187,7 @@ def get_or_create_station_group(file, cluster, station_id): try: station = file.get_node(cluster, node_name) except tables.NoSuchNodeError: - station = file.create_group(cluster, node_name, - 'HiSPARC station %d data' % station_id) + station = file.create_group(cluster, node_name, 'HiSPARC station %d data' % station_id) file.flush() return station @@ -213,8 +210,7 @@ def get_or_create_cluster_group(file, cluster): try: cluster = file.get_node(hisparc, node_name) except tables.NoSuchNodeError: - cluster = file.create_group(hisparc, node_name, - 'HiSPARC cluster %s data' % cluster) + cluster = file.create_group(hisparc, node_name, 'HiSPARC cluster %s data' % cluster) file.flush() return cluster @@ -232,27 +228,17 @@ def get_or_create_node(file, cluster, node): node = file.get_node(cluster, node) except tables.NoSuchNodeError: if node == 'events': - node = file.create_table(cluster, 'events', HisparcEvent, - 'HiSPARC coincidences table') + node = file.create_table(cluster, 'events', HisparcEvent, 'HiSPARC coincidences table') elif node == 'errors': - node = file.create_table(cluster, 'errors', HisparcError, - 'HiSPARC error messages') + node = file.create_table(cluster, 'errors', HisparcError, 'HiSPARC error messages') elif node == 'comparator': - node = file.create_table(cluster, 'comparator', - HisparcComparatorData, - 'HiSPARC comparator messages') + node = file.create_table(cluster, 'comparator', HisparcComparatorData, 'HiSPARC comparator messages') elif node == 'blobs': - node = file.create_vlarray(cluster, 'blobs', - tables.VLStringAtom(), - 'HiSPARC binary data') + node = file.create_vlarray(cluster, 'blobs', tables.VLStringAtom(), 'HiSPARC binary data') elif node == 'config': - node = file.create_table(cluster, 'config', - HisparcConfiguration, - 'HiSPARC configuration messages') + node = file.create_table(cluster, 'config', HisparcConfiguration, 'HiSPARC configuration messages') elif node == 'weather': - node = file.create_table(cluster, 'weather', - HisparcWeather, - 'HiSPARC weather data') + node = file.create_table(cluster, 'weather', HisparcWeather, 'HiSPARC weather data') file.flush() return node diff --git a/scripts/kascade/read_sqldump/store_events.py b/scripts/kascade/read_sqldump/store_events.py index ab125362..917d79db 100644 --- a/scripts/kascade/read_sqldump/store_events.py +++ b/scripts/kascade/read_sqldump/store_events.py @@ -1,11 +1,6 @@ import base64 import calendar -import datetime import logging -import os.path -import sys - -import tables import storage @@ -30,14 +25,11 @@ def store_event(datafile, cluster, station_id, event): try: upload_codes = eventtype_upload_codes[eventtype] except KeyError: - logger.error('Unknown event type: %s, discarding event' % - eventtype) + logger.error('Unknown event type: %s, discarding event' % eventtype) return - parentnode = storage.get_or_create_station_group(datafile, cluster, - station_id) - table = storage.get_or_create_node(datafile, parentnode, - upload_codes['_tablename']) + parentnode = storage.get_or_create_station_group(datafile, cluster, station_id) + table = storage.get_or_create_node(datafile, parentnode, upload_codes['_tablename']) blobs = storage.get_or_create_node(datafile, parentnode, 'blobs') row = table.row @@ -85,15 +77,11 @@ def store_event(datafile, cluster, station_id, event): if key in data: data[key][index] = value else: - logger.warning('Datatype not known on server side: %s ' - '(%s)' % (key, eventtype)) + logger.warning('Datatype not known on server side: %s (%s)' % (key, eventtype)) + elif uploadcode in data: + data[uploadcode] = value else: - # uploadcode: EVENTRATE, RED, etc. - if uploadcode in data: - data[uploadcode] = value - else: - logger.warning('Datatype not known on server side: %s ' - '(%s)' % (uploadcode, eventtype)) + logger.warning('Datatype not known on server side: %s (%s)' % (uploadcode, eventtype)) # write data values to row for key, value in upload_codes.items(): diff --git a/scripts/kascade/read_sqldump/upload_codes.py b/scripts/kascade/read_sqldump/upload_codes.py index be46ee75..a2614227 100644 --- a/scripts/kascade/read_sqldump/upload_codes.py +++ b/scripts/kascade/read_sqldump/upload_codes.py @@ -29,8 +29,7 @@ }, 'CFG': { '_tablename': 'config', - '_blobs': ['CFG_MAS_VERSION', 'CFG_SLV_VERSION', 'CFG_PASSWORD', - 'CFG_BUFFER'], + '_blobs': ['CFG_MAS_VERSION', 'CFG_SLV_VERSION', 'CFG_PASSWORD', 'CFG_BUFFER'], '_has_ext_time': False, 'CFG_GPS_LAT': 'gps_latitude', 'CFG_GPS_LONG': 'gps_longitude', diff --git a/scripts/kascade/reconstruction_efficiency.py b/scripts/kascade/reconstruction_efficiency.py index 86a3ed5c..9ea40cc5 100644 --- a/scripts/kascade/reconstruction_efficiency.py +++ b/scripts/kascade/reconstruction_efficiency.py @@ -1,5 +1,6 @@ import numpy as np import tables +import utils from scipy import optimize, stats @@ -8,8 +9,6 @@ from artist import GraphArtist -import utils - from sapphire.analysis import landau RANGE_MAX = 40000 @@ -17,7 +16,7 @@ LOW, HIGH = 500, 5500 -VNS = .57e-3 * 2.5 +VNS = 0.57e-3 * 2.5 USE_TEX = True @@ -26,7 +25,7 @@ plt.rcParams['font.serif'] = 'Computer Modern' plt.rcParams['font.sans-serif'] = 'Computer Modern' plt.rcParams['font.family'] = 'sans-serif' - plt.rcParams['figure.figsize'] = [4 * x for x in (1, 2. / 3)] + plt.rcParams['figure.figsize'] = [4 * x for x in (1, 2.0 / 3)] plt.rcParams['figure.subplot.left'] = 0.175 plt.rcParams['figure.subplot.bottom'] = 0.175 plt.rcParams['font.size'] = 10 @@ -34,7 +33,7 @@ plt.rcParams['text.usetex'] = True -class ReconstructionEfficiency(object): +class ReconstructionEfficiency: def __init__(self, data): global scintillator self.data = data @@ -63,7 +62,7 @@ def calc_charged_spectrum(self, x, y, p_gamma, p_landau): max_pos = x[y_landau.argmax()] y_gamma = self.gamma_func(x, *p_gamma) - y_gamma_trunc = np.where(x <= 3 * max_pos, y_gamma, 0.) + y_gamma_trunc = np.where(x <= 3 * max_pos, y_gamma, 0.0) y_reduced = y - y_gamma_trunc @@ -71,15 +70,13 @@ def calc_charged_spectrum(self, x, y, p_gamma, p_landau): y_charged_left = y_landau.compress(x <= max_pos) y_charged_right = y_reduced.compress(max_pos < x) - y_charged = np.array(y_charged_left.tolist() + - y_charged_right.tolist()) + y_charged = np.array(y_charged_left.tolist() + y_charged_right.tolist()) return y_charged def full_spectrum_fit(self, x, y, p0_gamma, p0_landau): p_gamma = self.fit_gammas_to_data(x, y, p0_gamma) - p_landau = self.fit_conv_landau_to_data(x, y - self.gamma_func(x, *p_gamma), - p0_landau) + p_landau = self.fit_conv_landau_to_data(x, y - self.gamma_func(x, *p_gamma), p0_landau) p_gamma, p_landau = self.fit_complete(x, y, p_gamma, p_landau) return p_gamma, p_landau @@ -95,24 +92,23 @@ def plot_gamma_landau_fit(self): n, bins = np.histogram(ph0, bins=bins) x = (bins[:-1] + bins[1:]) / 2 - p_gamma, p_landau = self.full_spectrum_fit(x, n, (1., 1.), - (5e3 / .32, 3.38 / 5000, 1.)) - print "FULL FIT" - print p_gamma, p_landau + p_gamma, p_landau = self.full_spectrum_fit(x, n, (1.0, 1.0), (5e3 / 0.32, 3.38 / 5000, 1.0)) + print('FULL FIT') + print(p_gamma, p_landau) n /= 10 p_gamma, p_landau = self.constrained_full_spectrum_fit(x, n, p_gamma, p_landau) - print "CONSTRAINED FIT" - print p_gamma, p_landau + print('CONSTRAINED FIT') + print(p_gamma, p_landau) plt.figure() - print self.calc_charged_fraction(x, n, p_gamma, p_landau) + print(self.calc_charged_fraction(x, n, p_gamma, p_landau)) plt.plot(x * VNS, n) self.plot_landau_and_gamma(x, p_gamma, p_landau) - #plt.plot(x, n - self.gamma_func(x, *p_gamma)) - plt.xlabel("Pulse integral [V ns]") - plt.ylabel("Count") + # plt.plot(x, n - self.gamma_func(x, *p_gamma)) + plt.xlabel('Pulse integral [V ns]') + plt.ylabel('Count') plt.yscale('log') plt.xlim(0, 30) plt.ylim(1e1, 1e4) @@ -122,8 +118,8 @@ def plot_gamma_landau_fit(self): graph = GraphArtist('semilogy') graph.histogram(n, bins * VNS, linestyle='gray') self.artistplot_landau_and_gamma(graph, x, p_gamma, p_landau) - graph.set_xlabel(r"Pulse integral [\si{\volt\nano\second}]") - graph.set_ylabel("Count") + graph.set_xlabel(r'Pulse integral [\si{\volt\nano\second}]') + graph.set_ylabel('Count') graph.set_xlimits(0, 30) graph.set_ylimits(1e1, 1e4) artist.utils.save_graph(graph, dirname='plots') @@ -139,52 +135,51 @@ def plot_spectrum_fit_chisq(self): n, bins = np.histogram(integrals, bins=bins) x = (bins[:-1] + bins[1:]) / 2 - p_gamma, p_landau = self.full_spectrum_fit(x, n, (1., 1.), - (5e3 / .32, 3.38 / 5000, 1.)) - print "FULL FIT" - print p_gamma, p_landau + p_gamma, p_landau = self.full_spectrum_fit(x, n, (1.0, 1.0), (5e3 / 0.32, 3.38 / 5000, 1.0)) + print('FULL FIT') + print(p_gamma, p_landau) - print "charged fraction:", self.calc_charged_fraction(x, n, p_gamma, p_landau) + print('charged fraction:', self.calc_charged_fraction(x, n, p_gamma, p_landau)) landaus = scintillator.conv_landau_for_x(x, *p_landau) gammas = self.gamma_func(x, *p_gamma) fit = landaus + gammas - x_trunc = x.compress((LOW <= x) & (x < HIGH)) - n_trunc = n.compress((LOW <= x) & (x < HIGH)) - fit_trunc = fit.compress((LOW <= x) & (x < HIGH)) + x_trunc = x.compress((x >= LOW) & (x < HIGH)) + n_trunc = n.compress((x >= LOW) & (x < HIGH)) + fit_trunc = fit.compress((x >= LOW) & (x < HIGH)) chisq, pvalue = stats.chisquare(n_trunc, fit_trunc, ddof=5) - chisq /= (len(n_trunc) - 1 - 5) - print "Chi-square statistic:", chisq, pvalue + chisq /= len(n_trunc) - 1 - 5 + print('Chi-square statistic:', chisq, pvalue) plt.figure() plt.plot(x * VNS, n) self.plot_landau_and_gamma(x, p_gamma, p_landau) - #plt.plot(x_trunc * VNS, fit_trunc, linewidth=4) + # plt.plot(x_trunc * VNS, fit_trunc, linewidth=4) plt.axvline(LOW * VNS) plt.axvline(HIGH * VNS) - plt.xlabel("Pulse integral [V ns]") - plt.ylabel("Count") + plt.xlabel('Pulse integral [V ns]') + plt.ylabel('Count') plt.yscale('log') plt.xlim(0, 20) plt.ylim(1e2, 1e5) - plt.title(r"$\chi^2_{red}$: %.2f, p-value: %.2e" % (chisq, pvalue)) + plt.title(r'$\chi^2_{red}$: %.2f, p-value: %.2e' % (chisq, pvalue)) utils.saveplot() plt.figure() plt.plot(x_trunc * VNS, n_trunc - fit_trunc) plt.axhline(0) - plt.xlabel("Pulse integral [V ns]") - plt.ylabel("Data - Fit") - plt.title(r"$\chi^2_{red}$: %.2f, p-value: %.2e" % (chisq, pvalue)) + plt.xlabel('Pulse integral [V ns]') + plt.ylabel('Data - Fit') + plt.title(r'$\chi^2_{red}$: %.2f, p-value: %.2e' % (chisq, pvalue)) utils.saveplot(suffix='residuals') def plot_landau_and_gamma(self, x, p_gamma, p_landau): gammas = self.gamma_func(x, *p_gamma) - gamma_trunc = np.where(x * VNS <= 21, gammas, 0.) + gamma_trunc = np.where(x * VNS <= 21, gammas, 0.0) plt.plot(x * VNS, gamma_trunc, label='gamma') @@ -214,36 +209,34 @@ def artistplot_alt_landau_and_gamma(self, graph, x, p_gamma, p_landau): graph.plot(x * VNS, landaus, mark=None, linestyle='dashdotted,gray') def fit_gammas_to_data(self, x, y, p0): - condition = (LOW <= x) & (x < 2000) + condition = (x >= LOW) & (x < 2000) x_trunc = x.compress(condition) y_trunc = y.compress(condition) - popt, pcov = optimize.curve_fit(self.gamma_func, x_trunc, y_trunc, - p0=p0, sigma=np.sqrt(y_trunc)) + popt, pcov = optimize.curve_fit(self.gamma_func, x_trunc, y_trunc, p0=p0, sigma=np.sqrt(y_trunc)) return popt def gamma_func(self, x, N, a): - return N * x ** -a + return N * x**-a def fit_conv_landau_to_data(self, x, y, p0): - popt = optimize.fmin(self.scintillator.residuals, p0, - (x, y, 4500, 5500), disp=0) + popt = optimize.fmin(self.scintillator.residuals, p0, (x, y, 4500, 5500), disp=0) return popt def fit_complete(self, x, y, p_gamma, p_landau): p0 = list(p_gamma) + list(p_landau) - popt = optimize.fmin(self.complete_residuals, p0, - (self.scintillator, x, y, LOW, HIGH), - maxfun=100000, disp=0) + popt = optimize.fmin(self.complete_residuals, p0, (self.scintillator, x, y, LOW, HIGH), maxfun=100000, disp=0) return popt[:2], popt[2:] def constrained_fit_complete(self, x, y, p_gamma, p_landau): N_gamma = p_gamma[0] N_landau = p_landau[0] - popt = optimize.fmin(self.constrained_complete_residuals, - (N_gamma, N_landau), - (self.scintillator, x, y, p_gamma, - p_landau, LOW, HIGH), - maxfun=100000, disp=0) + popt = optimize.fmin( + self.constrained_complete_residuals, + (N_gamma, N_landau), + (self.scintillator, x, y, p_gamma, p_landau, LOW, HIGH), + maxfun=100000, + disp=0, + ) p_gamma[0] = popt[0] p_landau[0] = popt[1] return p_gamma, p_landau @@ -257,13 +250,12 @@ def complete_residuals(self, par, scintillator, x, y, a, b): y_exp_trunc = y_exp.compress((a <= x) & (x < b)) # Make sure no zeroes end up in denominator of chi_squared - y_trunc = np.where(y_trunc != 0., y_trunc, 1.) + y_trunc = np.where(y_trunc != 0.0, y_trunc, 1.0) chisquared = ((y_trunc - y_exp_trunc) ** 2 / y_trunc).sum() return chisquared - def constrained_complete_residuals(self, par, scintillator, x, y, - p_gamma, p_landau, a, b): + def constrained_complete_residuals(self, par, scintillator, x, y, p_gamma, p_landau, a, b): full_par = (par[0], p_gamma[1], par[1], p_landau[1], p_landau[2]) return self.complete_residuals(full_par, scintillator, x, y, a, b) @@ -303,8 +295,7 @@ def determine_charged_fraction(self, integrals, p0): def plot_detection_efficiency(self): integrals, dens = self.get_integrals_and_densities() - popt = self.full_fit_on_data(integrals, - (1., 1., 5e3 / .32, 3.38 / 5000, 1.)) + popt = self.full_fit_on_data(integrals, (1.0, 1.0, 5e3 / 0.32, 3.38 / 5000, 1.0)) x, y, yerr = [], [], [] dens_bins = np.linspace(0, 10, 51) @@ -314,34 +305,32 @@ def plot_detection_efficiency(self): frac = self.determine_charged_fraction(sel, popt) y.append(frac) yerr.append(np.sqrt(frac * len(sel)) / len(sel)) - print (low + high) / 2, len(sel) + print((low + high) / 2, len(sel)) self.plot_full_spectrum_fit_in_density_range(sel, popt, low, high) - print + print() plt.figure() - plt.errorbar(x, y, yerr, fmt='o', label='data', markersize=3.) + plt.errorbar(x, y, yerr, fmt='o', label='data', markersize=3.0) - popt, pcov = optimize.curve_fit(self.conv_p_detection, x, y, p0=(1.,)) - print "Sigma Gauss:", popt + popt, pcov = optimize.curve_fit(self.conv_p_detection, x, y, p0=(1.0,)) + print('Sigma Gauss:', popt) x2 = plt.linspace(0, 10, 101) plt.plot(x2, self.p_detection(x2), label='poisson') plt.plot(x2, self.conv_p_detection(x2, *popt), label='poisson/gauss') - plt.xlabel("Charged particle density [$m^{-2}$]") - plt.ylabel("Detection probability") - plt.ylim(0, 1.) + plt.xlabel('Charged particle density [$m^{-2}$]') + plt.ylabel('Detection probability') + plt.ylim(0, 1.0) plt.legend(loc='best') utils.saveplot() graph = GraphArtist() graph.plot(x2, self.p_detection(x2), mark=None) - graph.plot(x2, self.conv_p_detection(x2, *popt), mark=None, - linestyle='dashed') + graph.plot(x2, self.conv_p_detection(x2, *popt), mark=None, linestyle='dashed') graph.plot(x, y, yerr=yerr, linestyle=None) - graph.set_xlabel( - r"Charged particle density [\si{\per\square\meter}]") - graph.set_ylabel("Detection probability") + graph.set_xlabel(r'Charged particle density [\si{\per\square\meter}]') + graph.set_ylabel('Detection probability') graph.set_xlimits(min=0) graph.set_ylimits(min=0) artist.utils.save_graph(graph, dirname='plots') @@ -363,8 +352,8 @@ def plot_full_spectrum_fit_in_density_range(self, sel, popt, low, high): plt.yscale('log') plt.xlim(0, 50) plt.ylim(ymin=1) - plt.xlabel("Pulse integral [V ns]") - plt.ylabel("Count") + plt.xlabel('Pulse integral [V ns]') + plt.ylabel('Count') plt.legend() suffix = '%.1f-%.1f' % (low, high) suffix = suffix.replace('.', '_') @@ -377,14 +366,16 @@ def plot_full_spectrum_fit_in_density_range(self, sel, popt, low, high): graph.histogram(n, bins * VNS, linestyle='gray') self.artistplot_alt_landau_and_gamma(graph, x, p_gamma, p_landau) graph.histogram(y_charged, bins * VNS) - graph.set_xlabel(r"Pulse integral [\si{\volt\nano\second}]") - graph.set_ylabel("Count") - graph.set_title(r"$\SI{%.1f}{\per\square\meter} \leq \rho_\mathrm{charged}$ < $\SI{%.1f}{\per\square\meter}$" % (low, high)) + graph.set_xlabel(r'Pulse integral [\si{\volt\nano\second}]') + graph.set_ylabel('Count') + graph.set_title( + r'$\SI{%.1f}{\per\square\meter} \leq \rho_\mathrm{charged}$ < $\SI{%.1f}{\per\square\meter}$' % (low, high), + ) graph.set_xlimits(0, 30) graph.set_ylimits(1e0, 1e4) artist.utils.save_graph(graph, suffix, dirname='plots') - p_detection = np.vectorize(lambda x: 1 - np.exp(-.5 * x) if x >= 0 else 0.) + p_detection = np.vectorize(lambda x: 1 - np.exp(-0.5 * x) if x >= 0 else 0.0) def conv_p_detection(self, x, sigma): x_step = x[-1] - x[-2] diff --git a/scripts/kascade/rel_gauss.py b/scripts/kascade/rel_gauss.py index 115e089d..f5352b8a 100644 --- a/scripts/kascade/rel_gauss.py +++ b/scripts/kascade/rel_gauss.py @@ -10,7 +10,7 @@ def main(data): fit_to_pulseheights(data, s, num_detector) fit_to_integrals(data, s, num_detector) fit_using_just_gauss_to_integrals(data, num_detector) - print + print() def fit_to_pulseheights(data, s, num_detector): @@ -19,12 +19,12 @@ def fit_to_pulseheights(data, s, num_detector): events = data.root.hisparc.cluster_kascade.station_601.events ph = events.col('pulseheights')[:, num_detector - 1] - print "Fitted to pulseheights, detector", num_detector - popt = do_fit_to_data(ph, s, 1000, 201, (380, .3 * 380)) - print "Relative Gauss width (1st try):", popt[2] / 3.38 + print('Fitted to pulseheights, detector', num_detector) + popt = do_fit_to_data(ph, s, 1000, 201, (380, 0.3 * 380)) + print('Relative Gauss width (1st try):', popt[2] / 3.38) center = 3.38 / popt[1] - popt = do_fit_to_data(ph, s, 1000, 201, (center, .3 * center)) - print "Relative Gauss width (2nd try):", popt[2] / 3.38 + popt = do_fit_to_data(ph, s, 1000, 201, (center, 0.3 * center)) + print('Relative Gauss width (2nd try):', popt[2] / 3.38) def fit_to_integrals(data, s, num_detector): @@ -33,12 +33,12 @@ def fit_to_integrals(data, s, num_detector): events = data.root.hisparc.cluster_kascade.station_601.events intg = events.col('integrals')[:, num_detector - 1] - print "Fitted to integrals, detector", num_detector - popt = do_fit_to_data(intg, s, 20000, 201, (5000, .3 * 5000)) - print "Relative Gauss width (1st try):", popt[2] / 3.38 + print('Fitted to integrals, detector', num_detector) + popt = do_fit_to_data(intg, s, 20000, 201, (5000, 0.3 * 5000)) + print('Relative Gauss width (1st try):', popt[2] / 3.38) center = 3.38 / popt[1] - popt = do_fit_to_data(intg, s, 20000, 201, (center, .3 * center)) - print "Relative Gauss width (2nd try):", popt[2] / 3.38 + popt = do_fit_to_data(intg, s, 20000, 201, (center, 0.3 * center)) + print('Relative Gauss width (2nd try):', popt[2] / 3.38) def fit_using_just_gauss_to_integrals(data, num_detector): @@ -47,22 +47,18 @@ def fit_using_just_gauss_to_integrals(data, num_detector): events = data.root.hisparc.cluster_kascade.station_601.events intg = events.col('integrals')[:, num_detector - 1] - print "Fitted to integrals (Gauss only), detector", num_detector - popt = do_fit_to_data_using_gauss(intg, 20000, 201, - (5000, .3 * 5000)) - print "Relative Gauss width (1st try):", popt[2] / 3.38 + print('Fitted to integrals (Gauss only), detector', num_detector) + popt = do_fit_to_data_using_gauss(intg, 20000, 201, (5000, 0.3 * 5000)) + print('Relative Gauss width (1st try):', popt[2] / 3.38) center = 3.38 / popt[1] - popt = do_fit_to_data_using_gauss(intg, 20000, 201, - (center, .3 * center)) - print "Relative Gauss width (2nd try):", popt[2] / 3.38 + popt = do_fit_to_data_using_gauss(intg, 20000, 201, (center, 0.3 * center)) + print('Relative Gauss width (2nd try):', popt[2] / 3.38) def do_fit_to_data(dataset, s, max_hist_value, n_bins, guess): center, width = guess - n, bins, patches = hist(dataset, bins=linspace(0, max_hist_value, - n_bins, 'b'), - histtype='step') + n, bins, patches = hist(dataset, bins=linspace(0, max_hist_value, n_bins, 'b'), histtype='step') yscale('log') x = (bins[:-1] + bins[1:]) / 2 @@ -74,9 +70,7 @@ def do_fit_to_data(dataset, s, max_hist_value, n_bins, guess): guess_count = interp(center, sx, sy) - popt, pcov = scipy.optimize.curve_fit(s.conv_landau_for_x, sx, sy, - p0=(guess_count, 3.38 / center, - 1.)) + popt, pcov = scipy.optimize.curve_fit(s.conv_landau_for_x, sx, sy, p0=(guess_count, 3.38 / center, 1.0)) plot(sx, sy, 'r') plot(x, s.conv_landau_for_x(x, *popt), 'g') @@ -87,9 +81,7 @@ def do_fit_to_data(dataset, s, max_hist_value, n_bins, guess): def do_fit_to_data_using_gauss(dataset, max_hist_value, n_bins, guess): center, width = guess - n, bins, patches = hist(dataset, bins=linspace(0, max_hist_value, - n_bins, 'b'), - histtype='step') + n, bins, patches = hist(dataset, bins=linspace(0, max_hist_value, n_bins, 'b'), histtype='step') yscale('log') x = (bins[:-1] + bins[1:]) / 2 @@ -101,18 +93,14 @@ def do_fit_to_data_using_gauss(dataset, max_hist_value, n_bins, guess): guess_count = interp(center, sx, sy) - f = lambda u, N, scale, sigma: N * scipy.stats.norm.pdf(u * scale, - loc=3.38, - scale=sigma) - popt, pcov = scipy.optimize.curve_fit(f, sx, sy, - p0=(guess_count, 3.38 / center, - 1.)) + f = lambda u, N, scale, sigma: N * scipy.stats.norm.pdf(u * scale, loc=3.38, scale=sigma) + popt, pcov = scipy.optimize.curve_fit(f, sx, sy, p0=(guess_count, 3.38 / center, 1.0)) plot(sx, sy, 'r') plot(x, f(x, *popt), 'g') ylim(ymin=1e1) - print popt + print(popt) return popt diff --git a/scripts/kascade/test_tcc.py b/scripts/kascade/test_tcc.py index 5c53a4b5..6e60f261 100644 --- a/scripts/kascade/test_tcc.py +++ b/scripts/kascade/test_tcc.py @@ -1,13 +1,11 @@ -import time - import tables import artist -from sapphire.analysis.core_reconstruction import CoreReconstruction, PlotCoreReconstruction +from sapphire.analysis.core_reconstruction import CoreReconstruction from sapphire.utils import pbar -X, Y = 65., 20.82 +X, Y = 65.0, 20.82 def get_tcc_values(data, force_new=False): @@ -31,15 +29,15 @@ def get_tcc_values(data, force_new=False): def calculate_tcc(event): - n = array([event[u] for u in 'n1', 'n2', 'n3', 'n4']) - n = where(n < .5, 0, n) + n = array([event[u] for u in ('n1', 'n2', 'n3', 'n4')]) + n = where(n < 0.5, 0, n) if not (n > 0).sum() >= 2: return -999 i_max = len(n) mean_n = n.mean() - if mean_n == 0.: + if mean_n == 0.0: return -998 variance = ((n - mean_n) ** 2).sum() / (i_max - 1) @@ -59,20 +57,18 @@ def plot_core_positions(data): hist(false_core_dist, bins=50, histtype='step', label='uncorrelated') hist(small_tcc_core_dist, bins=50, histtype='step', label='tcc < 10') legend() - xlabel("Core distance [m]") - ylabel("Counts") + xlabel('Core distance [m]') + ylabel('Counts') graph = artist.GraphArtist() n, bins = histogram(core_dist, bins=linspace(0, 200, 51)) graph.histogram(n, bins) - graph.add_pin(r'$T_{CC} \geq 10$', x=18, location='above right', - use_arrow=True) + graph.add_pin(r'$T_{CC} \geq 10$', x=18, location='above right', use_arrow=True) n, bins = histogram(false_core_dist, bins=linspace(0, 200, 51)) graph.histogram(n, bins, linestyle='gray') - graph.add_pin('uncorrelated', x=37, location='above right', - use_arrow=True) - graph.set_xlabel(r"Core distance [\si{\meter}]") - graph.set_ylabel("Counts") + graph.add_pin('uncorrelated', x=37, location='above right', use_arrow=True) + graph.set_xlabel(r'Core distance [\si{\meter}]') + graph.set_ylabel('Counts') graph.set_xlimits(0, 200) graph.set_ylimits(min=0) graph.save_as_pdf('preview') @@ -93,7 +89,6 @@ def scatter_core_positions(data): def _get_core_dists(data, sel_str, limit=None): - core_pos = _get_core_positions(data, sel_str, limit) x, y = zip(*core_pos) core_dist = sqrt((array(x) - X) ** 2 + (array(y) - Y) ** 2) @@ -133,7 +128,7 @@ def plot_energy(data, sel_str): energy = k_events.read_coordinates(k_idx, field='energy') false_energy = k_events.read_coordinates(k_idx + 1, field='energy') - print len(energy), len(false_energy) + print(len(energy), len(false_energy)) figure() hist(log10(energy), bins=linspace(14, 18, 51), histtype='step', label=sel_str) @@ -143,24 +138,20 @@ def plot_energy(data, sel_str): graph = artist.GraphArtist() n, bins = histogram(log10(energy), bins=linspace(14, 18, 51)) graph.histogram(n, bins) - graph.add_pin(r'$T_{CC} \geq 10$', x=14.6, location='above right', - use_arrow=True) + graph.add_pin(r'$T_{CC} \geq 10$', x=14.6, location='above right', use_arrow=True) n, bins = histogram(log10(false_energy), bins=linspace(14, 18, 51)) graph.histogram(n, bins, linestyle='gray') - graph.add_pin('uncorrelated', x=15.5, location='above right', - use_arrow=True) - graph.set_xlabel(r"$\lg$ energy [$\lg\si{\electronvolt}$]") - graph.set_ylabel("Counts") + graph.add_pin('uncorrelated', x=15.5, location='above right', use_arrow=True) + graph.set_xlabel(r'$\lg$ energy [$\lg\si{\electronvolt}$]') + graph.set_ylabel('Counts') graph.set_xlimits(14, 18) graph.set_ylimits(min=0) graph.save_as_pdf('preview') def reconstruct_shower_sizes(data, tcc): - reconstruction = KascadeCoreReconstruction(data, '/core', - overwrite=True) - reconstruction.reconstruct_core_positions( - '/hisparc/cluster_kascade/station_601', '/kascade', tcc) + reconstruction = KascadeCoreReconstruction(data, '/core', overwrite=True) + reconstruction.reconstruct_core_positions('/hisparc/cluster_kascade/station_601', '/kascade', tcc) class KascadeCoreReconstruction(CoreReconstruction): @@ -174,21 +165,24 @@ def reconstruct_core_positions(self, hisparc_group, kascade_group, tcc): self.cluster = hisparc_group._v_attrs.cluster self._store_cluster_with_results() - for idx, tcc_value in pbar(zip(c_index[:self.N], tcc)): + for idx, tcc_value in pbar(zip(c_index[: self.N], tcc)): hisparc_event = hisparc_table[idx['h_idx']] kascade_event = kascade_table[idx['k_idx']] if tcc_value >= 10: x, y, N = self.reconstruct_core_position(hisparc_event) - self.store_reconstructed_event(hisparc_event, - kascade_event, x, y, N) + self.store_reconstructed_event(hisparc_event, kascade_event, x, y, N) self.results_table.flush() - def store_reconstructed_event(self, hisparc_event, kascade_event, - reconstructed_core_x, - reconstructed_core_y, - reconstructed_shower_size): + def store_reconstructed_event( + self, + hisparc_event, + kascade_event, + reconstructed_core_x, + reconstructed_core_y, + reconstructed_shower_size, + ): dst_row = self.results_table.row dst_row['id'] = hisparc_event['event_id'] @@ -204,13 +198,10 @@ def store_reconstructed_event(self, hisparc_event, kascade_event, dst_row['reference_theta'] = kascade_event['zenith'] dst_row['reference_phi'] = kascade_event['azimuth'] dst_row['reference_core_pos'] = kascade_event['core_pos'] - dst_row['reconstructed_core_pos'] = reconstructed_core_x, \ - reconstructed_core_y + dst_row['reconstructed_core_pos'] = reconstructed_core_x, reconstructed_core_y dst_row['reference_shower_size'] = kascade_event['Num_e'] dst_row['reconstructed_shower_size'] = reconstructed_shower_size - dst_row['min_n134'] = min(hisparc_event['n1'], - hisparc_event['n3'], - hisparc_event['n4']) + dst_row['min_n134'] = min(hisparc_event['n1'], hisparc_event['n3'], hisparc_event['n4']) dst_row.append() def get_events_from_coincidence(self, coincidence): @@ -233,7 +224,7 @@ def _station_has_triggered(self, event): reconstruct_shower_sizes(data, tcc) core = data.root.core - #plot_core_positions(data) - #scatter_core_positions(data) - #plot_energy(data, 'tcc >= 10') - #plot_energy(data, 'tcc < 10') + # plot_core_positions(data) + # scatter_core_positions(data) + # plot_energy(data, 'tcc >= 10') + # plot_energy(data, 'tcc < 10') diff --git a/scripts/kascade/utils.py b/scripts/kascade/utils.py index d2c9a7d6..936afffb 100644 --- a/scripts/kascade/utils.py +++ b/scripts/kascade/utils.py @@ -1,10 +1,9 @@ -""" Utility functions """ +"""Utility functions""" import inspect -import numpy as np - import matplotlib.pyplot as plt +import numpy as np __suffix = '' __prefix = '' diff --git a/scripts/sciencepark/detector_locations.py b/scripts/sciencepark/detector_locations.py index 8fceb2d4..a9815e51 100644 --- a/scripts/sciencepark/detector_locations.py +++ b/scripts/sciencepark/detector_locations.py @@ -1,15 +1,11 @@ """Show Science Park detector locations on OpenStreetMap""" -import numpy as np - import pylab as plt import sapphire.api import sapphire.clusters import sapphire.simulations -from sapphire.simulations.ldf import KascadeLdf - DETECTOR_COLORS = ['black', 'r', 'g', 'b'] @@ -24,7 +20,7 @@ def get_cluster(stations): return cluster -def plot_detector_locations(cluster, background_path="backgrounds/ScienceParkMap_0.365.png"): +def plot_detector_locations(cluster, background_path='backgrounds/ScienceParkMap_0.365.png'): plot_scintillators_in_cluster(cluster) draw_background_map(background_path) @@ -34,11 +30,9 @@ def plot_scintillators_in_cluster(cluster): for station in cluster.stations: for i, detector in enumerate(station.detectors): detector_x, detector_y = detector.get_xy_coordinates() - plt.scatter(detector_x, detector_y, marker='h', - c=DETECTOR_COLORS[i], edgecolor='none', s=25) + plt.scatter(detector_x, detector_y, marker='h', c=DETECTOR_COLORS[i], edgecolor='none', s=25) station_x, station_y, station_a = station.get_xyalpha_coordinates() - plt.scatter(station_x, station_y, marker='o', c='m', edgecolor='none', - s=7) + plt.scatter(station_x, station_y, marker='o', c='m', edgecolor='none', s=7) plt.title('Science Park detector locations') plt.xlabel('Easting (meters)') plt.ylabel('Northing (meters)') @@ -51,11 +45,10 @@ def draw_background_map(background_path): bg_scale = 0.365 bg_width = background.shape[1] * bg_scale bg_height = background.shape[0] * bg_scale - plt.imshow(background, aspect='equal', alpha=0.5, - extent=[-bg_width, bg_width, -bg_height, bg_height]) + plt.imshow(background, aspect='equal', alpha=0.5, extent=[-bg_width, bg_width, -bg_height, bg_height]) -if __name__=="__main__": +if __name__ == '__main__': stations = sciencepark_stations() cluster = get_cluster(stations) plt.figure() diff --git a/scripts/sciencepark/direction_analysis_plots.py b/scripts/sciencepark/direction_analysis_plots.py index 439cf056..e206b64e 100644 --- a/scripts/sciencepark/direction_analysis_plots.py +++ b/scripts/sciencepark/direction_analysis_plots.py @@ -1,6 +1,7 @@ import itertools import tables +import utils from scipy.optimize import curve_fit from scipy.stats import chisquare, scoreatpercentile @@ -10,8 +11,6 @@ from pylab import * -import utils - from sapphire import clusters from sapphire.analysis.direction_reconstruction import DirectionReconstruction from sapphire.simulations.ldf import KascadeLdf @@ -27,7 +26,7 @@ rcParams['font.serif'] = 'Computer Modern' rcParams['font.sans-serif'] = 'Computer Modern' rcParams['font.family'] = 'sans-serif' - rcParams['figure.figsize'] = [4 * x for x in (1, 2. / 3)] + rcParams['figure.figsize'] = [4 * x for x in (1, 2.0 / 3)] rcParams['figure.subplot.left'] = 0.175 rcParams['figure.subplot.bottom'] = 0.175 rcParams['font.size'] = 10 @@ -36,13 +35,15 @@ def main(data): -# plot_sciencepark_cluster() - #plot_all_single_and_cluster_combinations(data) - #hist_phi_single_stations(data) - #hist_theta_single_stations(data) -# plot_N_vs_R(data) -# artistplot_N_vs_R() + # plot_sciencepark_cluster() + # plot_all_single_and_cluster_combinations(data) + # hist_phi_single_stations(data) + # hist_theta_single_stations(data) + # plot_N_vs_R(data) + # artistplot_N_vs_R() plot_fav_single_vs_cluster(data) + + # plot_fav_single_vs_single(data) # plot_fav_uncertainty_single_vs_cluster(data) # plot_fav_uncertainty_single_vs_single(data) @@ -62,8 +63,8 @@ def artistplot_N_vs_R(): graph.plot(R, N, linestyle=None) graph.plot(Rfit, Nfit, mark=None) - graph.set_xlabel(r"Distance [\si{\meter}]") - graph.set_ylabel("Number of coincidences") + graph.set_xlabel(r'Distance [\si{\meter}]') + graph.set_ylabel('Number of coincidences') graph.set_xlimits(min=0) graph.set_ylimits(min=0) @@ -101,10 +102,8 @@ def plot_sciencepark_cluster(): utils.savedata([x_list, y_list]) utils.saveplot() - artist.utils.save_data([x_list, y_list], suffix='detectors', - dirname='plots') - artist.utils.save_data([stations, x_stations, y_stations], - suffix='stations', dirname='plots') + artist.utils.save_data([x_list, y_list], suffix='detectors', dirname='plots') + artist.utils.save_data([stations, x_stations, y_stations], suffix='stations', dirname='plots') def plot_all_single_and_cluster_combinations(data): @@ -138,8 +137,12 @@ def calc_direction_single_vs_cluster(data, station, cluster, limit=None): if limit and len(theta_cluster) >= limit: break - return array(theta_station).flatten(), array(phi_station).flatten(), \ - array(theta_cluster).flatten(), array(phi_cluster).flatten() + return ( + array(theta_station).flatten(), + array(phi_station).flatten(), + array(theta_cluster).flatten(), + array(phi_cluster).flatten(), + ) def calc_direction_single_vs_single(data, station1, station2): @@ -163,30 +166,33 @@ def calc_direction_single_vs_single(data, station1, station2): theta_station1.append(event_station1['reconstructed_theta']) phi_station1.append(event_station1['reconstructed_phi']) - return array(theta_station1).flatten(), array(phi_station1).flatten(), \ - array(theta_station2).flatten(), array(phi_station2).flatten() + return ( + array(theta_station1).flatten(), + array(phi_station1).flatten(), + array(theta_station2).flatten(), + array(phi_station2).flatten(), + ) def plot_direction_single_vs_cluster(data, station, cluster): cluster_str = [str(u) for u in cluster] - theta_station, phi_station, theta_cluster, phi_cluster = \ - calc_direction_single_vs_cluster(data, station, cluster) + theta_station, phi_station, theta_cluster, phi_cluster = calc_direction_single_vs_cluster(data, station, cluster) figsize = list(rcParams['figure.figsize']) figsize[1] = figsize[0] / 2 figure(figsize=figsize) subplot(121) plot(theta_station, theta_cluster, ',') - xlabel(r"$\theta_{%d}$" % station) - ylabel(r"$\theta_{\{%s\}}$" % ','.join(cluster_str)) + xlabel(r'$\theta_{%d}$' % station) + ylabel(r'$\theta_{\{%s\}}$' % ','.join(cluster_str)) xlim(0, pi / 2) ylim(0, pi / 2) subplot(122) plot(phi_station, phi_cluster, ',') - xlabel(r"$\phi_{%d}$" % station) - ylabel(r"$\phi_{\{%s\}}$" % ','.join(cluster_str)) + xlabel(r'$\phi_{%d}$' % station) + ylabel(r'$\phi_{\{%s\}}$' % ','.join(cluster_str)) xlim(-pi, pi) ylim(-pi, pi) @@ -202,7 +208,7 @@ def hist_phi_single_stations(data): query = '(N == 1) & s%d' % station phi = reconstructions.read_where(query, field='reconstructed_phi') hist(rad2deg(phi), bins=linspace(-180, 180, 21), histtype='step') - xlabel(r"$\phi$") + xlabel(r'$\phi$') legend([station]) locator_params(tight=True, nbins=4) @@ -218,7 +224,7 @@ def hist_theta_single_stations(data): query = '(N == 1) & s%d' % station theta = reconstructions.read_where(query, field='reconstructed_theta') hist(rad2deg(theta), bins=linspace(0, 45, 21), histtype='step') - xlabel(r"$\theta$") + xlabel(r'$\theta$') legend([station]) locator_params(tight=True, nbins=4) @@ -234,40 +240,38 @@ def plot_N_vs_R(data): observables = data.root.coincidences.observables figure() - #clf() + # clf() global c_x, c_y if 'c_x' in globals(): scatter(c_x, c_y) else: stations_in_coincidence = [] for coincidence_events in c_index: - stations = [observables[u]['station_id'] for u in - coincidence_events] + stations = [observables[u]['station_id'] for u in coincidence_events] stations_in_coincidence.append(stations) c_x = [] c_y = [] for station1, station2 in itertools.combinations(station_ids, 2): - condition = [station1 in u and station2 in u for u in - stations_in_coincidence] + condition = [station1 in u and station2 in u for u in stations_in_coincidence] N = sum(condition) R, phi = cluster.calc_r_and_phi_for_stations(station1, station2) scatter(R, N) c_x.append(R) c_y.append(N) - print R, N, station1, station2 + print(R, N, station1, station2) ldf = KascadeLdf() R = linspace(100, 500) E = linspace(1e14, 1e19, 100) - F = E ** -2.7 + F = E**-2.7 N = [] for r in R: x = [] for f, e in zip(F, E): - Ne = e / 1e15 * 10 ** 4.8 + Ne = e / 1e15 * 10**4.8 density = ldf.get_ldf_value_for_size(r, Ne) - prob = 1 - exp(-.5 * density) + prob = 1 - exp(-0.5 * density) x.append(f * prob) N.append(mean(x)) N = array(N) @@ -279,12 +283,12 @@ def plot_N_vs_R(data): sc_y = c_y.compress(c_x >= 100) popt, pcov = curve_fit(f, sc_x, sc_y, p0=(1e45)) plot(R, f(R, popt[0])) - #ylim(0, 150000) + # ylim(0, 150000) ylim(0, 500000) xlim(0, 500) - xlabel("Distance [m]") - ylabel("Number of coincidences") + xlabel('Distance [m]') + ylabel('Number of coincidences') utils.saveplot() utils.savedata([sc_x, sc_y], suffix='data') @@ -301,42 +305,42 @@ def plot_fav_single_vs_cluster(data): figure() for n, station in enumerate(cluster, 1): - theta_station, phi_station, theta_cluster, phi_cluster = \ - calc_direction_single_vs_cluster(data, station, cluster, 2000) + theta_station, phi_station, theta_cluster, phi_cluster = calc_direction_single_vs_cluster( + data, + station, + cluster, + 2000, + ) subplot(2, 3, n) plot(rad2deg(phi_station), rad2deg(phi_cluster), ',') - xlabel(r"$\phi_{%d}$" % station) + xlabel(r'$\phi_{%d}$' % station) xlim(-180, 180) ylim(-180, 180) locator_params(tight=True, nbins=4) if n == 1: - ylabel(r"$\phi_{\{%s\}}$" % ','.join(cluster_str)) + ylabel(r'$\phi_{\{%s\}}$' % ','.join(cluster_str)) bins = linspace(-180, 180, 37) - H, x_edges, y_edges = histogram2d(rad2deg(phi_station), - rad2deg(phi_cluster), - bins=bins) + H, x_edges, y_edges = histogram2d(rad2deg(phi_station), rad2deg(phi_cluster), bins=bins) graph1.histogram2d(0, n - 1, H, x_edges, y_edges, 'reverse_bw') graph1.set_label(0, n - 1, station, 'upper left', style='fill=white') subplot(2, 3, n + 3) plot(rad2deg(theta_station), rad2deg(theta_cluster), ',') - xlabel(r"$\theta_{%d}$" % station) + xlabel(r'$\theta_{%d}$' % station) xlim(0, 45) ylim(0, 45) locator_params(tight=True, nbins=4) if n == 1: - ylabel(r"$\theta_{\{%s\}}$" % ','.join(cluster_str)) + ylabel(r'$\theta_{\{%s\}}$' % ','.join(cluster_str)) bins = linspace(0, 45, 46) - H, x_edges, y_edges = histogram2d(rad2deg(theta_station), - rad2deg(theta_cluster), - bins=bins) + H, x_edges, y_edges = histogram2d(rad2deg(theta_station), rad2deg(theta_cluster), bins=bins) graph2.histogram2d(0, n - 1, H, x_edges, y_edges, 'reverse_bw') graph2.set_label(0, n - 1, station, 'upper left', style='fill=white') - subplots_adjust(wspace=.4, hspace=.4) + subplots_adjust(wspace=0.4, hspace=0.4) utils.saveplot() graph1.set_xticks_for_all(None, range(-180, 181, 90)) @@ -344,14 +348,14 @@ def plot_fav_single_vs_cluster(data): graph1.show_xticklabels_for_all(None) graph1.show_yticklabels(0, 0) graph1.set_xticklabels_position(0, 1, 'right') - graph1.set_xlabel(r"Azimuthal angle (station) [\si{\degree}]") - graph1.set_ylabel(r"Azimuthal angle (cluster) [\si{\degree}]") + graph1.set_xlabel(r'Azimuthal angle (station) [\si{\degree}]') + graph1.set_ylabel(r'Azimuthal angle (cluster) [\si{\degree}]') graph2.show_xticklabels_for_all(None) graph2.show_yticklabels(0, 0) graph2.set_xticklabels_position(0, 1, 'right') - graph2.set_xlabel(r"Zenith angle (station) [\si{\degree}]") - graph2.set_ylabel(r"Zenith angle (cluster) [\si{\degree}]") + graph2.set_xlabel(r'Zenith angle (station) [\si{\degree}]') + graph2.set_ylabel(r'Zenith angle (cluster) [\si{\degree}]') artist.utils.save_graph(graph1, suffix='phi', dirname='plots') artist.utils.save_graph(graph2, suffix='theta', dirname='plots') @@ -370,8 +374,11 @@ def plot_fav_single_vs_single(data): station1 = cluster[i] station2 = cluster[j] - theta_station1, phi_station1, theta_station2, phi_station2 = \ - calc_direction_single_vs_single(data, station1, station2) + theta_station1, phi_station1, theta_station2, phi_station2 = calc_direction_single_vs_single( + data, + station1, + station2, + ) subplot(3, 3, j * 3 + i + 1) if i > j: @@ -380,24 +387,18 @@ def plot_fav_single_vs_single(data): ylim(-180, 180) bins = linspace(-180, 180, 37) - H, x_edges, y_edges = histogram2d(rad2deg(phi_station1), - rad2deg(phi_station2), - bins=bins) + H, x_edges, y_edges = histogram2d(rad2deg(phi_station1), rad2deg(phi_station2), bins=bins) graph.histogram2d(j, i, H, x_edges, y_edges, 'reverse_bw') - graph.set_label(j, i, r'$\phi$', 'upper left', - style='fill=white') + graph.set_label(j, i, r'$\phi$', 'upper left', style='fill=white') elif i < j: plot(rad2deg(theta_station1), rad2deg(theta_station2), ',') xlim(0, 45) ylim(0, 45) bins = linspace(0, 45, 46) - H, x_edges, y_edges = histogram2d(rad2deg(theta_station1), - rad2deg(theta_station2), - bins=bins) + H, x_edges, y_edges = histogram2d(rad2deg(theta_station1), rad2deg(theta_station2), bins=bins) graph.histogram2d(j, i, H, x_edges, y_edges, 'reverse_bw') - graph.set_label(j, i, r'$\theta$', 'upper left', - style='fill=white') + graph.set_label(j, i, r'$\theta$', 'upper left', style='fill=white') if j == 2: xlabel(station1) @@ -405,13 +406,13 @@ def plot_fav_single_vs_single(data): ylabel(station2) locator_params(tight=True, nbins=4) - #subplot(3, 3, n + 3) - #plot(rad2deg(theta_station1), rad2deg(theta_station2), ',') - #xlabel(r"$\theta_{%d}$" % station1) - #ylabel(r"$\theta_{\{%s\}}$" % ','.join(station2_str)) - #xlim(0, 45) - #ylim(0, 45) - #locator_params(tight=True, nbins=4) + # subplot(3, 3, n + 3) + # plot(rad2deg(theta_station1), rad2deg(theta_station2), ',') + # xlabel(r"$\theta_{%d}$" % station1) + # ylabel(r"$\theta_{\{%s\}}$" % ','.join(station2_str)) + # xlim(0, 45) + # ylim(0, 45) + # locator_params(tight=True, nbins=4) utils.saveplot() @@ -424,8 +425,8 @@ def plot_fav_single_vs_single(data): graph.set_yticks(1, 2, range(-180, 181, 90)) graph.set_yticks(0, 2, range(-90, 181, 90)) - graph.set_xlabel(r"Shower angle [\si{\degree}]") - graph.set_ylabel(r"Shower angle [\si{\degree}]") + graph.set_xlabel(r'Shower angle [\si{\degree}]') + graph.set_ylabel(r'Shower angle [\si{\degree}]') for i, station in enumerate(cluster): graph.set_label(i, i, cluster[i], 'center') @@ -443,46 +444,43 @@ def plot_fav_uncertainty_single_vs_cluster(data): figure() for n, station in enumerate(cluster, 1): - theta_station, phi_station, theta_cluster, phi_cluster = \ - calc_direction_single_vs_cluster(data, station, cluster) + theta_station, phi_station, theta_cluster, phi_cluster = calc_direction_single_vs_cluster( + data, + station, + cluster, + ) bins = linspace(0, deg2rad(45), 11) x, y, y2 = [], [], [] for low, high in zip(bins[:-1], bins[1:]): - sel_phi_c = phi_cluster.compress((low <= theta_station) & - (theta_station < high)) - sel_phi_s = phi_station.compress((low <= theta_station) & - (theta_station < high)) - sel_theta_c = theta_cluster.compress((low <= theta_station) & - (theta_station < high)) - sel_theta_s = theta_station.compress((low <= theta_station) & - (theta_station < high)) + sel_phi_c = phi_cluster.compress((low <= theta_station) & (theta_station < high)) + sel_phi_s = phi_station.compress((low <= theta_station) & (theta_station < high)) + sel_theta_c = theta_cluster.compress((low <= theta_station) & (theta_station < high)) + sel_theta_s = theta_station.compress((low <= theta_station) & (theta_station < high)) dphi = sel_phi_s - sel_phi_c dtheta = sel_theta_s - sel_theta_c # make sure phi, theta are between -pi and pi dphi = (dphi + pi) % (2 * pi) - pi dtheta = (dtheta + pi) % (2 * pi) - pi - print rad2deg((low + high) / 2), len(dphi), len(dtheta) + print(rad2deg((low + high) / 2), len(dphi), len(dtheta)) x.append((low + high) / 2) - #y.append(std(dphi)) - #y2.append(std(dtheta)) + # y.append(std(dphi)) + # y2.append(std(dtheta)) y.append((scoreatpercentile(dphi, 83) - scoreatpercentile(dphi, 17)) / 2) y2.append((scoreatpercentile(dtheta, 83) - scoreatpercentile(dtheta, 17)) / 2) ex = linspace(0, deg2rad(45), 50) ephi, etheta = [], [] for theta in ex: - ephi.append(calc_phi_error_for_station_cluster(theta, n, - cluster_ids)) - etheta.append(calc_theta_error_for_station_cluster(theta, n, - cluster_ids)) + ephi.append(calc_phi_error_for_station_cluster(theta, n, cluster_ids)) + etheta.append(calc_theta_error_for_station_cluster(theta, n, cluster_ids)) subplot(2, 3, n) plot(rad2deg(x), rad2deg(y), 'o') plot(rad2deg(ex), rad2deg(ephi)) - xlabel(r"$\theta_{%d}$ [deg]" % station) + xlabel(r'$\theta_{%d}$ [deg]' % station) if n == 1: - ylabel(r"$\phi$ uncertainty [deg]") + ylabel(r'$\phi$ uncertainty [deg]') ylim(0, 100) locator_params(tight=True, nbins=4) @@ -493,10 +491,10 @@ def plot_fav_uncertainty_single_vs_cluster(data): subplot(2, 3, n + 3) plot(rad2deg(x), rad2deg(y2), 'o') plot(rad2deg(ex), rad2deg(etheta)) - xlabel(r"$\theta_{%d}$ [deg]" % station) + xlabel(r'$\theta_{%d}$ [deg]' % station) if n == 1: - ylabel(r"$\theta$ uncertainty [deg]") - #ylabel(r"$\theta_{\{%s\}}$" % ','.join(cluster_str)) + ylabel(r'$\theta$ uncertainty [deg]') + # ylabel(r"$\theta_{\{%s\}}$" % ','.join(cluster_str)) ylim(0, 15) locator_params(tight=True, nbins=4) @@ -504,7 +502,7 @@ def plot_fav_uncertainty_single_vs_cluster(data): graph.plot(1, n - 1, rad2deg(ex), rad2deg(etheta), mark=None) graph.set_label(1, n - 1, r'$\theta$, %d' % station) - subplots_adjust(wspace=.3, hspace=.3) + subplots_adjust(wspace=0.3, hspace=0.3) utils.saveplot() graph.set_xlimits_for_all(None, 0, 45) @@ -513,8 +511,8 @@ def plot_fav_uncertainty_single_vs_cluster(data): graph.show_xticklabels_for_all([(1, 0), (0, 1), (1, 2)]) graph.show_yticklabels_for_all([(0, 2), (1, 0)]) - graph.set_xlabel(r"Shower zenith angle [\si{\degree}]") - graph.set_ylabel(r"Angle reconstruction uncertainty [\si{\degree}]") + graph.set_xlabel(r'Shower zenith angle [\si{\degree}]') + graph.set_ylabel(r'Angle reconstruction uncertainty [\si{\degree}]') artist.utils.save_graph(graph, dirname='plots') @@ -532,29 +530,28 @@ def plot_fav_uncertainty_single_vs_single(data): station1 = cluster[i] station2 = cluster[j] - theta_station1, phi_station1, theta_station2, phi_station2 = \ - calc_direction_single_vs_single(data, station1, station2) + theta_station1, phi_station1, theta_station2, phi_station2 = calc_direction_single_vs_single( + data, + station1, + station2, + ) bins = linspace(0, deg2rad(45), 11) x, y, y2 = [], [], [] for low, high in zip(bins[:-1], bins[1:]): - sel_phi_c = phi_station2.compress((low <= theta_station1) & - (theta_station1 < high)) - sel_phi_s = phi_station1.compress((low <= theta_station1) & - (theta_station1 < high)) - sel_theta_c = theta_station2.compress((low <= theta_station1) & - (theta_station1 < high)) - sel_theta_s = theta_station1.compress((low <= theta_station1) & - (theta_station1 < high)) + sel_phi_c = phi_station2.compress((low <= theta_station1) & (theta_station1 < high)) + sel_phi_s = phi_station1.compress((low <= theta_station1) & (theta_station1 < high)) + sel_theta_c = theta_station2.compress((low <= theta_station1) & (theta_station1 < high)) + sel_theta_s = theta_station1.compress((low <= theta_station1) & (theta_station1 < high)) dphi = sel_phi_s - sel_phi_c dtheta = sel_theta_s - sel_theta_c # make sure phi, theta are between -pi and pi dphi = (dphi + pi) % (2 * pi) - pi dtheta = (dtheta + pi) % (2 * pi) - pi - print rad2deg((low + high) / 2), len(dphi), len(dtheta) + print(rad2deg((low + high) / 2), len(dphi), len(dtheta)) x.append((low + high) / 2) - #y.append(std(dphi)) - #y2.append(std(dtheta)) + # y.append(std(dphi)) + # y2.append(std(dtheta)) y.append((scoreatpercentile(dphi, 83) - scoreatpercentile(dphi, 17)) / 2) y2.append((scoreatpercentile(dtheta, 83) - scoreatpercentile(dtheta, 17)) / 2) @@ -630,8 +627,7 @@ def calc_phi_error_for_station_cluster(theta, station, cluster): err_cluster = rec.rel_phi_errorsq(theta, phis, phi1, phi2, r1, r2) # errors are already squared!! - err_total = sqrt(STATION_TIMING_ERR ** 2 * err_single + - CLUSTER_TIMING_ERR ** 2 * err_cluster) + err_total = sqrt(STATION_TIMING_ERR**2 * err_single + CLUSTER_TIMING_ERR**2 * err_cluster) return mean(err_total) @@ -649,8 +645,7 @@ def calc_theta_error_for_station_cluster(theta, station, cluster): err_cluster = rec.rel_theta1_errorsq(theta, phis, phi1, phi2, r1, r2) # errors are already squared!! - err_total = sqrt(STATION_TIMING_ERR ** 2 * err_single + - CLUSTER_TIMING_ERR ** 2 * err_cluster) + err_total = sqrt(STATION_TIMING_ERR**2 * err_single + CLUSTER_TIMING_ERR**2 * err_cluster) return mean(err_total) @@ -668,8 +663,7 @@ def calc_phi_error_for_station_station(theta, station1, station2): err_single2 = rec.rel_phi_errorsq(theta, phis, phi1, phi2, r1, r2) # errors are already squared!! - err_total = sqrt(STATION_TIMING_ERR ** 2 * err_single1 + - STATION_TIMING_ERR ** 2 * err_single2) + err_total = sqrt(STATION_TIMING_ERR**2 * err_single1 + STATION_TIMING_ERR**2 * err_single2) return mean(err_total) @@ -687,8 +681,7 @@ def calc_theta_error_for_station_station(theta, station1, station2): err_single2 = rec.rel_theta1_errorsq(theta, phis, phi1, phi2, r1, r2) # errors are already squared!! - err_total = sqrt(STATION_TIMING_ERR ** 2 * err_single1 + - STATION_TIMING_ERR ** 2 * err_single2) + err_total = sqrt(STATION_TIMING_ERR**2 * err_single1 + STATION_TIMING_ERR**2 * err_single2) return mean(err_total) @@ -706,16 +699,15 @@ def hist_fav_single_stations(data): theta = reconstructions.read_where(query, field='reconstructed_theta') subplot(2, 3, n) - N, bins, patches = hist(rad2deg(phi), bins=linspace(-180, 180, 21), - histtype='step') + N, bins, patches = hist(rad2deg(phi), bins=linspace(-180, 180, 21), histtype='step') x = (bins[:-1] + bins[1:]) / 2 f = lambda x, a: a popt, pcov = curve_fit(f, x, N, sigma=sqrt(N)) chi2 = chisquare(N, popt[0], ddof=0) - print station, popt, pcov, chi2 + print(station, popt, pcov, chi2) axhline(popt[0]) - xlabel(r"$\phi$") + xlabel(r'$\phi$') legend([station], loc='lower right') locator_params(tight=True, nbins=4) axis('auto') @@ -724,9 +716,8 @@ def hist_fav_single_stations(data): graph1.set_label(0, n - 1, station) subplot(2, 3, n + 3) - N, bins, patches = hist(rad2deg(theta), bins=linspace(0, 45, 21), - histtype='step') - xlabel(r"$\theta$") + N, bins, patches = hist(rad2deg(theta), bins=linspace(0, 45, 21), histtype='step') + xlabel(r'$\theta$') legend([station], loc='lower right') locator_params(tight=True, nbins=4) axis('auto') @@ -734,7 +725,7 @@ def hist_fav_single_stations(data): graph2.histogram(0, n - 1, N, bins) graph2.set_label(0, n - 1, station) - subplots_adjust(wspace=.4) + subplots_adjust(wspace=0.4) utils.saveplot() graph1.set_ylimits_for_all(None, 0, 1500) @@ -760,16 +751,16 @@ def hist_fav_single_stations(data): if __name__ == '__main__': if 'data' not in globals(): # For single station plots - #data = tables.open_file('month-single.h5') + # data = tables.open_file('month-single.h5') # For station / cluster plots - #data = tables.open_file('new.h5') - #data = tables.open_file('newlarge.h5') + # data = tables.open_file('new.h5') + # data = tables.open_file('newlarge.h5') data = tables.open_file('master.h5') # For N vs R plot - #data = tables.open_file('master-large.h5') + # data = tables.open_file('master-large.h5') # No data - #data = None + # data = None - artist.utils.set_prefix("SP-DIR-") - utils.set_prefix("SP-DIR-") + artist.utils.set_prefix('SP-DIR-') + utils.set_prefix('SP-DIR-') main(data) diff --git a/scripts/sciencepark/master-simulations.py b/scripts/sciencepark/master-simulations.py index fde54279..8f9f09d4 100644 --- a/scripts/sciencepark/master-simulations.py +++ b/scripts/sciencepark/master-simulations.py @@ -2,9 +2,8 @@ import re import warnings -import tables - import store_aires_data +import tables from sapphire import clusters from sapphire.simulations import GroundParticlesSimulation, QSubSimulation @@ -14,21 +13,21 @@ N_CORES = 32 -class Master(object): +class Master: def __init__(self, data_filename): if os.path.exists(data_filename): - warnings.warn("%s already exists, some steps are skipped" % data_filename) + warnings.warn('%s already exists, some steps are skipped' % data_filename) self.data = tables.open_file(data_filename, 'a') def main(self): self.store_shower_data() self.do_cluster_simulations() - #self.do_energies_simulations() + # self.do_energies_simulations() def store_shower_data(self): for angle in [0, 5, 10, 15, 22.5, 30, 35, 45]: self.store_1PeV_data_for_angle(angle) - #for energy, group_name in [('e14', 'E_100TeV'), + # for energy, group_name in [('e14', 'E_100TeV'), # ('e16', 'E_10PeV')]: # self.store_data_for_energy(energy, group_name) @@ -76,8 +75,8 @@ def perform_simulation(self, cluster, shower, output_path=None): try: sim = Simulation(*args, **kwargs) - except RuntimeError, msg: - print msg + except RuntimeError as msg: + print(msg) return else: sim.run() diff --git a/scripts/sciencepark/master-single-station.py b/scripts/sciencepark/master-single-station.py index 282a20ce..a672914d 100644 --- a/scripts/sciencepark/master-single-station.py +++ b/scripts/sciencepark/master-single-station.py @@ -19,19 +19,17 @@ class Master: stations = [501, 503, 506] - datetimerange = (datetime.datetime(2012, 2, 1), - datetime.datetime(2012, 2, 2)) + datetimerange = (datetime.datetime(2012, 2, 1), datetime.datetime(2012, 2, 2)) offsets = [] - def __init__(self, data_path): self.data = tables.open_file(data_path, 'a') self.station_groups = ['/s%d' % u for u in self.stations] self.cluster = clusters.ScienceParkCluster(self.stations) - self.trig_threshold = .5 + self.trig_threshold = 0.5 self.detector_offsets = [] @@ -51,17 +49,16 @@ def download_data(self): start, end = self.datetimerange for station, group_path in zip(self.stations, self.station_groups): - if not group_path in self.data: - print "Downloading data for station", station - download_data(self.data, group_path, station, - start, end, get_blobs=True) + if group_path not in self.data: + print('Downloading data for station', station) + download_data(self.data, group_path, station, start, end, get_blobs=True) def clean_data(self): - print "Cleaning data..." + print('Cleaning data...') for group in self.station_groups: group = self.data.get_node(group) attrs = group._v_attrs - if not 'is_clean' in attrs or not attrs.is_clean: + if 'is_clean' not in attrs or not attrs.is_clean: self.clean_events_in_group(group) attrs.is_clean = True @@ -78,22 +75,20 @@ def clean_events_in_group(self, group): unique_ids.append(row_id) prev = timestamp - tmptable = self.data.create_table(group, 't__events', - description=events.description) + tmptable = self.data.create_table(group, 't__events', description=events.description) rows = events.read_coordinates(unique_ids) tmptable.append(rows) tmptable.flush() self.data.rename_node(tmptable, 'events', overwrite=True) def search_coincidences(self): - print "Searching for coincidences..." + print('Searching for coincidences...') if '/c_index' not in self.data and '/timestamps' not in self.data: c_index, timestamps = [], [] for id, station in enumerate(self.station_groups): station = self.data.get_node(station) for event_id, event in enumerate(station.events): - timestamps.append((event['ext_timestamp'], id, - event_id)) + timestamps.append((event['ext_timestamp'], id, event_id)) c_index.append([len(timestamps) - 1]) timestamps = np.array(timestamps, dtype=np.uint64) self.data.create_array('/', 'timestamps', timestamps) @@ -102,7 +97,7 @@ def search_coincidences(self): self.data.root.c_index.append(coincidence) def process_events(self): - print "Processing events..." + print('Processing events...') attrs = self.data.root._v_attrs if 'is_processed' not in attrs or not attrs.is_processed: for station_id, station_group in enumerate(self.station_groups): @@ -112,23 +107,19 @@ def process_events(self): attrs.is_processed = True def store_coincidences(self): - print "Storing coincidences..." + print('Storing coincidences...') if '/coincidences' not in self.data: group = self.data.create_group('/', 'coincidences') group._v_attrs.cluster = self.cluster self.c_index = [] - self.coincidences = self.data.create_table(group, - 'coincidences', - storage.Coincidence) - self.observables = self.data.create_table(group, 'observables', - storage.EventObservables) + self.coincidences = self.data.create_table(group, 'coincidences', storage.Coincidence) + self.observables = self.data.create_table(group, 'observables', storage.EventObservables) for coincidence in pbar(self.data.root.c_index): self.store_coincidence(coincidence) - c_index = self.data.create_vlarray(group, 'c_index', - tables.UInt32Col()) + c_index = self.data.create_vlarray(group, 'c_index', tables.UInt32Col()) for coincidence in self.c_index: c_index.append(coincidence) c_index.flush() @@ -153,15 +144,12 @@ def store_coincidence(self, coincidence): group = self.data.get_node(self.station_groups[station_id]) event = group.events[event_index] - idx = self.store_event_in_observables(event, coincidence_id, - station_id) + idx = self.store_event_in_observables(event, coincidence_id, station_id) observables_idx.append(idx) - timestamps.append((event['ext_timestamp'], event['timestamp'], - event['nanoseconds'])) + timestamps.append((event['ext_timestamp'], event['timestamp'], event['nanoseconds'])) first_timestamp = sorted(timestamps)[0] - row['ext_timestamp'], row['timestamp'], row['nanoseconds'] = \ - first_timestamp + row['ext_timestamp'], row['timestamp'], row['nanoseconds'] = first_timestamp row.append() self.c_index.append(observables_idx) self.coincidences.flush() @@ -172,11 +160,10 @@ def store_event_in_observables(self, event, coincidence_id, station_id): row['id'] = event_id row['station_id'] = station_id - for key in ('timestamp', 'nanoseconds', 'ext_timestamp', - 'n1', 'n2', 'n3', 'n4', 't1', 't2', 't3', 't4'): + for key in ('timestamp', 'nanoseconds', 'ext_timestamp', 'n1', 'n2', 'n3', 'n4', 't1', 't2', 't3', 't4'): row[key] = event[key] - signals = [event[key] for key in 'n1', 'n2', 'n3', 'n4'] + signals = [event[key] for key in ('n1', 'n2', 'n3', 'n4')] N = sum([1 if u > self.trig_threshold else 0 for u in signals]) row['N'] = N @@ -185,35 +172,48 @@ def store_event_in_observables(self, event, coincidence_id, station_id): return event_id def reconstruct_direction(self): - print "Reconstructing direction..." + print('Reconstructing direction...') if '/reconstructions' not in self.data: - reconstruction = ClusterDirectionReconstruction(self.data, - self.stations, '/reconstructions', - detector_offsets=self.detector_offsets) + reconstruction = ClusterDirectionReconstruction( + self.data, + self.stations, + '/reconstructions', + detector_offsets=self.detector_offsets, + ) reconstruction.reconstruct_angles('/coincidences') def determine_detector_offsets(self): - print "Determing detector offsets..." + print('Determing detector offsets...') for station_id, station_group in enumerate(self.station_groups): process = ProcessEvents(self.data, station_group) offsets = process.determine_detector_timing_offsets() - print "Offsets for station %d: %s" % (station_id, offsets) + print('Offsets for station %d: %s' % (station_id, offsets)) self.detector_offsets.append(offsets) class ClusterDirectionReconstruction(DirectionReconstruction): - reconstruction_description = {'coinc_id': tables.UInt32Col(), - 'N': tables.UInt8Col(), - 'reconstructed_theta': tables.Float32Col(), - 'reconstructed_phi': tables.Float32Col(), - 'min_n134': tables.Float32Col(), - } - reconstruction_coincidence_description = {'id': tables.UInt32Col(), - 'N': tables.UInt8Col(), - } - - - def __init__(self, datafile, stations, results_group=None, min_n134=1., N=None, detector_offsets=None, overwrite=False): + reconstruction_description = { + 'coinc_id': tables.UInt32Col(), + 'N': tables.UInt8Col(), + 'reconstructed_theta': tables.Float32Col(), + 'reconstructed_phi': tables.Float32Col(), + 'min_n134': tables.Float32Col(), + } + reconstruction_coincidence_description = { + 'id': tables.UInt32Col(), + 'N': tables.UInt8Col(), + } + + def __init__( + self, + datafile, + stations, + results_group=None, + min_n134=1.0, + N=None, + detector_offsets=None, + overwrite=False, + ): self.data = datafile self.stations = stations @@ -231,22 +231,19 @@ def _create_reconstruction_group_and_tables(self, results_group, overwrite): if overwrite: self.data.remove_node(results_group, recursive=True) else: - raise RuntimeError("Result group exists, but overwrite is False") + raise RuntimeError('Result group exists, but overwrite is False') head, tail = os.path.split(results_group) group = self.data.create_group(head, tail) - stations_description = {'s%d' % u: tables.BoolCol() for u in - self.stations} + stations_description = {'s%d' % u: tables.BoolCol() for u in self.stations} description = self.reconstruction_description description.update(stations_description) - self.reconstruction = self.data.create_table(group, - 'reconstructions', description) + self.reconstruction = self.data.create_table(group, 'reconstructions', description) description = self.reconstruction_coincidence_description description.update(stations_description) - self.reconstruction_coincidences = \ - self.data.create_table(group, 'coincidences', description) + self.reconstruction_coincidences = self.data.create_table(group, 'coincidences', description) return group @@ -295,7 +292,7 @@ def reconstruct_cluster_stations(self, coincidence): def reconstruct_angle(self, event, offsets=None): """Reconstruct angles from a single event""" - c = 3.00e+8 + c = 3.00e8 if offsets is not None: self._correct_offsets(event, offsets) @@ -307,8 +304,7 @@ def reconstruct_angle(self, event, offsets=None): r1, phi1 = station.calc_r_and_phi_for_detectors(1, 3) r2, phi2 = station.calc_r_and_phi_for_detectors(1, 4) - phi = arctan2((dt2 * r1 * cos(phi1) - dt1 * r2 * cos(phi2)), - (dt2 * r1 * sin(phi1) - dt1 * r2 * sin(phi2)) * -1) + phi = arctan2((dt2 * r1 * cos(phi1) - dt1 * r2 * cos(phi2)), (dt2 * r1 * sin(phi1) - dt1 * r2 * sin(phi2)) * -1) theta1 = arcsin(c * dt1 * 1e-9 / (r1 * cos(phi - phi1))) theta2 = arcsin(c * dt2 * 1e-9 / (r2 * cos(phi - phi2))) @@ -327,7 +323,7 @@ def reconstruct_angle(self, event, offsets=None): def reconstruct_cluster_angle(self, events, index_group): """Reconstruct angles from a single event""" - c = 3.00e+8 + c = 3.00e8 t = [int(events[u]['ext_timestamp']) for u in index_group] stations = [events[u]['station_id'] for u in index_group] @@ -338,8 +334,7 @@ def reconstruct_cluster_angle(self, events, index_group): r1, phi1 = self.cluster.calc_r_and_phi_for_stations(stations[0], stations[1]) r2, phi2 = self.cluster.calc_r_and_phi_for_stations(stations[0], stations[2]) - phi = arctan2((dt2 * r1 * cos(phi1) - dt1 * r2 * cos(phi2)), - (dt2 * r1 * sin(phi1) - dt1 * r2 * sin(phi2)) * -1) + phi = arctan2((dt2 * r1 * cos(phi1) - dt1 * r2 * cos(phi2)), (dt2 * r1 * sin(phi1) - dt1 * r2 * sin(phi2)) * -1) theta1 = arcsin(c * dt1 * 1e-9 / (r1 * cos(phi - phi1))) theta2 = arcsin(c * dt2 * 1e-9 / (r2 * cos(phi - phi2))) @@ -355,9 +350,7 @@ def reconstruct_cluster_angle(self, events, index_group): return theta_wgt, phi - def store_reconstructed_event_from_single_station(self, coincidence, event, - reconstructed_theta, - reconstructed_phi): + def store_reconstructed_event_from_single_station(self, coincidence, event, reconstructed_theta, reconstructed_phi): dst_row = self.results_group.reconstructions.row dst_row['coinc_id'] = coincidence['id'] @@ -369,9 +362,14 @@ def store_reconstructed_event_from_single_station(self, coincidence, event, dst_row['s%d' % station] = True dst_row.append() - def store_reconstructed_event_from_cluster(self, coincidence, events, - index_group, reconstructed_theta, - reconstructed_phi): + def store_reconstructed_event_from_cluster( + self, + coincidence, + events, + index_group, + reconstructed_theta, + reconstructed_phi, + ): dst_row = self.results_group.reconstructions.row dst_row['coinc_id'] = coincidence['id'] diff --git a/scripts/sciencepark/master.py b/scripts/sciencepark/master.py index 7cfbb1c0..3b32927b 100644 --- a/scripts/sciencepark/master.py +++ b/scripts/sciencepark/master.py @@ -14,7 +14,7 @@ import sapphire.analysis.coincidences -from sapphire import clusters, storage +from sapphire import clusters from sapphire.analysis.core_reconstruction import CoreReconstruction from sapphire.analysis.direction_reconstruction import DirectionReconstruction from sapphire.analysis.process_events import ProcessEvents @@ -24,8 +24,7 @@ class Master: stations = range(501, 507) - datetimerange = (datetime.datetime(2012, 1, 1), - datetime.datetime(2012, 1, 2)) + datetimerange = (datetime.datetime(2012, 1, 1), datetime.datetime(2012, 1, 2)) def __init__(self, data_path): self.data = tables.open_file(data_path, 'a') @@ -52,16 +51,15 @@ def download_data(self): start, end = self.datetimerange for station, group_path in zip(self.stations, self.station_groups): - if not group_path in self.data: - print "Downloading data for station", station - download_data(self.data, group_path, station, - start, end, get_blobs=True) + if group_path not in self.data: + print('Downloading data for station', station) + download_data(self.data, group_path, station, start, end, get_blobs=True) def clean_data(self): for group in self.station_groups: group = self.data.get_node(group) attrs = group._v_attrs - if not 'is_clean' in attrs or not attrs.is_clean: + if 'is_clean' not in attrs or not attrs.is_clean: self.clean_events_in_group(group) attrs.is_clean = True @@ -78,36 +76,36 @@ def clean_events_in_group(self, group): unique_ids.append(row_id) prev = timestamp - tmptable = self.data.create_table(group, 't__events', - description=events.description) + tmptable = self.data.create_table(group, 't__events', description=events.description) rows = events.read_coordinates(unique_ids) tmptable.append(rows) tmptable.flush() self.data.rename_node(tmptable, 'events', overwrite=True) def search_coincidences(self): - print "Searching for coincidences..." + print('Searching for coincidences...') if '/coincidences' not in self.data: - coincidences = sapphire.analysis.coincidences.Coincidences( - self.data, '/coincidences', self.station_groups) + coincidences = sapphire.analysis.coincidences.Coincidences(self.data, '/coincidences', self.station_groups) coincidences.search_coincidences() coincidences.process_events() coincidences.store_coincidences(self.cluster) def reconstruct_direction(self): - print "Reconstructing direction..." - if not '/reconstructions' in self.data: - reconstruction = ClusterDirectionReconstruction(self.data, - self.stations, '/reconstructions', - detector_offsets=self.detector_offsets, - station_offsets=self.station_offsets) + print('Reconstructing direction...') + if '/reconstructions' not in self.data: + reconstruction = ClusterDirectionReconstruction( + self.data, + self.stations, + '/reconstructions', + detector_offsets=self.detector_offsets, + station_offsets=self.station_offsets, + ) reconstruction.reconstruct_angles('/coincidences') def reconstruct_core_position(self): - print "Reconstructing core position..." - if not '/core_reconstructions' in self.data: - reconstruction = CoreReconstruction(self.data, self.stations, - '/core_reconstructions') + print('Reconstructing core position...') + if '/core_reconstructions' not in self.data: + reconstruction = CoreReconstruction(self.data, self.stations, '/core_reconstructions') reconstruction.reconstruct_core_positions('/coincidences') def determine_detector_offsets(self, overwrite=False): @@ -118,7 +116,7 @@ def determine_detector_offsets(self, overwrite=False): for station_id, station_group in enumerate(self.station_groups): process = ProcessEvents(self.data, station_group) offsets = process.determine_detector_timing_offsets() - print "Offsets for station %d: %s" % (station_id, offsets) + print('Offsets for station %d: %s' % (station_id, offsets)) self.detector_offsets.append(offsets) if offsets_group in self.data: self.data.remove_node(offsets_group) @@ -139,27 +137,29 @@ def determine_station_offsets(self, overwrite=False): for station_id, station_group in enumerate(station_groups): coincidences = sapphire.analysis.coincidences.Coincidences( - self.data, coincidence_group=None, - station_groups=[ref_group, station_group]) + self.data, + coincidence_group=None, + station_groups=[ref_group, station_group], + ) c_index, timestamps = coincidences._search_coincidences() dt = [] c_index = [c for c in c_index if len(c) == 2] for i, j in c_index: stations = [timestamps[u][1] for u in [i, j]] - t0, t1 = [int(timestamps[u][0]) for u in [i, j]] + t0, t1 = (int(timestamps[u][0]) for u in [i, j]) if stations[0] > stations[1]: t0, t1 = t1, t0 dt.append(t1 - t0) - print ref_group, station_group, len(dt), + print(ref_group, station_group, len(dt)) y, bins = np.histogram(dt, bins=bins) x = (bins[:-1] + bins[1:]) / 2 - popt, pcov = curve_fit(gauss, x, y, p0=(len(dt), 0, 100.)) - print popt + popt, pcov = curve_fit(gauss, x, y, p0=(len(dt), 0, 100.0)) + print(popt) self.station_offsets.append(popt[1]) ref_idx = self.station_groups.index(ref_group) - self.station_offsets.insert(ref_idx, 0.) + self.station_offsets.insert(ref_idx, 0.0) if offsets_group in self.data: self.data.remove_node(offsets_group) @@ -168,20 +168,29 @@ def determine_station_offsets(self, overwrite=False): class ClusterDirectionReconstruction(DirectionReconstruction): - reconstruction_description = {'coinc_id': tables.UInt32Col(), - 'N': tables.UInt8Col(), - 'reconstructed_theta': tables.Float32Col(), - 'reconstructed_phi': tables.Float32Col(), - 'min_n134': tables.Float32Col(), - } - reconstruction_coincidence_description = {'id': tables.UInt32Col(), - 'N': tables.UInt8Col(), - } - - - def __init__(self, datafile, stations, results_group=None, - min_n134=1., N=None, detector_offsets=None, - station_offsets=None, overwrite=False): + reconstruction_description = { + 'coinc_id': tables.UInt32Col(), + 'N': tables.UInt8Col(), + 'reconstructed_theta': tables.Float32Col(), + 'reconstructed_phi': tables.Float32Col(), + 'min_n134': tables.Float32Col(), + } + reconstruction_coincidence_description = { + 'id': tables.UInt32Col(), + 'N': tables.UInt8Col(), + } + + def __init__( + self, + datafile, + stations, + results_group=None, + min_n134=1.0, + N=None, + detector_offsets=None, + station_offsets=None, + overwrite=False, + ): self.data = datafile self.stations = stations @@ -200,22 +209,19 @@ def _create_reconstruction_group_and_tables(self, results_group, overwrite): if overwrite: self.data.remove_node(results_group, recursive=True) else: - raise RuntimeError("Result group exists, but overwrite is False") + raise RuntimeError('Result group exists, but overwrite is False') head, tail = os.path.split(results_group) group = self.data.create_group(head, tail) - stations_description = {'s%d' % u: tables.BoolCol() for u in - self.stations} + stations_description = {'s%d' % u: tables.BoolCol() for u in self.stations} description = self.reconstruction_description description.update(stations_description) - self.reconstruction = self.data.create_table(group, - 'reconstructions', description) + self.reconstruction = self.data.create_table(group, 'reconstructions', description) description = self.reconstruction_coincidence_description description.update(stations_description) - self.reconstruction_coincidences = \ - self.data.create_table(group, 'coincidences', description) + self.reconstruction_coincidences = self.data.create_table(group, 'coincidences', description) return group @@ -265,7 +271,7 @@ def reconstruct_cluster_stations(self, coincidence): def reconstruct_angle(self, event, offsets=None): """Reconstruct angles from a single event""" - c = 3.00e+8 + c = 3.00e8 if offsets is not None: self._correct_offsets(event, offsets) @@ -277,8 +283,7 @@ def reconstruct_angle(self, event, offsets=None): r1, phi1 = station.calc_r_and_phi_for_detectors(1, 3) r2, phi2 = station.calc_r_and_phi_for_detectors(1, 4) - phi = arctan2((dt2 * r1 * cos(phi1) - dt1 * r2 * cos(phi2)), - (dt2 * r1 * sin(phi1) - dt1 * r2 * sin(phi2)) * -1) + phi = arctan2((dt2 * r1 * cos(phi1) - dt1 * r2 * cos(phi2)), (dt2 * r1 * sin(phi1) - dt1 * r2 * sin(phi2)) * -1) theta1 = arcsin(c * dt1 * 1e-9 / (r1 * cos(phi - phi1))) theta2 = arcsin(c * dt2 * 1e-9 / (r2 * cos(phi - phi2))) @@ -297,27 +302,26 @@ def reconstruct_angle(self, event, offsets=None): def reconstruct_cluster_angle(self, events): """Reconstruct angles from three events""" - c = 3.00e+8 + c = 3.00e8 t = [] stations = [] for event in events: timestamp = int(event['ext_timestamp']) station = event['station_id'] - arrival_times = [event[u] for u in 't1', 't2', 't3', 't4' if - event[u] != -999.] + arrival_times = [event[u] for u in ('t1', 't2', 't3', 't4') if event[u] != -999.0] arrival_times.sort() # FIXME: should check for three low condition (< 1% ?) if len(arrival_times) >= 2: trigger_offset = arrival_times[1] - arrival_times[0] else: - trigger_offset = 0. + trigger_offset = 0.0 offset = self.station_offsets[station] # FIXME: ext_timestamp (long) + float loses precision try: correction = int(round(offset + trigger_offset)) except: - print arrival_times + print(arrival_times) correction = int(round(offset)) t.append(timestamp - correction) stations.append(station) @@ -328,8 +332,7 @@ def reconstruct_cluster_angle(self, events): r1, phi1 = self.cluster.calc_r_and_phi_for_stations(stations[0], stations[1]) r2, phi2 = self.cluster.calc_r_and_phi_for_stations(stations[0], stations[2]) - phi = arctan2((dt2 * r1 * cos(phi1) - dt1 * r2 * cos(phi2)), - (dt2 * r1 * sin(phi1) - dt1 * r2 * sin(phi2)) * -1) + phi = arctan2((dt2 * r1 * cos(phi1) - dt1 * r2 * cos(phi2)), (dt2 * r1 * sin(phi1) - dt1 * r2 * sin(phi2)) * -1) theta1 = arcsin(c * dt1 * 1e-9 / (r1 * cos(phi - phi1))) theta2 = arcsin(c * dt2 * 1e-9 / (r2 * cos(phi - phi2))) @@ -345,9 +348,7 @@ def reconstruct_cluster_angle(self, events): return theta_wgt, phi - def store_reconstructed_event_from_single_station(self, coincidence, event, - reconstructed_theta, - reconstructed_phi): + def store_reconstructed_event_from_single_station(self, coincidence, event, reconstructed_theta, reconstructed_phi): dst_row = self.results_group.reconstructions.row dst_row['coinc_id'] = coincidence['id'] @@ -359,9 +360,14 @@ def store_reconstructed_event_from_single_station(self, coincidence, event, dst_row['s%d' % station] = True dst_row.append() - def store_reconstructed_event_from_cluster(self, coincidence, events, - index_group, reconstructed_theta, - reconstructed_phi): + def store_reconstructed_event_from_cluster( + self, + coincidence, + events, + index_group, + reconstructed_theta, + reconstructed_phi, + ): dst_row = self.results_group.reconstructions.row dst_row['coinc_id'] = coincidence['id'] diff --git a/scripts/sciencepark/plot_trace.py b/scripts/sciencepark/plot_trace.py index 2f9297d7..1941ca3c 100644 --- a/scripts/sciencepark/plot_trace.py +++ b/scripts/sciencepark/plot_trace.py @@ -1,10 +1,10 @@ import tables +from hisparc.analysis.traces import get_traces + from artist import GraphArtist from pylab import * -from hisparc.analysis.traces import get_traces - def plot_trace(station_group, idx): events = station_group.events @@ -20,15 +20,15 @@ def plot_trace(station_group, idx): plot(x, traces.T) xlim(0, 200) - #line_styles = ['solid', 'dashed', 'dotted', 'dashdotted'] + # line_styles = ['solid', 'dashed', 'dotted', 'dashdotted'] line_styles = ['black', 'black!80', 'black!60', 'black!40'] styles = (u for u in line_styles) graph = GraphArtist(width=r'.5\linewidth') for trace in traces: graph.plot(x, trace / 1000, mark=None, linestyle=styles.next()) - graph.set_xlabel(r"Time [\si{\nano\second}]") - graph.set_ylabel(r"Signal [\si{\volt}]") + graph.set_xlabel(r'Time [\si{\nano\second}]') + graph.set_ylabel(r'Signal [\si{\volt}]') graph.set_xlimits(0, 200) graph.save('plots/traces') diff --git a/scripts/sciencepark/utils.py b/scripts/sciencepark/utils.py index d2c9a7d6..936afffb 100644 --- a/scripts/sciencepark/utils.py +++ b/scripts/sciencepark/utils.py @@ -1,10 +1,9 @@ -""" Utility functions """ +"""Utility functions""" import inspect -import numpy as np - import matplotlib.pyplot as plt +import numpy as np __suffix = '' __prefix = '' diff --git a/scripts/simulations/analyze_shower_front.py b/scripts/simulations/analyze_shower_front.py index 9500e337..0aa7a440 100644 --- a/scripts/simulations/analyze_shower_front.py +++ b/scripts/simulations/analyze_shower_front.py @@ -1,5 +1,7 @@ +import matplotlib.pyplot as plt import numpy as np import tables +import utils from scipy.optimize import curve_fit from scipy.stats import scoreatpercentile @@ -7,9 +9,6 @@ from artist import GraphArtist from pylab import * -import matplotlib.pyplot as plt -import utils - USE_TEX = False # For matplotlib plots @@ -17,7 +16,7 @@ rcParams['font.serif'] = 'Computer Modern' rcParams['font.sans-serif'] = 'Computer Modern' rcParams['font.family'] = 'sans-serif' - rcParams['figure.figsize'] = [4 * x for x in (1, 2. / 3)] + rcParams['figure.figsize'] = [4 * x for x in (1, 2.0 / 3)] rcParams['figure.subplot.left'] = 0.175 rcParams['figure.subplot.bottom'] = 0.175 rcParams['font.size'] = 10 @@ -28,12 +27,12 @@ def main(): global data data = tables.open_file('master-ch4v2.h5', 'r') - #utils.set_suffix('E_1PeV') + # utils.set_suffix('E_1PeV') - #scatterplot_core_distance_vs_time() - #median_core_distance_vs_time() + # scatterplot_core_distance_vs_time() + # median_core_distance_vs_time() boxplot_core_distance_vs_time() - #hists_core_distance_vs_time() + # hists_core_distance_vs_time() plot_front_passage() @@ -47,10 +46,10 @@ def scatterplot_core_distance_vs_time(): plt.xlim(1e0, 1e2) plt.ylim(1e-3, 1e3) - plt.xlabel("Core distance [m]") - plt.ylabel("Arrival time [ns]") + plt.xlabel('Core distance [m]') + plt.ylabel('Arrival time [ns]') - utils.title("Shower front timing structure") + utils.title('Shower front timing structure') utils.saveplot() @@ -59,10 +58,10 @@ def median_core_distance_vs_time(): plot_and_fit_statistic(lambda a: scoreatpercentile(a, 25)) plot_and_fit_statistic(lambda a: scoreatpercentile(a, 75)) - utils.title("Shower front timing structure (25, 75 %)") + utils.title('Shower front timing structure (25, 75 %)') utils.saveplot() - plt.xlabel("Core distance [m]") - plt.ylabel("Median arrival time [ns]") + plt.xlabel('Core distance [m]') + plt.ylabel('Median arrival time [ns]') legend(loc='lower right') @@ -85,8 +84,7 @@ def plot_and_fit_statistic(func): logf = lambda x, a, b: a * x + b g = lambda x, a, b: 10 ** logf(log10(x), a, b) popt, pcov = curve_fit(logf, logx, logy) - plot(x, g(x, *popt), label="f(x) = {:.2e} * x ^ {:.2e}".format(10 ** popt[1], - popt[0])) + plot(x, g(x, *popt), label=f'f(x) = {10 ** popt[1]:.2e} * x ^ {popt[0]:.2e}') def boxplot_core_distance_vs_time(): @@ -95,7 +93,7 @@ def boxplot_core_distance_vs_time(): sim = data.root.showers.E_1PeV.zenith_0.shower_0 leptons = sim.leptons - #bins = np.logspace(0, 2, 25) + # bins = np.logspace(0, 2, 25) bins = np.linspace(0, 100, 15) x, arrival_time, widths = [], [], [] t25, t50, t75 = [], [], [] @@ -112,17 +110,17 @@ def boxplot_core_distance_vs_time(): fill_between(x, t25, t75, color='0.75') plot(x, t50, 'o-', color='black') - plt.xlabel("Core distance [m]") - plt.ylabel("Arrival time [ns]") + plt.xlabel('Core distance [m]') + plt.ylabel('Arrival time [ns]') - #utils.title("Shower front timing structure") + # utils.title("Shower front timing structure") utils.saveplot() graph = GraphArtist() graph.plot(x, t50, linestyle=None) graph.shade_region(x, t25, t75) - graph.set_xlabel(r"Core distance [\si{\meter}]") - graph.set_ylabel(r"Arrival time [\si{\nano\second}]") + graph.set_xlabel(r'Core distance [\si{\meter}]') + graph.set_ylabel(r'Arrival time [\si{\nano\second}]') graph.set_ylimits(0, 30) graph.set_xlimits(0, 100) graph.save('plots/front-passage-vs-R') @@ -138,17 +136,20 @@ def hists_core_distance_vs_time(): for low, high in zip(bins[:-1], bins[1:]): sel = electrons.read_where('(low < core_distance) & (core_distance <= high)') arrival_time = sel[:]['arrival_time'] - plt.hist(arrival_time, bins=np.logspace(-2, 3, 50), histtype='step', - label="{:.2f} <= log10(R) < {:.2f}".format(np.log10(low), - np.log10(high))) + plt.hist( + arrival_time, + bins=np.logspace(-2, 3, 50), + histtype='step', + label=f'{np.log10(low):.2f} <= log10(R) < {np.log10(high):.2f}', + ) plt.xscale('log') - plt.xlabel("Arrival Time [ns]") - plt.ylabel("Count") + plt.xlabel('Arrival Time [ns]') + plt.ylabel('Count') plt.legend(loc='upper left') - utils.title("Shower front timing structure") + utils.title('Shower front timing structure') utils.saveplot() @@ -160,15 +161,14 @@ def plot_front_passage(): low = R - dR high = R + dR global t - t = leptons.read_where('(low < core_distance) & (core_distance <= high)', - field='arrival_time') + t = leptons.read_where('(low < core_distance) & (core_distance <= high)', field='arrival_time') n, bins, patches = hist(t, bins=linspace(0, 30, 31), histtype='step') graph = GraphArtist() graph.histogram(n, bins) - graph.set_xlabel(r"Arrival time [\si{\nano\second}]") - graph.set_ylabel("Number of leptons") + graph.set_xlabel(r'Arrival time [\si{\nano\second}]') + graph.set_ylabel('Number of leptons') graph.set_ylimits(min=0) graph.set_xlimits(0, 30) graph.save('plots/front-passage') diff --git a/scripts/simulations/cluster_sim.py b/scripts/simulations/cluster_sim.py index 8f1aa78d..5231273b 100644 --- a/scripts/simulations/cluster_sim.py +++ b/scripts/simulations/cluster_sim.py @@ -1,11 +1,11 @@ -""" HiSPARC detector simulation +"""HiSPARC detector simulation - This simulation takes an Extended Air Shower simulation ground - particles file and uses that to simulate numerous showers hitting a - HiSPARC detector station. Only data of one shower is used, but by - randomly selecting points on the ground as the position of a station, - the effect of the same shower hitting various positions around the - station is simulated. +This simulation takes an Extended Air Shower simulation ground +particles file and uses that to simulate numerous showers hitting a +HiSPARC detector station. Only data of one shower is used, but by +randomly selecting points on the ground as the position of a station, +the effect of the same shower hitting various positions around the +station is simulated. """ @@ -13,11 +13,10 @@ import sys import textwrap -import tables - import clusters +import tables -from simulations import GroundParticlesSimulation, QSubSimulation +from simulations import GroundParticlesSimulation DATAFILE = 'data.h5' @@ -29,10 +28,14 @@ data = tables.open_file(DATAFILE, 'a') if '/simulations' in data: - print - print textwrap.dedent("""\ + print() + print( + textwrap.dedent( + """\ WARNING: previous simulations exist and will be overwritten - Continue? (answer 'yes'; anything else will exit)""") + Continue? (answer 'yes'; anything else will exit)""", + ), + ) try: inp = raw_input() except KeyboardInterrupt: @@ -41,16 +44,18 @@ if inp.lower() == 'yes': data.remove_node('/simulations', recursive=True) else: - print - print "Aborting!" + print() + print('Aborting!') sys.exit(1) sim = 'E_1PeV/zenith_0' cluster = clusters.SimpleCluster() - simulation = GroundParticlesSimulation(cluster, data, - os.path.join('/showers', sim, - 'leptons'), - os.path.join('/simulations', - sim), - R=100, N=100) + simulation = GroundParticlesSimulation( + cluster, + data, + os.path.join('/showers', sim, 'leptons'), + os.path.join('/simulations', sim), + R=100, + N=100, + ) simulation.run() diff --git a/scripts/simulations/core_reconstruction.py b/scripts/simulations/core_reconstruction.py index db3146d9..f81a44db 100644 --- a/scripts/simulations/core_reconstruction.py +++ b/scripts/simulations/core_reconstruction.py @@ -1,12 +1,7 @@ -import os -import sys - -from itertools import combinations - import numpy as np import tables +import utils -from scipy import optimize from scipy.misc import comb from scipy.stats import scoreatpercentile @@ -14,9 +9,6 @@ from pylab import * -import utils - -from sapphire import storage from sapphire.analysis.core_reconstruction import * from sapphire.simulations import ldf @@ -29,7 +21,7 @@ rcParams['font.serif'] = 'Computer Modern' rcParams['font.sans-serif'] = 'Computer Modern' rcParams['font.family'] = 'sans-serif' - rcParams['figure.figsize'] = [4 * x for x in (1, 2. / 3)] + rcParams['figure.figsize'] = [4 * x for x in (1, 2.0 / 3)] rcParams['figure.subplot.left'] = 0.175 rcParams['figure.subplot.bottom'] = 0.175 rcParams['font.size'] = 10 @@ -48,9 +40,7 @@ def do_reconstruction_plots(table): Pnil = lambda x: exp(-0.5 * x) Pp = lambda x: 1 - Pnil(x) -Ptrig = lambda x: comb(4, 2) * Pp(x) ** 2 * Pnil(x) ** 2 + \ - comb(4, 3) * Pp(x) ** 3 * Pnil(x) + \ - comb(4, 4) * Pp(x) ** 4 +Ptrig = lambda x: comb(4, 2) * Pp(x) ** 2 * Pnil(x) ** 2 + comb(4, 3) * Pp(x) ** 3 * Pnil(x) + comb(4, 4) * Pp(x) ** 4 def plot_N_reconstructions_vs_R(table): @@ -87,15 +77,15 @@ def plot_N_reconstructions_vs_R(table): x = array(x) y = array(y) - plot(x, y, label="sim") + plot(x, y, label='sim') kldf = ldf.KascadeLdf() dens = kldf.calculate_ldf_value(x) - plot(x, Ptrig(dens), label="calc") + plot(x, Ptrig(dens), label='calc') legend() - xlabel("Core distance [m]") - ylabel("Reconstruction efficiency") + xlabel('Core distance [m]') + ylabel('Reconstruction efficiency') utils.saveplot() @@ -122,8 +112,8 @@ def plot_core_pos_uncertainty_vs_R(table): fill_between(x, d25, d75, color='0.75') plot(x, d50, 'o-', color='black') - xlabel("Core distance [m]") - ylabel("Core position uncertainty [m]") + xlabel('Core distance [m]') + ylabel('Core position uncertainty [m]') utils.saveplot() @@ -134,12 +124,12 @@ def plot_shower_size_hist(table): hist(log10(reconstructed), bins=200, histtype='step') reference_shower_size = table[0]['reference_shower_size'] - if reference_shower_size == 0.: - reference_shower_size = 10 ** 4.8 + if reference_shower_size == 0.0: + reference_shower_size = 10**4.8 axvline(log10(reference_shower_size)) - xlabel("log shower size") - ylabel("count") + xlabel('log shower size') + ylabel('count') utils.saveplot() @@ -153,29 +143,29 @@ def plot_scatter_reconstructed_core(table, N=None): station = table.attrs.cluster.stations[0] subplot(121) x, y = table.col('reference_core_pos')[:N].T - #scatter(x, y, c='b', s=1, edgecolor='none', zorder=1) + # scatter(x, y, c='b', s=1, edgecolor='none', zorder=1) plot(x, y, ',', c='b', markeredgecolor='b', zorder=1) for detector in station.detectors: x, y = detector.get_xy_coordinates() plt.scatter(x, y, c='r', s=20, edgecolor='none', zorder=2) - xlabel("Distance [m]") - ylabel("Distance [m]") + xlabel('Distance [m]') + ylabel('Distance [m]') xlim(-60, 60) ylim(-60, 60) - title("simulated") + title('simulated') subplot(122) x, y = table.col('reconstructed_core_pos')[:N].T - #scatter(x, y, c='b', s=1, edgecolor='none', zorder=1) + # scatter(x, y, c='b', s=1, edgecolor='none', zorder=1) plot(x, y, ',', c='b', markeredgecolor='b', zorder=1) for detector in station.detectors: x, y = detector.get_xy_coordinates() plt.scatter(x, y, c='r', s=20, edgecolor='none', zorder=2) - xlabel("Distance [m]") - ylabel("Distance [m]") + xlabel('Distance [m]') + ylabel('Distance [m]') xlim(-60, 60) ylim(-60, 60) - title("reconstructed") + title('reconstructed') utils.saveplot() @@ -204,31 +194,35 @@ def plot_scatter_reconstructed_core(table, N=None): c = CoreReconstruction(data, '/reconstructions/poisson_gauss_20') c.reconstruct_core_positions('/ldfsim/poisson_gauss_20') - c = CoreReconstruction(data, '/reconstructions/poisson_gauss_20_nonull', solver=CorePositionSolverWithoutNullMeasurements(ldf.KascadeLdf())) + c = CoreReconstruction( + data, + '/reconstructions/poisson_gauss_20_nonull', + solver=CorePositionSolverWithoutNullMeasurements(ldf.KascadeLdf()), + ) c.reconstruct_core_positions('/ldfsim/poisson_gauss_20') - #c = CoreReconstruction(data, '/reconstructions/ground_gauss_20') - #c.reconstruct_core_positions('/groundsim/zenith_0/shower_0') + # c = CoreReconstruction(data, '/reconstructions/ground_gauss_20') + # c.reconstruct_core_positions('/groundsim/zenith_0/shower_0') - utils.set_prefix("COR-") + utils.set_prefix('COR-') - utils.set_suffix("-EXACT") + utils.set_suffix('-EXACT') do_reconstruction_plots(data.root.reconstructions.exact) - utils.set_suffix("-GAUSS_10") + utils.set_suffix('-GAUSS_10') do_reconstruction_plots(data.root.reconstructions.gauss_10) - utils.set_suffix("-GAUSS_20") + utils.set_suffix('-GAUSS_20') do_reconstruction_plots(data.root.reconstructions.gauss_20) - utils.set_suffix("-POISSON") + utils.set_suffix('-POISSON') do_reconstruction_plots(data.root.reconstructions.poisson) - utils.set_suffix("-POISSON-GAUSS_20") + utils.set_suffix('-POISSON-GAUSS_20') do_reconstruction_plots(data.root.reconstructions.poisson_gauss_20) - utils.set_suffix("-POISSON-GAUSS_20_NONULL") + utils.set_suffix('-POISSON-GAUSS_20_NONULL') do_reconstruction_plots(data.root.reconstructions.poisson_gauss_20_nonull) - #utils.set_suffix("-GROUND-GAUSS_20") - #do_reconstruction_plots(data.root.reconstructions.ground_gauss_20) + # utils.set_suffix("-GROUND-GAUSS_20") + # do_reconstruction_plots(data.root.reconstructions.ground_gauss_20) diff --git a/scripts/simulations/detector_sim.py b/scripts/simulations/detector_sim.py index 2dec895e..29d46d28 100644 --- a/scripts/simulations/detector_sim.py +++ b/scripts/simulations/detector_sim.py @@ -1,11 +1,11 @@ -""" HiSPARC detector simulation +"""HiSPARC detector simulation - This simulation takes an Extended Air Shower simulation ground - particles file and uses that to simulate numerous showers hitting a - HiSPARC detector station. Only data of one shower is used, but by - randomly selecting points on the ground as the position of a station, - the effect of the same shower hitting various positions around the - station is simulated. +This simulation takes an Extended Air Shower simulation ground +particles file and uses that to simulate numerous showers hitting a +HiSPARC detector station. Only data of one shower is used, but by +randomly selecting points on the ground as the position of a station, +the effect of the same shower hitting various positions around the +station is simulated. """ @@ -13,11 +13,10 @@ import sys import textwrap -import tables - import clusters +import tables -from simulations import GroundParticlesSimulation, QSubSimulation +from simulations import GroundParticlesSimulation DATAFILE = 'data-e15.h5' @@ -29,10 +28,14 @@ data = tables.open_file(DATAFILE, 'a') if '/simulations' in data: - print - print textwrap.dedent("""\ + print() + print( + textwrap.dedent( + """\ WARNING: previous simulations exist and will be overwritten - Continue? (answer 'yes'; anything else will exit)""") + Continue? (answer 'yes'; anything else will exit)""", + ), + ) try: inp = raw_input() except KeyboardInterrupt: @@ -41,16 +44,18 @@ if inp.lower() == 'yes': data.remove_node('/simulations', recursive=True) else: - print - print "Aborting!" + print() + print('Aborting!') sys.exit(1) sim = 'E_1PeV/zenith_0' cluster = clusters.SingleStation() - simulation = GroundParticlesSimulation(cluster, data, - os.path.join('/showers', sim, - 'leptons'), - os.path.join('/simulations', - sim), - R=50, N=10000) + simulation = GroundParticlesSimulation( + cluster, + data, + os.path.join('/showers', sim, 'leptons'), + os.path.join('/simulations', sim), + R=50, + N=10000, + ) simulation.run() diff --git a/scripts/simulations/direction_reconstruction.py b/scripts/simulations/direction_reconstruction.py index a2730103..3b7a266d 100644 --- a/scripts/simulations/direction_reconstruction.py +++ b/scripts/simulations/direction_reconstruction.py @@ -3,7 +3,9 @@ from itertools import izip import tables +import utils +from myshowerfront import * from scipy import integrate from scipy.interpolate import spline from scipy.special import erf @@ -14,9 +16,6 @@ from artist import GraphArtist from pylab import * -import utils - -from myshowerfront import * from sapphire.analysis.direction_reconstruction import BinnedDirectionReconstruction, DirectionReconstruction from sapphire.utils import pbar @@ -31,7 +30,7 @@ rcParams['font.serif'] = 'Computer Modern' rcParams['font.sans-serif'] = 'Computer Modern' rcParams['font.family'] = 'sans-serif' - rcParams['figure.figsize'] = [4 * x for x in (1, 2. / 3)] + rcParams['figure.figsize'] = [4 * x for x in (1, 2.0 / 3)] rcParams['figure.subplot.left'] = 0.175 rcParams['figure.subplot.bottom'] = 0.175 rcParams['font.size'] = 10 @@ -77,7 +76,14 @@ def do_full_reconstruction(data, N=None): for binning in 1, 2.5, 5: dest_table = dest + '_%s' % str(binning).replace('.', '_') - rec = BinnedDirectionReconstruction(data, dest_table, min_n134=1, binning=binning, randomize_binning=randomize, N=N) + rec = BinnedDirectionReconstruction( + data, + dest_table, + min_n134=1, + binning=binning, + randomize_binning=randomize, + N=N, + ) rec.reconstruct_angles_for_shower_group(source) @@ -105,11 +111,11 @@ def do_reconstruction_plots(data): save_for_kascade_boxplot_core_distances_for_mips(group) plot_detection_efficiency_vs_R_for_angles(1) plot_detection_efficiency_vs_R_for_angles(2) - #plot_reconstruction_efficiency_vs_R_for_angles(1) - #plot_reconstruction_efficiency_vs_R_for_angles(2) + # plot_reconstruction_efficiency_vs_R_for_angles(1) + # plot_reconstruction_efficiency_vs_R_for_angles(2) artistplot_reconstruction_efficiency_vs_R_for_angles(1) artistplot_reconstruction_efficiency_vs_R_for_angles(2) - #plot_reconstruction_efficiency_vs_R_for_mips() + # plot_reconstruction_efficiency_vs_R_for_mips() def plot_uncertainty_mip(group): @@ -128,31 +134,31 @@ def plot_uncertainty_mip(group): for N in range(1, 5): x.append(N) events = table.read_where('min_n134>=%d' % N) - #query = '(n1 == N) & (n3 == N) & (n4 == N)' - #vents = table.read_where(query) - print len(events), + # query = '(n1 == N) & (n3 == N) & (n4 == N)' + # vents = table.read_where(query) + print(len(events)) errors = events['reference_theta'] - events['reconstructed_theta'] # Make sure -pi < errors < pi errors = (errors + pi) % (2 * pi) - pi errors2 = events['reference_phi'] - events['reconstructed_phi'] # Make sure -pi < errors2 < pi errors2 = (errors2 + pi) % (2 * pi) - pi - #y.append(std(errors)) - #y2.append(std(errors2)) + # y.append(std(errors)) + # y2.append(std(errors2)) y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) - print "YYY", rad2deg(scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) + print('YYY', rad2deg(scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17))) - plot(x, rad2deg(y), '^', label="Theta") - plot(x, rad2deg(y2), 'v', label="Phi") + plot(x, rad2deg(y), '^', label='Theta') + plot(x, rad2deg(y2), 'v', label='Phi') Sx = x Sy = y Sy2 = y2 - print - print "mip: min_n134, theta_std, phi_std" + print() + print('mip: min_n134, theta_std, phi_std') for u, v, w in zip(x, y, y2): - print u, v, w - print + print(u, v, w) + print() utils.savedata((x, y, y2)) # Uncertainty estimate @@ -165,8 +171,8 @@ def plot_uncertainty_mip(group): mc = my_std_t_for_R(data, x, R_list) for u, v in zip(mc, R_list): - print v, u, sqrt(u ** 2 + 1.2 ** 2), sqrt((.66 * u) ** 2 + 1.2 ** 2) - mc = sqrt(mc ** 2 + 1.2 ** 2) + print(v, u, sqrt(u**2 + 1.2**2), sqrt((0.66 * u) ** 2 + 1.2**2)) + mc = sqrt(mc**2 + 1.2**2) y3 = mc * sqrt(phi_errsq) y4 = mc * sqrt(theta_errsq) @@ -176,26 +182,26 @@ def plot_uncertainty_mip(group): y3 = spline(x, y3, nx) y4 = spline(x, y4, nx) - plot(nx, rad2deg(y), label="Gauss Phi") - plot(nx, rad2deg(y2), label="Gauss Theta") - plot(nx, rad2deg(y3), label="Monte Carlo Phi") - plot(nx, rad2deg(y4), label="Monte Carlo Theta") + plot(nx, rad2deg(y), label='Gauss Phi') + plot(nx, rad2deg(y2), label='Gauss Theta') + plot(nx, rad2deg(y3), label='Monte Carlo Phi') + plot(nx, rad2deg(y4), label='Monte Carlo Theta') # Labels etc. - xlabel("Minimum number of particles") - ylabel("Angle reconstruction uncertainty [deg]") - #title(r"$\theta = 22.5^\circ$") + xlabel('Minimum number of particles') + ylabel('Angle reconstruction uncertainty [deg]') + # title(r"$\theta = 22.5^\circ$") legend(numpoints=1) - xlim(.5, 4.5) + xlim(0.5, 4.5) utils.saveplot() - print + print() graph = GraphArtist() graph.plot(Sx, rad2deg(Sy), mark='o', linestyle='only marks') graph.plot(Sx, rad2deg(Sy2), mark='*', linestyle='only marks') graph.plot(nx, rad2deg(y), mark=None, linestyle='dashed,smooth') graph.plot(nx, rad2deg(y2), mark=None, linestyle='dashed,smooth') - graph.set_xlabel("Minimum number of particles") - graph.set_ylabel(r"Reconstruction uncertainty [\si{\degree}]") + graph.set_xlabel('Minimum number of particles') + graph.set_ylabel(r'Reconstruction uncertainty [\si{\degree}]') graph.set_xticks(range(1, 5)) graph.set_ylimits(0, 32) artist.utils.save_graph(graph, dirname='plots') @@ -224,27 +230,27 @@ def plot_uncertainty_zenith(group): x.append(THETA) table = group._f_get_child('zenith_%s' % str(THETA).replace('.', '_')) events = table.read_where('min_n134 >= N') - print THETA, len(events), + print(THETA, len(events)) errors = events['reference_theta'] - events['reconstructed_theta'] # Make sure -pi < errors < pi errors = (errors + pi) % (2 * pi) - pi errors2 = events['reference_phi'] - events['reconstructed_phi'] # Make sure -pi < errors2 < pi errors2 = (errors2 + pi) % (2 * pi) - pi - #y.append(std(errors)) - #y2.append(std(errors2)) + # y.append(std(errors)) + # y2.append(std(errors2)) y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) - plot(x, rad2deg(y), '^', label="Theta") + plot(x, rad2deg(y), '^', label='Theta') graph.plot(x, rad2deg(y), mark='o', linestyle=None) # Azimuthal angle undefined for zenith = 0 - plot(x[1:], rad2deg(y2[1:]), 'v', label="Phi") + plot(x[1:], rad2deg(y2[1:]), 'v', label='Phi') graph.plot(x[1:], rad2deg(y2[1:]), mark='*', linestyle=None) - print - print "zenith: theta, theta_std, phi_std" + print() + print('zenith: theta, theta_std, phi_std') for u, v, w in zip(x, y, y2): - print u, v, w - print + print(u, v, w) + print() utils.savedata((x, y, y2)) # Uncertainty estimate @@ -256,23 +262,23 @@ def plot_uncertainty_zenith(group): y2.append(mean(rec.rel_theta1_errorsq(t, phis, phi1, phi2, r1, r2))) y = TIMING_ERROR * sqrt(array(y)) y2 = TIMING_ERROR * sqrt(array(y2)) - plot(rad2deg(x), rad2deg(y), label="Estimate Phi") + plot(rad2deg(x), rad2deg(y), label='Estimate Phi') graph.plot(rad2deg(x), rad2deg(y), mark=None) - plot(rad2deg(x), rad2deg(y2), label="Estimate Theta") + plot(rad2deg(x), rad2deg(y2), label='Estimate Theta') graph.plot(rad2deg(x), rad2deg(y2), mark=None) # Labels etc. - xlabel("Shower zenith angle [deg]") - graph.set_xlabel(r"Shower zenith angle [\si{\degree}]") - ylabel("Angle reconstruction uncertainty [deg]") - graph.set_ylabel(r"Angle reconstruction uncertainty [\si{\degree}]") - #title(r"$N_{MIP} \geq %d$" % N) + xlabel('Shower zenith angle [deg]') + graph.set_xlabel(r'Shower zenith angle [\si{\degree}]') + ylabel('Angle reconstruction uncertainty [deg]') + graph.set_ylabel(r'Angle reconstruction uncertainty [\si{\degree}]') + # title(r"$N_{MIP} \geq %d$" % N) ylim(0, 100) graph.set_ylimits(0, 60) legend(numpoints=1) utils.saveplot() artist.utils.save_graph(graph, dirname='plots') - print + print() def plot_uncertainty_core_distance(group): @@ -286,37 +292,37 @@ def plot_uncertainty_core_distance(group): for R in range(0, 81, 20): x.append(R) events = table.read_where('(min_n134 == N) & (abs(r - R) <= DR)') - print len(events), + print(len(events)) errors = events['reference_theta'] - events['reconstructed_theta'] # Make sure -pi < errors < pi errors = (errors + pi) % (2 * pi) - pi errors2 = events['reference_phi'] - events['reconstructed_phi'] # Make sure -pi < errors2 < pi errors2 = (errors2 + pi) % (2 * pi) - pi - #y.append(std(errors)) - #y2.append(std(errors2)) + # y.append(std(errors)) + # y2.append(std(errors2)) y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) - print - print "R: theta_std, phi_std" + print() + print('R: theta_std, phi_std') for u, v, w in zip(x, y, y2): - print u, v, w - print + print(u, v, w) + print() utils.savedata((x, y, y2)) # Plots - plot(x, rad2deg(y), '^-', label="Theta") - plot(x, rad2deg(y2), 'v-', label="Phi") + plot(x, rad2deg(y), '^-', label='Theta') + plot(x, rad2deg(y2), 'v-', label='Phi') # Labels etc. - xlabel(r"Core distance [m] $\pm %d$" % DR) - ylabel("Angle reconstruction uncertainty [deg]") - #title(r"$N_{MIP} = %d, \theta = 22.5^\circ$" % N) + xlabel(r'Core distance [m] $\pm %d$' % DR) + ylabel('Angle reconstruction uncertainty [deg]') + # title(r"$N_{MIP} = %d, \theta = 22.5^\circ$" % N) ylim(ymin=0) legend(numpoints=1, loc='best') utils.saveplot() - print + print() def plot_uncertainty_size(group): @@ -344,26 +350,26 @@ def plot_uncertainty_size(group): table = group._f_get_child('zenith_22_5') events = table.read_where('min_n134 >= N') - print size, len(events), + print(size, len(events)) errors = events['reference_theta'] - events['reconstructed_theta'] # Make sure -pi < errors < pi errors = (errors + pi) % (2 * pi) - pi errors2 = events['reference_phi'] - events['reconstructed_phi'] # Make sure -pi < errors2 < pi errors2 = (errors2 + pi) % (2 * pi) - pi - #y.append(std(errors)) - #y2.append(std(errors2)) + # y.append(std(errors)) + # y2.append(std(errors2)) y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) - plot(x, rad2deg(y), '^', label="Theta") + plot(x, rad2deg(y), '^', label='Theta') graph.plot(x, rad2deg(y), mark='o', linestyle=None) - plot(x, rad2deg(y2), 'v', label="Phi") + plot(x, rad2deg(y2), 'v', label='Phi') graph.plot(x, rad2deg(y2), mark='*', linestyle=None) - print - print "stationsize: size, theta_std, phi_std" + print() + print('stationsize: size, theta_std, phi_std') for u, v, w in zip(x, y, y2): - print u, v, w - print + print(u, v, w) + print() # Uncertainty estimate x = linspace(5, 20, 50) @@ -374,22 +380,22 @@ def plot_uncertainty_size(group): y2.append(mean(rec.rel_theta1_errorsq(pi / 8, phis, phi1, phi2, r1=s, r2=s))) y = TIMING_ERROR * sqrt(array(y)) y2 = TIMING_ERROR * sqrt(array(y2)) - plot(x, rad2deg(y), label="Estimate Phi") + plot(x, rad2deg(y), label='Estimate Phi') graph.plot(x, rad2deg(y), mark=None) - plot(x, rad2deg(y2), label="Estimate Theta") + plot(x, rad2deg(y2), label='Estimate Theta') graph.plot(x, rad2deg(y2), mark=None) # Labels etc. - xlabel("Station size [m]") - graph.set_xlabel(r"Station size [\si{\meter}]") - ylabel("Angle reconstruction uncertainty [deg]") - graph.set_ylabel(r"Angle reconstruction uncertainty [\si{\degree}]") + xlabel('Station size [m]') + graph.set_xlabel(r'Station size [\si{\meter}]') + ylabel('Angle reconstruction uncertainty [deg]') + graph.set_ylabel(r'Angle reconstruction uncertainty [\si{\degree}]') graph.set_ylimits(0, 25) - #title(r"$\theta = 22.5^\circ, N_{MIP} \geq %d$" % N) + # title(r"$\theta = 22.5^\circ, N_{MIP} \geq %d$" % N) legend(numpoints=1) utils.saveplot() artist.utils.save_graph(graph, dirname='plots') - print + print() def plot_uncertainty_binsize(group): @@ -416,26 +422,26 @@ def plot_uncertainty_binsize(group): table = group.zenith_22_5 events = table.read_where('min_n134 >= 2') - print bin_size, len(events), + print(bin_size, len(events)) errors = events['reference_theta'] - events['reconstructed_theta'] # Make sure -pi < errors < pi errors = (errors + pi) % (2 * pi) - pi errors2 = events['reference_phi'] - events['reconstructed_phi'] # Make sure -pi < errors2 < pi errors2 = (errors2 + pi) % (2 * pi) - pi - #y.append(std(errors)) - #y2.append(std(errors2)) + # y.append(std(errors)) + # y2.append(std(errors2)) y.append((scoreatpercentile(errors, 83) - scoreatpercentile(errors, 17)) / 2) y2.append((scoreatpercentile(errors2, 83) - scoreatpercentile(errors2, 17)) / 2) - plot(x, rad2deg(y), '^', label="Theta") + plot(x, rad2deg(y), '^', label='Theta') graph.plot(x, rad2deg(y), mark='o', linestyle=None) - plot(x, rad2deg(y2), 'v', label="Phi") + plot(x, rad2deg(y2), 'v', label='Phi') graph.plot(x, rad2deg(y2), mark='*', linestyle=None) - print - print "binsize: size, theta_std, phi_std" + print() + print('binsize: size, theta_std, phi_std') for u, v, w in zip(x, y, y2): - print u, v, w - print + print(u, v, w) + print() # Uncertainty estimate x = linspace(0, 5, 50) @@ -444,40 +450,36 @@ def plot_uncertainty_binsize(group): phi_errorsq = mean(rec.rel_phi_errorsq(pi / 8, phis, phi1, phi2, r1, r2)) theta_errorsq = mean(rec.rel_theta1_errorsq(pi / 8, phis, phi1, phi2, r1, r2)) for t in x: - y.append(sqrt((TIMING_ERROR ** 2 + t ** 2 / 12) * phi_errorsq)) - y2.append(sqrt((TIMING_ERROR ** 2 + t ** 2 / 12) * theta_errorsq)) + y.append(sqrt((TIMING_ERROR**2 + t**2 / 12) * phi_errorsq)) + y2.append(sqrt((TIMING_ERROR**2 + t**2 / 12) * theta_errorsq)) y = array(y) y2 = array(y2) - plot(x, rad2deg(y), label="Estimate Phi") + plot(x, rad2deg(y), label='Estimate Phi') graph.plot(x, rad2deg(y), mark=None) - plot(x, rad2deg(y2), label="Estimate Theta") + plot(x, rad2deg(y2), label='Estimate Theta') graph.plot(x, rad2deg(y2), mark=None) # Labels etc. - xlabel("Sampling time [ns]") - graph.set_xlabel(r"Sampling time [\si{\nano\second}]") - ylabel("Angle reconstruction uncertainty [deg]") - graph.set_ylabel(r"Angle reconstruction uncertainty [\si{\degree}]") + xlabel('Sampling time [ns]') + graph.set_xlabel(r'Sampling time [\si{\nano\second}]') + ylabel('Angle reconstruction uncertainty [deg]') + graph.set_ylabel(r'Angle reconstruction uncertainty [\si{\degree}]') graph.set_ylimits(0, 20) - #title(r"$\theta = 22.5^\circ, N_{MIP} \geq %d$" % N) + # title(r"$\theta = 22.5^\circ, N_{MIP} \geq %d$" % N) legend(loc='upper left', numpoints=1) ylim(0, 20) xlim(-0.1, 5.5) utils.saveplot() artist.utils.save_graph(graph, dirname='plots') - print + print() + # Time of first hit pamflet functions -Q = lambda t, n: ((.5 * (1 - erf(t / sqrt(2)))) ** (n - 1) - * exp(-.5 * t ** 2) / sqrt(2 * pi)) +Q = lambda t, n: ((0.5 * (1 - erf(t / sqrt(2)))) ** (n - 1) * exp(-0.5 * t**2) / sqrt(2 * pi)) -expv_t = vectorize(lambda n: integrate.quad(lambda t: t * Q(t, n) - / n ** -1, - - inf, +inf)) +expv_t = vectorize(lambda n: integrate.quad(lambda t: t * Q(t, n) / n**-1, -inf, +inf)) expv_tv = lambda n: expv_t(n)[0] -expv_tsq = vectorize(lambda n: integrate.quad(lambda t: t ** 2 * Q(t, n) - / n ** -1, - - inf, +inf)) +expv_tsq = vectorize(lambda n: integrate.quad(lambda t: t**2 * Q(t, n) / n**-1, -inf, +inf)) expv_tsqv = lambda n: expv_tsq(n)[0] std_t = lambda n: sqrt(expv_tsqv(n) - expv_tv(n) ** 2) @@ -492,15 +494,14 @@ def plot_phi_reconstruction_results_for_MIP(group, N): figure() plot_2d_histogram(rad2deg(sim_phi), rad2deg(r_phi), 180) - xlabel(r"$\phi_{simulated}$ [deg]") - ylabel(r"$\phi_{reconstructed}$ [deg]") - #title(r"$N_{MIP} \geq %d, \quad \theta = 22.5^\circ$" % N) + xlabel(r'$\phi_{simulated}$ [deg]') + ylabel(r'$\phi_{reconstructed}$ [deg]') + # title(r"$N_{MIP} \geq %d, \quad \theta = 22.5^\circ$" % N) utils.saveplot(N) graph = artist.GraphArtist() bins = linspace(-180, 180, 73) - H, x_edges, y_edges = histogram2d(rad2deg(sim_phi), rad2deg(r_phi), - bins=bins) + H, x_edges, y_edges = histogram2d(rad2deg(sim_phi), rad2deg(r_phi), bins=bins) graph.histogram2d(H, x_edges, y_edges, type='reverse_bw') graph.set_xlabel(r'$\phi_\mathrm{sim}$ [\si{\degree}]') graph.set_ylabel(r'$\phi_\mathrm{rec}$ [\si{\degree}]') @@ -530,9 +531,9 @@ def boxplot_theta_reconstruction_results_for_MIP(group, N): fill_between(angles, d25, d75, color='0.75') plot(angles, d50, 'o-', color='black') - xlabel(r"$\theta_{simulated}$ [deg]") - ylabel(r"$\theta_{reconstructed} - \theta_{simulated}$ [deg]") - #title(r"$N_{MIP} \geq %d$" % N) + xlabel(r'$\theta_{simulated}$ [deg]') + ylabel(r'$\theta_{reconstructed} - \theta_{simulated}$ [deg]') + # title(r"$N_{MIP} \geq %d$" % N) axhline(0, color='black') ylim(-10, 25) @@ -543,9 +544,9 @@ def boxplot_theta_reconstruction_results_for_MIP(group, N): graph.draw_horizontal_line(0, linestyle='gray') graph.shade_region(angles, d25, d75) graph.plot(angles, d50, linestyle=None) - graph.set_xlabel(r"$\theta_\mathrm{sim}$ [\si{\degree}]") - graph.set_ylabel(r"$\theta_\mathrm{rec} - \theta_\mathrm{sim}$ [\si{\degree}]") - graph.set_title(r"$N_\mathrm{MIP} \geq %d$" % N) + graph.set_xlabel(r'$\theta_\mathrm{sim}$ [\si{\degree}]') + graph.set_ylabel(r'$\theta_\mathrm{rec} - \theta_\mathrm{sim}$ [\si{\degree}]') + graph.set_title(r'$N_\mathrm{MIP} \geq %d$' % N) graph.set_ylimits(-8, 22) artist.utils.save_graph(graph, suffix=N, dirname='plots') @@ -575,9 +576,9 @@ def boxplot_phi_reconstruction_results_for_MIP(group, N): fill_between(x, d25, d75, color='0.75') plot(x, d50, 'o-', color='black') - xlabel(r"$\phi_{simulated}$ [deg]") - ylabel(r"$\phi_{reconstructed} - \phi_{simulated}$ [deg]") - #title(r"$N_{MIP} \geq %d, \quad \theta = 22.5^\circ$" % N) + xlabel(r'$\phi_{simulated}$ [deg]') + ylabel(r'$\phi_{reconstructed} - \phi_{simulated}$ [deg]') + # title(r"$N_{MIP} \geq %d, \quad \theta = 22.5^\circ$" % N) xticks(linspace(-180, 180, 9)) axhline(0, color='black') @@ -589,9 +590,9 @@ def boxplot_phi_reconstruction_results_for_MIP(group, N): graph.draw_horizontal_line(0, linestyle='gray') graph.shade_region(x, d25, d75) graph.plot(x, d50, linestyle=None) - graph.set_xlabel(r"$\phi_\mathrm{sim}$ [\si{\degree}]") - graph.set_ylabel(r"$\phi_\mathrm{rec} - \phi_\mathrm{sim}$ [\si{\degree}]") - graph.set_title(r"$N_\mathrm{MIP} \geq %d$" % N) + graph.set_xlabel(r'$\phi_\mathrm{sim}$ [\si{\degree}]') + graph.set_ylabel(r'$\phi_\mathrm{rec} - \phi_\mathrm{sim}$ [\si{\degree}]') + graph.set_title(r'$N_\mathrm{MIP} \geq %d$' % N) graph.set_xticks([-180, -90, '...', 180]) graph.set_xlimits(-180, 180) graph.set_ylimits(-17, 17) @@ -606,7 +607,7 @@ def boxplot_arrival_times(group, N): t3 = sel[:]['t3'] t4 = sel[:]['t4'] ts = concatenate([t1, t3, t4]) - print "Median arrival time delay over all detected events", median(ts) + print('Median arrival time delay over all detected events', median(ts)) figure() @@ -630,9 +631,9 @@ def boxplot_arrival_times(group, N): fill_between(x, t25, t75, color='0.75') plot(x, t50, 'o-', color='black') - xlabel("Core distance [m]") - ylabel("Arrival time delay [ns]") - #title(r"$N_{MIP} \geq %d, \quad \theta = 0^\circ$" % N) + xlabel('Core distance [m]') + ylabel('Arrival time delay [ns]') + # title(r"$N_{MIP} \geq %d, \quad \theta = 0^\circ$" % N) xticks(arange(0, 100.5, 10)) @@ -642,8 +643,8 @@ def boxplot_arrival_times(group, N): graph = GraphArtist() graph.shade_region(x, t25, t75) graph.plot(x, t50, linestyle=None) - graph.set_xlabel(r"Core distance [\si{\meter}]") - graph.set_ylabel(r"Arrival time difference $|t_2 - t_1|$ [\si{\nano\second}]") + graph.set_xlabel(r'Core distance [\si{\meter}]') + graph.set_ylabel(r'Arrival time difference $|t_2 - t_1|$ [\si{\nano\second}]') graph.set_xlimits(0, 100) graph.set_ylimits(min=0) artist.utils.save_graph(graph, suffix=N, dirname='plots') @@ -657,8 +658,8 @@ def get_median_core_distances_for_mips(group, N_list): x = [] for N in N_list: sel = table.read_where('min_n134 >= N') - #query = '(n1 == N) & (n3 == N) & (n4 == N)' - #sel = table.read_where(query) + # query = '(n1 == N) & (n3 == N) & (n4 == N)' + # sel = table.read_where(query) r = sel[:]['r'] r_list.append(r) x.append(N) @@ -666,7 +667,7 @@ def get_median_core_distances_for_mips(group, N_list): r25.append(scoreatpercentile(r, 25)) r50.append(scoreatpercentile(r, 50)) r75.append(scoreatpercentile(r, 75)) - print "MIP, median, mean", N, r50[-1], mean(r), std(r) / mean(r) + print('MIP, median, mean', N, r50[-1], mean(r), std(r) / mean(r)) return r50 @@ -693,17 +694,17 @@ def boxplot_core_distances_for_mips(group): plot(x, r50, 'o-', color='black') xticks(range(1, 5)) - xlabel("Minimum number of particles") - ylabel("Core distance [m]") - #title(r"$\theta = 22.5^\circ$") + xlabel('Minimum number of particles') + ylabel('Core distance [m]') + # title(r"$\theta = 22.5^\circ$") utils.saveplot() graph = GraphArtist() graph.shade_region(x, r25, r75) graph.plot(x, r50, linestyle=None) - graph.set_xlabel("Minimum number of particles") - graph.set_ylabel(r"Core distance [\si{\meter}]") + graph.set_xlabel('Minimum number of particles') + graph.set_ylabel(r'Core distance [\si{\meter}]') graph.set_ylimits(min=0) graph.set_xticks(range(5)) artist.utils.save_graph(graph, dirname='plots') @@ -731,10 +732,10 @@ def plot_detection_efficiency_vs_R_for_angles(N): figure() graph = GraphArtist() locations = iter(['right', 'left', 'below left']) - positions = iter([.18, .14, .15]) + positions = iter([0.18, 0.14, 0.15]) bin_edges = linspace(0, 100, 20) - x = (bin_edges[:-1] + bin_edges[1:]) / 2. + x = (bin_edges[:-1] + bin_edges[1:]) / 2.0 for angle in [0, 22.5, 35]: angle_str = str(angle).replace('.', '_') @@ -751,22 +752,24 @@ def plot_detection_efficiency_vs_R_for_angles(N): assert (obs_sel['id'] == ids).all() o = obs_sel - sel = obs_sel.compress((o['n1'] >= N) & (o['n3'] >= N) & - (o['n4'] >= N)) + sel = obs_sel.compress((o['n1'] >= N) & (o['n3'] >= N) & (o['n4'] >= N)) shower_results.append(len(sel) / len(obs_sel)) efficiencies.append(mean(shower_results)) plot(x, efficiencies, label=r'$\theta = %s^\circ$' % angle) graph.plot(x, efficiencies, mark=None) - graph.add_pin(r'\SI{%s}{\degree}' % angle, - location=locations.next(), use_arrow=True, - relative_position=positions.next()) - - xlabel("Core distance [m]") - graph.set_xlabel(r"Core distance [\si{\meter}]") - ylabel("Detection efficiency") - graph.set_ylabel("Detection efficiency") - #title(r"$N_{MIP} \geq %d$" % N) + graph.add_pin( + r'\SI{%s}{\degree}' % angle, + location=locations.next(), + use_arrow=True, + relative_position=positions.next(), + ) + + xlabel('Core distance [m]') + graph.set_xlabel(r'Core distance [\si{\meter}]') + ylabel('Detection efficiency') + graph.set_ylabel('Detection efficiency') + # title(r"$N_{MIP} \geq %d$" % N) legend() graph.set_xlimits(0, 100) graph.set_ylimits(0, 1) @@ -781,7 +784,7 @@ def plot_reconstruction_efficiency_vs_R_for_angles(N): figure() bin_edges = linspace(0, 100, 10) - x = (bin_edges[:-1] + bin_edges[1:]) / 2. + x = (bin_edges[:-1] + bin_edges[1:]) / 2.0 all_data = [] @@ -801,8 +804,7 @@ def plot_reconstruction_efficiency_vs_R_for_angles(N): assert (obs_sel['id'] == ids).all() o = obs_sel - sel = obs_sel.compress((o['n1'] >= N) & (o['n3'] >= N) & - (o['n4'] >= N)) + sel = obs_sel.compress((o['n1'] >= N) & (o['n3'] >= N) & (o['n4'] >= N)) shower_results.append(len(sel)) ssel = reconstructions.read_where('(min_n134 >= N) & (low <= r) & (r < high)') efficiencies.append(len(ssel) / sum(shower_results)) @@ -810,9 +812,9 @@ def plot_reconstruction_efficiency_vs_R_for_angles(N): all_data.append(efficiencies) plot(x, efficiencies, label=r'$\theta = %s^\circ$' % angle) - xlabel("Core distance [m]") - ylabel("Reconstruction efficiency") - #title(r"$N_{MIP} \geq %d$" % N) + xlabel('Core distance [m]') + ylabel('Reconstruction efficiency') + # title(r"$N_{MIP} \geq %d$" % N) legend() utils.saveplot(N) @@ -825,18 +827,21 @@ def artistplot_reconstruction_efficiency_vs_R_for_angles(N): graph = GraphArtist() locations = iter(['above right', 'below left', 'below left']) - positions = iter([.9, .2, .2]) + positions = iter([0.9, 0.2, 0.2]) x = all_data[:, 0] for angle, efficiencies in zip([0, 22.5, 35], all_data[:, 1:].T): graph.plot(x, efficiencies, mark=None) - graph.add_pin(r'\SI{%s}{\degree}' % angle, use_arrow=True, - location=locations.next(), - relative_position=positions.next()) - - graph.set_xlabel(r"Core distance [\si{\meter}]") - graph.set_ylabel("Reconstruction efficiency") + graph.add_pin( + r'\SI{%s}{\degree}' % angle, + use_arrow=True, + location=locations.next(), + relative_position=positions.next(), + ) + + graph.set_xlabel(r'Core distance [\si{\meter}]') + graph.set_ylabel('Reconstruction efficiency') graph.set_xlimits(0, 100) graph.set_ylimits(max=1) artist.utils.save_graph(graph, suffix=N, dirname='plots') @@ -848,7 +853,7 @@ def plot_reconstruction_efficiency_vs_R_for_mips(): figure() bin_edges = linspace(0, 100, 10) - x = (bin_edges[:-1] + bin_edges[1:]) / 2. + x = (bin_edges[:-1] + bin_edges[1:]) / 2.0 for N in range(1, 5): shower_group = '/simulations/E_1PeV/zenith_22_5' @@ -868,14 +873,14 @@ def plot_reconstruction_efficiency_vs_R_for_mips(): shower_results.append(len(sel)) ssel = reconstructions.read_where('(min_n134 == N) & (low <= r) & (r < high)') - print sum(shower_results), len(ssel), len(ssel) / sum(shower_results) + print(sum(shower_results), len(ssel), len(ssel) / sum(shower_results)) efficiencies.append(len(ssel) / sum(shower_results)) plot(x, efficiencies, label=r'$N_{MIP} = %d$' % N) - xlabel("Core distance [m]") - ylabel("Reconstruction efficiency") - #title(r"$\theta = 22.5^\circ$") + xlabel('Core distance [m]') + ylabel('Reconstruction efficiency') + # title(r"$\theta = 22.5^\circ$") legend() utils.saveplot() @@ -883,9 +888,14 @@ def plot_reconstruction_efficiency_vs_R_for_mips(): def plot_2d_histogram(x, y, bins): H, xedges, yedges = histogram2d(x, y, bins) - imshow(H.T, extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], - origin='lower left', interpolation='lanczos', aspect='auto', - cmap=cm.Greys) + imshow( + H.T, + extent=[xedges[0], xedges[-1], yedges[0], yedges[-1]], + origin='lower left', + interpolation='lanczos', + aspect='auto', + cmap=cm.Greys, + ) colorbar() @@ -906,11 +916,10 @@ def make_datasets_failed_reconstructions_scatter(data): dt1, dt2, phis_sim, phis_rec = [], [], [], [] gdt1, gdt2, gphis_sim, gphis_rec = [], [], [], [] - for event, coincidence in pbar(izip(observables, coincidences), - len(observables)): + for event, coincidence in pbar(izip(observables, coincidences), len(observables)): assert event['id'] == coincidence['id'] - if min(event['n1'], event['n3'], event['n4']) >= 1.: - theta, phi = reconstruct_angle(event, 10.) + if min(event['n1'], event['n3'], event['n4']) >= 1.0: + theta, phi = reconstruct_angle(event, 10.0) assert not isnan(phi) if isnan(theta): @@ -926,41 +935,41 @@ def make_datasets_failed_reconstructions_scatter(data): def plot_failed_and_successful_scatter_plots(): - figure(figsize=(20., 11.5)) + figure(figsize=(20.0, 11.5)) subplot(231) plot(gdt1, rad2deg(gphis_sim), ',', c='green') plot(dt1, rad2deg(phis_sim), ',', c='red') - xlabel(r"$t_1 - t_3$ [ns]") - ylabel(r"$\phi_{sim}$") + xlabel(r'$t_1 - t_3$ [ns]') + ylabel(r'$\phi_{sim}$') xlim(-200, 200) subplot(232) plot(gdt2, rad2deg(gphis_sim), ',', c='green') plot(dt2, rad2deg(phis_sim), ',', c='red') - xlabel(r"$t_1 - t_4$ [ns]") - ylabel(r"$\phi_{sim}$") + xlabel(r'$t_1 - t_4$ [ns]') + ylabel(r'$\phi_{sim}$') xlim(-200, 200) subplot(234) plot(gdt1, rad2deg(gphis_rec), ',', c='green') plot(dt1, rad2deg(phis_rec), ',', c='red') - xlabel(r"$t_1 - t_3$ [ns]") - ylabel(r"$\phi_{rec}$") + xlabel(r'$t_1 - t_3$ [ns]') + ylabel(r'$\phi_{rec}$') xlim(-200, 200) subplot(235) plot(gdt2, rad2deg(gphis_rec), ',', c='green') plot(dt2, rad2deg(phis_rec), ',', c='red') - xlabel(r"$t_1 - t_4$ [ns]") - ylabel(r"$\phi_{rec}$") + xlabel(r'$t_1 - t_4$ [ns]') + ylabel(r'$\phi_{rec}$') xlim(-200, 200) subplot(233) plot(gdt1, gdt2, ',', c='green') plot(dt1, dt2, ',', c='red') - xlabel(r"$t_1 - t_3$ [ns]") - ylabel(r"$t_1 - t_4$ [ns]") + xlabel(r'$t_1 - t_3$ [ns]') + ylabel(r'$t_1 - t_4$ [ns]') xlim(-200, 200) ylim(-200, 200) @@ -982,11 +991,11 @@ def plot_failed_histograms(): subplot(121) hist(c * dt1 / (10 * cos(phis - phi1)), bins=linspace(-20, 20, 100)) - xlabel(r"$c \, \Delta t_1 / (r_1 \cos(\phi - \phi_1))$") + xlabel(r'$c \, \Delta t_1 / (r_1 \cos(\phi - \phi_1))$') subplot(122) hist(c * dt2 / (10 * cos(phis - phi2)), bins=linspace(-20, 20, 100)) - xlabel(r"$c \, \Delta t_2 / (r_2 \cos(\phi - \phi_2))$") + xlabel(r'$c \, \Delta t_2 / (r_2 \cos(\phi - \phi_2))$') utils.saveplot() @@ -1007,7 +1016,7 @@ def plot_uncertainty_zenith_angular_distance(group): graph = GraphArtist() # Uncertainty estimate x = linspace(0, deg2rad(45), 50) - #x = array([pi / 8]) + # x = array([pi / 8]) phis = linspace(-pi, pi, 50) y, y2 = [], [] for t in x: @@ -1015,29 +1024,29 @@ def plot_uncertainty_zenith_angular_distance(group): y2.append(mean(rec.rel_theta1_errorsq(t, phis, phi1, phi2, r1, r2))) y = TIMING_ERROR * sqrt(array(y)) y2 = TIMING_ERROR * sqrt(array(y2)) - ang_dist = sqrt((y * sin(x)) ** 2 + y2 ** 2) - #plot(rad2deg(x), rad2deg(y), label="Estimate Phi") - #plot(rad2deg(x), rad2deg(y2), label="Estimate Theta") - plot(rad2deg(x), rad2deg(ang_dist), label="Angular distance") + ang_dist = sqrt((y * sin(x)) ** 2 + y2**2) + # plot(rad2deg(x), rad2deg(y), label="Estimate Phi") + # plot(rad2deg(x), rad2deg(y2), label="Estimate Theta") + plot(rad2deg(x), rad2deg(ang_dist), label='Angular distance') graph.plot(rad2deg(x), rad2deg(ang_dist), mark=None) - print rad2deg(x) - print rad2deg(y) - print rad2deg(y2) - print rad2deg(y * sin(x)) - print rad2deg(ang_dist) + print(rad2deg(x)) + print(rad2deg(y)) + print(rad2deg(y2)) + print(rad2deg(y * sin(x))) + print(rad2deg(ang_dist)) # Labels etc. - xlabel("Shower zenith angle [deg]") - ylabel("Angular distance [deg]") - graph.set_xlabel(r"Shower zenith angle [\si{\degree}]") - graph.set_ylabel(r"Angular distance [\si{\degree}]") + xlabel('Shower zenith angle [deg]') + ylabel('Angular distance [deg]') + graph.set_xlabel(r'Shower zenith angle [\si{\degree}]') + graph.set_ylabel(r'Angular distance [\si{\degree}]') graph.set_ylimits(min=6) - #title(r"$N_{MIP} \geq %d$" % N) - #ylim(0, 100) - #legend(numpoints=1) + # title(r"$N_{MIP} \geq %d$" % N) + # ylim(0, 100) + # legend(numpoints=1) utils.saveplot() artist.utils.save_graph(graph, dirname='plots') - print + print() if __name__ == '__main__': @@ -1051,13 +1060,13 @@ def plot_uncertainty_zenith_angular_distance(group): data = tables.open_file('master-ch4v2.h5', 'r') if '/reconstructions' not in data: - print "Reconstructing shower direction..." + print('Reconstructing shower direction...') do_full_reconstruction(data) else: - print "Skipping reconstruction!" + print('Skipping reconstruction!') - utils.set_prefix("DIR-") - artist.utils.set_prefix("DIR-") + utils.set_prefix('DIR-') + artist.utils.set_prefix('DIR-') do_reconstruction_plots(data) # These currently don't work diff --git a/scripts/simulations/discrete_directions.py b/scripts/simulations/discrete_directions.py index d087d540..698b5608 100644 --- a/scripts/simulations/discrete_directions.py +++ b/scripts/simulations/discrete_directions.py @@ -1,24 +1,26 @@ import itertools -import numpy as np - import matplotlib.pyplot as plt +import numpy as np -from sapphire.analysis.direction_reconstruction import (DirectAlgorithm, DirectAlgorithmCartesian2D, - DirectAlgorithmCartesian3D, FitAlgorithm) -from sapphire.clusters import HiSPARCStations, ScienceParkCluster, SingleDiamondStation +from sapphire.analysis.direction_reconstruction import ( + DirectAlgorithmCartesian3D, +) +from sapphire.clusters import HiSPARCStations, ScienceParkCluster TIME_RESOLUTION = 2.5 # nanoseconds -C = .3 # lightspeed m/ns +C = 0.3 # lightspeed m/ns -def generate_discrete_times(station, detector_ids=[0, 2, 3]): +def generate_discrete_times(station, detector_ids=None): """Generates possible arrival times for detectors The times are relative to the first detector, which is assumed to be at t = 0. """ + if detector_ids is None: + detector_ids = [0, 2, 3] r = station_size(station, detector_ids) max_dt = ceil_in_base(r / C, TIME_RESOLUTION) times = np.arange(-max_dt, max_dt, TIME_RESOLUTION) @@ -26,14 +28,15 @@ def generate_discrete_times(station, detector_ids=[0, 2, 3]): return time_combinations -def station_size(station, detector_ids=[0, 2, 3]): +def station_size(station, detector_ids=None): """Get the largest distance between any two detectors in a station :param detectors: list of :class:`sapphire.clusters.Detector` objects """ - r = [station.calc_r_and_phi_for_detectors(d0, d1)[0] - for d0, d1 in itertools.combinations(detector_ids, 2)] + if detector_ids is None: + detector_ids = [0, 2, 3] + r = [station.calc_r_and_phi_for_detectors(d0, d1)[0] for d0, d1 in itertools.combinations(detector_ids, 2)] return max(r) @@ -42,7 +45,6 @@ def ceil_in_base(value, base): if __name__ == '__main__': - station_number = 502 dirrec = DirectAlgorithmCartesian3D() @@ -50,15 +52,17 @@ def ceil_in_base(value, base): station = HiSPARCStations([station_number]).get_station(station_number) except: station = ScienceParkCluster([station_number]).get_station(station_number) - #station = SingleDiamondStation().stations[0] + # station = SingleDiamondStation().stations[0] fig = plt.figure(figsize=(15, 10)) - sets = [plt.subplot2grid((2,3), (0,0), projection="polar"), - plt.subplot2grid((2,3), (1,0), projection="polar"), - plt.subplot2grid((2,3), (0,1), projection="polar"), - plt.subplot2grid((2,3), (1,1), projection="polar")] - combined = plt.subplot2grid((2,3), (1,2), projection="polar") - layout = plt.subplot2grid((2,3), (0,2)) + sets = [ + plt.subplot2grid((2, 3), (0, 0), projection='polar'), + plt.subplot2grid((2, 3), (1, 0), projection='polar'), + plt.subplot2grid((2, 3), (0, 1), projection='polar'), + plt.subplot2grid((2, 3), (1, 1), projection='polar'), + ] + combined = plt.subplot2grid((2, 3), (1, 2), projection='polar') + layout = plt.subplot2grid((2, 3), (0, 2)) # plt.setp(sets[0].get_xticklabels(), visible=False) # plt.setp(sets[2].get_xticklabels(), visible=False) @@ -73,8 +77,7 @@ def ceil_in_base(value, base): layout.axis('equal') layout.scatter(x, y, s=15, marker='o', color='black') for id in [0, 1, 2, 3]: - layout.annotate('%d' % id, (x[id], y[id]), xytext=(3, 3), - textcoords='offset points') + layout.annotate('%d' % id, (x[id], y[id]), xytext=(3, 3), textcoords='offset points') layout.set_ylabel('northing (m)') layout.set_xlabel('easting (m)') @@ -85,8 +88,7 @@ def ceil_in_base(value, base): detectors = [station.detectors[id].get_coordinates() for id in ids] x, y, z = zip(*detectors) - theta, phi = itertools.izip(*(dirrec.reconstruct_common((0,) + t, x, y, z) - for t in times)) + theta, phi = itertools.izip(*(dirrec.reconstruct_common((0,) + t, x, y, z) for t in times)) thetaa = np.degrees(np.array([t for t in theta if not np.isnan(t)])) phia = [p for p in phi if not np.isnan(p)] @@ -102,8 +104,7 @@ def ceil_in_base(value, base): x, y, z = zip(*detectors) for t1 in (0, 10, 20, 30): times = ((t1, x) for x in np.arange(-60, 60, TIME_RESOLUTION)) - theta, phi = itertools.izip(*(dirrec.reconstruct_common((0,) + t, x, y, z) - for t in times)) + theta, phi = itertools.izip(*(dirrec.reconstruct_common((0,) + t, x, y, z) for t in times)) thetaa = np.degrees(np.array([t for t in theta if not np.isnan(t)])) phia = [p for p in phi if not np.isnan(p)] sets[i].plot(phia, thetaa, color='red') @@ -113,6 +114,5 @@ def ceil_in_base(value, base): sets[0].set_ylabel('Zenith (degrees)') sets[3].set_xlabel('Azimuth (degrees)') - fig.suptitle('Station: %d - Time resolution: %.1f ns' % - (station_number, TIME_RESOLUTION)) + fig.suptitle('Station: %d - Time resolution: %.1f ns' % (station_number, TIME_RESOLUTION)) plt.show() diff --git a/scripts/simulations/ldf_sim.py b/scripts/simulations/ldf_sim.py index 5ebff2d7..91b75da1 100644 --- a/scripts/simulations/ldf_sim.py +++ b/scripts/simulations/ldf_sim.py @@ -1,11 +1,11 @@ -""" HiSPARC detector simulation +"""HiSPARC detector simulation - This simulation takes an Extended Air Shower simulation ground - particles file and uses that to simulate numerous showers hitting a - HiSPARC detector station. Only data of one shower is used, but by - randomly selecting points on the ground as the position of a station, - the effect of the same shower hitting various positions around the - station is simulated. +This simulation takes an Extended Air Shower simulation ground +particles file and uses that to simulate numerous showers hitting a +HiSPARC detector station. Only data of one shower is used, but by +randomly selecting points on the ground as the position of a station, +the effect of the same shower hitting various positions around the +station is simulated. """ @@ -31,14 +31,23 @@ simulation = KascadeLdfSimulation(cluster, data, '/ldfsim/exact', R=60, N=N) simulation.run(max_theta=pi / 3) - simulation = KascadeLdfSimulation(cluster, data, '/ldfsim/gauss_10', R=60, N=N, gauss=.1, trig_threshold=.9) + simulation = KascadeLdfSimulation(cluster, data, '/ldfsim/gauss_10', R=60, N=N, gauss=0.1, trig_threshold=0.9) simulation.run(max_theta=pi / 3) - simulation = KascadeLdfSimulation(cluster, data, '/ldfsim/gauss_20', R=60, N=N, gauss=.2, trig_threshold=.8) + simulation = KascadeLdfSimulation(cluster, data, '/ldfsim/gauss_20', R=60, N=N, gauss=0.2, trig_threshold=0.8) simulation.run(max_theta=pi / 3) simulation = KascadeLdfSimulation(cluster, data, '/ldfsim/poisson', R=60, N=N, use_poisson=True) simulation.run(max_theta=pi / 3) - simulation = KascadeLdfSimulation(cluster, data, '/ldfsim/poisson_gauss_20', R=60, N=N, use_poisson=True, gauss=.2, trig_threshold=.5) + simulation = KascadeLdfSimulation( + cluster, + data, + '/ldfsim/poisson_gauss_20', + R=60, + N=N, + use_poisson=True, + gauss=0.2, + trig_threshold=0.5, + ) simulation.run(max_theta=pi / 3) diff --git a/scripts/simulations/ldf_sim_to_model.py b/scripts/simulations/ldf_sim_to_model.py index 6a4f9c5b..99d5d727 100644 --- a/scripts/simulations/ldf_sim_to_model.py +++ b/scripts/simulations/ldf_sim_to_model.py @@ -1,7 +1,7 @@ -from itertools import izip - import tables +from pylab import * + def plot_ldf_and_models(data, group): global binned_densities @@ -14,8 +14,8 @@ def plot_ldf_and_models(data, group): # plot ldf from ground particles R = particles[:]['core_distance'] N, hbins = histogram(R, bins) - area = [pi * (v ** 2 - u ** 2) for u, v in zip(bins[:-1], bins[1:])] - loglog(x, N / area, label="full") + area = [pi * (v**2 - u**2) for u, v in zip(bins[:-1], bins[1:])] + loglog(x, N / area, label='full') # plot ldf from observables observables = data.root.simulations.E_1PeV.zenith_0.observables @@ -24,10 +24,10 @@ def plot_ldf_and_models(data, group): n2 = observables[:]['n2'] n3 = observables[:]['n3'] n4 = observables[:]['n4'] - density = (n1 + n2 + n3 + n4) / 2. + density = (n1 + n2 + n3 + n4) / 2.0 - R_within_limits = R.compress((R < bins[-1]) & (R >= bins[0])) - density_within_R_limits = density.compress((R < bins[-1]) & (R >= bins[0])) + R_within_limits = R.compress((bins[-1] > R) & (bins[0] <= R)) + density_within_R_limits = density.compress((bins[-1] > R) & (bins[0] <= R)) idxs = searchsorted(bins, R_within_limits) - 1 binned_densities = [[] for i in range(len(bins) - 1)] @@ -35,11 +35,11 @@ def plot_ldf_and_models(data, group): binned_densities[idx].append(value) y = [mean(u) for u in binned_densities] y_err = [std(u) for u in binned_densities] - errorbar(x, y, y_err, label="measured") + errorbar(x, y, y_err, label='measured') - xlabel("Core distance [m]") - ylabel("Particle density [m^{-2}]") - title("Full and measured LDF (E = 1 PeV)") + xlabel('Core distance [m]') + ylabel('Particle density [m^{-2}]') + title('Full and measured LDF (E = 1 PeV)') legend() @@ -63,13 +63,13 @@ def plot_ldf_ldf(data, group): n2 = observables['n2'] n3 = observables['n3'] n4 = observables['n4'] - density.append((n1 + n2 + n3 + n4) / 2.) + density.append((n1 + n2 + n3 + n4) / 2.0) R = array(R) density = array(density) - R_within_limits = R.compress((R < bins[-1]) & (R >= bins[0])) - density_within_R_limits = density.compress((R < bins[-1]) & (R >= bins[0])) + R_within_limits = R.compress((bins[-1] > R) & (bins[0] <= R)) + density_within_R_limits = density.compress((bins[-1] > R) & (bins[0] <= R)) idxs = searchsorted(bins, R_within_limits) - 1 binned_densities = [[] for i in range(len(bins) - 1)] @@ -77,7 +77,7 @@ def plot_ldf_ldf(data, group): binned_densities[idx].append(value) y = [mean(u) for u in binned_densities] y_err = [std(u) for u in binned_densities] - errorbar(x, y, y_err, label="measured ldf_sim") + errorbar(x, y, y_err, label='measured ldf_sim') legend() diff --git a/scripts/simulations/master.py b/scripts/simulations/master.py index 855e8c11..406660bc 100644 --- a/scripts/simulations/master.py +++ b/scripts/simulations/master.py @@ -2,9 +2,8 @@ import re import warnings -import tables - import store_aires_data +import tables from sapphire import clusters from sapphire.simulations import GroundParticlesSimulation, QSubSimulation @@ -14,10 +13,10 @@ N_CORES = 32 -class Master(object): +class Master: def __init__(self, data_filename): if os.path.exists(data_filename): - warnings.warn("%s already exists, some steps are skipped" % data_filename) + warnings.warn('%s already exists, some steps are skipped' % data_filename) self.data = tables.open_file(data_filename, 'a') def main(self): @@ -29,8 +28,7 @@ def main(self): def store_shower_data(self): for angle in [0, 5, 10, 15, 22.5, 30, 35, 45]: self.store_1PeV_data_for_angle(angle) - for energy, group_name in [('e14', 'E_100TeV'), - ('e16', 'E_10PeV')]: + for energy, group_name in [('e14', 'E_100TeV'), ('e16', 'E_10PeV')]: self.store_data_for_energy(energy, group_name) def store_1PeV_data_for_angle(self, angle): @@ -85,7 +83,7 @@ def perform_simulation(self, cluster, shower, output_path=None): try: sim = Simulation(*args, **kwargs) - except RuntimeError, msg: + except RuntimeError as msg: print(msg) return else: diff --git a/scripts/simulations/myshowerfront.py b/scripts/simulations/myshowerfront.py index beddf140..b84444e2 100644 --- a/scripts/simulations/myshowerfront.py +++ b/scripts/simulations/myshowerfront.py @@ -8,7 +8,7 @@ def get_front_arrival_time(sim, R, dR, theta): query = '(R - dR <= core_distance) & (core_distance < R + dR)' - c = 3e-1 # m / ns + c = 3e-1 # m / ns t_list = [] for shower in sim: @@ -32,8 +32,8 @@ def monte_carlo_timings(n, bins, size): t_list = [] while len(t_list) < size: - x = random.uniform(x0, x1) - y = random.uniform(y0, y1) + x = random.default_rng().uniform(x0, x1) + y = random.default_rng().uniform(y0, y1) idx = bins.searchsorted(x) - 1 if y <= n[idx]: t_list.append(x) @@ -47,13 +47,13 @@ def my_std_t(data, N): t = get_front_arrival_time(sim, 30, 5, pi / 8) n, bins = histogram(t, bins=linspace(0, 50, 401)) mct = monte_carlo_timings(n, bins, 10000) - print("Monte Carlo:", N) + print('Monte Carlo:', N) mint_list = [] i = 0 while i < len(mct): try: - values = mct[i:i + N] + values = mct[i : i + N] except IndexError: break if len(values) == N: @@ -70,13 +70,13 @@ def my_std_t_for_R(data, N_list, R_list): t = get_front_arrival_time(sim, R, 5, pi / 8) n, bins = histogram(t, bins=linspace(0, 50, 401)) mct = monte_carlo_timings(n, bins, 10000) - print("Monte Carlo:", N) + print('Monte Carlo:', N) mint_list = [] i = 0 while i < len(mct): try: - values = mct[i:i + N] + values = mct[i : i + N] except IndexError: break if len(values) == N: @@ -91,13 +91,13 @@ def my_t_draw_something(data, N, num_events): t = get_front_arrival_time(sim, 20, 5, pi / 8) n, bins = histogram(t, bins=linspace(0, 50, 201)) mct = monte_carlo_timings(n, bins, num_events * N) - print("Monte Carlo:", N) + print('Monte Carlo:', N) mint_list = [] i = 0 while i < len(mct): try: - values = mct[i:i + N] + values = mct[i : i + N] except IndexError: break if len(values) == N: @@ -106,26 +106,30 @@ def my_t_draw_something(data, N, num_events): return mint_list -def plot_R(): +def plot_core_distance(): graph = GraphArtist(width=r'.45\linewidth') - n, bins, patches = hist(data.root.simulations.E_1PeV.zenith_22_5.shower_0.coincidences.col('r'), bins=100, histtype='step') + n, bins, patches = hist( + data.root.simulations.E_1PeV.zenith_22_5.shower_0.coincidences.col('r'), + bins=100, + histtype='step', + ) graph.histogram(n, bins, linestyle='black!50') shower = data.root.simulations.E_1PeV.zenith_22_5.shower_0 ids = shower.observables.get_where_list('(n1 >= 1) & (n3 >= 1) & (n4 >= 1)') - R = shower.coincidences.read_coordinates(ids, field='r') + core_distance = shower.coincidences.read_coordinates(ids, field='r') n, bins, patches = hist(R, bins=100, histtype='step') graph.histogram(n, bins) - xlabel("Core distance [m]") - ylabel("Number of events") + xlabel('Core distance [m]') + ylabel('Number of events') - print("mean", mean(R)) - print("median", median(R)) + print('mean', mean(core_distance)) + print('median', median(core_distance)) - graph.set_xlabel(r"Core distance [\si{\meter}]") - graph.set_ylabel("Number of events") + graph.set_xlabel(r'Core distance [\si{\meter}]') + graph.set_ylabel('Number of events') graph.set_xlimits(min=0) graph.set_ylimits(min=0) graph.save('plots/SIM-R') @@ -146,11 +150,11 @@ def plot_arrival_times(): n, bins, patches = hist(mint, bins=linspace(0, 20, 101), histtype='step') graph.histogram(n, bins) - xlabel("Arrival time [ns]") - ylabel("Number of events") + xlabel('Arrival time [ns]') + ylabel('Number of events') - graph.set_xlabel(r"Arrival time [\si{\nano\second}]") - graph.set_ylabel("Number of events") + graph.set_xlabel(r'Arrival time [\si{\nano\second}]') + graph.set_ylabel('Number of events') graph.set_xlimits(0, 20) graph.set_ylimits(min=0) graph.save('plots/SIM-T') @@ -159,8 +163,8 @@ def plot_arrival_times(): if __name__ == '__main__': - if not 'data' in globals(): + if 'data' not in globals(): data = tables.open_file('master-ch4v2.h5') - plot_R() + plot_core_distance() plot_arrival_times() diff --git a/scripts/simulations/plot_coordinate_systems.py b/scripts/simulations/plot_coordinate_systems.py index 67faceb4..697ee18a 100644 --- a/scripts/simulations/plot_coordinate_systems.py +++ b/scripts/simulations/plot_coordinate_systems.py @@ -1,7 +1,7 @@ -from pylab import * - import utils +from pylab import * + def plot_coordinate_systems(): figure() @@ -37,7 +37,7 @@ def transform_coordinates(x, y, alpha): xp, yp, alphap = [], [], [] for u, v, w in zip(x, y, alpha): - r = sqrt(u ** 2 + v ** 2) + r = sqrt(u**2 + v**2) phi = arctan2(v, u) phi += pi - w @@ -91,7 +91,7 @@ def generate_random_coordinates_in_circle(R, N=100): while len(x) < N: u, v = uniform(-10, 10, 2) - if u ** 2 + v ** 2 <= R ** 2: + if u**2 + v**2 <= R**2: x.append(u) y.append(v) alpha.append(uniform(-pi, pi)) diff --git a/scripts/simulations/random_energy.py b/scripts/simulations/random_energy.py index 0f50e3c6..85453ec0 100644 --- a/scripts/simulations/random_energy.py +++ b/scripts/simulations/random_energy.py @@ -1,4 +1,10 @@ -flux = lambda x: x ** -2.7 +import random + +from numpy import array + + +def flux(x): + return x**-2.7 def random_energy(a, b, size=1): @@ -6,7 +12,7 @@ def random_energy(a, b, size=1): y1 = flux(b) energies = [] - for i in range(size): + for _ in range(size): while True: x = random.uniform(a, b) y = random.uniform(y0, y1) diff --git a/scripts/simulations/store_aires_data.py b/scripts/simulations/store_aires_data.py index 40d2287a..acbf0275 100644 --- a/scripts/simulations/store_aires_data.py +++ b/scripts/simulations/store_aires_data.py @@ -1,20 +1,20 @@ -""" Store AIRES simulation data in HDF5 file +"""Store AIRES simulation data in HDF5 file - This module reads the AIRES binary ground particles file and stores - each particle individually in a HDF5 file, using PyTables. This file - can then be used as input for the detector simulation. +This module reads the AIRES binary ground particles file and stores +each particle individually in a HDF5 file, using PyTables. This file +can then be used as input for the detector simulation. """ + import os import os.path import sys +import aires import tables from numpy import * -import aires - from sapphire.storage import ShowerParticle sys.path.append(os.path.expanduser('~/work/HiSPARC/software/bzr/shower')) @@ -25,12 +25,12 @@ def save_particle(row, p, id): row['id'] = id row['pid'] = p.code - row['core_distance'] = 10 ** p.core_distance + row['core_distance'] = 10**p.core_distance row['polar_angle'] = p.polar_angle row['arrival_time'] = p.arrival_time - row['energy'] = 10 ** p.energy - row['x'] = 10 ** p.core_distance * cos(p.polar_angle) - row['y'] = 10 ** p.core_distance * sin(p.polar_angle) + row['energy'] = 10**p.energy + row['x'] = 10**p.core_distance * cos(p.polar_angle) + row['y'] = 10**p.core_distance * sin(p.polar_angle) row.append() @@ -41,10 +41,10 @@ def store_aires_data(data, group_name, file): print('%s already exists, doing nothing' % group_name) return - print(f"Storing AIRES data ({file}) in {group}") + print(f'Storing AIRES data ({file}) in {group}') if not os.path.exists(file): - raise RuntimeError("File %s does not exist" % file) + raise RuntimeError('File %s does not exist' % file) else: sim = aires.SimulationData('', file) @@ -52,13 +52,16 @@ def store_aires_data(data, group_name, file): shower_group = data.create_group(group, 'shower_%d' % shower_num) print(shower_group) - leptons = data.create_table(shower_group, 'leptons', ShowerParticle, - 'Electrons, positrons, muons and anti-muons') + leptons = data.create_table( + shower_group, + 'leptons', + ShowerParticle, + 'Electrons, positrons, muons and anti-muons', + ) leptons_row = leptons.row for id, p in enumerate(shower.particles()): - if p.name == 'muon' or p.name == 'anti-muon' or \ - p.name == 'electron' or p.name == 'positron': + if p.name == 'muon' or p.name == 'anti-muon' or p.name == 'electron' or p.name == 'positron': save_particle(leptons_row, p, id) leptons.flush() diff --git a/scripts/simulations/test_coordinate_transform.py b/scripts/simulations/test_coordinate_transform.py index aacbf2a1..64052664 100644 --- a/scripts/simulations/test_coordinate_transform.py +++ b/scripts/simulations/test_coordinate_transform.py @@ -6,27 +6,26 @@ def plot_station_and_shower_transforms(event_id): # plot old coordinates coincidence = sim.coincidences[event_id] - r, phi, alpha = coincidence['r'], coincidence['phi'], \ - coincidence['alpha'] + r, phi, alpha = coincidence['r'], coincidence['phi'], coincidence['alpha'] cluster.set_rphialpha_coordinates(r, phi, alpha) plot_cluster(0.2) - scatter(0, 0, c='r', alpha=.2) + scatter(0, 0, c='r', alpha=0.2) # plot new coordinates coincidence = test_output.coincidences[event_id] x, y = coincidence['x'], coincidence['y'] cluster.set_rphialpha_coordinates(0, 0, 0) - plot_cluster(1.) - scatter(x, y, c='r', alpha=1.) - scatter(0, 0, c='white', alpha=1.) - scatter(0, 0, c='r', alpha=.2) + plot_cluster(1.0) + scatter(x, y, c='r', alpha=1.0) + scatter(0, 0, c='white', alpha=1.0) + scatter(0, 0, c='r', alpha=0.2) # plot coordinates stored in 'observables' table for event in test_output.observables.read_where('id == %d' % event_id): scatter(event['x'], event['y'], c='lightgreen') - xlabel("[m]") - ylabel("[m]") + xlabel('[m]') + ylabel('[m]') def plot_cluster(alpha): diff --git a/scripts/simulations/toy_energy_densities.py b/scripts/simulations/toy_energy_densities.py index e01878a9..5d041b5b 100644 --- a/scripts/simulations/toy_energy_densities.py +++ b/scripts/simulations/toy_energy_densities.py @@ -18,7 +18,7 @@ def main(self): densities = [] weights = [] for E in np.linspace(1e13, 1e17, 10000): - relative_flux = E ** -2.7 + relative_flux = E**-2.7 Ne = 10 ** (np.log10(E) - 15 + 4.8) self.ldf = KascadeLdf(Ne) min_dens = self.calculate_minimum_density_for_station_at_R(R) @@ -28,12 +28,11 @@ def main(self): weights = np.array(weights) densities = np.array(densities).T - weighted_densities = (np.sum(weights * densities, axis=1) / - np.sum(weights)) + weighted_densities = np.sum(weights * densities, axis=1) / np.sum(weights) plt.plot(R, weighted_densities) plt.yscale('log') - plt.ylabel("Min. density [m^{-2}]") - plt.xlabel("Core distance [m]") + plt.ylabel('Min. density [m^{-2}]') + plt.xlabel('Core distance [m]') plt.axvline(5.77) plt.show() @@ -44,8 +43,7 @@ def calculate_minimum_density_for_station_at_R(self, R): def calculate_densities_for_station_at_R(self, R): densities = [] for detector in self.station.detectors: - densities.append(self.calculate_densities_for_detector_at_R( - detector, R)) + densities.append(self.calculate_densities_for_detector_at_R(detector, R)) return np.array(densities) def calculate_densities_for_detector_at_R(self, detector, R): diff --git a/scripts/simulations/utils.py b/scripts/simulations/utils.py index cf113ac5..1e5ad069 100644 --- a/scripts/simulations/utils.py +++ b/scripts/simulations/utils.py @@ -1,10 +1,9 @@ -""" Utility functions """ +"""Utility functions""" import inspect -import numpy as np - import matplotlib.pyplot as plt +import numpy as np __suffix = '' __prefix = '' @@ -54,4 +53,5 @@ def savedata(data, suffix=''): def title(text): plt.title(text + '\n(%s)' % __suffix) + mylog = np.vectorize(lambda x: np.log10(x) if x > 0 else 0) diff --git a/scripts/tests/check_timings.py b/scripts/tests/check_timings.py index 4e647b96..fa4c3574 100644 --- a/scripts/tests/check_timings.py +++ b/scripts/tests/check_timings.py @@ -10,6 +10,7 @@ new code. """ + import numpy as np import tables @@ -28,9 +29,9 @@ def main(): # Get timings from trace process = process_events.ProcessEvents(data, '/') - t0 = process._reconstruct_time_from_trace(trace, 0.) + t0 = process._reconstruct_time_from_trace(trace, 0.0) process = process_events.ProcessEventsWithLINT(data, '/') - t1 = process._reconstruct_time_from_trace(trace, 0.) + t1 = process._reconstruct_time_from_trace(trace, 0.0) t0 /= 2.5e-9 t1 /= 2.5e-9 print(t0, t1) diff --git a/scripts/tests/coordinate_transform_benchmarks.py b/scripts/tests/coordinate_transform_benchmarks.py index 70af604e..8e836897 100644 --- a/scripts/tests/coordinate_transform_benchmarks.py +++ b/scripts/tests/coordinate_transform_benchmarks.py @@ -6,20 +6,20 @@ transformspeeds tests the speed of the new transformations """ + import datetime -import random as r +import random import time -import numpy as np - import matplotlib.pyplot as plt +import numpy as np from sapphire.transformations import celestial, clock from sapphire.utils import angle_between def transformspeeds(): - print("Running speeds for 100.000 transformations of the astropy functions:") + print('Running speeds for 100.000 transformations of the astropy functions:') a = np.array([(0, 0)] * 100000) t0 = time.clock() @@ -36,7 +36,7 @@ def transformspeeds(): celestial.zenithazimuth_to_equatorial_astropy(0, 0, 1_000_000_000, a) t4 = time.clock() - t0 - print("EQ->HO, EQ-> ZA, HO->EQ, ZA->EQ runtimes:") + print('EQ->HO, EQ-> ZA, HO->EQ, ZA->EQ runtimes:') print(t1, t2, t3, t4) @@ -56,7 +56,7 @@ def angle_between_horizontal(azimuth1, altitude1, azimuth2, altitude2): zenith2, azimuth2 = celestial.horizontal_to_zenithazimuth(altitude2, azimuth2) dlat = zenith1 - zenith2 dlon = azimuth2 - azimuth1 - a = (np.sin(dlat / 2) ** 2 + np.sin(zenith1) * np.sin(zenith2) * np.sin(dlon / 2) ** 2) + a = np.sin(dlat / 2) ** 2 + np.sin(zenith1) * np.sin(zenith2) * np.sin(dlon / 2) ** 2 angle = 2 * np.arcsin(np.sqrt(a)) return angle @@ -77,28 +77,37 @@ def oldvsnew_diagram(): refer to him for when something is unclear """ # make random frames, in correct angle range and from utc time 2000-2020 - frames = [] # boxes for the four different transformation results etoha = [] etoh = [] htoe = [] htoea = [] - straight = lambda x : x # straight trendline function + straight = lambda x: x # straight trendline function # Create the data sets for eq to az - for i in range(100): - frames.append((r.uniform(-90, 90), - r.uniform(-180,180), - r.randint(946684800,1577836800), - r.uniform(0, 2 * np.pi), - r.uniform(-0.5 * np.pi, 0.5 * np.pi))) + frames = [ + ( + random.uniform(-90, 90), + random.uniform(-180, 180), + random.randint(946684800, 1577836800), + random.uniform(0, 2 * np.pi), + random.uniform(-0.5 * np.pi, 0.5 * np.pi), + ) + for _ in range(100) + ] for i in frames: etoha.append(celestial.equatorial_to_zenithazimuth_astropy(i[0], i[1], i[2], [(i[3], i[4])])[0]) etoh.append(celestial.equatorial_to_zenithazimuth(i[0], i[1], clock.utc_to_gps(i[2]), i[3], i[4])) # Data sets for hor to eq for i in frames: - htoe.append(celestial.horizontal_to_equatorial(i[0], - clock.utc_to_lst(datetime.datetime.utcfromtimestamp(i[2]), i[1]), i[4], i[3])) + htoe.append( + celestial.horizontal_to_equatorial( + i[0], + clock.utc_to_lst(datetime.datetime.utcfromtimestamp(i[2]), i[1]), + i[4], + i[3], + ), + ) htoea.extend(celestial.horizontal_to_equatorial_astropy(i[0], i[1], i[2], [(i[3], i[4])])) # Make figs eq -> zenaz @@ -108,7 +117,7 @@ def oldvsnew_diagram(): zenrange = [0, np.pi] plt.subplot(211) plt.title('Zenith') - plt.axis(zenrange*2) + plt.axis(zenrange * 2) plt.xlabel('New (Astropy)') plt.ylabel('Old') @@ -119,18 +128,18 @@ def oldvsnew_diagram(): plt.subplot(212) plt.title('Azimuth') azrange = [-np.pi, np.pi] - plt.axis(azrange*2) + plt.axis(azrange * 2) plt.xlabel('New (Astropy)') plt.ylabel('Old') # Make figure and add 1:1 trendline plt.plot([co[1] for co in etoha], [co[1] for co in etoh], 'b.', azrange, straight(azrange), '-') - plt.tight_layout() # Prevent titles merging + plt.tight_layout() # Prevent titles merging plt.subplots_adjust(top=0.85) # Make histogram of differences plt.figure(2) # Take diff. and convert to arcsec - nieuw = (np.array(etoh) - np.array(etoha)) + nieuw = np.array(etoh) - np.array(etoha) nieuw *= 360 * 3600 / (2 * np.pi) plt.hist([i[0] for i in nieuw], bins=20) @@ -147,8 +156,7 @@ def oldvsnew_diagram(): # Make histogram of differences using the absolute distance in arcsec # this graph has no wrapping issues plt.figure(7) - nieuw = np.array([angle_between(etoh[i][0], etoh[i][1], etoha[i][0], etoha[i][1]) - for i in range(len(etoh))]) + nieuw = np.array([angle_between(etoh[i][0], etoh[i][1], etoha[i][0], etoha[i][1]) for i in range(len(etoh))]) nieuw *= 360 * 3600 / (2 * np.pi) plt.hist(nieuw, bins=20) plt.title('ZEN+AZ Old-New Error (equatorial_to_zenithazimuth)') @@ -182,7 +190,7 @@ def oldvsnew_diagram(): # Make histogram of differences plt.figure(5) # Take diff. and convert to arcsec - nieuw = (np.array(htoe) - np.array(htoea)) + nieuw = np.array(htoe) - np.array(htoea) nieuw *= 360 * 3600 / (2 * np.pi) plt.hist([i[1] for i in nieuw], bins=20) plt.title('Declination Old-New Error (horizontal_to_equatorial)') @@ -191,7 +199,7 @@ def oldvsnew_diagram(): plt.figure(6) # Take diff. and convert to arcsec - nieuw = (np.array(htoe) - np.array(htoea)) + nieuw = np.array(htoe) - np.array(htoea) nieuw *= 360 * 3600 / (2 * np.pi) plt.hist([i[0] for i in nieuw], bins=20) plt.title('Right Ascension Old-New Error (horizontal_to_equatorial)') @@ -201,8 +209,9 @@ def oldvsnew_diagram(): # Make histogram of differences using the absolute distance in arcsec # this graph has no wrapping issues plt.figure(8) - nieuw = np.array([angle_between_horizontal(htoe[i][0], htoe[i][1], htoea[i][0], htoea[i][1]) - for i in range(len(htoe))]) + nieuw = np.array( + [angle_between_horizontal(htoe[i][0], htoe[i][1], htoea[i][0], htoea[i][1]) for i in range(len(htoe))], + ) # Take diff. and convert to arcsec nieuw /= 2 / np.pi * 360 * 3600 plt.hist(nieuw, bins=20) @@ -211,13 +220,13 @@ def oldvsnew_diagram(): plt.ylabel('Counts') plt.show() - return try: # This try-except block contains a pyephem accuracy benchmarking function. # It uses this structure to accommodate people without pyephem. import ephem + def pyephem_comp(): """ This function compares the values from transformations done by our @@ -235,17 +244,21 @@ def pyephem_comp(): """ # Set up randoms equatorial J2000 bodies # that we will convert the RAs/Decs of. - eq = [] # random frames to use + eq = [] # random frames to use for i in range(100): - eq.append((r.uniform(-90, 90), - r.uniform(-180, 180), - r.randint(946684800, 1577836800), - r.uniform(0, 2 * np.pi), - r.uniform(-0.5 * np.pi, 0.5 * np.pi))) - efemeq = [] # store pyephem transformations to equatorial - altaz = [] # store pyephem transformations to altaz (horizontal) - htoea = [] # store astropy transformations to equatorial - etoha = [] # store astropy transformations to horizontal (altaz) + eq.append( + ( + random.uniform(-90, 90), + random.uniform(-180, 180), + random.randint(946684800, 1577836800), + random.uniform(0, 2 * np.pi), + random.uniform(-0.5 * np.pi, 0.5 * np.pi), + ), + ) + efemeq = [] # store pyephem transformations to equatorial + altaz = [] # store pyephem transformations to altaz (horizontal) + htoea = [] # store astropy transformations to equatorial + etoha = [] # store astropy transformations to horizontal (altaz) for latitude, longitude, utc, ra, dec in eq: # Calculate altaz # Set observer for each case @@ -253,7 +266,7 @@ def pyephem_comp(): obs.lat = str(latitude) obs.lon = str(longitude) obs.date = datetime.datetime.utcfromtimestamp(utc) - obs.pressure = 0 # Crucial to prevent refraction correction! + obs.pressure = 0 # Crucial to prevent refraction correction! # Set body for each case coord = ephem.FixedBody() @@ -266,7 +279,7 @@ def pyephem_comp(): altaz.append((float(coord.az), float(coord.alt))) # Also calculate efemeq using eq - result = obs.radec_of(ra, dec) # This is of course not ra,dec but + result = obs.radec_of(ra, dec) # This is of course not ra,dec but # actually az, alt. efemeq.append((float(result[0]), float(result[1]))) @@ -300,7 +313,7 @@ def pyephem_comp(): # DEC correlation subplot plt.subplot(212) plt.title('DEC') - plt.axis(altdecrange*2) + plt.axis(altdecrange * 2) plt.xlabel('Pyephem DEC (rad)') plt.ylabel('Astropy DEC (rad)') @@ -315,7 +328,7 @@ def pyephem_comp(): plt.figure(2) plt.title('RA Error Altaz->(Astropy/Pyephem)->RA,DEC') - nieuw = (np.array(htoea) - np.array(efemeq)) + nieuw = np.array(htoea) - np.array(efemeq) # Get differences in arcsec nieuw *= 360 * 3600 / (2 * np.pi) @@ -339,7 +352,7 @@ def pyephem_comp(): # Altitude plt.subplot(211) plt.title('Altitude') - plt.axis(altdecrange*2) + plt.axis(altdecrange * 2) plt.xlabel('Pyephem Altitude (rad)') plt.ylabel('Astropy Altitude (rad') @@ -349,7 +362,7 @@ def pyephem_comp(): # Azimuth plt.subplot(212) plt.title('Azimuth') - plt.axis(azrarange*2) + plt.axis(azrarange * 2) plt.xlabel('Pyephem Azimuth (rad)') plt.ylabel('Astropy Azimuth (rad)') @@ -362,7 +375,7 @@ def pyephem_comp(): # Alt error histogram plt.figure(5) plt.title('Altitude Error RA,DEC->(pyephem/astropy)->Altaz') - nieuw = (np.array(etoha) - np.array(altaz)) + nieuw = np.array(etoha) - np.array(altaz) nieuw *= 360 * 3600 / (2 * np.pi) plt.hist([co[1] for co in nieuw], bins=20) @@ -381,8 +394,9 @@ def pyephem_comp(): # these graphs have no wrapping issues plt.figure(7) - nieuw = np.array([angle_between_horizontal(altaz[i][0], altaz[i][1], etoha[i][0], etoha[i][1]) - for i in range(len(etoha))]) + nieuw = np.array( + [angle_between_horizontal(altaz[i][0], altaz[i][1], etoha[i][0], etoha[i][1]) for i in range(len(etoha))], + ) nieuw *= 360 * 3600 / (2 * np.pi) plt.hist(nieuw, bins=20) plt.title('Alt+Azi Error RA,DEC->(pyephem/astropy)->Altaz') @@ -390,8 +404,9 @@ def pyephem_comp(): plt.ylabel('Counts') plt.figure(8) - nieuw = np.array([angle_between_horizontal(efemeq[i][0], efemeq[i][1], htoea[i][0], htoea[i][1]) - for i in range(len(htoea))]) + nieuw = np.array( + [angle_between_horizontal(efemeq[i][0], efemeq[i][1], htoea[i][0], htoea[i][1]) for i in range(len(htoea))], + ) # Take difference and convert to arcsec nieuw *= 360 * 3600 / (2 * np.pi) @@ -406,4 +421,4 @@ def pyephem_comp(): except ImportError: # Pyephem is not required so there is a case for when it is not present def pyephem_comp(): - print("Pyephem not present; no comparisons will be done") + print('Pyephem not present; no comparisons will be done') diff --git a/scripts/tests/process_events_without_traces.py b/scripts/tests/process_events_without_traces.py index 571aa900..9ded1047 100644 --- a/scripts/tests/process_events_without_traces.py +++ b/scripts/tests/process_events_without_traces.py @@ -6,6 +6,7 @@ ProcessIndexedEventsWithoutTraces class. """ + import datetime import tables @@ -16,9 +17,8 @@ if __name__ == '__main__': data = tables.open_file('testdata.h5', 'a') if '/s501' not in data: - download_data(data, '/s501', 501, datetime.datetime(2013, 1, 1), - datetime.datetime(2013, 1, 2), get_blobs=False) + download_data(data, '/s501', 501, datetime.datetime(2013, 1, 1), datetime.datetime(2013, 1, 2), get_blobs=False) process = process_events.ProcessEventsWithoutTraces(data, '/s501') process.process_and_store_results(overwrite=True) offsets = process.determine_detector_timing_offsets() - print("Offsets:", offsets) + print('Offsets:', offsets) diff --git a/scripts/tests/search_coincidences.py b/scripts/tests/search_coincidences.py index 307dbfd8..c7ce3eb6 100644 --- a/scripts/tests/search_coincidences.py +++ b/scripts/tests/search_coincidences.py @@ -5,6 +5,7 @@ This script tests the process of searching for coincidences. """ + import datetime import tables @@ -28,6 +29,6 @@ coincidences.search_and_store_coincidences() # This is the manual method - #coincidences.search_coincidences() - #coincidences.process_events(overwrite=True) - #coincidences.store_coincidences() + # coincidences.search_coincidences() + # coincidences.process_events(overwrite=True) + # coincidences.store_coincidences() diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index e599df94..00000000 --- a/setup.cfg +++ /dev/null @@ -1,40 +0,0 @@ -[flake8] -ignore = - E501 - W503 - W504 - N801 -per-file-ignores = - sapphire/corsika/units.py:N816 - sapphire/storage.py:N815 - sapphire/tests/simulations/test_gammas.py:N806 - sapphire/tests/transformations/test_celestial.py:N806 - -[isort] -profile = black -line_length = 110 -known_extras = - artist - pylab -known_first_party = - sapphire -sections = - FUTURE - STDLIB - THIRDPARTY - EXTRAS - FIRSTPARTY - LOCALFOLDER -lines_between_types = 1 - -[coverage:run] -branch = true -source = . -omit = - /doc/* - /scripts/* - -[coverage:report] -show_missing = true -skip_empty = true -skip_covered = true diff --git a/setup.py b/setup.py deleted file mode 100644 index 65bd0ee6..00000000 --- a/setup.py +++ /dev/null @@ -1,62 +0,0 @@ -from setuptools import find_packages, setup - -# set version number and write to sapphire/version.py -version = '2.0.0' - -version_py = """\ -# Created by setup.py. Do not edit. -__version__ = "{version}" -""" -with open('sapphire/version.py', 'w') as f: - f.write(version_py.format(version=version)) - - -setup(name='hisparc-sapphire', - version=version, - packages=find_packages(), - url='https://github.com/hisparc/sapphire/', - bugtrack_url='https://github.com/HiSPARC/sapphire/issues', - license='GPLv3', - author='David Fokkema, Arne de Laat, Tom Kooij, and others', - author_email='davidf@nikhef.nl, arne@delaat.net', - description='A framework for the HiSPARC experiment', - long_description=open('README.rst').read(), - keywords=['HiSPARC', 'Nikhef', 'cosmic rays'], - classifiers=[ - 'Intended Audience :: Science/Research', - 'Intended Audience :: Education', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Topic :: Scientific/Engineering', - 'Topic :: Education', - 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)'], - entry_points={ - 'console_scripts': [ - 'create_and_store_test_data = sapphire.tests.create_and_store_test_data:main', - 'update_local_data = sapphire.data.update_local_data:main', - 'extend_local_data = sapphire.data.extend_local_data:main']}, - package_data={'sapphire': ['data/*.json', - 'data/*/*.json', - 'data/current/*.tsv', - 'data/detector_timing_offsets/*.tsv', - 'data/electronics/*.tsv', - 'data/gps/*.tsv', - 'data/layout/*.tsv', - 'data/station_timing_offsets/*/*.tsv', - 'data/trigger/*.tsv', - 'data/voltage/*.tsv', - 'corsika/LICENSE', - 'tests/test_data/*.h5', - 'tests/test_data/*.tsv', - 'tests/test_data/*.dat', - 'tests/analysis/test_data/*.h5', - 'tests/corsika/test_data/*.h5', - 'tests/corsika/test_data/*/DAT000000', - 'tests/corsika/test_data/*/*.h5', - 'tests/simulations/test_data/*.h5']}, - install_requires=['numpy', 'scipy', 'tables>=3.3.0', 'progressbar2>=3.7.0'], - extras_require={ - 'dev': ['Sphinx', 'flake8', 'pep8-naming', 'coverage', 'flake8-isort'], - 'astropy': ["astropy"]}, -)