Skip to content

Commit

Permalink
Require string representation of infinity
Browse files Browse the repository at this point in the history
Closes #208
  • Loading branch information
nsmith- committed Jun 18, 2024
1 parent 8197e4f commit 88a50b4
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 54 deletions.
51 changes: 31 additions & 20 deletions src/correction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,30 @@ namespace {
}
throw std::runtime_error("Error: could not find variable " + std::string(name) + " in inputs");
}

double parse_edge(const rapidjson::Value& edge) {
if ( edge.IsDouble() ) {
return edge.GetDouble();
} else if ( edge.IsString() ) {
std::string_view str = edge.GetString();
if (str == "inf" or str == "+inf") return std::numeric_limits<double>::infinity();
else if (str == "-inf") return -std::numeric_limits<double>::infinity();
}
throw std::runtime_error("Invalid edge type");
}

std::vector<double> parse_bin_edges(const rapidjson::Value::ConstArray& edges) {
std::vector<double> result;
result.reserve(edges.Size());
for (const auto& edge : edges) {
double val = parse_edge(edge);
if ( result.size() > 0 && result.back() >= val ) {
throw std::runtime_error("binning edges are not monotone increasing");
}
result.push_back(val);
}
return result;
}
} // end of anonymous namespace

Variable::Variable(const JSONObject& json) :
Expand Down Expand Up @@ -235,7 +259,7 @@ void Variable::validate(const Type& t) const {

Variable Variable::from_string(const char * data) {
rapidjson::Document json;
rapidjson::ParseResult ok = json.Parse<rapidjson::kParseNanAndInfFlag>(data);
rapidjson::ParseResult ok = json.Parse(data);
if (!ok) {
throw std::runtime_error(
std::string("JSON parse error: ") + rapidjson::GetParseError_En(ok.Code())
Expand Down Expand Up @@ -283,7 +307,7 @@ Formula::Formula(const JSONObject& json, const std::vector<Variable>& inputs, bo

Formula::Ref Formula::from_string(const char * data, std::vector<Variable>& inputs) {
rapidjson::Document json;
rapidjson::ParseResult ok = json.Parse<rapidjson::kParseNanAndInfFlag>(data);
rapidjson::ParseResult ok = json.Parse(data);
if (!ok) {
throw std::runtime_error(
std::string("JSON parse error: ") + rapidjson::GetParseError_En(ok.Code())
Expand Down Expand Up @@ -403,14 +427,7 @@ Binning::Binning(const JSONObject& json, const Correction& context)
// set bins_
const auto &edgesObj = json.getRequiredValue("edges");
if ( edgesObj.IsArray() ) { // non-uniform binning
std::vector<double> edges;
rapidjson::Value::ConstArray edgesArr = edgesObj.GetArray();
for (const auto& edge : edgesArr) {
if ( ! edge.IsDouble() ) { throw std::runtime_error("Invalid edges array type"); }
double val = edge.GetDouble();
if ( edges.size() > 0 && edges.back() >= val ) { throw std::runtime_error("binning edges are not monotone increasing"); }
edges.push_back(val);
}
std::vector<double> edges = parse_bin_edges(edgesObj.GetArray());
if ( edges.size() != content.Size() + 1 ) {
throw std::runtime_error("Inconsistency in Binning: number of content nodes does not match binning");
}
Expand Down Expand Up @@ -470,13 +487,7 @@ MultiBinning::MultiBinning(const JSONObject& json, const Correction& context)
for (const auto& dimension : edges) {
const auto& input = inputs[idx];
if ( dimension.IsArray() ) { // non-uniform binning
std::vector<double> dim_edges;
dim_edges.reserve(dimension.GetArray().Size());
for (const auto& item : dimension.GetArray()) {
double val = item.GetDouble();
if ( dim_edges.size() > 0 && dim_edges.back() >= val ) { throw std::runtime_error("binning edges are not monotone increasing"); }
dim_edges.push_back(val);
}
std::vector<double> dim_edges = parse_bin_edges(dimension.GetArray());
if ( ! input.IsString() ) { throw std::runtime_error("invalid multibinning input type"); }
axes_.push_back({input_index(input.GetString(), context.inputs()), 0, _NonUniformBins(std::move(dim_edges))});
} else if ( dimension.IsObject() ) { // UniformBinning
Expand Down Expand Up @@ -776,14 +787,14 @@ std::unique_ptr<CorrectionSet> CorrectionSet::from_file(const std::string& fn) {
#ifdef WITH_ZLIB
gzFile_s* fpz = gzopen(fn.c_str(), "r");
rapidjson::GzFileReadStream is(fpz, readBuffer, sizeof(readBuffer));
ok = json.ParseStream<rapidjson::kParseNanAndInfFlag>(is);
ok = json.ParseStream(is);
gzclose(fpz);
#else
throw std::runtime_error("Gzip-compressed JSON files are only supported if ZLIB is found when the package is built");
#endif
} else {
rapidjson::FileReadStream is(fp, readBuffer, sizeof(readBuffer));
ok = json.ParseStream<rapidjson::kParseNanAndInfFlag>(is);
ok = json.ParseStream(is);
fclose(fp);
}
if (!ok) {
Expand All @@ -798,7 +809,7 @@ std::unique_ptr<CorrectionSet> CorrectionSet::from_file(const std::string& fn) {

std::unique_ptr<CorrectionSet> CorrectionSet::from_string(const char * data) {
rapidjson::Document json;
rapidjson::ParseResult ok = json.Parse<rapidjson::kParseNanAndInfFlag>(data);
rapidjson::ParseResult ok = json.Parse(data);
if (!ok) {
throw std::runtime_error(
std::string("JSON parse error: ") + rapidjson::GetParseError_En(ok.Code())
Expand Down
65 changes: 32 additions & 33 deletions src/correctionlib/schemav2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import math
import sys
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Annotated, Dict, List, Optional, Set, Tuple, Union

from pydantic import (
AfterValidator,
BaseModel,
ConfigDict,
Field,
Expand Down Expand Up @@ -184,35 +186,44 @@ def validate_edges(cls, high: float, info: ValidationInfo) -> float:
return high


Infinity = Literal["inf", "+inf", "-inf"]
Edges = List[Union[float, Infinity]]


def validate_nonuniform_edges(edges: Edges) -> Edges:
for edge in edges:
if edge in ("inf", "+inf", "-inf"):
continue
if isinstance(edge, float):
if not math.isfinite(edge):
raise ValueError(
f"Edges array contains non-finite values: {edges}. Replace infinities with 'inf' or '-inf'. NaN is not allowed."
)
floatedges = [float(x) for x in edges]
for lo, hi in zip(floatedges[:-1], floatedges[1:]):
if lo >= hi:
raise ValueError(f"Binning edges not monotonically increasing: {edges}")
return edges


NonUniformBinning = Annotated[Edges, AfterValidator(validate_nonuniform_edges)]


class Binning(Model):
"""1-dimensional binning in an input variable"""

nodetype: Literal["binning"]
input: str = Field(
description="The name of the correction input variable this binning applies to"
)
edges: Union[List[float], UniformBinning] = Field(
edges: Union[NonUniformBinning, UniformBinning] = Field(
description="Edges of the binning, either as a list of monotonically increasing floats or as an instance of UniformBinning. edges[i] <= x < edges[i+1] => f(x, ...) = content[i](...)"
)
content: List[Content]
flow: Union[Content, Literal["clamp", "error"]] = Field(
description="Overflow behavior for out-of-bounds values"
)

@field_validator("edges")
@classmethod
def validate_edges(
cls, edges: Union[List[float], UniformBinning]
) -> Union[List[float], UniformBinning]:
if isinstance(edges, list):
for lo, hi in zip(edges[:-1], edges[1:]):
if hi <= lo:
raise ValueError(
f"Binning edges not monotonically increasing: {edges}"
)

return edges

@field_validator("content")
@classmethod
def validate_content(
Expand All @@ -234,8 +245,10 @@ def summarize(
) -> None:
nodecount["Binning"] += 1
inputstats[self.input].overflow &= self.flow != "error"
low = self.edges[0] if isinstance(self.edges, list) else self.edges.low
high = self.edges[-1] if isinstance(self.edges, list) else self.edges.high
low = float(self.edges[0]) if isinstance(self.edges, list) else self.edges.low
high = (
float(self.edges[-1]) if isinstance(self.edges, list) else self.edges.high
)
inputstats[self.input].min = min(inputstats[self.input].min, low)
inputstats[self.input].max = max(inputstats[self.input].max, high)
for item in self.content:
Expand All @@ -253,7 +266,7 @@ class MultiBinning(Model):
description="The names of the correction input variables this binning applies to",
min_length=1,
)
edges: List[Union[List[float], UniformBinning]] = Field(
edges: List[Union[NonUniformBinning, UniformBinning]] = Field(
description="Bin edges for each input"
)
content: List[Content] = Field(
Expand All @@ -266,20 +279,6 @@ class MultiBinning(Model):
description="Overflow behavior for out-of-bounds values"
)

@field_validator("edges")
@classmethod
def validate_edges(
cls, edges: List[Union[List[float], UniformBinning]]
) -> List[Union[List[float], UniformBinning]]:
for i, dim in enumerate(edges):
if isinstance(dim, list):
for lo, hi in zip(dim[:-1], dim[1:]):
if hi <= lo:
raise ValueError(
f"MultiBinning edges for axis {i} are not monotone increasing: {dim}"
)
return edges

@field_validator("content")
@classmethod
def validate_content(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_evaluator():
{
"nodetype": "binning",
"input": "pt",
"edges": [0, 20, 40, float("inf")],
"edges": [0, 20, 40, "inf"],
"flow": "error",
"content": [
schema.Category.model_validate(
Expand Down

0 comments on commit 88a50b4

Please sign in to comment.