diff --git a/src/climetlab/readers/grib/output.py b/src/climetlab/readers/grib/output.py index c7438104..5a6931a4 100644 --- a/src/climetlab/readers/grib/output.py +++ b/src/climetlab/readers/grib/output.py @@ -61,22 +61,8 @@ def __getitem__(self, key): return self.handle.get(key) -class GribOutput: - def __init__(self, file, split_output=False, template=None, **kwargs): - self._files = {} - self.fileobj = None - self.filename = None - - if isinstance(file, IOBase): - self.fileobj = file - split_output = False - else: - self.filename = file - - if split_output: - self.split_output = re.findall(r"\{(.*?)\}", self.filename) - else: - self.split_output = None +class GribCoder: + def __init__(self, template=None, **kwargs): self.template = template self._bbox = {} @@ -87,20 +73,7 @@ def __init__(self, file, split_output=False, template=None, **kwargs): def _normalize_kwargs_names(self, **kwargs): return kwargs - def f(self, handle): - if self.fileobj: - return self.fileobj, None - - if self.split_output: - path = self.filename.format(**{k: handle.get(k) for k in self.split_output}) - else: - path = self.filename - - if path not in self._files: - self._files[path] = open(path, "wb") - return self._files[path], path - - def write( + def encode( self, values, check_nans=False, @@ -161,20 +134,7 @@ def write( if "generatingProcessIdentifier" in metadata: handle.set("generatingProcessIdentifier", metadata["generatingProcessIdentifier"]) - file, path = self.f(handle) - handle.write(file) - - return handle, path - - def close(self): - for f in self._files.values(): - f.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, trace): - self.close() + return handle def update_metadata(self, handle, metadata, compulsary): # TODO: revisit that logic @@ -389,5 +349,73 @@ def _gg_field(self, values, metadata): return f"reduced_gg_{levtype}_{N}_grib{edition}" +class GribOutput: + def __init__(self, file, split_output=False, template=None, **kwargs): + self._files = {} + self.fileobj = None + self.filename = None + + if isinstance(file, IOBase): + self.fileobj = file + split_output = False + else: + self.filename = file + + if split_output: + self.split_output = re.findall(r"\{(.*?)\}", self.filename) + else: + self.split_output = None + + self._coder = GribCoder(template=template, **kwargs) + + def close(self): + for f in self._files.values(): + f.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, trace): + self.close() + + def write( + self, + values, + check_nans=False, + metadata={}, + template=None, + **kwargs, + ): + handle = self._coder.encode( + values, + check_nans=check_nans, + metadata=metadata, + template=template, + **kwargs, + ) + + file, path = self.f(handle) + handle.write(file) + + return handle, path + + def f(self, handle): + if self.fileobj: + return self.fileobj, None + + if self.split_output: + path = self.filename.format(**{k: handle.get(k) for k in self.split_output}) + else: + path = self.filename + + if path not in self._files: + self._files[path] = open(path, "wb") + return self._files[path], path + + def new_grib_output(*args, **kwargs): return GribOutput(*args, **kwargs) + + +def new_grib_coder(*args, **kwargs): + return GribCoder(*args, **kwargs)