Skip to content

Commit

Permalink
WIP: TimeVaryingInputs2D
Browse files Browse the repository at this point in the history
[skip ci][ci skip]
  • Loading branch information
Sbozzolo committed Jan 25, 2024
1 parent 887ece6 commit 70d645e
Show file tree
Hide file tree
Showing 7 changed files with 289 additions and 33 deletions.
121 changes: 121 additions & 0 deletions src/shared_utilities/FileReader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,4 +470,125 @@ exist in the DateTime calendar.
"""
to_datetime(date) = CFTime.reinterpret(DateTime, date)

struct TemporalDataHandler{PDT <: PrescribedDataTemporal, SPACE}
prescibed_data::PDT
space::SPACE
end

"""
first_time(data_handler::TemporalDataHandler)
Return the time of the first snapshot in seconds.
"""
function first_time(data_handler::TemporalDataHandler)
first_time_datetime = data_handler.prescibed_data.file_info.all_dates[begin]
date_ref = data_handler.prescribed_data.sim_info.date_ref
return Second(first_time_datetime - date_ref - t_start).value
end

"""
last_time(data_handler::TemporalDataHandler)
Return the time of the last snapshot in seconds.
"""
function last_time(data_handler::TemporalDataHandler)
last_time_datetime = data_handler.prescibed_data.file_info.all_dates[end]
date_ref = data_handler.prescribed_data.sim_info.date_ref
return Second(last_time_datetime - date_ref - t_start).value
end

"""
previous_time(data_handler::TemporalDataHandler, time)
Return the time in seconds of the snapshot before the given `time`.
"""
function previous_time(data_handler::TemporalDataHandler, time)
sim_info = data_handler.prescibed_data.sim_info
time_to_datetime = to_datetime(
sim_info.date_ref +
Second(round(sim_info.t_start)) +
Second(round(time)),
)
return searchsortedfirst(
data_handler.prescibed_data.file_info.all_dates,
time_to_datetime,
) - 1
end

"""
previous_time(data_handler::TemporalDataHandler, time)
Return the time in seconds of the snapshot after the given `time`.
"""
function next_time(data_handler::TemporalDataHandler, time)
sim_info = data_handler.prescibed_data.sim_info
time_to_datetime = to_datetime(
sim_info.date_ref +
Second(round(sim_info.t_start)) +
Second(round(time)),
)
return searchsortedfirst(
data_handler.prescibed_data.file_info.all_dates
time_to_datetime,
)
end

"""
previous_snapshot!(data_handler::TemporalDataHandler, time)
Return the first data snapshot from `data_handler` before the given `time`.
`previous_snapshot!` potentially modifies the internal state of `data_handler` and it might be a
very expensive operation.
"""
function previous_snapshot!(data_handler::TemporalDataHandler, time)
# Time in seconds

# Get the current date from `time`
sim_info = data_handler.prescibed_data.sim_info
sim_date = to_datetime(
sim_info.date_ref + Second(round(sim_info.t_start)) + Second(round(t)),
)
# Use next date if it's closest to current time
# This maintains `all_dates[date_idx]` <= `sim_date` < `all_dates[date_idx + 1]`
if sim_date >= to_datetime(next_date_in_file(prescibed_data.file_info))
read_data_fields!(
data_handler.prescibed_data,
sim_date,
data_handler.space,
)
end

return prescribed_data.file_state.data_fields[1]
end

"""
next_snapshot!(data_handler::TemporalDataHandler, time)
Return the first data snapshot from `data_handler` after the given `time`.
`next_snapshot!` potentially modifies the internal state of `data_handler` and it might be a
very expensive operation.
"""
function next_snapshot!(data_handler::TemporalDataHandler, time)
# Time in seconds

# Get the current date from `time`
sim_info = data_handler.prescibed_data.sim_info
sim_date = to_datetime(
sim_info.date_ref + Second(round(sim_info.t_start)) + Second(round(t)),
)
# Use next date if it's closest to current time
# This maintains `all_dates[date_idx]` <= `sim_date` < `all_dates[date_idx + 1]`
if sim_date >= to_datetime(next_date_in_file(model_albedo.albedo_info))
read_data_fields!(
data_handler.prescibed_data,
sim_date,
data_handler.space,
)
end

return prescribed_data.file_state.data_fields[2]
end

end
22 changes: 21 additions & 1 deletion src/shared_utilities/TimeVaryingInputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,39 @@ When passing single-site data
When a `times` and `vals` are passed, `times` have to be sorted and the two arrays have to
have the same length.
=======
When the input is a function, the signature of the function can be `func(time, args...;
kwargs...)`. The function will be called with the additional arguments and keyword arguments
passed to `evaluate!`. This can be used to access the state and the cache and use those to
set the output field.
For example:
```julia
CO2fromp(time, Y, p) = p.atmos.co2
input = TimeVaryingInput(CO2fromY)
evaluate!(dest, input, t, Y, p)
```
"""
function TimeVaryingInput end

"""
evaluate!(dest, input, time)
evaluate!(dest, input, time, args...; kwargs...)
Evaluate the `input` at the given `time`, writing the output in-place to `dest`.
Depending on the details of `input`, this function might do I/O and communication.
Extra arguments
================
`args` and `kwargs` are used only when the `input` is a non-interpolating function, e.g.,
an analytic one. In that case, `args` and `kwargs` are passed down to the function itself.
"""
function evaluate! end

include("analytic_time_varying_input.jl")
include("interpolating_time_varying_inputs.jl") # Shared stuff
include("interpolating_time_varying_input0d.jl")
include("interpolating_time_varying_input2d.jl")

end
15 changes: 11 additions & 4 deletions src/shared_utilities/analytic_time_varying_input.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
struct AnalyticTimeVaryingInput{F <: Function} <: AbstractTimeVaryingInput
# func here as to be GPU-compatible (e.g., splines are not)
# func here has to be GPU-compatible (e.g., splines are not)
func::F
end

function TimeVaryingInput(input::Function; method = nothing, device = nothing)
# _kwargs... is needed to seamlessly support the other TimeVaryingInputs.
function TimeVaryingInput(input::Function; _kwargs...)
isnothing(method) ||
@warn "Interpolation method is ignored for analytical functions"
return AnalyticTimeVaryingInput(input)
end

function evaluate!(dest, input::AnalyticTimeVaryingInput, time)
dest .= input.func(time)
function evaluate!(
dest,
input::AnalyticTimeVaryingInput,
time,
args...;
kwargs...,
)
dest .= input.func(time, args...; kwargs...)
return nothing
end
27 changes: 10 additions & 17 deletions src/shared_utilities/interpolating_time_varying_input0d.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
import CUDA

"""
NearestNeighbor
Return the value corresponding to the point closest to the input time.
"""
struct NearestNeighbor <: AbstractInterpolationMethod end

"""
LinearInterpolation
Perform linear interpolation between the two neighboring points.
"""
struct LinearInterpolation <: AbstractInterpolationMethod end

"""
InterpolatingTimeVaryingInput0D
Expand Down Expand Up @@ -67,7 +51,13 @@ function Adapt.adapt_structure(to, itp::InterpolatingTimeVaryingInput0D)
)
end

function evaluate!(destination, itp::InterpolatingTimeVaryingInput0D, time)
function evaluate!(
destination,
itp::InterpolatingTimeVaryingInput0D,
time,
args...;
kwargs...,
)
time in itp || error("TimeVaryingInput does not cover time $time")
if ClimaComms.device(itp.context) isa ClimaComms.CUDADevice
CUDA.@cuda evaluate!(parent(destination), itp, time, itp.method)
Expand Down Expand Up @@ -101,6 +91,7 @@ function TimeVaryingInput(
)
end

<<<<<<< HEAD:src/shared_utilities/interpolating_time_varying_input0d.jl
"""
in(time, itp::InterpolatingTimeVaryingInput0D)
Expand All @@ -111,6 +102,8 @@ function Base.in(time, itp::InterpolatingTimeVaryingInput0D)
end


=======
>>>>>>> a55fa815 (WIP: TimeVaryingInputs2D):src/shared_utilities/interpolating_time_varying_input1d.jl
function evaluate!(
dest,
itp::InterpolatingTimeVaryingInput0D,
Expand Down
91 changes: 91 additions & 0 deletions src/shared_utilities/interpolating_time_varying_input2d.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
InterpolatingTimeVaryingInput2D
The constructor for InterpolatingTimeVaryingInput2D is not supposed to be used directly, unless you
know what you are doing. The constructor does not perform any check and does not take care of
GPU compatibility. It is responsibility of the user-facing constructor TimeVaryingInput() to do so.
"""
struct InterpolatingTimeVaryingInput2D{
DH <: TemporalDataHandler,
AA <: AbstractArray,
M <: AbstractInterpolationMethod,
CC <: ClimaComms.AbstractCommsContext,
R <: Tuple,
} <: AbstractTimeVaryingInput
"""Object that has all the information on how to deal with files, data, and so on.
Having to deal with files, it lives on the CPU."""
data_handler::DH

"""Interpolation method"""
method::M

"""ClimaComms context"""
context::CC

"""Range of times over which the interpolator is defined. range is always defined on the
CPU. Used by the in() function."""
range::R
end

function TimeVaryingInput(
data_handler::TemporalDataHandler,
method = LinearInterpolation(),
context = ClimaComms.context(),
)
range = (first_time(data_handler), last_time(data_handler))
return InterpolatingTimeVaryingInput2D(data_handler, method, context, range)
end

function evaluate!(
dest,
itp::InterpolatingTimeVaryingInput2D,
time,
args...;
kwargs...,
)
time in itp || error("TimeVaryingInput does not cover time $time")
evaluate!(parent(dest), itp, time, itp.method)
return nothing
end

function evaluate!(
dest,
itp::InterpolatingTimeVaryingInput2D,
time,
::NearestNeighbor,
args...;
kwargs...,
)
t0, t1 =
previous_time(itp.data_handler, time), next_time(itp.data_handler, time)

# The closest snapshot could be either the previous or the next one
if (time - t0) < (t1 - time)
dest .= previous_snapshot!(itp.data_handler, time)
else
dest .= next_snapshot!(itp.data_handler, time)
end
end

function evaluate!(
dest,
itp::InterpolatingTimeVaryingInput2D,
time,
::LinearInterpolation,
args...;
kwargs...,
)
# Linear interpolation is:
# y = y0 + (y1 - y0) * (time - t0) / (t1 - t0)
#
# Define coeff = (time - t0) / (t1 - t0)
#
# y = (1 - coeff) * y0 + coeff * y1

t0, t1 =
previous_time(itp.data_handler, time), next_time(itp.data_handler, time)
coeff = (time - t0) / (t1 - t0)
dest .=
(1 - coeff) * previous_snapshot!(itp.data_handler, time) .+
coeff * next_snapshot!(itp.data_handler, time)
end
Loading

0 comments on commit 70d645e

Please sign in to comment.