Skip to content

Commit

Permalink
refactor sequential probit implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Jun 23, 2023
1 parent 915a1ea commit e966745
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 79 deletions.
7 changes: 1 addition & 6 deletions fortuna/metric/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,7 @@ def brier_score(probs: Array, targets: Union[TargetsLoader, Array]) -> jnp.ndarr
jnp.ndarray
The Brier score.
"""
if probs.ndim != 2:
raise ValueError(
"""`probs` must be a two-dimensional array of probabilities for each class and each data
point."""
)
if type(targets) == TargetsLoader:
targets = targets.to_array_targets()
targets = jax.nn.one_hot(targets, probs.shape[1])
targets = jax.nn.one_hot(targets, probs.shape[-1])
return jnp.mean(jnp.sum((probs - targets) ** 2, axis=1))
7 changes: 5 additions & 2 deletions fortuna/model_editor/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ProbitClassificationModelEditor(ModelEditor):
top_k: Optional[int] = None
memory: Optional[int] = None
n_final_tokens: Optional[int] = None
init_log_var: float = -10.0

@nn.compact
def __call__(
Expand All @@ -36,7 +37,9 @@ def __call__(
x: Any,
has_aux: bool,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]:
log_var = self.param("log_var", nn.initializers.zeros, (1,))
log_var = self.param(
"log_var", nn.initializers.constant(self.init_log_var), (1,)
)
outputs = sequential_probit_scaling(
apply_fn,
model_params,
Expand All @@ -46,6 +49,6 @@ def __call__(
freeze_fun=self.freeze_fun,
top_k=self.top_k,
memory=self.memory,
n_final_tokens=self.n_final_tokens
n_final_tokens=self.n_final_tokens,
)
return outputs
189 changes: 118 additions & 71 deletions fortuna/utils/probit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,34 @@
Union,
)

from flax.core import FrozenDict
from jax import (
ShapeDtypeStruct,
jit,
jvp,
lax,
pure_callback,
vjp,
vmap,
)
import jax.numpy as jnp
from jax import vmap, jvp, vjp, jit
from jax.tree_util import tree_map
from jax.tree_util import (
tree_map,
tree_reduce,
)

from fortuna.typing import (
AnyKey,
Array,
InputData,
Params,
)
from fortuna.utils.nested_dicts import nested_get, nested_set
from fortuna.utils.freeze import get_paths_with_label
from flax.core import FrozenDict
from fortuna.utils.grad import value_and_jacobian_squared_row_norm
from functools import partial
from fortuna.utils.nested_dicts import (
nested_get,
nested_set,
)


def probit_scaling(
Expand Down Expand Up @@ -55,7 +68,7 @@ def sequential_probit_scaling(
freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None,
top_k: Optional[int] = None,
memory: Optional[int] = None,
n_final_tokens: Optional[int] = None
n_final_tokens: Optional[int] = None,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]:
params = params.unfreeze()

Expand Down Expand Up @@ -94,22 +107,31 @@ def _apply_fn(_p, _x, tau):

n_outputs = f.shape[-1]
seq_length = f.shape[1]
if memory is None:
memory = seq_length
if memory <= 0 or memory > seq_length:
raise ValueError(f"`memory` must be greater than 0 and cannot be greater than {seq_length}.")
if n_final_tokens is None:
n_final_tokens = seq_length
if n_final_tokens <= 0 or n_final_tokens > seq_length:
raise ValueError(f"`n_final_tokens` must be greater than 0 and cannot be greater than {seq_length}.")
raise ValueError(
f"`n_final_tokens` must be greater than 0 and cannot be greater than {seq_length}."
)
if memory is None:
memory = n_final_tokens
if memory <= 0 or memory > n_final_tokens:
raise ValueError(
f"`memory` must be greater than 0 and cannot be greater than {n_final_tokens}."
)

block_size = top_k if top_k is not None else n_outputs
tot_size = memory * block_size
batch_size = f.shape[0]

indices = None
if top_k is not None:
indices = vmap(lambda _fx: vmap(lambda _fxtau: jnp.argsort(_fxtau)[-top_k:])(_fx))(f)
indices = vmap(
lambda _fx: vmap(lambda _fxtau: jnp.argsort(_fxtau)[-top_k:])(_fx)
)(f)

x = x[:, None] if not isinstance(x, dict) else tree_map(lambda v: v[:, None], x)

@jit
def compute_cov(new_tau, prev_tau):
new_tau -= 1
prev_tau -= 1
Expand All @@ -120,98 +142,123 @@ def _compute_cov(_x, idx):
prev_idx = idx[prev_tau] if idx is not None else None
size = n_outputs if idx is None else len(prev_idx)

new_fun = lambda p: _apply_fn(p, _x, new_tau)[new_idx] if idx is not None else _apply_fn(p, _x, new_tau)
prev_fun = lambda p: _apply_fn(p, _x, prev_tau)[prev_idx] if idx is not None else _apply_fn(p, _x, prev_tau)
new_fun = (
lambda p: _apply_fn(p, _x, new_tau)[new_idx]
if idx is not None
else _apply_fn(p, _x, new_tau)
)
prev_fun = (
lambda p: _apply_fn(p, _x, prev_tau)[prev_idx]
if idx is not None
else _apply_fn(p, _x, prev_tau)
)

J1J2T_op = lambda v: jvp(
new_fun,
(sub_params if params_paths is not None else params,),
vjp(prev_fun, sub_params if params_paths is not None else params)[1](v)
vjp(prev_fun, sub_params if params_paths is not None else params)[1](v),
)[1]

return vmap(J1J2T_op)(jnp.eye(size)).T
return _compute_cov(x, indices)

def compute_P(new_tau):
P = vmap(
lambda tau: compute_cov(new_tau, tau),
out_axes=2
)(jnp.arange(max(seq_length - n_final_tokens + 1, new_tau - memory), new_tau))
return jnp.where(prev_tau != 0, _compute_cov(x, indices), jnp.empty(block_size))

init_tau = seq_length - n_final_tokens + 1

@jit
def compute_P(new_tau, old_taus):
P = vmap(lambda tau: compute_cov(new_tau, tau), out_axes=2)(old_taus)
return P.reshape(P.shape[0], P.shape[1], P.shape[2] * P.shape[3])

@vmap
def get_diag(mat):
return jnp.diag(mat)

@partial(jit, static_argnums=(1,))
def fun(Jinv, tau):
P = compute_P(tau)
if Jinv.shape[1] != P.shape[2]:
Jinv = Jinv[:, -P.shape[2]:, -P.shape[2]:]
def fun(carry, tau):
Jinv, old_taus = carry
S = compute_cov(tau, tau)

P = compute_P(tau, old_taus)
M = jnp.matmul(P, Jinv)
C = S - jnp.matmul(M, P.swapaxes(1, 2))

M = M[:, :, block_size:]
Jinv = Jinv[:, block_size:, block_size:]
Cinv = jnp.linalg.inv(C)
MtCinv = jnp.matmul(M.swapaxes(1, 2), Cinv)

Jinv = jnp.concatenate(
(
jnp.concatenate(
(
Jinv + jnp.matmul(MtCinv, M),
-MtCinv
),
axis=2
),
jnp.concatenate(
(
-MtCinv.swapaxes(1, 2),
Cinv
),
axis=2
)
jnp.concatenate((Jinv + jnp.matmul(MtCinv, M), -MtCinv), axis=2),
jnp.concatenate((-MtCinv.swapaxes(1, 2), Cinv), axis=2),
),
axis=1
axis=1,
)

return Jinv, get_diag(C)
old_taus = jnp.concatenate((old_taus[1:], jnp.array([tau])))
return (Jinv, old_taus), get_diag(C)

C = compute_cov(seq_length - n_final_tokens, seq_length - n_final_tokens)
diagCs = [get_diag(C)]
if seq_length > 1:
Jinv = jnp.linalg.inv(C)
for tau in range(seq_length - n_final_tokens + 2, seq_length + 1):
Jinv, _diagC = fun(Jinv, tau)
diagCs.append(_diagC)
diagCs = jnp.stack(diagCs, axis=1)

if n_final_tokens < seq_length:
diagCs = jnp.concatenate(
(
jnp.max(diagCs, 1, keepdims=True).repeat(seq_length - n_final_tokens, axis=1),
diagCs
),
axis=1
def get_diagCs(_params):
old_taus = jnp.concatenate(
(jnp.zeros(memory - 1, dtype="int32"), jnp.array([init_tau], dtype="int32"))
)
C = compute_cov(old_taus[-1], old_taus[-1])

if n_final_tokens > 1:
Jinv = jnp.linalg.inv(C)
Jinv = jnp.concatenate(
(
jnp.zeros((batch_size, (memory - 1) * block_size, tot_size)),
jnp.concatenate(
(
jnp.zeros(
(batch_size, block_size, (memory - 1) * block_size)
),
Jinv,
),
axis=2,
),
),
axis=1,
)

scales = jnp.max(diagCs, axis=2, keepdims=True)
_, diagCs = lax.scan(
fun, (Jinv, old_taus), jnp.arange(init_tau + 1, seq_length + 1)
)
diagCs = jnp.concatenate(
(get_diag(C)[:, None], diagCs.swapaxes(0, 1)), axis=1
)
else:
diagCs = get_diag(C)[:, None]

if top_k is not None:
scales = jnp.ones_like(f) * scales
for i in range(f.shape[0]):
for j in range(f.shape[1]):
scales = scales.at[i, j, indices[i, j]].set(diagCs[i, j])
return diagCs

diagCs = lax.stop_gradient(get_diagCs(params if sub_params is None else sub_params))

f /= 1 + jnp.pi / 8 * jnp.exp(log_var) * scales
if top_k is not None:
scales = jnp.max(diagCs, axis=2, keepdims=True).repeat(n_outputs, axis=2)
scales = vmap(
lambda i: vmap(
lambda j: scales[i, j]
.at[indices[i, j]]
.set(diagCs[i, j, indices[i, j]])
)(jnp.arange(seq_length - n_final_tokens, seq_length))
)(jnp.arange(batch_size))
else:
scales = diagCs

f = jnp.concatenate(
(
f[:, : seq_length - n_final_tokens],
f[:, seq_length - n_final_tokens :]
/ (1 + jnp.pi / 8 * jnp.exp(log_var) * scales),
),
axis=1,
)

if seq_length == 1:
f = f[:, 0]

if has_aux:
return f, aux
return f


def vmap_jmp(fun, params, mat):
_jvp = lambda s: jvp(fun, (params,), (s,))[1]
return vmap(_jvp)(mat)

0 comments on commit e966745

Please sign in to comment.