Replies: 1 comment 1 reply
-
What Flax is trying to do with mixed precision is provide very explicit APIs that allow you to control the dtypes used in part of your code. For larger operations like normalization layers we try to do something sensible in the implementation. Where sensible means we do something that empirically is found to be stable in most cases. PyTorch AMP is a more general purpose transformation that takes code that doesn't deal with mixed precision types at all and does a best effort transformation of the computation that uses half precision types as much as possible. A tool like that could be very valuable in the JAX ecosystem as well but it has less to do with Flax. In JAX I would imagine the equivalent of AMP to be a functional transformation just like jit or vmap although it coudl be provided by a seperate library. So you would have something like:
|
Beta Was this translation helpful? Give feedback.
-
TLDR: I am proposing a jax interpreter for mixed precision training, automatically converting dtypes as appropriate. This topic aims to ask how it fits within the roadmap for mixed precision training in Flax if already considered, or to pitch it otherwise.
Background
Training with reduced precision can lead to problems with certain operations such as normalization layers, where large reduction ops can cause overflows. For that reason, the AMP implementation of PyTorch keeps a white- and blacklist of operations which respectively should and shouldn't be carried out in lower precision, falling back to f32 when necessary. Currently there seems to be no way to do that in flax: I can set the dtype of module parameters (
param_dtype
) and output types (dtype
), but for example normalization layers will still perform reduction operations at precision of inputs.My understanding of the state of AMP in Flax and comparison with PyTorch
PR #1803 is introducing
computation_dtype
(alongside already existingparam_dtype
anddtype
), which is a step towards solving this issue. My worry is however that keeping track of those parameters and propagating them all the way through our NN definitions is cumbersome and introduces a lot of boiler plate. It is also not very backwards-friendly: it avoids breakage, but a library implementing for example a ResNet50 backbone cannot be used in mixed precision training until it releases a patch to propagatecomputation_dtype
from the top-level call. A PyTorch-style solution could be to traverse the tree of nested modules after creation of the root module (ResNet50) and recursively change theircomputation_dtype
. This is however not very effective, because with the@linen.compact
idiom many submodules are defined inModule.__call__
and not stored statefully. This is in contrast to the imperative approach of PyTorch where all learnable parameters have to be defined inModule.__init__
and can thus be traversed.My suggestion
My suggestion to this problem is to define a context manager/decorator which implements mixed precision via jax interpreter mechanics. We would have a whitelist of operations which should run in reduced precision and a black list of those which shouldn't. Some operations, like reshaping is type-agnostic and would be in neither of those. The interpreter then considers the entire computation graph and introduces casting from/to lower precision when appropriate. This way the technicalities of precision choices can be abstracted away from model definition, following the same philosophy as
xmap
: define models at a high level of abstraction and adjust the details post-hoc, via jax transforms. To illustrate the savings in terms of boilerplate, this approach would make all ofcomputation_dtype
,param_dtype
anddtype
obsolete (in context of automatic mixed precision training, that is).Notes
Note 1: I imagine an idea like mine has likely already been discussed before, so I would be happy to know where it fits within the bigger picture of where Flax is going with mixed precision training. I originally intended this topic to just ask about the state of work, but I figured I may as well pitch the solution as I see it, to better understand the tradeoffs with different approaches.
Note 2: It's certainly up for discussion whether this kind of machinery belongs in Flax, Optax or elsewhere entirely, please feel free to move it as appropriate.
Beta Was this translation helpful? Give feedback.
All reactions