Skip to content

Commit

Permalink
Quicker spectrum in 1D
Browse files Browse the repository at this point in the history
  • Loading branch information
Ceyron committed Sep 12, 2024
1 parent 2222786 commit 034b4c5
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions exponax/_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,17 +909,21 @@ def get_spectrum(
mode="reconstruction", # because of rfft
)

if power:
magnitude = 0.5 * jnp.abs(state_hat_scaled) ** 2
else:
magnitude = jnp.abs(state_hat_scaled)

if num_spatial_dims == 1:
# 1D does not need any binning and can be returned directly
return magnitude

wavenumbers_mesh = build_wavenumbers(num_spatial_dims, num_points)
wavenumbers_1d = build_wavenumbers(1, num_points)
wavenumbers_norm = jnp.linalg.norm(wavenumbers_mesh, axis=0, keepdims=True)

dk = wavenumbers_1d[0, 1] - wavenumbers_1d[0, 0]

if power:
magnitude = 0.5 * jnp.abs(state_hat_scaled) ** 2
else:
magnitude = jnp.abs(state_hat_scaled)

spectrum = []

def power_in_bucket(p, k):
Expand Down

0 comments on commit 034b4c5

Please sign in to comment.