Skip to content

Commit

Permalink
raw -> default_ft
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano committed Nov 15, 2023
1 parent 36fd0aa commit c333c0a
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 107 deletions.
4 changes: 2 additions & 2 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,10 @@ def _embed_string_type(
result[namespace] = DenseFeatures(model.encode(item.value))
if item.keep:
keep_str = item.value.replace(" ", "_")
result[namespace] = {"raw": re.sub(r"[\t\n\r\f\v]+", " ", keep_str)}
result[namespace] = {"default_ft": re.sub(r"[\t\n\r\f\v]+", " ", keep_str)}
elif isinstance(item, str):
encoded = item.replace(" ", "_")
result[namespace] = {"raw": re.sub(r"[\t\n\r\f\v]+", " ", encoded)}
result[namespace] = {"default_ft": re.sub(r"[\t\n\r\f\v]+", " ", encoded)}
else:
raise ValueError(f"Unsupported type {type(item)} for embedding")

Expand Down
12 changes: 6 additions & 6 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def __init__(
def _dotproducts(self, context, actions):
_context_dense = base.Featurized()
for ns in context.sparse.keys():
if "raw" in context.sparse[ns]:
_context_dense[ns] = self.model.encode(context.sparse[ns]["raw"])
if "default_ft" in context.sparse[ns]:
_context_dense[ns] = self.model.encode(context.sparse[ns]["default_ft"])

_actions_dense = [base.Featurized() for _ in range(len(actions))]
for _action, action in zip(_actions_dense, actions):
for ns in action.sparse.keys():
if "raw" in action.sparse[ns]:
_action[ns] = self.model.encode(action.sparse[ns]["raw"])
if "default_ft" in action.sparse[ns]:
_action[ns] = self.model.encode(action.sparse[ns]["default_ft"])

context_names = list(_context_dense.dense.keys())
context_matrix = np.stack(list(_context_dense.dense.values()))
Expand All @@ -146,8 +146,8 @@ def _dotproducts(self, context, actions):
def _generic_namespace(self, featurized):
result = base.SparseFeatures()
for ns in featurized.sparse.keys():
if "raw" in featurized.sparse[ns]:
result[ns] = featurized.sparse[ns]["raw"]
if "default_ft" in featurized.sparse[ns]:
result[ns] = featurized.sparse[ns]["default_ft"]
return result

def _generic_namespaces(self, context, actions):
Expand Down
30 changes: 15 additions & 15 deletions tests/unit_tests/test_pick_best_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def test_everything_embedded() -> None:

expected = "\n".join(
[
f"shared |User_dense {encoded_ctx_str_1} |User_sparse raw:={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse raw:={str1}",
f"|action_dense {action_dense} |action_sparse raw:={str2}",
f"|action_dense {action_dense} |action_sparse raw:={str3}",
f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft:={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse default_ft:={str1}",
f"|action_dense {action_dense} |action_sparse default_ft:={str2}",
f"|action_dense {action_dense} |action_sparse default_ft:={str3}",
]
) # noqa

Expand All @@ -198,10 +198,10 @@ def test_default_auto_embedder_is_off() -> None:

expected = "\n".join(
[
f"shared |User_sparse raw:={ctx_str_1}",
f"|action_sparse raw:={str1}",
f"|action_sparse raw:={str2}",
f"|action_sparse raw:={str3}",
f"shared |User_sparse default_ft:={ctx_str_1}",
f"|action_sparse default_ft:={str1}",
f"|action_sparse default_ft:={str2}",
f"|action_sparse default_ft:={str3}",
]
) # noqa

Expand All @@ -227,10 +227,10 @@ def test_default_w_embeddings_off() -> None:

expected = "\n".join(
[
f"shared |User_sparse raw:={ctx_str_1}",
f"|action_sparse raw:={str1}",
f"|action_sparse raw:={str2}",
f"|action_sparse raw:={str3}",
f"shared |User_sparse default_ft:={ctx_str_1}",
f"|action_sparse default_ft:={str1}",
f"|action_sparse default_ft:={str2}",
f"|action_sparse default_ft:={str3}",
]
) # noqa

Expand Down Expand Up @@ -258,9 +258,9 @@ def test_default_w_embeddings_on() -> None:

expected = "\n".join(
[
f"shared |User_sparse raw:={ctx_str_1} |@_sparse User:={ctx_str_1}",
f"|action_sparse raw:={str1} |{dot_prod} |#_sparse action:={str1} ",
f"|action_sparse raw:={str2} |{dot_prod} |#_sparse action:={str2} ",
f"shared |User_sparse default_ft:={ctx_str_1} |@_sparse User:={ctx_str_1}",
f"|action_sparse default_ft:={str1} |{dot_prod} |#_sparse action:={str1} ",
f"|action_sparse default_ft:={str2} |{dot_prod} |#_sparse action:={str2} ",
]
) # noqa

Expand Down
86 changes: 43 additions & 43 deletions tests/unit_tests/test_pick_best_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def test_pickbest_textembedder_no_label_no_emb() -> None:
named_actions = {"action": ["0", "1", "2"]}
expected = "\n".join(
[
"shared |context_sparse raw:=context",
"|action_sparse raw:=0",
"|action_sparse raw:=1",
"|action_sparse raw:=2",
"shared |context_sparse default_ft:=context",
"|action_sparse default_ft:=0",
"|action_sparse default_ft:=1",
"|action_sparse default_ft:=2",
]
)

Expand All @@ -56,10 +56,10 @@ def test_pickbest_textembedder_w_label_no_score_no_emb() -> None:
named_actions = {"action": ["0", "1", "2"]}
expected = "\n".join(
[
"shared |context_sparse raw:=context",
"|action_sparse raw:=0",
"|action_sparse raw:=1",
"|action_sparse raw:=2",
"shared |context_sparse default_ft:=context",
"|action_sparse default_ft:=0",
"|action_sparse default_ft:=1",
"|action_sparse default_ft:=2",
]
)
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
Expand All @@ -80,10 +80,10 @@ def test_pickbest_textembedder_w_full_label_no_emb() -> None:
named_actions = {"action": ["0", "1", "2"]}
expected = "\n".join(
[
"shared |context_sparse raw:=context",
"0:-0.0:1.0 |action_sparse raw:=0",
"|action_sparse raw:=1",
"|action_sparse raw:=2",
"shared |context_sparse default_ft:=context",
"0:-0.0:1.0 |action_sparse default_ft:=0",
"|action_sparse default_ft:=1",
"|action_sparse default_ft:=2",
]
)

Expand Down Expand Up @@ -142,10 +142,10 @@ def test_pickbest_textembedder_w_full_label_w_embed_and_keep() -> None:
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
expected = "\n".join(
[
f"shared |context_dense {encoded_ctx_str} |context_sparse raw:={ctx_str}",
"0:-0.0:1.0 |action_dense 0:1.0 1:0.0 |action_sparse raw:=0",
"|action_dense 0:1.0 1:0.0 |action_sparse raw:=1",
"|action_dense 0:1.0 1:0.0 |action_sparse raw:=2",
f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft:={ctx_str}",
"0:-0.0:1.0 |action_dense 0:1.0 1:0.0 |action_sparse default_ft:=0",
"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=1",
"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2",
]
) # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
Expand All @@ -164,10 +164,10 @@ def test_pickbest_textembedder_more_namespaces_no_label_no_emb() -> None:
context = {"context1": "context1", "context2": "context2"}
expected = "\n".join(
[
"shared |context1_sparse raw:=context1 |context2_sparse raw:=context2 ",
"|a_sparse raw:=0 |b_sparse raw:=0",
"|action1_sparse raw:=1",
"|action1_sparse raw:=2",
"shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2 ",
"|a_sparse default_ft:=0 |b_sparse default_ft:=0",
"|action1_sparse default_ft:=1",
"|action1_sparse default_ft:=2",
]
) # noqa: E501
event = pick_best_chain.PickBestEvent(
Expand All @@ -185,10 +185,10 @@ def test_pickbest_textembedder_more_namespaces_w_label_no_emb() -> None:
context = {"context1": "context1", "context2": "context2"}
expected = "\n".join(
[
"shared |context1_sparse raw:=context1 |context2_sparse raw:=context2",
"|a_sparse raw:=0 |b_sparse raw:=0",
"|action_sparse raw:=1",
"|action_sparse raw:=2",
"shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2",
"|a_sparse default_ft:=0 |b_sparse default_ft:=0",
"|action_sparse default_ft:=1",
"|action_sparse default_ft:=2",
]
) # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0)
Expand All @@ -207,10 +207,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_no_emb() -> None:
context = {"context1": "context1", "context2": "context2"}
expected = "\n".join(
[
"shared |context1_sparse raw:=context1 |context2_sparse raw:=context2",
"0:-0.0:1.0 |a_sparse raw:=0 |b_sparse raw:=0",
"|action_sparse raw:=1",
"|action_sparse raw:=2",
"shared |context1_sparse default_ft:=context1 |context2_sparse default_ft:=context2",
"0:-0.0:1.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0",
"|action_sparse default_ft:=1",
"|action_sparse default_ft:=2",
]
) # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
Expand Down Expand Up @@ -282,10 +282,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_full_embed_and_kee
}
expected = "\n".join(
[
f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2} |context1_sparse raw:={ctx_str_1} |context2_sparse raw:={ctx_str_2}",
f"0:-0.0:1.0 |a_dense 0:1.0 1:0.0 |b_dense 0:1.0 1:0.0 |a_sparse raw:=0 |b_sparse raw:=0",
f"|action_dense 0:1.0 1:0.0 |action_sparse raw:=1",
f"|action_dense 0:1.0 1:0.0 |action_sparse raw:=2",
f"shared |context1_dense {encoded_ctx_str_1} |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1} |context2_sparse default_ft:={ctx_str_2}",
f"0:-0.0:1.0 |a_dense 0:1.0 1:0.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0",
f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=1",
f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2",
]
) # noqa: E501

Expand Down Expand Up @@ -317,9 +317,9 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emb() -> N

expected = "\n".join(
[
f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse raw:={ctx_str_1}",
f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse raw:=0",
f"|action_sparse raw:=1",
f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1}",
f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0",
f"|action_sparse default_ft:=1",
f"|action_dense 0:1.0 1:0.0",
]
) # noqa: E501
Expand Down Expand Up @@ -355,10 +355,10 @@ def test_pickbest_textembedder_more_namespaces_w_full_label_w_partial_emakeep()
context = {"context1": ctx_str_1, "context2": rl_chain.EmbedAndKeep(ctx_str_2)}
expected = "\n".join(
[
f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse raw:={ctx_str_1} |context2_sparse raw:={ctx_str_2}",
f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse raw:=0 |b_sparse raw:=0",
f"|action_sparse raw:=1",
f"|action_dense 0:1.0 1:0.0 |action_sparse raw:=2",
f"shared |context2_dense {encoded_ctx_str_2} |context1_sparse default_ft:={ctx_str_1} |context2_sparse default_ft:={ctx_str_2}",
f"0:-0.0:1.0 |b_dense 0:1.0 1:0.0 |a_sparse default_ft:=0 |b_sparse default_ft:=0",
f"|action_sparse default_ft:=1",
f"|action_dense 0:1.0 1:0.0 |action_sparse default_ft:=2",
]
) # noqa: E501
selected = pick_best_chain.PickBestSelected(index=0, probability=1.0, score=0.0)
Expand Down Expand Up @@ -386,8 +386,8 @@ def test_raw_features_underscored() -> None:
context = {"context": ctx_str}
expected_no_embed = "\n".join(
[
f"shared |context_sparse raw:={ctx_str_underscored}",
f"|action_sparse raw:={str1_underscored}",
f"shared |context_sparse default_ft:={ctx_str_underscored}",
f"|action_sparse default_ft:={str1_underscored}",
]
)

Expand All @@ -414,8 +414,8 @@ def test_raw_features_underscored() -> None:
context = {"context": rl_chain.EmbedAndKeep(ctx_str)}
expected_embed_and_keep = "\n".join(
[
f"shared |context_dense {encoded_ctx_str} |context_sparse raw:={ctx_str_underscored}",
f"|action_dense {encoded_str1} |action_sparse raw:={str1_underscored}",
f"shared |context_dense {encoded_ctx_str} |context_sparse default_ft:={ctx_str_underscored}",
f"|action_dense {encoded_str1} |action_sparse default_ft:={str1_underscored}",
]
) # noqa: E501
event = pick_best_chain.PickBestEvent(
Expand Down
Loading

0 comments on commit c333c0a

Please sign in to comment.