Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Nov 20, 2023
1 parent 45e273e commit 97d9f0c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
20 changes: 13 additions & 7 deletions src/learn_to_pick/byom/pytorch_feature_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from learn_to_pick import PickBestFeaturizer


class PyTorchFeatureEmbedder:
def __init__(self, auto_embed=False, model=None, *args, **kwargs):
if model is None:
Expand All @@ -10,34 +11,39 @@ def __init__(self, auto_embed=False, model=None, *args, **kwargs):
self.model = model
self.featurizer = PickBestFeaturizer(auto_embed=auto_embed)


def encode(self, stuff):
embeddings = self.model.encode(stuff, convert_to_tensor=True)
normalized = torch.nn.functional.normalize(embeddings)
return normalized


def convert_features_to_text(self, features):
def process_feature(feature):
if isinstance(feature, dict):
return " ".join([f"{k}_{process_feature(v)}" for k, v in feature.items()])
return " ".join(
[f"{k}_{process_feature(v)}" for k, v in feature.items()]
)
elif isinstance(feature, list):
return " ".join([process_feature(elem) for elem in feature])
else:
return str(feature)

return process_feature(features)


def format(self, event):
# TODO: handle dense
context_featurized, actions_featurized, selected = self.featurizer.featurize(event)
context_featurized, actions_featurized, selected = self.featurizer.featurize(
event
)

context_sparse = self.encode([self.convert_features_to_text(context_featurized.sparse)])
context_sparse = self.encode(
[self.convert_features_to_text(context_featurized.sparse)]
)

actions_sparse = []
for action_featurized in actions_featurized:
actions_sparse.append(self.convert_features_to_text(action_featurized.sparse))
actions_sparse.append(
self.convert_features_to_text(action_featurized.sparse)
)
actions_sparse = self.encode(actions_sparse).unsqueeze(0)

if selected.score is not None:
Expand Down
1 change: 0 additions & 1 deletion src/learn_to_pick/byom/pytorch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def predict(self, event):
# print(f"returning: {r}")
return r


def learn(self, event):
R, X, A = self.feature_embedder.format(event)
# print(f"R: {R}")
Expand Down

0 comments on commit 97d9f0c

Please sign in to comment.