-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_index.py
152 lines (126 loc) · 5.4 KB
/
create_index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# SCRIPT is still a WORK-IN-PROGRESS and may not fully work yet
import asyncio
import logging
import os
import yaml
from typing import Dict, Any, Optional, Union
from llama_index.llms.bedrock import Bedrock
from llama_index.embeddings.bedrock import BedrockEmbedding
from llama_index.core import Settings, VectorStoreIndex
from rag.index_builder import IndexBuilder
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(), logging.FileHandler("index_builder.log")],
)
logger = logging.getLogger(__name__)
DEFAULT_CONFIG: Dict[str, Any] = {
"sources": {"urls": [], "github_orgs": []},
"llm_config": {
"embed_model": "cohere.embed-english-v3",
"aws_region": "us-east-1",
"aws_access_key_id": None,
"aws_secret_access_key": None,
"github_token": None,
},
"index_settings": {
"force_reload": False,
"incremental": False,
},
}
def load_config() -> Dict[str, Any]:
"""Load configuration from YAML file and process environment variables.
Returns:
Dict[str, Any]: Processed configuration dictionary
"""
config_path = os.path.join(os.path.dirname(__file__), "config.yaml")
with open(config_path, "r") as file:
config = yaml.safe_load(file)
def process_env_vars(item):
if isinstance(item, str) and item.startswith("OS_ENV_"):
env_var = item[7:] # Remove "OS_ENV_" prefix
value = os.getenv(env_var)
if value is None:
raise EnvironmentError(
f"Required environment variable {env_var} not set"
)
return value
elif isinstance(item, dict):
return {k: process_env_vars(v) for k, v in item.items()}
elif isinstance(item, list):
return [process_env_vars(v) for v in item]
return item
return process_env_vars(config)
async def build_index() -> Optional[VectorStoreIndex]:
"""Build or load the vector index using configuration settings.
This function handles the entire index building process, including:
1. Loading and validating configuration
2. Setting up the embedding model with AWS Bedrock
3. Initializing and running the index builder
4. Logging statistics about the built index
The index is built based on the sources specified in the config file,
which can include both URLs and GitHub organizations. The function supports
both creating a new index and incrementally updating an existing one.
Returns:
Optional[VectorStoreIndex]: The built or loaded vector index if successful,
None if an error occurs during the build process.
Raises:
Exception: If there are errors during configuration loading or index building.
The specific exception types depend on the failure mode:
- EnvironmentError: Missing required environment variables
- ValueError: Invalid configuration values
- Various exceptions from llama_index during index building
"""
try:
logger.info("Loading configuration...")
config = load_config()
# Configure the embedding model
Settings.embed_model = BedrockEmbedding(
model=config["llm_config"].get(
"embed_model", DEFAULT_CONFIG["llm_config"]["embed_model"]
),
region_name=config["llm_config"].get(
"aws_region", DEFAULT_CONFIG["llm_config"]["aws_region"]
),
aws_access_key_id=config["llm_config"]["aws_access_key_id"],
aws_secret_access_key=config["llm_config"]["aws_secret_access_key"],
)
# Initialize the index builder
index_builder = IndexBuilder()
# Build or load the index
logger.info("Building/loading index...")
index = await index_builder.build_or_load_index(
urls=config["sources"].get("urls", DEFAULT_CONFIG["sources"]["urls"]),
orgs=config["sources"].get(
"github_orgs", DEFAULT_CONFIG["sources"]["github_orgs"]
),
github_token=config["llm_config"].get("github_token"),
force_reload=config.get("index_settings", {}).get(
"force_reload", DEFAULT_CONFIG["index_settings"]["force_reload"]
),
incremental=config.get("index_settings", {}).get(
"incremental", DEFAULT_CONFIG["index_settings"]["incremental"]
),
)
# Get and log index statistics
stats = index_builder.get_index_stats()
logger.info("Index statistics:")
logger.info(f"- Last update: {stats['last_update']}")
logger.info(f"- Total documents: {stats['total_documents']}")
logger.info(f"- Processed URLs: {stats['processed_urls']}")
logger.info(f"- Processed organizations: {stats['processed_orgs']}")
logger.info("Index build completed successfully")
return index
except Exception as e:
logger.error(f"Error building index: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
"""Main entry point for the standalone index builder script.
Runs the index building process and handles any exceptions that occur,
logging them appropriately.
"""
try:
asyncio.run(build_index())
except Exception as e:
logger.error(f"Failed to build index: {str(e)}")
raise