diff --git a/poetry.lock b/poetry.lock index 459f8f4..bae44a9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -416,6 +416,17 @@ category = "main" optional = false python-versions = ">=3.6" +[[package]] +name = "h5py" +version = "3.9.0" +description = "Read and write HDF5 files from Python" +category = "main" +optional = false +python-versions = ">=3.8" + +[package.dependencies] +numpy = ">=1.17.3" + [[package]] name = "html5lib" version = "1.1" @@ -790,6 +801,14 @@ category = "dev" optional = false python-versions = ">=2.6" +[[package]] +name = "peewee" +version = "3.16.3" +description = "a little orm" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "pexpect" version = "4.8.0" @@ -801,6 +820,14 @@ python-versions = "*" [package.dependencies] ptyprocess = ">=0.5" +[[package]] +name = "pgpasslib" +version = "1.1.0" +description = "Library for getting passwords from PostgreSQL password files" +category = "main" +optional = false +python-versions = "*" + [[package]] name = "pickleshare" version = "0.7.5" @@ -867,6 +894,14 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" [package.extras] test = ["ipaddress", "mock", "unittest2", "enum34", "pywin32", "wmi"] +[[package]] +name = "psycopg2-binary" +version = "2.9.7" +description = "psycopg2 - Python-PostgreSQL Database Adapter" +category = "main" +optional = false +python-versions = ">=3.6" + [[package]] name = "ptyprocess" version = "0.7.0" @@ -1216,6 +1251,31 @@ sdsstools = ">=0.4.5" dev = ["Sphinx (>=3.0.0)", "sphinx-bootstrap-theme (>=0.4.12)", "recommonmark (>=0.6)", "sphinx-argparse (>=0.2.5)", "sphinx-issues (>=1.2.0)", "importlib-metadata (>=1.6.0)", "jinja2 (<3.1)", "six (>=1.14)", "ipython (>=7.9.0)", "matplotlib (>=3.1.4)", "flake8 (>=3.7.9)", "doc8 (>=0.8.0)", "pytest (>=5.2.2)", "pytest-cov (>=2.8.1)", "pytest-mock (>=1.13.0)", "pytest-sugar (>=0.9.2)", "isort (>=4.3.21)", "codecov (>=2.0.15)", "coverage[toml] (>=5.0)", "coveralls (>=1.7)", "ipdb (>=0.12.3)", "sdsstools[dev] (>=0.4.0)", "invoke (>=1.3.0)", "twine (>=3.1.4)", "wheel (>=0.33.6)"] docs = ["Sphinx (>=3.0.0)", "sphinx-bootstrap-theme (>=0.4.12)", "recommonmark (>=0.6)", "sphinx-argparse (>=0.2.5)", "sphinx-issues (>=1.2.0)", "importlib-metadata (>=1.6.0)", "jinja2 (<3.1)", "six (>=1.14)"] +[[package]] +name = "sdssdb" +version = "0.8.0" +description = "SDSS product for database management" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +h5py = ">=3.8.0" +numpy = ">=1.18.2" +peewee = ">=3.9.6" +pgpasslib = ">=1.1.0" +psycopg2-binary = ">=2.7.7" +pygments = "*" +pyyaml = ">=5.1" +sdsstools = ">=0.1.11" +six = ">=1.12.0" +sqlalchemy = ">=1.3.6" + +[package.extras] +all = ["progressbar2 (>=3.46.1)", "pydot (>=1.4.1)", "astropy (>=4.0.0)", "pandas (>=1.0.0)", "inflect (>=4.1.0)"] +dev = ["pytest (>=5.2)", "pytest-cov (>=2.4.0)", "pytest-sugar (>=0.8.0)", "ipython (>=7.13.0)", "ipdb (>=0.13.2)", "pytest-postgresql (>=2.2.1)", "factory-boy (>=2.12.0)", "pytest-factoryboy (>=2.0.3)", "astropy (>=4.0.0)", "pydot (>=1.4.2)"] +docs = ["Sphinx (>=3.0.0,<4.0.0)", "sphinx-bootstrap-theme (>=0.4.12)", "jinja2 (==3.0.0)"] + [[package]] name = "sdsstools" version = "1.0.2" @@ -1580,7 +1640,7 @@ docs = ["Sphinx"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "a11fa078812138b3bb784a1f209acab65d12e6b597ca18749bdd2d84f17200e1" +content-hash = "20013247d33d7e402f34f0832e0bd9ebb722aa3c1c6281dbc578b855d3ee219d" [metadata.files] aiofiles = [ @@ -1789,6 +1849,7 @@ h11 = [ {file = "h11-0.12.0-py3-none-any.whl", hash = "sha256:36a3cb8c0a032f56e2da7084577878a035d3b61d104230d4bd49c0c6b555a9c6"}, {file = "h11-0.12.0.tar.gz", hash = "sha256:47222cb6067e4a307d535814917cd98fd0a57b6788ce715755fa2b6c28b56042"}, ] +h5py = [] html5lib = [] httpcore = [] httpx = [] @@ -2054,10 +2115,12 @@ pbr = [ {file = "pbr-5.8.1-py2.py3-none-any.whl", hash = "sha256:27108648368782d07bbf1cb468ad2e2eeef29086affd14087a6d04b7de8af4ec"}, {file = "pbr-5.8.1.tar.gz", hash = "sha256:66bc5a34912f408bb3925bf21231cb6f59206267b7f63f3503ef865c1a292e25"}, ] +peewee = [] pexpect = [ {file = "pexpect-4.8.0-py2.py3-none-any.whl", hash = "sha256:0b48a55dcb3c05f3329815901ea4fc1537514d6ba867a152b581d69ae3710937"}, {file = "pexpect-4.8.0.tar.gz", hash = "sha256:fc65a43959d153d0114afe13997d439c22823a27cefceb5ff35c2178c6784c0c"}, ] +pgpasslib = [] pickleshare = [ {file = "pickleshare-0.7.5-py2.py3-none-any.whl", hash = "sha256:9649af414d74d4df115d5d718f82acb59c9d418196b7b4290ed47a12ce62df56"}, {file = "pickleshare-0.7.5.tar.gz", hash = "sha256:87683d47965c1da65cdacaf31c8441d12b8044cdec9aca500cd78fc2c683afca"}, @@ -2144,6 +2207,7 @@ psutil = [ {file = "psutil-5.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:7d190ee2eaef7831163f254dc58f6d2e2a22e27382b936aab51c835fc080c3d3"}, {file = "psutil-5.9.0.tar.gz", hash = "sha256:869842dbd66bb80c3217158e629d6fceaecc3a3166d3d1faee515b05dd26ca25"}, ] +psycopg2-binary = [] ptyprocess = [ {file = "ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35"}, {file = "ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220"}, @@ -2332,6 +2396,7 @@ rstcheck = [ ] sdss-access = [] sdss-tree = [] +sdssdb = [] sdsstools = [] secretstorage = [] setuptools-scm = [ diff --git a/pyproject.toml b/pyproject.toml index 33a79a3..d0127d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ httpx = "^0.24.0" astroquery = "^0.4.6" pandas = "^1.5.3" SQLAlchemy = "^1.4.35" +sdssdb = "^0.8.0" [tool.poetry.dev-dependencies] diff --git a/python/valis/db/__init__.py b/python/valis/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/valis/db/db.py b/python/valis/db/db.py new file mode 100644 index 0000000..356a743 --- /dev/null +++ b/python/valis/db/db.py @@ -0,0 +1,87 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# + +from contextvars import ContextVar + +import peewee +from fastapi import Depends, HTTPException +from sdssdb.peewee.sdss5db import database as pdb +from sdssdb.sqlalchemy.sdss5db import database as sdb + +# To make Peewee async-compatible, we need to hack the peewee connection state +# See FastAPI/Peewee docs at https://fastapi.tiangolo.com/how-to/sql-databases-peewee/ + +# create global context variables for the peewee connection state +db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} +db_state = ContextVar("db_state", default=db_state_default.copy()) + + +async def reset_db_state(): + """ Sub-depdency for get_db that resets the context connection state """ + pdb._state._state.set(db_state_default.copy()) + pdb._state.reset() + + +class PeeweeConnectionState(peewee._ConnectionState): + """ Custom Connection State for Peewee """ + def __init__(self, **kwargs): + super().__setattr__("_state", db_state) + super().__init__(**kwargs) + + def __setattr__(self, name, value): + self._state.get()[name] = value + + def __getattr__(self, name): + return self._state.get()[name] + + +# override the database connection state, after db connection +pdb._state = PeeweeConnectionState() + + +def connect_db(db, orm: str = 'peewee'): + """ Connect to the peewee sdss5db database """ + + from valis.main import settings + profset = db.set_profile(settings.db_server) + if settings.db_remote and not profset: + port = settings.db_port + user = settings.db_user + host = settings.db_host + passwd = settings.db_pass + db.connect_from_parameters(dbname='sdss5db', host=host, port=port, + user=user, password=passwd) + + # raise error if we cannot connect + if not db.connected: + raise HTTPException(status_code=503, detail=f'Could not connect to database via sdssdb {orm}.') + + return db + + +def get_pw_db(db_state=Depends(reset_db_state)): + """ Dependency to connect a database with peewee """ + + # connect to the db, yield None since we don't need the db in peewee + db = connect_db(pdb, orm='peewee') + try: + yield None + finally: + if db: + db.close() + + +def get_sqla_db(): + """ Dependency to connect to a database with sqlalchemy """ + + # connect to the db, yield the db Session object for sql queries + db = connect_db(sdb, orm='sqla') + db = db.Session() + try: + yield db + finally: + if db: + db.close() + + diff --git a/python/valis/db/models.py b/python/valis/db/models.py new file mode 100644 index 0000000..7c07b4d --- /dev/null +++ b/python/valis/db/models.py @@ -0,0 +1,79 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# + +# all resuable Pydantic models of the ORMs go here + +import peewee +from typing import Any, Optional +from pydantic import BaseModel, Field +from pydantic.utils import GetterDict + + +class PeeweeGetterDict(GetterDict): + """ Class to convert peewee.ModelSelect into a list """ + def get(self, key: Any, default: Any = None): + res = getattr(self._obj, key, default) + if isinstance(res, peewee.ModelSelect): + return list(res) + return res + + +class OrmBase(BaseModel): + """ Base pydantic model for sqlalchemy ORMs """ + + class Config: + orm_mode = True + + +class PeeweeBase(OrmBase): + """ Base pydantic model for peewee ORMs """ + + class Config: + orm_mode = True + getter_dict = PeeweeGetterDict + + +# class SDSSidStackedBaseA(OrmBase): +# """ Pydantic model for the SQLA vizdb.SDSSidStacked ORM """ + +# sdss_id: int = Field(..., description='the SDSS identifier') +# ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid') +# dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid') +# catalogid21: Optional[int] = Field(description='the version 21 catalog id') +# catalogid25: Optional[int] = Field(description='the version 25 catalog id') +# catalogid31: Optional[int] = Field(description='the version 31 catalog id') + + +class SDSSidStackedBase(PeeweeBase): + """ Pydantic model for the Peewee vizdb.SDSSidStacked ORM """ + + sdss_id: int = Field(..., description='the SDSS identifier') + ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid') + dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid') + catalogid21: Optional[int] = Field(description='the version 21 catalog id') + catalogid25: Optional[int] = Field(description='the version 25 catalog id') + catalogid31: Optional[int] = Field(description='the version 31 catalog id') + + +class SDSSidFlatBase(PeeweeBase): + """ Pydantic model for the Peewee vizdb.SDSSidFlat ORM """ + + sdss_id: int = Field(..., description='the SDSS identifier') + ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid') + dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid') + catalogid: int = Field(..., descrption='the catalogid associated with the given sdss_id') + version_id: int = Field(..., descrption='the version of the catalog for a given catalogid') + n_associated: int = Field(..., description='The total number of sdss_ids associated with that catalogid.') + ra_cat: float = Field(..., descrption='Right Ascension, in degrees, specific to the catalogid') + dec_cat: float = Field(..., descrption='Declination, in degrees, specific to the catalogid') + + +class SDSSidPipesBase(PeeweeBase): + """ Pydantic model for the Peewee vizdb.SDSSidToPipes ORM """ + + sdss_id: int = Field(..., description='the SDSS identifier') + in_boss: bool = Field(..., description='Flag if the sdss_id is in the BHM reductions') + in_apogee: bool = Field(..., description='Flag if the sdss_id is in the Apogee/MWM reductions') + in_astra: bool = Field(..., description='Flag if the sdss_id is in the Astra reductions') + diff --git a/python/valis/db/queries.py b/python/valis/db/queries.py new file mode 100644 index 0000000..857fb6f --- /dev/null +++ b/python/valis/db/queries.py @@ -0,0 +1,80 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# + +# all resuable queries go here + +from typing import Union +import peewee +import astropy.units as u +from astropy.coordinates import SkyCoord +from sdssdb.peewee.sdss5db import vizdb + + +def convert_coords(ra: Union[str, float], dec: Union[str, float]) -> tuple: + """ Convert sky coordinates to decimal degrees + + Convert the input RA, Dec sky coordinates into decimal + degrees. Input format can either be decimal or hmsdms. + + Parameters + ---------- + ra : str + The Right Ascension + dec : str + The Declination + + Returns + ------- + tuple + the converted (RA, Dec) + """ + is_hms = set('hms: ') & set(str(ra)) + if is_hms: + ra = str(ra).replace(' ', ':') + dec = str(dec).replace(' ', ':') + unit = ('hourangle', 'degree') if is_hms else ('degree', 'degree') + coord = SkyCoord(f'{ra} {dec}', unit=unit) + ra = round(coord.ra.value, 5) + dec = round(coord.dec.value, 5) + return float(ra), float(dec) + + +def cone_search(ra: Union[str, float], dec: Union[str, float], + radius: float, units: str = 'degree') -> peewee.ModelSelect: + """ Perform a cone search against the vizdb sdss_id_stacked table + + Perform a cone search using the peewee ORM for SDSS targets in the + vizdb sdss_id_stacked table. We return the peewee ModelSelect + directly here so it can be easily combined with other queries. + + In the route endpoint itself, remember to return wrap this in a list. + + Parameters + ---------- + ra : Union[str, float] + the Right Ascension coord + dec : Union[str, float] + the Declination coord + radius : float + the cone search radius + units : str, optional + the units of the search radius, by default 'degree' + + Returns + ------- + peewee.ModelSelect + the ORM query + """ + + # convert ra, dec to decimal if in hms-dms + ra, dec = convert_coords(ra, dec) + + # convert radial units to degrees + radius *= u.Unit(units) + radius = radius.to(u.degree).value + + return vizdb.SDSSidStacked.select().\ + where(vizdb.SDSSidStacked.cone_search(ra, dec, radius, + ra_col='ra_sdss_id', + dec_col='dec_sdss_id')) diff --git a/python/valis/main.py b/python/valis/main.py index d212327..924cc5d 100644 --- a/python/valis/main.py +++ b/python/valis/main.py @@ -16,6 +16,7 @@ import os import pathlib from typing import Dict +from functools import lru_cache from fastapi import Depends, FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -23,7 +24,7 @@ from fastapi.staticfiles import StaticFiles import valis -from valis.routes import access, auth, envs, files, info, maskbits, mocs, target +from valis.routes import access, auth, envs, files, info, maskbits, mocs, target, query from valis.routes.auth import set_auth from valis.routes.base import release from valis.settings import Settings, read_valis_config @@ -74,15 +75,20 @@ "name": "mocs", "description": "Access SDSS surveys MOCs", }, + { + "name": "query", + "description": "Query the SDSS databases", + }, ] -from functools import lru_cache + @lru_cache def get_settings(): """ Get the valis settings """ cfg = read_valis_config() return Settings(**cfg) + settings = get_settings() # create the application @@ -93,7 +99,7 @@ def get_settings(): # add CORS for cross-domain, for any sdss.org or sdss.utah.edu domain app.add_middleware(CORSMiddleware, allow_origin_regex="^https://.*\.sdss5?\.(org|utah\.edu)$", - allow_origins=settings.valis_allow_origin, + allow_origins=settings.allow_origin, allow_credentials=True, allow_methods=['*'], allow_headers=['*']) # mount the MOCs to a static path @@ -102,10 +108,12 @@ def get_settings(): hips_dir.mkdir(parents=True, exist_ok=True) app.mount("/static/mocs", StaticFiles(directory=hips_dir, html=True, follow_symlink=True), name="static") + @app.get("/", summary='Hello World route', response_model=Dict[str, str]) def hello(release=Depends(release)): return {"Hello SDSS": "This is the FastAPI World", 'release': release} + app.include_router(access.router, prefix='/paths', tags=['paths'], dependencies=[Depends(set_auth)]) app.include_router(envs.router, prefix='/envs', tags=['envs'], dependencies=[Depends(set_auth)]) app.include_router(files.router, prefix='/file', tags=['file'], dependencies=[Depends(set_auth)]) @@ -114,6 +122,7 @@ def hello(release=Depends(release)): app.include_router(target.router, prefix='/target', tags=['target']) app.include_router(maskbits.router, prefix='/maskbits', tags=['maskbits']) app.include_router(mocs.router, prefix='/mocs', tags=['mocs']) +app.include_router(query.router, prefix='/query', tags=['query']) def hack_auth(dd): @@ -142,6 +151,7 @@ def hack_auth(dd): dd[k] = v.replace(f'Body_{b2}', 'CredForm') return dd + def custom_openapi(): """ Custom OpenAPI spec to remove "release" POST body param from GET requests in the docs """ # cache the openapi schema diff --git a/python/valis/routes/maskbits.py b/python/valis/routes/maskbits.py index a3189c4..6f23374 100644 --- a/python/valis/routes/maskbits.py +++ b/python/valis/routes/maskbits.py @@ -48,7 +48,7 @@ def get_file() -> pathlib.Path: raise HTTPException(status_code=404, detail='Could not find a valid sdssMaskbits.par file. Check proper file paths.') -def read_maskbits(path: pathlib.Path = Depends(get_file)) -> np.recarray: +def read_maskbits(path: pathlib.Path = Depends(get_file)) -> np.recarray: """ Read the maskbits yanny file Read the sdssMasbits.par file with the yanny reader. @@ -76,7 +76,8 @@ def read_maskbits(path: pathlib.Path = Depends(get_file)) -> np.recarray: return data['MASKBITS'] -def make_table(schema: str = Query(..., description='The name of the SDSS flag', example='MANGA_DRP2QUAL'), masks: np.recarray = Depends(read_maskbits)): +def make_table(schema: str = Query(..., description='The name of the SDSS flag', + example='MANGA_DRP2QUAL'), masks: np.recarray = Depends(read_maskbits)): """ Dependency to return an Astropy Table from the maskbits data _extended_summary_ @@ -94,9 +95,7 @@ def make_table(schema: str = Query(..., description='The name of the SDSS flag', _description_ """ tt = Table(masks) - if schema: - return tt[tt["flag"] == schema] - return tt + return tt[tt["flag"] == schema] if schema else tt class MaskBitResponse(BaseModel): @@ -106,6 +105,7 @@ class MaskBitResponse(BaseModel): labels: List[str] = Field(None, description='A list of mask labels') bits: List[int] = Field(None, description='A list of integer mask bits') + class MaskSchema(BaseModel): """ SDSS maskbit schema response model """ bit: List[int] = Field(None, description='A list of integer mask bits') @@ -117,25 +117,29 @@ class MaskSchema(BaseModel): class Maskbits(Base): """ API routes for interacting with SDSS maskbits, defined in sdssMaskbits.par """ - @router.get("/list", summary='List the available maskbits schema / flags', response_model=MaskBitResponse, response_model_exclude_unset=True) - async def get_schema(self, masks = Depends(read_maskbits)) -> dict: + @router.get("/list", summary='List the available maskbits schema / flags', + response_model=MaskBitResponse, response_model_exclude_unset=True) + async def get_schema(self, masks=Depends(read_maskbits)) -> dict: """ Get a list of available SDSS maskbits schema or flag names """ return {"schema": sorted({i[0].decode('utf-8') for i in masks})} - @router.get("/schema", summary='Get the maskbits for a given schema / flag', response_model=Dict[str, MaskSchema]) + @router.get("/schema", summary='Get the maskbits for a given schema / flag', + response_model=Dict[str, MaskSchema]) async def get_bits(self, schema: str, tab: Table = Depends(make_table)) -> dict: """ Get the SDSS maskbit schema for a given flag name """ return {schema: {c: tab[c].tolist() for c in tab.columns if c != 'flag'}} - @router.get("/bits/value", summary='Convert a list of bits into a maskbit value', response_model=MaskBitResponse, response_model_exclude_unset=True) + @router.get("/bits/value", summary='Convert a list of bits into a maskbit value', + response_model=MaskBitResponse, response_model_exclude_unset=True) async def bits_to_value(self, bits: Union[List[int], None] = Query([], description='A list of integer bits', example=[2, 8])) -> dict: """ Convert a list of integer bits into a maskbit value""" print('bits', bits, type(bits)) return {'value': sum(1 << int(i) for i in bits)} - @router.get("/bits/labels", summary='Convert a list of bits into their labels', response_model=MaskBitResponse, response_model_exclude_unset=True) + @router.get("/bits/labels", summary='Convert a list of bits into their labels', + response_model=MaskBitResponse, response_model_exclude_unset=True) async def bits_to_labels(self, bits: Union[List[int], None] = Query([], description='A list of integer bits', example=[2, 8]), tab: Table = Depends(make_table)) -> dict: """ Convert a list of integer bits into their labels for a given schema """ @@ -148,7 +152,8 @@ async def bits_to_labels(self, bits: Union[List[int], None] = Query([], descript else: return {'labels': labels.tolist()} - @router.get("/labels/value", summary='Convert a list of labels into a maskbit value', response_model=MaskBitResponse, response_model_exclude_unset=True) + @router.get("/labels/value", summary='Convert a list of labels into a maskbit value', + response_model=MaskBitResponse, response_model_exclude_unset=True) async def labels_to_value(self, labels: Union[List[str], None] = Query([], description='A list of mask labels', example=['BADIFU', 'SCATFAIL']), tab: Table = Depends(make_table)) -> dict: """ Convert a list of mask labels into a maskbit value for a given schema """ @@ -162,7 +167,8 @@ async def labels_to_value(self, labels: Union[List[str], None] = Query([], descr else: return {'value': sum(1 << int(i) for i in bits)} - @router.get("/labels/bits", summary='Convert a list of labels into their bits', response_model=MaskBitResponse, response_model_exclude_unset=True) + @router.get("/labels/bits", summary='Convert a list of labels into their bits', + response_model=MaskBitResponse, response_model_exclude_unset=True) async def labels_to_bits(self, labels: Union[List[str], None] = Query([], description='A list of mask labels', example=['BADIFU', 'SCATFAIL']), tab: Table = Depends(make_table)) -> dict: """ Convert a list of mask labels into their respective bits for a given schema """ @@ -176,14 +182,16 @@ async def labels_to_bits(self, labels: Union[List[str], None] = Query([], descri else: return {'bits': bits.tolist()} - @router.get("/value/bits", summary='Decompose a maskbit value into a list of bits', response_model=MaskBitResponse, response_model_exclude_unset=True) + @router.get("/value/bits", summary='Decompose a maskbit value into a list of bits', + response_model=MaskBitResponse, response_model_exclude_unset=True) async def value_to_bits(self, value: int = Query(..., description='A maskbit value', example=260), tab: Table = Depends(make_table)) -> dict: """ Decompose a maskbit value into its list of bits for a given schema """ bits = tab['bit'] return {'bits': [int(i) for i in bits if value & 1 << int(i)]} - @router.get("/value/labels", summary='Decompose a maskbit value into a list of labels', response_model=MaskBitResponse, response_model_exclude_unset=True) + @router.get("/value/labels", summary='Decompose a maskbit value into a list of labels', + response_model=MaskBitResponse, response_model_exclude_unset=True) async def value_to_labels(self, value: int = Query(..., description='A maskbit value', example=260), tab: Table = Depends(make_table)) -> dict: """ Decompose a maskbit value into its list of labels for a given schema """ diff --git a/python/valis/routes/query.py b/python/valis/routes/query.py new file mode 100644 index 0000000..8388f37 --- /dev/null +++ b/python/valis/routes/query.py @@ -0,0 +1,57 @@ +# !/usr/bin/env python +# -*- coding: utf-8 -*- +# + +from enum import Enum +from typing import List, Union +from fastapi import APIRouter, Depends, Query +from fastapi_utils.cbv import cbv + +from valis.routes.base import Base +from valis.db.db import get_pw_db +from valis.db.models import SDSSidStackedBase +from valis.db.queries import cone_search + + +class SearchCoordUnits(str, Enum): + """ Units of coordinate search radius """ + degree: str = "degree" + arcmin: str = "arcmin" + arcsec: str = "arcsec" + + +router = APIRouter() + + +@cbv(router) +class QueryRoutes(Base): + """ API routes for performing queries against sdss5db """ + + # @router.get('/test_sqla', summary='Perform a cone search for SDSS targets with sdss_ids', + # response_model=List[SDSSidStackedBaseA]) + # async def test_search(self, + # ra=Query(..., description='right ascension in degrees', example=315.01417), + # dec=Query(..., description='declination in degrees', example=35.299), + # radius=Query(..., description='the search radius in degrees', example=0.01), + # db=Depends(get_sqla_db)): + # """ Example for writing a route with a sqlalchemy ORM """ + # from sdssdb.sqlalchemy.sdss5db import vizdb + + # return db.query(vizdb.SDSSidStacked).\ + # filter(vizdb.SDSSidStacked.cone_search(ra, dec, radius, ra_col='ra_sdss_id', dec_col='dec_sdss_id')).all() + + @router.get('/main', summary='Main query for Search UI') + async def main_search(self): + pass + + + @router.get('/cone', summary='Perform a cone search for SDSS targets with sdss_ids', + response_model=List[SDSSidStackedBase], dependencies=[Depends(get_pw_db)]) + async def cone_search(self, + ra: Union[float, str] = Query(..., description='Right Ascension in degrees or hmsdms', example=315.01417), + dec: Union[float, str] = Query(..., description='Declination in degrees or hmsdms', example=35.299), + radius: float = Query(..., description='Search radius in specified units', example=0.01), + units: SearchCoordUnits = Query('degree', description='Units of search radius', example='degree')): + """ Perform a cone search """ + return list(cone_search(ra, dec, radius, units=units)) + diff --git a/python/valis/settings.py b/python/valis/settings.py index 3ddb4e9..9e92dcd 100644 --- a/python/valis/settings.py +++ b/python/valis/settings.py @@ -32,9 +32,18 @@ class EnvEnum(str, Enum): class Settings(BaseSettings): valis_env: EnvEnum = EnvEnum.dev - valis_allow_origin: Union[str, List[AnyHttpUrl]] = Field([], env="VALIS_ALLOW_ORIGIN") - - @validator('valis_allow_origin') + allow_origin: Union[str, List[AnyHttpUrl]] = Field([], env="VALIS_ALLOW_ORIGIN") + db_server: str = 'pipelines' + db_remote: bool = False + db_port: int = 5432 + db_user: str = None + db_host: str = 'localhost' + db_pass: str = None + + class Config: + env_prefix = "valis_" + + @validator('allow_origin') def must_be_list(cls, v): if not isinstance(v, list): return v.split(',') if ',' in v else [v] diff --git a/tests/test_queries.py b/tests/test_queries.py new file mode 100644 index 0000000..316cc84 --- /dev/null +++ b/tests/test_queries.py @@ -0,0 +1,24 @@ +# encoding: utf-8 +# + +import pytest +from valis.db.queries import convert_coords + + +@pytest.mark.parametrize('ra, dec, exp', + [('315.01417', '35.299', (315.01417, 35.299)), + ('315.01417', '-35.299', (315.01417, -35.299)), + (315.01417, -35.299, (315.01417, -35.299)), + ('21h00m03.4008s', '+35d17m56.4s', (315.01417, 35.299)), + ('21:00:03.4008', '+35:17:56.4', (315.01417, 35.299)), + ('21 00 03.4008', '+35 17 56.4', (315.01417, 35.299)), + ], + ids=['dec1', 'dec2', 'dec3', 'hms1', 'hms2', 'hms3']) +def test_convert_coords(ra, dec, exp): + """ test we can convert coordinates correctly """ + coord = convert_coords(ra, dec) + assert coord == exp + + + + diff --git a/tests/test_targets.py b/tests/test_targets.py index 51def4e..742360c 100644 --- a/tests/test_targets.py +++ b/tests/test_targets.py @@ -12,14 +12,14 @@ def test_resolve_target_name(client): data = get_data(response) assert data['coordinate']['value'] == [230.50745896, 43.53232817] assert data['coordinate']['unit'] == 'deg' - assert data['name'] == "2MASX J15220182+4331560" + assert data['name'] in ("LEDA 2223006", "2MASX J15220182+433156") def test_resolve_target_coord(client): response = client.get("/target/resolve/coord?coord=230.50745896&coord=43.53232817") data = get_data(response) assert len(data) == 1 - assert data[0]['main_id'] == "2MASX J15220182+4331560" + assert data[0]['main_id'] in ("LEDA 2223006", "2MASX J15220182+433156") assert data[0]['ra'] == "15 22 01.7901" assert data[0]['dec'] == "+43 31 56.381" assert data[0]['distance_result']['value'] == 0