-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move completions cli to its own file
- Loading branch information
Showing
6 changed files
with
284 additions
and
263 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,280 @@ | ||
import json | ||
|
||
import click | ||
import pandas as pd | ||
import rich | ||
import tqdm | ||
|
||
from log10._httpx_utils import _try_get | ||
from log10.cli_utils import generate_markdown_report, generate_results_table | ||
from log10.completions.completions import ( | ||
_check_model_support, | ||
_compare, | ||
_create_dataframe_from_comparison_data, | ||
_get_completion, | ||
_get_completions_url, | ||
_render_comparison_table, | ||
_render_completions_table, | ||
_write_completions, | ||
) | ||
from log10.llm import Log10Config | ||
from log10.prompt_analyzer import PromptAnalyzer, convert_suggestion_to_markdown, display_prompt_analyzer_suggestions | ||
|
||
|
||
_log10_config = Log10Config() | ||
|
||
|
||
@click.command() | ||
@click.option("--limit", default=25, help="Specify the maximum number of completions to retrieve.") | ||
@click.option("--offset", default=0, help="Set the starting point (offset) from where to begin fetching completions.") | ||
@click.option( | ||
"--timeout", default=10, help="Set the maximum time (in seconds) allowed for the HTTP request to complete." | ||
) | ||
@click.option("--tags", default="", help="Filter completions by specific tags. Separate multiple tags with commas.") | ||
@click.option( | ||
"--from", | ||
"from_date", | ||
type=click.DateTime(), | ||
help="Define the start date for fetching completions (inclusive). Use the format: YYYY-MM-DD.", | ||
) | ||
@click.option( | ||
"--to", | ||
"to_date", | ||
type=click.DateTime(), | ||
help="Set the end date for fetching completions (inclusive). Use the format: YYYY-MM-DD.", | ||
) | ||
def list_completions(limit, offset, timeout, tags, from_date, to_date): | ||
""" | ||
List completions | ||
""" | ||
base_url = _log10_config.url | ||
org_id = _log10_config.org_id | ||
|
||
url = _get_completions_url(limit, offset, tags, from_date, to_date, base_url, org_id) | ||
# Fetch completions | ||
res = _try_get(url, timeout) | ||
|
||
completions = res.json() | ||
total_completions = completions["total"] | ||
completions = completions["data"] | ||
|
||
_render_completions_table(completions, total_completions) | ||
|
||
|
||
@click.command() | ||
@click.option("--id", prompt="Enter completion id", help="Completion ID") | ||
def get_completion(id): | ||
""" | ||
Get a completion by id | ||
""" | ||
res = _get_completion(id) | ||
rich.print_json(json.dumps(res.json()["data"], indent=4)) | ||
|
||
|
||
@click.command() | ||
@click.option("--limit", default="", help="Specify the maximum number of completions to retrieve.") | ||
@click.option("--offset", default="", help="Set the starting point (offset) from where to begin fetching completions.") | ||
@click.option( | ||
"--timeout", default=10, help="Set the maximum time (in seconds) allowed for the HTTP request to complete." | ||
) | ||
@click.option("--tags", default="", help="Filter completions by specific tags. Separate multiple tags with commas.") | ||
@click.option( | ||
"--from", | ||
"from_date", | ||
type=click.DateTime(), | ||
help="Define the start date for fetching completions (inclusive). Use the format: YYYY-MM-DD.", | ||
) | ||
@click.option( | ||
"--to", | ||
"to_date", | ||
type=click.DateTime(), | ||
help="Set the end date for fetching completions (inclusive). Use the format: YYYY-MM-DD.", | ||
) | ||
@click.option("--compact", is_flag=True, help="Enable to download only the compact version of the output.") | ||
@click.option("--file", "-f", default="completions.jsonl", help="Specify the filename and path for the output file.") | ||
def download_completions(limit, offset, timeout, tags, from_date, to_date, compact, file): | ||
""" | ||
Download completions to a jsonl file | ||
""" | ||
base_url = _log10_config.url | ||
org_id = _log10_config.org_id | ||
|
||
init_url = _get_completions_url(1, 0, tags, from_date, to_date, base_url, org_id) | ||
res = _try_get(init_url) | ||
if res.status_code != 200: | ||
rich.print(f"Error: {res.json()}") | ||
return | ||
|
||
total_completions = res.json()["total"] | ||
offset = int(offset) if offset else 0 | ||
limit = int(limit) if limit else total_completions | ||
rich.print(f"Download total completions: {limit}/{total_completions}") | ||
if not click.confirm("Do you want to continue?"): | ||
return | ||
|
||
# dowlnoad completions | ||
pbar = tqdm.tqdm(total=limit) | ||
batch_size = 10 | ||
end = offset + limit if offset + limit < total_completions else total_completions | ||
for batch in range(offset, end, batch_size): | ||
current_batch_size = batch_size if batch + batch_size < end else end - batch | ||
download_url = _get_completions_url( | ||
current_batch_size, batch, tags, from_date, to_date, base_url, org_id, printout=False | ||
) | ||
res = _try_get(download_url, timeout) | ||
_write_completions(res, file, compact) | ||
pbar.update(current_batch_size) | ||
|
||
|
||
@click.command() | ||
@click.option("--ids", default="", help="Completion IDs. Separate multiple ids with commas.") | ||
@click.option("--tags", default="", help="Filter completions by specific tags. Separate multiple tags with commas.") | ||
@click.option("--limit", help="Specify the maximum number of completions to retrieve filtered by tags.") | ||
@click.option( | ||
"--offset", help="Set the starting point (offset) from where to begin fetching completions filtered by tags." | ||
) | ||
@click.option("--models", default="", help="Comma separated list of models to compare") | ||
@click.option("--temperature", default=0.2, help="Temperature") | ||
@click.option("--max_tokens", default=512, help="Max tokens") | ||
@click.option("--top_p", default=1.0, help="Top p") | ||
@click.option("--analyze_prompt", is_flag=True, help="Run prompt analyzer on the messages.") | ||
@click.option("--file", "-f", help="Specify the filename for the report in markdown format.") | ||
def benchmark_models(ids, tags, limit, offset, models, temperature, max_tokens, top_p, file, analyze_prompt): | ||
""" | ||
Compare completions using different models and generate report | ||
""" | ||
if ids and tags: | ||
raise click.UsageError("--ids and --tags cannot be set together.") | ||
if (limit or offset) and not tags: | ||
raise click.UsageError("--limit and --offset can only be used with --tags.") | ||
if tags: | ||
if not limit: | ||
limit = 5 | ||
if not offset: | ||
offset = 0 | ||
|
||
if not models: | ||
raise click.UsageError("--models must be set to compare.") | ||
else: | ||
for model in [m for m in models.split(",") if m]: | ||
if not _check_model_support(model): | ||
raise click.UsageError(f"Model {model} is not supported.") | ||
|
||
# get completions ids | ||
completion_ids = [] | ||
if ids: | ||
completion_ids = [id for id in ids.split(",") if id] | ||
elif tags: | ||
base_url = _log10_config.url | ||
org_id = _log10_config.org_id | ||
url = _get_completions_url(limit, offset, tags, None, None, base_url, org_id) | ||
res = _try_get(url) | ||
completions = res.json()["data"] | ||
completion_ids = [completion["id"] for completion in completions] | ||
if not completion_ids: | ||
SystemExit(f"No completions found for tags: {tags}") | ||
|
||
compare_models = [m for m in models.split(",") if m] | ||
|
||
data = [] | ||
skipped_completion_ids = [] | ||
for id in completion_ids: | ||
# get message from id | ||
completion_data = _get_completion(id).json()["data"] | ||
|
||
# skip completion if status is not finished or kind is not chat | ||
if completion_data["status"] != "finished" or completion_data["kind"] != "chat": | ||
rich.print(f"Skip completion {id}. Status is not finished or kind is not chat.") | ||
skipped_completion_ids.append(id) | ||
continue | ||
|
||
original_model_request = completion_data["request"] | ||
original_model_response = completion_data["response"] | ||
original_model = original_model_response["model"] | ||
benchmark_data = { | ||
"completion_id": id, | ||
"original_request": original_model_request, | ||
f"{original_model} (original model)": { | ||
"content": original_model_response["choices"][0]["message"]["content"], | ||
"usage": original_model_response["usage"], | ||
"duration": completion_data["duration"], | ||
}, | ||
} | ||
messages = original_model_request["messages"] | ||
compare_models_data = _compare(compare_models, messages, temperature, max_tokens, top_p) | ||
benchmark_data.update(compare_models_data) | ||
data.append(benchmark_data) | ||
|
||
prompt_analysis_data = {} | ||
if analyze_prompt: | ||
rich.print("Analyzing prompts") | ||
for item in data: | ||
completion_id = item["completion_id"] | ||
prompt_messages = item["original_request"]["messages"] | ||
all_messages = "\n\n".join([m["content"] for m in prompt_messages]) | ||
analyzer = PromptAnalyzer() | ||
suggestions = analyzer.analyze(all_messages) | ||
prompt_analysis_data[completion_id] = suggestions | ||
|
||
# create an empty dataframe | ||
all_df = pd.DataFrame( | ||
columns=[ | ||
"Completion ID", | ||
"Prompt Messages", | ||
"Model", | ||
"Content", | ||
"Prompt Tokens", | ||
"Completion Tokens", | ||
"Total Tokens", | ||
"Duration (ms)", | ||
] | ||
) | ||
|
||
# | ||
# Display or save the results | ||
# | ||
if not file: | ||
# display in terminal using rich | ||
for ret in data: | ||
_render_comparison_table(ret) | ||
if analyze_prompt: | ||
completion_id = ret["completion_id"] | ||
suggestions = prompt_analysis_data[completion_id] | ||
rich.print(f"Prompt Analysis for completion_id: {completion_id}") | ||
display_prompt_analyzer_suggestions(suggestions) | ||
else: | ||
# generate markdown report and save to file | ||
for ret in data: | ||
df = _create_dataframe_from_comparison_data(ret) | ||
all_df = pd.concat([all_df, df]) | ||
pivot_df = all_df.pivot(index="Completion ID", columns="Model", values="Content") | ||
pivot_df["Prompt Messages"] = all_df.groupby("Completion ID")["Prompt Messages"].first() | ||
# Reorder the columns | ||
cols = pivot_df.columns.tolist() | ||
cols = [cols[-1]] + cols[:-1] | ||
pivot_df = pivot_df[cols] | ||
|
||
pivot_table = generate_results_table(pivot_df, section_name="model comparison") | ||
all_results_table = generate_results_table(all_df, section_name="All Results") | ||
|
||
prompt_analysis_markdown = "" | ||
if analyze_prompt: | ||
prompt_analysis_markdown = "## Prompt Analysis\n\n" | ||
for completion_id, suggestions in prompt_analysis_data.items(): | ||
prompt_messages = all_df[all_df["Completion ID"] == completion_id]["Prompt Messages"].values[0] | ||
prompt_analysis_markdown += ( | ||
f"### Prompt Analysis for completion_id: {completion_id}\n\n{prompt_messages}\n\n" | ||
) | ||
prompt_analysis_markdown += convert_suggestion_to_markdown(suggestions) | ||
|
||
# generate the list of skipped completions ids | ||
skipped_completion_markdown = "" | ||
if skipped_completion_ids: | ||
skipped_completion_ids_str = ", ".join(skipped_completion_ids) | ||
skipped_completion_markdown += "## Skipped Completion IDs\n\n" | ||
skipped_completion_markdown += f"Skipped completions: {skipped_completion_ids_str}\n\n" | ||
|
||
generate_markdown_report( | ||
file, [pivot_table, prompt_analysis_markdown, all_results_table, skipped_completion_markdown] | ||
) | ||
rich.print(f"Report saved to {file}") |
File renamed without changes.
File renamed without changes.
Oops, something went wrong.