From ab29cc2a7ecc254bce487a4a613e4c86c4ee89ef Mon Sep 17 00:00:00 2001 From: Derek Worthen Date: Tue, 3 Sep 2024 15:33:16 -0700 Subject: [PATCH] Consistent config load_config (#1065) * Consistent config load_config - Provide a consistent way to load configuration - Resolve potential timestamp directories upfront upon config object creation - Add unit tests for resolving timestamp directories - Resolves #599 - Resolves #1049 * fix formatting issues * remove unnecessary path resolution * fix smoke tests * update prompts to use load_config * Update none checks * Update none checks * Update searching for config method signature * Update unit tests * fix formatting issues --- .../patch-20240830151802543194.json | 4 + graphrag/config/__init__.py | 9 ++ graphrag/config/config_file_loader.py | 17 +-- graphrag/config/load_config.py | 65 ++++++++++ graphrag/config/logging.py | 8 +- graphrag/index/__main__.py | 13 +- graphrag/index/api.py | 12 +- graphrag/index/cli.py | 42 ++----- graphrag/prompt_tune/cli.py | 9 +- graphrag/prompt_tune/loader/__init__.py | 2 - graphrag/prompt_tune/loader/config.py | 61 ---------- graphrag/query/api.py | 15 +-- graphrag/query/cli.py | 113 ++++-------------- pyproject.toml | 1 + .../timestamp_dirs/20240812-120000/empty.txt | 0 .../config/test_resolve_timestamp_path.py | 33 +++++ tests/unit/query/test_infer_data_dir.py | 32 ----- 17 files changed, 169 insertions(+), 267 deletions(-) create mode 100644 .semversioner/next-release/patch-20240830151802543194.json create mode 100644 graphrag/config/load_config.py delete mode 100644 graphrag/prompt_tune/loader/config.py create mode 100644 tests/unit/config/fixtures/timestamp_dirs/20240812-120000/empty.txt create mode 100644 tests/unit/config/test_resolve_timestamp_path.py delete mode 100644 tests/unit/query/test_infer_data_dir.py diff --git a/.semversioner/next-release/patch-20240830151802543194.json b/.semversioner/next-release/patch-20240830151802543194.json new file mode 100644 index 0000000000..d7805109ce --- /dev/null +++ b/.semversioner/next-release/patch-20240830151802543194.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Consistent config loading. Resolves #99 and Resolves #1049" +} diff --git a/graphrag/config/__init__.py b/graphrag/config/__init__.py index 118018a98f..c65795f02d 100644 --- a/graphrag/config/__init__.py +++ b/graphrag/config/__init__.py @@ -3,6 +3,7 @@ """The Indexing Engine default config package root.""" +from .config_file_loader import load_config_from_file, search_for_config_in_root_dir from .create_graphrag_config import ( create_graphrag_config, ) @@ -42,6 +43,8 @@ TextEmbeddingConfigInput, UmapConfigInput, ) +from .load_config import load_config +from .logging import enable_logging_with_config from .models import ( CacheConfig, ChunkingConfig, @@ -65,6 +68,7 @@ UmapConfig, ) from .read_dotenv import read_dotenv +from .resolve_timestamp_path import resolve_timestamp_path __all__ = [ "ApiKeyMissingError", @@ -119,5 +123,10 @@ "UmapConfig", "UmapConfigInput", "create_graphrag_config", + "enable_logging_with_config", + "load_config", + "load_config_from_file", "read_dotenv", + "resolve_timestamp_path", + "search_for_config_in_root_dir", ] diff --git a/graphrag/config/config_file_loader.py b/graphrag/config/config_file_loader.py index 3f045cdc41..667fbe8807 100644 --- a/graphrag/config/config_file_loader.py +++ b/graphrag/config/config_file_loader.py @@ -9,13 +9,13 @@ import yaml -from . import create_graphrag_config +from .create_graphrag_config import create_graphrag_config from .models.graph_rag_config import GraphRagConfig _default_config_files = ["settings.yaml", "settings.yml", "settings.json"] -def resolve_config_path_with_root(root: str | Path) -> Path: +def search_for_config_in_root_dir(root: str | Path) -> Path | None: """Resolve the config path from the given root directory. Parameters @@ -26,13 +26,9 @@ def resolve_config_path_with_root(root: str | Path) -> Path: Returns ------- - Path - The resolved config file path. - - Raises - ------ - FileNotFoundError - If the config file is not found or cannot be resolved for the directory. + Path | None + returns a Path if there is a config in the root directory + Otherwise returns None. """ root = Path(root) @@ -44,8 +40,7 @@ def resolve_config_path_with_root(root: str | Path) -> Path: if (root / file).is_file(): return root / file - msg = f"Unable to resolve config file for parent directory: {root}" - raise FileNotFoundError(msg) + return None class ConfigFileLoader(ABC): diff --git a/graphrag/config/load_config.py b/graphrag/config/load_config.py new file mode 100644 index 0000000000..54c9271fb1 --- /dev/null +++ b/graphrag/config/load_config.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Default method for loading config.""" + +from pathlib import Path + +from .config_file_loader import load_config_from_file, search_for_config_in_root_dir +from .create_graphrag_config import create_graphrag_config +from .models.graph_rag_config import GraphRagConfig +from .resolve_timestamp_path import resolve_timestamp_path + + +def load_config( + root_dir: str | Path, + config_filepath: str | Path | None = None, + run_id: str | None = None, +) -> GraphRagConfig: + """Load configuration from a file or create a default configuration. + + If a config file is not found the default configuration is created. + + Parameters + ---------- + root_dir : str | Path + The root directory of the project. Will search for the config file in this directory. + config_filepath : str | Path | None + The path to the config file. + If None, searches for config file in root and + if not found creates a default configuration. + run_id : str | None + The run id to use for resolving timestamp_paths. + """ + root = Path(root_dir).resolve() + + # If user specified a config file path then it is required + if config_filepath: + config_path = (root / config_filepath).resolve() + if not config_path.exists(): + msg = f"Specified Config file not found: {config_path}" + raise FileNotFoundError(msg) + + # Else optional resolve the config path from the root directory if it exists + config_path = search_for_config_in_root_dir(root) + if config_path: + config = load_config_from_file(config_path) + else: + config = create_graphrag_config(root_dir=str(root)) + + if run_id: + config.storage.base_dir = str( + resolve_timestamp_path((root / config.storage.base_dir).resolve(), run_id) + ) + config.reporting.base_dir = str( + resolve_timestamp_path((root / config.reporting.base_dir).resolve(), run_id) + ) + else: + config.storage.base_dir = str( + resolve_timestamp_path((root / config.storage.base_dir).resolve()) + ) + config.reporting.base_dir = str( + resolve_timestamp_path((root / config.reporting.base_dir).resolve()) + ) + + return config diff --git a/graphrag/config/logging.py b/graphrag/config/logging.py index 84d7369955..99ee459a27 100644 --- a/graphrag/config/logging.py +++ b/graphrag/config/logging.py @@ -8,7 +8,6 @@ from .enums import ReportingType from .models.graph_rag_config import GraphRagConfig -from .resolve_timestamp_path import resolve_timestamp_path def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None: @@ -35,7 +34,7 @@ def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None: def enable_logging_with_config( - config: GraphRagConfig, timestamp_value: str, verbose: bool = False + config: GraphRagConfig, verbose: bool = False ) -> tuple[bool, str]: """Enable logging to a file based on the config. @@ -56,10 +55,7 @@ def enable_logging_with_config( (True, str) if logging was enabled. """ if config.reporting.type == ReportingType.file: - log_path = resolve_timestamp_path( - Path(config.root_dir) / config.reporting.base_dir / "indexing-engine.log", - timestamp_value, - ) + log_path = Path(config.reporting.base_dir) / "indexing-engine.log" enable_logging(log_path, verbose) return (True, str(log_path)) return (False, "") diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index 0530290a63..df505e86a4 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -63,11 +63,6 @@ help="Create an initial configuration in the given path.", action="store_true", ) - parser.add_argument( - "--overlay-defaults", - help="Overlay default configuration values on a provided configuration file (--config).", - action="store_true", - ) parser.add_argument( "--skip-validations", help="Skip any preflight validation. Useful when running no LLM steps.", @@ -75,20 +70,16 @@ ) args = parser.parse_args() - if args.overlay_defaults and not args.config: - parser.error("--overlay-defaults requires --config") - index_cli( - root=args.root, + root_dir=args.root, verbose=args.verbose or False, resume=args.resume, memprofile=args.memprofile or False, nocache=args.nocache or False, reporter=args.reporter, - config=args.config, + config_filepath=args.config, emit=args.emit, dryrun=args.dryrun or False, init=args.init or False, - overlay_defaults=args.overlay_defaults or False, skip_validations=args.skip_validations or False, ) diff --git a/graphrag/index/api.py b/graphrag/index/api.py index a58e832c9b..adede4ae4c 100644 --- a/graphrag/index/api.py +++ b/graphrag/index/api.py @@ -8,9 +8,9 @@ Backwards compatibility is not guaranteed at this time. """ -from graphrag.config.enums import CacheType -from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.resolve_timestamp_path import resolve_timestamp_path +from pathlib import Path + +from graphrag.config import CacheType, GraphRagConfig from .cache.noop_pipeline_cache import NoopPipelineCache from .create_pipeline_config import create_pipeline_config @@ -50,11 +50,7 @@ async def build_index( list[PipelineRunResult] The list of pipeline run results """ - try: - resolve_timestamp_path(config.storage.base_dir, run_id) - resume = True - except ValueError as _: - resume = False + resume = Path(config.storage.base_dir).exists() pipeline_config = create_pipeline_config(config) pipeline_cache = ( NoopPipelineCache() if config.cache.type == CacheType.none is None else None diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 6dda401ca8..f32beb946b 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -11,13 +11,7 @@ import warnings from pathlib import Path -from graphrag.config import create_graphrag_config -from graphrag.config.config_file_loader import ( - load_config_from_file, - resolve_config_path_with_root, -) -from graphrag.config.enums import CacheType -from graphrag.config.logging import enable_logging_with_config +from graphrag.config import CacheType, enable_logging_with_config, load_config from .api import build_index from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT @@ -103,17 +97,16 @@ def handle_signal(signum, _): def index_cli( - root: str, + root_dir: str, init: bool, verbose: bool, resume: str | None, memprofile: bool, nocache: bool, reporter: str | None, - config: str | None, + config_filepath: str | None, emit: str | None, dryrun: bool, - overlay_defaults: bool, skip_validations: bool, ): """Run the pipeline with the given config.""" @@ -122,41 +115,30 @@ def index_cli( run_id = resume or time.strftime("%Y%m%d-%H%M%S") if init: - _initialize_project_at(root, progress_reporter) + _initialize_project_at(root_dir, progress_reporter) sys.exit(0) - if overlay_defaults or config: - config_path = ( - Path(root) / config if config else resolve_config_path_with_root(root) - ) - default_config = load_config_from_file(config_path) - else: - try: - config_path = resolve_config_path_with_root(root) - default_config = load_config_from_file(config_path) - except FileNotFoundError: - default_config = create_graphrag_config(root_dir=root) + root = Path(root_dir).resolve() + config = load_config(root, config_filepath, run_id) if nocache: - default_config.cache.type = CacheType.none + config.cache.type = CacheType.none - enabled_logging, log_path = enable_logging_with_config( - default_config, run_id, verbose - ) + enabled_logging, log_path = enable_logging_with_config(config, verbose) if enabled_logging: info(f"Logging enabled at {log_path}", True) else: info( - f"Logging not enabled for config {_redact(default_config.model_dump())}", + f"Logging not enabled for config {_redact(config.model_dump())}", True, ) if skip_validations: - validate_config_names(progress_reporter, default_config) + validate_config_names(progress_reporter, config) info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose) info( - f"Using default configuration: {_redact(default_config.model_dump())}", + f"Using default configuration: {_redact(config.model_dump())}", verbose, ) @@ -170,7 +152,7 @@ def index_cli( outputs = asyncio.run( build_index( - default_config, + config, run_id, memprofile, progress_reporter, diff --git a/graphrag/prompt_tune/cli.py b/graphrag/prompt_tune/cli.py index eb8ff6f49f..92cb581c90 100644 --- a/graphrag/prompt_tune/cli.py +++ b/graphrag/prompt_tune/cli.py @@ -5,11 +5,11 @@ from pathlib import Path +from graphrag.config import load_config from graphrag.index.progress import PrintProgressReporter from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT from graphrag.prompt_tune.loader import ( MIN_CHUNK_SIZE, - read_config_parameters, ) from . import api @@ -53,11 +53,12 @@ async def prompt_tune( - min_examples_required: The minimum number of examples required for entity extraction prompts. """ reporter = PrintProgressReporter("") - graph_config = read_config_parameters(root, reporter, config) + root_path = Path(root).resolve() + graph_config = load_config(root_path, config) prompts = await api.generate_indexing_prompts( config=graph_config, - root=root, + root=str(root_path), chunk_size=chunk_size, limit=limit, selection_method=selection_method, @@ -70,7 +71,7 @@ async def prompt_tune( k=k, ) - output_path = Path(output) + output_path = (root_path / output).resolve() if output_path: reporter.info(f"Writing prompts to {output_path}") output_path.mkdir(parents=True, exist_ok=True) diff --git a/graphrag/prompt_tune/loader/__init__.py b/graphrag/prompt_tune/loader/__init__.py index 94e64cbe87..bc8026e92d 100644 --- a/graphrag/prompt_tune/loader/__init__.py +++ b/graphrag/prompt_tune/loader/__init__.py @@ -3,12 +3,10 @@ """Fine-tuning config and data loader module.""" -from .config import read_config_parameters from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks __all__ = [ "MIN_CHUNK_OVERLAP", "MIN_CHUNK_SIZE", "load_docs_in_chunks", - "read_config_parameters", ] diff --git a/graphrag/prompt_tune/loader/config.py b/graphrag/prompt_tune/loader/config.py deleted file mode 100644 index 350feacd79..0000000000 --- a/graphrag/prompt_tune/loader/config.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Config loading, parsing and handling module.""" - -from pathlib import Path - -from graphrag.config import create_graphrag_config -from graphrag.index.progress.types import ProgressReporter - - -def read_config_parameters( - root: str, reporter: ProgressReporter, config: str | None = None -): - """Read the configuration parameters from the settings file or environment variables. - - Parameters - ---------- - - root: The root directory where the parameters are. - - reporter: The progress reporter. - - config: The path to the settings file. - """ - _root = Path(root) - settings_yaml = ( - Path(config) - if config and Path(config).suffix in [".yaml", ".yml"] - else _root / "settings.yaml" - ) - if not settings_yaml.exists(): - settings_yaml = _root / "settings.yml" - if settings_yaml.exists(): - reporter.info(f"Reading settings from {settings_yaml}") - with settings_yaml.open("rb") as file: - import yaml - - data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - settings_json = ( - Path(config) - if config and Path(config).suffix == ".json" - else _root / "settings.json" - ) - if settings_yaml.exists(): - reporter.info(f"Reading settings from {settings_yaml}") - with settings_yaml.open("rb") as file: - import yaml - - data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - if settings_json.exists(): - reporter.info(f"Reading settings from {settings_json}") - with settings_json.open("rb") as file: - import json - - data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - reporter.info("Reading settings from environment variables") - return create_graphrag_config(root_dir=root) diff --git a/graphrag/query/api.py b/graphrag/query/api.py index 57f5a12305..d4983d9a89 100644 --- a/graphrag/query/api.py +++ b/graphrag/query/api.py @@ -24,8 +24,7 @@ import pandas as pd from pydantic import validate_call -from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.resolve_timestamp_path import resolve_timestamp_path +from graphrag.config import GraphRagConfig from graphrag.index.progress.types import PrintProgressReporter from graphrag.model.entity import Entity from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 @@ -149,7 +148,6 @@ async def global_search_streaming( @validate_call(config={"arbitrary_types_allowed": True}) async def local_search( - root_dir: str | None, config: GraphRagConfig, nodes: pd.DataFrame, entities: pd.DataFrame, @@ -196,9 +194,8 @@ async def local_search( _entities = read_indexer_entities(nodes, entities, community_level) - base_dir = Path(str(root_dir)) / config.storage.base_dir - resolved_base_dir = resolve_timestamp_path(base_dir) - lancedb_dir = resolved_base_dir / "lancedb" + lancedb_dir = Path(config.storage.base_dir) / "lancedb" + vector_store_args.update({"db_uri": str(lancedb_dir)}) description_embedding_store = _get_embedding_description_store( entities=_entities, @@ -227,7 +224,6 @@ async def local_search( @validate_call(config={"arbitrary_types_allowed": True}) async def local_search_streaming( - root_dir: str | None, config: GraphRagConfig, nodes: pd.DataFrame, entities: pd.DataFrame, @@ -271,9 +267,8 @@ async def local_search_streaming( _entities = read_indexer_entities(nodes, entities, community_level) - base_dir = Path(str(root_dir)) / config.storage.base_dir - resolved_base_dir = resolve_timestamp_path(base_dir) - lancedb_dir = resolved_base_dir / "lancedb" + lancedb_dir = lancedb_dir = Path(config.storage.base_dir) / "lancedb" + vector_store_args.update({"db_uri": str(lancedb_dir)}) description_embedding_store = _get_embedding_description_store( entities=_entities, diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index ed7dc4566a..f3847c0c28 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -4,17 +4,12 @@ """Command line interface for the query module.""" import asyncio -import re import sys from pathlib import Path import pandas as pd -from graphrag.config import ( - GraphRagConfig, - create_graphrag_config, -) -from graphrag.config.resolve_timestamp_path import resolve_timestamp_path +from graphrag.config import load_config, resolve_timestamp_path from graphrag.index.progress import PrintProgressReporter from . import api @@ -25,7 +20,7 @@ def run_global_search( config_filepath: str | None, data_dir: str | None, - root_dir: str | None, + root_dir: str, community_level: int, response_type: str, streaming: bool, @@ -35,10 +30,15 @@ def run_global_search( Loads index files required for global search and calls the Query API. """ - data_dir, root_dir, config = _configure_paths_and_settings( - data_dir, root_dir, config_filepath - ) - data_path = Path(data_dir) + root = Path(root_dir).resolve() + config = load_config(root, config_filepath) + + if data_dir: + config.storage.base_dir = str( + resolve_timestamp_path((root / data_dir).resolve()) + ) + + data_path = Path(config.storage.base_dir).resolve() final_nodes: pd.DataFrame = pd.read_parquet( data_path / "create_final_nodes.parquet" @@ -98,7 +98,7 @@ async def run_streaming_search(): def run_local_search( config_filepath: str | None, data_dir: str | None, - root_dir: str | None, + root_dir: str, community_level: int, response_type: str, streaming: bool, @@ -108,10 +108,15 @@ def run_local_search( Loads index files required for local search and calls the Query API. """ - data_dir, root_dir, config = _configure_paths_and_settings( - data_dir, root_dir, config_filepath - ) - data_path = Path(data_dir) + root = Path(root_dir).resolve() + config = load_config(root, config_filepath) + + if data_dir: + config.storage.base_dir = str( + resolve_timestamp_path((root / data_dir).resolve()) + ) + + data_path = Path(config.storage.base_dir).resolve() final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet") final_community_reports = pd.read_parquet( @@ -137,7 +142,6 @@ async def run_streaming_search(): context_data = None get_context_data = True async for stream_chunk in api.local_search_streaming( - root_dir=root_dir, config=config, nodes=final_nodes, entities=final_entities, @@ -163,7 +167,6 @@ async def run_streaming_search(): # not streaming response, context_data = asyncio.run( api.local_search( - root_dir=root_dir, config=config, nodes=final_nodes, entities=final_entities, @@ -180,77 +183,3 @@ async def run_streaming_search(): # NOTE: we return the response and context data here purely as a complete demonstration of the API. # External users should use the API directly to get the response and context data. return response, context_data - - -def _configure_paths_and_settings( - data_dir: str | None, - root_dir: str | None, - config_filepath: str | None, -) -> tuple[str, str | None, GraphRagConfig]: - config = _create_graphrag_config(root_dir, config_filepath) - if data_dir is None and root_dir is None: - msg = "Either data_dir or root_dir must be provided." - raise ValueError(msg) - if data_dir is None: - base_dir = Path(str(root_dir)) / config.storage.base_dir - data_dir = str(resolve_timestamp_path(base_dir)) - return data_dir, root_dir, config - - -def _infer_data_dir(root: str) -> str: - output = Path(root) / "output" - # use the latest data-run folder - if output.exists(): - expr = re.compile(r"\d{8}-\d{6}") - filtered = [f for f in output.iterdir() if f.is_dir() and expr.match(f.name)] - folders = sorted(filtered, key=lambda f: f.name, reverse=True) - if len(folders) > 0: - folder = folders[0] - return str((folder / "artifacts").absolute()) - msg = f"Could not infer data directory from root={root}" - raise ValueError(msg) - - -def _create_graphrag_config( - root: str | None, - config_filepath: str | None, -) -> GraphRagConfig: - """Create a GraphRag configuration.""" - return _read_config_parameters(root or "./", config_filepath) - - -def _read_config_parameters(root: str, config: str | None): - _root = Path(root) - settings_yaml = ( - Path(config) - if config and Path(config).suffix in [".yaml", ".yml"] - else _root / "settings.yaml" - ) - if not settings_yaml.exists(): - settings_yaml = _root / "settings.yml" - - if settings_yaml.exists(): - reporter.info(f"Reading settings from {settings_yaml}") - with settings_yaml.open( - "rb", - ) as file: - import yaml - - data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - settings_json = ( - Path(config) - if config and Path(config).suffix == ".json" - else _root / "settings.json" - ) - if settings_json.exists(): - reporter.info(f"Reading settings from {settings_json}") - with settings_json.open("rb") as file: - import json - - data = json.loads(file.read().decode(encoding="utf-8", errors="strict")) - return create_graphrag_config(data, root) - - reporter.info("Reading settings from environment variables") - return create_graphrag_config(root_dir=root) diff --git a/pyproject.toml b/pyproject.toml index 9bb9451789..bcd7b997bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ _convert_global_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/ _semversioner_release = "semversioner release" _semversioner_changelog = "semversioner changelog > CHANGELOG.md" _semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)" +semversioner_add = "semversioner add-change" coverage_report = 'coverage report --omit "**/tests/**" --show-missing' check_format = 'ruff format . --check --preview' fix = "ruff --preview check --fix ." diff --git a/tests/unit/config/fixtures/timestamp_dirs/20240812-120000/empty.txt b/tests/unit/config/fixtures/timestamp_dirs/20240812-120000/empty.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/config/test_resolve_timestamp_path.py b/tests/unit/config/test_resolve_timestamp_path.py new file mode 100644 index 0000000000..0af0f11a5e --- /dev/null +++ b/tests/unit/config/test_resolve_timestamp_path.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from pathlib import Path + +from graphrag.config.resolve_timestamp_path import resolve_timestamp_path + + +def test_resolve_timestamp_path_no_timestamp_with_run_id(): + path = Path("path/to/data") + result = resolve_timestamp_path(path, "20240812-121000") + assert result == path + + +def test_resolve_timestamp_path_no_timestamp_without_run_id(): + path = Path("path/to/data") + result = resolve_timestamp_path(path) + assert result == path + + +def test_resolve_timestamp_path_with_timestamp_and_run_id(): + path = Path("some/path/${timestamp}/data") + expected = Path("some/path/20240812/data") + result = resolve_timestamp_path(path, "20240812") + assert result == expected + + +def test_resolve_timestamp_path_with_timestamp_and_inferred_directory(): + cwd = Path(__file__).parent + path = cwd / "fixtures/timestamp_dirs/${timestamp}/data" + expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data" + result = resolve_timestamp_path(path) + assert result == expected diff --git a/tests/unit/query/test_infer_data_dir.py b/tests/unit/query/test_infer_data_dir.py deleted file mode 100644 index c950eb35ae..0000000000 --- a/tests/unit/query/test_infer_data_dir.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -from pathlib import Path - -import pytest - -from graphrag.query.cli import _infer_data_dir - - -def test_infer_data_dir(): - root = "./tests/unit/query/data/defaults" - result = Path(_infer_data_dir(root)) - assert result.parts[-2] == "20240812-121000" - - -def test_infer_data_dir_ignores_hidden_files(): - """A hidden file, starting with '.', will naturally be selected as latest data directory.""" - root = "./tests/unit/query/data/hidden" - result = Path(_infer_data_dir(root)) - assert result.parts[-2] == "20240812-121000" - - -def test_infer_data_dir_ignores_non_numeric(): - root = "./tests/unit/query/data/non-numeric" - result = Path(_infer_data_dir(root)) - assert result.parts[-2] == "20240812-121000" - - -def test_infer_data_dir_throws_on_no_match(): - root = "./tests/unit/query/data/empty" - with pytest.raises(ValueError): # noqa PT011 (this is what is actually thrown...) - _infer_data_dir(root)