Skip to content

Commit

Permalink
ref(ai-autofix): Better design for document chunk models (#276)
Browse files Browse the repository at this point in the history
  • Loading branch information
jennmueng authored Mar 3, 2024
1 parent 0311570 commit df843e2
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 61 deletions.
4 changes: 2 additions & 2 deletions src/seer/automation/autofix/autofix.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from seer.automation.autofix.tools import BaseTools, CodeActionTools
from seer.automation.autofix.utils import escape_multi_xml, extract_xml_element_text
from seer.automation.codebase.models import DocumentChunkWithEmbedding
from seer.automation.codebase.models import StoredDocumentChunk
from seer.automation.codebase.tasks import update_codebase_index

logger = logging.getLogger("autofix")
Expand Down Expand Up @@ -457,7 +457,7 @@ def parser(x: str | None) -> list[str] | None:
return ""

context_dump = ""
unique_chunks: dict[str, DocumentChunkWithEmbedding] = {}
unique_chunks: dict[str, StoredDocumentChunk] = {}
for query in queries:
retrived_chunks = self.context.query(query, top_k=4)
for chunk in retrived_chunks:
Expand Down
4 changes: 2 additions & 2 deletions src/seer/automation/autofix/autofix_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from seer.automation.autofix.models import RepoDefinition, Stacktrace
from seer.automation.codebase.codebase_index import CodebaseIndex
from seer.automation.codebase.models import DocumentChunkWithEmbeddingAndId
from seer.automation.codebase.models import StoredDocumentChunk
from seer.automation.utils import get_embedding_model
from seer.db import DbDocumentChunk, Session

Expand Down Expand Up @@ -80,7 +80,7 @@ def query(
for db_chunk in db_chunks:
chunks_by_repo_id.setdefault(db_chunk.repo_id, []).append(db_chunk)

populated_chunks: list[DocumentChunkWithEmbeddingAndId] = []
populated_chunks: list[StoredDocumentChunk] = []
for _repo_id, db_chunks in chunks_by_repo_id.items():
codebase = self.get_codebase(_repo_id)
populated_chunks.extend(codebase._populate_chunks(db_chunks))
Expand Down
65 changes: 32 additions & 33 deletions src/seer/automation/codebase/codebase_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import uuid

import numpy as np
from langsmith import RunTree, traceable
from langsmith import traceable
from tqdm import tqdm

from seer.automation.autofix.models import FileChange, RepoDefinition, Stacktrace
from seer.automation.codebase.models import (
BaseDocumentChunk,
Document,
DocumentChunk,
DocumentChunkWithEmbedding,
DocumentChunkWithEmbeddingAndId,
EmbeddedDocumentChunk,
RepositoryInfo,
StoredDocumentChunk,
)
from seer.automation.codebase.parser import DocumentParser
from seer.automation.codebase.repo_client import RepoClient
Expand Down Expand Up @@ -114,6 +114,18 @@ def create(cls, organization: int, project: int, repo: RepoDefinition, run_id: u
tmp_dir, tmp_repo_dir = repo_client.load_repo_to_tmp_dir(head_sha)
logger.debug(f"Loaded repository to {tmp_repo_dir}")
try:
documents = read_directory(tmp_repo_dir)

logger.debug(f"Read {len(documents)} documents:")
documents_by_language = group_documents_by_language(documents)
for language, docs in documents_by_language.items():
logger.debug(f" {language}: {len(docs)}")

doc_parser = DocumentParser(get_embedding_model())
chunks = doc_parser.process_documents(documents)
embedded_chunks = cls._embed_chunks(chunks)
logger.debug(f"Processed {len(chunks)} chunks")

with Session() as session:
db_repo_info = DbRepositoryInfo(
organization=organization,
Expand All @@ -124,26 +136,16 @@ def create(cls, organization: int, project: int, repo: RepoDefinition, run_id: u
)
session.add(db_repo_info)
session.flush()
logger.debug(f"Inserted repository info with id {db_repo_info.id}")

documents = read_directory(tmp_repo_dir, repo_id=db_repo_info.id)

logger.debug(f"Read {len(documents)} documents:")
documents_by_language = group_documents_by_language(documents)
for language, docs in documents_by_language.items():
logger.debug(f" {language}: {len(docs)}")
db_chunks = [
chunk.to_db_model(repo_id=db_repo_info.id) for chunk in embedded_chunks
]

doc_parser = DocumentParser(get_embedding_model())
chunks = doc_parser.process_documents(documents)
embedded_chunks = cls._embed_chunks(chunks)
logger.debug(f"Processed {len(chunks)} chunks")

repo_info = RepositoryInfo.from_db(db_repo_info)

db_chunks = [chunk.to_db_model() for chunk in embedded_chunks]
session.add_all(db_chunks)
session.commit()

repo_info = RepositoryInfo.from_db(db_repo_info)

logger.debug(f"Create Step: Inserted {len(chunks)} chunks into the database")

return cls(organization, project, repo_client, repo_info, run_id)
Expand Down Expand Up @@ -174,15 +176,17 @@ def update(self):
logger.debug(f"Loaded repository to {tmp_repo_dir}")

try:
documents = read_specific_files(tmp_repo_dir, changed_files, repo_id=self.repo_info.id)
documents = read_specific_files(tmp_repo_dir, changed_files)

doc_parser = DocumentParser(get_embedding_model())
chunks = doc_parser.process_documents(documents)
embedded_chunks = self._embed_chunks(chunks)
logger.debug(f"Processed {len(chunks)} chunks")

with Session() as session:
db_chunks = [chunk.to_db_model() for chunk in embedded_chunks]
db_chunks = [
chunk.to_db_model(repo_id=self.repo_info.id) for chunk in embedded_chunks
]
session.add_all(db_chunks)

if removed_files:
Expand All @@ -204,7 +208,7 @@ def update(self):
cleanup_dir(tmp_dir)

@classmethod
def _embed_chunks(cls, chunks: list[DocumentChunk]) -> list[DocumentChunkWithEmbedding]:
def _embed_chunks(cls, chunks: list[BaseDocumentChunk]) -> list[EmbeddedDocumentChunk]:
logger.debug(f"Embedding {len(chunks)} chunks...")
embeddings_list: list[np.ndarray] = []

Expand All @@ -223,7 +227,7 @@ def _embed_chunks(cls, chunks: list[DocumentChunk]) -> list[DocumentChunkWithEmb
embedded_chunks = []
for i, chunk in enumerate(chunks):
embedded_chunks.append(
DocumentChunkWithEmbedding(
EmbeddedDocumentChunk(
**chunk.model_dump(),
embedding=embeddings[i],
)
Expand Down Expand Up @@ -277,9 +281,7 @@ def get_document(self, path: str, ignore_local_changes=False) -> Document | None
logger.warning(f"Unsupported language for {path}")
return None

document = Document(
path=path, text=document_content, repo_id=self.repo_info.id, language=language
)
document = Document(path=path, text=document_content, language=language)

content = document_content
if not ignore_local_changes:
Expand Down Expand Up @@ -330,7 +332,7 @@ def update_document_temporarily(self, document: Document):

db_chunks: list[DbDocumentChunk] = []
for chunk in embedded_chunks:
db_chunk = chunk.to_db_model()
db_chunk = chunk.to_db_model(repo_id=self.repo_info.id)
db_chunk.namespace = str(self.run_id)
session.add_all(db_chunks)
session.commit()
Expand Down Expand Up @@ -364,15 +366,13 @@ def process_stacktrace(self, stacktrace: Stacktrace):
frame.filename = valid_path
break

def _populate_chunks(
self, chunks: list[DbDocumentChunk]
) -> list[DocumentChunkWithEmbeddingAndId]:
def _populate_chunks(self, chunks: list[DbDocumentChunk]) -> list[StoredDocumentChunk]:
### This seems awfully wasteful to chunk and hash a document for each returned chunk but I guess we are offloading the work to when it's needed?
assert self.repo_info is not None, "Repository info is not set"

doc_parser = DocumentParser(get_embedding_model())

matched_chunks: list[DocumentChunkWithEmbeddingAndId] = []
matched_chunks: list[StoredDocumentChunk] = []
for chunk in chunks:
content = self._get_file_content_with_cache(chunk.path, self.repo_info.sha)

Expand All @@ -385,7 +385,6 @@ def _populate_chunks(
Document(
path=chunk.path,
text=content,
repo_id=self.repo_info.id,
language=chunk.language,
)
)
Expand All @@ -396,7 +395,7 @@ def _populate_chunks(
continue

matched_chunks.append(
DocumentChunkWithEmbeddingAndId(
StoredDocumentChunk(
id=chunk.id,
path=chunk.path,
index=chunk.index,
Expand Down
13 changes: 6 additions & 7 deletions src/seer/automation/codebase/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
class Document(BaseModel):
path: str
text: str
repo_id: int
language: str


class DocumentChunk(BaseModel):
class BaseDocumentChunk(BaseModel):
id: Optional[int] = None
content: str
context: Optional[str]
Expand All @@ -23,7 +22,6 @@ class DocumentChunk(BaseModel):
path: str
index: int
token_count: int
repo_id: int

def get_dump_for_embedding(self):
return """{context}{content}""".format(
Expand Down Expand Up @@ -58,13 +56,13 @@ def __repr__(self):
return self.__str__()


class DocumentChunkWithEmbedding(DocumentChunk):
class EmbeddedDocumentChunk(BaseDocumentChunk):
embedding: np.ndarray

def to_db_model(self) -> DbDocumentChunk:
def to_db_model(self, repo_id: int) -> DbDocumentChunk:
return DbDocumentChunk(
id=self.id,
repo_id=self.repo_id,
repo_id=repo_id,
path=self.path,
index=self.index,
hash=self.hash,
Expand All @@ -80,8 +78,9 @@ class Config:
}


class DocumentChunkWithEmbeddingAndId(DocumentChunkWithEmbedding):
class StoredDocumentChunk(EmbeddedDocumentChunk):
id: int
repo_id: int


class RepositoryInfo(BaseModel):
Expand Down
13 changes: 6 additions & 7 deletions src/seer/automation/codebase/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sentence_transformers import SentenceTransformer
from tree_sitter import Node

from seer.automation.codebase.models import Document, DocumentChunk
from seer.automation.codebase.models import BaseDocumentChunk, Document
from seer.utils import class_method_lru_cache

logger = logging.getLogger("autofix")
Expand Down Expand Up @@ -313,27 +313,26 @@ def _extract_declaration(
def _get_parser(self, language: str):
return tree_sitter_languages.get_parser(language)

def _chunk_document(self, document: Document) -> list[DocumentChunk]:
def _chunk_document(self, document: Document) -> list[BaseDocumentChunk]:
tree = self._get_parser(document.language).parse(bytes(document.text, "utf-8"))

chunked_documents = self._chunk_nodes_by_whitespace(tree.root_node, document.language)

chunks: list[DocumentChunk] = []
chunks: list[BaseDocumentChunk] = []

for i, tmp_chunk in enumerate(chunked_documents):
context_text = tmp_chunk.get_context(tree.root_node)
chunk_text = tmp_chunk.get_content(tree.root_node)
embedding_dump = tmp_chunk.get_dump_for_embedding(tree.root_node)

chunk = DocumentChunk(
chunk = BaseDocumentChunk(
index=i,
context=context_text,
content=chunk_text.strip("\n"),
path=document.path,
# Hash should be unique to the file, it is used in comparing which chunks changed
hash=self._generate_sha256_hash(f"[{document.path}][{i}]\n{embedding_dump}"),
token_count=self._get_str_token_count(embedding_dump),
repo_id=document.repo_id,
language=document.language,
)

Expand All @@ -344,7 +343,7 @@ def _chunk_document(self, document: Document) -> list[DocumentChunk]:
def _generate_sha256_hash(self, text: str):
return hashlib.sha256(text.encode("utf-8"), usedforsecurity=False).hexdigest()

def process_document(self, document: Document) -> list[DocumentChunk]:
def process_document(self, document: Document) -> list[BaseDocumentChunk]:
"""
Process a document by chunking it into smaller pieces and extracting metadata about each chunk.
"""
Expand All @@ -354,7 +353,7 @@ def process_document(self, document: Document) -> list[DocumentChunk]:

return chunks

def process_documents(self, documents: list[Document]) -> list[DocumentChunk]:
def process_documents(self, documents: list[Document]) -> list[BaseDocumentChunk]:
"""
Process a list of documents by chunking them into smaller pieces and extracting metadata about each chunk.
"""
Expand Down
11 changes: 4 additions & 7 deletions src/seer/automation/codebase/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def get_language_from_path(path: str) -> str | None:

def read_directory(
path: str,
repo_id: int,
parent_tmp_dir: str | None = None,
max_file_size=2 * 1024 * 1024, # 2 MB
) -> list[Document]:
Expand All @@ -92,7 +91,7 @@ def read_directory(
dir_children = []
for entry in os.scandir(path):
if entry.is_dir(follow_symlinks=False):
dir_children.extend(read_directory(entry.path, repo_id, path_to_remove))
dir_children.extend(read_directory(entry.path, path_to_remove))
elif entry.is_file() and entry.stat().st_size < max_file_size:
language = get_language_from_path(entry.path)

Expand All @@ -108,13 +107,11 @@ def read_directory(
if truncated_path.startswith("/"):
truncated_path = truncated_path[1:]

dir_children.append(
Document(path=truncated_path, text=text, repo_id=repo_id, language=language)
)
dir_children.append(Document(path=truncated_path, text=text, language=language))
return dir_children


def read_specific_files(repo_path: str, files: list[str], repo_id: int) -> list[Document]:
def read_specific_files(repo_path: str, files: list[str]) -> list[Document]:
"""
Reads the contents of specific files and returns a list of Document objects.
Expand All @@ -134,7 +131,7 @@ def read_specific_files(repo_path: str, files: list[str], repo_id: int) -> list[
with open(file_path, "r", encoding="utf-8") as f:
text = f.read()

documents.append(Document(path=file, text=text, repo_id=repo_id, language=language))
documents.append(Document(path=file, text=text, language=language))
return documents


Expand Down
6 changes: 3 additions & 3 deletions tests/automation/codebase/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sentence_transformers import SentenceTransformer
from tree_sitter import Node, Parser

from seer.automation.codebase.models import DocumentChunk
from seer.automation.codebase.models import BaseDocumentChunk
from seer.automation.codebase.parser import Document, DocumentParser, ParentDeclaration, TempChunk


Expand All @@ -23,7 +23,7 @@ def test_document_parser_process_document(self):
mock_document.path = "test.py"
mock_document.repo_id = 1

expected_chunks = [MagicMock(spec=DocumentChunk)]
expected_chunks = [MagicMock(spec=BaseDocumentChunk)]
self.document_parser.process_document = MagicMock(return_value=expected_chunks)

result_chunks = self.document_parser.process_document(mock_document)
Expand All @@ -33,7 +33,7 @@ def test_document_parser_process_document(self):
def test_document_parser_process_documents(self):
# Test processing of multiple documents
mock_documents = [MagicMock(spec=Document) for _ in range(2)]
expected_chunks = [MagicMock(spec=DocumentChunk), MagicMock(spec=DocumentChunk)]
expected_chunks = [MagicMock(spec=BaseDocumentChunk), MagicMock(spec=BaseDocumentChunk)]
self.document_parser.process_documents = MagicMock(return_value=expected_chunks)

result_chunks = self.document_parser.process_documents(mock_documents)
Expand Down

0 comments on commit df843e2

Please sign in to comment.