Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add black formatting ci check #19

Merged
merged 4 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .github/workflows/black_checker.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Black Formatting Check

on:
push:
branches:
- main
pull_request:
branches:
- '*'

jobs:
black-check:
container:
image: python:3.8
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- name: Install dependencies
run: |
pip install .[dev]

- name: Check Black formatting
run: |
black --check .
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
image: python:3.8
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/checkout@v2
- name: Run Tests
shell: bash
run: |
Expand Down
22 changes: 9 additions & 13 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,19 @@
name="learn_to_pick",
version="0.1",
install_requires=[
'numpy',
'pandas',
'vowpal-wabbit-next',
'sentence-transformers',
'torch',
'pyskiplist',
'parameterfree',
"numpy",
"pandas",
"vowpal-wabbit-next",
"sentence-transformers",
"torch",
"pyskiplist",
"parameterfree",
],
extras_require={
'dev': [
'pytest'
]
},
extras_require={"dev": ["pytest", "black==23.10.0"]},
author="VowpalWabbit",
description="",
packages=find_packages(where="src"),
package_dir={"": "src"},
url="https://github.com/VowpalWabbit/learn_to_pick",
python_requires='>=3.8',
python_requires=">=3.8",
)
48 changes: 24 additions & 24 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
Union,
)

from learn_to_pick.metrics import (
MetricsTrackerAverage,
MetricsTrackerRollingWindow,
)
from learn_to_pick.metrics import MetricsTrackerAverage, MetricsTrackerRollingWindow
from learn_to_pick.model_repository import ModelRepository
from learn_to_pick.vw_logger import VwLogger

Expand Down Expand Up @@ -234,14 +231,12 @@ class SelectionScorer(Generic[TEvent], ABC):
"""

@abstractmethod
def score_response(
self, inputs: Dict[str, Any], picked: Any, event: TEvent
) -> Any:
def score_response(self, inputs: Dict[str, Any], picked: Any, event: TEvent) -> Any:
"""
Calculate and return the score for the selected response.

This is an abstract method and should be implemented by subclasses.
The method defines a blueprint for applying scoring logic based on the provided
The method defines a blueprint for applying scoring logic based on the provided
inputs, the selection made by the policy, and additional metadata from the event.

Args:
Expand All @@ -256,10 +251,12 @@ def score_response(


class AutoSelectionScorer(SelectionScorer[Event]):
def __init__(self,
llm,
prompt: Union[Any, None] = None,
scoring_criteria_template_str: Optional[str] = None):
def __init__(
self,
llm,
prompt: Union[Any, None] = None,
scoring_criteria_template_str: Optional[str] = None,
):
self.llm = llm
self.prompt = prompt
if prompt is None and scoring_criteria_template_str is None:
Expand All @@ -285,16 +282,19 @@ def get_default_prompt() -> str:
@staticmethod
def format_with_ignoring_extra_args(prompt, inputs):
import string

# Extract placeholders from the prompt
placeholders = [field[1] for field in string.Formatter().parse(str(prompt)) if field[1]]
placeholders = [
field[1] for field in string.Formatter().parse(str(prompt)) if field[1]
]

# Keep only the inputs that have corresponding placeholders in the prompt
relevant_inputs = {k: v for k, v in inputs.items() if k in placeholders}

return prompt.format(**relevant_inputs)

def score_response(
self, inputs: Dict[str, Any], picked: Any, event: Event
self, inputs: Dict[str, Any], picked: Any, event: Event
) -> float:
p = AutoSelectionScorer.format_with_ignoring_extra_args(self.prompt, inputs)
ranking = self.llm.predict(p)
Expand Down Expand Up @@ -337,6 +337,7 @@ class RLLoop(Generic[TEvent]):
Notes:
By default the class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called.
"""

# Define the default values as class attributes
selected_input_key = "pick_best_selected"
selected_based_on_input_key = "pick_best_selected_based_on"
Expand Down Expand Up @@ -409,7 +410,6 @@ def save_progress(self) -> None:
"""
self.policy.save()


def _can_use_selection_scorer(self) -> bool:
"""
Returns whether the chain can use the selection scorer to score responses or not.
Expand All @@ -422,10 +422,7 @@ def _call_before_predict(self, inputs: Dict[str, Any]) -> TEvent:

@abstractmethod
def _call_after_predict_before_scoring(
self,
inputs: Dict[str, Any],
event: Event,
prediction: List[Tuple[int, float]],
self, inputs: Dict[str, Any], event: Event, prediction: List[Tuple[int, float]]
) -> Tuple[Dict[str, Any], Event]:
...

Expand All @@ -448,7 +445,9 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
elif kwargs and not args:
inputs = kwargs
else:
raise ValueError("Either a dictionary positional argument or keyword arguments should be provided")
raise ValueError(
"Either a dictionary positional argument or keyword arguments should be provided"
)

event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.policy.predict(event=event)
Expand All @@ -461,11 +460,11 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:

for callback_func in self.callbacks_before_scoring:
try:
next_chain_inputs, event = callback_func(inputs=next_chain_inputs, picked=picked, event=event)
except Exception as e:
logger.info(
f"Callback function {callback_func} failed, error: {e}"
next_chain_inputs, event = callback_func(
inputs=next_chain_inputs, picked=picked, event=event
)
except Exception as e:
logger.info(f"Callback function {callback_func} failed, error: {e}")

score = None
try:
Expand All @@ -489,6 +488,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
event.outputs = next_chain_inputs
return {"picked": picked, "picked_metadata": event}


def is_stringtype_instance(item: Any) -> bool:
"""Helper function to check if an item is a string."""
return isinstance(item, str) or (
Expand Down
4 changes: 3 additions & 1 deletion src/learn_to_pick/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def on_feedback(self, value: float) -> None:
self.sum -= old_val

if self.step > 0 and self.feedback_count % self.step == 0:
self.history.append({"step": self.feedback_count, "score": self.sum / len(self.queue)})
self.history.append(
{"step": self.feedback_count, "score": self.sum / len(self.queue)}
)

def to_pandas(self) -> "pd.DataFrame":
import pandas as pd
Expand Down
13 changes: 4 additions & 9 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def create(
logger.warning(
f"{[k for k, v in policy_args.items() if v]} will be ignored since nontrivial policy is provided, please set those arguments in the policy directly if needed"
)

if policy_args["model_save_dir"] is None:
policy_args["model_save_dir"] = "./"
if policy_args["reset_model"] is None:
Expand All @@ -370,7 +370,7 @@ def create_policy(
vw_cmd: Optional[List[str]] = None,
model_save_dir: str = "./",
reset_model: bool = False,
rl_logs: Optional[Union[str, os.PathLike]] = None
rl_logs: Optional[Union[str, os.PathLike]] = None,
):
if not featurizer:
featurizer = PickBestFeaturizer(auto_embed=False)
Expand All @@ -384,20 +384,15 @@ def create_policy(
)
else:
interactions += ["--interactions=::"]
vw_cmd = [
"--cb_explore_adf",
"--coin",
"--squarecb",
"--quiet",
]
vw_cmd = ["--cb_explore_adf", "--coin", "--squarecb", "--quiet"]

if featurizer.auto_embed:
interactions += [
"--interactions=@#",
"--ignore_linear=@",
"--ignore_linear=#",
]

vw_cmd = interactions + vw_cmd

return base.VwPolicy(
Expand Down
32 changes: 8 additions & 24 deletions tests/unit_tests/test_pick_best_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,8 @@ def score_response(


def test_everything_embedded() -> None:
featurizer = learn_to_pick.PickBestFeaturizer(
auto_embed=False, model=MockEncoder()
)
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller, featurizer=featurizer
)
featurizer = learn_to_pick.PickBestFeaturizer(auto_embed=False, model=MockEncoder())
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)

str1 = "0"
str2 = "1"
Expand All @@ -187,12 +183,8 @@ def test_everything_embedded() -> None:


def test_default_auto_embedder_is_off() -> None:
featurizer = learn_to_pick.PickBestFeaturizer(
auto_embed=False, model=MockEncoder()
)
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller, featurizer=featurizer
)
featurizer = learn_to_pick.PickBestFeaturizer(auto_embed=False, model=MockEncoder())
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)

str1 = "0"
str2 = "1"
Expand All @@ -213,12 +205,8 @@ def test_default_auto_embedder_is_off() -> None:


def test_default_w_embeddings_off() -> None:
featurizer = learn_to_pick.PickBestFeaturizer(
auto_embed=False, model=MockEncoder()
)
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller, featurizer=featurizer
)
featurizer = learn_to_pick.PickBestFeaturizer(auto_embed=False, model=MockEncoder())
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)

str1 = "0"
str2 = "1"
Expand All @@ -242,9 +230,7 @@ def test_default_w_embeddings_on() -> None:
featurizer = learn_to_pick.PickBestFeaturizer(
auto_embed=True, model=MockEncoderReturnsList()
)
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller, featurizer=featurizer
)
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)

str1 = "0"
str2 = "1"
Expand All @@ -268,9 +254,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings() -> None:
featurizer = learn_to_pick.PickBestFeaturizer(
auto_embed=True, model=MockEncoderReturnsList()
)
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller, featurizer=featurizer
)
pick = learn_to_pick.PickBest.create(llm=fake_llm_caller, featurizer=featurizer)

str1 = "0"
str2 = "1"
Expand Down
11 changes: 2 additions & 9 deletions tests/unit_tests/test_pick_best_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N
encoded_ctx_str_2 = rl_chain.stringify_embedding(list(encoded_keyword + ctx_str_2))

named_actions = {
"action1": [
{"a": str1, "b": rl_chain.Embed(str1)},
str2,
rl_chain.Embed(str3),
]
"action1": [{"a": str1, "b": rl_chain.Embed(str1)}, str2, rl_chain.Embed(str3)]
}
context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)}
expected = f"""shared |context1 {ctx_str_1} |context2 {encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {encoded_str1} \n|action1 {str2} \n|action1 {encoded_str3} """ # noqa: E501
Expand Down Expand Up @@ -296,10 +292,7 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep()
rl_chain.EmbedAndKeep(str3),
]
}
context = {
"context1": ctx_str_1,
"context2": rl_chain.EmbedAndKeep(ctx_str_2),
}
context = {"context1": ctx_str_1, "context2": rl_chain.EmbedAndKeep(ctx_str_2)}
expected = f"""shared |context1 {ctx_str_1} |context2 {ctx_str_2 + " " + encoded_ctx_str_2} \n0:-0.0:1.0 |a {str1} |b {str1 + " " + encoded_str1} \n|action1 {str2} \n|action1 {str3 + " " + encoded_str3} """ # noqa: E501

selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
Expand Down
20 changes: 4 additions & 16 deletions tests/unit_tests/test_rl_loop_base_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ def test_context_w_namespace_w_some_emb() -> None:
== expected
)
expected_embed_and_keep = [
{
"test_namespace": str1,
"test_namespace2": str2 + " " + encoded_str2,
}
{"test_namespace": str1, "test_namespace2": str2 + " " + encoded_str2}
]
assert (
base.embed(
Expand Down Expand Up @@ -337,18 +334,9 @@ def test_action_w_namespace_w_emb_w_more_than_one_item_in_first_dict() -> None:
== expected
)
expected_embed_and_keep = [
{
"test_namespace": str1 + " " + encoded_str1,
"test_namespace2": str1,
},
{
"test_namespace": str2 + " " + encoded_str2,
"test_namespace2": str2,
},
{
"test_namespace": str3 + " " + encoded_str3,
"test_namespace2": str3,
},
{"test_namespace": str1 + " " + encoded_str1, "test_namespace2": str1},
{"test_namespace": str2 + " " + encoded_str2, "test_namespace2": str2},
{"test_namespace": str3 + " " + encoded_str3, "test_namespace2": str3},
]
assert (
base.embed(
Expand Down
Loading