diff --git a/.github/workflows/tests_ci.yml b/.github/workflows/tests_ci.yml index 948b2a31..46295d8b 100644 --- a/.github/workflows/tests_ci.yml +++ b/.github/workflows/tests_ci.yml @@ -36,6 +36,10 @@ jobs: run: | ./scripts/check_format.sh + - name: Static type checking + run: | + mypy --config-file scripts/mypy.ini . + - name: Run unit tests run: | . "$HOME/.cargo/env" diff --git a/.gitignore b/.gitignore index 383aa46f..4d6abb6e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ +.mypy_cache/ .conda/ .idea/ test_clean_scratchspace/ diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmark/cli.py b/benchmark/cli.py index 2edf5d7f..cd58d55e 100644 --- a/benchmark/cli.py +++ b/benchmark/cli.py @@ -6,7 +6,7 @@ @click.group(name="benchmark") @click.pass_obj -def benchmark_group(dbgym_cfg: DBGymConfig): +def benchmark_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("benchmark") diff --git a/benchmark/tpch/__init__.py b/benchmark/tpch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmark/tpch/cli.py b/benchmark/tpch/cli.py index 975fd769..270cb629 100644 --- a/benchmark/tpch/cli.py +++ b/benchmark/tpch/cli.py @@ -1,6 +1,5 @@ import logging -import os -import shutil +from pathlib import Path import click @@ -10,7 +9,6 @@ link_result, workload_name_fn, ) -from util.pg import * from util.shell import subprocess_run benchmark_tpch_logger = logging.getLogger("benchmark/tpch") @@ -19,7 +17,7 @@ @click.group(name="tpch") @click.pass_obj -def tpch_group(dbgym_cfg: DBGymConfig): +def tpch_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("tpch") @@ -28,7 +26,7 @@ def tpch_group(dbgym_cfg: DBGymConfig): @click.pass_obj # The reason generate data is separate from create dbdata is because generate-data is generic # to all DBMSs while create dbdata is specific to a single DBMS. -def tpch_data(dbgym_cfg: DBGymConfig, scale_factor: float): +def tpch_data(dbgym_cfg: DBGymConfig, scale_factor: float) -> None: _clone(dbgym_cfg) _generate_data(dbgym_cfg, scale_factor) @@ -59,7 +57,7 @@ def tpch_workload( seed_end: int, query_subset: str, scale_factor: float, -): +) -> None: assert ( seed_start <= seed_end ), f"seed_start ({seed_start}) must be <= seed_end ({seed_end})" @@ -72,7 +70,7 @@ def _get_queries_dname(seed: int, scale_factor: float) -> str: return f"queries_{seed}_sf{get_scale_factor_string(scale_factor)}" -def _clone(dbgym_cfg: DBGymConfig): +def _clone(dbgym_cfg: DBGymConfig) -> None: expected_symlink_dpath = ( dbgym_cfg.cur_symlinks_build_path(mkdir=True) / "tpch-kit.link" ) @@ -102,7 +100,7 @@ def _get_tpch_kit_dpath(dbgym_cfg: DBGymConfig) -> Path: def _generate_queries( dbgym_cfg: DBGymConfig, seed_start: int, seed_end: int, scale_factor: float -): +) -> None: tpch_kit_dpath = _get_tpch_kit_dpath(dbgym_cfg) data_path = dbgym_cfg.cur_symlinks_data_path(mkdir=True) benchmark_tpch_logger.info( @@ -132,7 +130,7 @@ def _generate_queries( ) -def _generate_data(dbgym_cfg: DBGymConfig, scale_factor: float): +def _generate_data(dbgym_cfg: DBGymConfig, scale_factor: float) -> None: tpch_kit_dpath = _get_tpch_kit_dpath(dbgym_cfg) data_path = dbgym_cfg.cur_symlinks_data_path(mkdir=True) expected_tables_symlink_dpath = ( @@ -162,7 +160,7 @@ def _generate_workload( seed_end: int, query_subset: str, scale_factor: float, -): +) -> None: symlink_data_dpath = dbgym_cfg.cur_symlinks_data_path(mkdir=True) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) expected_workload_symlink_dpath = symlink_data_dpath / (workload_name + ".link") @@ -177,6 +175,8 @@ def _generate_workload( queries = [f"{i}" for i in range(1, 22 + 1) if i % 2 == 0] elif query_subset == "odd": queries = [f"{i}" for i in range(1, 22 + 1) if i % 2 == 1] + else: + assert False with open(real_dpath / "order.txt", "w") as f: for seed in range(seed_start, seed_end + 1): diff --git a/benchmark/tpch/load_info.py b/benchmark/tpch/load_info.py index 2c84ac2b..1076c1f7 100644 --- a/benchmark/tpch/load_info.py +++ b/benchmark/tpch/load_info.py @@ -1,3 +1,6 @@ +from pathlib import Path +from typing import Optional + from dbms.load_info_base_class import LoadInfoBaseClass from misc.utils import DBGymConfig, get_scale_factor_string @@ -55,11 +58,11 @@ def __init__(self, dbgym_cfg: DBGymConfig, scale_factor: float): table_fpath = tables_dpath / f"{table}.tbl" self._tables_and_fpaths.append((table, table_fpath)) - def get_schema_fpath(self): + def get_schema_fpath(self) -> Path: return self._schema_fpath - def get_tables_and_fpaths(self): + def get_tables_and_fpaths(self) -> list[tuple[str, Path]]: return self._tables_and_fpaths - def get_constraints_fpath(self): + def get_constraints_fpath(self) -> Optional[Path]: return self._constraints_fpath diff --git a/dbms/__init__.py b/dbms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbms/cli.py b/dbms/cli.py index b71bed18..990f096c 100644 --- a/dbms/cli.py +++ b/dbms/cli.py @@ -1,11 +1,12 @@ import click from dbms.postgres.cli import postgres_group +from misc.utils import DBGymConfig @click.group(name="dbms") @click.pass_obj -def dbms_group(dbgym_cfg): +def dbms_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("dbms") diff --git a/dbms/load_info_base_class.py b/dbms/load_info_base_class.py index a5aec24e..40df2590 100644 --- a/dbms/load_info_base_class.py +++ b/dbms/load_info_base_class.py @@ -1,3 +1,7 @@ +from pathlib import Path +from typing import Optional + + class LoadInfoBaseClass: """ A base class for providing info for DBMSs to load the data of a benchmark @@ -5,12 +9,12 @@ class LoadInfoBaseClass: copy the comments or type annotations or else they might become out of sync. """ - def get_schema_fpath(self) -> str: + def get_schema_fpath(self) -> Path: raise NotImplemented - def get_tables_and_fpaths(self) -> list[(str, str)]: + def get_tables_and_fpaths(self) -> list[tuple[str, Path]]: raise NotImplemented # If the subclassing benchmark does not have constraints, you can return None here - def get_constraints_fpath(self) -> str | None: + def get_constraints_fpath(self) -> Optional[Path]: raise NotImplemented diff --git a/dbms/postgres/__init__.py b/dbms/postgres/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dbms/postgres/cli.py b/dbms/postgres/cli.py index 140f7e7c..2d968a98 100644 --- a/dbms/postgres/cli.py +++ b/dbms/postgres/cli.py @@ -10,9 +10,10 @@ import shutil import subprocess from pathlib import Path +from typing import Optional import click -from sqlalchemy import Connection +import sqlalchemy from benchmark.tpch.load_info import TpchLoadInfo from dbms.load_info_base_class import LoadInfoBaseClass @@ -35,9 +36,9 @@ DEFAULT_POSTGRES_DBNAME, DEFAULT_POSTGRES_PORT, SHARED_PRELOAD_LIBRARIES, - conn_execute, - create_conn, + create_sqlalchemy_conn, sql_file_execute, + sqlalchemy_conn_execute, ) from util.shell import subprocess_run @@ -47,7 +48,7 @@ @click.group(name="postgres") @click.pass_obj -def postgres_group(dbgym_cfg: DBGymConfig): +def postgres_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("postgres") @@ -61,7 +62,7 @@ def postgres_group(dbgym_cfg: DBGymConfig): is_flag=True, help="Include this flag to rebuild Postgres even if it already exists.", ) -def postgres_build(dbgym_cfg: DBGymConfig, rebuild: bool): +def postgres_build(dbgym_cfg: DBGymConfig, rebuild: bool) -> None: _build_repo(dbgym_cfg, rebuild) @@ -94,14 +95,14 @@ def postgres_dbdata( dbgym_cfg: DBGymConfig, benchmark_name: str, scale_factor: float, - pgbin_path: Path, + pgbin_path: Optional[Path], intended_dbdata_hardware: str, - dbdata_parent_dpath: Path, -): + dbdata_parent_dpath: Optional[Path], +) -> None: # Set args to defaults programmatically (do this before doing anything else in the function) - if pgbin_path == None: + if pgbin_path is None: pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path) - if dbdata_parent_dpath == None: + if dbdata_parent_dpath is None: dbdata_parent_dpath = default_dbdata_parent_dpath( dbgym_cfg.dbgym_workspace_path ) @@ -138,7 +139,7 @@ def _get_repo_symlink_path(dbgym_cfg: DBGymConfig) -> Path: return dbgym_cfg.cur_symlinks_build_path("repo.link") -def _build_repo(dbgym_cfg: DBGymConfig, rebuild): +def _build_repo(dbgym_cfg: DBGymConfig, rebuild: bool) -> None: expected_repo_symlink_dpath = _get_repo_symlink_path(dbgym_cfg) if not rebuild and expected_repo_symlink_dpath.exists(): dbms_postgres_logger.info( @@ -209,7 +210,7 @@ def _create_dbdata( dbms_postgres_logger.info(f"Created dbdata in {dbdata_tgz_symlink_path}") -def _generic_dbdata_setup(dbgym_cfg: DBGymConfig): +def _generic_dbdata_setup(dbgym_cfg: DBGymConfig) -> None: # get necessary vars pgbin_real_dpath = _get_pgbin_symlink_path(dbgym_cfg).resolve() assert pgbin_real_dpath.exists() @@ -247,8 +248,8 @@ def _generic_dbdata_setup(dbgym_cfg: DBGymConfig): def _load_benchmark_into_dbdata( dbgym_cfg: DBGymConfig, benchmark_name: str, scale_factor: float -): - with create_conn(use_psycopg=False) as conn: +) -> None: + with create_sqlalchemy_conn() as conn: if benchmark_name == "tpch": load_info = TpchLoadInfo(dbgym_cfg, scale_factor) else: @@ -260,23 +261,27 @@ def _load_benchmark_into_dbdata( def _load_into_dbdata( - dbgym_cfg: DBGymConfig, conn: Connection, load_info: LoadInfoBaseClass -): + dbgym_cfg: DBGymConfig, conn: sqlalchemy.Connection, load_info: LoadInfoBaseClass +) -> None: sql_file_execute(dbgym_cfg, conn, load_info.get_schema_fpath()) # truncate all tables first before even loading a single one for table, _ in load_info.get_tables_and_fpaths(): - conn_execute(conn, f"TRUNCATE {table} CASCADE") + sqlalchemy_conn_execute(conn, f"TRUNCATE {table} CASCADE") # then, load the tables for table, table_fpath in load_info.get_tables_and_fpaths(): with open_and_save(dbgym_cfg, table_fpath, "r") as table_csv: - with conn.connection.dbapi_connection.cursor() as cur: + assert conn.connection.dbapi_connection is not None + cur = conn.connection.dbapi_connection.cursor() + try: with cur.copy(f"COPY {table} FROM STDIN CSV DELIMITER '|'") as copy: while data := table_csv.read(8192): copy.write(data) + finally: + cur.close() constraints_fpath = load_info.get_constraints_fpath() - if constraints_fpath != None: + if constraints_fpath is not None: sql_file_execute(dbgym_cfg, conn, constraints_fpath) diff --git a/dependencies/requirements.txt b/dependencies/requirements.txt index ba32594c..2bf61344 100644 --- a/dependencies/requirements.txt +++ b/dependencies/requirements.txt @@ -1,6 +1,7 @@ absl-py==2.1.0 aiosignal==1.3.1 astunparse==1.6.3 +async-timeout==4.0.3 attrs==23.2.0 black==24.2.0 cachetools==5.3.2 @@ -25,7 +26,7 @@ google-auth-oauthlib==1.0.0 google-pasta==0.2.0 greenlet==3.0.3 grpcio==1.60.0 -gymnasium==0.28.1 +gymnasium==0.29.1 h5py==3.10.0 hyperopt==0.2.7 idna==3.6 @@ -44,6 +45,7 @@ MarkupSafe==2.1.4 ml-dtypes==0.2.0 mpmath==1.3.0 msgpack==1.0.7 +mypy==1.11.2 mypy-extensions==1.0.0 networkx==3.2.1 numpy==1.26.3 @@ -56,7 +58,7 @@ nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu11==11.7.99 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu11==8.5.0.96 -nvidia-cudnn-cu12==8.9.2.26 +nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu11==10.9.0.58 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu11==10.2.10.91 @@ -66,7 +68,7 @@ nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu11==11.7.4.91 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu11==2.14.3 -nvidia-nccl-cu12==2.19.3 +nvidia-nccl-cu12==2.20.5 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu11==11.7.91 nvidia-nvtx-cu12==12.1.105 @@ -74,6 +76,7 @@ oauthlib==3.2.2 opt-einsum==3.3.0 packaging==23.2 pandas==2.2.0 +pandas-stubs==2.2.2.240807 pathspec==0.12.1 pglast==6.2 platformdirs==4.2.0 @@ -92,6 +95,7 @@ pytz==2023.4 PyYAML==6.0.1 ray==2.9.3 record-keeper==0.9.32 +redis==5.0.3 referencing==0.33.0 requests==2.31.0 requests-oauthlib==1.3.1 @@ -112,9 +116,12 @@ tensorflow-io-gcs-filesystem==0.36.0 termcolor==2.4.0 threadpoolctl==3.2.0 tomli==2.0.1 -torch==2.0.0 +torch==2.4.0 tqdm==4.66.1 -triton==2.0.0 +triton==3.0.0 +types-python-dateutil==2.9.0.20240821 +types-pytz==2024.1.0.20240417 +types-PyYAML==6.0.12.20240808 typing_extensions==4.9.0 tzdata==2023.4 urllib3==2.2.0 @@ -122,4 +129,3 @@ virtualenv==20.25.0 Werkzeug==3.0.1 wrapt==1.14.1 zipp==3.17.0 -redis==5.0.3 diff --git a/manage/cli.py b/manage/cli.py index 3f3cba2e..af839570 100644 --- a/manage/cli.py +++ b/manage/cli.py @@ -3,81 +3,34 @@ import shutil from itertools import chain from pathlib import Path -from typing import List, Set import click -import yaml -from misc.utils import DBGymConfig, is_child_path, parent_dpath_of_path +from misc.utils import ( + DBGymConfig, + get_runs_path_from_workspace_path, + get_symlinks_path_from_workspace_path, + is_child_path, + parent_dpath_of_path, +) task_logger = logging.getLogger("task") task_logger.setLevel(logging.INFO) -@click.group(name="manage") -def manage_group(): - pass - - -@click.command(name="show") -@click.argument("keys", nargs=-1) -@click.pass_obj -def manage_show(dbgym_cfg, keys): - config_path = dbgym_cfg.path - config_yaml = dbgym_cfg.yaml - - # Traverse the YAML. - for key in keys: - config_yaml = config_yaml[key] - - # Pretty-print the requested YAML value. - output_str = None - if type(config_yaml) != dict: - output_str = config_yaml - else: - output_str = yaml.dump(config_yaml, default_flow_style=False) - if len(keys) > 0: - output_str = " " + output_str.replace("\n", "\n ") - output_str = output_str.rstrip() - print(output_str) - - task_logger.info(f"Read: {Path(config_path)}") - - -@click.command(name="write") -@click.argument("keys", nargs=-1) -@click.argument("value_type") -@click.argument("value") -@click.pass_obj -def manage_write(dbgym_cfg, keys, value_type, value): - config_path = dbgym_cfg.path - config_yaml = dbgym_cfg.yaml - - # Traverse the YAML. - root_yaml = config_yaml - for key in keys[:-1]: - config_yaml = config_yaml[key] - - # Modify the requested YAML value and write the YAML file. - assert type(config_yaml[keys[-1]]) != dict - config_yaml[keys[-1]] = getattr(__builtins__, value_type)(value) - new_yaml = yaml.dump(root_yaml, default_flow_style=False).rstrip() - Path(config_path).write_text(new_yaml) - - task_logger.info(f"Updated: {Path(config_path)}") - - -@click.command(name="standardize") -@click.pass_obj -def manage_standardize(dbgym_cfg): - config_path = dbgym_cfg.path - config_yaml = dbgym_cfg.yaml +# This is used in test_clean.py. It's defined here to avoid a circular import. +class MockDBGymConfig: + def __init__(self, scratchspace_path: Path): + self.dbgym_workspace_path = scratchspace_path + self.dbgym_symlinks_path = get_symlinks_path_from_workspace_path( + scratchspace_path + ) + self.dbgym_runs_path = get_runs_path_from_workspace_path(scratchspace_path) - # Write the YAML file. - new_yaml = yaml.dump(config_yaml, default_flow_style=False).rstrip() - Path(config_path).write_text(new_yaml) - task_logger.info(f"Updated: {Path(config_path)}") +@click.group(name="manage") +def manage_group() -> None: + pass @click.command("clean") @@ -88,13 +41,13 @@ def manage_standardize(dbgym_cfg): default="safe", help='The mode to clean the workspace (default="safe"). "aggressive" means "only keep run_*/ folders referenced by a file in symlinks/". "safe" means "in addition to that, recursively keep any run_*/ folders referenced by any symlinks in run_*/ folders we are keeping."', ) -def manage_clean(dbgym_cfg: DBGymConfig, mode: str): +def manage_clean(dbgym_cfg: DBGymConfig, mode: str) -> None: clean_workspace(dbgym_cfg, mode=mode, verbose=True) @click.command("count") @click.pass_obj -def manage_count(dbgym_cfg: DBGymConfig): +def manage_count(dbgym_cfg: DBGymConfig) -> None: num_files = _count_files_in_workspace(dbgym_cfg) print( f"The workspace ({dbgym_cfg.dbgym_workspace_path}) has {num_files} total files/dirs/symlinks." @@ -102,7 +55,7 @@ def manage_count(dbgym_cfg: DBGymConfig): def add_symlinks_in_dpath( - symlinks_stack: List[Path], root_dpath: Path, processed_symlinks: Set[Path] + symlinks_stack: list[Path], root_dpath: Path, processed_symlinks: set[Path] ) -> None: """ Will modify symlinks_stack and processed_symlinks. @@ -117,7 +70,7 @@ def add_symlinks_in_dpath( processed_symlinks.add(file_path) -def _count_files_in_workspace(dbgym_cfg: DBGymConfig) -> int: +def _count_files_in_workspace(dbgym_cfg: DBGymConfig | MockDBGymConfig) -> int: """ Counts the number of files (regular file or dir or symlink) in the workspace. """ @@ -136,7 +89,9 @@ def _count_files_in_workspace(dbgym_cfg: DBGymConfig) -> int: return total_count -def clean_workspace(dbgym_cfg: DBGymConfig, mode: str = "safe", verbose=False) -> None: +def clean_workspace( + dbgym_cfg: DBGymConfig | MockDBGymConfig, mode: str = "safe", verbose: bool = False +) -> None: """ Clean all [workspace]/task_runs/run_*/ directories that are not referenced by any "active symlinks". If mode is "aggressive", "active symlinks" means *only* the symlinks directly in [workspace]/symlinks/. @@ -144,9 +99,9 @@ def clean_workspace(dbgym_cfg: DBGymConfig, mode: str = "safe", verbose=False) - any symlinks referenced in task_runs/run_*/ directories we have already decided to keep. """ # This stack holds the symlinks that are left to be processed - symlink_fpaths_to_process = [] + symlink_fpaths_to_process: list[Path] = [] # This set holds the symlinks that have already been processed to avoid infinite loops - processed_symlinks = set() + processed_symlinks: set[Path] = set() # 1. Initialize paths to process if dbgym_cfg.dbgym_symlinks_path.exists(): @@ -237,8 +192,5 @@ def clean_workspace(dbgym_cfg: DBGymConfig, mode: str = "safe", verbose=False) - ) -manage_group.add_command(manage_show) -manage_group.add_command(manage_write) -manage_group.add_command(manage_standardize) manage_group.add_command(manage_clean) manage_group.add_command(manage_count) diff --git a/manage/tests/test_clean.py b/manage/tests/test_clean.py index 2ba24249..20beefbf 100644 --- a/manage/tests/test_clean.py +++ b/manage/tests/test_clean.py @@ -4,13 +4,10 @@ import shutil import unittest from pathlib import Path +from typing import Any, NewType, cast -from manage.cli import clean_workspace -from misc.utils import ( - get_runs_path_from_workspace_path, - get_symlinks_path_from_workspace_path, - path_exists_dont_follow_symlinks, -) +from manage.cli import MockDBGymConfig, clean_workspace +from misc.utils import path_exists_dont_follow_symlinks # This is here instead of on `if __name__ == "__main__"` because we often run individual tests, which # does not go through the `if __name__ == "__main__"` codepath. @@ -18,13 +15,7 @@ logging.basicConfig(level=logging.INFO) -class MockDBGymConfig: - def __init__(self, scratchspace_path: Path): - self.dbgym_workspace_path = scratchspace_path - self.dbgym_symlinks_path = get_symlinks_path_from_workspace_path( - scratchspace_path - ) - self.dbgym_runs_path = get_runs_path_from_workspace_path(scratchspace_path) +FilesystemStructure = NewType("FilesystemStructure", dict[str, Any]) class CleanTests(unittest.TestCase): @@ -33,17 +24,23 @@ class CleanTests(unittest.TestCase): losing important files. """ + scratchspace_path: Path = Path() + @staticmethod - def create_structure(root_path: Path, structure: dict) -> None: + def create_structure(root_path: Path, structure: FilesystemStructure) -> None: def create_structure_internal( - root_path: Path, cur_path: Path, structure: dict + root_path: Path, cur_path: Path, structure: FilesystemStructure ) -> None: for path, content in structure.items(): full_path: Path = cur_path / path if isinstance(content, dict): # Directory full_path.mkdir(parents=True, exist_ok=True) - create_structure_internal(root_path, full_path, content) + create_structure_internal( + root_path, + full_path, + FilesystemStructure(cast(dict[str, Any], content)), + ) elif isinstance(content, tuple) and content[0] == "file": assert len(content) == 1 full_path.touch() @@ -58,9 +55,9 @@ def create_structure_internal( create_structure_internal(root_path, root_path, structure) @staticmethod - def verify_structure(root_path: Path, structure: dict) -> bool: + def verify_structure(root_path: Path, structure: FilesystemStructure) -> bool: def verify_structure_internal( - root_path: Path, cur_path: Path, structure: dict + root_path: Path, cur_path: Path, structure: FilesystemStructure ) -> bool: # Check for the presence of each item specified in the structure for name, item in structure.items(): @@ -72,7 +69,11 @@ def verify_structure_internal( if not new_cur_path.is_dir(): logging.debug(f"expected {new_cur_path} to be a directory") return False - if not verify_structure_internal(root_path, new_cur_path, item): + if not verify_structure_internal( + root_path, + new_cur_path, + FilesystemStructure(cast(dict[str, Any], item)), + ): return False elif isinstance(item, tuple) and item[0] == "file": if not new_cur_path.is_file(): @@ -111,36 +112,41 @@ def verify_structure_internal( @staticmethod def make_workspace_structure( - symlinks_structure: dict, task_runs_structure: dict - ) -> dict: + symlinks_structure: FilesystemStructure, + task_runs_structure: FilesystemStructure, + ) -> FilesystemStructure: """ This function exists so that it's easier to refactor the tests in case we ever change how the workspace is organized. """ - return { - "symlinks": symlinks_structure, - "task_runs": task_runs_structure, - } + return FilesystemStructure( + { + "symlinks": symlinks_structure, + "task_runs": task_runs_structure, + } + ) @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: cls.scratchspace_path = Path.cwd() / "manage/tests/test_clean_scratchspace/" - def setUp(self): + def setUp(self) -> None: if self.scratchspace_path.exists(): shutil.rmtree(self.scratchspace_path) - def tearDown(self): + def tearDown(self) -> None: if self.scratchspace_path.exists(): shutil.rmtree(self.scratchspace_path) - def test_structure_helpers(self): - structure = { - "dir1": {"file1.txt": ("file",), "dir2": {"file2.txt": ("file",)}}, - "dir3": {"nested_link_to_dir1": ("symlink", "dir1")}, - "link_to_dir1": ("symlink", "dir1"), - "link_to_file2": ("symlink", "dir1/dir2/file2.txt"), - } + def test_structure_helpers(self) -> None: + structure = FilesystemStructure( + { + "dir1": {"file1.txt": ("file",), "dir2": {"file2.txt": ("file",)}}, + "dir3": {"nested_link_to_dir1": ("symlink", "dir1")}, + "link_to_dir1": ("symlink", "dir1"), + "link_to_file2": ("symlink", "dir1/dir2/file2.txt"), + } + ) CleanTests.create_structure(self.scratchspace_path, structure) self.assertTrue(CleanTests.verify_structure(self.scratchspace_path, structure)) @@ -214,44 +220,46 @@ def test_structure_helpers(self): CleanTests.verify_structure(self.scratchspace_path, wrong_link_structure) ) - def test_nonexistent_workspace(self): + def test_nonexistent_workspace(self) -> None: clean_workspace(MockDBGymConfig(self.scratchspace_path)) - def test_no_symlinks_dir_and_no_task_runs_dir(self): - starting_structure = {} - ending_structure = {} + def test_no_symlinks_dir_and_no_task_runs_dir(self) -> None: + starting_structure = FilesystemStructure({}) + ending_structure = FilesystemStructure({}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path)) self.assertTrue( CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_no_symlinks_dir_and_yes_task_runs_dir(self): - starting_structure = {"task_runs": {"file1.txt": ("file",)}} - ending_structure = {"task_runs": {}} + def test_no_symlinks_dir_and_yes_task_runs_dir(self) -> None: + starting_structure = FilesystemStructure( + {"task_runs": {"file1.txt": ("file",)}} + ) + ending_structure = FilesystemStructure({"task_runs": {}}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path)) self.assertTrue( CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_yes_symlinks_dir_and_no_task_runs_dir(self): - starting_structure = {"symlinks": {}} - ending_structure = {"symlinks": {}} + def test_yes_symlinks_dir_and_no_task_runs_dir(self) -> None: + starting_structure = FilesystemStructure({"symlinks": {}}) + ending_structure = FilesystemStructure({"symlinks": {}}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path)) self.assertTrue( CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_no_symlinks_in_dir_and_no_task_runs_in_dir(self): - starting_symlinks_structure = {} - starting_task_runs_structure = {} + def test_no_symlinks_in_dir_and_no_task_runs_in_dir(self) -> None: + starting_symlinks_structure = FilesystemStructure({}) + starting_task_runs_structure = FilesystemStructure({}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {} - ending_task_runs_structure = {} + ending_symlinks_structure = FilesystemStructure({}) + ending_task_runs_structure = FilesystemStructure({}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -262,14 +270,14 @@ def test_no_symlinks_in_dir_and_no_task_runs_in_dir(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_no_links_in_symlinks(self): - starting_symlinks_structure = {} - starting_task_runs_structure = {"run_0": {}} + def test_no_links_in_symlinks(self) -> None: + starting_symlinks_structure = FilesystemStructure({}) + starting_task_runs_structure = FilesystemStructure({"run_0": {}}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {} - ending_task_runs_structure = {} + ending_symlinks_structure = FilesystemStructure({}) + ending_task_runs_structure = FilesystemStructure({}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -280,14 +288,20 @@ def test_no_links_in_symlinks(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_file_directly_in_task_runs(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/file1.txt")} - starting_task_runs_structure = {"file1.txt": ("file",), "file2.txt": ("file",)} + def test_link_to_file_directly_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/file1.txt")} + ) + starting_task_runs_structure = FilesystemStructure( + {"file1.txt": ("file",), "file2.txt": ("file",)} + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/file1.txt")} - ending_task_runs_structure = {"file1.txt": ("file",)} + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/file1.txt")} + ) + ending_task_runs_structure = FilesystemStructure({"file1.txt": ("file",)}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -298,17 +312,25 @@ def test_link_to_file_directly_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_dir_directly_in_task_runs(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { - "dir1": {"file1.txt": ("file",)}, - "dir2": {"file2.txt": ("file",)}, - } + def test_link_to_dir_directly_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"file1.txt": ("file",)}, + "dir2": {"file2.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = {"dir1": {"file1.txt": ("file",)}} + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + {"dir1": {"file1.txt": ("file",)}} + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -319,21 +341,25 @@ def test_link_to_dir_directly_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_file_in_dir_in_task_runs(self): - starting_symlinks_structure = { - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - starting_task_runs_structure = { - "dir1": {"file1.txt": ("file",)}, - "dir2": {"file2.txt": ("file",)}, - } + def test_link_to_file_in_dir_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"file1.txt": ("file",)}, + "dir2": {"file2.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = { - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - ending_task_runs_structure = {"dir1": {"file1.txt": ("file",)}} + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + ending_task_runs_structure = FilesystemStructure( + {"dir1": {"file1.txt": ("file",)}} + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -344,19 +370,27 @@ def test_link_to_file_in_dir_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_dir_in_dir_in_task_runs(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1/dir2")} - starting_task_runs_structure = { - "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, - "dir3": {"file3.txt": ("file",)}, - } + def test_link_to_dir_in_dir_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/dir2")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, + "dir3": {"file3.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1/dir2")} - ending_task_runs_structure = { - "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/dir2")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": {"dir2": {"file1.txt": ("file",)}, "file2.txt": ("file",)}, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -367,12 +401,16 @@ def test_link_to_dir_in_dir_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_link_crashes(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/symlink2")} - starting_task_runs_structure = { - "symlink2": ("symlink", "task_runs/file1.txt"), - "file1.txt": ("file",), - } + def test_link_to_link_crashes(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/symlink2")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "symlink2": ("symlink", "task_runs/file1.txt"), + "file1.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -381,21 +419,29 @@ def test_link_to_link_crashes(self): with self.assertRaises(AssertionError): clean_workspace(MockDBGymConfig(self.scratchspace_path)) - def test_safe_mode_link_to_dir_with_link(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { - "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, - "file1.txt": ("file",), - "file2.txt": ("file",), - } + def test_safe_mode_link_to_dir_with_link(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, + "file1.txt": ("file",), + "file2.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { - "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, - "file1.txt": ("file",), - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, + "file1.txt": ("file",), + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -406,31 +452,35 @@ def test_safe_mode_link_to_dir_with_link(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_safe_mode_link_to_file_in_dir_with_link(self): - starting_symlinks_structure = { - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - starting_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/file2.txt"), - }, - "file2.txt": ("file",), - "file3.txt": ("file",), - } + def test_safe_mode_link_to_file_in_dir_with_link(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/file2.txt"), + }, + "file2.txt": ("file",), + "file3.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = { - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - ending_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/file2.txt"), - }, - "file2.txt": ("file",), - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/file2.txt"), + }, + "file2.txt": ("file",), + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -441,25 +491,33 @@ def test_safe_mode_link_to_file_in_dir_with_link(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_safe_mode_link_to_dir_with_link_to_file_in_dir_in_task_runs(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { - "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, - "dir2": { - "file2.txt": ("file",), - }, - "file3.txt": ("file",), - } + def test_safe_mode_link_to_dir_with_link_to_file_in_dir_in_task_runs(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, + "dir2": { + "file2.txt": ("file",), + }, + "file3.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { - "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, - "dir2": { - "file2.txt": ("file",), - }, - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/dir2/file2.txt")}, + "dir2": { + "file2.txt": ("file",), + }, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -470,20 +528,28 @@ def test_safe_mode_link_to_dir_with_link_to_file_in_dir_in_task_runs(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_aggressive_mode_link_to_dir_with_link(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { - "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, - "file1.txt": ("file",), - "file2.txt": ("file",), - } + def test_aggressive_mode_link_to_dir_with_link(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/file1.txt")}, + "file1.txt": ("file",), + "file2.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { - "dir1": {"symlink2": ("symlink", None)}, - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", None)}, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -494,14 +560,16 @@ def test_aggressive_mode_link_to_dir_with_link(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_link_to_link_to_file_gives_error(self): - starting_symlinks_structure = { - "symlink1": ("symlink", "task_runs/dir1/symlink2") - } - starting_task_runs_structure = { - "dir1": {"symlink2": ("symlink", "task_runs/file2.txt")}, - "file2.txt": ("file",), - } + def test_link_to_link_to_file_gives_error(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/symlink2")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "task_runs/file2.txt")}, + "file2.txt": ("file",), + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -512,13 +580,15 @@ def test_link_to_link_to_file_gives_error(self): with self.assertRaises(AssertionError): clean_workspace(MockDBGymConfig(self.scratchspace_path), mode="safe") - def test_multi_link_loop_gives_error(self): - starting_symlinks_structure = { - "symlink1": ("symlink", "task_runs/dir1/symlink2") - } - starting_task_runs_structure = { - "dir1": {"symlink2": ("symlink", "symlinks/symlink1")}, - } + def test_multi_link_loop_gives_error(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/symlink2")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": {"symlink2": ("symlink", "symlinks/symlink1")}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -529,9 +599,11 @@ def test_multi_link_loop_gives_error(self): with self.assertRaises(RuntimeError): clean_workspace(MockDBGymConfig(self.scratchspace_path), mode="safe") - def test_link_self_loop_gives_error(self): - starting_symlinks_structure = {"symlink1": ("symlink", "symlinks/symlink1")} - starting_task_runs_structure = dict() + def test_link_self_loop_gives_error(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "symlinks/symlink1")} + ) + starting_task_runs_structure = FilesystemStructure({}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) @@ -544,32 +616,40 @@ def test_link_self_loop_gives_error(self): def test_dont_loop_infinitely_if_there_are_cycles_between_different_dirs_in_runs( self, - ): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir2/file2.txt"), - }, - "dir2": { - "file2.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - } + ) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir2/file2.txt"), + }, + "dir2": { + "file2.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir2/file2.txt"), - }, - "dir2": { - "file2.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir2/file2.txt"), + }, + "dir2": { + "file2.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -582,24 +662,32 @@ def test_dont_loop_infinitely_if_there_are_cycles_between_different_dirs_in_runs def test_dont_loop_infinitely_if_there_is_a_dir_in_runs_that_links_to_a_file_in_itself( self, - ): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - } + ) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -610,24 +698,32 @@ def test_dont_loop_infinitely_if_there_is_a_dir_in_runs_that_links_to_a_file_in_ CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_dont_loop_infinitely_if_there_is_loop_amongst_symlinks(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - } + def test_dont_loop_infinitely_if_there_is_loop_amongst_symlinks(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/file1.txt"), - }, - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/file1.txt"), + }, + } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -638,22 +734,28 @@ def test_dont_loop_infinitely_if_there_is_loop_amongst_symlinks(self): CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_broken_symlink_has_no_effect(self): - starting_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - starting_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "task_runs/dir1/non_existent_file.txt"), - }, - "dir2": {"file2.txt": ("file",)}, - } + def test_broken_symlink_has_no_effect(self) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "task_runs/dir1/non_existent_file.txt"), + }, + "dir2": {"file2.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - ending_symlinks_structure = {"symlink1": ("symlink", "task_runs/dir1")} - ending_task_runs_structure = { - "dir1": {"file1.txt": ("file",), "symlink2": ("symlink", None)} - } + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1")} + ) + ending_task_runs_structure = FilesystemStructure( + {"dir1": {"file1.txt": ("file",), "symlink2": ("symlink", None)}} + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -667,35 +769,41 @@ def test_broken_symlink_has_no_effect(self): # The idea behind this test is that we shouldn't be following links outside of task_runs, even on safe mode def test_link_to_folder_outside_runs_that_contains_link_to_other_run_doesnt_save_other_run( self, - ): - starting_symlinks_structure = { - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - starting_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "external/dir3/file3.txt"), - }, - "dir2": {"file2.txt": ("file",)}, - } + ) -> None: + starting_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + starting_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "external/dir3/file3.txt"), + }, + "dir2": {"file2.txt": ("file",)}, + } + ) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - starting_structure["external"] = { - "dir3": { - "file3.txt": ("file",), - "symlink3": ("symlink", "task_runs/dir2/file2.txt"), + starting_structure["external"] = FilesystemStructure( + { + "dir3": { + "file3.txt": ("file",), + "symlink3": ("symlink", "task_runs/dir2/file2.txt"), + } } - } - ending_symlinks_structure = { - "symlink1": ("symlink", "task_runs/dir1/file1.txt") - } - ending_task_runs_structure = { - "dir1": { - "file1.txt": ("file",), - "symlink2": ("symlink", "external/dir3/file3.txt"), + ) + ending_symlinks_structure = FilesystemStructure( + {"symlink1": ("symlink", "task_runs/dir1/file1.txt")} + ) + ending_task_runs_structure = FilesystemStructure( + { + "dir1": { + "file1.txt": ("file",), + "symlink2": ("symlink", "external/dir3/file3.txt"), + } } - } + ) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) @@ -709,19 +817,19 @@ def test_link_to_folder_outside_runs_that_contains_link_to_other_run_doesnt_save CleanTests.verify_structure(self.scratchspace_path, ending_structure) ) - def test_outside_task_runs_doesnt_get_deleted(self): - starting_symlinks_structure = {} - starting_task_runs_structure = {"dir1": {}} + def test_outside_task_runs_doesnt_get_deleted(self) -> None: + starting_symlinks_structure = FilesystemStructure({}) + starting_task_runs_structure = FilesystemStructure({"dir1": {}}) starting_structure = CleanTests.make_workspace_structure( starting_symlinks_structure, starting_task_runs_structure ) - starting_structure["external"] = {"file1.txt": ("file",)} - ending_symlinks_structure = {} - ending_task_runs_structure = {} + starting_structure["external"] = FilesystemStructure({"file1.txt": ("file",)}) + ending_symlinks_structure = FilesystemStructure({}) + ending_task_runs_structure = FilesystemStructure({}) ending_structure = CleanTests.make_workspace_structure( ending_symlinks_structure, ending_task_runs_structure ) - ending_structure["external"] = {"file1.txt": ("file",)} + ending_structure["external"] = FilesystemStructure({"file1.txt": ("file",)}) CleanTests.create_structure(self.scratchspace_path, starting_structure) clean_workspace(MockDBGymConfig(self.scratchspace_path), mode="safe") diff --git a/misc/__init__.py b/misc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/misc/utils.py b/misc/utils.py index 4a78c352..dd110405 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -1,13 +1,11 @@ import os import shutil import subprocess -import sys from datetime import datetime from enum import Enum from pathlib import Path -from typing import Tuple +from typing import IO, Any, Callable, Optional, Tuple -import click import redis import yaml @@ -34,21 +32,20 @@ # Helper functions that both this file and other files use -def get_symlinks_path_from_workspace_path(workspace_path): +def get_symlinks_path_from_workspace_path(workspace_path: Path) -> Path: return workspace_path / "symlinks" -def get_tmp_path_from_workspace_path(workspace_path): +def get_tmp_path_from_workspace_path(workspace_path: Path) -> Path: return workspace_path / "tmp" -def get_runs_path_from_workspace_path(workspace_path): +def get_runs_path_from_workspace_path(workspace_path: Path) -> Path: return workspace_path / "task_runs" def get_scale_factor_string(scale_factor: float | str) -> str: - assert type(scale_factor) is float or type(scale_factor) is str - if scale_factor == SCALE_FACTOR_PLACEHOLDER: + if type(scale_factor) is str and scale_factor == SCALE_FACTOR_PLACEHOLDER: return scale_factor else: if float(int(scale_factor)) == scale_factor: @@ -57,46 +54,46 @@ def get_scale_factor_string(scale_factor: float | str) -> str: return str(scale_factor).replace(".", "point") -def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: +def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float | str) -> str: return f"{benchmark_name}_sf{get_scale_factor_string(scale_factor)}_pristine_dbdata.tgz" # Other parameters -BENCHMARK_NAME_PLACEHOLDER = "[benchmark_name]" -WORKLOAD_NAME_PLACEHOLDER = "[workload_name]" -SCALE_FACTOR_PLACEHOLDER = "[scale_factor]" +BENCHMARK_NAME_PLACEHOLDER: str = "[benchmark_name]" +WORKLOAD_NAME_PLACEHOLDER: str = "[workload_name]" +SCALE_FACTOR_PLACEHOLDER: str = "[scale_factor]" # Paths of config files in the codebase. These are always relative paths. # The reason these can be relative paths instead of functions taking in codebase_path as input is because relative paths are relative to the codebase root DEFAULT_HPO_SPACE_PATH = PROTOX_EMBEDDING_PATH / "default_hpo_space.json" DEFAULT_SYSKNOBS_PATH = PROTOX_AGENT_PATH / "default_sysknobs.yaml" DEFAULT_BOOT_CONFIG_FPATH = POSTGRES_PATH / "default_boot_config.yaml" -default_benchmark_config_path = ( +default_benchmark_config_path: Callable[[str], Path] = ( lambda benchmark_name: PROTOX_PATH / f"default_{benchmark_name}_benchmark_config.yaml" ) -default_benchbase_config_path = ( +default_benchbase_config_path: Callable[[str], Path] = ( lambda benchmark_name: PROTOX_PATH / f"default_{benchmark_name}_benchbase_config.xml" ) # Generally useful functions -workload_name_fn = ( +workload_name_fn: Callable[[float | str, int, int, str], str] = ( lambda scale_factor, seed_start, seed_end, query_subset: f"workload_sf{get_scale_factor_string(scale_factor)}_{seed_start}_{seed_end}_{query_subset}" ) # Standard names of files/directories. These can refer to either the actual file/directory or a link to the file/directory. # Since they can refer to either the actual or the link, they do not have ".link" in them. -traindata_fname = ( +traindata_fname: Callable[[str, str], str] = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_embedding_traindata.parquet" ) -default_embedder_dname = ( +default_embedder_dname: Callable[[str, str], str] = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_embedder" ) -default_hpoed_agent_params_fname = ( +default_hpoed_agent_params_fname: Callable[[str, str], str] = ( lambda benchmark_name, workload_name: f"{benchmark_name}_{workload_name}_hpoed_agent_params.json" ) -default_tuning_steps_dname = ( +default_tuning_steps_dname: Callable[[str, str, bool], str] = ( lambda benchmark_name, workload_name, boot_enabled_during_tune: f"{benchmark_name}_{workload_name}{'_boot' if boot_enabled_during_tune else ''}_tuning_steps" ) @@ -113,7 +110,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: # folder called run_*/dbgym_agent_protox_tune/tuning_steps. However, replay itself generates an output.log file, which goes in # run_*/dbgym_agent_protox_tune/tuning_steps/. The bug was that my replay function was overwriting the output.log file of the # tuning run. By naming all symlinks "*.link", we avoid the possibility of subtle bugs like this happening. -default_traindata_path = ( +default_traindata_path: Callable[[Path, str, str], Path] = ( lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path( workspace_path ) @@ -121,7 +118,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (traindata_fname(benchmark_name, workload_name) + ".link") ) -default_embedder_path = ( +default_embedder_path: Callable[[Path, str, str], Path] = ( lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path( workspace_path ) @@ -129,7 +126,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (default_embedder_dname(benchmark_name, workload_name) + ".link") ) -default_hpoed_agent_params_path = ( +default_hpoed_agent_params_path: Callable[[Path, str, str], Path] = ( lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path( workspace_path ) @@ -137,7 +134,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (default_hpoed_agent_params_fname(benchmark_name, workload_name) + ".link") ) -default_workload_path = ( +default_workload_path: Callable[[Path, str, str], Path] = ( lambda workspace_path, benchmark_name, workload_name: get_symlinks_path_from_workspace_path( workspace_path ) @@ -145,7 +142,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (workload_name + ".link") ) -default_pristine_dbdata_snapshot_path = ( +default_pristine_dbdata_snapshot_path: Callable[[Path, str, float | str], Path] = ( lambda workspace_path, benchmark_name, scale_factor: get_symlinks_path_from_workspace_path( workspace_path ) @@ -153,10 +150,10 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "data" / (get_dbdata_tgz_name(benchmark_name, scale_factor) + ".link") ) -default_dbdata_parent_dpath = lambda workspace_path: get_tmp_path_from_workspace_path( - workspace_path +default_dbdata_parent_dpath: Callable[[Path], Path] = ( + lambda workspace_path: get_tmp_path_from_workspace_path(workspace_path) ) -default_pgbin_path = ( +default_pgbin_path: Callable[[Path], Path] = ( lambda workspace_path: get_symlinks_path_from_workspace_path(workspace_path) / "dbgym_dbms_postgres" / "build" @@ -166,7 +163,7 @@ def get_dbdata_tgz_name(benchmark_name: str, scale_factor: float) -> str: / "postgres" / "bin" ) -default_tuning_steps_dpath = ( +default_tuning_steps_dpath: Callable[[Path, str, str, bool], Path] = ( lambda workspace_path, benchmark_name, workload_name, boot_enabled_during_tune: get_symlinks_path_from_workspace_path( workspace_path ) @@ -201,7 +198,7 @@ def __init__(self, dbgym_config_path: Path): # Parse the YAML file. contents: str = dbgym_config_path.read_text() - yaml_config: dict = yaml.safe_load(contents) + yaml_config: dict[str, Any] = yaml.safe_load(contents) # Require dbgym_workspace_path to be absolute. # All future paths should be constructed from dbgym_workspace_path. @@ -211,8 +208,8 @@ def __init__(self, dbgym_config_path: Path): self.path: Path = dbgym_config_path self.cur_path_list: list[str] = ["dbgym"] - self.root_yaml: dict = yaml_config - self.cur_yaml: dict = self.root_yaml + self.root_yaml: dict[str, Any] = yaml_config + self.cur_yaml: dict[str, Any] = self.root_yaml # Set and create paths. self.dbgym_repo_path = Path(os.getcwd()) @@ -247,11 +244,11 @@ def __init__(self, dbgym_config_path: Path): # `append_group()` is used to mark the "codebase path" of an invocation of the CLI. The "codebase path" is # explained further in the documentation. - def append_group(self, name) -> None: + def append_group(self, name: str) -> None: self.cur_path_list.append(name) self.cur_yaml = self.cur_yaml.get(name, {}) - def cur_source_path(self, *dirs) -> Path: + def cur_source_path(self, *dirs: str) -> Path: cur_path = self.dbgym_repo_path assert self.cur_path_list[0] == "dbgym" for folder in self.cur_path_list[1:]: @@ -260,7 +257,7 @@ def cur_source_path(self, *dirs) -> Path: cur_path = cur_path / dir return cur_path - def cur_symlinks_path(self, *dirs, mkdir=False) -> Path: + def cur_symlinks_path(self, *dirs: str, mkdir: bool = False) -> Path: flattened_structure = "_".join(self.cur_path_list) cur_path = self.dbgym_symlinks_path / flattened_structure for dir in dirs: @@ -269,7 +266,7 @@ def cur_symlinks_path(self, *dirs, mkdir=False) -> Path: cur_path.mkdir(parents=True, exist_ok=True) return cur_path - def cur_task_runs_path(self, *dirs, mkdir=False) -> Path: + def cur_task_runs_path(self, *dirs: str, mkdir: bool = False) -> Path: flattened_structure = "_".join(self.cur_path_list) cur_path = self.dbgym_this_run_path / flattened_structure for dir in dirs: @@ -278,27 +275,27 @@ def cur_task_runs_path(self, *dirs, mkdir=False) -> Path: cur_path.mkdir(parents=True, exist_ok=True) return cur_path - def cur_symlinks_bin_path(self, *dirs, mkdir=False) -> Path: + def cur_symlinks_bin_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_symlinks_path("bin", *dirs, mkdir=mkdir) - def cur_symlinks_build_path(self, *dirs, mkdir=False) -> Path: + def cur_symlinks_build_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_symlinks_path("build", *dirs, mkdir=mkdir) - def cur_symlinks_data_path(self, *dirs, mkdir=False) -> Path: + def cur_symlinks_data_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_symlinks_path("data", *dirs, mkdir=mkdir) - def cur_task_runs_build_path(self, *dirs, mkdir=False) -> Path: + def cur_task_runs_build_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_task_runs_path("build", *dirs, mkdir=mkdir) - def cur_task_runs_data_path(self, *dirs, mkdir=False) -> Path: + def cur_task_runs_data_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_task_runs_path("data", *dirs, mkdir=mkdir) - def cur_task_runs_artifacts_path(self, *dirs, mkdir=False) -> Path: + def cur_task_runs_artifacts_path(self, *dirs: str, mkdir: bool = False) -> Path: return self.cur_task_runs_path("artifacts", *dirs, mkdir=mkdir) def conv_inputpath_to_realabspath( - dbgym_cfg: DBGymConfig, inputpath: os.PathLike + dbgym_cfg: DBGymConfig, inputpath: os.PathLike[str] ) -> Path: """ Convert any user inputted path to a real, absolute path @@ -329,7 +326,7 @@ def conv_inputpath_to_realabspath( return realabspath -def is_base_git_dir(cwd) -> bool: +def is_base_git_dir(cwd: str) -> bool: """ Returns whether we are in the base directory of some git repository """ @@ -394,7 +391,7 @@ def basename_of_path(dpath: Path) -> str: # TODO(phw2): refactor to use Path -def is_child_path(child_path: os.PathLike, parent_dpath: os.PathLike) -> bool: +def is_child_path(child_path: os.PathLike[str], parent_dpath: os.PathLike[str]) -> bool: """ Checks whether child_path refers to a file/dir/link that is a child of the dir referred to by parent_dpath If the two paths are equal, this function returns FALSE @@ -408,7 +405,7 @@ def is_child_path(child_path: os.PathLike, parent_dpath: os.PathLike) -> bool: ) -def open_and_save(dbgym_cfg: DBGymConfig, open_fpath: Path, mode="r"): +def open_and_save(dbgym_cfg: DBGymConfig, open_fpath: Path, mode: str = "r") -> IO[Any]: """ Open a file and "save" it to [workspace]/task_runs/run_*/. It takes in a str | Path to match the interface of open(). @@ -448,7 +445,7 @@ def open_and_save(dbgym_cfg: DBGymConfig, open_fpath: Path, mode="r"): def extract_from_task_run_fordpath( dbgym_cfg: DBGymConfig, task_run_fordpath: Path -) -> Tuple[Path, str, Path, str]: +) -> tuple[Path, str, Path, str]: """ The task_runs/ folder is organized like task_runs/run_*/[codebase]/[org]/any/path/you/want. This function extracts the [codebase] and [org] components @@ -481,7 +478,7 @@ def extract_from_task_run_fordpath( # TODO(phw2): really look at the clean PR to see what it changed # TODO(phw2): after merging agent-train, refactor some code in agent-train to use save_file() instead of open_and_save() -def save_file(dbgym_cfg: DBGymConfig, fpath: Path) -> Path: +def save_file(dbgym_cfg: DBGymConfig, fpath: Path) -> None: """ If an external function takes in a file/directory as input, you will not be able to call open_and_save(). In these situations, just call save_file(). @@ -544,7 +541,9 @@ def save_file(dbgym_cfg: DBGymConfig, fpath: Path) -> Path: # TODO(phw2): refactor our manual symlinking in postgres/cli.py to use link_result() instead def link_result( - dbgym_cfg: DBGymConfig, result_fordpath: Path, custom_result_name: str | None = None + dbgym_cfg: DBGymConfig, + result_fordpath: Path, + custom_result_name: Optional[str] = None, ) -> Path: """ result_fordpath must be a "result", meaning it was generated inside dbgym_cfg.dbgym_this_run_path. @@ -564,7 +563,7 @@ def link_result( assert is_child_path(result_fordpath, dbgym_cfg.dbgym_this_run_path) assert not os.path.islink(result_fordpath) - if custom_result_name != None: + if type(custom_result_name) is str: result_name = custom_result_name else: if os.path.isfile(result_fordpath): diff --git a/scripts/mypy.ini b/scripts/mypy.ini new file mode 100644 index 00000000..98ef8d68 --- /dev/null +++ b/scripts/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +strict = True +ignore_missing_imports = True diff --git a/scripts/read_parquet.py b/scripts/read_parquet.py index 161aec35..7158ce6a 100644 --- a/scripts/read_parquet.py +++ b/scripts/read_parquet.py @@ -1,9 +1,10 @@ import sys +from pathlib import Path import pandas as pd -def read_and_print_parquet(file_path): +def read_and_print_parquet(file_path: Path) -> None: # Read the Parquet file into a DataFrame df = pd.read_parquet(file_path) @@ -14,7 +15,7 @@ def read_and_print_parquet(file_path): if __name__ == "__main__": # Specify the path to the Parquet file - parquet_file_path = sys.argv[0] + parquet_file_path = Path(sys.argv[0]) # Call the function to read and print the Parquet file read_and_print_parquet(parquet_file_path) diff --git a/task.py b/task.py index 7871fdc4..37ac3a69 100644 --- a/task.py +++ b/task.py @@ -16,7 +16,7 @@ @click.group() @click.pass_context -def task(ctx): +def task(ctx: click.Context) -> None: """💩💩💩 CMU-DB Database Gym: github.com/cmu-db/dbgym 💩💩💩""" dbgym_config_path = Path(os.getenv("DBGYM_CONFIG_PATH", "dbgym_config.yaml")) ctx.obj = DBGymConfig(dbgym_config_path) diff --git a/tune/cli.py b/tune/cli.py index 7d1f98f1..5b5dec02 100644 --- a/tune/cli.py +++ b/tune/cli.py @@ -6,7 +6,7 @@ @click.group(name="tune") @click.pass_obj -def tune_group(dbgym_cfg: DBGymConfig): +def tune_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("tune") diff --git a/tune/protox/agent/__init__.py b/tune/protox/agent/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/agent/agent_env.py b/tune/protox/agent/agent_env.py index b5af657b..4a69c2ef 100644 --- a/tune/protox/agent/agent_env.py +++ b/tune/protox/agent/agent_env.py @@ -1,6 +1,6 @@ import copy import inspect -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import gymnasium as gym import numpy as np @@ -12,7 +12,7 @@ def __init__(self, env: gym.Env[Any, Any]): super().__init__(env) self.class_attributes = dict(inspect.getmembers(self.__class__)) - def reset(self, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: observations, info = self.env.reset(**kwargs) self._check_val(event="reset", observations=observations) self._observations = observations @@ -20,7 +20,7 @@ def reset(self, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: def step( self, actions: NDArray[np.float32] - ) -> Tuple[Any, float, bool, bool, dict[str, Any]]: + ) -> tuple[Any, float, bool, bool, dict[str, Any]]: self._actions = actions observations, rewards, term, trunc, infos = self.env.step(actions) @@ -50,7 +50,7 @@ def __getattr__(self, name: str) -> Any: return self.getattr_recursive(name) - def _get_all_attributes(self) -> Dict[str, Any]: + def _get_all_attributes(self) -> dict[str, Any]: """Get all (inherited) instance and class attributes :return: all_attributes @@ -97,7 +97,7 @@ def getattr_depth_check(self, name: str, already_found: bool) -> str: def check_array_value( self, name: str, value: NDArray[np.float32] - ) -> List[Tuple[str, str]]: + ) -> list[tuple[str, str]]: """ Check for inf and NaN for a single numpy array. diff --git a/tune/protox/agent/base_class.py b/tune/protox/agent/base_class.py index 3f999335..e681cbb3 100644 --- a/tune/protox/agent/base_class.py +++ b/tune/protox/agent/base_class.py @@ -6,6 +6,7 @@ import numpy as np from numpy.typing import NDArray +from misc.utils import TuningMode from tune.protox.agent.agent_env import AgentEnv from tune.protox.agent.noise import ActionNoise from tune.protox.env.logger import Logger @@ -75,7 +76,9 @@ def _setup_learn( return total_timesteps @abstractmethod - def learn(self, env: AgentEnv, total_timesteps: int) -> None: + def learn( + self, env: AgentEnv, total_timesteps: int, tuning_mode: TuningMode + ) -> None: """ Return a trained model. diff --git a/tune/protox/agent/buffers.py b/tune/protox/agent/buffers.py index d2b7e351..d4de74d4 100644 --- a/tune/protox/agent/buffers.py +++ b/tune/protox/agent/buffers.py @@ -12,7 +12,7 @@ class ReplayBufferSamples(NamedTuple): next_observations: th.Tensor dones: th.Tensor rewards: th.Tensor - infos: List[dict[str, Any]] + infos: list[dict[str, Any]] class ReplayBuffer: @@ -68,7 +68,7 @@ def add( action: NDArray[np.float32], reward: float, done: bool, - infos: Dict[str, Any], + infos: dict[str, Any], ) -> None: # Reshape to handle multi-dim and discrete action spaces, see GH #970 #1392 action = action.reshape((self.action_dim)) diff --git a/tune/protox/agent/build_trial.py b/tune/protox/agent/build_trial.py index 4bd23ae5..7c84366f 100644 --- a/tune/protox/agent/build_trial.py +++ b/tune/protox/agent/build_trial.py @@ -1,7 +1,5 @@ import glob import json -import os -import shutil import socket import xml.etree.ElementTree as ET from pathlib import Path @@ -11,8 +9,12 @@ import numpy as np import torch from gymnasium.wrappers import FlattenObservation # type: ignore -from gymnasium.wrappers import NormalizeObservation, NormalizeReward +from gymnasium.wrappers import ( # type: ignore[attr-defined] + NormalizeObservation, + NormalizeReward, +) from torch import nn +from torch.optim import Adam # type: ignore[attr-defined] from misc.utils import ( DBGymConfig, @@ -62,7 +64,7 @@ def _parse_activation_fn(act_type: str) -> type[nn.Module]: raise ValueError(f"Unsupported activation type {act_type}") -def _get_signal(signal_folder: Union[str, Path]) -> Tuple[int, str]: +def _get_signal(signal_folder: Union[str, Path]) -> tuple[int, str]: MIN_PORT = 5434 MAX_PORT = 5500 @@ -85,7 +87,7 @@ def _get_signal(signal_folder: Union[str, Path]) -> Tuple[int, str]: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) continue - with open(f"{signal_folder}/{port}.signal", "w") as f: # type: IO[Any] + with open(f"{signal_folder}/{port}.signal", "w") as f: f.write(str(port)) f.close() @@ -124,8 +126,9 @@ def _modify_benchbase_config( def _gen_noise_scale( vae_config: dict[str, Any], hpo_params: dict[str, Any] -) -> Callable[[ProtoAction, torch.Tensor], ProtoAction]: - def f(p: ProtoAction, n: torch.Tensor) -> ProtoAction: +) -> Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction]: + def f(p: ProtoAction, n: Optional[torch.Tensor]) -> ProtoAction: + assert n is not None if hpo_params["scale_noise_perturb"]: return ProtoAction( torch.clamp( @@ -143,7 +146,7 @@ def _build_utilities( tuning_mode: TuningMode, pgport: int, hpo_params: dict[str, Any], -) -> Tuple[Logger, RewardUtility, PostgresConn, Workload]: +) -> tuple[Logger, RewardUtility, PostgresConn, Workload]: logger = Logger( dbgym_cfg, hpo_params["trace"], @@ -202,7 +205,7 @@ def _build_actions( hpo_params: dict[str, Any], workload: Workload, logger: Logger, -) -> Tuple[HolonSpace, LSC]: +) -> tuple[HolonSpace, LSC]: sysknobs = LatentKnobSpace( logger=logger, tables=hpo_params["benchmark_config"]["tables"], @@ -335,7 +338,7 @@ def _build_env( workload: Workload, reward_utility: RewardUtility, logger: Logger, -) -> Tuple[TargetResetWrapper, AgentEnv]: +) -> tuple[TargetResetWrapper, AgentEnv]: env = gym.make( "Postgres-v0", @@ -434,9 +437,7 @@ def _build_agent( policy_weight_adjustment=hpo_params["policy_weight_adjustment"], ) - actor_optimizer = torch.optim.Adam( - actor.parameters(), lr=hpo_params["learning_rate"] - ) + actor_optimizer = Adam(actor.parameters(), lr=hpo_params["learning_rate"]) critic = ContinuousCritic( observation_space=observation_space, @@ -462,7 +463,7 @@ def _build_agent( action_dim=critic_action_dim, ) - critic_optimizer = torch.optim.Adam( + critic_optimizer = Adam( critic.parameters(), lr=hpo_params["learning_rate"] * hpo_params["critic_lr_scale"], ) @@ -539,7 +540,7 @@ def build_trial( seed: int, hpo_params: dict[str, Any], ray_trial_id: Optional[str] = None, -) -> Tuple[Logger, TargetResetWrapper, AgentEnv, Wolp, str]: +) -> tuple[Logger, TargetResetWrapper, AgentEnv, Wolp, str]: # The massive trial builder. port, signal = _get_signal(hpo_params["pgconn_info"]["pgbin_path"]) diff --git a/tune/protox/agent/cli.py b/tune/protox/agent/cli.py index 98f7bb22..fcc85ee1 100644 --- a/tune/protox/agent/cli.py +++ b/tune/protox/agent/cli.py @@ -8,7 +8,7 @@ @click.group("agent") @click.pass_obj -def agent_group(dbgym_cfg: DBGymConfig): +def agent_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("agent") diff --git a/tune/protox/agent/hpo.py b/tune/protox/agent/hpo.py index 05ca46ef..b4acddb5 100644 --- a/tune/protox/agent/hpo.py +++ b/tune/protox/agent/hpo.py @@ -6,7 +6,7 @@ import time from datetime import datetime from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional, Type, Union import click import numpy as np @@ -54,26 +54,26 @@ class AgentHPOArgs: def __init__( self, - benchmark_name, - workload_name, - embedder_path, - benchmark_config_path, - benchbase_config_path, - sysknobs_path, - pristine_dbdata_snapshot_path, - dbdata_parent_dpath, - pgbin_path, - workload_path, - seed, - agent, - max_concurrent, - num_samples, - tune_duration_during_hpo, - workload_timeout, - query_timeout, - enable_boot_during_hpo, - boot_config_fpath_during_hpo, - build_space_good_for_boot, + benchmark_name: str, + workload_name: str, + embedder_path: Path, + benchmark_config_path: Path, + benchbase_config_path: Path, + sysknobs_path: Path, + pristine_dbdata_snapshot_path: Path, + dbdata_parent_dpath: Path, + pgbin_path: Path, + workload_path: Path, + seed: int, + agent: str, + max_concurrent: int, + num_samples: int, + tune_duration_during_hpo: float, + workload_timeout: float, + query_timeout: float, + enable_boot_during_hpo: bool, + boot_config_fpath_during_hpo: Path, + build_space_good_for_boot: bool, ): self.benchmark_name = benchmark_name self.workload_name = workload_name @@ -119,35 +119,38 @@ def __init__( ) @click.option( "--scale-factor", + type=float, default=1.0, help=f"The scale factor used when generating the data of the benchmark.", ) @click.option( "--embedder-path", + type=Path, default=None, help=f"The path to the directory that contains an `embedder.pth` file with a trained encoder and decoder as well as a `config` file. The default is {default_embedder_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}", ) @click.option( "--benchmark-config-path", - default=None, type=Path, + default=None, help=f"The path to the .yaml config file for the benchmark. The default is {default_benchmark_config_path(BENCHMARK_NAME_PLACEHOLDER)}.", ) @click.option( "--benchbase-config-path", - default=None, type=Path, + default=None, help=f"The path to the .xml config file for BenchBase, used to run OLTP workloads. The default is {default_benchbase_config_path(BENCHMARK_NAME_PLACEHOLDER)}.", ) @click.option( "--sysknobs-path", + type=Path, default=DEFAULT_SYSKNOBS_PATH, help=f"The path to the file configuring the space of system knobs the tuner can tune.", ) @click.option( "--pristine-dbdata-snapshot-path", - default=None, type=Path, + default=None, help=f"The path to the .tgz snapshot of the dbdata directory to use as a starting point for tuning. The default is {default_pristine_dbdata_snapshot_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, SCALE_FACTOR_PLACEHOLDER)}.", ) @click.option( @@ -158,57 +161,62 @@ def __init__( ) @click.option( "--dbdata-parent-dpath", - default=None, type=Path, + default=None, help=f"The path to the parent directory of the dbdata which will be actively tuned. The default is {default_dbdata_parent_dpath(WORKSPACE_PATH_PLACEHOLDER)}.", ) @click.option( "--pgbin-path", - default=None, type=Path, + default=None, help=f"The path to the bin containing Postgres executables. The default is {default_pgbin_path(WORKSPACE_PATH_PLACEHOLDER)}.", ) @click.option( "--workload-path", - default=None, type=Path, + default=None, help=f"The path to the directory that specifies the workload (such as its queries and order of execution). The default is {default_workload_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}.", ) @click.option( "--seed", - default=None, type=int, + default=None, help="The seed used for all sources of randomness (random, np, torch, etc.). The default is a random value.", ) @click.option( - "--agent", default="wolp", help=f"The RL algorithm to use for the tuning agent." + "--agent", + type=str, + default="wolp", + help=f"The RL algorithm to use for the tuning agent.", ) @click.option( "--max-concurrent", + type=int, default=1, help=f"The max # of concurrent agent models to train. Note that unlike in HPO, all will use the same hyperparameters. This just helps control for other sources of randomness.", ) @click.option( "--num-samples", + type=int, default=40, help=f"The # of times to specific hyperparameter configs to sample from the hyperparameter search space and train agent models with.", ) @click.option( "--tune-duration-during-hpo", - default=4, type=float, + default=4.0, help="The number of hours to run each hyperparamer config tuning trial for.", ) @click.option( "--workload-timeout", - default=DEFAULT_WORKLOAD_TIMEOUT, type=int, + default=DEFAULT_WORKLOAD_TIMEOUT, help="The timeout (in seconds) of a workload. We run the workload once per DBMS configuration. For OLAP workloads, certain configurations may be extremely suboptimal, so we need to time out the workload.", ) @click.option( "--query-timeout", - default=30, type=int, + default=30, help="The timeout (in seconds) of a query. See the help of --workload-timeout for the motivation of this.", ) @click.option( @@ -218,8 +226,8 @@ def __init__( ) @click.option( "--boot-config-fpath-during-hpo", - default=DEFAULT_BOOT_CONFIG_FPATH, type=Path, + default=DEFAULT_BOOT_CONFIG_FPATH, help="The path to the file configuring Boot when running HPO. When tuning, you may use a different Boot config.", ) # Building a space good for Boot is subtly different from whether we enable Boot during HPO. @@ -240,58 +248,58 @@ def __init__( help="Whether to avoid certain options that are known to not perform well when Boot is enabled. See the codebase for why this is subtly different from --enable-boot-during-hpo.", ) def hpo( - dbgym_cfg, - benchmark_name, - seed_start, - seed_end, - query_subset, - scale_factor, - embedder_path, - benchmark_config_path, - benchbase_config_path, - sysknobs_path, - pristine_dbdata_snapshot_path, - intended_dbdata_hardware, - dbdata_parent_dpath, - pgbin_path, - workload_path, - seed, - agent, - max_concurrent, - num_samples, - tune_duration_during_hpo, - workload_timeout, - query_timeout, + dbgym_cfg: DBGymConfig, + benchmark_name: str, + seed_start: int, + seed_end: int, + query_subset: str, + scale_factor: float, + embedder_path: Optional[Path], + benchmark_config_path: Optional[Path], + benchbase_config_path: Optional[Path], + sysknobs_path: Path, + pristine_dbdata_snapshot_path: Optional[Path], + intended_dbdata_hardware: str, + dbdata_parent_dpath: Optional[Path], + pgbin_path: Optional[Path], + workload_path: Optional[Path], + seed: Optional[int], + agent: str, + max_concurrent: int, + num_samples: int, + tune_duration_during_hpo: float, + workload_timeout: int, + query_timeout: int, enable_boot_during_hpo: bool, boot_config_fpath_during_hpo: Path, build_space_good_for_boot: bool, -): +) -> None: # Set args to defaults programmatically (do this before doing anything else in the function) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if embedder_path == None: + if embedder_path is None: embedder_path = default_embedder_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) - if benchmark_config_path == None: + if benchmark_config_path is None: benchmark_config_path = default_benchmark_config_path(benchmark_name) - if benchbase_config_path == None: + if benchbase_config_path is None: benchbase_config_path = default_benchbase_config_path(benchmark_name) - if pristine_dbdata_snapshot_path == None: + if pristine_dbdata_snapshot_path is None: pristine_dbdata_snapshot_path = default_pristine_dbdata_snapshot_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, scale_factor ) - if dbdata_parent_dpath == None: + if dbdata_parent_dpath is None: dbdata_parent_dpath = default_dbdata_parent_dpath( dbgym_cfg.dbgym_workspace_path ) - if pgbin_path == None: + if pgbin_path is None: pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path) - if workload_path == None: + if workload_path is None: workload_path = default_workload_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) - if seed == None: - seed = random.randint(0, 1e8) + if seed is None: + seed = random.randint(0, int(1e8)) # Convert all input paths to absolute paths embedder_path = conv_inputpath_to_realabspath(dbgym_cfg, embedder_path) @@ -358,15 +366,15 @@ def build_space( benchmark_config: dict[str, Any], workload_path: Path, embedder_path: list[Path], - pgconn_info: dict[str, str], + pgconn_info: dict[str, Path], benchbase_config: dict[str, Any] = {}, - tune_duration_during_hpo: int = 30, + tune_duration_during_hpo: float = 30.0, seed: int = 0, enable_boot_during_hpo: bool = False, - boot_config_fpath_during_hpo: Path = None, + boot_config_fpath_during_hpo: Path = Path(), build_space_good_for_boot: bool = False, - workload_timeouts: list[int] = [600], - query_timeouts: list[int] = [30], + workload_timeouts: list[float] = [600.0], + query_timeouts: list[float] = [30.0], ) -> dict[str, Any]: return { @@ -543,7 +551,7 @@ def __init__( ), "If we're doing HPO, we will create multiple TuneTrial() objects. We thus need to differentiate them somehow." else: assert ( - ray_trial_id == None + ray_trial_id is None ), "If we're not doing HPO, we (currently) will create only one TuneTrial() object. For clarity, we set ray_trial_id to None since ray_trial_id should not be used in this case." self.ray_trial_id = ray_trial_id @@ -551,7 +559,7 @@ def setup(self, hpo_params: dict[str, Any]) -> None: # Attach mythril directory to the search path. sys.path.append(os.path.expanduser(self.dbgym_cfg.dbgym_repo_path)) - torch.set_default_dtype(torch.float32) # type: ignore + torch.set_default_dtype(torch.float32) # type: ignore[no-untyped-call] seed = ( hpo_params["seed"] if hpo_params["seed"] != -1 @@ -640,7 +648,7 @@ def step(self) -> dict[Any, Any]: def cleanup(self) -> None: self.logger.flush() - self.env.close() # type: ignore + self.env.close() # type: ignore[no-untyped-call] if Path(self.signal).exists(): os.remove(self.signal) @@ -650,7 +658,10 @@ def cleanup(self) -> None: # Using a function to create a class is Ray's recommended way of doing this (see # https://discuss.ray.io/t/using-static-variables-to-control-trainable-subclass-in-ray-tune/808/4) # If you don't create the class with a function, it doesn't work due to how Ray serializes classes -def create_tune_opt_class(dbgym_cfg_param): +global_dbgym_cfg: DBGymConfig + + +def create_tune_opt_class(dbgym_cfg_param: DBGymConfig) -> Type[Trainable]: global global_dbgym_cfg global_dbgym_cfg = dbgym_cfg_param @@ -697,20 +708,23 @@ def _tune_hpo(dbgym_cfg: DBGymConfig, hpo_args: AgentHPOArgs) -> None: workload_timeouts = [hpo_args.workload_timeout] query_timeouts = [hpo_args.query_timeout] - benchbase_config = ( - { - "oltp_config": { - "oltp_num_terminals": hpo_args.oltp_num_terminals, - "oltp_duration": hpo_args.oltp_duration, - "oltp_sf": hpo_args.oltp_sf, - "oltp_warmup": hpo_args.oltp_warmup, - }, - "benchbase_path": hpo_args.benchbase_path, - "benchbase_config_path": hpo_args.benchbase_config_path, - } - if is_oltp - else {} - ) + assert not is_oltp + benchbase_config: dict[str, Any] = {} + # This is commented out because OLTP is currently not implemented. + # benchbase_config = ( + # { + # "oltp_config": { + # "oltp_num_terminals": hpo_args.oltp_num_terminals, + # "oltp_duration": hpo_args.oltp_duration, + # "oltp_sf": hpo_args.oltp_sf, + # "oltp_warmup": hpo_args.oltp_warmup, + # }, + # "benchbase_path": hpo_args.benchbase_path, + # "benchbase_config_path": hpo_args.benchbase_config_path, + # } + # if is_oltp + # else {} + # ) space = build_space( sysknobs, @@ -738,7 +752,7 @@ def _tune_hpo(dbgym_cfg: DBGymConfig, hpo_args: AgentHPOArgs) -> None: ) # Scheduler. - scheduler = FIFOScheduler() # type: ignore + scheduler = FIFOScheduler() # type: ignore[no-untyped-call] # Search. search = BasicVariantGenerator(max_concurrent=hpo_args.max_concurrent) @@ -761,7 +775,7 @@ def _tune_hpo(dbgym_cfg: DBGymConfig, hpo_args: AgentHPOArgs) -> None: sync_config=SyncConfig(), verbose=2, log_to_file=True, - storage_path=dbgym_cfg.cur_task_runs_path("hpo_ray_results", mkdir=True), + storage_path=str(dbgym_cfg.cur_task_runs_path("hpo_ray_results", mkdir=True)), ) tuner = ray.tune.Tuner( diff --git a/tune/protox/agent/off_policy_algorithm.py b/tune/protox/agent/off_policy_algorithm.py index fd33004e..36567b29 100644 --- a/tune/protox/agent/off_policy_algorithm.py +++ b/tune/protox/agent/off_policy_algorithm.py @@ -43,7 +43,7 @@ def __init__( replay_buffer: ReplayBuffer, learning_starts: int = 100, batch_size: int = 256, - train_freq: Tuple[int, str] = (1, "step"), + train_freq: tuple[int, str] = (1, "step"), gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, seed: Optional[int] = None, @@ -62,7 +62,7 @@ def __init__( # Save train freq parameter, will be converted later to TrainFreq object self.train_freq = self._convert_train_freq(train_freq) - def _convert_train_freq(self, train_freq: Tuple[int, str]) -> TrainFreq: + def _convert_train_freq(self, train_freq: tuple[int, str]) -> TrainFreq: """ Convert `train_freq` parameter (int or tuple) to a TrainFreq object. @@ -91,7 +91,7 @@ def _store_transition( new_obs: NDArray[np.float32], reward: float, dones: bool, - infos: Dict[str, Any], + infos: dict[str, Any], ) -> None: """ Store transition in the replay buffer. @@ -135,7 +135,7 @@ def _sample_action( self, learning_starts: int, action_noise: Optional[ActionNoise] = None, - ) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: + ) -> tuple[NDArray[np.float32], NDArray[np.float32]]: raise NotImplementedError() def collect_rollouts( diff --git a/tune/protox/agent/policies.py b/tune/protox/agent/policies.py index 85fedb8a..3464539d 100644 --- a/tune/protox/agent/policies.py +++ b/tune/protox/agent/policies.py @@ -83,7 +83,7 @@ def __init__( self, observation_space: spaces.Space[Any], action_space: spaces.Space[Any], - net_arch: List[int], + net_arch: list[int], features_dim: int, activation_fn: Type[nn.Module] = nn.ReLU, weight_init: Optional[str] = None, @@ -150,7 +150,7 @@ def __init__( self, observation_space: spaces.Space[Any], action_space: spaces.Space[Any], - net_arch: List[int], + net_arch: list[int], features_dim: int, activation_fn: Type[nn.Module] = nn.ReLU, weight_init: Optional[str] = None, @@ -178,7 +178,7 @@ def __init__( self.add_module(f"qf{idx}", q_net) self.q_networks.append(q_net) - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + def forward(self, obs: th.Tensor, actions: th.Tensor) -> tuple[th.Tensor, ...]: with th.set_grad_enabled(True): features = self.extract_features(obs) qvalue_input = th.cat([features, actions], dim=1) diff --git a/tune/protox/agent/replay.py b/tune/protox/agent/replay.py index 6c59ba5c..90a4b7ce 100644 --- a/tune/protox/agent/replay.py +++ b/tune/protox/agent/replay.py @@ -9,7 +9,9 @@ import json import logging import pickle +from datetime import datetime from pathlib import Path +from typing import Any, Optional, Set, cast import click import pandas as pd @@ -28,8 +30,9 @@ from tune.protox.agent.build_trial import build_trial from tune.protox.env.pg_env import PostgresEnv from tune.protox.env.space.holon_space import HolonSpace +from tune.protox.env.space.primitive.index import IndexAction from tune.protox.env.space.utils import fetch_server_indexes, fetch_server_knobs -from tune.protox.env.types import HolonAction +from tune.protox.env.types import ActionsInfo, HolonAction from tune.protox.env.workload import Workload REPLAY_DATA_FNAME = "replay_data.csv" @@ -38,11 +41,13 @@ class ReplayArgs: def __init__( self, - workload_timeout_during_replay: bool, + # If it's None, it'll get set later on inside replay_tuning_run(). + workload_timeout_during_replay: Optional[float], replay_all_variations: bool, simulated: bool, - cutoff: float, - blocklist: list, + # If it's None, it'll get set later on inside replay_tuning_run(). + cutoff: Optional[float], + blocklist: list[str], ): self.workload_timeout_during_replay = workload_timeout_during_replay self.replay_all_variations = replay_all_variations @@ -73,6 +78,7 @@ def __init__( ) @click.option( "--scale-factor", + type=float, default=1.0, help="The scale factor used when generating the data of the benchmark.", ) @@ -83,14 +89,14 @@ def __init__( ) @click.option( "--tuning-steps-dpath", - default=None, type=Path, + default=None, help="The path to the `tuning_steps` directory to be replayed.", ) @click.option( "--workload-timeout-during-replay", + type=float, default=None, - type=int, # You can make it use the workload timeout used during tuning if you want. # I just made it use the workload timeout from HPO because I don't currently persist the tuning HPO params. help="The timeout (in seconds) of a workload when replaying. By default, it will be equal to the workload timeout used during HPO.", @@ -107,14 +113,14 @@ def __init__( ) @click.option( "--cutoff", - default=None, type=float, + default=None, help='Only evaluate configs up to cutoff hours. None means "evaluate all configs".', ) @click.option( "--blocklist", + type=list[str], default=[], - type=list, help="Ignore running queries in the blocklist.", ) def replay( @@ -125,17 +131,17 @@ def replay( query_subset: str, scale_factor: float, boot_enabled_during_tune: bool, - tuning_steps_dpath: Path, - workload_timeout_during_replay: bool, + tuning_steps_dpath: Optional[Path], + workload_timeout_during_replay: Optional[float], replay_all_variations: bool, simulated: bool, - cutoff: float, - blocklist: list, + cutoff: Optional[float], + blocklist: list[str], ) -> None: # Set args to defaults programmatically (do this before doing anything else in the function) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if tuning_steps_dpath == None: + if tuning_steps_dpath is None: tuning_steps_dpath = default_tuning_steps_dpath( dbgym_cfg.dbgym_workspace_path, benchmark_name, @@ -161,7 +167,7 @@ def replay( def replay_tuning_run( dbgym_cfg: DBGymConfig, tuning_steps_dpath: Path, replay_args: ReplayArgs -): +) -> None: """ Replay a single tuning run (as in one tuning_steps/ folder). """ @@ -174,7 +180,7 @@ def _is_tuning_step_line(line: str) -> bool: hpo_params = json.load(f) # Set defaults that depend on hpo_params - if replay_args.workload_timeout_during_replay == None: + if replay_args.workload_timeout_during_replay is None: replay_args.workload_timeout_during_replay = hpo_params["workload_timeout"][ str(TuningMode.HPO) ] @@ -190,6 +196,7 @@ def _is_tuning_step_line(line: str) -> bool: # This finds all the [time] folders in tuning_steps/ (except "baseline" since we ignore that in `_is_tuning_step_line()`), # so you could just do `ls tuning_steps/` if you wanted to. folders = [] + start_time: Optional[datetime] = None start_found = False output_log_fpath = tuning_steps_dpath / "output.log" with open_and_save(dbgym_cfg, output_log_fpath, "r") as f: @@ -209,8 +216,9 @@ def _is_tuning_step_line(line: str) -> bool: time_since_start = parse( line.split("DEBUG:")[-1].split(" Running")[0].split("[")[0] ) + assert type(start_time) is datetime if ( - replay_args.cutoff == None + replay_args.cutoff is None or (time_since_start - start_time).total_seconds() < replay_args.cutoff * 3600 ): @@ -225,8 +233,8 @@ def _is_tuning_step_line(line: str) -> bool: _, _, agent_env, _, _ = build_trial( dbgym_cfg, TuningMode.REPLAY, hpo_params["seed"], hpo_params ) - pg_env: PostgresEnv = agent_env.unwrapped - action_space: HolonSpace = pg_env.action_space + pg_env: PostgresEnv = cast(PostgresEnv, agent_env.unwrapped) + action_space: HolonSpace = cast(HolonSpace, pg_env.action_space) # Reset things. if not replay_args.simulated: @@ -241,7 +249,9 @@ def _is_tuning_step_line(line: str) -> bool: num_lines += 1 # A convenience wrapper around execute_workload() which fills in the arguments properly and processes the return values. - def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: + def _execute_workload_wrapper( + actions_info: ActionsInfo, + ) -> tuple[int, int, bool, float]: logging.info( f"\n\nfetch_server_knobs(): {fetch_server_knobs(pg_env.pg_conn.conn(), action_space.get_knob_space().tables, action_space.get_knob_space().knobs, pg_env.workload.queries)}\n\n" ) @@ -267,6 +277,7 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: # will not have had a chance to run at all. Based on the behavior of `_mutilate_action_with_metrics()`, we select # an arbitrary variation fo the queries that have not executed at all. best_observed_holon_action = actions_info["best_observed_holon_action"] + assert best_observed_holon_action is not None actions = [best_observed_holon_action] variation_names = ["BestObserved"] @@ -299,8 +310,7 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: current_step = 0 start_found = False start_time = None - maximal_repo = None - existing_index_acts = [] + existing_index_acts: set[IndexAction] = set() for line in f: # Keep going until we've found the start. @@ -316,19 +326,10 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: continue elif _is_tuning_step_line(line): - if _is_tuning_step_line(line): - repo = eval(line.split("Running ")[-1])[-1] - time_since_start = parse( - line.split("DEBUG:")[-1].split(" Running")[0].split("[")[0] - ) - elif "Found new maximal state with" in line: - repo = eval(maximal_repo.split("Running ")[-1])[-1] - time_since_start = parse( - maximal_repo.split("DEBUG:")[-1] - .split(" Running")[0] - .split("[")[0] - ) - maximal_repo = None + repo = eval(line.split("Running ")[-1])[-1] + time_since_start = parse( + line.split("DEBUG:")[-1].split(" Running")[0].split("[")[0] + ) # Get the original runtime as well as whether any individual queries and/or the full workload timed out. run_raw_csv_fpath = tuning_steps_dpath / repo / "run.raw.csv" @@ -367,7 +368,7 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: with open_and_save( dbgym_cfg, tuning_steps_dpath / repo / "action.pkl", "rb" ) as f: - actions_info = pickle.load(f) + actions_info: ActionsInfo = pickle.load(f) all_holon_action_variations = actions_info[ "all_holon_action_variations" ] @@ -451,6 +452,7 @@ def _execute_workload_wrapper(actions_info: list["HolonAction"]) -> list[float]: ) # Perform some validity checks and then add this tuning step's data to `run_data``. + assert isinstance(start_time, datetime) this_step_run_data = { "step": current_step, "time_since_start": (time_since_start - start_time).total_seconds(), diff --git a/tune/protox/agent/torch_layers.py b/tune/protox/agent/torch_layers.py index 941478c4..1be91a1e 100644 --- a/tune/protox/agent/torch_layers.py +++ b/tune/protox/agent/torch_layers.py @@ -33,14 +33,14 @@ def init_layer( def create_mlp( input_dim: int, output_dim: int, - net_arch: List[int], + net_arch: list[int], activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False, with_bias: bool = True, weight_init: Optional[str] = None, bias_zero: bool = False, final_layer_adjust: float = 1.0, -) -> List[nn.Module]: +) -> list[nn.Module]: """ Create a multi layer perceptron (MLP), which is a collection of fully-connected layers each followed by an activation function. diff --git a/tune/protox/agent/tune.py b/tune/protox/agent/tune.py index 2ec6045b..c9a3467c 100644 --- a/tune/protox/agent/tune.py +++ b/tune/protox/agent/tune.py @@ -89,7 +89,7 @@ def tune( """IMPORTANT: The "tune" here is the one in "tune a DBMS". This is *different* from the "tune" in ray.tune.TuneConfig, which means to "tune hyperparameters".""" # Set args to defaults programmatically (do this before doing anything else in the function) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if hpoed_agent_params_path == None: + if hpoed_agent_params_path is None: hpoed_agent_params_path = default_hpoed_agent_params_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) @@ -120,7 +120,7 @@ def tune( ) # Set defaults that depend on hpo_params - if tune_duration_during_tune == None: + if tune_duration_during_tune is None: tune_duration_during_tune = hpo_params["tune_duration"][str(TuningMode.HPO)] # Set the hpo_params that are allowed to differ between HPO, tuning, and replay. diff --git a/tune/protox/agent/wolp/__init__.py b/tune/protox/agent/wolp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/agent/wolp/policies.py b/tune/protox/agent/wolp/policies.py index c4294f0d..4882cbd3 100644 --- a/tune/protox/agent/wolp/policies.py +++ b/tune/protox/agent/wolp/policies.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from gymnasium import spaces from numpy.typing import NDArray -from torch.optim import Optimizer +from torch.optim import Optimizer # type: ignore[attr-defined] from tune.protox.agent.buffers import ReplayBufferSamples from tune.protox.agent.noise import ActionNoise @@ -98,7 +98,7 @@ def discriminate( embed_actions: th.Tensor, actions_dim: th.Tensor, env_actions: list[HolonAction], - ) -> Tuple[list[HolonAction], th.Tensor]: + ) -> tuple[list[HolonAction], th.Tensor]: states_tile = states.repeat_interleave(actions_dim, dim=0) if use_target: next_q_values = th.cat( @@ -140,7 +140,7 @@ def wolp_act( action_noise: Optional[Union[ActionNoise, th.Tensor]] = None, neighbor_parameters: NeighborParameters = DEFAULT_NEIGHBOR_PARAMETERS, random_act: bool = False, - ) -> Tuple[list[HolonAction], th.Tensor]: + ) -> tuple[list[HolonAction], th.Tensor]: # Get the tensor representation. start_time = time.time() if not isinstance(states, th.Tensor): @@ -244,7 +244,9 @@ def train_critic( self.critic_optimizer.zero_grad() assert not th.isnan(critic_loss).any() critic_loss.backward() # type: ignore - th.nn.utils.clip_grad_norm_(list(self.critic.parameters()), self.grad_clip, error_if_nonfinite=True) # type: ignore + th.nn.utils.clip_grad_norm_( + list(self.critic.parameters()), self.grad_clip, error_if_nonfinite=True + ) self.critic.check_grad() self.critic_optimizer.step() return critic_loss @@ -282,7 +284,9 @@ def train_actor(self, replay_data: ReplayBufferSamples) -> Any: self.actor_optimizer.zero_grad() assert not th.isnan(actor_loss).any() actor_loss.backward() # type: ignore - th.nn.utils.clip_grad_norm_(list(self.actor.parameters()), self.grad_clip, error_if_nonfinite=True) # type: ignore + th.nn.utils.clip_grad_norm_( + list(self.actor.parameters()), self.grad_clip, error_if_nonfinite=True + ) self.actor.check_grad() self.actor_optimizer.step() return actor_loss diff --git a/tune/protox/agent/wolp/wolp.py b/tune/protox/agent/wolp/wolp.py index ba519258..6b4f5c8e 100644 --- a/tune/protox/agent/wolp/wolp.py +++ b/tune/protox/agent/wolp/wolp.py @@ -47,12 +47,12 @@ def __init__( replay_buffer: ReplayBuffer, learning_starts: int = 100, batch_size: int = 100, - train_freq: Tuple[int, str] = (1, "episode"), + train_freq: tuple[int, str] = (1, "episode"), gradient_steps: int = -1, action_noise: Optional[ActionNoise] = None, target_action_noise: Optional[ActionNoise] = None, seed: Optional[int] = None, - neighbor_parameters: Dict[str, Any] = {}, + neighbor_parameters: dict[str, Any] = {}, ray_trial_id: Optional[str] = None, ): super().__init__( @@ -77,7 +77,7 @@ def _store_transition( new_obs: NDArray[np.float32], reward: float, dones: bool, - infos: Dict[str, Any], + infos: dict[str, Any], ) -> None: """ Store transition in the replay buffer. @@ -124,7 +124,7 @@ def _sample_action( self, learning_starts: int, action_noise: Optional[ActionNoise] = None, - ) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: + ) -> tuple[NDArray[np.float32], NDArray[np.float32]]: """ Sample an action according to the exploration policy. This is either done by sampling the probability distribution of the policy, diff --git a/tune/protox/cli.py b/tune/protox/cli.py index 15160f89..4c09f383 100644 --- a/tune/protox/cli.py +++ b/tune/protox/cli.py @@ -7,7 +7,7 @@ @click.group(name="protox") @click.pass_obj -def protox_group(dbgym_cfg: DBGymConfig): +def protox_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("protox") diff --git a/tune/protox/embedding/__init__.py b/tune/protox/embedding/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/embedding/analyze.py b/tune/protox/embedding/analyze.py index d895e9ca..8a1cc44d 100644 --- a/tune/protox/embedding/analyze.py +++ b/tune/protox/embedding/analyze.py @@ -7,6 +7,7 @@ import shutil import time from pathlib import Path +from typing import Any, Optional import numpy as np import torch @@ -28,20 +29,21 @@ from tune.protox.embedding.trainer import StratifiedRandomSampler from tune.protox.embedding.vae import VAELoss, gen_vae_collate from tune.protox.env.space.latent_space.latent_index_space import LatentIndexSpace +from tune.protox.env.types import ProtoAction, TableAttrAccessSetsMap from tune.protox.env.workload import Workload STATS_FNAME = "stats.txt" RANGES_FNAME = "ranges.txt" -def compute_num_parts(num_samples: int): +def compute_num_parts(num_samples: int) -> int: # TODO(phw2): in the future, implement running different parts in parallel, set OMP_NUM_THREADS accordingly, and investigate the effect of having more parts # TODO(phw2): if having more parts is effective, figure out a good way to specify num_parts (can it be determined automatically or should it be a CLI arg?) # TODO(phw2): does anything bad happen if num_parts doesn't evenly divide num_samples? return 1 -def redist_trained_models(dbgym_cfg: DBGymConfig, num_parts: int): +def redist_trained_models(dbgym_cfg: DBGymConfig, num_parts: int) -> None: """ Redistribute all embeddings_*/ folders inside the run_*/ folder into num_parts subfolders """ @@ -64,7 +66,7 @@ def analyze_all_embeddings_parts( num_parts: int, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Analyze all part*/ dirs _in parallel_ """ @@ -83,7 +85,7 @@ def _analyze_embeddings_part( part_i: int, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Analyze (meaning create both stats.txt and ranges.txt) all the embedding models in the part[part_i]/ dir """ @@ -107,7 +109,7 @@ def _create_stats_for_part( part_dpath: Path, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Creates a stats.txt file inside each embeddings_*/models/epoch*/ dir inside this part*/ dir TODO(wz2): what does stats.txt contain? @@ -124,9 +126,7 @@ def _create_stats_for_part( ) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - models = itertools.chain(*[part_dpath.rglob("config")]) - models = [m for m in models] - print(f"models={models}") + models = [m for m in itertools.chain(*[part_dpath.rglob("config")])] for model_config in tqdm.tqdm(models): if ((Path(model_config).parent) / "FAILED").exists(): print("Detected failure in: ", model_config) @@ -192,7 +192,7 @@ def _create_stats_for_part( vae_loss = VAELoss(config["loss_fn"], max_attrs, max_cat_features) # Construct the accumulator. - accumulated_stats = {} + accumulated_stats: dict[str, list[Any]] = {} for class_idx in class_mapping: accumulated_stats[f"recon_{class_idx}"] = [] @@ -320,7 +320,7 @@ def _create_ranges_for_part( part_dpath: Path, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Create the ranges.txt for all models in part_dpath TODO(wz2): what does ranges.txt contain? @@ -341,7 +341,7 @@ def _create_ranges_for_embedder( embedder_fpath: Path, generic_args: EmbeddingTrainGenericArgs, analyze_args: EmbeddingAnalyzeArgs, -): +) -> None: """ Create the ranges.txt file corresponding to a specific part*/embeddings_*/models/epoch*/embedder_*.pth file """ @@ -376,9 +376,9 @@ def _create_ranges_for_embedder( lambda x: torch.nn.Sigmoid()(x) * config["output_scale"] ) - def index_noise_scale(x, n): + def index_noise_scale(x: ProtoAction, n: Optional[torch.Tensor]) -> ProtoAction: assert n is None - return torch.clamp(x, 0.0, config["output_scale"]) + return ProtoAction(torch.clamp(x, 0.0, config["output_scale"])) max_attrs, max_cat_features = fetch_vae_parameters_from_workload( workload, len(tables) @@ -392,10 +392,10 @@ def index_noise_scale(x, n): tables=tables, max_num_columns=max_num_columns, max_indexable_attributes=workload.max_indexable(), - seed=np.random.randint(1, 1e10), + seed=np.random.randint(1, int(1e10)), rel_metadata=copy.deepcopy(modified_attrs), attributes_overwrite=copy.deepcopy(modified_attrs), - tbl_include_subsets={}, + tbl_include_subsets=TableAttrAccessSetsMap({}), vae=vae, index_space_aux_type=False, index_space_aux_include=False, @@ -418,7 +418,7 @@ def index_noise_scale(x, n): ranges_fpath = epoch_dpath / RANGES_FNAME with open(ranges_fpath, "w") as f: for _ in tqdm.tqdm(range(num_segments), total=num_segments, leave=False): - classes = {} + classes: dict[str, int] = {} with torch.no_grad(): points = ( torch.rand(analyze_args.num_points_to_sample, config["latent_dim"]) @@ -444,18 +444,18 @@ def index_noise_scale(x, n): if idx_class not in classes: classes[idx_class] = 0 classes[idx_class] += 1 - classes = sorted( + sorted_classes = sorted( [(k, v) for k, v in classes.items()], key=lambda x: x[1], reverse=True ) if analyze_args.num_classes_to_keep != 0: - classes = classes[: analyze_args.num_classes_to_keep] + sorted_classes = sorted_classes[: analyze_args.num_classes_to_keep] f.write(f"Generating range {base} - {base + output_scale}\n") f.write( "\n".join( [ f"{k}: {v / analyze_args.num_points_to_sample}" - for (k, v) in classes + for (k, v) in sorted_classes ] ) ) diff --git a/tune/protox/embedding/cli.py b/tune/protox/embedding/cli.py index 264afed9..0e8829b9 100644 --- a/tune/protox/embedding/cli.py +++ b/tune/protox/embedding/cli.py @@ -7,7 +7,7 @@ @click.group("embedding") @click.pass_obj -def embedding_group(dbgym_cfg: DBGymConfig): +def embedding_group(dbgym_cfg: DBGymConfig) -> None: dbgym_cfg.append_group("embedding") diff --git a/tune/protox/embedding/datagen.py b/tune/protox/embedding/datagen.py index 53defc2b..aa75d280 100644 --- a/tune/protox/embedding/datagen.py +++ b/tune/protox/embedding/datagen.py @@ -8,14 +8,16 @@ from itertools import chain, combinations from multiprocessing import Pool from pathlib import Path +from typing import Any, NewType, Optional, cast import click import numpy as np import pandas as pd +import psycopg import yaml from sklearn.preprocessing import quantile_transform -from dbms.postgres.cli import create_conn, start_postgres, stop_postgres +from dbms.postgres.cli import start_postgres, stop_postgres from misc.utils import ( BENCHMARK_NAME_PLACEHOLDER, SCALE_FACTOR_PLACEHOLDER, @@ -37,8 +39,14 @@ ) from tune.protox.embedding.loss import COST_COLUMNS from tune.protox.env.space.primitive_space.index_space import IndexSpace -from tune.protox.env.types import QueryType +from tune.protox.env.types import ( + QuerySpec, + QueryType, + TableAttrAccessSetsMap, + TableAttrListMap, +) from tune.protox.env.workload import Workload +from util.pg import create_psycopg_conn from util.shell import subprocess_run # FUTURE(oltp) @@ -49,6 +57,11 @@ # pass +QueryBatches = NewType( + "QueryBatches", list[tuple[str, list[tuple[QueryType, str]], Any]] +) + + # click steup @click.command() @click.pass_obj @@ -74,6 +87,7 @@ ) @click.option( "--scale-factor", + type=float, default=1.0, help=f"The scale factor used when generating the data of the benchmark.", ) @@ -86,8 +100,8 @@ # TODO(phw2): need to run pgtune before gathering data @click.option( "--pristine-dbdata-snapshot-path", - default=None, type=Path, + default=None, help=f"The path to the .tgz snapshot of the dbdata directory to build an embedding space over. The default is {default_pristine_dbdata_snapshot_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, SCALE_FACTOR_PLACEHOLDER)}.", ) @click.option( @@ -98,60 +112,60 @@ ) @click.option( "--dbdata-parent-dpath", - default=None, type=Path, + default=None, help=f"The path to the parent directory of the dbdata which will be actively tuned. The default is {default_pristine_dbdata_snapshot_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, SCALE_FACTOR_PLACEHOLDER)}.", ) @click.option( "--benchmark-config-path", - default=None, type=Path, + default=None, help=f"The path to the .yaml config file for the benchmark. The default is {default_benchmark_config_path(BENCHMARK_NAME_PLACEHOLDER)}.", ) @click.option( "--workload-path", - default=None, type=Path, + default=None, help=f"The path to the directory that specifies the workload (such as its queries and order of execution). The default is {default_workload_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}.", ) @click.option( "--seed", - default=None, type=int, + default=None, help="The seed used for all sources of randomness (random, np, torch, etc.). The default is a random value.", ) # dir gen args @click.option( "--leading-col-tbls", - default=None, type=str, + default=None, help='All tables included here will have indexes created s.t. each column is represented equally often as the "leading column" of the index.', ) # TODO(wz2): what if we sample tbl_sample_limit / len(cols) for tables in leading_col_tbls? this way, tbl_sample_limit will always represent the total # of indexes created on that table. currently the description of the param is a bit weird as you can see @click.option( "--default-sample-limit", - default=2048, type=int, + default=2048, help="The default sample limit of all tables, used unless override sample limit is specified. If the table is in --leading-col-tbls, sample limit is # of indexes to sample per column for that table table. If the table is in --leading-col-tbls, sample limit is the # of indexes to sample total for that table.", ) @click.option( "--override-sample-limits", - default=None, type=str, + default=None, help='Override the sample limit for specific tables. An example input would be "lineitem,32768,orders,4096".', ) # TODO(wz2): if I'm just outputting out.parquet instead of the full directory, do we even need file limit at all? @click.option( "--file-limit", - default=1024, type=int, + default=1024, help="The max # of data points (one data point = one hypothetical index) per file", ) @click.option( "--max-concurrent", - default=None, type=int, + default=None, help="The max # of concurrent threads that will be creating hypothetical indexes. The default is `nproc`.", ) # TODO(wz2): when would we not want to generate costs? @@ -160,33 +174,33 @@ # file gen args @click.option("--table-shape", is_flag=True, help="TODO(wz2)") @click.option("--dual-class", is_flag=True, help="TODO(wz2)") -@click.option("--pad-min", default=None, type=int, help="TODO(wz2)") -@click.option("--rebias", default=0, type=float, help="TODO(wz2)") +@click.option("--pad-min", type=int, default=None, help="TODO(wz2)") +@click.option("--rebias", type=float, default=0, help="TODO(wz2)") def datagen( - dbgym_cfg, - benchmark_name, - seed_start, - seed_end, - query_subset, - scale_factor, - pgbin_path, - pristine_dbdata_snapshot_path, - intended_dbdata_hardware, - dbdata_parent_dpath, - benchmark_config_path, - workload_path, - seed, - leading_col_tbls, - default_sample_limit, - override_sample_limits, - file_limit, - max_concurrent, - no_generate_costs, - table_shape, - dual_class, - pad_min, - rebias, -): + dbgym_cfg: DBGymConfig, + benchmark_name: str, + seed_start: int, + seed_end: int, + query_subset: str, + scale_factor: float, + pgbin_path: Optional[Path], + pristine_dbdata_snapshot_path: Optional[Path], + intended_dbdata_hardware: str, + dbdata_parent_dpath: Optional[Path], + benchmark_config_path: Optional[Path], + workload_path: Optional[Path], + seed: Optional[int], + leading_col_tbls: str, + default_sample_limit: int, + override_sample_limits: Optional[str], + file_limit: int, + max_concurrent: Optional[int], + no_generate_costs: bool, + table_shape: bool, + dual_class: bool, + pad_min: int, + rebias: float, +) -> None: """ Samples the effects of indexes on the workload as estimated by HypoPG. Outputs all this data as a .parquet file in the run_*/ dir. @@ -201,26 +215,27 @@ def datagen( # TODO(phw2): figure out whether different scale factors use the same config # TODO(phw2): figure out what parts of the config should be taken out (like stuff about tables) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if benchmark_config_path == None: + if benchmark_config_path is None: benchmark_config_path = default_benchmark_config_path(benchmark_name) - if workload_path == None: + if workload_path is None: workload_path = default_workload_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) - if pgbin_path == None: + if pgbin_path is None: pgbin_path = default_pgbin_path(dbgym_cfg.dbgym_workspace_path) - if pristine_dbdata_snapshot_path == None: + if pristine_dbdata_snapshot_path is None: pristine_dbdata_snapshot_path = default_pristine_dbdata_snapshot_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, scale_factor ) - if dbdata_parent_dpath == None: + if dbdata_parent_dpath is None: dbdata_parent_dpath = default_dbdata_parent_dpath( dbgym_cfg.dbgym_workspace_path ) - if max_concurrent == None: + if max_concurrent is None: max_concurrent = os.cpu_count() - if seed == None: - seed = random.randint(0, 1e8) + assert max_concurrent is not None + if seed is None: + seed = random.randint(0, int(1e8)) # Convert all input paths to absolute paths workload_path = conv_inputpath_to_realabspath(dbgym_cfg, workload_path) @@ -246,22 +261,21 @@ def datagen( assert False # Process the "data structure" args - leading_col_tbls = [] if leading_col_tbls == None else leading_col_tbls.split(",") + leading_col_tbls_parsed: list[str] = ( + [] if leading_col_tbls is None else leading_col_tbls.split(",") + ) # I chose to only use the "," delimiter in override_sample_limits_str, so the dictionary is encoded as [key],[value],[key],[value] # I felt this was better than introducing a new delimiter which might conflict with the name of a table - if override_sample_limits == None: - override_sample_limits = dict() - else: - override_sample_limits_str = override_sample_limits - override_sample_limits = dict() - override_sample_limits_str_split = override_sample_limits_str.split(",") + override_sample_limits_parsed: dict[str, int] = dict() + if override_sample_limits is not None: + override_sample_limits_str_split = override_sample_limits.split(",") assert ( len(override_sample_limits_str_split) % 2 == 0 - ), f'override_sample_limits ("{override_sample_limits_str}") does not have an even number of values' + ), f'override_sample_limits ("{override_sample_limits}") does not have an even number of values' for i in range(0, len(override_sample_limits_str_split), 2): tbl = override_sample_limits_str_split[i] limit = int(override_sample_limits_str_split[i + 1]) - override_sample_limits[tbl] = limit + override_sample_limits_parsed[tbl] = limit # Group args together to reduce the # of parameters we pass into functions # I chose to group them into separate objects instead because it felt hacky to pass a giant args object into every function @@ -276,9 +290,9 @@ def datagen( dbdata_parent_dpath, ) dir_gen_args = EmbeddingDirGenArgs( - leading_col_tbls, + leading_col_tbls_parsed, default_sample_limit, - override_sample_limits, + override_sample_limits_parsed, file_limit, max_concurrent, no_generate_costs, @@ -331,14 +345,14 @@ class EmbeddingDatagenGenericArgs: def __init__( self, - benchmark_name, - workload_name, - scale_factor, - benchmark_config_path, - seed, - workload_path, - pristine_dbdata_snapshot_path, - dbdata_parent_dpath, + benchmark_name: str, + workload_name: str, + scale_factor: float, + benchmark_config_path: Path, + seed: int, + workload_path: Path, + pristine_dbdata_snapshot_path: Path, + dbdata_parent_dpath: Path, ): self.benchmark_name = benchmark_name self.workload_name = workload_name @@ -355,12 +369,12 @@ class EmbeddingDirGenArgs: def __init__( self, - leading_col_tbls, - default_sample_limit, - override_sample_limits, - file_limit, - max_concurrent, - no_generate_costs, + leading_col_tbls: list[str], + default_sample_limit: int, + override_sample_limits: dict[str, int], + file_limit: int, + max_concurrent: int, + no_generate_costs: bool, ): self.leading_col_tbls = leading_col_tbls self.default_sample_limit = default_sample_limit @@ -373,25 +387,31 @@ def __init__( class EmbeddingFileGenArgs: """Same comment as EmbeddingDatagenGenericArgs""" - def __init__(self, table_shape, dual_class, pad_min, rebias): + def __init__( + self, table_shape: bool, dual_class: bool, pad_min: int, rebias: float + ): self.table_shape = table_shape self.dual_class = dual_class self.pad_min = pad_min self.rebias = rebias -def get_traindata_dir(dbgym_cfg): +def get_traindata_dir(dbgym_cfg: DBGymConfig) -> Path: return dbgym_cfg.dbgym_this_run_path / "traindata_dir" -def _gen_traindata_dir(dbgym_cfg: DBGymConfig, generic_args, dir_gen_args): +def _gen_traindata_dir( + dbgym_cfg: DBGymConfig, + generic_args: EmbeddingDatagenGenericArgs, + dir_gen_args: EmbeddingDirGenArgs, +) -> None: with open_and_save(dbgym_cfg, generic_args.benchmark_config_path, "r") as f: benchmark_config = yaml.safe_load(f) - max_num_columns = benchmark_config["protox"]["max_num_columns"] - tables = benchmark_config["protox"]["tables"] - attributes = benchmark_config["protox"]["attributes"] - query_spec = benchmark_config["protox"]["query_spec"] + max_num_columns: int = benchmark_config["protox"]["max_num_columns"] + tables: list[str] = benchmark_config["protox"]["tables"] + attributes: TableAttrListMap = benchmark_config["protox"]["attributes"] + query_spec: QuerySpec = benchmark_config["protox"]["query_spec"] workload = Workload( dbgym_cfg, tables, attributes, query_spec, generic_args.workload_path, pid=None @@ -403,10 +423,10 @@ def _gen_traindata_dir(dbgym_cfg: DBGymConfig, generic_args, dir_gen_args): results = [] job_id = 0 for tbl in tables: - cols = ( + cols: list[Optional[str]] = ( [None] if tbl not in dir_gen_args.leading_col_tbls - else modified_attrs[tbl] + else cast(list[Optional[str]], modified_attrs[tbl]) ) for colidx, col in enumerate(cols): if col is None: @@ -455,7 +475,7 @@ def _combine_traindata_dir_into_parquet( dbgym_cfg: DBGymConfig, generic_args: EmbeddingDatagenGenericArgs, file_gen_args: EmbeddingFileGenArgs, -): +) -> None: tbl_dirs = {} with open_and_save(dbgym_cfg, generic_args.benchmark_config_path, "r") as f: benchmark_config = yaml.safe_load(f) @@ -561,14 +581,10 @@ def read(file: Path) -> pd.DataFrame: link_result(dbgym_cfg, traindata_path) -def _all_subsets(ss): - return chain(*map(lambda x: combinations(ss, x), range(0, len(ss) + 1))) +_INDEX_SERVER_COUNTS: dict[str, int] = {} -_INDEX_SERVER_COUNTS = {} - - -def _fetch_server_indexes(connection): +def _fetch_server_indexes(connection: psycopg.Connection[Any]) -> None: global _INDEX_SERVER_COUNTS query = """ SELECT t.relname as table_name, i.relname as index_name @@ -595,26 +611,28 @@ def _fetch_server_indexes(connection): # return models -def _write(data, output_dir, batch_num): +def _write(data: list[dict[str, Any]], output_dir: Path, batch_num: int) -> None: df = pd.DataFrame(data) - cols = [c for c in df if "col" in c and "str" not in c] + cols = [c for c in df.columns if "col" in c and "str" not in c] df[cols] = df[cols].astype(int) - df.to_parquet(f"{output_dir}/{batch_num}.parquet") + df.to_parquet(output_dir / f"{batch_num}.parquet") del df -def _augment_query_data(workload, data): +def _augment_query_data(workload: Workload, data: dict[str, float]) -> dict[str, float]: for qstem, value in workload.queries_mix.items(): if qstem in data: data[qstem] *= value return data -def _execute_explains(cursor, batches, models): - data = {} - ou_model_data = {} +def _execute_explains( + cursor: psycopg.Cursor[Any], batches: QueryBatches, models: Optional[dict[Any, Any]] +) -> dict[str, float]: + data: dict[str, float] = {} + ou_model_data: dict[str, list[Any]] = {} - def acquire_model_data(q, plan): + def acquire_model_data(q: str, plan: dict[str, Any]) -> None: nonlocal ou_model_data node_tag = plan["Node Type"] node_tag = node_tag.replace(" ", "") @@ -700,15 +718,23 @@ def acquire_model_data(q, plan): return data -def _extract_refs(generate_costs, target, cursor, workload, models): +def _extract_refs( + generate_costs: bool, + target: Optional[str], + cursor: psycopg.Cursor[Any], + workload: Workload, + models: Optional[dict[Any, Any]], +) -> tuple[dict[str, float], dict[str, float]]: ref_qs = {} table_ref_qs = {} if generate_costs: # Get reference costs. - batches = [ - (q, workload.queries[q], workload.query_aliases[q]) - for q in workload.queries.keys() - ] + batches = QueryBatches( + [ + (q, workload.queries[q], workload.query_aliases[q]) + for q in workload.queries.keys() + ] + ) ref_qs = _execute_explains(cursor, batches, models) ref_qs = _augment_query_data(workload, ref_qs) @@ -717,28 +743,30 @@ def _extract_refs(generate_costs, target, cursor, workload, models): table_ref_qs = ref_qs else: qs = workload.queries_for_table(target) - batches = [(q, workload.queries[q], workload.query_aliases[q]) for q in qs] + batches = QueryBatches( + [(q, workload.queries[q], workload.query_aliases[q]) for q in qs] + ) table_ref_qs = _execute_explains(cursor, batches, models) table_ref_qs = _augment_query_data(workload, table_ref_qs) return ref_qs, table_ref_qs def _produce_index_data( - dbgym_cfg, - tables, - attributes, - query_spec, - workload_path, - max_num_columns, - seed, - generate_costs, - sample_limit, - target, - leading_col, - leading_col_name, - p, - output, -): + dbgym_cfg: DBGymConfig, + tables: list[str], + attributes: TableAttrListMap, + query_spec: QuerySpec, + workload_path: Path, + max_num_columns: int, + seed: int, + generate_costs: bool, + sample_limit: int, + target: Optional[str], + leading_col: Optional[int], + leading_col_name: Optional[str], + p: int, + output: Path, +) -> None: models = None # FUTURE(oltp) @@ -746,9 +774,7 @@ def _produce_index_data( # models = load_ou_models(model_dir) # Construct workload. - workload = Workload( - dbgym_cfg, tables, attributes, query_spec, workload_path, pid=str(p) - ) + workload = Workload(dbgym_cfg, tables, attributes, query_spec, workload_path, pid=p) modified_attrs = workload.column_usages() np.random.seed(seed) @@ -763,7 +789,7 @@ def _produce_index_data( seed=seed, rel_metadata=copy.deepcopy(modified_attrs), attributes_overwrite=copy.deepcopy(modified_attrs), - tbl_include_subsets={}, + tbl_include_subsets=TableAttrAccessSetsMap({}), index_space_aux_type=False, index_space_aux_include=False, deterministic_policy=False, @@ -780,7 +806,7 @@ def _produce_index_data( # there are no indexes to generate. return - with create_conn() as connection: + with create_psycopg_conn() as connection: _fetch_server_indexes(connection) if generate_costs: try: @@ -792,8 +818,7 @@ def _produce_index_data( reference_qs, table_reference_qs = _extract_refs( generate_costs, target, cursor, workload, models ) - cached_refs = {} - accum_data = [] + accum_data: list[dict[str, Any]] = [] # Repeatedly... for i in range(sample_limit): @@ -810,7 +835,7 @@ def _produce_index_data( ) ia = idxs.to_action(act) - accum = { + accum: dict[str, Any] = { "table": ia.tbl_name, } if generate_costs: @@ -847,10 +872,12 @@ def _produce_index_data( else: qs_for_tbl = workload.queries_for_table(ia.tbl_name) - batches = [ - (q, workload.queries[q], workload.query_aliases[q]) - for q in qs_for_tbl - ] + batches = QueryBatches( + [ + (q, workload.queries[q], workload.query_aliases[q]) + for q in qs_for_tbl + ] + ) data = _execute_explains(cursor, batches, models) data = _augment_query_data(workload, data) if models is None: @@ -889,6 +916,7 @@ def _produce_index_data( for i in range(max_num_columns): accum[f"col{i}"] = 0 + assert ia.col_idxs is not None for i, col_idx in enumerate(ia.col_idxs): accum[f"col{i}"] = col_idx + 1 diff --git a/tune/protox/embedding/loss.py b/tune/protox/embedding/loss.py index b3f28843..5fbd85c6 100644 --- a/tune/protox/embedding/loss.py +++ b/tune/protox/embedding/loss.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from pytorch_metric_learning import losses # type: ignore -from pytorch_metric_learning.utils import common_functions as c_f # type: ignore +from pytorch_metric_learning import losses +from pytorch_metric_learning.utils import common_functions as c_f COST_COLUMNS = [ "quant_mult_cost_improvement", @@ -24,11 +24,11 @@ def get_loss(distance_fn: str) -> nn.Module: def get_bias_fn( config: dict[str, Any] ) -> Callable[ - [torch.Tensor, torch.Tensor], Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + [torch.Tensor, torch.Tensor], Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] ]: def bias_fn( data: torch.Tensor, labels: torch.Tensor - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: red_index = COST_COLUMNS.index(config["cost_reduction_type"]) distance_scale = config["distance_scale"] if distance_scale == "auto": @@ -74,7 +74,7 @@ def _distance_cost( targets: torch.Tensor, bias: Callable[ [torch.Tensor, torch.Tensor], - Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], + Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor], ], output_scale: float, ) -> Any: diff --git a/tune/protox/embedding/select.py b/tune/protox/embedding/select.py index 1e28dce0..613730b6 100644 --- a/tune/protox/embedding/select.py +++ b/tune/protox/embedding/select.py @@ -2,10 +2,12 @@ import os import shutil from pathlib import Path +from typing import Any, Optional import numpy as np import pandas as pd import tqdm +from pandas import DataFrame from misc.utils import DBGymConfig, default_embedder_dname, link_result from tune.protox.embedding.analyze import RANGES_FNAME, STATS_FNAME @@ -15,12 +17,6 @@ ) -class DotDict(dict): - __getattr__ = dict.get - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - def select_best_embeddings( dbgym_cfg: DBGymConfig, generic_args: EmbeddingTrainGenericArgs, @@ -28,9 +24,7 @@ def select_best_embeddings( ) -> None: data = _load_data(dbgym_cfg, select_args) - if generic_args.traindata_path is not None and os.path.exists( - generic_args.traindata_path - ): + if generic_args.traindata_path is not None and generic_args.traindata_path.exists(): raw_data = pd.read_parquet(generic_args.traindata_path) data = _attach(data, raw_data, select_args.idx_limit) @@ -55,6 +49,8 @@ def select_best_embeddings( if select_args.flatten_idx == -1: for tup in df.itertuples(): + assert type(tup.path) is str + assert type(tup.root) is str shutil.copytree( tup.path, curated_dpath / tup.path, @@ -69,6 +65,8 @@ def select_best_embeddings( info_txt = open(curated_dpath / "info.txt", "w") for loop_i, tup in enumerate(df.itertuples()): + assert type(tup.path) is str + assert type(tup.root) is str epoch = int(str(tup.path).split("epoch")[-1]) model_dpath = curated_dpath / f"model{idx}" shutil.copytree(tup.path, model_dpath) @@ -97,8 +95,8 @@ def select_best_embeddings( info_txt.close() -def _load_data(dbgym_cfg, select_args): - data = [] +def _load_data(dbgym_cfg: DBGymConfig, select_args: EmbeddingSelectArgs) -> DataFrame: + stat_infos = [] stats = [s for s in dbgym_cfg.dbgym_this_run_path.rglob(STATS_FNAME)] print(f"stats={stats}") for stat in stats: @@ -126,7 +124,7 @@ def _load_data(dbgym_cfg, select_args): with open(stat.parent.parent.parent / "config", "r") as f: config = json.load(f) - def recurse_set(source, target): + def recurse_set(source: dict[Any, Any], target: dict[Any, Any]) -> None: for k, v in source.items(): if isinstance(v, dict): recurse_set(v, target) @@ -147,10 +145,9 @@ def recurse_set(source, target): info["ranges_file"] = str(Path(stat).parent / RANGES_FNAME) - data.append(info) + stat_infos.append(info) - print(f"data={data}") - data = pd.DataFrame(data) + data = DataFrame(stat_infos) data = data.loc[:, ~(data == data.iloc[0]).all()] if "output_scale" not in data: @@ -162,15 +159,17 @@ def recurse_set(source, target): return data -def _attach(data, raw_data, num_limit=0): +def _attach(data: DataFrame, raw_data: DataFrame, num_limit: int = 0) -> DataFrame: # As the group index goes up, the perf should go up (i.e., bounds should tighten) - filtered_data = {} + filtered_data: dict[tuple[float, float], DataFrame] = {} new_data = [] for tup in tqdm.tqdm(data.itertuples(), total=data.shape[0]): - tup = DotDict({k: getattr(tup, k) for k in data.columns}) - if raw_data is not None and Path(tup.ranges_file).exists(): + tup_dict = {k: getattr(tup, k) for k in data.columns} + if raw_data is not None and Path(tup_dict["ranges_file"]).exists(): - def compute_dist_score(current_dists, base, upper): + def compute_dist_score( + current_dists: dict[str, float], base: float, upper: float + ) -> float: nonlocal filtered_data key = (base, upper) if key not in filtered_data: @@ -202,15 +201,16 @@ def compute_dist_score(current_dists, base, upper): return error # don't use open_and_save() because we generated ranges in this run - with open(tup.ranges_file, "r") as f: - errors = [] - drange = (None, None) - current_dists = {} + with open(tup_dict["ranges_file"], "r") as f: + errors: list[float] = [] + drange: tuple[Optional[float], Optional[float]] = (None, None) + current_dists: dict[str, float] = {} for line in f: if "Generating range" in line: if len(current_dists) > 0: assert drange[0] is not None + assert drange[1] is not None errors.append( compute_dist_score(current_dists, drange[0], drange[1]) ) @@ -219,9 +219,12 @@ def compute_dist_score(current_dists, base, upper): break if drange[0] is None: - drange = (1.0 - tup.bias_separation, 1.01) + drange = (1.0 - tup_dict["bias_separation"], 1.01) else: - drange = (drange[0] - tup.bias_separation, drange[0]) + drange = ( + drange[0] - tup_dict["bias_separation"], + drange[0], + ) current_dists = {} else: @@ -232,19 +235,21 @@ def compute_dist_score(current_dists, base, upper): if len(current_dists) > 0: # Put the error in. errors.append( - compute_dist_score(current_dists, 0.0, tup.bias_separation) + compute_dist_score( + current_dists, 0.0, tup_dict["bias_separation"] + ) ) - tup["idx_class_errors"] = ",".join( + tup_dict["idx_class_errors"] = ",".join( [str(np.round(e, 2)) for e in errors] ) for i, e in enumerate(errors): - tup[f"idx_class_error{i}"] = np.round(e, 2) + tup_dict[f"idx_class_error{i}"] = np.round(e, 2) if len(errors) > 0: - tup["idx_class_mean_error"] = np.mean(errors) - tup["idx_class_total_error"] = np.sum(errors) - tup["idx_class_min_error"] = np.min(errors) - tup["idx_class_max_error"] = np.max(errors) - new_data.append(dict(tup)) - return pd.DataFrame(new_data) + tup_dict["idx_class_mean_error"] = np.mean(errors) + tup_dict["idx_class_total_error"] = np.sum(errors) + tup_dict["idx_class_min_error"] = np.min(errors) + tup_dict["idx_class_max_error"] = np.max(errors) + new_data.append(tup_dict) + return DataFrame(new_data) diff --git a/tune/protox/embedding/train.py b/tune/protox/embedding/train.py index 69eba251..0f8116e8 100644 --- a/tune/protox/embedding/train.py +++ b/tune/protox/embedding/train.py @@ -1,6 +1,7 @@ import logging import random from pathlib import Path +from typing import Optional import click import numpy as np @@ -11,6 +12,7 @@ DEFAULT_HPO_SPACE_PATH, WORKLOAD_NAME_PLACEHOLDER, WORKSPACE_PATH_PLACEHOLDER, + DBGymConfig, conv_inputpath_to_realabspath, default_benchmark_config_path, default_traindata_path, @@ -59,75 +61,82 @@ ) @click.option( "--scale-factor", + type=float, default=1.0, help=f"The scale factor used when generating the data of the benchmark.", ) @click.option( "--benchmark-config-path", - default=None, type=Path, + default=None, help=f"The path to the .yaml config file for the benchmark. The default is {default_benchmark_config_path(BENCHMARK_NAME_PLACEHOLDER)}.", ) @click.option( "--traindata-path", - default=None, type=Path, + default=None, help=f"The path to the .parquet file containing the training data to use to train the embedding models. The default is {default_traindata_path(WORKSPACE_PATH_PLACEHOLDER, BENCHMARK_NAME_PLACEHOLDER, WORKLOAD_NAME_PLACEHOLDER)}.", ) @click.option( "--seed", - default=None, type=int, + default=None, help="The seed used for all sources of randomness (random, np, torch, etc.). The default is a random value.", ) # train args @click.option( "--hpo-space-path", + type=Path, default=DEFAULT_HPO_SPACE_PATH, - type=str, help="The path to the .json file defining the search space for hyperparameter optimization (HPO).", ) @click.option( "--train-max-concurrent", - default=1, type=int, + default=1, help="The max # of concurrent embedding models to train during hyperparameter optimization. This is usually set lower than `nproc` to reduce memory pressure.", ) @click.option("--iterations-per-epoch", default=1000, help=f"TODO(wz2)") @click.option( "--num-samples", + type=int, default=40, help=f"The # of times to specific hyperparameter configs to sample from the hyperparameter search space and train embedding models with.", ) -@click.option("--train-size", default=0.99, help=f"TODO(wz2)") +@click.option("--train-size", type=float, default=0.99, help=f"TODO(wz2)") # analyze args @click.option( - "--start-epoch", default=0, help="The epoch to start analyzing models at." + "--start-epoch", type=int, default=0, help="The epoch to start analyzing models at." ) @click.option( "--batch-size", + type=int, default=8192, help=f"The size of batches to use to build {STATS_FNAME}.", ) @click.option( "--num-batches", + type=int, default=100, help=f'The number of batches to use to build {STATS_FNAME}. Setting it to -1 indicates "use all batches".', ) @click.option( "--max-segments", + type=int, default=15, help=f"The maximum # of segments in the latent space when creating {RANGES_FNAME}.", ) @click.option( "--num-points-to-sample", + type=int, default=8192, help=f"The number of points to sample when creating {RANGES_FNAME}.", ) @click.option( "--num-classes-to-keep", + type=int, default=5, help=f"The number of classes to keep for each segment when creating {RANGES_FNAME}.", ) @@ -158,41 +167,41 @@ help="The number of indexes whose errors to compute during _attach().", ) @click.option( - "--num-curate", default=1, help="The number of models to curate" + "--num-curate", type=int, default=1, help="The number of models to curate" ) # TODO(wz2): why would we want to curate more than one? @click.option( "--allow-all", is_flag=True, help="Whether to curate within or across parts." ) -@click.option("--flatten-idx", default=0, help="TODO(wz2)") +@click.option("--flatten-idx", type=int, default=0, help="TODO(wz2)") def train( - dbgym_cfg, - benchmark_name, - seed_start, - seed_end, - query_subset, - scale_factor, - benchmark_config_path, - traindata_path, - seed, - hpo_space_path, - train_max_concurrent, - iterations_per_epoch, - num_samples, - train_size, - start_epoch, - batch_size, - num_batches, - max_segments, - num_points_to_sample, - num_classes_to_keep, - recon, - latent_dim, - bias_sep, - idx_limit, - num_curate, - allow_all, - flatten_idx, -): + dbgym_cfg: DBGymConfig, + benchmark_name: str, + seed_start: int, + seed_end: int, + query_subset: str, + scale_factor: float, + benchmark_config_path: Optional[Path], + traindata_path: Optional[Path], + seed: Optional[int], + hpo_space_path: Path, + train_max_concurrent: int, + iterations_per_epoch: int, + num_samples: int, + train_size: float, + start_epoch: int, + batch_size: int, + num_batches: int, + max_segments: int, + num_points_to_sample: int, + num_classes_to_keep: int, + recon: float, + latent_dim: int, + bias_sep: float, + idx_limit: int, + num_curate: int, + allow_all: bool, + flatten_idx: int, +) -> None: """ Trains embeddings with num_samples samples of the hyperparameter space. Analyzes the accuracy of all epochs of all hyperparameter space samples. @@ -200,16 +209,16 @@ def train( """ # set args to defaults programmatically (do this before doing anything else in the function) workload_name = workload_name_fn(scale_factor, seed_start, seed_end, query_subset) - if traindata_path == None: + if traindata_path is None: traindata_path = default_traindata_path( dbgym_cfg.dbgym_workspace_path, benchmark_name, workload_name ) # TODO(phw2): figure out whether different scale factors use the same config # TODO(phw2): figure out what parts of the config should be taken out (like stuff about tables) - if benchmark_config_path == None: + if benchmark_config_path is None: benchmark_config_path = default_benchmark_config_path(benchmark_name) - if seed == None: - seed = random.randint(0, 1e8) + if seed is None: + seed = random.randint(0, int(1e8)) # Convert all input paths to absolute paths benchmark_config_path = conv_inputpath_to_realabspath( diff --git a/tune/protox/embedding/train_all.py b/tune/protox/embedding/train_all.py index e8358387..9f0aed3a 100644 --- a/tune/protox/embedding/train_all.py +++ b/tune/protox/embedding/train_all.py @@ -24,7 +24,8 @@ from ray.tune.schedulers import FIFOScheduler from ray.tune.search import ConcurrencyLimiter from ray.tune.search.hyperopt import HyperOptSearch -from sklearn.model_selection import train_test_split # type: ignore +from sklearn.model_selection import train_test_split +from torch.optim import Adam # type: ignore[attr-defined] from torch.utils.data import TensorDataset from typing_extensions import ParamSpec @@ -46,7 +47,7 @@ from tune.protox.env.workload import Workload -def fetch_vae_parameters_from_workload(w: Workload, ntables: int) -> Tuple[int, int]: +def fetch_vae_parameters_from_workload(w: Workload, ntables: int) -> tuple[int, int]: max_indexable = w.max_indexable() max_cat_features = max( ntables, max_indexable + 1 @@ -59,7 +60,7 @@ def fetch_index_parameters( dbgym_cfg: DBGymConfig, data: dict[str, Any], workload_path: Path, -) -> Tuple[int, int, TableAttrListMap, dict[TableColTuple, int]]: +) -> tuple[int, int, TableAttrListMap, dict[TableColTuple, int]]: tables = data["tables"] attributes = data["attributes"] query_spec = data["query_spec"] @@ -90,11 +91,11 @@ def fetch_index_parameters( def load_input_data( dbgym_cfg: DBGymConfig, traindata_path: Path, - train_size: int, + train_size: float, max_attrs: int, require_cost: bool, seed: int, -) -> Tuple[TensorDataset, Any, Any, Optional[TensorDataset], int]: +) -> tuple[TensorDataset, Any, Any, Optional[TensorDataset], int]: # Load the input data. columns = [] columns += ["tbl_index", "idx_class"] @@ -115,7 +116,7 @@ def load_input_data( gc.collect() gc.collect() - if train_size == 1: + if train_size == 1.0: train_dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y)) del x gc.collect() @@ -126,7 +127,7 @@ def load_input_data( train_x, val_x, train_y, val_y = train_test_split( x, y, - test_size=1 - train_size, + test_size=1.0 - train_size, train_size=train_size, random_state=seed, shuffle=True, @@ -161,7 +162,7 @@ def create_vae_model( "sigmoid": nn.Sigmoid, }[config["mean_output_act"]] - torch.set_float32_matmul_precision("high") # type: ignore + torch.set_float32_matmul_precision("high") model = VAE( max_categorical=max_cat_features, input_dim=cat_input, @@ -182,7 +183,7 @@ def train_all_embeddings( dbgym_cfg: DBGymConfig, generic_args: EmbeddingTrainGenericArgs, train_all_args: EmbeddingTrainAllArgs, -): +) -> None: """ Trains all num_samples models using different samples of the hyperparameter space, writing their results to different embedding_*/ folders in the run_*/ folder @@ -226,7 +227,9 @@ def train_all_embeddings( sync_config=SyncConfig(), verbose=2, log_to_file=True, - storage_path=dbgym_cfg.cur_task_runs_path("embedding_ray_results", mkdir=True), + storage_path=str( + dbgym_cfg.cur_task_runs_path("embedding_ray_results", mkdir=True) + ), ) resources = {"cpu": 1} @@ -270,7 +273,7 @@ def _hpo_train( dbgym_cfg: DBGymConfig, generic_args: EmbeddingTrainGenericArgs, train_all_args: EmbeddingTrainAllArgs, -): +) -> None: sys.path.append(os.fspath(dbgym_cfg.dbgym_repo_path)) # Explicitly set the number of torch threads. @@ -352,11 +355,11 @@ def _build_trainer( traindata_path: Path, trial_dpath: Path, benchmark_config_path: Path, - train_size: int, + train_size: float, workload_path: Path, - dataloader_num_workers=0, - disable_tqdm=False, -): + dataloader_num_workers: int = 0, + disable_tqdm: bool = False, +) -> tuple[VAETrainer, Callable[..., Optional[dict[str, Any]]]]: max_cat_features = 0 max_attrs = 0 @@ -401,7 +404,7 @@ def _build_trainer( models = {"trunk": trunk, "embedder": model} optimizers = { - "embedder_optimizer": torch.optim.Adam( + "embedder_optimizer": Adam( model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"] ), } @@ -442,7 +445,7 @@ def _build_trainer( def clip_grad() -> None: if config["grad_clip_amount"] is not None: - torch.nn.utils.clip_grad_norm_( # type: ignore + torch.nn.utils.clip_grad_norm_( model.parameters(), config["grad_clip_amount"] ) @@ -513,9 +516,9 @@ def epoch_end(*args: P.args, **kwargs: P.kwargs) -> Optional[dict[str, Any]]: trainer.switch_eval() pbar = None if suppress else tqdm.tqdm(total=len(val_dl)) - for i, curr_batch in enumerate(val_dl): # type: ignore + for i, curr_batch in enumerate(val_dl): # Get the losses. - trainer.calculate_loss(curr_batch) # type: ignore + trainer.calculate_loss(curr_batch) if isinstance(trainer.losses["metric_loss"], torch.Tensor): total_metric_loss.append(trainer.losses["metric_loss"].item()) else: diff --git a/tune/protox/embedding/train_args.py b/tune/protox/embedding/train_args.py index f4a955f9..c86a6392 100644 --- a/tune/protox/embedding/train_args.py +++ b/tune/protox/embedding/train_args.py @@ -1,16 +1,19 @@ +from pathlib import Path + + class EmbeddingTrainGenericArgs: """Same comment as EmbeddingDatagenGenericArgs""" def __init__( self, - benchmark_name, - workload_name, - scale_factor, - benchmark_config_path, - traindata_path, - seed, - workload_path, - ): + benchmark_name: str, + workload_name: str, + scale_factor: float, + benchmark_config_path: Path, + traindata_path: Path, + seed: int, + workload_path: Path, + ) -> None: self.benchmark_name = benchmark_name self.workload_name = workload_name self.scale_factor = scale_factor @@ -25,12 +28,12 @@ class EmbeddingTrainAllArgs: def __init__( self, - hpo_space_path, - train_max_concurrent, - iterations_per_epoch, - num_samples, - train_size, - ): + hpo_space_path: Path, + train_max_concurrent: int, + iterations_per_epoch: int, + num_samples: int, + train_size: float, + ) -> None: self.hpo_space_path = hpo_space_path self.train_max_concurrent = train_max_concurrent self.iterations_per_epoch = iterations_per_epoch @@ -43,13 +46,13 @@ class EmbeddingAnalyzeArgs: def __init__( self, - start_epoch, - batch_size, - num_batches, - max_segments, - num_points_to_sample, - num_classes_to_keep, - ): + start_epoch: int, + batch_size: int, + num_batches: int, + max_segments: int, + num_points_to_sample: int, + num_classes_to_keep: int, + ) -> None: self.start_epoch = start_epoch self.batch_size = batch_size self.num_batches = num_batches @@ -62,8 +65,15 @@ class EmbeddingSelectArgs: """Same comment as EmbeddingDatagenGenericArgs""" def __init__( - self, recon, latent_dim, bias_sep, idx_limit, num_curate, allow_all, flatten_idx - ): + self, + recon: float, + latent_dim: int, + bias_sep: float, + idx_limit: int, + num_curate: int, + allow_all: bool, + flatten_idx: int, + ) -> None: self.recon = recon self.latent_dim = latent_dim self.bias_sep = bias_sep diff --git a/tune/protox/embedding/trainer.py b/tune/protox/embedding/trainer.py index 19648aa5..6b85fcba 100644 --- a/tune/protox/embedding/trainer.py +++ b/tune/protox/embedding/trainer.py @@ -6,8 +6,8 @@ import torch import tqdm from numpy.typing import NDArray -from pytorch_metric_learning import trainers # type: ignore -from pytorch_metric_learning.utils import common_functions as c_f # type: ignore +from pytorch_metric_learning import trainers +from pytorch_metric_learning.utils import common_functions as c_f from torch.utils.data import Sampler @@ -26,7 +26,7 @@ def __init__( self.elem_per_class = 0 assert self.batch_size > 0 - def compute(self) -> Tuple[dict[int, Tuple[int, NDArray[Any]]], int, int]: + def compute(self) -> tuple[dict[int, tuple[int, NDArray[Any]]], int, int]: r = {} for c in range(self.max_class): lc = np.argwhere(self.labels == c) @@ -80,7 +80,7 @@ def __init__( bias_fn: Optional[ Callable[ [torch.Tensor, torch.Tensor], - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], ] ], *args: Any, @@ -90,7 +90,7 @@ def __init__( self.failed = False self.fail_msg: Optional[str] = None self.fail_data: Optional[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ] = None self.disable_tqdm = disable_tqdm self.bias_fn = bias_fn @@ -117,7 +117,7 @@ def maybe_get_vae_loss( ) return 0 - def calculate_loss(self, curr_batch: Tuple[torch.Tensor, torch.Tensor]) -> None: + def calculate_loss(self, curr_batch: tuple[torch.Tensor, torch.Tensor]) -> None: data, labels = curr_batch if labels.shape[1] == 1: # Flatten labels if it's a class. @@ -170,7 +170,7 @@ def train(self, start_epoch: int = 1, num_epochs: int = 1) -> None: if not self.disable_tqdm: pbar = tqdm.tqdm(range(self.iterations_per_epoch)) else: - pbar = range(self.iterations_per_epoch) # type: ignore + pbar = range(self.iterations_per_epoch) for self.iteration in pbar: self.forward_and_backward() @@ -232,7 +232,7 @@ def train(self, start_epoch: int = 1, num_epochs: int = 1) -> None: def compute_embeddings(self, base_output: Any) -> None: assert False - def get_batch(self) -> Tuple[torch.Tensor, torch.Tensor]: + def get_batch(self) -> tuple[torch.Tensor, torch.Tensor]: self.dataloader_iter, curr_batch = c_f.try_next_on_generator(self.dataloader_iter, self.dataloader) # type: ignore data, labels = self.data_and_label_getter(curr_batch) return data, labels diff --git a/tune/protox/embedding/utils.py b/tune/protox/embedding/utils.py index a631c24f..0e369158 100644 --- a/tune/protox/embedding/utils.py +++ b/tune/protox/embedding/utils.py @@ -1,6 +1,6 @@ from typing import Any -from hyperopt import hp # type: ignore +from hyperopt import hp def f_unpack_dict(dct: dict[str, Any]) -> dict[str, Any]: diff --git a/tune/protox/embedding/vae.py b/tune/protox/embedding/vae.py index c1a657ac..9040d49c 100644 --- a/tune/protox/embedding/vae.py +++ b/tune/protox/embedding/vae.py @@ -3,16 +3,16 @@ import torch import torch.nn as nn import torch.nn.functional as F -from pytorch_metric_learning import losses, reducers # type: ignore -from pytorch_metric_learning.utils import common_functions as c_f # type: ignore +from pytorch_metric_learning import losses, reducers +from pytorch_metric_learning.utils import common_functions as c_f def gen_vae_collate( max_categorical: int, infer: bool = False -) -> Callable[[list[Any]], Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]]: +) -> Callable[[list[Any]], Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]]: def vae_collate( batch: list[Any], - ) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if infer: x = torch.as_tensor(batch).type(torch.int64) else: @@ -120,7 +120,7 @@ def forward( embeddings: torch.Tensor, labels: Any = None, indices_tuple: Any = None, - ref_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ref_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ref_labels: Any = None, is_eval: bool = False, ) -> Any: @@ -149,7 +149,7 @@ def compute_loss( preds: torch.Tensor, unused0: Any, unused1: Any, - tdata: Optional[Tuple[torch.Tensor, torch.Tensor]], + tdata: Optional[tuple[torch.Tensor, torch.Tensor]], *args: Any, **kwargs: Any ) -> Any: @@ -308,7 +308,7 @@ def init(layer: nn.Module) -> None: else: init_fn(layer.weight) - modules = [encoder, decoder] + modules: list[nn.Module] = [encoder, decoder] for module in modules: if module is not None: module.apply(init) @@ -353,16 +353,16 @@ def get_collate(self) -> Callable[[torch.Tensor], torch.Tensor]: def forward( self, x: torch.Tensor, - bias: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, - ) -> Union[Tuple[torch.Tensor, torch.Tensor, bool], Tuple[torch.Tensor, bool]]: + bias: Optional[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]] = None, + ) -> Union[tuple[torch.Tensor, torch.Tensor, bool], tuple[torch.Tensor, bool]]: return self._compute(x, bias=bias, require_full=True) def latents( self, x: torch.Tensor, - bias: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + bias: Optional[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]] = None, require_full: bool = False, - ) -> Tuple[torch.Tensor, bool]: + ) -> tuple[torch.Tensor, bool]: rets = self._compute(x, bias=bias, require_full=False) assert len(rets) == 2 return rets[0], rets[1] @@ -370,9 +370,9 @@ def latents( def _compute( self, x: torch.Tensor, - bias: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + bias: Optional[Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]] = None, require_full: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor, bool], Tuple[torch.Tensor, bool]]: + ) -> Union[tuple[torch.Tensor, torch.Tensor, bool], tuple[torch.Tensor, bool]]: latents: torch.Tensor = self.encoder(x) latents = latents * self.output_scale diff --git a/tune/protox/env/logger.py b/tune/protox/env/logger.py index 627e6a3c..6cf2a4fe 100644 --- a/tune/protox/env/logger.py +++ b/tune/protox/env/logger.py @@ -9,7 +9,7 @@ import numpy as np from plumbum import local -from torch.utils.tensorboard import SummaryWriter # type: ignore +from torch.utils.tensorboard.writer import SummaryWriter from typing_extensions import ParamSpec from misc.utils import DBGymConfig @@ -25,16 +25,17 @@ def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> T: ret = f(*args, **kwargs) # TODO(wz2): This is a hack to get a logger instance. - assert hasattr(args[0], "logger"), print(args[0], type(args[0])) + first_arg = args[0] # Ignore the indexing type error + assert hasattr(first_arg, "logger"), print(first_arg, type(first_arg)) - if args[0].logger is None: + if first_arg.logger is None: # If there is no logger, just return. return ret - assert isinstance(args[0].logger, Logger) - if args[0].logger is not None: - cls_name = type(args[0]).__name__ - args[0].logger.record(f"{cls_name}_{key}", time.time() - start) + assert isinstance(first_arg.logger, Logger) + if first_arg.logger is not None: + cls_name = type(first_arg).__name__ + first_arg.logger.record(f"{cls_name}_{key}", time.time() - start) return ret return wrapped_f @@ -81,7 +82,7 @@ def __init__( self.writer: Union[SummaryWriter, None] = None if self.trace: self.tensorboard_dpath.mkdir(parents=True, exist_ok=True) - self.writer = SummaryWriter(self.tensorboard_dpath) # type: ignore + self.writer = SummaryWriter(self.tensorboard_dpath) # type: ignore[no-untyped-call] self.iteration = 1 self.iteration_data: dict[str, Any] = {} @@ -144,14 +145,14 @@ def advance(self) -> None: for key, value in self.iteration_data.items(): if isinstance(value, str): # str is considered a np.ScalarType - self.writer.add_text(key, value, self.iteration) # type: ignore + self.writer.add_text(key, value, self.iteration) # type: ignore[no-untyped-call] else: - self.writer.add_scalar(key, value, self.iteration) # type: ignore + self.writer.add_scalar(key, value, self.iteration) # type: ignore[no-untyped-call] del self.iteration_data self.iteration_data = {} self.iteration += 1 - self.writer.flush() # type: ignore + self.writer.flush() # type: ignore[no-untyped-call] def record(self, key: str, value: Any) -> None: stack = inspect.stack(context=2) @@ -168,4 +169,4 @@ def flush(self) -> None: if self.trace: assert self.writer self.advance() - self.writer.flush() # type: ignore + self.writer.flush() # type: ignore[no-untyped-call] diff --git a/tune/protox/env/lsc/__init__.py b/tune/protox/env/lsc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/lsc/lsc_wrapper.py b/tune/protox/env/lsc/lsc_wrapper.py index 5f5d464c..5d4ff5a5 100644 --- a/tune/protox/env/lsc/lsc_wrapper.py +++ b/tune/protox/env/lsc/lsc_wrapper.py @@ -14,7 +14,7 @@ def __init__(self, lsc: LSC, env: gym.Env[Any, Any], logger: Optional[Logger]): self.lsc = lsc self.logger = logger - def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: + def reset(self, *args: Any, **kwargs: Any) -> tuple[Any, dict[str, Any]]: state, info = self.env.reset(*args, **kwargs) self.lsc.reset() @@ -27,7 +27,7 @@ def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: def step( self, *args: Any, **kwargs: Any - ) -> Tuple[Any, float, bool, bool, dict[str, Any]]: + ) -> tuple[Any, float, bool, bool, dict[str, Any]]: state, reward, term, trunc, info = self.env.step(*args, **kwargs) # Remember the LSC when we considered this action. diff --git a/tune/protox/env/mqo/__init__.py b/tune/protox/env/mqo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/mqo/mqo_wrapper.py b/tune/protox/env/mqo/mqo_wrapper.py index 84baa36f..965f7952 100644 --- a/tune/protox/env/mqo/mqo_wrapper.py +++ b/tune/protox/env/mqo/mqo_wrapper.py @@ -163,7 +163,7 @@ def __init__( self.logger = logger def _update_best_observed( - self, query_metric_data: dict[str, BestQueryRun], force_overwrite=False + self, query_metric_data: dict[str, BestQueryRun], force_overwrite: bool = False ) -> None: if query_metric_data is not None: for qid, best_run in query_metric_data.items(): @@ -176,6 +176,7 @@ def _update_best_observed( None, ) if self.logger: + assert best_run.runtime is not None self.logger.get_logger(__name__).debug( f"[best_observe] {qid}: {best_run.runtime/1e6} (force: {force_overwrite})" ) @@ -198,7 +199,7 @@ def _update_best_observed( def step( # type: ignore self, action: HolonAction, - ) -> Tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, float, bool, bool, EnvInfoDict]: # Step based on the "global" action. assert isinstance(self.unwrapped, PostgresEnv) success, info = self.unwrapped.step_before_execution(action) @@ -307,6 +308,7 @@ def transmute( ) # Execute. + assert self.logger is not None self.logger.get_logger(__name__).info("MQOWrapper called step_execute()") success, info = self.unwrapped.step_execute(success, runs, info) if info["query_metric_data"]: @@ -319,6 +321,7 @@ def transmute( with torch.no_grad(): # Pass the mutilated action back through. assert isinstance(self.action_space, HolonSpace) + assert info["actions_info"] is not None info["actions_info"][ "best_observed_holon_action" ] = best_observed_holon_action @@ -326,9 +329,12 @@ def transmute( [best_observed_holon_action] ) - return self.unwrapped.step_post_execute(success, action, info) + obs, reward, term, trunc, info = self.step_post_execute(success, action, info) + # Since we called step_post_execute() with soft=False, we expect infos[1] (reward) to not be None. + assert reward is not None + return (obs, reward, term, trunc, info) - def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, EnvInfoDict]: # type: ignore + def reset(self, *args: Any, **kwargs: Any) -> tuple[Any, EnvInfoDict]: # type: ignore assert isinstance(self.unwrapped, PostgresEnv) # First have to shift to the new state. state, info = self.unwrapped.reset(*args, **kwargs) @@ -412,6 +418,7 @@ def reset(self, *args: Any, **kwargs: Any) -> Tuple[Any, EnvInfoDict]: # type: # Update the reward baseline. if self.unwrapped.reward_utility: + assert self.unwrapped.baseline_metric self.unwrapped.reward_utility.set_relative_baseline( self.unwrapped.baseline_metric, prev_result=metric, diff --git a/tune/protox/env/pg_env.py b/tune/protox/env/pg_env.py index 3e267d53..de298170 100644 --- a/tune/protox/env/pg_env.py +++ b/tune/protox/env/pg_env.py @@ -14,6 +14,7 @@ from tune.protox.env.space.state.space import StateSpace from tune.protox.env.space.utils import fetch_server_indexes, fetch_server_knobs from tune.protox.env.types import ( + ActionsInfo, EnvInfoDict, HolonAction, HolonStateContainer, @@ -78,7 +79,7 @@ def _restore_last_snapshot(self) -> None: @time_record("reset") def reset( # type: ignore self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None - ) -> Tuple[Any, EnvInfoDict]: + ) -> tuple[Any, EnvInfoDict]: reset_start = time.time() if self.logger: self.logger.get_logger(__name__).info( @@ -212,7 +213,7 @@ def reset( # type: ignore return self.current_state, info @time_record("step_before_execution") - def step_before_execution(self, action: HolonAction) -> Tuple[bool, EnvInfoDict]: + def step_before_execution(self, action: HolonAction) -> tuple[bool, EnvInfoDict]: # Log the action in debug mode. if self.logger: self.logger.get_logger(__name__).debug( @@ -248,13 +249,14 @@ def step_before_execution(self, action: HolonAction) -> Tuple[bool, EnvInfoDict] def step_execute( self, setup_success: bool, - all_holon_action_variations: list[Tuple[str, HolonAction]], + all_holon_action_variations: list[tuple[str, HolonAction]], info: EnvInfoDict, - ) -> Tuple[bool, EnvInfoDict]: + ) -> tuple[bool, EnvInfoDict]: if setup_success: assert isinstance(self.observation_space, StateSpace) assert isinstance(self.action_space, HolonSpace) # Evaluate the benchmark. + assert self.logger is not None self.logger.get_logger(__name__).info( f"\n\nfetch_server_knobs(): {fetch_server_knobs(self.pg_conn.conn(), self.action_space.get_knob_space().tables, self.action_space.get_knob_space().knobs, self.workload.queries)}\n\n" ) @@ -302,9 +304,12 @@ def step_execute( "query_metric_data": query_metric_data, "reward": reward, "results_dpath": results_dpath, - "actions_info": { - "all_holon_action_variations": all_holon_action_variations, - }, + "actions_info": ActionsInfo( + { + "all_holon_action_variations": all_holon_action_variations, + "best_observed_holon_action": None, + } + ), } ) ) @@ -316,8 +321,17 @@ def step_post_execute( success: bool, action: HolonAction, info: EnvInfoDict, + # If "soft" is true, it means we're calling step_post_execute() from reset(). If it's false, it means we're calling step_post_execute() from step(). soft: bool = False, - ) -> Tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, Optional[float], bool, bool, EnvInfoDict]: + # If we're calling step_post_execute() from reset(), we expect info["metric"] and info["reward"] to be None. + if not soft: + assert info["reward"] is not None + assert info["metric"] is not None + else: + assert info["reward"] is None + assert info["metric"] is None + if self.workload.oltp_workload and self.horizon > 1: # If horizon = 1, then we're going to reset anyways. So easier to just untar the original archive. # Restore the crisp and clean snapshot. @@ -328,6 +342,7 @@ def step_post_execute( if not soft: if not self.workload.oltp_workload: # Update the workload metric timeout if we've succeeded. + assert info["metric"] is not None self.workload.set_workload_timeout(info["metric"]) # Get the current view of the state container. @@ -361,11 +376,14 @@ def step_post_execute( def step( # type: ignore self, action: HolonAction - ) -> Tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, float, bool, bool, EnvInfoDict]: assert self.tuning_mode != TuningMode.REPLAY success, info = self.step_before_execution(action) success, info = self.step_execute(success, [("PerQuery", action)], info) - return self.step_post_execute(success, action, info) + obs, reward, term, trunc, info = self.step_post_execute(success, action, info) + # Since we called step_post_execute() with soft=False, we expect infos[1] (reward) to not be None. + assert reward is not None + return (obs, reward, term, trunc, info) @time_record("shift_state") def shift_state( diff --git a/tune/protox/env/space/__init__.py b/tune/protox/env/space/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/space/holon_space.py b/tune/protox/env/space/holon_space.py index ee51a5de..870ecedd 100644 --- a/tune/protox/env/space/holon_space.py +++ b/tune/protox/env/space/holon_space.py @@ -39,7 +39,7 @@ def _latent_assert_check( carprod_neighbors: list[HolonAction], carprod_embeds: torch.Tensor, first_drift: int, - ): + ) -> None: zero = self.to_latent([carprod_neighbors[0]])[0] last = self.to_latent([carprod_neighbors[-1]])[0] first_d = self.to_latent([carprod_neighbors[first_drift]])[0] @@ -81,9 +81,9 @@ def __init__( self.space_dims: Optional[list[int]] = None self.logger = logger - def get_spaces(self) -> list[Tuple[str, HolonSubSpace]]: + def get_spaces(self) -> list[tuple[str, HolonSubSpace]]: r = cast( - list[Tuple[str, HolonSubSpace]], + list[tuple[str, HolonSubSpace]], [(s.name, s) for s in self.spaces if hasattr(s, "name")], ) assert len(r) == 3 @@ -98,7 +98,7 @@ def null_action(self, sc: HolonStateContainer) -> HolonAction: def split_action( self, action: HolonAction - ) -> list[Tuple[HolonSubSpace, HolonSubAction]]: + ) -> list[tuple[HolonSubSpace, HolonSubAction]]: return [ (cast(LatentKnobSpace, self.spaces[0]), action[0]), (cast(LatentIndexSpace, self.spaces[1]), action[1]), @@ -230,21 +230,21 @@ def neighborhood( self, raw_action: ProtoAction, neighbor_parameters: NeighborParameters = DEFAULT_NEIGHBOR_PARAMETERS, - ) -> Tuple[list[HolonAction], ProtoAction, torch.Tensor]: + ) -> tuple[list[HolonAction], ProtoAction, torch.Tensor]: env_acts = [] - emb_acts: List[torch.Tensor] = [] + emb_acts: list[torch.Tensor] = [] ndims = [] env_action = self.from_latent(raw_action) for proto in env_action: # Figure out the neighbors for each subspace. - envs_neighbors = [] - embed_neighbors = [] + envs_neighbors: list[Any] = [] + embed_neighbors: list[Any] = [] # TODO(wz2,PROTOX_DELTA): For pseudo-backwards compatibility, we meld the knob + query space together. # In this way, we don't actually generate knob x query cartesian product. # Rather, we directly fuse min(knob_neighbors, query_neighbors) together and then cross with indexes. - meld_groups = [ + meld_groups: list[list[Any]] = [ [self.get_knob_space(), self.get_query_space()], [self.get_index_space()], ] @@ -329,7 +329,7 @@ def generate_state_container( prev_state_container: Optional[HolonStateContainer], action: Optional[HolonAction], connection: Connection[Any], - queries: dict[str, list[Tuple[QueryType, str]]], + queries: dict[str, list[tuple[QueryType, str]]], ) -> HolonStateContainer: t = tuple( s.generate_state_container( @@ -346,7 +346,7 @@ def generate_state_container( def generate_action_plan( self, action: HolonAction, state_container: HolonStateContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: outputs = [ space.generate_action_plan(action[i], state_container[i], **kwargs) for i, space in enumerate(self.spaces) @@ -359,7 +359,7 @@ def generate_action_plan( def generate_plan_from_config( self, config: HolonStateContainer, sc: HolonStateContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: outputs = [ space.generate_delta_action_plan(config[i], sc[i], **kwargs) for i, space in enumerate(self.spaces) diff --git a/tune/protox/env/space/latent_space/latent_index_space.py b/tune/protox/env/space/latent_space/latent_index_space.py index 33d59466..9afa38b4 100644 --- a/tune/protox/env/space/latent_space/latent_index_space.py +++ b/tune/protox/env/space/latent_space/latent_index_space.py @@ -39,7 +39,7 @@ def __init__( latent_dim: int = 0, index_output_transform: Optional[Callable[[ProtoAction], ProtoAction]] = None, index_noise_scale: Optional[ - Callable[[ProtoAction, torch.Tensor], ProtoAction] + Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction] ] = None, logger: Optional[Logger] = None, ) -> None: @@ -250,7 +250,7 @@ def generate_state_container( def generate_action_plan( self, action: IndexSpaceRawSample, sc: IndexSpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: assert check_subspace(self, action) sql_commands = [] @@ -277,7 +277,7 @@ def generate_action_plan( def generate_delta_action_plan( self, action: IndexSpaceContainer, sc: IndexSpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: assert isinstance(action, list) acts = [] sql_commands = [] diff --git a/tune/protox/env/space/latent_space/latent_knob_space.py b/tune/protox/env/space/latent_space/latent_knob_space.py index 6d1a97ea..caa923ee 100644 --- a/tune/protox/env/space/latent_space/latent_knob_space.py +++ b/tune/protox/env/space/latent_space/latent_knob_space.py @@ -181,7 +181,7 @@ def generate_state_container( def generate_action_plan( self, action: KnobSpaceAction, sc: KnobSpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: config_changes = [] sql_commands = [] require_cleanup = False @@ -235,5 +235,5 @@ def generate_action_plan( def generate_delta_action_plan( self, action: KnobSpaceAction, sc: KnobSpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: return self.generate_action_plan(action, sc, **kwargs) diff --git a/tune/protox/env/space/latent_space/latent_query_space.py b/tune/protox/env/space/latent_space/latent_query_space.py index 0bfa40b4..a2668206 100644 --- a/tune/protox/env/space/latent_space/latent_query_space.py +++ b/tune/protox/env/space/latent_space/latent_query_space.py @@ -41,12 +41,12 @@ def generate_state_container( def generate_action_plan( self, action: QuerySpaceAction, sc: QuerySpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: return [], [] def generate_delta_action_plan( self, action: QuerySpaceAction, sc: QuerySpaceContainer, **kwargs: Any - ) -> Tuple[list[str], list[str]]: + ) -> tuple[list[str], list[str]]: return [], [] def extract_query(self, action: QuerySpaceAction) -> QuerySpaceKnobAction: diff --git a/tune/protox/env/space/latent_space/lsc_index_space.py b/tune/protox/env/space/latent_space/lsc_index_space.py index e1425081..87290dcf 100644 --- a/tune/protox/env/space/latent_space/lsc_index_space.py +++ b/tune/protox/env/space/latent_space/lsc_index_space.py @@ -35,7 +35,7 @@ def __init__( latent_dim: int = 0, index_output_transform: Optional[Callable[[ProtoAction], ProtoAction]] = None, index_noise_scale: Optional[ - Callable[[ProtoAction, torch.Tensor], ProtoAction] + Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction] ] = None, logger: Optional[Logger] = None, lsc: Optional[LSC] = None, diff --git a/tune/protox/env/space/primitive/index.py b/tune/protox/env/space/primitive/index.py index ae31a486..070bf092 100644 --- a/tune/protox/env/space/primitive/index.py +++ b/tune/protox/env/space/primitive/index.py @@ -7,7 +7,7 @@ class IndexAction(object): IA = TypeVar("IA", bound="IndexAction") index_name_counter = 0 - index_name_map: dict["IndexAction", int] = dict() + index_name_map: dict["IndexAction", str] = dict() def __init__( self, @@ -81,7 +81,7 @@ def sql(self, add: bool, allow_fail: bool = False) -> str: # A given index name (like "index5") maps one-to-one to the function of an # index (i.e. its table, columns, etc.). - def get_index_name(self): + def get_index_name(self) -> str: if self not in IndexAction.index_name_map: IndexAction.index_name_map[self] = f"index{IndexAction.index_name_counter}" IndexAction.index_name_counter += 1 diff --git a/tune/protox/env/space/primitive/knob.py b/tune/protox/env/space/primitive/knob.py index f71e397f..a09ce942 100644 --- a/tune/protox/env/space/primitive/knob.py +++ b/tune/protox/env/space/primitive/knob.py @@ -27,7 +27,7 @@ def full_knob_name( return knob_name -def _parse_setting_dtype(type_str: str) -> Tuple[SettingType, Any]: +def _parse_setting_dtype(type_str: str) -> tuple[SettingType, Any]: return { "boolean": (SettingType.BOOLEAN, np.int32), "integer": (SettingType.INTEGER, np.int32), @@ -44,7 +44,7 @@ class KnobMetadata(TypedDict, total=False): type: str min: float max: float - quantize: bool + quantize: int log_scale: int unit: int values: list[str] @@ -229,7 +229,7 @@ def _flatdim_knob(space: Knob) -> int: return 1 -def _categorical_elems(type_str: str) -> Tuple[SettingType, int]: +def _categorical_elems(type_str: str) -> tuple[SettingType, int]: return { "scanmethod_enum_categorical": (SettingType.SCANMETHOD_ENUM_CATEGORICAL, 2), }[type_str] diff --git a/tune/protox/env/space/primitive_space/index_policy.py b/tune/protox/env/space/primitive_space/index_policy.py index 99390e57..dd03d209 100644 --- a/tune/protox/env/space/primitive_space/index_policy.py +++ b/tune/protox/env/space/primitive_space/index_policy.py @@ -42,7 +42,7 @@ def __init__( self.index_space_aux_include = index_space_aux_include def spaces(self, seed: int) -> Sequence[spaces.Space[Any]]: - aux: List[spaces.Space[Any]] = [ + aux: list[spaces.Space[Any]] = [ # One-hot encoding for the tables. spaces.Discrete(self.num_tables, seed=seed), # Ordering. Note that we use the postgres style ordinal notation. 0 is illegal/end-of-index. @@ -67,7 +67,7 @@ def spaces(self, seed: int) -> Sequence[spaces.Space[Any]]: ) ] - return cast(List[spaces.Space[Any]], aux_type + aux + aux_include) + return cast(list[spaces.Space[Any]], aux_type + aux + aux_include) def to_action(self, act: IndexSpaceRawSample) -> IndexAction: # First index is the index type. diff --git a/tune/protox/env/space/primitive_space/index_space.py b/tune/protox/env/space/primitive_space/index_space.py index ca4b7a06..f8e8ff41 100644 --- a/tune/protox/env/space/primitive_space/index_space.py +++ b/tune/protox/env/space/primitive_space/index_space.py @@ -107,7 +107,7 @@ def null_action(self) -> IndexSpaceRawSample: action[0] = 1.0 return self.policy.sample_dist(action, self.np_random, sample_num_columns=False) - def to_jsonable(self, sample_n) -> List[str]: # type: ignore + def to_jsonable(self, sample_n) -> list[str]: # type: ignore # Emit the representation of an index. ias = [self.to_action(sample) for sample in sample_n] return [ia.__repr__() for ia in ias] diff --git a/tune/protox/env/space/state/structure.py b/tune/protox/env/space/state/structure.py index 04dbffdd..c5b2ab19 100644 --- a/tune/protox/env/space/state/structure.py +++ b/tune/protox/env/space/state/structure.py @@ -32,7 +32,7 @@ def __init__( self.normalize = normalize if self.normalize: - self.internal_spaces: Dict[str, gym.spaces.Space[Any]] = { + self.internal_spaces: dict[str, gym.spaces.Space[Any]] = { k: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(s.critic_dim(),)) for k, s in action_space.get_spaces() } @@ -116,7 +116,7 @@ def construct_offline( else: index_state = np.zeros(index_space.critic_dim(), dtype=np.float32) - state = {} + state: dict[str, Any] = {} if knob_state is not None: state["knobs"] = knob_state if query_state is not None: diff --git a/tune/protox/env/space/utils.py b/tune/protox/env/space/utils.py index c1f79a4c..0977a906 100644 --- a/tune/protox/env/space/utils.py +++ b/tune/protox/env/space/utils.py @@ -224,7 +224,7 @@ def fetch_server_knobs( def fetch_server_indexes( connection: Connection[Any], tables: list[str] -) -> typing.Tuple[TableAttrListMap, ServerTableIndexMetadata]: +) -> tuple[TableAttrListMap, ServerTableIndexMetadata]: rel_metadata = TableAttrListMap({t: [] for t in tables}) existing_indexes = ServerTableIndexMetadata({}) with connection.cursor(row_factory=dict_row) as cursor: diff --git a/tune/protox/env/target_reset/__init__.py b/tune/protox/env/target_reset/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/target_reset/target_reset_wrapper.py b/tune/protox/env/target_reset/target_reset_wrapper.py index 800ec60a..edfcf520 100644 --- a/tune/protox/env/target_reset/target_reset_wrapper.py +++ b/tune/protox/env/target_reset/target_reset_wrapper.py @@ -36,7 +36,7 @@ def _get_state(self) -> HolonStateContainer: def step( # type: ignore self, *args: Any, **kwargs: Any - ) -> Tuple[Any, float, bool, bool, EnvInfoDict]: + ) -> tuple[Any, float, bool, bool, EnvInfoDict]: """Steps through the environment, normalizing the rewards returned.""" obs, rews, terms, truncs, infos = self.env.step(*args, **kwargs) query_metric_data = infos.get("query_metric_data", None) @@ -81,7 +81,7 @@ def step( # type: ignore ] return obs, rews, terms, truncs, infos - def reset(self, **kwargs: Any) -> Tuple[Any, dict[str, Any]]: + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: if len(self.tracked_states) == 0: # First time. state, info = self.env.reset(**kwargs) diff --git a/tune/protox/env/types.py b/tune/protox/env/types.py index 976317ed..35d3d8a0 100644 --- a/tune/protox/env/types.py +++ b/tune/protox/env/types.py @@ -75,12 +75,12 @@ class ServerIndexMetadata(TypedDict, total=False): QuerySpaceContainer: TypeAlias = KnobSpaceContainer # ([idx_type], [table_encoding], [key1_encoding], ... [key#_encoding], [include_mask]) -IndexSpaceRawSample = NewType("IndexSpaceRawSample", Tuple[Any, ...]) +IndexSpaceRawSample = NewType("IndexSpaceRawSample", tuple[Any, ...]) # [IndexAction(index1), ...] IndexSpaceContainer = NewType("IndexSpaceContainer", list["IndexAction"]) # (table_name, column_name) -TableColTuple = NewType("TableColTuple", Tuple[str, str]) +TableColTuple = NewType("TableColTuple", tuple[str, str]) # {table: [att1, att2, ...], ...} TableAttrListMap = NewType("TableAttrListMap", dict[str, list[str]]) @@ -91,7 +91,7 @@ class ServerIndexMetadata(TypedDict, total=False): # {table: set[ (att1, att3), (att3, att4), ... ], ...} # This maps a table to a set of attributes accessed together. TableAttrAccessSetsMap = NewType( - "TableAttrAccessSetsMap", dict[str, set[Tuple[str, ...]]] + "TableAttrAccessSetsMap", dict[str, set[tuple[str, ...]]] ) # {qid: {table: scan_method, ...}, ...} @@ -101,11 +101,11 @@ class ServerIndexMetadata(TypedDict, total=False): # {qid: {table: [alias1, alias2, ...], ...}, ...} QueryTableAliasMap = NewType("QueryTableAliasMap", dict[str, TableAliasMap]) # {qid: [(query_type1, query_str1), (query_type2, query_str2), ...], ...} -QueryMap = NewType("QueryMap", dict[str, list[Tuple[QueryType, str]]]) +QueryMap = NewType("QueryMap", dict[str, list[tuple[QueryType, str]]]) HolonAction = NewType( "HolonAction", - Tuple[ + tuple[ KnobSpaceAction, IndexSpaceRawSample, QuerySpaceAction, @@ -114,7 +114,7 @@ class ServerIndexMetadata(TypedDict, total=False): HolonStateContainer = NewType( "HolonStateContainer", - Tuple[ + tuple[ KnobSpaceContainer, IndexSpaceContainer, QuerySpaceContainer, @@ -153,12 +153,12 @@ class TargetResetConfig(TypedDict, total=False): class QuerySpec(TypedDict, total=False): benchbase: bool oltp_workload: bool - query_transactional: Union[str, Path] - query_directory: Union[str, Path] - query_order: Union[str, Path] + query_transactional: Path + query_directory: Path + query_order: Path - execute_query_directory: Union[str, Path] - execute_query_order: Union[str, Path] + execute_query_directory: Path + execute_query_order: Path tbl_include_subsets_prune: bool tbl_fold_subsets: bool @@ -166,6 +166,11 @@ class QuerySpec(TypedDict, total=False): tbl_fold_iterations: int +class ActionsInfo(TypedDict): + all_holon_action_variations: list[tuple[str, HolonAction]] + best_observed_holon_action: Optional[HolonAction] + + class EnvInfoDict(TypedDict, total=False): # Original baseline metric. baseline_metric: float @@ -182,19 +187,19 @@ class EnvInfoDict(TypedDict, total=False): prior_pgconf: Optional[Union[str, Path]] # Changes made to the DBMS during this step. - attempted_changes: Tuple[list[str], list[str]] + attempted_changes: tuple[list[str], list[str]] # Metric of this step. - metric: float + metric: Optional[float] # Reward of this step. - reward: float + reward: Optional[float] # Whether any queries timed out or the workload as a whole timed out. did_anything_time_out: bool # Query metric data. query_metric_data: Optional[dict[str, BestQueryRun]] # Information about the actions that were executed this step. - # The actions are in a format usable by replay. (TODO(phw2)) - actions_info: Tuple["KnobSpaceAction", "IndexAction", "QuerySpaceAction"] + # The actions are in a format usable by replay. + actions_info: Optional[ActionsInfo] # ProtoAction of the altered step action. maximal_embed: ProtoAction diff --git a/tune/protox/env/util/__init__.py b/tune/protox/env/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tune/protox/env/util/execute.py b/tune/protox/env/util/execute.py index 6ec5d695..fbbe9a4c 100644 --- a/tune/protox/env/util/execute.py +++ b/tune/protox/env/util/execute.py @@ -36,7 +36,7 @@ def _time_query( connection: psycopg.Connection[Any], query: str, timeout: float, -) -> Tuple[float, bool, Any]: +) -> tuple[float, bool, Any]: did_time_out = False has_explain = "EXPLAIN" in query explain_data = None @@ -77,7 +77,7 @@ def _acquire_metrics_around_query( query: str, query_timeout: float = 0.0, observation_space: Optional[StateSpace] = None, -) -> Tuple[float, bool, Any, Any]: +) -> tuple[float, bool, Any, Any]: _force_statement_timeout(connection, 0) if observation_space and observation_space.require_metrics(): initial_metrics = observation_space.construct_online(connection) diff --git a/tune/protox/env/util/pg_conn.py b/tune/protox/env/util/pg_conn.py index 233b49bc..3faf66a6 100644 --- a/tune/protox/env/util/pg_conn.py +++ b/tune/protox/env/util/pg_conn.py @@ -71,7 +71,7 @@ def __init__( self._conn: Optional[psycopg.Connection[Any]] = None - def get_connstr(self): + def get_connstr(self) -> str: return f"host=localhost port={self.pgport} user={DBGYM_POSTGRES_USER} password={DBGYM_POSTGRES_PASS} dbname={DBGYM_POSTGRES_DBNAME}" def conn(self) -> psycopg.Connection[Any]: @@ -272,7 +272,7 @@ def _set_up_boot( mu_hyp_opt: float, mu_hyp_time: int, mu_hyp_stdev: float, - ): + ) -> None: """ Sets up Boot on the currently running Postgres instances. Uses instance vars of PostgresConn for configuration. @@ -302,7 +302,7 @@ def _set_up_boot( self.logger.get_logger(__name__).debug("Set up boot") @time_record("psql") - def psql(self, sql: str) -> Tuple[int, Optional[str]]: + def psql(self, sql: str) -> tuple[int, Optional[str]]: low_sql = sql.lower() def cancel_fn(conn_str: str) -> None: @@ -358,11 +358,11 @@ def cancel_fn(conn_str: str) -> None: self.disconnect() return 0, None - def restore_pristine_snapshot(self): - self._restore_snapshot(self.pristine_dbdata_snapshot_fpath) + def restore_pristine_snapshot(self) -> bool: + return self._restore_snapshot(self.pristine_dbdata_snapshot_fpath) - def restore_checkpointed_snapshot(self): - self._restore_snapshot(self.checkpoint_dbdata_snapshot_fpath) + def restore_checkpointed_snapshot(self) -> bool: + return self._restore_snapshot(self.checkpoint_dbdata_snapshot_fpath) @time_record("restore") def _restore_snapshot( diff --git a/tune/protox/env/util/reward.py b/tune/protox/env/util/reward.py index ba01b8a0..9b5046d7 100644 --- a/tune/protox/env/util/reward.py +++ b/tune/protox/env/util/reward.py @@ -52,7 +52,7 @@ def set_relative_baseline( def parse_tps_avg_p99_for_metric( self, parent: Union[Path, str] - ) -> Tuple[float, float, float]: + ) -> tuple[float, float, float]: files = [f for f in Path(parent).rglob("*.summary.json")] assert len(files) == 1 @@ -99,7 +99,7 @@ def __call__( metric: Optional[float] = None, update: bool = True, did_error: bool = False, - ) -> Tuple[float, float]: + ) -> tuple[float, float]: # TODO: we need to get the memory consumption of indexes. if the index usage # exceeds the limit, then kill the reward function. may also want to penalize diff --git a/tune/protox/env/util/workload_analysis.py b/tune/protox/env/util/workload_analysis.py index de4be450..7df507bf 100644 --- a/tune/protox/env/util/workload_analysis.py +++ b/tune/protox/env/util/workload_analysis.py @@ -1,9 +1,8 @@ -from enum import Enum, unique from typing import Iterator, Optional, Tuple -import pglast # type: ignore +import pglast from pglast import stream -from pglast.visitors import Continue, Visitor # type: ignore +from pglast.visitors import Continue, Visitor from tune.protox.env.types import ( AttrTableListMap, @@ -73,7 +72,7 @@ def extract_aliases(stmts: pglast.ast.Node) -> TableAliasMap: def extract_sqltypes( stmts: pglast.ast.Node, pid: Optional[int] -) -> list[Tuple[QueryType, str]]: +) -> list[tuple[QueryType, str]]: sqls = [] for stmt in stmts: sql_type = QueryType.UNKNOWN @@ -115,7 +114,7 @@ def extract_columns( tables: list[str], all_attributes: AttrTableListMap, query_aliases: TableAliasMap, -) -> Tuple[TableAttrSetMap, list[TableColTuple]]: +) -> tuple[TableAttrSetMap, list[TableColTuple]]: tbl_col_usages: TableAttrSetMap = TableAttrSetMap({t: set() for t in tables}) def traverse_extract_columns( diff --git a/tune/protox/env/workload.py b/tune/protox/env/workload.py index f56b931b..58d27c59 100644 --- a/tune/protox/env/workload.py +++ b/tune/protox/env/workload.py @@ -5,10 +5,10 @@ import tempfile import time from pathlib import Path -from typing import Any, Optional, Tuple, Union, cast +from typing import IO, Any, Optional, Tuple, Union, cast import numpy as np -import pglast # type: ignore +import pglast from plumbum import local from misc.utils import DBGymConfig, open_and_save @@ -50,14 +50,11 @@ class Workload(object): # However, when creating a Workload object for unittesting, we just want to call open() def _open_for_reading( self, - path, - mode="r", - ): - # when opening for writing we always use open() so we don't need this function, which is - # why we assert here - # I still chose to make mode an argument just to make the interface identical to open()/open_and_save() - assert mode == "r" - if self.dbgym_cfg != None: + path: Path, + ) -> IO[Any]: + # When opening for writing we always use open() so we don't need this function, which is + # why hardcode the mode as "r". + if self.dbgym_cfg is not None: return open_and_save(self.dbgym_cfg, path) else: return open(path) @@ -65,7 +62,7 @@ def _open_for_reading( def _crunch( self, all_attributes: AttrTableListMap, - sqls: list[Tuple[str, Path, float]], + sqls: list[tuple[str, Path, float]], pid: Optional[int], query_spec: QuerySpec, ) -> None: @@ -82,7 +79,7 @@ def _crunch( self.tbl_filter_queries_usage: dict[TableColTuple, set[str]] = {} # Build the SQL and table usage information. - self.queries_mix = {} + self.queries_mix: dict[str, float] = {} self.query_aliases = {} self.query_usages = TableAttrListMap({t: [] for t in self.tables}) tbl_include_subsets = TableAttrAccessSetsMap( @@ -93,7 +90,7 @@ def _crunch( self.order.append(stem) self.queries_mix[stem] = ratio - with self._open_for_reading(sql_file, "r") as q: + with self._open_for_reading(sql_file) as q: sql = q.read() assert not sql.startswith("/*") @@ -162,7 +159,7 @@ def _crunch( ) if do_tbl_include_subsets_prune: - self.tbl_include_subsets = {} + self.tbl_include_subsets = TableAttrAccessSetsMap({}) # First prune any "fully enclosed". for tbl, attrsets in tbl_include_subsets.items(): self.tbl_include_subsets[tbl] = set( @@ -217,7 +214,8 @@ def _crunch( def __init__( self, - dbgym_cfg: DBGymConfig, + # dbgym_cfg is only optional so we can set it to None for unittests. Don't set it to None during normal operation. + dbgym_cfg: Optional[DBGymConfig], tables: list[str], attributes: TableAttrListMap, query_spec: QuerySpec, @@ -255,7 +253,7 @@ def __init__( sqls = [] order_or_txn_fname = "txn.txt" if self.oltp_workload else "order.txt" workload_order_or_txn_fpath = self.workload_path / order_or_txn_fname - with self._open_for_reading(workload_order_or_txn_fpath, "r") as f: + with self._open_for_reading(workload_order_or_txn_fpath) as f: lines = f.read().splitlines() sqls = [ ( @@ -268,7 +266,7 @@ def __init__( # TODO(phw2): pass "query_transactional" somewhere other than query_spec, just like "query_order" is if "query_transactional" in query_spec: - with self._open_for_reading(query_spec["query_transactional"], "r") as f: + with self._open_for_reading(query_spec["query_transactional"]) as f: lines = f.read().splitlines() splits = [line.split(",") for line in lines] sqls = [ @@ -286,7 +284,7 @@ def __init__( # TODO(phw2): pass "execute_query_order" somewhere other than query_spec, just like "query_order" is if "execute_query_order" in query_spec: - with open_and_save(dbgym_cfg, query_spec["execute_query_order"], "r") as f: + with self._open_for_reading(query_spec["execute_query_order"]) as f: lines = f.read().splitlines() sqls = [ ( @@ -336,7 +334,12 @@ def max_indexable(self) -> int: def compute_total_workload_runtime( qid_runtime_data: dict[str, BestQueryRun] ) -> float: - return sum(best_run.runtime for best_run in qid_runtime_data.values()) / 1.0e6 + total_runtime: float = 0.0 + for best_run in qid_runtime_data.values(): + assert best_run.runtime is not None + total_runtime += best_run.runtime + total_runtime /= 1.0e6 + return total_runtime @time_record("execute") def execute_workload( @@ -344,16 +347,16 @@ def execute_workload( pg_conn: PostgresConn, actions: list[HolonAction] = [], variation_names: list[str] = [], - results_dpath: Optional[Union[str, Path]] = None, + results_dpath: Optional[Path] = None, observation_space: Optional[StateSpace] = None, action_space: Optional[HolonSpace] = None, reset_metrics: Optional[dict[str, BestQueryRun]] = None, override_workload_timeout: Optional[float] = None, query_timeout: Optional[int] = None, - workload_qdir: Optional[Tuple[Union[str, Path], Union[str, Path]]] = None, + workload_qdir: Optional[tuple[Path, Path]] = None, blocklist: list[str] = [], first: bool = False, - ) -> Tuple[int, bool, dict[str, Any]]: + ) -> tuple[int, bool, dict[str, Any]]: this_execution_workload_timeout = ( self.workload_timeout if not override_workload_timeout @@ -375,7 +378,7 @@ def execute_workload( ][0], ) ql_knobs = cast( - list[Tuple[LatentQuerySpace, QuerySpaceAction]], + list[tuple[LatentQuerySpace, QuerySpaceAction]], [ [ (t, v) @@ -390,7 +393,7 @@ def execute_workload( if workload_qdir is not None and workload_qdir[0] is not None: # Load actual queries to execute. workload_dir, workload_qlist = workload_qdir - with self._open_for_reading(workload_qlist, "r") as f: + with self._open_for_reading(workload_qlist) as f: psql_order = [ (f"Q{i+1}", Path(workload_dir) / l.strip()) for i, l in enumerate(f.readlines()) @@ -400,7 +403,7 @@ def execute_workload( actual_sql_files = {k: str(v) for (k, v) in psql_order} actual_queries = {} for qid, qpat in psql_order: - with self._open_for_reading(qpat, "r") as f: + with self._open_for_reading(qpat) as f: query = f.read() actual_queries[qid] = [(QueryType.SELECT, query)] else: @@ -651,7 +654,7 @@ def execute( reset_metrics: Optional[dict[str, BestQueryRun]] = None, update: bool = True, first: bool = False, - ) -> Tuple[bool, float, float, Union[str, Path], bool, dict[str, BestQueryRun]]: + ) -> tuple[bool, float, float, Union[str, Path], bool, dict[str, BestQueryRun]]: success = True if self.logger: self.logger.get_logger(__name__).info("Starting to run benchmark...") @@ -673,6 +676,7 @@ def execute( # Execute benchbase if specified. success = self._execute_benchbase(benchbase_config, results_dpath) # We can only create a state if we succeeded. + assert self.dbgym_cfg is not None success = observation_space.check_benchbase(self.dbgym_cfg, results_dpath) else: num_timed_out_queries, did_workload_time_out, query_metric_data = ( diff --git a/tune/protox/tests/test_index_space.py b/tune/protox/tests/test_index_space.py index 02225649..9ccfd73e 100644 --- a/tune/protox/tests/test_index_space.py +++ b/tune/protox/tests/test_index_space.py @@ -6,18 +6,19 @@ from tune.protox.env.space.primitive_space import IndexSpace from tune.protox.env.space.utils import check_subspace +from tune.protox.env.types import IndexSpaceRawSample from tune.protox.env.workload import Workload class IndexSpaceTests(unittest.TestCase): @staticmethod def load( - config_path=Path( + config_path: Path = Path( "tune/protox/tests/unittest_benchmark_configs/unittest_tpch.yaml" ).resolve(), - aux_type=True, - aux_include=True, - ): + aux_type: bool = True, + aux_include: bool = True, + ) -> tuple[Workload, IndexSpace]: # don't call open_and_save() because this is a unittest with open(config_path, "r") as f: benchmark_config = yaml.safe_load(f) @@ -51,7 +52,7 @@ def load( ) return w, i - def test_null_action(self): + def test_null_action(self) -> None: w, i = IndexSpaceTests.load() null_action = i.null_action() self.assertTrue(check_subspace(i, null_action)) @@ -60,19 +61,19 @@ def test_null_action(self): null_action = i.null_action() self.assertTrue(check_subspace(i, null_action)) - def test_sample(self): + def test_sample(self) -> None: w, i = IndexSpaceTests.load(aux_type=False, aux_include=False) for _ in range(100): self.assertTrue(check_subspace(i, i.sample())) - def test_sample_table(self): + def test_sample_table(self) -> None: w, i = IndexSpaceTests.load(aux_type=False, aux_include=False) for _ in range(100): mask = {"table_idx": 2} ia = i.to_action(i.sample(mask)) self.assertEqual(ia.tbl_name, "lineitem") - def test_sample_table_col(self): + def test_sample_table_col(self) -> None: w, i = IndexSpaceTests.load(aux_type=False, aux_include=False) for _ in range(100): mask = {"table_idx": 2, "col_idx": 1} @@ -80,12 +81,14 @@ def test_sample_table_col(self): self.assertEqual(ia.tbl_name, "lineitem") self.assertEqual(ia.columns[0], "l_partkey") - def test_neighborhood(self): + def test_neighborhood(self) -> None: w, i = IndexSpaceTests.load(aux_type=True, aux_include=True) _, isa = IndexSpaceTests.load(aux_type=False, aux_include=False) act = isa.sample(mask={"table_idx": 2, "col_idx": 1}) - act = (0, *act, np.zeros(i.max_inc_columns, dtype=np.float32)) + act = IndexSpaceRawSample( + tuple([0, *act, np.zeros(i.max_inc_columns, dtype=np.float32)]) + ) self.assertTrue(check_subspace(i, act)) neighbors = i.policy.structural_neighbors(act) diff --git a/tune/protox/tests/test_primitive.py b/tune/protox/tests/test_primitive.py index f9c2bd29..d7590d80 100644 --- a/tune/protox/tests/test_primitive.py +++ b/tune/protox/tests/test_primitive.py @@ -7,7 +7,7 @@ class PrimitivesTests(unittest.TestCase): - def test_linear_knob(self): + def test_linear_knob(self) -> None: k = Knob( table_name=None, query_name="q", @@ -30,7 +30,7 @@ def test_linear_knob(self): self.assertEqual(k.project_scraped_setting(0.58), 0.5) self.assertEqual(round(k.project_scraped_setting(0.62), 2), 0.6) - def test_log_knob(self): + def test_log_knob(self) -> None: k = Knob( table_name=None, query_name="q", @@ -53,7 +53,7 @@ def test_log_knob(self): self.assertEqual(k.project_scraped_setting(24), 32.0) self.assertEqual(k.project_scraped_setting(1024), 1024.0) - def test_latent_knob(self): + def test_latent_knob(self) -> None: k = LatentKnob( table_name=None, query_name="q", @@ -85,7 +85,7 @@ def test_latent_knob(self): self.assertEqual(k.shift_offset(0.5, 1), 0.6) self.assertEqual(k.shift_offset(0.5, -2), 0.3) - def test_ia(self): + def test_ia(self) -> None: ia1 = IndexAction( idx_type="btree", tbl="tbl", @@ -95,7 +95,7 @@ def test_ia(self): raw_repr=None, bias=0.0, ) - IndexAction.index_counter = 0 + IndexAction.index_name_counter = 0 self.assertEqual( ia1.sql(add=True), "CREATE INDEX index0 ON tbl USING btree (a,b,c) INCLUDE (d,e)", diff --git a/tune/protox/tests/test_workload.py b/tune/protox/tests/test_workload.py index fb46fea3..04a0f980 100644 --- a/tune/protox/tests/test_workload.py +++ b/tune/protox/tests/test_workload.py @@ -1,16 +1,18 @@ import json import unittest from pathlib import Path +from typing import Any, Tuple import yaml from tune.protox.env.space.primitive_space import IndexSpace +from tune.protox.env.types import TableAttrAccessSetsMap, TableColTuple from tune.protox.env.workload import Workload class WorkloadTests(unittest.TestCase): @staticmethod - def load(config_file: str, workload_path: Path): + def load(config_file: str, workload_path: Path) -> tuple[Workload, IndexSpace]: # don't call open_and_save() because this is a unittest with open(config_file, "r") as f: benchmark_config = yaml.safe_load(f) @@ -37,19 +39,21 @@ def load(config_file: str, workload_path: Path): seed=0, rel_metadata=w.column_usages(), attributes_overwrite=w.column_usages(), - tbl_include_subsets={}, + tbl_include_subsets=TableAttrAccessSetsMap({}), index_space_aux_type=True, index_space_aux_include=True, deterministic_policy=True, ) return w, i - def diff_classmapping(self, ref, target): + def diff_classmapping( + self, ref: dict[TableColTuple, int], target: dict[TableColTuple, int] + ) -> None: for k, v in ref.items(): self.assertTrue(k in target, msg=f"{k} is missing.") self.assertTrue(v == target[k]) - def test_tpch(self): + def test_tpch(self) -> None: with open("tune/protox/tests/unittest_ref_models/ref_tpch_model.txt", "r") as f: ref = json.load(f)["class_mapping"] ref = {(v["relname"], v["ord_column"]): int(k) for k, v in ref.items()} @@ -60,7 +64,7 @@ def test_tpch(self): ) self.assertEqual(i.class_mapping, ref) - def test_job(self): + def test_job(self) -> None: # don't call open_and_save() because this is a unittest with open( "tune/protox/tests/unittest_ref_models/ref_job_full_model.txt", "r" @@ -74,7 +78,7 @@ def test_job(self): ) self.assertEqual(i.class_mapping, ref) - def test_dsb(self): + def test_dsb(self) -> None: # don't call open_and_save() because this is a unittest with open("tune/protox/tests/unittest_ref_models/ref_dsb_model.txt", "r") as f: ref = json.load(f)["class_mapping"] @@ -86,7 +90,7 @@ def test_dsb(self): ) self.diff_classmapping(ref, i.class_mapping) - def test_tpcc(self): + def test_tpcc(self) -> None: # don't call open_and_save() because this is a unittest with open("tune/protox/tests/unittest_ref_models/ref_tpcc_model.txt", "r") as f: ref = json.load(f)["class_mapping"] diff --git a/tune/protox/tests/test_workload_utils.py b/tune/protox/tests/test_workload_utils.py index b1e63cf7..be2fd9a8 100644 --- a/tune/protox/tests/test_workload_utils.py +++ b/tune/protox/tests/test_workload_utils.py @@ -2,7 +2,12 @@ import pglast -from tune.protox.env.util.workload_analysis import * +from tune.protox.env.types import AttrTableListMap, QueryType +from tune.protox.env.util.workload_analysis import ( + extract_aliases, + extract_columns, + extract_sqltypes, +) class WorkloadUtilsTests(unittest.TestCase): @@ -16,69 +21,71 @@ class WorkloadUtilsTests(unittest.TestCase): "nation", "region", ] - TPCH_ALL_ATTRIBUTES = { - "r_regionkey": ["region"], - "r_name": ["region"], - "r_comment": ["region"], - "n_nationkey": ["nation"], - "n_name": ["nation"], - "n_regionkey": ["nation"], - "n_comment": ["nation"], - "p_partkey": ["part"], - "p_name": ["part"], - "p_mfgr": ["part"], - "p_brand": ["part"], - "p_type": ["part"], - "p_size": ["part"], - "p_container": ["part"], - "p_retailprice": ["part"], - "p_comment": ["part"], - "s_suppkey": ["supplier"], - "s_name": ["supplier"], - "s_address": ["supplier"], - "s_nationkey": ["supplier"], - "s_phone": ["supplier"], - "s_acctbal": ["supplier"], - "s_comment": ["supplier"], - "ps_partkey": ["partsupp"], - "ps_suppkey": ["partsupp"], - "ps_availqty": ["partsupp"], - "ps_supplycost": ["partsupp"], - "ps_comment": ["partsupp"], - "c_custkey": ["customer"], - "c_name": ["customer"], - "c_address": ["customer"], - "c_nationkey": ["customer"], - "c_phone": ["customer"], - "c_acctbal": ["customer"], - "c_mktsegment": ["customer"], - "c_comment": ["customer"], - "o_orderkey": ["orders"], - "o_custkey": ["orders"], - "o_orderstatus": ["orders"], - "o_totalprice": ["orders"], - "o_orderdate": ["orders"], - "o_orderpriority": ["orders"], - "o_clerk": ["orders"], - "o_shippriority": ["orders"], - "o_comment": ["orders"], - "l_orderkey": ["lineitem"], - "l_partkey": ["lineitem"], - "l_suppkey": ["lineitem"], - "l_linenumber": ["lineitem"], - "l_quantity": ["lineitem"], - "l_extendedprice": ["lineitem"], - "l_discount": ["lineitem"], - "l_tax": ["lineitem"], - "l_returnflag": ["lineitem"], - "l_linestatus": ["lineitem"], - "l_shipdate": ["lineitem"], - "l_commitdate": ["lineitem"], - "l_receiptdate": ["lineitem"], - "l_shipinstruct": ["lineitem"], - "l_shipmode": ["lineitem"], - "l_comment": ["lineitem"], - } + TPCH_ALL_ATTRIBUTES = AttrTableListMap( + { + "r_regionkey": ["region"], + "r_name": ["region"], + "r_comment": ["region"], + "n_nationkey": ["nation"], + "n_name": ["nation"], + "n_regionkey": ["nation"], + "n_comment": ["nation"], + "p_partkey": ["part"], + "p_name": ["part"], + "p_mfgr": ["part"], + "p_brand": ["part"], + "p_type": ["part"], + "p_size": ["part"], + "p_container": ["part"], + "p_retailprice": ["part"], + "p_comment": ["part"], + "s_suppkey": ["supplier"], + "s_name": ["supplier"], + "s_address": ["supplier"], + "s_nationkey": ["supplier"], + "s_phone": ["supplier"], + "s_acctbal": ["supplier"], + "s_comment": ["supplier"], + "ps_partkey": ["partsupp"], + "ps_suppkey": ["partsupp"], + "ps_availqty": ["partsupp"], + "ps_supplycost": ["partsupp"], + "ps_comment": ["partsupp"], + "c_custkey": ["customer"], + "c_name": ["customer"], + "c_address": ["customer"], + "c_nationkey": ["customer"], + "c_phone": ["customer"], + "c_acctbal": ["customer"], + "c_mktsegment": ["customer"], + "c_comment": ["customer"], + "o_orderkey": ["orders"], + "o_custkey": ["orders"], + "o_orderstatus": ["orders"], + "o_totalprice": ["orders"], + "o_orderdate": ["orders"], + "o_orderpriority": ["orders"], + "o_clerk": ["orders"], + "o_shippriority": ["orders"], + "o_comment": ["orders"], + "l_orderkey": ["lineitem"], + "l_partkey": ["lineitem"], + "l_suppkey": ["lineitem"], + "l_linenumber": ["lineitem"], + "l_quantity": ["lineitem"], + "l_extendedprice": ["lineitem"], + "l_discount": ["lineitem"], + "l_tax": ["lineitem"], + "l_returnflag": ["lineitem"], + "l_linestatus": ["lineitem"], + "l_shipdate": ["lineitem"], + "l_commitdate": ["lineitem"], + "l_receiptdate": ["lineitem"], + "l_shipinstruct": ["lineitem"], + "l_shipmode": ["lineitem"], + "l_comment": ["lineitem"], + } + ) TPCH_Q1 = """ select l_returnflag, @@ -104,10 +111,10 @@ class WorkloadUtilsTests(unittest.TestCase): """ @staticmethod - def pglast_parse(sql): + def pglast_parse(sql: str) -> pglast.ast.Node: return pglast.parse_sql(sql) - def test_extract_aliases(self): + def test_extract_aliases(self) -> None: sql = "select * from t1 as t1_alias; select * from t1;" stmts = WorkloadUtilsTests.pglast_parse(sql) aliases = extract_aliases(stmts) @@ -116,21 +123,21 @@ def test_extract_aliases(self): self.assertTrue("t1" in aliases and len(aliases) == 1) self.assertEqual(set(aliases["t1"]), set(["t1", "t1_alias"])) - def test_extract_aliases_ignores_views_in_create_view(self): + def test_extract_aliases_ignores_views_in_create_view(self) -> None: sql = "create view view1 (view1_c1) as select c1 from t1;" stmts = WorkloadUtilsTests.pglast_parse(sql) aliases = extract_aliases(stmts) # all tables have only one alias so we can do this simpler assertion code self.assertEqual(aliases, {"t1": ["t1"]}) - def test_extract_aliases_doesnt_ignore_views_that_are_used(self): + def test_extract_aliases_doesnt_ignore_views_that_are_used(self) -> None: sql = "create view view1 (view1_c1) as select c1 from t1; select * from view1;" stmts = WorkloadUtilsTests.pglast_parse(sql) aliases = extract_aliases(stmts) # all tables have only one alias so we can do this simpler assertion code self.assertEqual(aliases, {"t1": ["t1"], "view1": ["view1"]}) - def test_extract_sqltypes(self): + def test_extract_sqltypes(self) -> None: sql = """ select * from t1; update t1 set t1.c1 = 0 where t1.c1 = 1; @@ -150,7 +157,7 @@ def test_extract_sqltypes(self): self.assertEqual(sqltypes[1][0], QueryType.INS_UPD_DEL) self.assertEqual(sqltypes[2][0], QueryType.CREATE_VIEW) - def test_extract_columns(self): + def test_extract_columns(self) -> None: sql = WorkloadUtilsTests.TPCH_Q1 tables = WorkloadUtilsTests.TPCH_TABLES all_attributes = WorkloadUtilsTests.TPCH_ALL_ATTRIBUTES @@ -194,7 +201,7 @@ def test_extract_columns(self): ), ) - def test_extract_columns_with_cte(self): + def test_extract_columns_with_cte(self) -> None: sql = """ with cte1 as ( select t1.c1 @@ -205,7 +212,7 @@ def test_extract_columns_with_cte(self): from cte1; """ tables = ["t1"] - all_attributes = {"c1": "t1", "c2": "t1"} + all_attributes = AttrTableListMap({"c1": ["t1"], "c2": ["t1"]}) stmts = WorkloadUtilsTests.pglast_parse(sql) aliases = extract_aliases(stmts) self.assertEqual(len(stmts), 1) diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/util/pg.py b/util/pg.py index 8c5f1e78..5358b937 100644 --- a/util/pg.py +++ b/util/pg.py @@ -1,10 +1,10 @@ from pathlib import Path -from typing import List +from typing import Any, List, NewType, Union import pglast import psycopg -from sqlalchemy import Connection, Engine, create_engine, text -from sqlalchemy.engine import CursorResult +import sqlalchemy +from sqlalchemy import create_engine, text from misc.utils import DBGymConfig, open_and_save @@ -16,11 +16,13 @@ SHARED_PRELOAD_LIBRARIES = "boot,pg_hint_plan,pg_prewarm" -def conn_execute(conn: Connection, sql: str) -> CursorResult: +def sqlalchemy_conn_execute( + conn: sqlalchemy.Connection, sql: str +) -> sqlalchemy.engine.CursorResult[Any]: return conn.execute(text(sql)) -def sql_file_queries(dbgym_cfg: DBGymConfig, filepath: Path) -> List[str]: +def sql_file_queries(dbgym_cfg: DBGymConfig, filepath: Path) -> list[str]: with open_and_save(dbgym_cfg, filepath) as f: lines: list[str] = [] for line in f: @@ -29,18 +31,21 @@ def sql_file_queries(dbgym_cfg: DBGymConfig, filepath: Path) -> List[str]: if len(line.strip()) == 0: continue lines.append(line) - queries = "".join(lines) - return pglast.split(queries) + queries_str = "".join(lines) + queries: list[str] = pglast.split(queries_str) + return queries -def sql_file_execute(dbgym_cfg: DBGymConfig, conn: Connection, filepath: Path) -> None: +def sql_file_execute( + dbgym_cfg: DBGymConfig, conn: sqlalchemy.Connection, filepath: Path +) -> None: for sql in sql_file_queries(dbgym_cfg, filepath): - conn_execute(conn, sql) + sqlalchemy_conn_execute(conn, sql) # The reason pgport is an argument is because when doing agnet HPO, we want to run multiple instances of Postgres # at the same time. In this situation, they need to have different ports -def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg=True) -> str: +def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg: bool = True) -> str: connstr_suffix = f"{DBGYM_POSTGRES_USER}:{DBGYM_POSTGRES_PASS}@localhost:{pgport}/{DBGYM_POSTGRES_DBNAME}" # use_psycopg means whether or not we use the psycopg.connect() function # counterintuively, you *don't* need psycopg in the connection string if you *are* @@ -49,13 +54,18 @@ def get_connstr(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg=True) -> str: return connstr_prefix + "://" + connstr_suffix -def create_conn(pgport: int = DEFAULT_POSTGRES_PORT, use_psycopg=True) -> Connection: - connstr = get_connstr(use_psycopg=use_psycopg, pgport=pgport) - if use_psycopg: - return psycopg.connect(connstr, autocommit=True, prepare_threshold=None) - else: - engine: Engine = create_engine( - connstr, - execution_options={"isolation_level": "AUTOCOMMIT"}, - ) - return engine.connect() +def create_psycopg_conn(pgport: int = DEFAULT_POSTGRES_PORT) -> psycopg.Connection[Any]: + connstr = get_connstr(use_psycopg=True, pgport=pgport) + psycopg_conn = psycopg.connect(connstr, autocommit=True, prepare_threshold=None) + return psycopg_conn + + +def create_sqlalchemy_conn( + pgport: int = DEFAULT_POSTGRES_PORT, +) -> sqlalchemy.Connection: + connstr = get_connstr(use_psycopg=False, pgport=pgport) + engine: sqlalchemy.Engine = create_engine( + connstr, + execution_options={"isolation_level": "AUTOCOMMIT"}, + ) + return engine.connect() diff --git a/util/shell.py b/util/shell.py index ab06f4c3..d20097ec 100644 --- a/util/shell.py +++ b/util/shell.py @@ -1,18 +1,21 @@ import logging import os import subprocess +from pathlib import Path +from typing import Optional shell_util_logger = logging.getLogger("shell_util") shell_util_logger.setLevel(logging.INFO) -def subprocess_run(c, cwd=None, check_returncode=True, dry_run=False, verbose=True): +def subprocess_run( + c: str, + cwd: Optional[Path] = None, + check_returncode: bool = True, + verbose: bool = True, +) -> subprocess.Popen[str]: cwd_msg = f"(cwd: {cwd if cwd is not None else os.getcwd()})" - if dry_run: - shell_util_logger.info(f"Dry run {cwd_msg}: {c}") - return - if verbose: shell_util_logger.info(f"Running {cwd_msg}: {c}") @@ -27,6 +30,7 @@ def subprocess_run(c, cwd=None, check_returncode=True, dry_run=False, verbose=Tr ) as proc: while True: loop = proc.poll() is None + assert proc.stdout is not None for line in proc.stdout: if verbose: print(line, end="", flush=True)