Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write some rough notes on scale propagation directions #13

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions docs/scaleprop_proposal_20231110.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Scale propagation proposal

AutoScale relies on predicting the scale of tensors in computational graphs. Two key questions we have are:

- What rules & approximations can we use to predict scale?
- Therefore what is the design contract for a scaled op.
- When should we requantise, in order to reset to an empirical scale?

In this proposal, we mainly explore the first question, with a bias for simplicity over fidelity.

## Design alternatives

### 1. Scale is a worst-case bound

Choose scale such that there is no situation where in-range inputs give out-of-range outputs.

Note that the following maths assumes that data is in the range `[-1, 1]`.

| Operation | Scaling rule |
| --- | --- |
| `R = quantise(a)` | `R.scale = max(abs(a))` |
| `R = add(A, B)` | `R.scale = A.scale + B.scale` |
| `R = sub(A, B)` | `R.scale = A.scale + B.scale` |
| `R = mul(A, B)` | `R.scale = A.scale * B.scale` |
| `R = dot(A, B)` | `R.scale = A.scale * B.scale * inner_dim` |
| `R = div(A, B)` | undefined |
| `R = pow(A, B)` | undefined unless all `b > 0`, then `R.scale = pow(A.scale, B.scale)` |

**Thoughts:**
- Easy to define
- No distributional assumptions (only that scale defines the max)
- Should be relatively consistent (e.g. `A + A` behaves the same as `2 * A`)
- Increases risk of underflow, if scale recalculation isn't frequent enough
- May require too-frequent recalculation

**Example of scale-pessimism:**

ScaledTensors `A` and `B` each contain 1000 Gaussian values, std=1. The scale of each might be 3 (due to the 3-sigma rule). We run `R = A * B`, so `R.scale = 9` from the table above. An average case scaling rule would have set `R.scale = 3`.

### 2. Scale is an average-case bound

Choose scale that, under simple distributional assumptions, predicts the actual scale of the output.

Note that the following maths assumes that `1` is the midpoint of the range of tensor data (e.g. floating point formats).

| Operation | Scaling rule |
| --- | --- |
| `R = quantise(a)` | `R.scale = sqrt((a**2).mean())` |
| `R = add(A, B)` | `R.scale = sqrt(A.scale**2 + B.scale**2)` |
| `R = sub(A, B)` | `R.scale = sqrt(A.scale**2 + B.scale**2)` |
| `R = mul(A, B)` | `R.scale = A.scale * B.scale` |
| `R = dot(A, B)` | `R.scale = A.scale * B.scale * sqrt(inner_dim)` |
| `R = div(A, B)` | undefined |
| `R = pow(A, B)` | undefined |

**Thoughts:**
- Somewhat easy to define "just assume that your inputs are IID-Gaussian (or IID-something-else)", if that helps.
- There are some inconsistencies (`A + A` doesn't behave the same as `2 * A`), but this is true for finite-precision numerics anyway!
- For undefined cases, more thought required!

### 3. Use worse-case scale, track average-case scale

Track both worst-case scale and average-case scale in `ScaledTensor`. The worst-case scale is the one used for quantisation, while the average-case scale is metadata. This scheme behaves like worst-case scaling, and uses the difference between the average-case scale and the worst-case scale to determine when to requantise.

For example:

```
R = dot(A, B)

dtype -- E4M3
A.shape, B.shape -- (4096, 4096), (4096, 4096)
A.scale -- (64, 2) (worst, average)
B.scale -- (16, 1) (worst, average)
```

If we ran-scaled, we would set an output scale of `(64*16*4096, 2*1*sqrt(4096)) = (4194304, 128)`. Since the ratio of worst to average scales is `32768`, we are worried about underflow in our E4M3 format (which has a ratio of max to min normal that is less than this). Therefore we requantise `A` (the worse offender) on the way in to the op. Perhaps the new `A.scale = (8, 4)`, so our output scale is a more reasonable `(524288, 256)` with a ratio of `2048`.

_While writing this example, I had to go quite extreme — this is due to the E4M3 format, which still has quite a wide range. Requantisation would automatically happen much more frequently in integer or low-E formats, for this reason._

### A. Notes on undefined cases

There are cases where scale tracking alone isn't enough to save you. Consider an implementation of LayerNorm:

```
Y = (X - X.mean()) / (X.var() + 1e-6).sqrt()
```

We know that `Y.scale = 1` is fine. But we'd need some sophisticated theorem-proving to get this from the bunch of primitives!

Tracing through the computation:

```
(X - X.mean()).scale ~= X.scale
(X.var() + 1e-6).scale ~= X.scale**2
(X.var() + 1e-6).sqrt().scale - (undefined, need to know >=0), then, yes ~= X.scale
((X - X.mean()) / (X.var() + 1e-6).sqrt()).scale - (undefined, need to relate numer to denom)
```

Propagating extra information through ScaledTensor could help some cases, e.g. a minimum bound on the values would ease `sqrt()`, `pow()`. But others, like LayerNorm's `div()` seem harder.

Some options:
- (More) theorem proving
- User-side promises `Y = with_scale(1.0, (X - X.mean()) / (X.var() + 1e-6).sqrt())`
- Extract subgraphs containing these operations, where we use regular (unscaled) tensors & lift to higher-precision

## Proposal #1 (most similar to unit scaling)

- Scale is set to an estimate of the uncentered standard deviation
- Ops assume inputs are IID Gaussian
- For floating point types, `dequantise(X) = X.data * X.scale`, since `1` is the center of the log-range
- For integer types, perhaps `dequantise(X) = X.data * X.scale * 4 / INT_MAX`, to provide 4-sigma headroom(?)
- When an op cannot make a reasonable estimate of scale (e.g. `pow`, `div`, `sqrt`), perform the operation in higher precision ("global" setting), and requantise the output.
- Remove unnecessary `dequantise(quantise(X))`, e.g. sqrt->div in the LayerNorm example above
- Provide `with_scale()` to the user to make a promise about scale, overriding the above logic

This does not address when to reset to an emprical scale because are estimates are too weak. It also suggests we shouldn't worry too much about minor inconsistencies e.g. `(A + A).scale = sqrt(2)*A.scale` versus `(2*A).scale = 2*A.scale`.

## Proposal #2 (worst-case with average-case tracking for renormalisation)

- Scaled tensors keep `scale` (used for quantistaion, based on worst case analysis) and `expected_scale` (based on average case)
- Quantise sets `scale = abs(max(x))`, `expected_scale = sqrt(mean(x**2))`
- (Being a bit sloppy here, the quantisation scale should use `dtype.max`)
- Ops compute the new `scale` based on worst-case logic, and estimate `expected_scale` based on average-case logic
- Every input to every scaled op includes a runtime-conditional requantisation based on the ratio between scale and expected scale, and the range of the element dtype
- _Alternative: perhaps this doesn't have to be runtime-conditional, if we don't propagate `expected_scale`, and instead statically propagate the ratio, `scale/expected_scale` (e.g. this increases by `*sqrt(inner_dim)` for a dot product)_
- When an op cannot make a reasonable estimate of scale (e.g. `pow`, `div`, `sqrt`), perform the operation in higher precision ("global" setting), and requantise the output
- Remove unnecessary `dequantise(quantise(X))`, e.g. sqrt->div in the LayerNorm example above
- Provide `with_scale()` to the user to make a promise about scale, overriding the above logic
115 changes: 115 additions & 0 deletions docs/scaleprop_proposal_b_20231110.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Scale propagation proposal B

AutoScale relies on predicting the scale of tensors in computational graphs. Two key questions we have are:

- What rules & approximations can we use to predict scale?
- Therefore what is the design contract for a scaled op.
- When should we requantise, in order to reset to an empirical scale?

In this proposal, we explore the two problems jointly, prioritising practical robustness over a more principled approach.

## High-level idea

1. For the first few iterations, run in higher-precision and calculate empirical "propagation scales" for each op, as well as collecting statistics for our rescaling strategy
2. These propagation scales are a multiplier which is applied to the input tensor-scales to get the output tensor-scale (_not_ simply the scale of the output tensor, see below for details)
3. After these initial iterations, a rescaling strategy will decide for each op how often it wishes to update its propagation scale (via empirical re-calculation), based on the statistics collected
4. More unstable ops (those where the tensor scale changes significantly) will be updated more often
5. The simplest version of the strategy is "on/off" - select a % of ops which will never be updated, and the rest update every `n` iterations
6. As the propagation scale is a _multiplier_, even when an op's propagation scale is fixed the benefits of any re-scalings elsewhere propagate
7. This is unlike Transformer Engine, where e.g. a drop in the scale of `grad(final_layer)` can require _every_ bwd-pass tensor to re-scale; in auto-scale only that op needs a new scale

## Design Principles

> quantization research is like printers. Nobody cares about printers. Nobody likes printers. But everybody is happy if printers do their job. --Tim Dettmers, 2022

- There should be as few assumptions as possible about the user's workload
- The default mode of operation should be very likely to give full accuracy on almost any model/workload, at the expense of speed
- There should be a clear, simple path presented to users for optimising to get the desired speedups, at the expense of accuracy
- There should be a reliable way of indicating to users if we might have over/underflowed
- The approach should be explainable in a few sentences

Plus, a key consideration: for big expensive runs, users don't have the opportunity to do a high-precision baseline to ensure auto-scale hasn't degraded things.
On this basis, we want to give users confidence that our method a) is fundamentally conservative, b) can flag up if it thinks it's seen a numerics issue.
Unlike unit scaling, autoscale users shouldn't have to cross their fingers and hope things turn out ok!

## Assumptions

We require the following assumptions for autoscale to work/be used effectively:
- the computational graph contains sub-graphs that are re-used (e.g. repeated layers, training loops)
- for re-used computational sub-graphs, no operation experiences a large, sudden shift in the way it propagates the scale of its inputs
- tensor dtypes that are sufficient for the first few iterations of a sub-graph are sufficient throughout (though an advanced scheme could in theory change dtypes)

Note that we do not need to assume:
- anything about the distributions of tensors
- anything about the type of workload - it need not be ML

## Usage

### (Contrived) Example

```python
x = torch.randn(batch_size, 16, dtype=torch.float32) * 3
linear = auto_scale.nn.Linear(16, 10, bias=False)
nn.init.normal_(linear.weight, std=5)

with auto_scale.scaling_context(analysis_iters=6, rescale_strategy=auto_scale.strategy.on_off(op_freq=1/3, loop_freq=1/20)):
# in practice there would be a loop in here that runs at least `analysis_iters` times
x = auto_scale.scale(x, dtype=torch.float8_e4m3) # now a scaled float8 tensor, scale=3
linear = auto_scale.scale(linear, dtype=torch.float8_e4m3) # scaled float8, scale=5
y = linear(x) # y is also scaled float8, scale=3*sqrt(16)*5=60 # where the propagation scale is calculated empirically as sqrt(16)
z = auto_scale.unscale(y, torch.float32)

print(x, x.type, x.scale)
print(linear.prop_scale)
print(y, y.type, y.scale)
print(z, z.dtype, z.std())
```
outputs:
```
tensor([...]) scaled_fp8 3
4 # Initially calculated empirically as `norm(linear.weight.value @ x.value)` (where `value` is the non-scale part of the scaled tensor), may then be re-calculated every n steps or frozen
tensor([...]) scaled_fp8 60
tensor([...]) float32 60
```

### Explanation

`with auto_scale.scaling_context(analysis_iters=6, rescale_strategy=auto_scale.strategy.on_off(op_freq=1/3, loop_freq=1/20))`

- the first 6 times scaled-tensor ops within this context are used, propagation scales will be calculated empirically in float32. In addition, statistics will be captured (e.g. abs max/min) for use by the rescaling strategy
- after 6 uses, the on/off rescaling strategy will be used - in this case 1/3 of the ops will be re-scaled every 20 steps, and the rest frozen (e.g. `scale`/`unscale` frozen, re-scale `linear`)
- the choice of which op to use is determined by some logic within the rescaling strategy. In this case it could just be based on how much the norm of the empirical propagation scale changes for each op over the 6 steps.
- other more complex rescaling strategies could be used to e.g. have different frequencies for different ops, or change those frequencies dynamically throughout training

`x = auto_scale.scale(x, dtype=torch.float8_e4m3) # now a scaled float8 tensor, scale=3`

- the empirical scaling will set the tensor-scale here to the norm of the input tensor, and divide by that tensor-scale to get the tensor-value
- this is done in float32 for the first 6 iterations, then scaled float8
- if this op were frozen for subsequent operations, it would repeat this for every cast assuming the same tensor-scale(/norm)
- some work required to figure out how best to associate there "propagation scales" with their corresponding functions in software

`y = linear(x) # y is also scaled float8, scale=3*sqrt(16)*5=60`

- this is a special implementation of a linear layer defined in the auto-scale library, designed to handle scaled tensors
- the logic is: `y.scale = x.scale * linear.weight.scale * linear.prop_scale`; `y.value = x.value @ linear.weight.value / linear.prop_scale`
- where the empirical calculation of the op's propagation scale is `linear.prop_scale = norm(y.value)` (done in fp32 for analysis phase, slightly more involved for re-scaling after that)
- the above example assumes the norm is std, but it could also be e.g. amax
- note that the transformer-engine way of doing this would be: `y.scale = linear.prop_scale`; `y.value = (x.scale * y.scale / linear.prop_scale) * x.value @ linear.weight.value`, `linear.prop_scale = y.scale * norm(y.value)`. This does not facilitate propagation of scale, which is the key difference.

## Over/underflow warnings

When we come to re-scale an op and the difference in old/new scales is sufficiently large that over/underflows may have caused a degradation, we should have some mechanism for flagging a warning to the user.
How we determine this requires some thought.
Of course, we aim to avoid this scenario - the frequency with which we decide to re-compute scales will be determined precisely to stop this happening. But if the user pushes their rescaling strategy too hard, then this should warn them.

The warnings should be reasonably conservative, with the hope that we can always signal to users if we've had a range issue.
For this reason we may also wish to implement a feature where for the final step of training _every_ op re-computes its scale (especially those which were frozen throughout), in order that they all check for the possibility that the model's final scales are inappropriate.

## Additional notes

- the hope here is that even though this method locks the user into having more empirical scale-calculations than a unit-scaling-like method, in practice only a small number of ops actually change the way they propagate scale significantly throughout training, and even these only need rescaling every e.g. 100 steps.
- if this is true, a sensible scaling stragety should be able to freeze most ops, and set the `loop_freq` to something fairly low. The overheads in such a system should be negligible.
- the default rescaling strategy hyperparameters provided to users should be conservative - for the on/off strategy we may simply set the default to be "on everywhere". The path to then getting speedups is then clear - start dropping the frequencies.
- this system is simpler than a unit-scaling-like method and has fewer assumptions & scary edge-cases
- most ops can be implemented in a very similar way to those described above
- this can be wrapped up into a graph transform, but similar to the unit scaling library it doesn't have to be. This way users have the options of e.g. just auto-scaling their matmuls and still getting the benefits. It's a less intrusive change.
Loading