Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Nov 14, 2023
1 parent 063faaf commit 2ec66ab
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 53 deletions.
9 changes: 3 additions & 6 deletions src/learn_to_pick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@
PickBestSelected,
)

from learn_to_pick.byom.pytorch_policy import (
PyTorchPolicy
)
from learn_to_pick.byom.pytorch_policy import PyTorchPolicy

from learn_to_pick.byom.pytorch_feature_embedder import PyTorchFeatureEmbedder

from learn_to_pick.byom.pytorch_feature_embedder import (
PyTorchFeatureEmbedder
)

def configure_logger() -> None:
logger = logging.getLogger(__name__)
Expand Down
7 changes: 5 additions & 2 deletions src/learn_to_pick/byom/igw.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import torch


def IGW(fhat, gamma):
from math import sqrt

fhatahat, ahat = fhat.max(dim=1)
A = fhat.shape[1]
gamma *= sqrt(A)
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat))
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat))
sump = p.sum(dim=1)
p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None)
return torch.multinomial(p, num_samples=1).squeeze(1), ahat


def SamplingIGW(A, P, gamma):
exploreind, _ = IGW(P, gamma)
explore = [ ind for _, ind in zip(A, exploreind) ]
explore = [ind for _, ind in zip(A, exploreind)]
return explore
35 changes: 27 additions & 8 deletions src/learn_to_pick/byom/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,27 @@
import torch
import torch.nn.functional as F


class MLP(torch.nn.Module):
@staticmethod
def new_gelu(x):
import math
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

return (
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)

def __init__(self, dim):
super().__init__()
self.c_fc = torch.nn.Linear(dim, 4 * dim)
self.c_proj = torch.nn.Linear(4 * dim, dim)
self.c_fc = torch.nn.Linear(dim, 4 * dim)
self.c_proj = torch.nn.Linear(4 * dim, dim)
self.dropout = torch.nn.Dropout(0.5)

def forward(self, x):
Expand All @@ -21,6 +32,7 @@ def forward(self, x):
x = self.dropout(x)
return x


class Block(torch.nn.Module):
def __init__(self, dim):
super().__init__()
Expand All @@ -29,13 +41,14 @@ def __init__(self, dim):
def forward(self, x):
return x + self.layer(x)


class ResidualLogisticRegressor(torch.nn.Module):
def __init__(self, in_features, depth, device):
super().__init__()
self._in_features = in_features
self._depth = depth
self.blocks = torch.nn.Sequential(*[ Block(in_features) for _ in range(depth) ])
self.linear = torch.nn.Linear(in_features=in_features, out_features=1)
self.blocks = torch.nn.Sequential(*[Block(in_features) for _ in range(depth)])
self.linear = torch.nn.Linear(in_features=in_features, out_features=1)
self.optim = parameterfree.COCOB(self.parameters())
self._device = device

Expand All @@ -53,9 +66,15 @@ def logits(self, X, A):
# X = batch x features
# A = batch x actionbatch x actionfeatures

Xreshap = X.unsqueeze(1).expand(-1, A.shape[1], -1) # batch x actionbatch x features
XA = torch.cat((Xreshap, A), dim=-1).reshape(X.shape[0], A.shape[1], -1).to(self._device) # batch x actionbatch x (features + actionfeatures)
return self.linear(self.blocks(XA)).squeeze(2) # batch x actionbatch
Xreshap = X.unsqueeze(1).expand(
-1, A.shape[1], -1
) # batch x actionbatch x features
XA = (
torch.cat((Xreshap, A), dim=-1)
.reshape(X.shape[0], A.shape[1], -1)
.to(self._device)
) # batch x actionbatch x (features + actionfeatures)
return self.linear(self.blocks(XA)).squeeze(2) # batch x actionbatch

def predict(self, X, A):
self.eval()
Expand Down
15 changes: 9 additions & 6 deletions src/learn_to_pick/byom/pytorch_feature_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from sentence_transformers import SentenceTransformer
import torch

class PyTorchFeatureEmbedder(): #rl_chain.Embedder[rl_chain.PickBestEvent]
def __init__(
self, auto_embed, model = None, *args, **kwargs
):

class PyTorchFeatureEmbedder:
def __init__(self, auto_embed, model=None, *args, **kwargs):
if model is None:
model = model = SentenceTransformer('all-MiniLM-L6-v2')
model = model = SentenceTransformer("all-MiniLM-L6-v2")

self.model = model
self.auto_embed = auto_embed
Expand Down Expand Up @@ -84,4 +83,8 @@ def format(self, event: rl_chain.PickBestEvent):
if cost is None:
return context, actions
else:
return torch.Tensor([[-1.0 * cost]]), context, actions[:,chosen_action,:].unsqueeze(1)
return (
torch.Tensor([[-1.0 * cost]]),
context,
actions[:, chosen_action, :].unsqueeze(1),
)
27 changes: 14 additions & 13 deletions src/learn_to_pick/byom/pytorch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ def __init__(
self,
feature_embedder,
depth: int = 2,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
device: str = "cuda" if torch.cuda.is_available() else "cpu",
*args,
**kwargs,
):
print(f"Device: {device}")
super().__init__(*args, **kwargs)
self.workspace = ResidualLogisticRegressor(
feature_embedder.model.get_sentence_embedding_dimension() * 2, depth, device).to(device)
feature_embedder.model.get_sentence_embedding_dimension() * 2, depth, device
).to(device)
self.feature_embedder = feature_embedder
self.device = device
self.index = 0
Expand All @@ -34,6 +35,7 @@ def predict(self, event):
p = self.workspace.predict(X, A)
# print(f"p: {p}")
import math

explore = SamplingIGW(A, p, math.sqrt(self.index))
self.index += 1
# print(f"explore: {explore}")
Expand All @@ -58,30 +60,29 @@ def log(self, event):

def save(self, path) -> None:
state = {
'workspace_state_dict': self.workspace.state_dict(),
'optimizer_state_dict': self.workspace.optim.state_dict(),
'device': self.device,
'index': self.index,
'loss': self.loss
"workspace_state_dict": self.workspace.state_dict(),
"optimizer_state_dict": self.workspace.optim.state_dict(),
"device": self.device,
"index": self.index,
"loss": self.loss,
}
print(f"Saving model to {path}")
dir, _ = os.path.split(path)
if dir and not os.path.exists(dir):
os.makedirs(dir, exist_ok=True)
torch.save(state, path)


def load(self, path) -> None:
import parameterfree

if os.path.exists(path):
print(f"Loading model from {path}")
checkpoint = torch.load(path, map_location=self.device)

self.workspace.load_state_dict(checkpoint['workspace_state_dict'])
self.workspace.load_state_dict(checkpoint["workspace_state_dict"])
self.workspace.optim = parameterfree.COCOB(self.workspace.parameters())
self.workspace.optim.load_state_dict(checkpoint['optimizer_state_dict'])
self.device = checkpoint['device']
self.workspace.optim.load_state_dict(checkpoint["optimizer_state_dict"])
self.device = checkpoint["device"]
self.workspace.to(self.device)
self.index = checkpoint['index']
self.loss = checkpoint['loss']
self.index = checkpoint["index"]
self.loss = checkpoint["loss"]
41 changes: 23 additions & 18 deletions tests/unit_tests/test_byom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,21 @@
import learn_to_pick


CHECKPOINT_DIR = 'test_models'
CHECKPOINT_DIR = "test_models"


@pytest.fixture
def remove_checkpoint():
yield
if os.path.isdir(CHECKPOINT_DIR):
shutil.rmtree(CHECKPOINT_DIR)


class CustomSelectionScorer(learn_to_pick.SelectionScorer):
def get_score(self, user, time_of_day, article):
preferences = {
'Tom': {
'morning': 'politics',
'afternoon': 'music'
},
'Anna': {
'morning': 'sports',
'afternoon': 'politics'
}
"Tom": {"morning": "politics", "afternoon": "music"},
"Anna": {"morning": "sports", "afternoon": "politics"},
}

return int(preferences[user][time_of_day] == article)
Expand Down Expand Up @@ -58,16 +54,19 @@ def run(self, pytorch_picker, T):
user = self._choose_user()
time_of_day = self._choose_time_of_day()
pytorch_picker.run(
article = learn_to_pick.ToSelectFrom(self.articles),
user = learn_to_pick.BasedOn(user),
time_of_day = learn_to_pick.BasedOn(time_of_day),
article=learn_to_pick.ToSelectFrom(self.articles),
user=learn_to_pick.BasedOn(user),
time_of_day=learn_to_pick.BasedOn(time_of_day),
)


def verify_same_models(model1, model2):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.equal(p1, p2), "The models' parameters are not equal."

for (name1, buffer1), (name2, buffer2) in zip(model1.named_buffers(), model2.named_buffers()):
for (name1, buffer1), (name2, buffer2) in zip(
model1.named_buffers(), model2.named_buffers()
):
assert name1 == name2, "Buffer names do not match."
assert torch.equal(buffer1, buffer2), f"The buffers {name1} are not equal."

Expand All @@ -86,7 +85,7 @@ def verify_same_optimizers(optimizer1, optimizer2):
return False

for key in state_dict1:
if key == 'state':
if key == "state":
if state_dict1[key].keys() != state_dict2[key].keys():
return False
for subkey in state_dict1[key]:
Expand All @@ -104,26 +103,32 @@ def test_save_load(remove_checkpoint):
sim2 = Simulator()

fe = learn_to_pick.PyTorchFeatureEmbedder(auto_embed=True)
first_model_path = f'{CHECKPOINT_DIR}/first.checkpoint'
first_model_path = f"{CHECKPOINT_DIR}/first.checkpoint"

torch.manual_seed(0)
first_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe)
second_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe)

torch.manual_seed(0)

first_picker = learn_to_pick.PickBest.create(policy=first_byom, selection_scorer=CustomSelectionScorer())
first_picker = learn_to_pick.PickBest.create(
policy=first_byom, selection_scorer=CustomSelectionScorer()
)
sim1.run(first_picker, 5)
first_byom.save(first_model_path)

second_byom.load(first_model_path)
second_picker = learn_to_pick.PickBest.create(policy=second_byom, selection_scorer=CustomSelectionScorer())
second_picker = learn_to_pick.PickBest.create(
policy=second_byom, selection_scorer=CustomSelectionScorer()
)
sim1.run(second_picker, 5)

torch.manual_seed(0)
all_byom = learn_to_pick.PyTorchPolicy(feature_embedder=fe)
torch.manual_seed(0)
all_picker = learn_to_pick.PickBest.create(policy=all_byom, selection_scorer=CustomSelectionScorer())
all_picker = learn_to_pick.PickBest.create(
policy=all_byom, selection_scorer=CustomSelectionScorer()
)
sim2.run(all_picker, 10)

verify_same_models(second_byom.workspace, all_byom.workspace)
Expand Down

0 comments on commit 2ec66ab

Please sign in to comment.