diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index 3c5d727..df91615 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -35,13 +35,19 @@ class PickBestEvent(base.Event[PickBestSelected]): def __init__( self, inputs: Dict[str, Any], - to_select_from: Dict[str, Any], - based_on: Dict[str, Any], selected: Optional[PickBestSelected] = None, ): super().__init__(inputs=inputs, selected=selected or PickBestSelected()) - self.to_select_from = to_select_from - self.based_on = based_on + self.to_select_from = base.get_to_select_from(inputs) + self.based_on = base.get_based_on(inputs) + if not self.to_select_from: + raise ValueError( + "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." + ) + if len(self.to_select_from) > 1: + raise ValueError( + "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." + ) class VwTxt: @@ -151,10 +157,6 @@ def get_context_actions( if to_select_from else None ) - if not actions: - raise ValueError( - "Context and to_select_from must be provided in the inputs dictionary" - ) return context, actions def featurize( @@ -223,20 +225,7 @@ class PickBest(base.RLLoop[PickBestEvent]): """ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: - to_select_from = base.get_to_select_from(inputs) - if not to_select_from: - raise ValueError( - "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." - ) - if len(to_select_from) > 1: - raise ValueError( - "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." - ) - return PickBestEvent( - inputs=inputs, - to_select_from=to_select_from, - based_on=base.get_based_on(inputs), - ) + return PickBestEvent(inputs=inputs) def _call_after_predict_before_scoring( self, diff --git a/tests/unit_tests/test_pick_best_text_embedder.py b/tests/unit_tests/test_pick_best_text_embedder.py index 83f327d..02fc114 100644 --- a/tests/unit_tests/test_pick_best_text_embedder.py +++ b/tests/unit_tests/test_pick_best_text_embedder.py @@ -4,16 +4,14 @@ import learn_to_pick.base as rl_chain import learn_to_pick.pick_best as pick_best_chain from learn_to_pick.pick_best import vw_cb_formatter +from learn_to_pick.base import BasedOn, ToSelectFrom def test_pickbest_textembedder_missing_context_not_throws() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_action = {"action": ["0", "1", "2"]} - event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_action, based_on={} - ) + event = pick_best_chain.PickBestEvent(inputs={"action": ToSelectFrom(["0", "1", "2"])}) featurizer.featurize(event) @@ -21,10 +19,8 @@ def test_pickbest_textembedder_missing_actions_throws() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from={}, based_on={"context": "context"} - ) with pytest.raises(ValueError): + event = pick_best_chain.PickBestEvent(inputs={"context": BasedOn("context")}) featurizer.featurize(event) @@ -32,7 +28,6 @@ def test_pickbest_textembedder_no_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ "shared |context_sparse default_ft=context", @@ -43,7 +38,7 @@ def test_pickbest_textembedder_no_label_no_emb() -> None: ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on={"context": "context"} + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])} ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -53,7 +48,6 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ "shared |context_sparse default_ft=context", @@ -64,10 +58,8 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: ) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) event = pick_best_chain.PickBestEvent( - inputs={}, - to_select_from=named_actions, - based_on={"context": "context"}, - selected=selected, + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])}, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -77,7 +69,6 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ "shared |context_sparse default_ft=context", @@ -89,9 +80,7 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None: selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, - to_select_from=named_actions, - based_on={"context": "context"}, + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])}, selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) @@ -102,15 +91,8 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str = "ctx" encoded_ctx_str = "0:3.0 1:0.0" - - named_actions = {"action": rl_chain.Embed([str1, str2, str3])} - context = {"context": rl_chain.Embed(ctx_str)} expected = "\n".join( [ f"shared |context_dense {encoded_ctx_str}", @@ -121,7 +103,10 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context": rl_chain.Embed(BasedOn("ctx")), + "action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"])) + }, selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -131,15 +116,10 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" ctx_str = "ctx" encoded_ctx_str = "0:3.0 1:0.0" - named_actions = {"action": rl_chain.EmbedAndKeep([str1, str2, str3])} - context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected = "\n".join( [ f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft={ctx_str}", @@ -150,7 +130,10 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context": rl_chain.EmbedAndKeep(BasedOn("ctx")), + "action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"])) + }, selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -160,18 +143,20 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2 ", "|a_sparse default_ft=0 |b_sparse default_ft=0", - "|action1_sparse default_ft=1", - "|action1_sparse default_ft=2", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) # noqa: E501 event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2") + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -181,8 +166,6 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2", @@ -193,7 +176,12 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2") + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -203,8 +191,6 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2", @@ -215,7 +201,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2") + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -226,20 +217,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_1 = "0:3.0 1:0.0" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = {"action": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])} - context = { - "context1": rl_chain.Embed(ctx_str_1), - "context2": rl_chain.Embed(ctx_str_2), - } expected = "\n".join( [ f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2}", @@ -251,7 +233,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(rl_chain.Embed(ctx_str_1)), + "context2": BasedOn(rl_chain.Embed(ctx_str_2)), + "action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"])) + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -264,22 +251,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_1 = "0:3.0 1:0.0" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3]) - } - context = { - "context1": rl_chain.EmbedAndKeep(ctx_str_1), - "context2": rl_chain.EmbedAndKeep(ctx_str_2), - } expected = "\n".join( [ f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1} |context2_sparse default_ft={ctx_str_2}", @@ -291,7 +267,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(rl_chain.EmbedAndKeep(ctx_str_1)), + "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), + "action": ToSelectFrom(rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"])) + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -302,19 +283,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": [{"a": str1, "b": rl_chain.Embed(str1)}, str2, rl_chain.Embed(str3)] - } - context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)} - expected = "\n".join( [ f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1}", @@ -326,7 +298,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(ctx_str_1), + "context2": BasedOn(rl_chain.Embed(ctx_str_2)), + "action": ToSelectFrom([{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")]) + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -337,22 +314,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": [ - {"a": str1, "b": rl_chain.EmbedAndKeep(str1)}, - str2, - rl_chain.EmbedAndKeep(str3), - ] - } - context = {"context1": ctx_str_1, "context2": rl_chain.EmbedAndKeep(ctx_str_2)} expected = "\n".join( [ f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1} |context2_sparse default_ft={ctx_str_2}", @@ -363,7 +328,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(ctx_str_1), + "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), + "action": ToSelectFrom([{"a": "0", "b": rl_chain.EmbedAndKeep("0")}, "1", rl_chain.EmbedAndKeep("2")]) + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -382,8 +352,6 @@ def test_raw_features_underscored() -> None: encoded_ctx_str = f"0:{float(len(ctx_str))} 1:0.0" # No embeddings - named_actions = {"action": [str1]} - context = {"context": ctx_str} expected_no_embed = "\n".join( [ f"shared |context_sparse default_ft={ctx_str_underscored}", @@ -392,26 +360,28 @@ def test_raw_features_underscored() -> None: ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom([str1]), + "context": BasedOn(ctx_str) + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_no_embed) # Just embeddings - named_actions = {"action": rl_chain.Embed([str1])} - context = {"context": rl_chain.Embed(ctx_str)} expected_embed = "\n".join( [f"shared |context_dense {encoded_ctx_str}", f"|action_dense {encoded_str1}"] ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom(rl_chain.Embed([str1])), + "context": BasedOn(rl_chain.Embed(ctx_str)) + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_embed) # Embeddings and raw features - named_actions = {"action": rl_chain.EmbedAndKeep([str1])} - context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected_embed_and_keep = "\n".join( [ f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft={ctx_str_underscored}", @@ -419,7 +389,10 @@ def test_raw_features_underscored() -> None: ] ) # noqa: E501 event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom(rl_chain.EmbedAndKeep([str1])), + "context": BasedOn(rl_chain.EmbedAndKeep(ctx_str)) + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_embed_and_keep)