Skip to content

Commit

Permalink
Merge pull request #804 from padraic-shafer/int64-index
Browse files Browse the repository at this point in the history
Accept numpy.int64 as databroker catalog key
  • Loading branch information
tacaswell committed Apr 16, 2024
2 parents 23071d5 + 854caaa commit b9ce855
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 1 deletion.
1 change: 1 addition & 0 deletions databroker/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def _(key, db):
@search.register(numbers.Integral)
def _(key, db):
logger.info('Interpreting key = %s as an integer' % key)
key = int(key)
if key > -1:
# Interpret key as a scan_id.
gen = db.hs.find_run_starts(scan_id=key)
Expand Down
4 changes: 3 additions & 1 deletion databroker/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections.abc
import json
import keyword
import numbers
import warnings

from tiled.adapters.utils import IndexCallable
Expand Down Expand Up @@ -251,13 +252,14 @@ def __getitem__(self, key):
# Fall back to partial uid lookup below.
pass
return self._lookup_by_partial_uid(key)
elif isinstance(key, int):
elif isinstance(key, numbers.Integral):
if key > 0:
# CASE 2: Interpret key as a scan_id.
return self._lookup_by_scan_id(key)
else:
# CASE 3: Interpret key as a recently lookup, as in
# `catalog[-1]` is the latest entry.
key = int(key)
return self.values()[key]
elif isinstance(key, slice):
if (key.start is None) or (key.start >= 0):
Expand Down
56 changes: 56 additions & 0 deletions databroker/tests/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import tempfile
import os
import logging
import numbers
import packaging
import sys
import string
import time as ttime
import uuid
from contextlib import nullcontext as does_not_raise
from datetime import date, timedelta
import itertools
from databroker import (wrap_in_doct, wrap_in_deprecated_doct,
Expand Down Expand Up @@ -237,6 +239,60 @@ def test_indexing(db_empty, RE, hw):
db[-11]


@pytest.mark.parametrize(
"key, expected",
(
# These values are in range...
(np.int64(1), does_not_raise()), # Key is a Scan ID
(np.int64(-1), does_not_raise()), # Key is Nth-last scan
# These values are out of range...
(np.int64(10), pytest.raises(KeyError)), # Scan ID does not exist
(-np.int64(10), pytest.raises(IndexError)), # Abs(key) > number of scans
),
)
def test_int64_indexing(db_empty, RE, hw, key, expected):
"""numpy.int64 can be used as a catalog key; it is treated as an int key"""
db = db_empty
RE.subscribe(db.insert)

uids = []
for i in range(2):
uids.extend(get_uids(RE(count([hw.det]))))

assert not isinstance(key, int)
assert isinstance(key, numbers.Integral)
assert key == int(key)
with expected:
db[key]


@pytest.mark.parametrize(
"key, expected",
(
# >32-bit values are ok...
(np.int64(2**33), does_not_raise()), # Scan ID exists
# But these values are out of range...
(np.int64(2**33 + 2), pytest.raises(KeyError)), # Scan ID does not exist
(-np.int64(2**33), pytest.raises(IndexError)), # Abs(key) > number of scans
),
)
def test_large_int_indexing(db_empty, RE, hw, key, expected):
"""Integer-valued catalog key can exceed 32-bit values"""
db = db_empty
RE.md["scan_id"] = 2**33 - 1
RE.subscribe(db.insert)

uids = []
for i in range(2):
uids.extend(get_uids(RE(count([hw.det]))))

assert not isinstance(key, int)
assert isinstance(key, numbers.Integral)
assert key == int(key)
with expected:
db[key]


def test_full_text_search(db_empty, RE, hw):
db = db_empty
RE.subscribe(db.insert)
Expand Down

0 comments on commit b9ce855

Please sign in to comment.