-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(codegen): Unit test generation init
- Loading branch information
Showing
11 changed files
with
663 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import logging | ||
|
||
from seer.automation.codebase.repo_client import RepoClient | ||
from seer.automation.codegen.codegen_event_manager import CodegenEventManager | ||
from seer.automation.codegen.models import CodegenContinuation | ||
from seer.automation.codegen.state import CodegenContinuationState | ||
from seer.automation.models import RepoDefinition | ||
from seer.automation.pipeline import PipelineContext | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
RepoExternalId = str | ||
RepoInternalId = int | ||
RepoKey = RepoExternalId | RepoInternalId | ||
RepoIdentifiers = tuple[RepoExternalId, RepoInternalId] | ||
|
||
|
||
class CodegenContext(PipelineContext): | ||
state: CodegenContinuationState | ||
event_manager: CodegenEventManager | ||
repo: RepoDefinition | ||
|
||
def __init__( | ||
self, | ||
state: CodegenContinuationState, | ||
): | ||
request = state.get().request | ||
|
||
self.repo = request.repo | ||
self.state = state | ||
self.event_manager = CodegenEventManager(state) | ||
|
||
logger.info(f"CodegenContext initialized with run_id {self.run_id}") | ||
|
||
@classmethod | ||
def from_run_id(cls, run_id: int): | ||
state = CodegenContinuationState.from_id(run_id, model=CodegenContinuation) | ||
with state.update() as cur: | ||
cur.mark_triggered() | ||
|
||
return cls(state) | ||
|
||
@property | ||
def run_id(self) -> int: | ||
return self.state.get().run_id | ||
|
||
@property | ||
def signals(self) -> list[str]: | ||
return self.state.get().signals | ||
|
||
@signals.setter | ||
def signals(self, value: list[str]): | ||
with self.state.update() as state: | ||
state.signals = value | ||
|
||
def get_repo_client(self, repo_name: str | None = None): | ||
""" | ||
Gets a repo client for the current single repo or for a given repo name. | ||
If there are more than 1 repos, a repo name must be provided. | ||
""" | ||
return RepoClient.from_repo_definition(self.repo, "read") | ||
|
||
def get_file_contents( | ||
self, path: str, repo_name: str | None = None, ignore_local_changes: bool = False | ||
) -> str | None: | ||
repo_client = self.get_repo_client() | ||
|
||
file_contents = repo_client.get_file_content(path) | ||
|
||
if not ignore_local_changes: | ||
cur_state = self.state.get() | ||
current_file_changes = list(filter(lambda x: x.path == path, cur_state.file_changes)) | ||
for file_change in current_file_changes: | ||
file_contents = file_change.apply(file_contents) | ||
|
||
return file_contents |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import dataclasses | ||
from datetime import datetime | ||
|
||
from seer.automation.codegen.models import CodegenStatus | ||
from seer.automation.codegen.state import CodegenContinuationState | ||
from seer.automation.models import FileChange | ||
|
||
|
||
@dataclasses.dataclass | ||
class CodegenEventManager: | ||
state: CodegenContinuationState | ||
|
||
def mark_running(self): | ||
with self.state.update() as cur: | ||
cur.status = CodegenStatus.IN_PROGRESS | ||
|
||
def mark_completed(self): | ||
with self.state.update() as cur: | ||
cur.completed_at = datetime.now() | ||
cur.status = CodegenStatus.COMPLETED | ||
|
||
def add_log(self, message: str): | ||
pass | ||
|
||
def append_file_change(self, file_change: FileChange): | ||
with self.state.update() as current_state: | ||
current_state.file_changes.append(file_change) | ||
|
||
def on_error( | ||
self, error_msg: str = "Something went wrong", should_completely_error: bool = True | ||
): | ||
with self.state.update() as cur: | ||
cur.status = CodegenStatus.ERRORED |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import datetime | ||
from enum import Enum | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
from seer.automation.component import BaseComponentOutput, BaseComponentRequest | ||
from seer.automation.models import FileChange, RepoDefinition | ||
|
||
|
||
class CodegenStatus(str, Enum): | ||
PENDING = "pending" | ||
IN_PROGRESS = "in_progress" | ||
COMPLETED = "completed" | ||
ERRORED = "errored" | ||
|
||
|
||
class CodegenState(BaseModel): | ||
run_id: int = -1 | ||
file_changes: list[FileChange] = Field(default_factory=list) | ||
status: CodegenStatus = CodegenStatus.PENDING | ||
last_triggered_at: datetime.datetime = Field(default_factory=datetime.datetime.now) | ||
updated_at: datetime.datetime = Field(default_factory=datetime.datetime.now) | ||
completed_at: datetime.datetime | None = None | ||
signals: list[str] = Field(default_factory=list) | ||
|
||
|
||
class CodeUnitTestOutput(BaseComponentOutput): | ||
diffs: list[FileChange] | ||
|
||
|
||
class CodegenUnitTestsRequest(BaseModel): | ||
repo: RepoDefinition | ||
pr_id: int # The PR number | ||
|
||
|
||
class CodegenContinuation(CodegenState): | ||
request: CodegenUnitTestsRequest | ||
|
||
def mark_triggered(self): | ||
self.last_triggered_at = datetime.datetime.now() | ||
|
||
def mark_updated(self): | ||
self.updated_at = datetime.datetime.now() | ||
|
||
|
||
class CodeUnitTestRequest(BaseComponentRequest): | ||
diff: str | ||
|
||
|
||
class CodegenUnitTestsResponse(BaseModel): | ||
run_id: int | ||
|
||
|
||
class CodegenUnitTestsStateRequest(BaseModel): | ||
run_id: int | ||
|
||
|
||
class CodegenUnitTestsStateResponse(BaseModel): | ||
run_id: int | ||
status: CodegenStatus | ||
changes: list[FileChange] | ||
triggered_at: datetime.datetime | ||
updated_at: datetime.datetime | ||
completed_at: datetime.datetime | None = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import textwrap | ||
|
||
from seer.automation.autofix.components.coding.models import PlanStepsPromptXml, PlanTaskPromptXml | ||
|
||
|
||
class CodingUnitTestPrompts: | ||
@staticmethod | ||
def format_system_msg(): | ||
return textwrap.dedent( | ||
"""\ | ||
You are an exceptional principal engineer that is amazing at writing unit tests given a change request against codebases. | ||
You have access to tools that allow you to search a codebase to find the relevant code snippets and view relevant files. You can use these tools as many times as you want to find the relevant code snippets. | ||
# Guidelines: | ||
- EVERY TIME before you use a tool, think step-by-step each time before using the tools provided to you. | ||
- You also MUST think step-by-step before giving the final answer.""" | ||
) | ||
|
||
@staticmethod | ||
def format_plan_step_msg(diff_str: str): | ||
return textwrap.dedent( | ||
"""\ | ||
You are given the below code changes as a diff: | ||
{diff_str} | ||
# Your goal: | ||
Provide the most actionable and effective steps to add unit tests to ensure test coverage for all the changes in the diff. | ||
Since you are an exceptional principal engineer, your unit tests should not just add trivial tests, but should add meaningful ones that test all changed functionality. Your list of steps should be detailed enough so that following it exactly will lead to complete test coverage of the code changed in the given diff. | ||
When ready with your final answer, detail the precise plan to add unit tests. | ||
# Guidelines: | ||
- No placeholders are allowed, the unit test must be clear and detailed. | ||
- Make sure you use the tools provided to look through the codebase and at the files you are changing before outputting your suggested fix. | ||
- The unit tests must be comprehensive. Do not provide temporary examples, placeholders or incomplete ones. | ||
- In your suggested unit tests, whenever you are providing code, provide explicit diffs to show the exact changes that need to be made. | ||
- All your changes should be in test files. | ||
- EVERY TIME before you use a tool, think step-by-step each time before using the tools provided to you. | ||
- You also MUST think step-by-step before giving the final answer.""" | ||
).format( | ||
diff_str=diff_str, | ||
) | ||
|
||
@staticmethod | ||
def format_find_unit_test_pattern_step_msg(diff_str: str): | ||
return textwrap.dedent( | ||
"""\ | ||
You are given the below code changes as a diff: | ||
{diff_str} | ||
# Your goal: | ||
Look at existing unit tests in the code and succinctly describe, in clear terms, the main highlights of how they are to designed. | ||
Since you are an exceptional principal engineer, your description should not be trivial. Your description should be detailed enough so that following it exactly will lead to writing good and executable unit tests that follow the same design pattern. | ||
# Guidelines: | ||
- You do not have to explain each test and what it is testing. Just identify the basic libraries used as well as how the tests are structured. | ||
- If the codebase has no relevant tests then return the exact phrase "No relevant tests in the codebase" | ||
- Make sure you use the tools provided to look through the codebase and at the files that contain existing unit tests, even if they are not fully related to the changes in the given diff. | ||
- EVERY TIME before you use a tool, think step-by-step each time before using the tools provided to you. | ||
- You also MUST think step-by-step before giving the final answer.""" | ||
).format( | ||
diff_str=diff_str, | ||
) | ||
|
||
@staticmethod | ||
def format_unit_test_msg(diff_str, test_design_hint): | ||
example = PlanTaskPromptXml( | ||
file_path="path/to/file.py", | ||
repo_name="owner/repo", | ||
type="Either 'file_change', 'file_create', or 'file_delete'", | ||
description="Describe what you are doing here in detail like you are explaining it to a software engineer.", | ||
diff=textwrap.dedent( | ||
"""\ | ||
# Here provide the EXACT unified diff of the code change required to accomplish this step. | ||
# You must prefix lines that are removed with a '-' and lines that are added with a '+'. Context lines around the change are required and must be prefixed with a space. | ||
# Make sure the diff is complete and the code is EXACTLY matching the files you see. | ||
# For example: | ||
--- a/path/to/file.py | ||
+++ b/path/to/file.py | ||
@@ -1,3 +1,3 @@ | ||
return 'fab' | ||
y = 2 | ||
x = 1 | ||
-def foo(): | ||
+def foo(): | ||
return 'foo' | ||
def bar(): | ||
return 'bar' | ||
""" | ||
), | ||
commit_message="Provide a commit message that describes the unit test you are adding or changing", | ||
) | ||
|
||
prompt_obj = PlanStepsPromptXml( | ||
tasks=[ | ||
example, | ||
example, | ||
] | ||
) | ||
|
||
return textwrap.dedent( | ||
"""\ | ||
You are given a code diff: | ||
{diff_str} | ||
You are also given the following test design guidelines: | ||
{test_design_hint} | ||
# Your goal: | ||
Write unit tests that cover the changes in the diff. You should first explain in clear and definite terms what you are adding. Then add the unit test such that lines modified, added or deleted are covered. Create multiple test files if required and cover code changed in all the files. | ||
When ready with your final answer, detail the explanation of the test wrapped with a <explanation></explanation> block. Your output must follow the format properly according to the following guidelines: | ||
{steps_example_str} | ||
# Guidelines: | ||
_ Closely follow the guidelines provided to design the tests | ||
- Each file change must be a separate step and be explicit and clear. | ||
- Before adding new files check if a file exists with same name and if it can be edited instead | ||
- You MUST include exact file paths for each step you provide. If you cannot, find the correct path. | ||
- No placeholders are allowed, the steps must be clear and detailed. | ||
- Make sure you use the tools provided to look through the codebase and at the files you are changing before outputting the steps. | ||
- The plan must be comprehensive. Do not provide temporary examples, placeholders or incomplete steps. | ||
- EVERY TIME before you use a tool, think step-by-step each time before using the tools provided to you. | ||
- You also MUST think step-by-step before giving the final answer.""" | ||
).format( | ||
diff_str=diff_str, | ||
test_design_hint=test_design_hint, | ||
steps_example_str=prompt_obj.to_prompt_str(), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import dataclasses | ||
from typing import cast | ||
|
||
from pydantic import BaseModel | ||
|
||
from seer.automation.codegen.models import CodegenContinuation | ||
from seer.automation.state import DbState, DbStateRunTypes | ||
from seer.db import DbRunState, Session | ||
|
||
|
||
@dataclasses.dataclass | ||
class CodegenContinuationState(DbState[CodegenContinuation]): | ||
@classmethod | ||
def from_id(cls, id: int, model: type[BaseModel]) -> "CodegenContinuationState": | ||
return cast( | ||
CodegenContinuationState, super().from_id(id, model, type=DbStateRunTypes.UNIT_TEST) | ||
) | ||
|
||
def set(self, state: CodegenContinuation): | ||
state.mark_updated() | ||
|
||
with Session() as session: | ||
db_state = DbRunState( | ||
id=self.id, | ||
value=state.model_dump(mode="json"), | ||
updated_at=state.updated_at, | ||
last_triggered_at=state.last_triggered_at, | ||
) | ||
session.merge(db_state) | ||
session.commit() |
Oops, something went wrong.