From 9791d179dda1a176c1d504730cf4fb3fc7be8c73 Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Wed, 15 Nov 2023 22:01:58 -0500 Subject: [PATCH 1/8] := -> = --- src/learn_to_pick/pick_best.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index abf70d6..c4883c1 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -77,9 +77,9 @@ def _sparse_2_str(values: base.SparseFeatures) -> str: def _to_str(v): import numbers - return v if isinstance(v, numbers.Number) else f"={v}" + return f":{v}" if isinstance(v, numbers.Number) else f"={v}" - return " ".join([f"{k}:{_to_str(v)}" for k, v in values.items()]) + return " ".join([f"{k}{_to_str(v)}" for k, v in values.items()]) @staticmethod def featurized_2_str(obj: base.Featurized) -> str: From 76e04cb632e09c73322202a2c1060f0f1d304400 Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Wed, 15 Nov 2023 22:02:09 -0500 Subject: [PATCH 2/8] tests fix --- tests/unit_tests/test_pick_best_call.py | 30 +++---- .../test_pick_best_text_embedder.py | 86 +++++++++---------- 2 files changed, 58 insertions(+), 58 deletions(-) diff --git a/tests/unit_tests/test_pick_best_call.py b/tests/unit_tests/test_pick_best_call.py index d35d4b2..b283092 100644 --- a/tests/unit_tests/test_pick_best_call.py +++ b/tests/unit_tests/test_pick_best_call.py @@ -169,10 +169,10 @@ def test_everything_embedded() -> None: expected = "\n".join( [ - f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft:={ctx_str_1}", - f"|action_dense {action_dense} |action_sparse default_ft:={str1}", - f"|action_dense {action_dense} |action_sparse default_ft:={str2}", - f"|action_dense {action_dense} |action_sparse default_ft:={str3}", + f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft={ctx_str_1}", + f"|action_dense {action_dense} |action_sparse default_ft={str1}", + f"|action_dense {action_dense} |action_sparse default_ft={str2}", + f"|action_dense {action_dense} |action_sparse default_ft={str3}", ] ) # noqa @@ -198,10 +198,10 @@ def test_default_auto_embedder_is_off() -> None: expected = "\n".join( [ - f"shared |User_sparse default_ft:={ctx_str_1}", - f"|action_sparse default_ft:={str1}", - f"|action_sparse default_ft:={str2}", - f"|action_sparse default_ft:={str3}", + f"shared |User_sparse default_ft={ctx_str_1}", + f"|action_sparse default_ft={str1}", + f"|action_sparse default_ft={str2}", + f"|action_sparse default_ft={str3}", ] ) # noqa @@ -227,10 +227,10 @@ def test_default_w_embeddings_off() -> None: expected = "\n".join( [ - f"shared |User_sparse default_ft:={ctx_str_1}", - f"|action_sparse default_ft:={str1}", - f"|action_sparse default_ft:={str2}", - f"|action_sparse default_ft:={str3}", + f"shared |User_sparse default_ft={ctx_str_1}", + f"|action_sparse default_ft={str1}", + f"|action_sparse default_ft={str2}", + f"|action_sparse default_ft={str3}", ] ) # noqa @@ -258,9 +258,9 @@ def test_default_w_embeddings_on() -> None: expected = "\n".join( [ - f"shared |User_sparse default_ft:={ctx_str_1} |@_sparse User:={ctx_str_1}", - f"|action_sparse default_ft:={str1} |{dot_prod} |#_sparse action:={str1} ", - f"|action_sparse default_ft:={str2} |{dot_prod} |#_sparse action:={str2} ", + f"shared |User_sparse default_ft={ctx_str_1} |@_sparse User={ctx_str_1}", + f"|action_sparse default_ft={str1} |{dot_prod} |#_sparse action={str1} ", + f"|action_sparse default_ft={str2} |{dot_prod} |#_sparse action={str2} ", ] ) # noqa diff --git a/tests/unit_tests/test_pick_best_text_embedder.py b/tests/unit_tests/test_pick_best_text_embedder.py index b5aafd8..83f327d 100644 --- a/tests/unit_tests/test_pick_best_text_embedder.py +++ b/tests/unit_tests/test_pick_best_text_embedder.py @@ -35,10 +35,10 @@ def test_pickbest_textembedder_no_label_no_emb() -> None: named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ - "shared |context_sparse default_ft:=context", - "|action_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context_sparse default_ft=context", + "|action_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) @@ -56,10 +56,10 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ - "shared |context_sparse default_ft:=context", - "|action_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context_sparse default_ft=context", + "|action_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) @@ -80,10 +80,10 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None: named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ - "shared |context_sparse default_ft:=context", - "0:-0.0:1.0 |action_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context_sparse default_ft=context", + "0:-0.0:1.0 |action_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) @@ -142,10 +142,10 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected = "\n".join( [ - f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft:={ctx_str}", - "0:-0.0:1.0 |action_dense 0:1.0 1:0.0 |action_sparse default_ft:=0", - "|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=1", - "|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2", + f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft={ctx_str}", + "0:-0.0:1.0 |action_dense 0:1.0 1:0.0 |action_sparse default_ft=0", + "|action_dense 0:1.0 1:0.0 |action_sparse default_ft=1", + "|action_dense 0:1.0 1:0.0 |action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) @@ -164,10 +164,10 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ - "shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2 ", - "|a_sparse default_ft:=0 |b_sparse default_ft:=0", - "|action1_sparse default_ft:=1", - "|action1_sparse default_ft:=2", + "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2 ", + "|a_sparse default_ft=0 |b_sparse default_ft=0", + "|action1_sparse default_ft=1", + "|action1_sparse default_ft=2", ] ) # noqa: E501 event = pick_best_chain.PickBestEvent( @@ -185,10 +185,10 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ - "shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2", - "|a_sparse default_ft:=0 |b_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2", + "|a_sparse default_ft=0 |b_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) @@ -207,10 +207,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ - "shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2", - "0:-0.0:1.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0", - "|action_sparse default_ft:=1", - "|action_sparse default_ft:=2", + "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2", + "0:-0.0:1.0 |a_sparse default_ft=0 |b_sparse default_ft=0", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) @@ -282,10 +282,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee } expected = "\n".join( [ - f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1} |context2_sparse default_ft:={ctx_str_2}", - f"0:-0.0:1.0 |a_dense 0:1.0 1:0.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0", - f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=1", - f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2", + f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1} |context2_sparse default_ft={ctx_str_2}", + f"0:-0.0:1.0 |a_dense 0:1.0 1:0.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft=0 |b_sparse default_ft=0", + f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft=1", + f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft=2", ] ) # noqa: E501 @@ -317,9 +317,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N expected = "\n".join( [ - f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1}", - f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0", - f"|action_sparse default_ft:=1", + f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1}", + f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft=0", + f"|action_sparse default_ft=1", f"|action_dense 0:1.0 1:0.0", ] ) # noqa: E501 @@ -355,10 +355,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() context = {"context1": ctx_str_1, "context2": rl_chain.EmbedAndKeep(ctx_str_2)} expected = "\n".join( [ - f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1} |context2_sparse default_ft:={ctx_str_2}", - f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0", - f"|action_sparse default_ft:=1", - f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2", + f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1} |context2_sparse default_ft={ctx_str_2}", + f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft=0 |b_sparse default_ft=0", + f"|action_sparse default_ft=1", + f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft=2", ] ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) @@ -386,8 +386,8 @@ def test_raw_features_underscored() -> None: context = {"context": ctx_str} expected_no_embed = "\n".join( [ - f"shared |context_sparse default_ft:={ctx_str_underscored}", - f"|action_sparse default_ft:={str1_underscored}", + f"shared |context_sparse default_ft={ctx_str_underscored}", + f"|action_sparse default_ft={str1_underscored}", ] ) @@ -414,8 +414,8 @@ def test_raw_features_underscored() -> None: context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected_embed_and_keep = "\n".join( [ - f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft:={ctx_str_underscored}", - f"|action_dense {encoded_str1} |action_sparse default_ft:={str1_underscored}", + f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft={ctx_str_underscored}", + f"|action_dense {encoded_str1} |action_sparse default_ft={str1_underscored}", ] ) # noqa: E501 event = pick_best_chain.PickBestEvent( From 7b34c514f98bb5adc840d63123d49791a3d09e87 Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Wed, 15 Nov 2023 22:09:39 -0500 Subject: [PATCH 3/8] get_context_actions to featurized --- src/learn_to_pick/pick_best.py | 47 ++++++++++++++++------------------ 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index c4883c1..5b787d1 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -43,29 +43,6 @@ def __init__( self.to_select_from = to_select_from self.based_on = based_on - def context(self, model) -> base.Featurized: - return base.embed(self.based_on or {}, model) - - def actions(self, model) -> List[base.Featurized]: - to_select_from_var_name, to_select_from = next( - iter(self.to_select_from.items()), (None, None) - ) - - action_embs = ( - ( - base.embed(to_select_from, model, to_select_from_var_name) - if self.to_select_from - else None - ) - if to_select_from - else None - ) - if not action_embs: - raise ValueError( - "Context and to_select_from must be provided in the inputs dictionary" - ) - return action_embs - class VwTxt: @staticmethod @@ -157,11 +134,31 @@ def _generic_namespaces(context, actions): for a in actions: a["#"] = PickBestFeaturizer._generic_namespace(a) + def get_context_actions(self, event) -> Tuple[base.Featurized, List[base.Featurized]]: + context = base.embed(event.based_on or {}, self.model) + to_select_from_var_name, to_select_from = next( + iter(event.to_select_from.items()), (None, None) + ) + + actions = ( + ( + base.embed(to_select_from, self.model, to_select_from_var_name) + if event.to_select_from + else None + ) + if to_select_from + else None + ) + if not actions: + raise ValueError( + "Context and to_select_from must be provided in the inputs dictionary" + ) + return context, actions + def featurize( self, event: PickBestEvent ) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]: - context = event.context(self.model) - actions = event.actions(self.model) + context, actions = self.get_context_actions(event) if self.auto_embed: self._dotproducts(context, actions) From 9f39ea9ba4ac636a3c3e06bacb22063e9b729213 Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Wed, 15 Nov 2023 22:11:48 -0500 Subject: [PATCH 4/8] black --- src/learn_to_pick/pick_best.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index 5b787d1..8ed4b36 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -134,7 +134,9 @@ def _generic_namespaces(context, actions): for a in actions: a["#"] = PickBestFeaturizer._generic_namespace(a) - def get_context_actions(self, event) -> Tuple[base.Featurized, List[base.Featurized]]: + def get_context_actions( + self, event + ) -> Tuple[base.Featurized, List[base.Featurized]]: context = base.embed(event.based_on or {}, self.model) to_select_from_var_name, to_select_from = next( iter(event.to_select_from.items()), (None, None) From ca7f469a4bd754f6d942f0117a1cde46ff500c72 Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Wed, 15 Nov 2023 23:56:27 -0500 Subject: [PATCH 5/8] exceptions in one place --- src/learn_to_pick/base.py | 21 ++++++++------------- src/learn_to_pick/pick_best.py | 20 ++++++++------------ tests/unit_tests/test_pick_best_call.py | 5 ++--- 3 files changed, 18 insertions(+), 28 deletions(-) diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index 28b86a9..5430c39 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -93,26 +93,21 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam return [parser.parse_line(line) for line in input_str.split("\n")] -def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]: - to_select_from = { - k: inputs[k].value +def get_based_on(inputs: Dict[str, Any]) -> Dict: + return { + k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value for k in inputs.keys() - if isinstance(inputs[k], _ToSelectFrom) + if isinstance(inputs[k], _BasedOn) } - if not to_select_from: - raise ValueError( - "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." - ) - based_on = { - k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value +def get_to_select_from(inputs: Dict[str, Any]) -> Dict: + return { + k: inputs[k].value for k in inputs.keys() - if isinstance(inputs[k], _BasedOn) + if isinstance(inputs[k], _ToSelectFrom) } - return based_on, to_select_from - # end helper functions diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index 8ed4b36..3c5d727 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -223,24 +223,20 @@ class PickBest(base.RLLoop[PickBestEvent]): """ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: - context, actions = base.get_based_on_and_to_select_from(inputs=inputs) - if not actions: + to_select_from = base.get_to_select_from(inputs) + if not to_select_from: raise ValueError( "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." ) - - if len(list(actions.values())) > 1: + if len(to_select_from) > 1: raise ValueError( "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." ) - - if not context: - raise ValueError( - "No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on." - ) - - event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context) - return event + return PickBestEvent( + inputs=inputs, + to_select_from=to_select_from, + based_on=base.get_based_on(inputs), + ) def _call_after_predict_before_scoring( self, diff --git a/tests/unit_tests/test_pick_best_call.py b/tests/unit_tests/test_pick_best_call.py index b283092..30c7343 100644 --- a/tests/unit_tests/test_pick_best_call.py +++ b/tests/unit_tests/test_pick_best_call.py @@ -36,7 +36,7 @@ def test_multiple_ToSelectFrom_throws() -> None: ) -def test_missing_basedOn_from_throws() -> None: +def test_missing_basedOn_from_dont_throw() -> None: pick = learn_to_pick.PickBest.create( llm=fake_llm_caller, featurizer=learn_to_pick.PickBestFeaturizer( @@ -44,8 +44,7 @@ def test_missing_basedOn_from_throws() -> None: ), ) actions = ["0", "1", "2"] - with pytest.raises(ValueError): - pick.run(action=learn_to_pick.ToSelectFrom(actions)) + pick.run(action=learn_to_pick.ToSelectFrom(actions)) def test_ToSelectFrom_not_a_list_throws() -> None: From 86f382761e9863236fdcf9c2f3a5add01ca80d8f Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Thu, 16 Nov 2023 01:41:05 -0500 Subject: [PATCH 6/8] all cb actions checks to PickBestEvent ctr --- src/learn_to_pick/pick_best.py | 33 ++-- .../test_pick_best_text_embedder.py | 167 ++++++++---------- 2 files changed, 81 insertions(+), 119 deletions(-) diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index 3c5d727..df91615 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -35,13 +35,19 @@ class PickBestEvent(base.Event[PickBestSelected]): def __init__( self, inputs: Dict[str, Any], - to_select_from: Dict[str, Any], - based_on: Dict[str, Any], selected: Optional[PickBestSelected] = None, ): super().__init__(inputs=inputs, selected=selected or PickBestSelected()) - self.to_select_from = to_select_from - self.based_on = based_on + self.to_select_from = base.get_to_select_from(inputs) + self.based_on = base.get_based_on(inputs) + if not self.to_select_from: + raise ValueError( + "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." + ) + if len(self.to_select_from) > 1: + raise ValueError( + "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." + ) class VwTxt: @@ -151,10 +157,6 @@ def get_context_actions( if to_select_from else None ) - if not actions: - raise ValueError( - "Context and to_select_from must be provided in the inputs dictionary" - ) return context, actions def featurize( @@ -223,20 +225,7 @@ class PickBest(base.RLLoop[PickBestEvent]): """ def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent: - to_select_from = base.get_to_select_from(inputs) - if not to_select_from: - raise ValueError( - "No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from." - ) - if len(to_select_from) > 1: - raise ValueError( - "Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from." - ) - return PickBestEvent( - inputs=inputs, - to_select_from=to_select_from, - based_on=base.get_based_on(inputs), - ) + return PickBestEvent(inputs=inputs) def _call_after_predict_before_scoring( self, diff --git a/tests/unit_tests/test_pick_best_text_embedder.py b/tests/unit_tests/test_pick_best_text_embedder.py index 83f327d..02fc114 100644 --- a/tests/unit_tests/test_pick_best_text_embedder.py +++ b/tests/unit_tests/test_pick_best_text_embedder.py @@ -4,16 +4,14 @@ import learn_to_pick.base as rl_chain import learn_to_pick.pick_best as pick_best_chain from learn_to_pick.pick_best import vw_cb_formatter +from learn_to_pick.base import BasedOn, ToSelectFrom def test_pickbest_textembedder_missing_context_not_throws() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_action = {"action": ["0", "1", "2"]} - event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_action, based_on={} - ) + event = pick_best_chain.PickBestEvent(inputs={"action": ToSelectFrom(["0", "1", "2"])}) featurizer.featurize(event) @@ -21,10 +19,8 @@ def test_pickbest_textembedder_missing_actions_throws() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from={}, based_on={"context": "context"} - ) with pytest.raises(ValueError): + event = pick_best_chain.PickBestEvent(inputs={"context": BasedOn("context")}) featurizer.featurize(event) @@ -32,7 +28,6 @@ def test_pickbest_textembedder_no_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ "shared |context_sparse default_ft=context", @@ -43,7 +38,7 @@ def test_pickbest_textembedder_no_label_no_emb() -> None: ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on={"context": "context"} + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])} ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -53,7 +48,6 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ "shared |context_sparse default_ft=context", @@ -64,10 +58,8 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: ) selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) event = pick_best_chain.PickBestEvent( - inputs={}, - to_select_from=named_actions, - based_on={"context": "context"}, - selected=selected, + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])}, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -77,7 +69,6 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": ["0", "1", "2"]} expected = "\n".join( [ "shared |context_sparse default_ft=context", @@ -89,9 +80,7 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None: selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, - to_select_from=named_actions, - based_on={"context": "context"}, + inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])}, selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) @@ -102,15 +91,8 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str = "ctx" encoded_ctx_str = "0:3.0 1:0.0" - - named_actions = {"action": rl_chain.Embed([str1, str2, str3])} - context = {"context": rl_chain.Embed(ctx_str)} expected = "\n".join( [ f"shared |context_dense {encoded_ctx_str}", @@ -121,7 +103,10 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context": rl_chain.Embed(BasedOn("ctx")), + "action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"])) + }, selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -131,15 +116,10 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" ctx_str = "ctx" encoded_ctx_str = "0:3.0 1:0.0" - named_actions = {"action": rl_chain.EmbedAndKeep([str1, str2, str3])} - context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected = "\n".join( [ f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft={ctx_str}", @@ -150,7 +130,10 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context": rl_chain.EmbedAndKeep(BasedOn("ctx")), + "action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"])) + }, selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -160,18 +143,20 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action1": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2 ", "|a_sparse default_ft=0 |b_sparse default_ft=0", - "|action1_sparse default_ft=1", - "|action1_sparse default_ft=2", + "|action_sparse default_ft=1", + "|action_sparse default_ft=2", ] ) # noqa: E501 event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2") + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -181,8 +166,6 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2", @@ -193,7 +176,12 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2") + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -203,8 +191,6 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - named_actions = {"action": [{"a": "0", "b": "0"}, "1", "2"]} - context = {"context1": "context1", "context2": "context2"} expected = "\n".join( [ "shared |context1_sparse default_ft=context1 |context2_sparse default_ft=context2", @@ -215,7 +201,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), + "context1": BasedOn("context1"), + "context2": BasedOn("context2") + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -226,20 +217,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_1 = "0:3.0 1:0.0" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = {"action": rl_chain.Embed([{"a": str1, "b": str1}, str2, str3])} - context = { - "context1": rl_chain.Embed(ctx_str_1), - "context2": rl_chain.Embed(ctx_str_2), - } expected = "\n".join( [ f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2}", @@ -251,7 +233,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(rl_chain.Embed(ctx_str_1)), + "context2": BasedOn(rl_chain.Embed(ctx_str_2)), + "action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"])) + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -264,22 +251,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_1 = "0:3.0 1:0.0" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": rl_chain.EmbedAndKeep([{"a": str1, "b": str1}, str2, str3]) - } - context = { - "context1": rl_chain.EmbedAndKeep(ctx_str_1), - "context2": rl_chain.EmbedAndKeep(ctx_str_2), - } expected = "\n".join( [ f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1} |context2_sparse default_ft={ctx_str_2}", @@ -291,7 +267,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(rl_chain.EmbedAndKeep(ctx_str_1)), + "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), + "action": ToSelectFrom(rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"])) + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -302,19 +283,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": [{"a": str1, "b": rl_chain.Embed(str1)}, str2, rl_chain.Embed(str3)] - } - context = {"context1": ctx_str_1, "context2": rl_chain.Embed(ctx_str_2)} - expected = "\n".join( [ f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1}", @@ -326,7 +298,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(ctx_str_1), + "context2": BasedOn(rl_chain.Embed(ctx_str_2)), + "action": ToSelectFrom([{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")]) + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -337,22 +314,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() auto_embed=False, model=MockEncoder() ) - str1 = "0" - str2 = "1" - str3 = "2" - ctx_str_1 = "ctx" ctx_str_2 = "ctx_" encoded_ctx_str_2 = "0:4.0 1:0.0" - named_actions = { - "action": [ - {"a": str1, "b": rl_chain.EmbedAndKeep(str1)}, - str2, - rl_chain.EmbedAndKeep(str3), - ] - } - context = {"context1": ctx_str_1, "context2": rl_chain.EmbedAndKeep(ctx_str_2)} expected = "\n".join( [ f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft={ctx_str_1} |context2_sparse default_ft={ctx_str_2}", @@ -363,7 +328,12 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() ) # noqa: E501 selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context, selected=selected + inputs={ + "context1": BasedOn(ctx_str_1), + "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), + "action": ToSelectFrom([{"a": "0", "b": rl_chain.EmbedAndKeep("0")}, "1", rl_chain.EmbedAndKeep("2")]) + }, + selected=selected ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -382,8 +352,6 @@ def test_raw_features_underscored() -> None: encoded_ctx_str = f"0:{float(len(ctx_str))} 1:0.0" # No embeddings - named_actions = {"action": [str1]} - context = {"context": ctx_str} expected_no_embed = "\n".join( [ f"shared |context_sparse default_ft={ctx_str_underscored}", @@ -392,26 +360,28 @@ def test_raw_features_underscored() -> None: ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom([str1]), + "context": BasedOn(ctx_str) + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_no_embed) # Just embeddings - named_actions = {"action": rl_chain.Embed([str1])} - context = {"context": rl_chain.Embed(ctx_str)} expected_embed = "\n".join( [f"shared |context_dense {encoded_ctx_str}", f"|action_dense {encoded_str1}"] ) event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom(rl_chain.Embed([str1])), + "context": BasedOn(rl_chain.Embed(ctx_str)) + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_embed) # Embeddings and raw features - named_actions = {"action": rl_chain.EmbedAndKeep([str1])} - context = {"context": rl_chain.EmbedAndKeep(ctx_str)} expected_embed_and_keep = "\n".join( [ f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft={ctx_str_underscored}", @@ -419,7 +389,10 @@ def test_raw_features_underscored() -> None: ] ) # noqa: E501 event = pick_best_chain.PickBestEvent( - inputs={}, to_select_from=named_actions, based_on=context + inputs={ + "action": ToSelectFrom(rl_chain.EmbedAndKeep([str1])), + "context": BasedOn(rl_chain.EmbedAndKeep(ctx_str)) + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_embed_and_keep) From 9643dff294c66e6f60b630ccd6341f8b80e62f6f Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Thu, 16 Nov 2023 01:42:06 -0500 Subject: [PATCH 7/8] black --- .../test_pick_best_text_embedder.py | 67 +++++++++++-------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/tests/unit_tests/test_pick_best_text_embedder.py b/tests/unit_tests/test_pick_best_text_embedder.py index 02fc114..ea34db8 100644 --- a/tests/unit_tests/test_pick_best_text_embedder.py +++ b/tests/unit_tests/test_pick_best_text_embedder.py @@ -11,7 +11,9 @@ def test_pickbest_textembedder_missing_context_not_throws() -> None: featurizer = pick_best_chain.PickBestFeaturizer( auto_embed=False, model=MockEncoder() ) - event = pick_best_chain.PickBestEvent(inputs={"action": ToSelectFrom(["0", "1", "2"])}) + event = pick_best_chain.PickBestEvent( + inputs={"action": ToSelectFrom(["0", "1", "2"])} + ) featurizer.featurize(event) @@ -59,7 +61,7 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None: selected = pick_best_chain.PickBestSelected(index=0, probability=1.0) event = pick_best_chain.PickBestEvent( inputs={"context": BasedOn("context"), "action": ToSelectFrom(["0", "1", "2"])}, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -105,8 +107,9 @@ def test_pickbest_textembedder_w_full_label_w_emb() -> None: event = pick_best_chain.PickBestEvent( inputs={ "context": rl_chain.Embed(BasedOn("ctx")), - "action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"])) - }, selected=selected + "action": rl_chain.Embed(ToSelectFrom(["0", "1", "2"])), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -132,8 +135,9 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None: event = pick_best_chain.PickBestEvent( inputs={ "context": rl_chain.EmbedAndKeep(BasedOn("ctx")), - "action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"])) - }, selected=selected + "action": rl_chain.EmbedAndKeep(ToSelectFrom(["0", "1", "2"])), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -155,8 +159,8 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None: inputs={ "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), "context1": BasedOn("context1"), - "context2": BasedOn("context2") - } + "context2": BasedOn("context2"), + } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -179,9 +183,9 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None: inputs={ "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), "context1": BasedOn("context1"), - "context2": BasedOn("context2") - }, - selected=selected + "context2": BasedOn("context2"), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -204,9 +208,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None: inputs={ "action": ToSelectFrom([{"a": "0", "b": "0"}, "1", "2"]), "context1": BasedOn("context1"), - "context2": BasedOn("context2") - }, - selected=selected + "context2": BasedOn("context2"), + }, + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -236,9 +240,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_emb() -> None inputs={ "context1": BasedOn(rl_chain.Embed(ctx_str_1)), "context2": BasedOn(rl_chain.Embed(ctx_str_2)), - "action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"])) + "action": ToSelectFrom(rl_chain.Embed([{"a": "0", "b": "0"}, "1", "2"])), }, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -270,9 +274,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee inputs={ "context1": BasedOn(rl_chain.EmbedAndKeep(ctx_str_1)), "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), - "action": ToSelectFrom(rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"])) + "action": ToSelectFrom( + rl_chain.EmbedAndKeep([{"a": "0", "b": "0"}, "1", "2"]) + ), }, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -301,9 +307,11 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N inputs={ "context1": BasedOn(ctx_str_1), "context2": BasedOn(rl_chain.Embed(ctx_str_2)), - "action": ToSelectFrom([{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")]) + "action": ToSelectFrom( + [{"a": "0", "b": rl_chain.Embed("0")}, "1", rl_chain.Embed("2")] + ), }, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -331,9 +339,15 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep() inputs={ "context1": BasedOn(ctx_str_1), "context2": BasedOn(rl_chain.EmbedAndKeep(ctx_str_2)), - "action": ToSelectFrom([{"a": "0", "b": rl_chain.EmbedAndKeep("0")}, "1", rl_chain.EmbedAndKeep("2")]) + "action": ToSelectFrom( + [ + {"a": "0", "b": rl_chain.EmbedAndKeep("0")}, + "1", + rl_chain.EmbedAndKeep("2"), + ] + ), }, - selected=selected + selected=selected, ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected) @@ -360,10 +374,7 @@ def test_raw_features_underscored() -> None: ) event = pick_best_chain.PickBestEvent( - inputs={ - "action": ToSelectFrom([str1]), - "context": BasedOn(ctx_str) - } + inputs={"action": ToSelectFrom([str1]), "context": BasedOn(ctx_str)} ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) assert_vw_ex_equals(vw_ex_str, expected_no_embed) @@ -375,7 +386,7 @@ def test_raw_features_underscored() -> None: event = pick_best_chain.PickBestEvent( inputs={ "action": ToSelectFrom(rl_chain.Embed([str1])), - "context": BasedOn(rl_chain.Embed(ctx_str)) + "context": BasedOn(rl_chain.Embed(ctx_str)), } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) @@ -391,7 +402,7 @@ def test_raw_features_underscored() -> None: event = pick_best_chain.PickBestEvent( inputs={ "action": ToSelectFrom(rl_chain.EmbedAndKeep([str1])), - "context": BasedOn(rl_chain.EmbedAndKeep(ctx_str)) + "context": BasedOn(rl_chain.EmbedAndKeep(ctx_str)), } ) vw_ex_str = vw_cb_formatter(*featurizer.featurize(event)) From 2f6d21e2a1d065adfb0e842d58a4120c092543ed Mon Sep 17 00:00:00 2001 From: Alexey Taymanov Date: Thu, 16 Nov 2023 01:54:36 -0500 Subject: [PATCH 8/8] naming cleanup --- src/learn_to_pick/pick_best.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index df91615..c85eefa 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -140,7 +140,7 @@ def _generic_namespaces(context, actions): for a in actions: a["#"] = PickBestFeaturizer._generic_namespace(a) - def get_context_actions( + def get_context_and_actions( self, event ) -> Tuple[base.Featurized, List[base.Featurized]]: context = base.embed(event.based_on or {}, self.model) @@ -162,7 +162,7 @@ def get_context_actions( def featurize( self, event: PickBestEvent ) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]: - context, actions = self.get_context_actions(event) + context, actions = self.get_context_and_actions(event) if self.auto_embed: self._dotproducts(context, actions)