Skip to content

Commit

Permalink
feat(codegen): Unit test generation init (#1135)
Browse files Browse the repository at this point in the history
Introduces a new `codegen` folder with a unit test generation agent that
was made for the codecov hack week project

Uses a new `CodegenContext` similar to the autofix context but new


![image](https://github.com/user-attachments/assets/c2da7079-c48d-433a-8218-78ca676574cc)

Lack of unit tests is due to this being a WIP agent, will follow up &
harden
  • Loading branch information
jennmueng authored Sep 10, 2024
1 parent ec794e1 commit a730a1d
Show file tree
Hide file tree
Showing 13 changed files with 667 additions and 2 deletions.
28 changes: 28 additions & 0 deletions src/seer/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@
RepoAccessCheckResponse,
)
from seer.automation.codebase.repo_client import RepoClient
from seer.automation.codegen.models import (
CodegenUnitTestsRequest,
CodegenUnitTestsResponse,
CodegenUnitTestsStateRequest,
CodegenUnitTestsStateResponse,
)
from seer.automation.codegen.tasks import codegen_unittest, get_unittest_state
from seer.automation.summarize.issue import run_summarize_issue
from seer.automation.summarize.models import SummarizeIssueRequest, SummarizeIssueResponse
from seer.automation.utils import raise_if_no_genai_consent
Expand Down Expand Up @@ -214,6 +221,27 @@ def autofix_evaluation_start_endpoint(data: AutofixEvaluationRequest) -> Autofix
return AutofixEndpointResponse(started=True, run_id=-1)


@json_api(blueprint, "/v1/automation/codegen/unit-tests")
def codegen_unit_tests_endpoint(data: CodegenUnitTestsRequest) -> CodegenUnitTestsResponse:
return codegen_unittest(data)


@json_api(blueprint, "/v1/automation/codegen/unit-tests/state")
def codegen_unit_tests_state_endpoint(
data: CodegenUnitTestsStateRequest,
) -> CodegenUnitTestsStateResponse:
state = get_unittest_state(data)

return CodegenUnitTestsStateResponse(
run_id=state.run_id,
status=state.status,
changes=state.file_changes,
triggered_at=state.last_triggered_at,
updated_at=state.updated_at,
completed_at=state.completed_at,
)


@json_api(blueprint, "/v1/automation/summarize/issue")
def summarize_issue_endpoint(data: SummarizeIssueRequest) -> SummarizeIssueResponse:
return run_summarize_issue(data)
Expand Down
5 changes: 3 additions & 2 deletions src/seer/automation/autofix/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from seer.automation.codebase.code_search import CodeSearcher
from seer.automation.codebase.models import MatchXml
from seer.automation.codebase.utils import cleanup_dir
from seer.automation.codegen.codegen_context import CodegenContext

logger = logging.getLogger(__name__)


class BaseTools:
context: AutofixContext
context: AutofixContext | CodegenContext
retrieval_top_k: int

def __init__(self, context: AutofixContext, retrieval_top_k: int = 8):
def __init__(self, context: AutofixContext | CodegenContext, retrieval_top_k: int = 8):
self.context = context
self.retrieval_top_k = retrieval_top_k

Expand Down
13 changes: 13 additions & 0 deletions src/seer/automation/codebase/repo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,16 @@ def get_index_file_set(
file_set.add(file.path)

return file_set

def get_pr_diff_content(self, pr_url: str) -> str:
requester = self.repo._requester
headers = {
"Authorization": f"{requester.auth.token_type} {requester.auth.token}", # type: ignore
"Accept": "application/vnd.github.diff",
}

data = requests.get(pr_url, headers=headers)

data.raise_for_status() # Raise an exception for HTTP errors

return data.text
76 changes: 76 additions & 0 deletions src/seer/automation/codegen/codegen_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging

from seer.automation.codebase.repo_client import RepoClient
from seer.automation.codegen.codegen_event_manager import CodegenEventManager
from seer.automation.codegen.models import CodegenContinuation
from seer.automation.codegen.state import CodegenContinuationState
from seer.automation.models import RepoDefinition
from seer.automation.pipeline import PipelineContext

logger = logging.getLogger(__name__)

RepoExternalId = str
RepoInternalId = int
RepoKey = RepoExternalId | RepoInternalId
RepoIdentifiers = tuple[RepoExternalId, RepoInternalId]


class CodegenContext(PipelineContext):
state: CodegenContinuationState
event_manager: CodegenEventManager
repo: RepoDefinition

def __init__(
self,
state: CodegenContinuationState,
):
request = state.get().request

self.repo = request.repo
self.state = state
self.event_manager = CodegenEventManager(state)

logger.info(f"CodegenContext initialized with run_id {self.run_id}")

@classmethod
def from_run_id(cls, run_id: int):
state = CodegenContinuationState.from_id(run_id, model=CodegenContinuation)
with state.update() as cur:
cur.mark_triggered()

return cls(state)

@property
def run_id(self) -> int:
return self.state.get().run_id

@property
def signals(self) -> list[str]:
return self.state.get().signals

@signals.setter
def signals(self, value: list[str]):
with self.state.update() as state:
state.signals = value

def get_repo_client(self, repo_name: str | None = None):
"""
Gets a repo client for the current single repo or for a given repo name.
If there are more than 1 repos, a repo name must be provided.
"""
return RepoClient.from_repo_definition(self.repo, "read")

def get_file_contents(
self, path: str, repo_name: str | None = None, ignore_local_changes: bool = False
) -> str | None:
repo_client = self.get_repo_client()

file_contents = repo_client.get_file_content(path)

if not ignore_local_changes:
cur_state = self.state.get()
current_file_changes = list(filter(lambda x: x.path == path, cur_state.file_changes))
for file_change in current_file_changes:
file_contents = file_change.apply(file_contents)

return file_contents
33 changes: 33 additions & 0 deletions src/seer/automation/codegen/codegen_event_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import dataclasses
from datetime import datetime

from seer.automation.codegen.models import CodegenStatus
from seer.automation.codegen.state import CodegenContinuationState
from seer.automation.models import FileChange


@dataclasses.dataclass
class CodegenEventManager:
state: CodegenContinuationState

def mark_running(self):
with self.state.update() as cur:
cur.status = CodegenStatus.IN_PROGRESS

def mark_completed(self):
with self.state.update() as cur:
cur.completed_at = datetime.now()
cur.status = CodegenStatus.COMPLETED

def add_log(self, message: str):
pass

def append_file_change(self, file_change: FileChange):
with self.state.update() as current_state:
current_state.file_changes.append(file_change)

def on_error(
self, error_msg: str = "Something went wrong", should_completely_error: bool = True
):
with self.state.update() as cur:
cur.status = CodegenStatus.ERRORED
64 changes: 64 additions & 0 deletions src/seer/automation/codegen/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import datetime
from enum import Enum

from pydantic import BaseModel, Field

from seer.automation.component import BaseComponentOutput, BaseComponentRequest
from seer.automation.models import FileChange, RepoDefinition


class CodegenStatus(str, Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERRORED = "errored"


class CodegenState(BaseModel):
run_id: int = -1
file_changes: list[FileChange] = Field(default_factory=list)
status: CodegenStatus = CodegenStatus.PENDING
last_triggered_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
updated_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
completed_at: datetime.datetime | None = None
signals: list[str] = Field(default_factory=list)


class CodeUnitTestOutput(BaseComponentOutput):
diffs: list[FileChange]


class CodegenUnitTestsRequest(BaseModel):
repo: RepoDefinition
pr_id: int # The PR number


class CodegenContinuation(CodegenState):
request: CodegenUnitTestsRequest

def mark_triggered(self):
self.last_triggered_at = datetime.datetime.now()

def mark_updated(self):
self.updated_at = datetime.datetime.now()


class CodeUnitTestRequest(BaseComponentRequest):
diff: str


class CodegenUnitTestsResponse(BaseModel):
run_id: int


class CodegenUnitTestsStateRequest(BaseModel):
run_id: int


class CodegenUnitTestsStateResponse(BaseModel):
run_id: int
status: CodegenStatus
changes: list[FileChange]
triggered_at: datetime.datetime
updated_at: datetime.datetime
completed_at: datetime.datetime | None = None
133 changes: 133 additions & 0 deletions src/seer/automation/codegen/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import textwrap

from seer.automation.autofix.components.coding.models import PlanStepsPromptXml, PlanTaskPromptXml


class CodingUnitTestPrompts:
@staticmethod
def format_system_msg():
return textwrap.dedent(
"""\
You are an exceptional principal engineer that is amazing at writing unit tests given a change request against codebases.
You have access to tools that allow you to search a codebase to find the relevant code snippets and view relevant files. You can use these tools as many times as you want to find the relevant code snippets.
# Guidelines:
- EVERY TIME before you use a tool, think step-by-step each time before using the tools provided to you.
- You also MUST think step-by-step before giving the final answer."""
)

@staticmethod
def format_plan_step_msg(diff_str: str):
return textwrap.dedent(
"""\
You are given the below code changes as a diff:
{diff_str}
# Your goal:
Provide the most actionable and effective steps to add unit tests to ensure test coverage for all the changes in the diff.
Since you are an exceptional principal engineer, your unit tests should not just add trivial tests, but should add meaningful ones that test all changed functionality. Your list of steps should be detailed enough so that following it exactly will lead to complete test coverage of the code changed in the given diff.
When ready with your final answer, detail the precise plan to add unit tests.
# Guidelines:
- No placeholders are allowed, the unit test must be clear and detailed.
- Make sure you use the tools provided to look through the codebase and at the files you are changing before outputting your suggested fix.
- The unit tests must be comprehensive. Do not provide temporary examples, placeholders or incomplete ones.
- In your suggested unit tests, whenever you are providing code, provide explicit diffs to show the exact changes that need to be made.
- All your changes should be in test files.
- EVERY TIME before you use a tool, think step-by-step each time before using the tools provided to you.
- You also MUST think step-by-step before giving the final answer."""
).format(
diff_str=diff_str,
)

@staticmethod
def format_find_unit_test_pattern_step_msg(diff_str: str):
return textwrap.dedent(
"""\
You are given the below code changes as a diff:
{diff_str}
# Your goal:
Look at existing unit tests in the code and succinctly describe, in clear terms, the main highlights of how they are to designed.
Since you are an exceptional principal engineer, your description should not be trivial. Your description should be detailed enough so that following it exactly will lead to writing good and executable unit tests that follow the same design pattern.
# Guidelines:
- You do not have to explain each test and what it is testing. Just identify the basic libraries used as well as how the tests are structured.
- If the codebase has no relevant tests then return the exact phrase "No relevant tests in the codebase"
- Make sure you use the tools provided to look through the codebase and at the files that contain existing unit tests, even if they are not fully related to the changes in the given diff.
- EVERY TIME before you use a tool, think step-by-step each time before using the tools provided to you.
- You also MUST think step-by-step before giving the final answer."""
).format(
diff_str=diff_str,
)

@staticmethod
def format_unit_test_msg(diff_str, test_design_hint):
example = PlanTaskPromptXml(
file_path="path/to/file.py",
repo_name="owner/repo",
type="Either 'file_change', 'file_create', or 'file_delete'",
description="Describe what you are doing here in detail like you are explaining it to a software engineer.",
diff=textwrap.dedent(
"""\
# Here provide the EXACT unified diff of the code change required to accomplish this step.
# You must prefix lines that are removed with a '-' and lines that are added with a '+'. Context lines around the change are required and must be prefixed with a space.
# Make sure the diff is complete and the code is EXACTLY matching the files you see.
# For example:
--- a/path/to/file.py
+++ b/path/to/file.py
@@ -1,3 +1,3 @@
return 'fab'
y = 2
x = 1
-def foo():
+def foo():
return 'foo'
def bar():
return 'bar'
"""
),
commit_message="Provide a commit message that describes the unit test you are adding or changing",
)

prompt_obj = PlanStepsPromptXml(
tasks=[
example,
example,
]
)

return textwrap.dedent(
"""\
You are given a code diff:
{diff_str}
You are also given the following test design guidelines:
{test_design_hint}
# Your goal:
Write unit tests that cover the changes in the diff. You should first explain in clear and definite terms what you are adding. Then add the unit test such that lines modified, added or deleted are covered. Create multiple test files if required and cover code changed in all the files.
When ready with your final answer, detail the explanation of the test wrapped with a <explanation></explanation> block. Your output must follow the format properly according to the following guidelines:
{steps_example_str}
# Guidelines:
_ Closely follow the guidelines provided to design the tests
- Each file change must be a separate step and be explicit and clear.
- Before adding new files check if a file exists with same name and if it can be edited instead
- You MUST include exact file paths for each step you provide. If you cannot, find the correct path.
- No placeholders are allowed, the steps must be clear and detailed.
- Make sure you use the tools provided to look through the codebase and at the files you are changing before outputting the steps.
- The plan must be comprehensive. Do not provide temporary examples, placeholders or incomplete steps.
- EVERY TIME before you use a tool, think step-by-step each time before using the tools provided to you.
- You also MUST think step-by-step before giving the final answer."""
).format(
diff_str=diff_str,
test_design_hint=test_design_hint,
steps_example_str=prompt_obj.to_prompt_str(),
)
Loading

0 comments on commit a730a1d

Please sign in to comment.