Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Jun 24, 2023
1 parent 529f9aa commit 42d2117
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
8 changes: 5 additions & 3 deletions benchmarks/transformers/prob_model_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
accuracy,
expected_calibration_error,
)
from fortuna.model_editor import ProbitModelEditor
from fortuna.prob_model import (
ADVIPosteriorApproximator,
DeepEnsemblePosteriorApproximator,
Expand All @@ -51,7 +52,6 @@
SNGPPosteriorApproximator,
SWAGPosteriorApproximator,
)
from fortuna.model_editor import ProbitModelEditor
from fortuna.prob_model.fit_config.hyperparameters import FitHyperparameters
from fortuna.prob_model.posterior.posterior_approximations import (
ADVI_NAME,
Expand Down Expand Up @@ -400,7 +400,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
if args.enable_probit_model_editor:
model_editor = ProbitModelEditor(
freeze_fun=lambda p, v: True if "classifier" in p else False,
init_log_var=0.
init_log_var=0.0,
)

### TRAINING
Expand Down Expand Up @@ -498,7 +498,9 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
)

if args.enable_probit_model_editor:
logger.info(f"Probit log-variance: {prob_model.posterior.state.get().params['model_editor']['params']['log_var']}")
logger.info(
f"Probit log-variance: {prob_model.posterior.state.get().params['model_editor']['params']['log_var']}"
)

### IN-D PERFORMANCE
test_inputs_loader = test_data_loader.to_inputs_loader()
Expand Down
2 changes: 1 addition & 1 deletion fortuna/model_editor/probit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def __call__(
top_k=self.top_k,
memory=self.memory,
n_final_tokens=self.n_final_tokens,
stop_gradient=self.stop_gradient
stop_gradient=self.stop_gradient,
)
return outputs
11 changes: 8 additions & 3 deletions fortuna/utils/probit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def sequential_probit_scaling(
top_k: Optional[int] = None,
memory: Optional[int] = None,
n_final_tokens: Optional[int] = None,
stop_gradient: bool = False
stop_gradient: bool = False,
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]:
params = params.unfreeze()

Expand Down Expand Up @@ -162,7 +162,9 @@ def _compute_cov(_x, idx):

return vmap(J1J2T_op)(jnp.eye(size)).T

return jnp.where(prev_tau != -1, _compute_cov(x, indices), jnp.empty(block_size))
return jnp.where(
prev_tau != -1, _compute_cov(x, indices), jnp.empty(block_size)
)

init_tau = seq_length - n_final_tokens + 1

Expand Down Expand Up @@ -201,7 +203,10 @@ def fun(carry, tau):

def get_diagCs(_params):
old_taus = jnp.concatenate(
(jnp.zeros(memory - 1, dtype="int32") - 1, jnp.array([init_tau], dtype="int32"))
(
jnp.zeros(memory - 1, dtype="int32") - 1,
jnp.array([init_tau], dtype="int32"),
)
)
C = compute_cov(old_taus[-1], old_taus[-1])

Expand Down

0 comments on commit 42d2117

Please sign in to comment.