Skip to content

Commit

Permalink
Consistent config load_config (#1065)
Browse files Browse the repository at this point in the history
* Consistent config load_config

- Provide a consistent way to load configuration
- Resolve potential timestamp directories upfront
    upon config object creation
- Add unit tests for resolving timestamp directories
- Resolves #599
- Resolves #1049

* fix formatting issues

* remove unnecessary path resolution

* fix smoke tests

* update prompts to use load_config

* Update none checks

* Update none checks

* Update searching for config method signature

* Update unit tests

* fix formatting issues
  • Loading branch information
dworthen committed Sep 3, 2024
1 parent 3f98002 commit ab29cc2
Show file tree
Hide file tree
Showing 17 changed files with 169 additions and 267 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20240830151802543194.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Consistent config loading. Resolves #99 and Resolves #1049"
}
9 changes: 9 additions & 0 deletions graphrag/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""The Indexing Engine default config package root."""

from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
from .create_graphrag_config import (
create_graphrag_config,
)
Expand Down Expand Up @@ -42,6 +43,8 @@
TextEmbeddingConfigInput,
UmapConfigInput,
)
from .load_config import load_config
from .logging import enable_logging_with_config
from .models import (
CacheConfig,
ChunkingConfig,
Expand All @@ -65,6 +68,7 @@
UmapConfig,
)
from .read_dotenv import read_dotenv
from .resolve_timestamp_path import resolve_timestamp_path

__all__ = [
"ApiKeyMissingError",
Expand Down Expand Up @@ -119,5 +123,10 @@
"UmapConfig",
"UmapConfigInput",
"create_graphrag_config",
"enable_logging_with_config",
"load_config",
"load_config_from_file",
"read_dotenv",
"resolve_timestamp_path",
"search_for_config_in_root_dir",
]
17 changes: 6 additions & 11 deletions graphrag/config/config_file_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

import yaml

from . import create_graphrag_config
from .create_graphrag_config import create_graphrag_config
from .models.graph_rag_config import GraphRagConfig

_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]


def resolve_config_path_with_root(root: str | Path) -> Path:
def search_for_config_in_root_dir(root: str | Path) -> Path | None:
"""Resolve the config path from the given root directory.
Parameters
Expand All @@ -26,13 +26,9 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
Returns
-------
Path
The resolved config file path.
Raises
------
FileNotFoundError
If the config file is not found or cannot be resolved for the directory.
Path | None
returns a Path if there is a config in the root directory
Otherwise returns None.
"""
root = Path(root)

Expand All @@ -44,8 +40,7 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
if (root / file).is_file():
return root / file

msg = f"Unable to resolve config file for parent directory: {root}"
raise FileNotFoundError(msg)
return None


class ConfigFileLoader(ABC):
Expand Down
65 changes: 65 additions & 0 deletions graphrag/config/load_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""Default method for loading config."""

from pathlib import Path

from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
from .create_graphrag_config import create_graphrag_config
from .models.graph_rag_config import GraphRagConfig
from .resolve_timestamp_path import resolve_timestamp_path


def load_config(
root_dir: str | Path,
config_filepath: str | Path | None = None,
run_id: str | None = None,
) -> GraphRagConfig:
"""Load configuration from a file or create a default configuration.
If a config file is not found the default configuration is created.
Parameters
----------
root_dir : str | Path
The root directory of the project. Will search for the config file in this directory.
config_filepath : str | Path | None
The path to the config file.
If None, searches for config file in root and
if not found creates a default configuration.
run_id : str | None
The run id to use for resolving timestamp_paths.
"""
root = Path(root_dir).resolve()

# If user specified a config file path then it is required
if config_filepath:
config_path = (root / config_filepath).resolve()
if not config_path.exists():
msg = f"Specified Config file not found: {config_path}"
raise FileNotFoundError(msg)

# Else optional resolve the config path from the root directory if it exists
config_path = search_for_config_in_root_dir(root)
if config_path:
config = load_config_from_file(config_path)
else:
config = create_graphrag_config(root_dir=str(root))

if run_id:
config.storage.base_dir = str(
resolve_timestamp_path((root / config.storage.base_dir).resolve(), run_id)
)
config.reporting.base_dir = str(
resolve_timestamp_path((root / config.reporting.base_dir).resolve(), run_id)
)
else:
config.storage.base_dir = str(
resolve_timestamp_path((root / config.storage.base_dir).resolve())
)
config.reporting.base_dir = str(
resolve_timestamp_path((root / config.reporting.base_dir).resolve())
)

return config
8 changes: 2 additions & 6 deletions graphrag/config/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from .enums import ReportingType
from .models.graph_rag_config import GraphRagConfig
from .resolve_timestamp_path import resolve_timestamp_path


def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
Expand All @@ -35,7 +34,7 @@ def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:


def enable_logging_with_config(
config: GraphRagConfig, timestamp_value: str, verbose: bool = False
config: GraphRagConfig, verbose: bool = False
) -> tuple[bool, str]:
"""Enable logging to a file based on the config.
Expand All @@ -56,10 +55,7 @@ def enable_logging_with_config(
(True, str) if logging was enabled.
"""
if config.reporting.type == ReportingType.file:
log_path = resolve_timestamp_path(
Path(config.root_dir) / config.reporting.base_dir / "indexing-engine.log",
timestamp_value,
)
log_path = Path(config.reporting.base_dir) / "indexing-engine.log"
enable_logging(log_path, verbose)
return (True, str(log_path))
return (False, "")
13 changes: 2 additions & 11 deletions graphrag/index/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,23 @@
help="Create an initial configuration in the given path.",
action="store_true",
)
parser.add_argument(
"--overlay-defaults",
help="Overlay default configuration values on a provided configuration file (--config).",
action="store_true",
)
parser.add_argument(
"--skip-validations",
help="Skip any preflight validation. Useful when running no LLM steps.",
action="store_true",
)
args = parser.parse_args()

if args.overlay_defaults and not args.config:
parser.error("--overlay-defaults requires --config")

index_cli(
root=args.root,
root_dir=args.root,
verbose=args.verbose or False,
resume=args.resume,
memprofile=args.memprofile or False,
nocache=args.nocache or False,
reporter=args.reporter,
config=args.config,
config_filepath=args.config,
emit=args.emit,
dryrun=args.dryrun or False,
init=args.init or False,
overlay_defaults=args.overlay_defaults or False,
skip_validations=args.skip_validations or False,
)
12 changes: 4 additions & 8 deletions graphrag/index/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
Backwards compatibility is not guaranteed at this time.
"""

from graphrag.config.enums import CacheType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from pathlib import Path

from graphrag.config import CacheType, GraphRagConfig

from .cache.noop_pipeline_cache import NoopPipelineCache
from .create_pipeline_config import create_pipeline_config
Expand Down Expand Up @@ -50,11 +50,7 @@ async def build_index(
list[PipelineRunResult]
The list of pipeline run results
"""
try:
resolve_timestamp_path(config.storage.base_dir, run_id)
resume = True
except ValueError as _:
resume = False
resume = Path(config.storage.base_dir).exists()
pipeline_config = create_pipeline_config(config)
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
Expand Down
42 changes: 12 additions & 30 deletions graphrag/index/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,7 @@
import warnings
from pathlib import Path

from graphrag.config import create_graphrag_config
from graphrag.config.config_file_loader import (
load_config_from_file,
resolve_config_path_with_root,
)
from graphrag.config.enums import CacheType
from graphrag.config.logging import enable_logging_with_config
from graphrag.config import CacheType, enable_logging_with_config, load_config

from .api import build_index
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
Expand Down Expand Up @@ -103,17 +97,16 @@ def handle_signal(signum, _):


def index_cli(
root: str,
root_dir: str,
init: bool,
verbose: bool,
resume: str | None,
memprofile: bool,
nocache: bool,
reporter: str | None,
config: str | None,
config_filepath: str | None,
emit: str | None,
dryrun: bool,
overlay_defaults: bool,
skip_validations: bool,
):
"""Run the pipeline with the given config."""
Expand All @@ -122,41 +115,30 @@ def index_cli(
run_id = resume or time.strftime("%Y%m%d-%H%M%S")

if init:
_initialize_project_at(root, progress_reporter)
_initialize_project_at(root_dir, progress_reporter)
sys.exit(0)

if overlay_defaults or config:
config_path = (
Path(root) / config if config else resolve_config_path_with_root(root)
)
default_config = load_config_from_file(config_path)
else:
try:
config_path = resolve_config_path_with_root(root)
default_config = load_config_from_file(config_path)
except FileNotFoundError:
default_config = create_graphrag_config(root_dir=root)
root = Path(root_dir).resolve()
config = load_config(root, config_filepath, run_id)

if nocache:
default_config.cache.type = CacheType.none
config.cache.type = CacheType.none

enabled_logging, log_path = enable_logging_with_config(
default_config, run_id, verbose
)
enabled_logging, log_path = enable_logging_with_config(config, verbose)
if enabled_logging:
info(f"Logging enabled at {log_path}", True)
else:
info(
f"Logging not enabled for config {_redact(default_config.model_dump())}",
f"Logging not enabled for config {_redact(config.model_dump())}",
True,
)

if skip_validations:
validate_config_names(progress_reporter, default_config)
validate_config_names(progress_reporter, config)

info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
info(
f"Using default configuration: {_redact(default_config.model_dump())}",
f"Using default configuration: {_redact(config.model_dump())}",
verbose,
)

Expand All @@ -170,7 +152,7 @@ def index_cli(

outputs = asyncio.run(
build_index(
default_config,
config,
run_id,
memprofile,
progress_reporter,
Expand Down
9 changes: 5 additions & 4 deletions graphrag/prompt_tune/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from pathlib import Path

from graphrag.config import load_config
from graphrag.index.progress import PrintProgressReporter
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import (
MIN_CHUNK_SIZE,
read_config_parameters,
)

from . import api
Expand Down Expand Up @@ -53,11 +53,12 @@ async def prompt_tune(
- min_examples_required: The minimum number of examples required for entity extraction prompts.
"""
reporter = PrintProgressReporter("")
graph_config = read_config_parameters(root, reporter, config)
root_path = Path(root).resolve()
graph_config = load_config(root_path, config)

prompts = await api.generate_indexing_prompts(
config=graph_config,
root=root,
root=str(root_path),
chunk_size=chunk_size,
limit=limit,
selection_method=selection_method,
Expand All @@ -70,7 +71,7 @@ async def prompt_tune(
k=k,
)

output_path = Path(output)
output_path = (root_path / output).resolve()
if output_path:
reporter.info(f"Writing prompts to {output_path}")
output_path.mkdir(parents=True, exist_ok=True)
Expand Down
2 changes: 0 additions & 2 deletions graphrag/prompt_tune/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@

"""Fine-tuning config and data loader module."""

from .config import read_config_parameters
from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks

__all__ = [
"MIN_CHUNK_OVERLAP",
"MIN_CHUNK_SIZE",
"load_docs_in_chunks",
"read_config_parameters",
]
Loading

0 comments on commit ab29cc2

Please sign in to comment.