diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 28b86a9..5430c39 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -93,26 +93,21 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam return [parser.parse_line(line) for line in input_str.split("\n")] -def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]: - to_select_from = { - k: inputs[k].value +def get_based_on(inputs: Dict[str, Any]) -> Dict: + return { + k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value for k in inputs.keys() - if isinstance(inputs[k], _ToSelectFrom) + if isinstance(inputs[k], _BasedOn) } - if not to_select_from: - raise ValueError( - "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." - ) - based_on = { - k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value +def get_to_select_from(inputs: Dict[str, Any]) -> Dict: + return { + k: inputs[k].value for k in inputs.keys() - if isinstance(inputs[k], _BasedOn) + if isinstance(inputs[k], _ToSelectFrom) } - return based_on, to_select_from - # end helper functions diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index abf70d6..c85eefa 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -35,36 +35,19 @@ class PickBestEvent(base.Event[PickBestSelected]): def __init__( self, inputs: Dict[str, Any], - to_select_from: Dict[str, Any], - based_on: Dict[str, Any], selected: Optional[PickBestSelected] = None, ): super().__init__(inputs=inputs, selected=selected or PickBestSelected()) - self.to_select_from = to_select_from - self.based_on = based_on - - def context(self, model) -> base.Featurized: - return base.embed(self.based_on or {}, model) - - def actions(self, model) -> List[base.Featurized]: - to_select_from_var_name, to_select_from = next( - iter(self.to_select_from.items()), (None, None) - ) - - action_embs = ( - ( - base.embed(to_select_from, model, to_select_from_var_name) - if self.to_select_from - else None + self.to_select_from = base.get_to_select_from(inputs) + self.based_on = base.get_based_on(inputs) + if not self.to_select_from: + raise ValueError( + "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." ) - if to_select_from - else None - ) - if not action_embs: + if len(self.to_select_from) > 1: raise ValueError( - "Context and to_select_from must be provided in the inputs dictionary" + "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." ) - return action_embs class VwTxt: @@ -77,9 +60,9 @@ def _sparse_2_str(values: base.SparseFeatures) -> str: def _to_str(v): import numbers - return v if isinstance(v, numbers.Number) else f"={v}" + return f":{v}" if isinstance(v, numbers.Number) else f"={v}" - return " ".join([f"{k}:{_to_str(v)}" for k, v in values.items()]) + return " ".join([f"{k}{_to_str(v)}" for k, v in values.items()]) @staticmethod def featurized_2_str(obj: base.Featurized) -> str: @@ -157,11 +140,29 @@ def _generic_namespaces(context, actions): for a in actions: a["#"] = PickBestFeaturizer._generic_namespace(a) + def get_context_and_actions( + self, event + ) -> Tuple[base.Featurized, List[base.Featurized]]: + context = base.embed(event.based_on or {}, self.model) + to_select_from_var_name, to_select_from = next( + iter(event.to_select_from.items()), (None, None) + ) + + actions = ( + ( + base.embed(to_select_from, self.model, to_select_from_var_name) + if event.to_select_from + else None + ) + if to_select_from + else None + ) + return context, actions + def featurize( self, event: PickBestEvent ) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]: - context = event.context(self.model) - actions = event.actions(self.model) + context, actions = self.get_context_and_actions(event) if self.auto_embed: self._dotproducts(context, actions) @@ -224,24 +225,7 @@ class PickBest(base.RLLoop[PickBestEvent]): """ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: - context, actions = base.get_based_on_and_to_select_from(inputs=inputs) - if not actions: - raise ValueError( - "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." - ) - - if len(list(actions.values())) > 1: - raise ValueError( - "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." - ) - - if not context: - raise ValueError( - "No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." - ) - - event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context) - return event + return PickBestEvent(inputs=inputs) def _call_after_predict_before_scoring( self, diff --git a/tests/unit_tests/test_pick_best_call.py b/tests/unit_tests/test_pick_best_call.py index d35d4b2..30c7343 100644 --- a/tests/unit_tests/test_pick_best_call.py +++ b/tests/unit_tests/test_pick_best_call.py @@ -36,7 +36,7 @@ def test_multiple_ToSelectFrom_throws() -> None: ) -def test_missing_basedOn_from_throws() -> None: +def test_missing_basedOn_from_dont_throw() -> None: pick = learn_to_pick.PickBest.create( llm=fake_llm_caller, featurizer=learn_to_pick.PickBestFeaturizer( @@ -44,8 +44,7 @@ def test_missing_basedOn_from_throws() -> None: ), ) actions = ["0", "1", "2"] - with pytest.raises(ValueError): - pick.run(action=learn_to_pick.ToSelectFrom(actions)) + pick.run(action=learn_to_pick.ToSelectFrom(actions)) def test_ToSelectFrom_not_a_list_throws() -> None: @@ -169,10 +168,10 @@ def test_everything_embedded() -> None: expected = "\n".join( [ - f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft:={ctx_str_1}", - f"|action_dense {action_dense} |action_sparse default_ft:={str1}", - f"|action_dense {action_dense} |action_sparse default_ft:={str2}", - f"|action_dense {action_dense} |action_sparse default_ft:={str3}", + f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft={ctx_str_1}", + f"|action_dense {action_dense} |action_sparse default_ft={str1}", + f"|action_dense {action_dense} |action_sparse default_ft={str2}", + f"|action_dense {action_dense} |action_sparse default_ft={str3}", ] ) # noqa @@ -198,10 +197,10 @@ def test_default_auto_embedder_is_off() -> None: expected = "\n".join( [ - f"shared |User_sparse default_ft:={ctx_str_1}", - f"|action_sparse default_ft:={str1}", - f"|action_sparse default_ft:={str2}", - f"|action_sparse default_ft:={str3}", + f"shared |User_sparse default_ft={ctx_str_1}", + f"|action_sparse default_ft={str1}", + f"|action_sparse default_ft={str2}", + f"|action_sparse default_ft={str3}", ] ) # noqa @@ -227,10 +226,10 @@ def test_default_w_embeddings_off() -> None: expected = "\n".join( [ - f"shared |User_sparse default_ft:={ctx_str_1}", - f"|action_sparse default_ft:={str1}", - f"|action_sparse default_ft:={str2}", - f"|action_sparse default_ft:={str3}", + f"shared |User_sparse default_ft={ctx_str_1}", + f"|action_sparse default_ft={str1}", + f"|action_sparse default_ft={str2}", + f"|action_sparse default_ft={str3}", ] ) # noqa @@ -258,9 +257,9 @@ def test_default_w_embeddings_on() -> None: expected = "\n".join( [ - f"shared |User_sparse default_ft:={ctx_str_1} |@_sparse User:={ctx_str_1}", - f"|action_sparse default_ft:={str1} |{dot_prod} |#_sparse action:={str1} ", - f"|action_sparse default_ft:={str2} |{dot_prod} |#_sparse action:={str2} ", + f"shared |User_sparse default_ft={ctx_str_1} |@_sparse User={ctx_str_1}", + f"|action_sparse default_ft={str1} |{dot_prod} |#_sparse action={str1} ", + f"|action_sparse default_ft={str2} |{dot_prod} |#_sparse action={str2} ", ] ) # noqa diff --git a/tests/unit_tests/test_pick_best_text_embedder.py b/tests/unit_tests/test_pick_best_text_embedder.py index b5aafd8..ea34db8 100644 --- a/tests/unit_tests/test_pick_best_text_embedder.py +++ b/tests/unit_tests/test_pick_best_text_embedder.py @@ -4,15 +4,15 @@ import learn_to_pick.base as rl_chain import learn_to_pick.pick_best as pick_best_chain from learn_to_pick.pick_best import vw_cb_formatter +from learn_to_pick.base import BasedOn, ToSelectFrom def test_pickbest_textembedder_missing_context_not_throws() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_action = {"action": ["0", "1", "2"]} event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_action, based_on={} + inputs={"action": ToSelectFrom(["0", "1", "2"])} ) featurizer.featurize(event) @@ -21,10 +21,8 @@ def test_pickbest_textembedder_missing_actions_throws() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from={}, based_on={"context": "context"} - ) with pytest.raises(ValueError): + event = pick_best_chain.PickBestEvent(inputs={"context": BasedOn("context")}) featurizer.featurize(event) @@ -32,18 +30,17 @@ def test_pickbest_textembedder_no_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ - "shared |context_sparse default_ft:=context", - "|action_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context_sparse default_ft=context", + "|action_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on={"context": "context"} + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])} ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -53,20 +50,17 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ - "shared |context_sparse default_ft:=context", - "|action_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context_sparse default_ft=context", + "|action_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) event = pick_best_chain.PickBestEvent( - inputs={}, - to_select_from=named_actions, - based_on={"context": "context"}, + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])}, selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) @@ -77,21 +71,18 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ - "shared |context_sparse default_ft:=context", - "0:-0.0:1.0 |action_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context_sparse default_ft=context", + "0:-0.0:1.0 |action_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, - to_select_from=named_actions, - based_on={"context": "context"}, + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])}, selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) @@ -102,15 +93,8 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str = "ctx" encoded_ctx_str = "0:3.0 1:0.0" - - named_actions = {"action": rl_chain.Embed([str1, str2, str3])} - context = {"context": rl_chain.Embed(ctx_str)} expected = "\n".join( [ f"shared |context_dense {encoded_ctx_str}", @@ -121,7 +105,11 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context": rl_chain.Embed(BasedOn("ctx")), + "action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"])), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -131,26 +119,25 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" ctx_str = "ctx" encoded_ctx_str = "0:3.0 1:0.0" - named_actions = {"action": rl_chain.EmbedAndKeep([str1, str2, str3])} - context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected = "\n".join( [ - f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft:={ctx_str}", - "0:-0.0:1.0 |action_dense 0:1.0 1:0.0 |action_sparse default_ft:=0", - "|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=1", - "|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2", + f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft={ctx_str}", + "0:-0.0:1.0 |action_dense 0:1.0 1:0.0 |action_sparse default_ft=0", + "|action_dense 0:1.0 1:0.0 |action_sparse default_ft=1", + "|action_dense 0:1.0 1:0.0 |action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context": rl_chain.EmbedAndKeep(BasedOn("ctx")), + "action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"])), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -160,18 +147,20 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ - "shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2 ", - "|a_sparse default_ft:=0 |b_sparse default_ft:=0", - "|action1_sparse default_ft:=1", - "|action1_sparse default_ft:=2", + "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2 ", + "|a_sparse default_ft=0 |b_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) # noqa: E501 event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2"), + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -181,19 +170,22 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ - "shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2", - "|a_sparse default_ft:=0 |b_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2", + "|a_sparse default_ft=0 |b_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2"), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -203,19 +195,22 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ - "shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2", - "0:-0.0:1.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2", + "0:-0.0:1.0 |a_sparse default_ft=0 |b_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2"), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -226,20 +221,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_1 = "0:3.0 1:0.0" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = {"action": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])} - context = { - "context1": rl_chain.Embed(ctx_str_1), - "context2": rl_chain.Embed(ctx_str_2), - } expected = "\n".join( [ f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2}", @@ -251,7 +237,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(rl_chain.Embed(ctx_str_1)), + "context2": BasedOn(rl_chain.Embed(ctx_str_2)), + "action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"])), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -264,34 +255,30 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_1 = "0:3.0 1:0.0" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3]) - } - context = { - "context1": rl_chain.EmbedAndKeep(ctx_str_1), - "context2": rl_chain.EmbedAndKeep(ctx_str_2), - } expected = "\n".join( [ - f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1} |context2_sparse default_ft:={ctx_str_2}", - f"0:-0.0:1.0 |a_dense 0:1.0 1:0.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0", - f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=1", - f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2", + f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1} |context2_sparse default_ft={ctx_str_2}", + f"0:-0.0:1.0 |a_dense 0:1.0 1:0.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft=0 |b_sparse default_ft=0", + f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft=1", + f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(rl_chain.EmbedAndKeep(ctx_str_1)), + "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), + "action": ToSelectFrom( + rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"]) + ), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -302,31 +289,29 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": [{"a": str1, "b": rl_chain.Embed(str1)}, str2, rl_chain.Embed(str3)] - } - context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)} - expected = "\n".join( [ - f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1}", - f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0", - f"|action_sparse default_ft:=1", + f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1}", + f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft=0", + f"|action_sparse default_ft=1", f"|action_dense 0:1.0 1:0.0", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(ctx_str_1), + "context2": BasedOn(rl_chain.Embed(ctx_str_2)), + "action": ToSelectFrom( + [{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")] + ), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -337,33 +322,32 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": [ - {"a": str1, "b": rl_chain.EmbedAndKeep(str1)}, - str2, - rl_chain.EmbedAndKeep(str3), - ] - } - context = {"context1": ctx_str_1, "context2": rl_chain.EmbedAndKeep(ctx_str_2)} expected = "\n".join( [ - f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1} |context2_sparse default_ft:={ctx_str_2}", - f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0", - f"|action_sparse default_ft:=1", - f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2", + f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1} |context2_sparse default_ft={ctx_str_2}", + f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft=0 |b_sparse default_ft=0", + f"|action_sparse default_ft=1", + f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(ctx_str_1), + "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), + "action": ToSelectFrom( + [ + {"a": "0", "b": rl_chain.EmbedAndKeep("0")}, + "1", + rl_chain.EmbedAndKeep("2"), + ] + ), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -382,44 +366,44 @@ def test_raw_features_underscored() -> None: encoded_ctx_str = f"0:{float(len(ctx_str))} 1:0.0" # No embeddings - named_actions = {"action": [str1]} - context = {"context": ctx_str} expected_no_embed = "\n".join( [ - f"shared |context_sparse default_ft:={ctx_str_underscored}", - f"|action_sparse default_ft:={str1_underscored}", + f"shared |context_sparse default_ft={ctx_str_underscored}", + f"|action_sparse default_ft={str1_underscored}", ] ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={"action": ToSelectFrom([str1]), "context": BasedOn(ctx_str)} ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_no_embed) # Just embeddings - named_actions = {"action": rl_chain.Embed([str1])} - context = {"context": rl_chain.Embed(ctx_str)} expected_embed = "\n".join( [f"shared |context_dense {encoded_ctx_str}", f"|action_dense {encoded_str1}"] ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom(rl_chain.Embed([str1])), + "context": BasedOn(rl_chain.Embed(ctx_str)), + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_embed) # Embeddings and raw features - named_actions = {"action": rl_chain.EmbedAndKeep([str1])} - context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected_embed_and_keep = "\n".join( [ - f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft:={ctx_str_underscored}", - f"|action_dense {encoded_str1} |action_sparse default_ft:={str1_underscored}", + f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft={ctx_str_underscored}", + f"|action_dense {encoded_str1} |action_sparse default_ft={str1_underscored}", ] ) # noqa: E501 event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom(rl_chain.EmbedAndKeep([str1])), + "context": BasedOn(rl_chain.EmbedAndKeep(ctx_str)), + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_embed_and_keep)