Skip to content

Commit

Permalink
add stop gradient flag
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Jun 24, 2023
1 parent e966745 commit 529f9aa
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 55 deletions.
19 changes: 15 additions & 4 deletions benchmarks/transformers/prob_model_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
accuracy,
expected_calibration_error,
)
from fortuna.model_editor.classification import ProbitClassificationModelEditor
from fortuna.prob_model import (
ADVIPosteriorApproximator,
DeepEnsemblePosteriorApproximator,
Expand All @@ -52,6 +51,7 @@
SNGPPosteriorApproximator,
SWAGPosteriorApproximator,
)
from fortuna.model_editor import ProbitModelEditor
from fortuna.prob_model.fit_config.hyperparameters import FitHyperparameters
from fortuna.prob_model.posterior.posterior_approximations import (
ADVI_NAME,
Expand Down Expand Up @@ -214,6 +214,9 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
parser.add_argument("--sgmcmc_polynomial_schedule_gamma", type=float, default=0.55)
parser.add_argument("--sgmcmc_preconditioner", type=strbool, default=False)
parser.add_argument("--sghmc_momentum_decay", type=float, default=0.01)
# model editor
parser.add_argument("--enable_probit_model_editor", type=strbool, default=False)
parser.add_argument("--init_probit_log_var", type=float, default=-5)
# optimizer
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--adam_eps", type=float, default=1e-8)
Expand Down Expand Up @@ -393,6 +396,13 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
),
}

model_editor = None
if args.enable_probit_model_editor:
model_editor = ProbitModelEditor(
freeze_fun=lambda p, v: True if "classifier" in p else False,
init_log_var=0.
)

### TRAINING
prob_model = ProbClassifier(
model=model,
Expand All @@ -401,9 +411,7 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
],
prior=IsotropicGaussianPrior(log_var=args.prior_log_var),
output_calibrator=None,
model_editor=ProbitClassificationModelEditor(
freeze_fun=lambda p, v: True if "classifier" in p else False, top_k=10
),
model_editor=model_editor,
)

fit_config = FitConfig(
Expand Down Expand Up @@ -489,6 +497,9 @@ def unpack_model_tar(model_ckpt_path: pathlib.Path) -> pathlib.Path:
"Either restore_checkpoint_path or num_train_epochs > 0 should be specified."
)

if args.enable_probit_model_editor:
logger.info(f"Probit log-variance: {prob_model.posterior.state.get().params['model_editor']['params']['log_var']}")

### IN-D PERFORMANCE
test_inputs_loader = test_data_loader.to_inputs_loader()
test_targets = test_data_loader.to_array_targets()
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ hparams:
per_device_eval_batch_size: 32
per_device_train_batch_size: 32
learning_rate: 2e-05
num_warmup_steps: 10000
num_warmup_steps: 500
prior_log_var: 100.0
weight_decay: 0.01
1 change: 1 addition & 0 deletions fortuna/model_editor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from fortuna.model_editor.probit import ProbitModelEditor
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
from fortuna.utils.probit import sequential_probit_scaling


class ProbitClassificationModelEditor(ModelEditor):
class ProbitModelEditor(ModelEditor):
freeze_fun: Optional[Callable[[Tuple[AnyKey, ...], Array], str]] = None
top_k: Optional[int] = None
memory: Optional[int] = None
n_final_tokens: Optional[int] = None
init_log_var: float = -10.0
init_log_var: float = -5.0
stop_gradient: bool = False

@nn.compact
def __call__(
Expand All @@ -50,5 +51,6 @@ def __call__(
top_k=self.top_k,
memory=self.memory,
n_final_tokens=self.n_final_tokens,
stop_gradient=self.stop_gradient
)
return outputs
9 changes: 6 additions & 3 deletions fortuna/utils/probit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def sequential_probit_scaling(
top_k: Optional[int] = None,
memory: Optional[int] = None,
n_final_tokens: Optional[int] = None,
stop_gradient: bool = False
) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Dict]]:
params = params.unfreeze()

Expand Down Expand Up @@ -161,7 +162,7 @@ def _compute_cov(_x, idx):

return vmap(J1J2T_op)(jnp.eye(size)).T

return jnp.where(prev_tau != 0, _compute_cov(x, indices), jnp.empty(block_size))
return jnp.where(prev_tau != -1, _compute_cov(x, indices), jnp.empty(block_size))

init_tau = seq_length - n_final_tokens + 1

Expand Down Expand Up @@ -200,7 +201,7 @@ def fun(carry, tau):

def get_diagCs(_params):
old_taus = jnp.concatenate(
(jnp.zeros(memory - 1, dtype="int32"), jnp.array([init_tau], dtype="int32"))
(jnp.zeros(memory - 1, dtype="int32") - 1, jnp.array([init_tau], dtype="int32"))
)
C = compute_cov(old_taus[-1], old_taus[-1])

Expand Down Expand Up @@ -233,7 +234,9 @@ def get_diagCs(_params):

return diagCs

diagCs = lax.stop_gradient(get_diagCs(params if sub_params is None else sub_params))
diagCs = get_diagCs(params if sub_params is None else sub_params)
if stop_gradient:
diagCs = lax.stop_gradient(diagCs)

if top_k is not None:
scales = jnp.max(diagCs, axis=2, keepdims=True).repeat(n_outputs, axis=2)
Expand Down

0 comments on commit 529f9aa

Please sign in to comment.