From 7900b5cdf65b89351185e5e729db9971fb96331d Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 2 Nov 2023 14:03:32 -0400 Subject: [PATCH 1/3] don't replace to_select_from options in input, add picked to AutoScoring input, escape newlines --- src/learn_to_pick/base.py | 15 ++++++++++----- src/learn_to_pick/pick_best.py | 3 --- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 595c7cd..3688dcf 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -282,6 +282,7 @@ def format_with_ignoring_extra_args(prompt, inputs): def score_response( self, inputs: Dict[str, Any], picked: Any, event: Event ) -> float: + inputs.update({"picked": picked}) p = AutoSelectionScorer.format_with_ignoring_extra_args(self.prompt, inputs) ranking = self.llm.predict(p) ranking = ranking.strip() @@ -440,14 +441,14 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: if self.metrics: self.metrics.on_decision() - next_chain_inputs, picked, event = self._call_after_predict_before_scoring( + next_inputs, picked, event = self._call_after_predict_before_scoring( inputs=inputs, event=event, prediction=prediction ) for callback_func in self.callbacks_before_scoring: try: - next_chain_inputs, event = callback_func( - inputs=next_chain_inputs, picked=picked, event=event + next_inputs, event = callback_func( + inputs=next_inputs, picked=picked, event=event ) except Exception as e: logger.info(f"Callback function {callback_func} failed, error: {e}") @@ -456,7 +457,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: try: if self._can_use_selection_scorer(): score = self.selection_scorer.score_response( - inputs=next_chain_inputs, picked=picked, event=event + inputs=next_inputs, picked=picked, event=event ) except Exception as e: logger.info( @@ -471,7 +472,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: self.policy.learn(event=event) self.policy.log(event=event) - event.outputs = next_chain_inputs + event.outputs = next_inputs return {"picked": picked, "picked_metadata": event} @@ -479,13 +480,17 @@ def _embed_string_type( item: Union[str, _Embed], model: Any, namespace: Optional[str] = None ) -> Dict[str, Union[str, List[str]]]: """Helper function to embed a string or an _Embed object.""" + import re keep_str = "" if isinstance(item, _Embed): encoded = _stringify_embedding(model.encode(item.value)) + # TODO these should be moved to pick_best if item.keep: keep_str = item.value.replace(" ", "_") + " " + keep_str = re.sub(r"[\t\n\r\f\v]+", " ", keep_str) elif isinstance(item, str): encoded = item.replace(" ", "_") + encoded = re.sub(r"[\t\n\r\f\v]+", " ", encoded) else: raise ValueError(f"Unsupported type {type(item)} for embedding") diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index b80790e..89bbb97 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -297,10 +297,7 @@ def _call_after_predict_before_scoring( selected = PickBestSelected(index=sampled_action, probability=sampled_prob) event.selected = selected - # only one key, value pair in event.to_select_from - key, value = next(iter(event.to_select_from.items())) next_inputs = inputs.copy() - next_inputs[key] = value[event.selected.index] # only one key, value pair in event.to_select_from value = next(iter(event.to_select_from.values())) From ad6551028b748b4ba945e87374e26683ba2091c8 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 2 Nov 2023 20:16:10 -0400 Subject: [PATCH 2/3] fix names and add input validation --- src/learn_to_pick/base.py | 19 ++++++++++++++----- src/learn_to_pick/pick_best.py | 6 ++++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 3688dcf..c6c736d 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -259,9 +259,9 @@ def get_default_system_prompt() -> str: @staticmethod def get_default_prompt() -> str: - human_template = """Given this based_on "{pick_best_selected_based_on}" \ + human_template = """Given this based_on "{selected_based_on}" \ as the most important attribute, rank how good or bad this text is: \ - "{pick_best_selected}".""" + "{picked}".""" default_system_prompt = AutoSelectionScorer.get_default_system_prompt() return default_system_prompt + human_template @@ -282,7 +282,6 @@ def format_with_ignoring_extra_args(prompt, inputs): def score_response( self, inputs: Dict[str, Any], picked: Any, event: Event ) -> float: - inputs.update({"picked": picked}) p = AutoSelectionScorer.format_with_ignoring_extra_args(self.prompt, inputs) ranking = self.llm.predict(p) ranking = ranking.strip() @@ -326,8 +325,8 @@ class RLLoop(Generic[TEvent]): """ # Define the default values as class attributes - selected_input_key = "pick_best_selected" - selected_based_on_input_key = "pick_best_selected_based_on" + selected_based_on_input_key = "selected_based_on" + selected_input_key = "picked" def __init__( self, @@ -435,6 +434,16 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: raise ValueError( "Either a dictionary positional argument or keyword arguments should be provided" ) + + if self.selected_based_on_input_key in inputs: + raise ValueError( + f"The input key {self.selected_based_on_input_key} is reserved. Please use a different key." + ) + + if self.selected_input_key in inputs: + raise ValueError( + f"The input key {self.selected_input_key} is reserved. Please use a different key." + ) event: TEvent = self._call_before_predict(inputs=inputs) prediction = self.policy.predict(event=event) diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index 89bbb97..e0b53fc 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -306,12 +306,14 @@ def _call_after_predict_before_scoring( if event.selected else event.to_select_from.values() ) - next_inputs[self.selected_based_on_input_key] = str(event.based_on) - next_inputs[self.selected_input_key] = v + picked = {} for k, v in event.to_select_from.items(): picked[k] = v[event.selected.index] + next_inputs[self.selected_based_on_input_key] = str(event.based_on) + next_inputs[self.selected_input_key] = str(picked) + return next_inputs, picked, event def _call_after_scoring_before_learning( From a62686f681157e5ed1d3fc7ef8bd4798098286aa Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 2 Nov 2023 20:33:35 -0400 Subject: [PATCH 3/3] black formatting --- src/learn_to_pick/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index c6c736d..6612b75 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -434,12 +434,12 @@ def run(self, *args, **kwargs) -> Dict[str, Any]: raise ValueError( "Either a dictionary positional argument or keyword arguments should be provided" ) - + if self.selected_based_on_input_key in inputs: raise ValueError( f"The input key {self.selected_based_on_input_key} is reserved. Please use a different key." ) - + if self.selected_input_key in inputs: raise ValueError( f"The input key {self.selected_input_key} is reserved. Please use a different key." @@ -490,6 +490,7 @@ def _embed_string_type( ) -> Dict[str, Union[str, List[str]]]: """Helper function to embed a string or an _Embed object.""" import re + keep_str = "" if isinstance(item, _Embed): encoded = _stringify_embedding(model.encode(item.value))