Skip to content

Commit

Permalink
Merge pull request #4 from sdss/dbtest
Browse files Browse the repository at this point in the history
Adds initial connections to a database
  • Loading branch information
havok2063 authored Oct 24, 2023
2 parents 35d28eb + 03bc0ab commit ececdeb
Show file tree
Hide file tree
Showing 12 changed files with 443 additions and 23 deletions.
67 changes: 66 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ httpx = "^0.24.0"
astroquery = "^0.4.6"
pandas = "^1.5.3"
SQLAlchemy = "^1.4.35"
sdssdb = "^0.8.0"


[tool.poetry.dev-dependencies]
Expand Down
Empty file added python/valis/db/__init__.py
Empty file.
87 changes: 87 additions & 0 deletions python/valis/db/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#

from contextvars import ContextVar

import peewee
from fastapi import Depends, HTTPException
from sdssdb.peewee.sdss5db import database as pdb
from sdssdb.sqlalchemy.sdss5db import database as sdb

# To make Peewee async-compatible, we need to hack the peewee connection state
# See FastAPI/Peewee docs at https://fastapi.tiangolo.com/how-to/sql-databases-peewee/

# create global context variables for the peewee connection state
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


async def reset_db_state():
""" Sub-depdency for get_db that resets the context connection state """
pdb._state._state.set(db_state_default.copy())
pdb._state.reset()


class PeeweeConnectionState(peewee._ConnectionState):
""" Custom Connection State for Peewee """
def __init__(self, **kwargs):
super().__setattr__("_state", db_state)
super().__init__(**kwargs)

def __setattr__(self, name, value):
self._state.get()[name] = value

def __getattr__(self, name):
return self._state.get()[name]


# override the database connection state, after db connection
pdb._state = PeeweeConnectionState()


def connect_db(db, orm: str = 'peewee'):
""" Connect to the peewee sdss5db database """

from valis.main import settings
profset = db.set_profile(settings.db_server)
if settings.db_remote and not profset:
port = settings.db_port
user = settings.db_user
host = settings.db_host
passwd = settings.db_pass
db.connect_from_parameters(dbname='sdss5db', host=host, port=port,
user=user, password=passwd)

# raise error if we cannot connect
if not db.connected:
raise HTTPException(status_code=503, detail=f'Could not connect to database via sdssdb {orm}.')

return db


def get_pw_db(db_state=Depends(reset_db_state)):
""" Dependency to connect a database with peewee """

# connect to the db, yield None since we don't need the db in peewee
db = connect_db(pdb, orm='peewee')
try:
yield None
finally:
if db:
db.close()


def get_sqla_db():
""" Dependency to connect to a database with sqlalchemy """

# connect to the db, yield the db Session object for sql queries
db = connect_db(sdb, orm='sqla')
db = db.Session()
try:
yield db
finally:
if db:
db.close()


79 changes: 79 additions & 0 deletions python/valis/db/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#

# all resuable Pydantic models of the ORMs go here

import peewee
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
""" Class to convert peewee.ModelSelect into a list """
def get(self, key: Any, default: Any = None):
res = getattr(self._obj, key, default)
if isinstance(res, peewee.ModelSelect):
return list(res)
return res


class OrmBase(BaseModel):
""" Base pydantic model for sqlalchemy ORMs """

class Config:
orm_mode = True


class PeeweeBase(OrmBase):
""" Base pydantic model for peewee ORMs """

class Config:
orm_mode = True
getter_dict = PeeweeGetterDict


# class SDSSidStackedBaseA(OrmBase):
# """ Pydantic model for the SQLA vizdb.SDSSidStacked ORM """

# sdss_id: int = Field(..., description='the SDSS identifier')
# ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid')
# dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid')
# catalogid21: Optional[int] = Field(description='the version 21 catalog id')
# catalogid25: Optional[int] = Field(description='the version 25 catalog id')
# catalogid31: Optional[int] = Field(description='the version 31 catalog id')


class SDSSidStackedBase(PeeweeBase):
""" Pydantic model for the Peewee vizdb.SDSSidStacked ORM """

sdss_id: int = Field(..., description='the SDSS identifier')
ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid')
dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid')
catalogid21: Optional[int] = Field(description='the version 21 catalog id')
catalogid25: Optional[int] = Field(description='the version 25 catalog id')
catalogid31: Optional[int] = Field(description='the version 31 catalog id')


class SDSSidFlatBase(PeeweeBase):
""" Pydantic model for the Peewee vizdb.SDSSidFlat ORM """

sdss_id: int = Field(..., description='the SDSS identifier')
ra_sdss_id: float = Field(..., description='Right Ascension of the most recent cross-match catalogid')
dec_sdss_id: float = Field(..., description='Declination of the most recent cross-match catalogid')
catalogid: int = Field(..., descrption='the catalogid associated with the given sdss_id')
version_id: int = Field(..., descrption='the version of the catalog for a given catalogid')
n_associated: int = Field(..., description='The total number of sdss_ids associated with that catalogid.')
ra_cat: float = Field(..., descrption='Right Ascension, in degrees, specific to the catalogid')
dec_cat: float = Field(..., descrption='Declination, in degrees, specific to the catalogid')


class SDSSidPipesBase(PeeweeBase):
""" Pydantic model for the Peewee vizdb.SDSSidToPipes ORM """

sdss_id: int = Field(..., description='the SDSS identifier')
in_boss: bool = Field(..., description='Flag if the sdss_id is in the BHM reductions')
in_apogee: bool = Field(..., description='Flag if the sdss_id is in the Apogee/MWM reductions')
in_astra: bool = Field(..., description='Flag if the sdss_id is in the Astra reductions')

80 changes: 80 additions & 0 deletions python/valis/db/queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#

# all resuable queries go here

from typing import Union
import peewee
import astropy.units as u
from astropy.coordinates import SkyCoord
from sdssdb.peewee.sdss5db import vizdb


def convert_coords(ra: Union[str, float], dec: Union[str, float]) -> tuple:
""" Convert sky coordinates to decimal degrees
Convert the input RA, Dec sky coordinates into decimal
degrees. Input format can either be decimal or hmsdms.
Parameters
----------
ra : str
The Right Ascension
dec : str
The Declination
Returns
-------
tuple
the converted (RA, Dec)
"""
is_hms = set('hms: ') & set(str(ra))
if is_hms:
ra = str(ra).replace(' ', ':')
dec = str(dec).replace(' ', ':')
unit = ('hourangle', 'degree') if is_hms else ('degree', 'degree')
coord = SkyCoord(f'{ra} {dec}', unit=unit)
ra = round(coord.ra.value, 5)
dec = round(coord.dec.value, 5)
return float(ra), float(dec)


def cone_search(ra: Union[str, float], dec: Union[str, float],
radius: float, units: str = 'degree') -> peewee.ModelSelect:
""" Perform a cone search against the vizdb sdss_id_stacked table
Perform a cone search using the peewee ORM for SDSS targets in the
vizdb sdss_id_stacked table. We return the peewee ModelSelect
directly here so it can be easily combined with other queries.
In the route endpoint itself, remember to return wrap this in a list.
Parameters
----------
ra : Union[str, float]
the Right Ascension coord
dec : Union[str, float]
the Declination coord
radius : float
the cone search radius
units : str, optional
the units of the search radius, by default 'degree'
Returns
-------
peewee.ModelSelect
the ORM query
"""

# convert ra, dec to decimal if in hms-dms
ra, dec = convert_coords(ra, dec)

# convert radial units to degrees
radius *= u.Unit(units)
radius = radius.to(u.degree).value

return vizdb.SDSSidStacked.select().\
where(vizdb.SDSSidStacked.cone_search(ra, dec, radius,
ra_col='ra_sdss_id',
dec_col='dec_sdss_id'))
Loading

0 comments on commit ececdeb

Please sign in to comment.