-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from sdss/dbtest
Adds initial connections to a database
- Loading branch information
Showing
12 changed files
with
443 additions
and
23 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# !/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
|
||
from contextvars import ContextVar | ||
|
||
import peewee | ||
from fastapi import Depends, HTTPException | ||
from sdssdb.peewee.sdss5db import database as pdb | ||
from sdssdb.sqlalchemy.sdss5db import database as sdb | ||
|
||
# To make Peewee async-compatible, we need to hack the peewee connection state | ||
# See FastAPI/Peewee docs at https://fastapi.tiangolo.com/how-to/sql-databases-peewee/ | ||
|
||
# create global context variables for the peewee connection state | ||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} | ||
db_state = ContextVar("db_state", default=db_state_default.copy()) | ||
|
||
|
||
async def reset_db_state(): | ||
""" Sub-depdency for get_db that resets the context connection state """ | ||
pdb._state._state.set(db_state_default.copy()) | ||
pdb._state.reset() | ||
|
||
|
||
class PeeweeConnectionState(peewee._ConnectionState): | ||
""" Custom Connection State for Peewee """ | ||
def __init__(self, **kwargs): | ||
super().__setattr__("_state", db_state) | ||
super().__init__(**kwargs) | ||
|
||
def __setattr__(self, name, value): | ||
self._state.get()[name] = value | ||
|
||
def __getattr__(self, name): | ||
return self._state.get()[name] | ||
|
||
|
||
# override the database connection state, after db connection | ||
pdb._state = PeeweeConnectionState() | ||
|
||
|
||
def connect_db(db, orm: str = 'peewee'): | ||
""" Connect to the peewee sdss5db database """ | ||
|
||
from valis.main import settings | ||
profset = db.set_profile(settings.db_server) | ||
if settings.db_remote and not profset: | ||
port = settings.db_port | ||
user = settings.db_user | ||
host = settings.db_host | ||
passwd = settings.db_pass | ||
db.connect_from_parameters(dbname='sdss5db', host=host, port=port, | ||
user=user, password=passwd) | ||
|
||
# raise error if we cannot connect | ||
if not db.connected: | ||
raise HTTPException(status_code=503, detail=f'Could not connect to database via sdssdb {orm}.') | ||
|
||
return db | ||
|
||
|
||
def get_pw_db(db_state=Depends(reset_db_state)): | ||
""" Dependency to connect a database with peewee """ | ||
|
||
# connect to the db, yield None since we don't need the db in peewee | ||
db = connect_db(pdb, orm='peewee') | ||
try: | ||
yield None | ||
finally: | ||
if db: | ||
db.close() | ||
|
||
|
||
def get_sqla_db(): | ||
""" Dependency to connect to a database with sqlalchemy """ | ||
|
||
# connect to the db, yield the db Session object for sql queries | ||
db = connect_db(sdb, orm='sqla') | ||
db = db.Session() | ||
try: | ||
yield db | ||
finally: | ||
if db: | ||
db.close() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# !/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
|
||
# all resuable Pydantic models of the ORMs go here | ||
|
||
import peewee | ||
from typing import Any, Optional | ||
from pydantic import BaseModel, Field | ||
from pydantic.utils import GetterDict | ||
|
||
|
||
class PeeweeGetterDict(GetterDict): | ||
""" Class to convert peewee.ModelSelect into a list """ | ||
def get(self, key: Any, default: Any = None): | ||
res = getattr(self._obj, key, default) | ||
if isinstance(res, peewee.ModelSelect): | ||
return list(res) | ||
return res | ||
|
||
|
||
class OrmBase(BaseModel): | ||
""" Base pydantic model for sqlalchemy ORMs """ | ||
|
||
class Config: | ||
orm_mode = True | ||
|
||
|
||
class PeeweeBase(OrmBase): | ||
""" Base pydantic model for peewee ORMs """ | ||
|
||
class Config: | ||
orm_mode = True | ||
getter_dict = PeeweeGetterDict | ||
|
||
|
||
# class SDSSidStackedBaseA(OrmBase): | ||
# """ Pydantic model for the SQLA vizdb.SDSSidStacked ORM """ | ||
|
||
# sdss_id: int = Field(..., description='the SDSS identifier') | ||
# ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid') | ||
# dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid') | ||
# catalogid21: Optional[int] = Field(description='the version 21 catalog id') | ||
# catalogid25: Optional[int] = Field(description='the version 25 catalog id') | ||
# catalogid31: Optional[int] = Field(description='the version 31 catalog id') | ||
|
||
|
||
class SDSSidStackedBase(PeeweeBase): | ||
""" Pydantic model for the Peewee vizdb.SDSSidStacked ORM """ | ||
|
||
sdss_id: int = Field(..., description='the SDSS identifier') | ||
ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid') | ||
dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid') | ||
catalogid21: Optional[int] = Field(description='the version 21 catalog id') | ||
catalogid25: Optional[int] = Field(description='the version 25 catalog id') | ||
catalogid31: Optional[int] = Field(description='the version 31 catalog id') | ||
|
||
|
||
class SDSSidFlatBase(PeeweeBase): | ||
""" Pydantic model for the Peewee vizdb.SDSSidFlat ORM """ | ||
|
||
sdss_id: int = Field(..., description='the SDSS identifier') | ||
ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid') | ||
dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid') | ||
catalogid: int = Field(..., descrption='the catalogid associated with the given sdss_id') | ||
version_id: int = Field(..., descrption='the version of the catalog for a given catalogid') | ||
n_associated: int = Field(..., description='The total number of sdss_ids associated with that catalogid.') | ||
ra_cat: float = Field(..., descrption='Right Ascension, in degrees, specific to the catalogid') | ||
dec_cat: float = Field(..., descrption='Declination, in degrees, specific to the catalogid') | ||
|
||
|
||
class SDSSidPipesBase(PeeweeBase): | ||
""" Pydantic model for the Peewee vizdb.SDSSidToPipes ORM """ | ||
|
||
sdss_id: int = Field(..., description='the SDSS identifier') | ||
in_boss: bool = Field(..., description='Flag if the sdss_id is in the BHM reductions') | ||
in_apogee: bool = Field(..., description='Flag if the sdss_id is in the Apogee/MWM reductions') | ||
in_astra: bool = Field(..., description='Flag if the sdss_id is in the Astra reductions') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# !/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
|
||
# all resuable queries go here | ||
|
||
from typing import Union | ||
import peewee | ||
import astropy.units as u | ||
from astropy.coordinates import SkyCoord | ||
from sdssdb.peewee.sdss5db import vizdb | ||
|
||
|
||
def convert_coords(ra: Union[str, float], dec: Union[str, float]) -> tuple: | ||
""" Convert sky coordinates to decimal degrees | ||
Convert the input RA, Dec sky coordinates into decimal | ||
degrees. Input format can either be decimal or hmsdms. | ||
Parameters | ||
---------- | ||
ra : str | ||
The Right Ascension | ||
dec : str | ||
The Declination | ||
Returns | ||
------- | ||
tuple | ||
the converted (RA, Dec) | ||
""" | ||
is_hms = set('hms: ') & set(str(ra)) | ||
if is_hms: | ||
ra = str(ra).replace(' ', ':') | ||
dec = str(dec).replace(' ', ':') | ||
unit = ('hourangle', 'degree') if is_hms else ('degree', 'degree') | ||
coord = SkyCoord(f'{ra} {dec}', unit=unit) | ||
ra = round(coord.ra.value, 5) | ||
dec = round(coord.dec.value, 5) | ||
return float(ra), float(dec) | ||
|
||
|
||
def cone_search(ra: Union[str, float], dec: Union[str, float], | ||
radius: float, units: str = 'degree') -> peewee.ModelSelect: | ||
""" Perform a cone search against the vizdb sdss_id_stacked table | ||
Perform a cone search using the peewee ORM for SDSS targets in the | ||
vizdb sdss_id_stacked table. We return the peewee ModelSelect | ||
directly here so it can be easily combined with other queries. | ||
In the route endpoint itself, remember to return wrap this in a list. | ||
Parameters | ||
---------- | ||
ra : Union[str, float] | ||
the Right Ascension coord | ||
dec : Union[str, float] | ||
the Declination coord | ||
radius : float | ||
the cone search radius | ||
units : str, optional | ||
the units of the search radius, by default 'degree' | ||
Returns | ||
------- | ||
peewee.ModelSelect | ||
the ORM query | ||
""" | ||
|
||
# convert ra, dec to decimal if in hms-dms | ||
ra, dec = convert_coords(ra, dec) | ||
|
||
# convert radial units to degrees | ||
radius *= u.Unit(units) | ||
radius = radius.to(u.degree).value | ||
|
||
return vizdb.SDSSidStacked.select().\ | ||
where(vizdb.SDSSidStacked.cone_search(ra, dec, radius, | ||
ra_col='ra_sdss_id', | ||
dec_col='dec_sdss_id')) |
Oops, something went wrong.