diff --git a/benchmarks/transformers/prob_model_text_classification.py b/benchmarks/transformers/prob_model_text_classification.py index 09ff2828..5b33c7f6 100644 --- a/benchmarks/transformers/prob_model_text_classification.py +++ b/benchmarks/transformers/prob_model_text_classification.py @@ -36,6 +36,7 @@ accuracy, expected_calibration_error, ) +from fortuna.model_editor import ProbitModelEditor from fortuna.prob_model import ( ADVIPosteriorApproximator, DeepEnsemblePosteriorApproximator, @@ -51,7 +52,6 @@ SNGPPosteriorApproximator, SWAGPosteriorApproximator, ) -from fortuna.model_editor import ProbitModelEditor from fortuna.prob_model.fit_config.hyperparameters import FitHyperparameters from fortuna.prob_model.posterior.posterior_approximations import ( ADVI_NAME, @@ -400,7 +400,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: if args.enable_probit_model_editor: model_editor = ProbitModelEditor( freeze_fun=lambda p, v: True if "classifier" in p else False, - init_log_var=0. + init_log_var=0.0, ) ### TRAINING @@ -498,7 +498,9 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path: ) if args.enable_probit_model_editor: - logger.info(f"Probit log-variance: {prob_model.posterior.state.get().params['model_editor']['params']['log_var']}") + logger.info( + f"Probit log-variance: {prob_model.posterior.state.get().params['model_editor']['params']['log_var']}" + ) ### IN-D PERFORMANCE test_inputs_loader = test_data_loader.to_inputs_loader() diff --git a/fortuna/model_editor/probit.py b/fortuna/model_editor/probit.py index 569f79be..ef8fc2a3 100644 --- a/fortuna/model_editor/probit.py +++ b/fortuna/model_editor/probit.py @@ -51,6 +51,6 @@ def __call__( top_k=self.top_k, memory=self.memory, n_final_tokens=self.n_final_tokens, - stop_gradient=self.stop_gradient + stop_gradient=self.stop_gradient, ) return outputs diff --git a/fortuna/utils/probit.py b/fortuna/utils/probit.py index 9b1fdd9a..4e6aaaad 100644 --- a/fortuna/utils/probit.py +++ b/fortuna/utils/probit.py @@ -69,7 +69,7 @@ def sequential_probit_scaling( top_k: Optional[int] = None, memory: Optional[int] = None, n_final_tokens: Optional[int] = None, - stop_gradient: bool = False + stop_gradient: bool = False, ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]: params = params.unfreeze() @@ -162,7 +162,9 @@ def _compute_cov(_x, idx): return vmap(J1J2T_op)(jnp.eye(size)).T - return jnp.where(prev_tau != -1, _compute_cov(x, indices), jnp.empty(block_size)) + return jnp.where( + prev_tau != -1, _compute_cov(x, indices), jnp.empty(block_size) + ) init_tau = seq_length - n_final_tokens + 1 @@ -201,7 +203,10 @@ def fun(carry, tau): def get_diagCs(_params): old_taus = jnp.concatenate( - (jnp.zeros(memory - 1, dtype="int32") - 1, jnp.array([init_tau], dtype="int32")) + ( + jnp.zeros(memory - 1, dtype="int32") - 1, + jnp.array([init_tau], dtype="int32"), + ) ) C = compute_cov(old_taus[-1], old_taus[-1])