Skip to content

Commit

Permalink
fixed issues around psycopg and sqlalchemy conn
Browse files Browse the repository at this point in the history
  • Loading branch information
wangpatrick57 committed Sep 2, 2024
1 parent e094409 commit 5ec9ff6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
12 changes: 6 additions & 6 deletions dbms/postgres/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Optional

import click
from sqlalchemy import Connection
import sqlalchemy

from benchmark.tpch.load_info import TpchLoadInfo
from dbms.load_info_base_class import LoadInfoBaseClass
Expand All @@ -36,9 +36,9 @@
DEFAULT_POSTGRES_DBNAME,
DEFAULT_POSTGRES_PORT,
SHARED_PRELOAD_LIBRARIES,
conn_execute,
create_conn,
create_sqlalchemy_conn,
sql_file_execute,
sqlalchemy_conn_execute,
)
from util.shell import subprocess_run

Expand Down Expand Up @@ -249,7 +249,7 @@ def _generic_dbdata_setup(dbgym_cfg: DBGymConfig) -> None:
def _load_benchmark_into_dbdata(
dbgym_cfg: DBGymConfig, benchmark_name: str, scale_factor: float
) -> None:
with create_conn(use_psycopg=False) as conn:
with create_sqlalchemy_conn() as conn:
if benchmark_name == "tpch":
load_info = TpchLoadInfo(dbgym_cfg, scale_factor)
else:
Expand All @@ -261,13 +261,13 @@ def _load_benchmark_into_dbdata(


def _load_into_dbdata(
dbgym_cfg: DBGymConfig, conn: Connection, load_info: LoadInfoBaseClass
dbgym_cfg: DBGymConfig, conn: sqlalchemy.Connection, load_info: LoadInfoBaseClass
) -> None:
sql_file_execute(dbgym_cfg, conn, load_info.get_schema_fpath())

# truncate all tables first before even loading a single one
for table, _ in load_info.get_tables_and_fpaths():
conn_execute(conn, f"TRUNCATE {table} CASCADE")
sqlalchemy_conn_execute(conn, f"TRUNCATE {table} CASCADE")
# then, load the tables
for table, table_fpath in load_info.get_tables_and_fpaths():
with open_and_save(dbgym_cfg, table_fpath, "r") as table_csv:
Expand Down
45 changes: 25 additions & 20 deletions util/pg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pathlib import Path
from typing import Any, List
from typing import Any, List, NewType, Union

import pglast
import psycopg
from sqlalchemy import Connection, Engine, create_engine, text
from sqlalchemy.engine import CursorResult
import sqlalchemy
from sqlalchemy import create_engine, text

from misc.utils import DBGymConfig, open_and_save

Expand All @@ -16,7 +16,9 @@
SHARED_PRELOAD_LIBRARIES = "boot,pg_hint_plan,pg_prewarm"


def conn_execute(conn: Connection, sql: str) -> CursorResult[Any]:
def sqlalchemy_conn_execute(
conn: sqlalchemy.Connection, sql: str
) -> sqlalchemy.engine.CursorResult[Any]:
return conn.execute(text(sql))


Expand All @@ -34,9 +36,11 @@ def sql_file_queries(dbgym_cfg: DBGymConfig, filepath: Path) -> list[str]:
return queries


def sql_file_execute(dbgym_cfg: DBGymConfig, conn: Connection, filepath: Path) -> None:
def sql_file_execute(
dbgym_cfg: DBGymConfig, conn: sqlalchemy.Connection, filepath: Path
) -> None:
for sql in sql_file_queries(dbgym_cfg, filepath):
conn_execute(conn, sql)
sqlalchemy_conn_execute(conn, sql)


# The reason pgport is an argument is because when doing agnet HPO, we want to run multiple instances of Postgres
Expand All @@ -50,17 +54,18 @@ def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool = True) -
return connstr_prefix + "://" + connstr_suffix


def create_conn(
pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool = True
) -> Connection:
connstr = get_connstr(use_psycopg=use_psycopg, pgport=pgport)
if use_psycopg:
psycopg_conn = psycopg.connect(connstr, autocommit=True, prepare_threshold=None)
engine = create_engine(connstr, creator=lambda: psycopg_conn)
return engine.connect()
else:
engine = create_engine(
connstr,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
return engine.connect()
def create_psycopg_conn(pgport: int = DEFAULT_POSTGRES_PORT) -> psycopg.Connection[Any]:
connstr = get_connstr(use_psycopg=True, pgport=pgport)
psycopg_conn = psycopg.connect(connstr, autocommit=True, prepare_threshold=None)
return psycopg_conn


def create_sqlalchemy_conn(
pgport: int = DEFAULT_POSTGRES_PORT,
) -> sqlalchemy.Connection:
connstr = get_connstr(use_psycopg=True, pgport=pgport)
engine = create_engine(
connstr,
execution_options={"isolation_level": "AUTOCOMMIT"},
)
return engine.connect()

0 comments on commit 5ec9ff6

Please sign in to comment.