Skip to content

Commit

Permalink
make cli work
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhe-log10 committed Feb 12, 2024
1 parent 173e1ad commit 7eb48ff
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 38 deletions.
47 changes: 22 additions & 25 deletions log10/feedback/feedback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging

import click
Expand All @@ -7,15 +8,6 @@
from log10.llm import Log10Config


# def create(name: str, task_schema: dict) -> httpx.Response:
# """
# Example:
# >>> from log10.feedback import feedback, feedback_task
# >>> task = feedback_task.create(name="summarization", task_schema={...})
# >>> task_id = task.id
# >>> fb = feedback.create(task_id=task_id, rate={...})
# """

load_dotenv()

logging.basicConfig(
Expand All @@ -34,35 +26,40 @@ def __init__(self, log10_config: Log10Config = None):
self._http_client = httpx.Client()

def _post_request(self, url: str, json_payload: dict) -> httpx.Response:
headers = {"x-log10-token": self._log10_config.token, "x-log10-organization": self._log10_config.org_id, "Content-Type": "application/json"}
headers = {
"x-log10-token": self._log10_config.token,
"x-log10-organization": self._log10_config.org_id,
"Content-Type": "application/json",
}
json_payload["organization_id"] = self._log10_config.org_id
try:
res = self._http_client.post(self._log10_config.url + url, headers=headers, json=json_payload)
res.raise_for_status()
return res
except Exception as e:
logger.error(e)
logger.error(e.response.json()["error"])
raise

def create(self, task_id: str, rate: dict, completion_tags_selector: list[str], comment: str = None) -> httpx.Response:
"""
Example:
>>> from log10.feedback import Feedback
>>> fb = Feedback()
>>> fb.create(task_id="task_id", rate={...})
"""
json_payload = {"task_id": task_id, "json_values": rate, "completion_tags_selector": completion_tags_selector}
def create(
self, task_id: str, values: dict, completion_tags_selector: list[str], comment: str = None
) -> httpx.Response:
json_payload = {
"task_id": task_id,
"json_values": values,
"completion_tags_selector": completion_tags_selector,
}
res = self._post_request(self.feedback_create_url, json_payload)
return res


@click.command()
@click.option("--task_id", prompt="Enter task id", help="Task ID")
@click.option("--rate", prompt="Enter task rate", help="Rate in JSON format")
def create_feedback(task_id, rate):
@click.option("--values", prompt="Enter task values", help="Feedback in JSON format")
@click.option("--completion_tags_selector", prompt="Enter completion tags selector", help="Completion tags selector")
def create_feedback(task_id, values, completion_tags_selector):
click.echo("Creating feedback")
click.echo(f"Task ID: {task_id}")
click.echo(f"Rate: {rate}")
# fb = Feedback()
# feedback = fb.create(task_id=task_id, rate=rate)
# click.echo(feedback)
tags = completion_tags_selector.split(",")
values = json.loads(values)
feedback = Feedback().create(task_id=task_id, values=values, completion_tags_selector=tags)
click.echo(feedback.json())
23 changes: 10 additions & 13 deletions log10/feedback/feedback_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging

import click
Expand All @@ -17,7 +18,6 @@
logger.setLevel(logging.INFO)



class FeedbackTask:
feedback_task_create_url = "api/v1/feedback_task"

Expand All @@ -26,23 +26,22 @@ def __init__(self, log10_config: Log10Config = None):
self._http_client = httpx.Client()

def _post_request(self, url: str, json_payload: dict) -> httpx.Response:
headers = {"x-log10-token": self._log10_config.token, "Content-Type": "application/json", "x-log10-organization": self._log10_config.org_id}
headers = {
"x-log10-token": self._log10_config.token,
"Content-Type": "application/json",
"x-log10-organization": self._log10_config.org_id,
}
json_payload["organization_id"] = self._log10_config.org_id
try:
res = self._http_client.post(self._log10_config.url + url, headers=headers, json=json_payload)
res.raise_for_status()
return res
except Exception as e:
logger.error(e)
logger.error(e.response.json()["error"])
raise

def create(self, task_schema: dict, name: str = None, instruction: str = None) -> httpx.Response:
"""
Example:
>>> from log10.feedback.feedback_task import FeedbackTask
>>> feedback_task = FeedbackTask()
>>> task = feedback_task.create(name="summarization", task_schema={...})
"""
json_payload = {"json_schema": task_schema}
if name:
json_payload["name"] = name
Expand All @@ -59,8 +58,6 @@ def create(self, task_schema: dict, name: str = None, instruction: str = None) -
@click.option("--task_schema", prompt="Enter feedback task schema", help="Task schema")
def create_feedback_task(name, task_schema):
click.echo("Creating feedback task")
click.echo(f"Name: {name}")
click.echo(f"Task Schema: {task_schema}")
# fb_task = FeedbackTask()
# task = fb_task.create(name=name, task_schema=task_schema)
# click.echo(task)
task_schema = json.loads(task_schema)
task = FeedbackTask().create(name=name, task_schema=task_schema)
click.echo(f"Use this task_id to add feedback: {task.json()['id']}")

0 comments on commit 7eb48ff

Please sign in to comment.