Skip to content

Commit

Permalink
Add half-precision (bfloat16, float16) support to train & validate sc…
Browse files Browse the repository at this point in the history
…ripts. Should push dtype handling into model factory / pretrained load at some point...
  • Loading branch information
rwightman committed Jan 7, 2025
1 parent 6f80214 commit 92f610c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 26 deletions.
26 changes: 13 additions & 13 deletions timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,18 @@ class PrefetchLoader:

def __init__(
self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
device=torch.device('cuda'),
img_dtype=torch.float32,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):

loader: torch.utils.data.DataLoader,
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
channels: int = 3,
device: torch.device = torch.device('cuda'),
img_dtype: Optional[torch.dtype] = None,
fp16: bool = False,
re_prob: float = 0.,
re_mode: str = 'const',
re_count: int = 1,
re_num_splits: int = 0,
):
mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
normalization_shape = (1, channels, 1, 1)
Expand All @@ -98,7 +98,7 @@ def __init__(
if fp16:
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
img_dtype = torch.float16
self.img_dtype = img_dtype
self.img_dtype = img_dtype or torch.float32
self.mean = torch.tensor(
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
self.std = torch.tensor(
Expand Down
36 changes: 28 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--model-dtype', default=None, type=str,
help='Model dtype override (non-AMP) (default: float32)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--synchronize-step', action='store_true', default=False,
Expand Down Expand Up @@ -434,10 +436,18 @@ def main():
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0

model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)
if model_dtype == torch.float16:
_logger.warning('float16 is not recommended for training, for half precision bfloat16 is recommended.')

# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
use_amp = 'apex'
Expand Down Expand Up @@ -517,7 +527,7 @@ def main():
model = convert_splitbn_model(model, max(num_aug_splits, 2))

# move model to GPU, enable channels last layout if set
model.to(device=device)
model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
if args.channels_last:
model.to(memory_format=torch.channels_last)

Expand Down Expand Up @@ -587,7 +597,7 @@ def main():
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if utils.is_primary(args):
_logger.info('AMP not enabled. Training in float32.')
_logger.info(f'AMP not enabled. Training in {model_dtype}.')

# optionally resume from a checkpoint
resume_epoch = None
Expand Down Expand Up @@ -732,6 +742,7 @@ def main():
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
img_dtype=model_dtype,
device=device,
use_prefetcher=args.prefetcher,
use_multi_epochs_loader=args.use_multi_epochs_loader,
Expand All @@ -756,6 +767,7 @@ def main():
distributed=args.distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
img_dtype=model_dtype,
device=device,
use_prefetcher=args.prefetcher,
)
Expand Down Expand Up @@ -823,9 +835,13 @@ def main():
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
assert not args.wandb_resume_id or args.resume
wandb.init(project=args.experiment, config=args, tags=args.wandb_tags,
resume='must' if args.wandb_resume_id else None,
id=args.wandb_resume_id if args.wandb_resume_id else None)
wandb.init(
project=args.experiment,
config=args,
tags=args.wandb_tags,
resume='must' if args.wandb_resume_id else None,
id=args.wandb_resume_id if args.wandb_resume_id else None,
)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
Expand Down Expand Up @@ -879,6 +895,7 @@ def main():
output_dir=output_dir,
amp_autocast=amp_autocast,
loss_scaler=loss_scaler,
model_dtype=model_dtype,
model_ema=model_ema,
mixup_fn=mixup_fn,
num_updates_total=num_epochs * updates_per_epoch,
Expand All @@ -897,6 +914,7 @@ def main():
args,
device=device,
amp_autocast=amp_autocast,
model_dtype=model_dtype,
)

if model_ema is not None and not args.model_ema_force_cpu:
Expand Down Expand Up @@ -979,6 +997,7 @@ def train_one_epoch(
output_dir=None,
amp_autocast=suppress,
loss_scaler=None,
model_dtype=None,
model_ema=None,
mixup_fn=None,
num_updates_total=None,
Expand Down Expand Up @@ -1015,7 +1034,7 @@ def train_one_epoch(
accum_steps = last_accum_steps

if not args.prefetcher:
input, target = input.to(device), target.to(device)
input, target = input.to(device=device, dtype=model_dtype), target.to(device=device)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
Expand Down Expand Up @@ -1142,6 +1161,7 @@ def validate(
args,
device=torch.device('cuda'),
amp_autocast=suppress,
model_dtype=None,
log_suffix=''
):
batch_time_m = utils.AverageMeter()
Expand All @@ -1157,8 +1177,8 @@ def validate(
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.to(device)
target = target.to(device)
input = input.to(device=device, dtype=model_dtype)
target = target.to(device=device)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)

Expand Down
19 changes: 14 additions & 5 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
parser.add_argument('--model-dtype', default=None, type=str,
help='Model dtype override (non-AMP) (default: float32)')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
Expand Down Expand Up @@ -168,10 +170,16 @@ def validate(args):

device = torch.device(args.device)

model_dtype = None
if args.model_dtype:
assert args.model_dtype in ('float32', 'float16', 'bfloat16')
model_dtype = getattr(torch, args.model_dtype)

# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress
if args.amp:
assert model_dtype is None or model_dtype == torch.float32, 'float32 model dtype must be used with AMP'
if args.amp_impl == 'apex':
assert has_apex, 'AMP impl specified as APEX but APEX is not installed.'
assert args.amp_dtype == 'float16'
Expand All @@ -184,7 +192,7 @@ def validate(args):
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Validating in mixed precision with native PyTorch AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
_logger.info(f'Validating in {model_dtype}. AMP not enabled.')

if args.fuser:
set_jit_fuser(args.fuser)
Expand Down Expand Up @@ -231,7 +239,7 @@ def validate(args):
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)

model = model.to(device)
model = model.to(device=device, dtype=model_dtype) # FIXME move model device & dtype into create_model
if args.channels_last:
model = model.to(memory_format=torch.channels_last)

Expand Down Expand Up @@ -299,6 +307,7 @@ def validate(args):
crop_border_pixels=args.crop_border_pixels,
pin_memory=args.pin_mem,
device=device,
img_dtype=model_dtype,
tf_preprocessing=args.tf_preprocessing,
)

Expand All @@ -310,7 +319,7 @@ def validate(args):
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device=device, dtype=model_dtype)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
Expand All @@ -319,8 +328,8 @@ def validate(args):
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.to(device)
input = input.to(device)
target = target.to(device=device)
input = input.to(device=device, dtype=model_dtype)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)

Expand Down

0 comments on commit 92f610c

Please sign in to comment.