diff --git a/tests/test_utils.py b/tests/test_utils.py index abc790d9..51f94ade 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,6 +11,8 @@ from pathlib import Path from typing import Any, Dict, NamedTuple, Optional, Tuple, Union +import numpy as np + import pytest import torch import torch.distributed as dist @@ -240,3 +242,47 @@ def split_tensor_for_distributed_test( if move_to_device: x = x.to(device=device_id) return x + + +def fixed_init_tensor( + shape: torch.Size, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, + dtype: torch.dtype = torch.float, +): + """ + Utility for generating deterministic tensors of a given shape. In general stuff + like torch.ones, torch.eye, etc can result in trivial outputs. This utility + generates a range tensor [min_val, max_val) of a specified dtype, applies + a sine function if nonlinear=True, then reshapes to the appropriate shape. + """ + n_elements = np.prod(shape) + step_size = (max_val - min_val) / n_elements + x = torch.arange(min_val, max_val, step_size, dtype=dtype).reshape(shape) + if nonlinear: + return torch.sin(x) + return x + + +@torch.no_grad +def fixed_init_model( + model: nn.Module, + min_val: Union[float, int] = 0.0, + max_val: Union[float, int] = 1.0, + nonlinear: bool = False, +): + """ + This utility initializes all parameters of a model deterministically using the + function fixed_init_tensor above. See that docstring for details of each parameter. + """ + for name, param in model.named_parameters(): + param.copy_( + fixed_init_tensor( + param.shape, + min_val=min_val, + max_val=max_val, + nonlinear=nonlinear, + dtype=param.dtype, + ) + )