Skip to content

Commit

Permalink
Fourier Coefficient Extractor (#38)
Browse files Browse the repository at this point in the history
* Add function

* Enhance docstring

* Change rounding default

* Add coefficient extractor to docs
  • Loading branch information
Ceyron authored Sep 6, 2024
1 parent a7817b2 commit 2222786
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/api/utilities/spectral.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@

---

::: exponax.spectral.get_fourier_coefficients

---

::: exponax.spectral.build_scaling_array
141 changes: 141 additions & 0 deletions exponax/_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,3 +937,144 @@ def power_in_bucket(p, k):
# spectrum /= jnp.sum(spectrum, axis=-1, keepdims=True)

return spectrum


def get_fourier_coefficients(
state: Float[Array, "C ... N"],
*,
scaling_compensation_mode: Optional[
Literal["norm_compensation", "reconstruction", "coef_extraction"]
] = "coef_extraction",
round: Optional[int] = 5,
indexing: str = "ij",
) -> Complex[Array, "C ... (N//2)+1"]:
"""
Extract the Fourier coefficients of a state in Fourier space.
It correctly compensates the scaling used in `exponax.fft` such that the
coefficient values can be directly read off from the array.
**Arguments:**
- `state`: The state following the `Exponax` convention with a leading
channel axis and then one, two, or three subsequent spatial axes, each
of the same length N.
- `scaling_compensation_mode`: The mode of the scaling array to use to
compensate the scaling of the Fourier transform. The mode
`"norm_compensation"` would produce the coefficient array as produced if
`jnp.fft.rfftn` was applied with `norm="forward"`, instead of the
default of `norm="backward"` which is also the default used in
`Exponax`. The mode `"reconstruction"` is similar to that but
compensates for the fact that the rfft only has half of the coefficients
along the right-most axis. The mode `"coef_extraction"` allows to read
of the coefficient e.g. at index [i, j] (in 2D) directly wheras in the
other modes, one would require to consider both the positive and
negative wavenumbers. Can be set to `None` to not apply any scaling
compensation. See also [`exponax.spectral.build_scaling_array`][] for
more information.
- `round`: The number of decimals to round the coefficients to. Default is
`5` which compensates for the rounding errors created by the FFT in
single precision such that all coefficients that should not carry any
energy also have zero value. Set to `None` to not round.
- `indexing`: The indexing scheme to use for `jax.numpy.meshgrid`.
**Returns:**
- `coefficients`: The Fourier coefficients of the state.
!!! warning
Do not use the results of this function together with the `exponax.viz`
utilities since they will periodically wrap the boundary condition which
is not needed in Fourier space.
!!! tip
Use this function to visualize the coefficients in higher dimensions.
For example in 2D
```python
state_2d = ... # shape (1, N, N)
coef_2d = exponax.spectral.get_fourier_coefficients(state_2d)
# shape (1, N, (N//2)+1)
plt.imshow(
jnp.log10(jnp.abs(coef_2d[0])),
)
```
And in 3D (requires the [`vape4d`](https://github.com/KeKsBoTer/vape4d)
volume renderer to be installed - only works on GPU devices).
```python
state_3d = ... # shape (1, N, N, N)
coef_3d = exponax.spectral.get_fourier_coefficients(
state_3d, round=None,
)
images = ex.viz.volume_render_state_3d(
jnp.log10(jnp.abs(coef_3d)), vlim=(-8, 2),
)
plt.imshow(images[0])
```
To have the major half to the real-valued axis more prominent, consider
flipping it via
```python
coef_3d_flipped = jnp.flip(coef_3d, axis=-1)
```
!!! tip
**Interpretation Guide** In general for a FFT following the NumPy
conventions, we have:
* Positive amplitudes on cosine signals have positive coefficients in
the real part of both the positive and the negative wavenumber.
* Positive amplitudes on sine signals have negative coefficients in the
imaginary part of the positive wavenumber and positive coefficients
in the imaginary part of the negative wavenumber.
As such, if the output of this function on a 1D state was
```python
array([[3.0 + 0.0j, 0.0 - 1.5j, 0.3 + 0.8j, 0.0 + 0.0j,]])
```
This would correspond to a signal with:
* A constant offset of +3.0
* A first sine mode with amplitude +1.5
* A second cosine mode with amplitude +0.3
* A second sine mode with amplitude -0.8
In higher dimensions, the interpretation arise out of the tensor
product. Also be aware that for a `(1, N, N)` state, the coefficients
are in the shape `(1, N, (N//2)+1)`.
"""
state_hat = fft(state)
if scaling_compensation_mode is not None:
scaling = build_scaling_array(
state.ndim - 1,
state.shape[-1],
mode=scaling_compensation_mode,
indexing=indexing,
)
coefficients = state_hat / scaling
else:
coefficients = state_hat

if round is not None:
coefficients = jnp.round(coefficients, round)

return coefficients

0 comments on commit 2222786

Please sign in to comment.