Skip to content

Commit

Permalink
Merge pull request #27 from VowpalWabbit/cleanup
Browse files Browse the repository at this point in the history
Rename scorer keys, input validation, escape newlines/tabs/etc
  • Loading branch information
olgavrou authored Nov 3, 2023
2 parents a45671c + a62686f commit c8ad86a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
33 changes: 24 additions & 9 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ def get_default_system_prompt() -> str:

@staticmethod
def get_default_prompt() -> str:
human_template = """Given this based_on "{pick_best_selected_based_on}" \
human_template = """Given this based_on "{selected_based_on}" \
as the most important attribute, rank how good or bad this text is: \
"{pick_best_selected}"."""
"{picked}"."""
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
return default_system_prompt + human_template

Expand Down Expand Up @@ -325,8 +325,8 @@ class RLLoop(Generic[TEvent]):
"""

# Define the default values as class attributes
selected_input_key = "pick_best_selected"
selected_based_on_input_key = "pick_best_selected_based_on"
selected_based_on_input_key = "selected_based_on"
selected_input_key = "picked"

def __init__(
self,
Expand Down Expand Up @@ -435,19 +435,29 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
"Either a dictionary positional argument or keyword arguments should be provided"
)

if self.selected_based_on_input_key in inputs:
raise ValueError(
f"The input key {self.selected_based_on_input_key} is reserved. Please use a different key."
)

if self.selected_input_key in inputs:
raise ValueError(
f"The input key {self.selected_input_key} is reserved. Please use a different key."
)

event: TEvent = self._call_before_predict(inputs=inputs)
prediction = self.policy.predict(event=event)
if self.metrics:
self.metrics.on_decision()

next_chain_inputs, picked, event = self._call_after_predict_before_scoring(
next_inputs, picked, event = self._call_after_predict_before_scoring(
inputs=inputs, event=event, prediction=prediction
)

for callback_func in self.callbacks_before_scoring:
try:
next_chain_inputs, event = callback_func(
inputs=next_chain_inputs, picked=picked, event=event
next_inputs, event = callback_func(
inputs=next_inputs, picked=picked, event=event
)
except Exception as e:
logger.info(f"Callback function {callback_func} failed, error: {e}")
Expand All @@ -456,7 +466,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
try:
if self._can_use_selection_scorer():
score = self.selection_scorer.score_response(
inputs=next_chain_inputs, picked=picked, event=event
inputs=next_inputs, picked=picked, event=event
)
except Exception as e:
logger.info(
Expand All @@ -471,21 +481,26 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
self.policy.learn(event=event)
self.policy.log(event=event)

event.outputs = next_chain_inputs
event.outputs = next_inputs
return {"picked": picked, "picked_metadata": event}


def _embed_string_type(
item: Union[str, _Embed], model: Any, namespace: Optional[str] = None
) -> Dict[str, Union[str, List[str]]]:
"""Helper function to embed a string or an _Embed object."""
import re

keep_str = ""
if isinstance(item, _Embed):
encoded = _stringify_embedding(model.encode(item.value))
# TODO these should be moved to pick_best
if item.keep:
keep_str = item.value.replace(" ", "_") + " "
keep_str = re.sub(r"[\t\n\r\f\v]+", " ", keep_str)
elif isinstance(item, str):
encoded = item.replace(" ", "_")
encoded = re.sub(r"[\t\n\r\f\v]+", " ", encoded)
else:
raise ValueError(f"Unsupported type {type(item)} for embedding")

Expand Down
9 changes: 4 additions & 5 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,7 @@ def _call_after_predict_before_scoring(
selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
event.selected = selected

# only one key, value pair in event.to_select_from
key, value = next(iter(event.to_select_from.items()))
next_inputs = inputs.copy()
next_inputs[key] = value[event.selected.index]

# only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values()))
Expand All @@ -309,12 +306,14 @@ def _call_after_predict_before_scoring(
if event.selected
else event.to_select_from.values()
)
next_inputs[self.selected_based_on_input_key] = str(event.based_on)
next_inputs[self.selected_input_key] = v

picked = {}
for k, v in event.to_select_from.items():
picked[k] = v[event.selected.index]

next_inputs[self.selected_based_on_input_key] = str(event.based_on)
next_inputs[self.selected_input_key] = str(picked)

return next_inputs, picked, event

def _call_after_scoring_before_learning(
Expand Down

0 comments on commit c8ad86a

Please sign in to comment.