Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytorch policy #31

Merged
merged 18 commits into from
Nov 28, 2023
337 changes: 337 additions & 0 deletions notebooks/news_recommendation_byom.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from setuptools import setup, find_packages
import os

with open("README.md", "r", encoding="UTF-8") as fh:
long_description = fh.read()
Expand Down
6 changes: 6 additions & 0 deletions src/learn_to_pick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
PickBestSelected,
)

from learn_to_pick.byom.pytorch_policy import PyTorchPolicy
olgavrou marked this conversation as resolved.
Show resolved Hide resolved

from learn_to_pick.byom.pytorch_feature_embedder import PyTorchFeatureEmbedder


def configure_logger() -> None:
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,6 +54,8 @@ def configure_logger() -> None:
"Featurizer",
"ModelRepository",
"Policy",
"PyTorchPolicy",
"PyTorchFeatureEmbedder",
"VwPolicy",
"VwLogger",
"embed",
Expand Down
Empty file.
19 changes: 19 additions & 0 deletions src/learn_to_pick/byom/igw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch


def IGW(fhat, gamma):
from math import sqrt

fhatahat, ahat = fhat.max(dim=1)
A = fhat.shape[1]
gamma *= sqrt(A)
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat))
sump = p.sum(dim=1)
p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None)
return torch.multinomial(p, num_samples=1).squeeze(1), ahat


def SamplingIGW(A, P, gamma):
exploreind, _ = IGW(P, gamma)
explore = [ind for _, ind in zip(A, exploreind)]
return explore
90 changes: 90 additions & 0 deletions src/learn_to_pick/byom/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import parameterfree
import torch
import torch.nn.functional as F


class MLP(torch.nn.Module):
@staticmethod
def new_gelu(x):
import math

return (
0.5
* x
* (
1.0
+ torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))
)
)
)

def __init__(self, dim):
super().__init__()
self.c_fc = torch.nn.Linear(dim, 4 * dim)
self.c_proj = torch.nn.Linear(4 * dim, dim)
self.dropout = torch.nn.Dropout(0.5)

def forward(self, x):
x = self.c_fc(x)
x = self.new_gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x


class Block(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.layer = MLP(dim)

def forward(self, x):
return x + self.layer(x)


class ResidualLogisticRegressor(torch.nn.Module):
def __init__(self, in_features, depth, device):
super().__init__()
self._in_features = in_features
self._depth = depth
self.blocks = torch.nn.Sequential(*[Block(in_features) for _ in range(depth)])
self.linear = torch.nn.Linear(in_features=in_features, out_features=1)
self.optim = parameterfree.COCOB(self.parameters())
self._device = device

def clone(self):
other = ResidualLogisticRegressor(self._in_features, self._depth, self._device)
other.load_state_dict(self.state_dict())
other.optim = parameterfree.COCOB(other.parameters())
other.optim.load_state_dict(self.optim.state_dict())
return other

def forward(self, X, A):
return self.logits(X, A)

def logits(self, X, A):
# X = batch x features
# A = batch x actionbatch x actionfeatures

Xreshap = X.unsqueeze(1).expand(
-1, A.shape[1], -1
) # batch x actionbatch x features
XA = (
torch.cat((Xreshap, A), dim=-1)
.reshape(X.shape[0], A.shape[1], -1)
.to(self._device)
) # batch x actionbatch x (features + actionfeatures)
return self.linear(self.blocks(XA)).squeeze(2) # batch x actionbatch

def predict(self, X, A):
self.eval()
return torch.special.expit(self.logits(X, A))

def bandit_learn(self, X, A, R):
self.train()
self.optim.zero_grad()
output = self(X, A)
loss = F.binary_cross_entropy_with_logits(output, R)
loss.backward()
self.optim.step()
return loss.item()
90 changes: 90 additions & 0 deletions src/learn_to_pick/byom/pytorch_feature_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import learn_to_pick as rl_chain
from sentence_transformers import SentenceTransformer
import torch


class PyTorchFeatureEmbedder:
def __init__(self, auto_embed, model=None, *args, **kwargs):
olgavrou marked this conversation as resolved.
Show resolved Hide resolved
if model is None:
model = model = SentenceTransformer("all-MiniLM-L6-v2")

self.model = model
self.auto_embed = auto_embed

def encode(self, stuff):
embeddings = self.model.encode(stuff, convert_to_tensor=True)
normalized = torch.nn.functional.normalize(embeddings)
return normalized

def get_label(self, event: rl_chain.PickBestEvent) -> tuple:
cost = None
if event.selected:
chosen_action = event.selected.index
cost = (
-1.0 * event.selected.score
if event.selected.score is not None
else None
)
prob = event.selected.probability
return chosen_action, cost, prob
else:
return None, None, None

def get_context_and_action_embeddings(self, event: rl_chain.PickBestEvent) -> tuple:
context_emb = rl_chain.embed(event.based_on, self) if event.based_on else None
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)

action_embs = (
(
rl_chain.embed(to_select_from, self, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None
)

if not context_emb or not action_embs:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
)
return context_emb, action_embs

def format(self, event: rl_chain.PickBestEvent):
chosen_action, cost, prob = self.get_label(event)
context_emb, action_embs = self.get_context_and_action_embeddings(event)

context = ""
for context_item in context_emb:
for ns, based_on in context_item.items():
e = " ".join(based_on) if isinstance(based_on, list) else based_on
context += f"{ns}={e} "

if self.auto_embed:
context = self.encode([context])

actions = []
for action in action_embs:
action_str = ""
for ns, action_embedding in action.items():
e = (
" ".join(action_embedding)
if isinstance(action_embedding, list)
else action_embedding
)
action_str += f"{ns}={e} "
actions.append(action_str)

if self.auto_embed:
actions = self.encode(actions).unsqueeze(0)

if cost is None:
return context, actions
else:
return (
torch.Tensor([[-1.0 * cost]]),
context,
actions[:, chosen_action, :].unsqueeze(1),
)
88 changes: 88 additions & 0 deletions src/learn_to_pick/byom/pytorch_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from learn_to_pick import base, PickBestEvent
from learn_to_pick.byom.logistic_regression import ResidualLogisticRegressor
from learn_to_pick.byom.igw import SamplingIGW
import torch
import os


class PyTorchPolicy(base.Policy[PickBestEvent]):
def __init__(
self,
feature_embedder,
depth: int = 2,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
*args,
**kwargs,
):
print(f"Device: {device}")
super().__init__(*args, **kwargs)
self.workspace = ResidualLogisticRegressor(
feature_embedder.model.get_sentence_embedding_dimension() * 2, depth, device
).to(device)
self.feature_embedder = feature_embedder
self.device = device
self.index = 0
self.loss = None

def predict(self, event):
X, A = self.feature_embedder.format(event)
# print(f"X shape: {X.shape}")
# print(f"A shape: {A.shape}")
# TODO IGW sampling then create the distro so that the one
# that was sampled here is the one that will def be sampled by
# the base sampler, and in the future replace the sampler so that it
# is something that can be plugged in
p = self.workspace.predict(X, A)
# print(f"p: {p}")
import math

explore = SamplingIGW(A, p, math.sqrt(self.index))
self.index += 1
# print(f"explore: {explore}")
r = []
for index in range(p.shape[1]):
if index == explore[0]:
r.append((index, 1))
else:
r.append((index, 0))
# print(f"returning: {r}")
return r
return [(index, val) for index, val in enumerate(p[0].tolist())]

def learn(self, event):
R, X, A = self.feature_embedder.format(event)
# print(f"R: {R}")
R, X, A = R.to(self.device), X.to(self.device), A.to(self.device)
self.loss = self.workspace.bandit_learn(X, A, R)

def log(self, event):
pass

def save(self, path) -> None:
state = {
"workspace_state_dict": self.workspace.state_dict(),
"optimizer_state_dict": self.workspace.optim.state_dict(),
"device": self.device,
"index": self.index,
"loss": self.loss,
}
print(f"Saving model to {path}")
dir, _ = os.path.split(path)
if dir and not os.path.exists(dir):
os.makedirs(dir, exist_ok=True)
torch.save(state, path)

def load(self, path) -> None:
import parameterfree

if os.path.exists(path):
print(f"Loading model from {path}")
checkpoint = torch.load(path, map_location=self.device)

self.workspace.load_state_dict(checkpoint["workspace_state_dict"])
self.workspace.optim = parameterfree.COCOB(self.workspace.parameters())
self.workspace.optim.load_state_dict(checkpoint["optimizer_state_dict"])
self.device = checkpoint["device"]
self.workspace.to(self.device)
self.index = checkpoint["index"]
self.loss = checkpoint["loss"]
Loading
Loading