Skip to content

Commit

Permalink
Cleanup how database manager is instantiated
Browse files Browse the repository at this point in the history
  • Loading branch information
nothingface0 committed Jun 5, 2024
1 parent 9dabba0 commit f4e8397
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 66 deletions.
48 changes: 43 additions & 5 deletions db.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,43 @@ class DQM2MirrorDB:
def __str__(self):
return f"{self.__class__.__name__}: {self.db_uri}"

def __init__(self, log: logging.Logger, db_uri: str = None, server: bool = False):
def __init__(
self,
log: logging.Logger,
username: str = "postgres",
password: str = "postgres",
host: str = "postgres",
port: int = 5432,
db_name: str = "postgres",
server: bool = False,
):
"""
The server flag will determine if table creation will take place or not, upon
initialization.
"""
self.log = log
self.password: str = password
self.username: str = username
self.host: str = host
self.port: int = port
self.db_name: str = db_name

self.log: logging.Logger = log
self.log.info("\n\n DQM2MirrorDB ===== init ")
self.db_uri = db_uri
self.db_uri: str = self.format_db_uri(
host=self.host,
port=self.port,
username=self.username,
password=self.password,
db_name=self.db_name,
)

if not self.db_uri:
if self.host == ":memory:":
self.db_uri = ":memory:"

self.log.info(
f"Connecting to database {self.db_name} on {self.username}@{self.host}:{self.port}"
)

self.engine = sqlalchemy.create_engine(
url=self.db_uri,
poolclass=sqlalchemy.pool.QueuePool,
Expand All @@ -123,7 +148,7 @@ def __init__(self, log: logging.Logger, db_uri: str = None, server: bool = False
)
if not database_exists(self.engine.url):
raise DatabaseNotFoundError(
f"Database name was not found when connecting to '{self.db_uri}'"
f"Database {self.db_name} was not found on '{self.host}:{self.port}'"
)

self.Session = sessionmaker(bind=self.engine)
Expand All @@ -132,6 +157,19 @@ def __init__(self, log: logging.Logger, db_uri: str = None, server: bool = False
self.db_meta = sqlalchemy.MetaData(bind=self.engine)
self.db_meta.reflect()

@staticmethod
def format_db_uri(
username: str = "postgres",
password: str = "postgres",
host: str = "postgres",
port: int = 5432,
db_name="postgres",
) -> str:
"""
Helper function to format the DB URI for SQLAclhemy
"""
return f"postgresql://{username}:{password}@{host}:{port}/{db_name}"

def create_tables(self):
"""
Initialize the databases
Expand Down
60 changes: 25 additions & 35 deletions dqmsquare_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,31 @@
TZ = pytz.timezone(TIMEZONE)


def format_db_uri(
username: str = "postgres",
password: str = "postgres",
host: str = "postgres",
port: int = 5432,
db_name="postgres",
) -> str:
"""
Helper function to format the DB URI for SQLAclhemy
"""
return f"postgresql://{username}:{password}@{host}:{port}/{db_name}"


def load_cfg() -> dict:
"""
Prepare configuration, using .env file
"""

load_dotenv()
# No leading slash: cinder/dqmsquare
mount_path = os.path.join("cinder", "dqmsquare")

### default values === >
### default values
cfg = {}
cfg["VERSION"] = "1.3.1"
cfg["VERSION"] = "1.3.2"

cfg["ENV"] = os.environ.get("ENV", "development")

# How often to try to get CMSSW jobs info
# sec, int
cfg["GRABBER_SLEEP_TIME_INFO"] = os.environ.get("GRABBER_SLEEP_TIME_INFO", 5)
cfg["GRABBER_SLEEP_TIME_INFO"] = int(os.environ.get("GRABBER_SLEEP_TIME_INFO", 5))

# How often to ping the cluster machines for their status.
# Keep it above 30 secs.
# sec, int
cfg["GRABBER_SLEEP_TIME_STATUS"] = os.environ.get("GRABBER_SLEEP_TIME_STATUS", 30)
cfg["GRABBER_SLEEP_TIME_STATUS"] = int(
os.environ.get("GRABBER_SLEEP_TIME_STATUS", 30)
)

cfg["LOGGER_ROTATION_TIME"] = 24 # h, int
cfg["LOGGER_MAX_N_LOG_FILES"] = 10 # int
Expand All @@ -61,7 +51,7 @@ def load_cfg() -> dict:
cfg["FFF_PORT"] = "9215"

# Flask server config
cfg["SERVER_DEBUG"] = os.environ.get("SERVER_DEBUG", False)
cfg["SERVER_DEBUG"] = bool(os.environ.get("SERVER_DEBUG", False))
# MACHETE
if isinstance(cfg["SERVER_DEBUG"], str):
cfg["SERVER_DEBUG"] = True if cfg["SERVER_DEBUG"] == "True" else False
Expand Down Expand Up @@ -89,15 +79,15 @@ def load_cfg() -> dict:
"CMSWEB_FRONTEND_PROXY_URL",
# If value is not found in .env
(
"https://cmsweb-testbed.cern.ch/dqm/dqm-square-origin-rubu"
"https://cmsweb.cern.ch/dqm/dqm-square-origin"
if cfg["ENV"] == "testbed"
else (
"https://cmsweb-testbed.cern.ch/dqm/dqm-square-origin-rubu"
"https://cmsweb.cern.ch/dqm/dqm-square-origin"
if cfg["ENV"] == "production"
else (
"https://cmsweb-testbed.cern.ch/dqm/dqm-square-origin-rubu"
"https://cmsweb.cern.ch/dqm/dqm-square-origin"
if cfg["ENV"] == "test4"
else "https://cmsweb-testbed.cern.ch/dqm/dqm-square-origin-rubu"
else "https://cmsweb.cern.ch/dqm/dqm-square-origin"
)
)
),
Expand Down Expand Up @@ -159,20 +149,20 @@ def load_cfg() -> dict:
if isinstance(cfg["GRABBER_DEBUG"], str):
cfg["GRABBER_DEBUG"] = True if cfg["GRABBER_DEBUG"] == "True" else False

cfg["DB_PLAYBACK_URI"] = format_db_uri(
username=os.environ.get("POSTGRES_USERNAME", "postgres"),
password=os.environ.get("POSTGRES_PASSWORD", "postgres"),
host=os.environ.get("POSTGRES_HOST", "127.0.0.1"),
port=os.environ.get("POSTGRES_PORT", 5432),
db_name=os.environ.get("POSTGRES_PLAYBACK_DB_NAME", "postgres"),
)
cfg["DB_PRODUCTION_URI"] = format_db_uri(
username=os.environ.get("POSTGRES_USERNAME", "postgres"),
password=os.environ.get("POSTGRES_PASSWORD", "postgres"),
host=os.environ.get("POSTGRES_HOST", "127.0.0.1"),
port=os.environ.get("POSTGRES_PORT", 5432),
db_name=os.environ.get("POSTGRES_PRODUCTION_DB_NAME", "postgres_production"),
cfg["DB_PLAYBACK_USERNAME"] = os.environ.get("POSTGRES_USERNAME", "postgres")
cfg["DB_PLAYBACK_PASSWORD"] = os.environ.get("POSTGRES_PASSWORD", "postgres")
cfg["DB_PLAYBACK_HOST"] = os.environ.get("POSTGRES_HOST", "127.0.0.1")
cfg["DB_PLAYBACK_PORT"] = os.environ.get("POSTGRES_PORT", 5432)
cfg["DB_PLAYBACK_NAME"] = os.environ.get("POSTGRES_PLAYBACK_DB_NAME", "postgres")

cfg["DB_PRODUCTION_USERNAME"] = os.environ.get("POSTGRES_USERNAME", "postgres")
cfg["DB_PRODUCTION_PASSWORD"] = os.environ.get("POSTGRES_PASSWORD", "postgres")
cfg["DB_PRODUCTION_HOST"] = os.environ.get("POSTGRES_HOST", "127.0.0.1")
cfg["DB_PRODUCTION_PORT"] = os.environ.get("POSTGRES_PORT", 5432)
cfg["DB_PRODUCTION_NAME"] = os.environ.get(
"POSTGRES_PRODUCTION_DB_NAME", "postgres_production"
)

cfg["TIMEZONE"] = TIMEZONE
return cfg

Expand Down
38 changes: 28 additions & 10 deletions grabber.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def get_cluster_status(db: DQM2MirrorDB, cluster: str = "playback"):
Function that queries the gateway playback machine periodically to get the status of the
production or playback cluster machines.
"""
logger.debug(f"Requesting {cluster} cluster status.")
url = urljoin(
cfg["CMSWEB_FRONTEND_PROXY_URL"] + "/",
"cr/exe?" + urlencode({"cluster": cluster, "what": "get_cluster_status"}),
Expand All @@ -163,6 +164,7 @@ def get_cluster_status(db: DQM2MirrorDB, cluster: str = "playback"):
raise Exception(
f"Failed to fetch {cluster} status. Got ({response.status_code}) {response.text}"
)
logger.debug(f"Got {cluster} cluster status.")

try:
response = response.json()
Expand All @@ -184,9 +186,9 @@ def get_latest_info_from_hosts(hosts: list[str], db: DQM2MirrorDB) -> None:


if __name__ == "__main__":
run_modes = ["playback", "production"]
playback_machines = cfg["FFF_PLAYBACK_MACHINES"]
production_machines = cfg["FFF_PRODUCTION_MACHINES"]
run_modes: list[str] = ["playback", "production"]
playback_machines: list[str] = cfg["FFF_PLAYBACK_MACHINES"]
production_machines: list[str] = cfg["FFF_PRODUCTION_MACHINES"]

if len(sys.argv) > 1 and sys.argv[1] == "playback":
set_log_handler(
Expand Down Expand Up @@ -224,25 +226,41 @@ def get_latest_info_from_hosts(hosts: list[str], db: DQM2MirrorDB) -> None:
logger.info(f"Configured logger for grabber, level={level}")

### global variables and auth cookies
cmsweb_proxy_url = cfg["CMSWEB_FRONTEND_PROXY_URL"]
cert_path = [cfg["SERVER_GRID_CERT_PATH"], cfg["SERVER_GRID_KEY_PATH"]]
cmsweb_proxy_url: str = cfg["CMSWEB_FRONTEND_PROXY_URL"]
cert_path: list[str] = [cfg["SERVER_GRID_CERT_PATH"], cfg["SERVER_GRID_KEY_PATH"]]

env_secret = os.environ.get("DQM_FFF_SECRET")
env_secret: str = os.environ.get("DQM_FFF_SECRET")
if env_secret:
fff_secret = env_secret
logger.debug("Found secret in environmental variables")
else:
logger.warning("No secret found in environmental variables")

# Trailing whitespace in secret leads to crashes, strip it
cookies = {str(cfg["FFF_SECRET_NAME"]): env_secret.strip()}
cookies: dict[str, str] = {str(cfg["FFF_SECRET_NAME"]): env_secret.strip()}

# DB CONNECTION
db_playback, db_production = None, None
db_playback: DQM2MirrorDB = None
db_production: DQM2MirrorDB = None

if "playback" in run_modes:
db_playback = DQM2MirrorDB(logger, cfg["DB_PLAYBACK_URI"])
db_playback = DQM2MirrorDB(
log=logger,
host=cfg.get("DB_PLAYBACK_HOST"),
port=cfg.get("DB_PLAYBACK_PORT"),
username=cfg.get("DB_PLAYBACK_USERNAME"),
password=cfg.get("DB_PLAYBACK_PASSWORD"),
db_name=cfg.get("DB_PLAYBACK_NAME"),
)
if "production" in run_modes:
db_production = DQM2MirrorDB(logger, cfg["DB_PRODUCTION_URI"])
db_production = DQM2MirrorDB(
log=logger,
host=cfg.get("DB_PRODUCTION_HOST"),
port=cfg.get("DB_PRODUCTION_PORT"),
username=cfg.get("DB_PRODUCTION_USERNAME"),
password=cfg.get("DB_PRODUCTION_PASSWORD"),
db_name=cfg.get("DB_PRODUCTION_NAME"),
)

logger.info("Starting loop for modes " + str(run_modes))

Expand Down
22 changes: 19 additions & 3 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
log.info("start_server() call ... ")


def create_app(cfg):
def create_app(cfg: dict):
app = Flask(
__name__, static_url_path=os.path.join("/", cfg["SERVER_URL_PREFIX"], "static")
)
Expand Down Expand Up @@ -56,8 +56,24 @@ def create_app(cfg):
).strip()
}

db_playback = DQM2MirrorDB(log, cfg["DB_PLAYBACK_URI"], server=True)
db_production = DQM2MirrorDB(log, cfg["DB_PRODUCTION_URI"], server=True)
db_playback = DQM2MirrorDB(
log=log,
host=cfg.get("DB_PLAYBACK_HOST"),
port=cfg.get("DB_PLAYBACK_PORT"),
username=cfg.get("DB_PLAYBACK_USERNAME"),
password=cfg.get("DB_PLAYBACK_PASSWORD"),
db_name=cfg.get("DB_PLAYBACK_NAME"),
server=True,
)
db_production = DQM2MirrorDB(
log=log,
host=cfg.get("DB_PRODUCTION_HOST"),
port=cfg.get("DB_PRODUCTION_PORT"),
username=cfg.get("DB_PRODUCTION_USERNAME"),
password=cfg.get("DB_PRODUCTION_PASSWORD"),
db_name=cfg.get("DB_PRODUCTION_NAME"),
server=True,
)
databases = {
"playback": db_playback,
"production": db_production,
Expand Down
13 changes: 8 additions & 5 deletions tests/test_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy import create_engine, text
from sqlalchemy_utils import create_database, database_exists, drop_database
from custom_logger import dummy_log
from dqmsquare_cfg import format_db_uri, TZ
from dqmsquare_cfg import TZ


def format_entry_to_db_entry(graph_entry: list, datetime_cols: list):
Expand All @@ -29,21 +29,24 @@ def format_entry_to_db_entry(graph_entry: list, datetime_cols: list):


@pytest.fixture
def testing_database() -> DQM2MirrorDB:
db_uri = format_db_uri(
def testing_database():
db_uri = DQM2MirrorDB.format_db_uri(
username=os.environ.get("POSTGRES_USERNAME", "postgres"),
password=os.environ.get("POSTGRES_PASSWORD", "postgres"),
host=os.environ.get("POSTGRES_HOST", "127.0.0.1"),
port=os.environ.get("POSTGRES_PORT", 5432),
db_name="postgres_test",
)

engine = create_engine(db_uri)
if not database_exists(engine.url):
create_database(db_uri)
db = DQM2MirrorDB(
log=dummy_log(),
db_uri=db_uri,
username=os.environ.get("POSTGRES_USERNAME", "postgres"),
password=os.environ.get("POSTGRES_PASSWORD", "postgres"),
host=os.environ.get("POSTGRES_HOST", "127.0.0.1"),
port=os.environ.get("POSTGRES_PORT", 5432),
db_name="postgres_test",
server=False,
)

Expand Down
Loading

0 comments on commit f4e8397

Please sign in to comment.