-
Notifications
You must be signed in to change notification settings - Fork 12
/
callbacks.py
78 lines (66 loc) · 3.85 KB
/
callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from transformers.training_args import TrainingArguments
from transformers.utils import logging, is_torch_tpu_available
logger = logging.get_logger(__name__)
from transformers.integrations import WandbCallback
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
class MyWandbCallback(WandbCallback):
def setup(self, args, state, model, **kwargs):
"""
Setup the optional Weights & Biases (*wandb*) integration.
One can subclass and override this method to customize the setup if needed. Find more information
[here](https://docs.wandb.ai/guides/integrations/huggingface). You can also override the following environment
variables:
Environment:
- **WANDB_LOG_MODEL** (`str`, *optional*, defaults to `"false"`):
Whether to log model and checkpoints during training. Can be `"end"`, `"checkpoint"` or `"false"`. If set
to `"end"`, the model will be uploaded at the end of training. If set to `"checkpoint"`, the checkpoint
will be uploaded every `args.save_steps` . If set to `"false"`, the model will not be uploaded. Use along
with [`~transformers.TrainingArguments.load_best_model_at_end`] to upload best model.
<Deprecated version="5.0">
Setting `WANDB_LOG_MODEL` as `bool` will be deprecated in version 5 of 🤗 Transformers.
</Deprecated>
- **WANDB_WATCH** (`str`, *optional* defaults to `"false"`):
Can be `"gradients"`, `"all"`, `"parameters"`, or `"false"`. Set to `"all"` to log gradients and
parameters.
"""
if self._wandb is None:
return
self._initialized = True
if state.is_world_process_zero:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**args.to_dict()}
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
trial_name = state.trial_name
init_args = {}
if trial_name is not None:
init_args["name"] = trial_name
init_args["group"] = args.run_name
else:
if not (args.run_name is None or args.run_name == args.output_dir):
init_args["name"] = args.run_name
if self._wandb.run is None:
self._wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"),
settings=self._wandb.Settings(code_dir="/share/project/qiying/projects/vision-tokenizer"),
**init_args,
)
# add config parameters (run may have been created manually)
self._wandb.config.update(combined_dict, allow_val_change=True)
# define default x-axis (for latest wandb versions)
if getattr(self._wandb, "define_metric", None):
self._wandb.define_metric("train/global_step")
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
# keep track of model topology and gradients, unsupported on TPU
_watch_model = os.getenv("WANDB_WATCH", "false")
if not is_torch_tpu_available() and _watch_model in ("all", "parameters", "gradients"):
self._wandb.watch(model, log=_watch_model, log_freq=max(100, state.logging_steps))
self._wandb.run._label(code="transformers_trainer")
class ModelCallback(TrainerCallback):
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
kwargs['model'].global_step = state.global_step
return super().on_step_begin(args, state, control, **kwargs)