Skip to content

Commit

Permalink
Remove AbstractTOMLDict
Browse files Browse the repository at this point in the history
  • Loading branch information
nefrathenrici committed May 31, 2024
1 parent ef8c2e4 commit b64fce3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 56 deletions.
7 changes: 0 additions & 7 deletions docs/src/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
CurrentModule = ClimaParams
```

## Parameter dictionaries

```@docs
AbstractTOMLDict
ParamDict
```

## File parsing and parameter logging

### User facing functions:
Expand Down
1 change: 0 additions & 1 deletion src/ClimaParams.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module ClimaParams
using TOML
using DocStringExtensions

export AbstractTOMLDict
export ParamDict

export float_type,
Expand Down
86 changes: 38 additions & 48 deletions src/file_parsing.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
"""
AbstractTOMLDict{FT <: AbstractFloat}
Abstract parameter dict. One subtype:
- [`ParamDict`](@ref)
"""
abstract type AbstractTOMLDict{FT <: AbstractFloat} end

const NAMESTYPE =
Union{AbstractVector{S}, NTuple{N, S} where {N}} where {S <: AbstractString}

Expand All @@ -21,27 +13,29 @@ Uses the name to search
$(DocStringExtensions.FIELDS)
"""
struct ParamDict{FT} <: AbstractTOMLDict{FT}
struct ParamDict{FT <: AbstractFloat}
"dictionary representing a default/merged parameter TOML file"
data::Dict
"either a nothing, or a dictionary representing an override parameter TOML file"
override_dict::Union{Nothing, Dict}
end

"""
float_type(::AbstractTOMLDict)
float_type(::ParamDict)
The float type from the parameter dict.
"""
float_type(::AbstractTOMLDict{FT}) where {FT} = FT
float_type(::ParamDict{FT}) where {FT} = FT

Base.iterate(pd::ParamDict, state) = Base.iterate(pd.data, state)
Base.iterate(pd::ParamDict) = Base.iterate(pd.data)

Base.getindex(pd::ParamDict, i) = getindex(pd.data, i)

Base.print(td::ParamDict, io = stdout) = TOML.print(io, td.data)

Check warning on line 35 in src/file_parsing.jl

View check run for this annotation

Codecov / codecov/patch

src/file_parsing.jl#L35

Added line #L35 was not covered by tests

"""
log_component!(pd::AbstractTOMLDict, names, component)
log_component!(pd::ParamDict, names, component)
Adds a new key,val pair: `("used_in",component)` to each
named parameter in `pd`.
Expand Down Expand Up @@ -76,7 +70,7 @@ enforces `val` to be of type as specified in the toml file
Default type of `String` is used if no type is provided.
"""
function _get_typed_value(
pd::AbstractTOMLDict,
pd::ParamDict,
val,
valname::AbstractString,
valtype,
Expand All @@ -103,7 +97,7 @@ function _get_typed_value(
end

"""
get_values(pd::AbstractTOMLDict, names)
get_values(pd::ParamDict, names)
Gets the values of the parameters in `names` from the TOML dict `pd`.
"""
Expand All @@ -129,13 +123,13 @@ end

"""
get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
names::Union{String,Vector{String}},
component::String
)
get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Union{Dict, Vector{Pair}, NTuple{N, Pair}, Vararg{Pair}},
component::String
)
Expand All @@ -149,15 +143,15 @@ parameter names to variable names in code. Then, this function retrieves all par
from the long names and returns a NamedTuple where the keys are the variable names.
"""
function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
names::AbstractString,
component = nothing,
)
return get_parameter_values(pd, [names], component)
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
names::NAMESTYPE,
component::Union{AbstractString, Nothing} = nothing,
)
Expand All @@ -168,15 +162,15 @@ function get_parameter_values(
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Union{AbstractVector{Pair{S, S}}, NTuple{N, Pair}},
component = nothing,
) where {S, N}
return get_parameter_values(pd, Dict(name_map), component)
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Vararg{Pair};
component = nothing,
)
Expand All @@ -188,7 +182,7 @@ function get_parameter_values(
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Dict{S, S},
component = nothing,
) where {S <: AbstractString}
Expand All @@ -201,15 +195,15 @@ function get_parameter_values(
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::NamedTuple,
component = nothing,
)
return get_parameter_values(pd, Dict(pairs(name_map)), component)
end

function get_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
name_map::Dict{Symbol, Symbol},
component = nothing,
)
Expand Down Expand Up @@ -287,14 +281,14 @@ key "used_in" (i.e. were these parameters used within the model run).
Throws warnings in each where parameters are not used. Also throws
an error if `strict == true` .
"""
check_override_parameter_usage(pd::AbstractTOMLDict, strict::Bool) =
check_override_parameter_usage(pd::ParamDict, strict::Bool) =
check_override_parameter_usage(pd, strict, pd.override_dict)

check_override_parameter_usage(pd::AbstractTOMLDict, strict::Bool, ::Nothing) =
check_override_parameter_usage(pd::ParamDict, strict::Bool, ::Nothing) =
nothing

function check_override_parameter_usage(
pd::AbstractTOMLDict,
pd::ParamDict,
strict::Bool,
override_dict,
)
Expand Down Expand Up @@ -329,12 +323,12 @@ function check_override_parameter_usage(
end

"""
write_log_file(pd::AbstractTOMLDict, filepath)
write_log_file(pd::ParamDict, filepath)
Writes a log file of all used parameters of `pd` at
the `filepath`. This file can be used to rerun the experiment.
"""
function write_log_file(pd::AbstractTOMLDict, filepath::AbstractString)
function write_log_file(pd::ParamDict, filepath::AbstractString)
used_parameters = Dict()
for (key, val) in pd.data
if "used_in" in keys(val)
Expand All @@ -349,7 +343,7 @@ end

"""
log_parameter_information(
pd::AbstractTOMLDict,
pd::ParamDict,
filepath;
strict::Bool = false
)
Expand All @@ -360,7 +354,7 @@ override parameters are all used.
If `strict = true`, errors if override parameters are unused.
"""
function log_parameter_information(
pd::AbstractTOMLDict,
pd::ParamDict,
filepath::AbstractString;
strict::Bool = false,
)
Expand All @@ -372,17 +366,17 @@ end

"""
merge_override_default_values(
override_toml_dict::AbstractTOMLDict{FT},
default_toml_dict::AbstractTOMLDict{FT}
override_toml_dict::ParamDict,
default_toml_dict::ParamDict
) where {FT}
Combines the `default_toml_dict` with the `override_toml_dict`,
precedence is given to override information.
"""
function merge_override_default_values(
override_toml_dict::PDT,
default_toml_dict::PDT,
) where {FT, PDT <: AbstractTOMLDict{FT}}
override_toml_dict::ParamDict{FT},
default_toml_dict::ParamDict{FT},
) where {FT}
data = default_toml_dict.data
override_dict = override_toml_dict.override_dict
for (key, val) in override_toml_dict.data
Expand All @@ -394,7 +388,7 @@ function merge_override_default_values(
end
end
end
return PDT(data, override_dict)
return ParamDict{FT}(data, override_dict)
end

"""
Expand Down Expand Up @@ -425,17 +419,13 @@ function create_toml_dict(
return merge_override_default_values(override_toml_dict, default_toml_dict)
end

# Extend Base.print to AbstractTOMLDict
Base.print(td::AbstractTOMLDict, io = stdout) = TOML.print(io, td.data)


"""
get_tagged_parameter_names(pd::AbstractTOMLDict, tag::AbstractString)
get_tagged_parameter_names(pd::AbstractTOMLDict, tags::Vector{AbstractString})
get_tagged_parameter_names(pd::ParamDict, tag::AbstractString)
get_tagged_parameter_names(pd::ParamDict, tags::Vector{AbstractString})
Returns a list of the parameters with a given tag.
"""
function get_tagged_parameter_names(pd::AbstractTOMLDict, tag::AbstractString)
function get_tagged_parameter_names(pd::ParamDict, tag::AbstractString)
data = pd.data
ret_values = String[]
for (key, val) in data
Expand All @@ -447,7 +437,7 @@ function get_tagged_parameter_names(pd::AbstractTOMLDict, tag::AbstractString)
end

get_tagged_parameter_names(
pd::AbstractTOMLDict,
pd::ParamDict,
tags::Vector{S},
) where {S <: AbstractString} =
vcat(map(x -> get_tagged_parameter_names(pd, x), tags)...)
Expand All @@ -464,16 +454,16 @@ function fuzzy_match(s1::AbstractString, s2::AbstractString)
end

"""
get_tagged_parameter_values(pd::AbstractTOMLDict, tag::AbstractString)
get_tagged_parameter_values(pd::AbstractTOMLDict, tags::Vector{AbstractString})
get_tagged_parameter_values(pd::ParamDict, tag::AbstractString)
get_tagged_parameter_values(pd::ParamDict, tags::Vector{AbstractString})
Returns a list of name-value Pairs of the parameters with the given tag(s).
"""
get_tagged_parameter_values(pd::AbstractTOMLDict, tag::AbstractString) =
get_tagged_parameter_values(pd::ParamDict, tag::AbstractString) =
get_parameter_values(pd, get_tagged_parameter_names(pd, tag))

get_tagged_parameter_values(
pd::AbstractTOMLDict,
pd::ParamDict,
tags::Vector{S},
) where {S <: AbstractString} =
merge(map(x -> get_tagged_parameter_values(pd, x), tags)...)

0 comments on commit b64fce3

Please sign in to comment.