Skip to content

Commit

Permalink
additions, using buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Apr 29, 2024
1 parent e6ba246 commit 79a3a83
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 10 deletions.
14 changes: 6 additions & 8 deletions anemoi/datasets/create/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from anemoi.datasets import MissingDateError
from anemoi.datasets import open_dataset
from anemoi.datasets.create.persistent import PersistentDict
from anemoi.datasets.create.persistent import build_storage
from anemoi.datasets.data.misc import as_first_date
from anemoi.datasets.data.misc import as_last_date
from anemoi.datasets.dates.groups import Groups
Expand Down Expand Up @@ -574,7 +574,7 @@ def check_missing_dates(missing_dates):

self.name = name
storage_path = os.path.join(self.path + ".tmp_data", name)
self.tmp_storage = PersistentDict(directory=storage_path, create=True)
self.tmp_storage = build_storage(directory=storage_path, create=True)

def initialise(self):
self.tmp_storage.delete()
Expand Down Expand Up @@ -611,9 +611,9 @@ def _variables_with_nans(self):
return None

def allow_nan(self, name):
if self._variables_with_nans:
if self._variables_with_nans is not None:
return name in self._variables_with_nans
warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}. Assuming nans allowed for {name}.")
warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, Assuming nans allowed for {name}.")
return True

def run(self, parts):
Expand All @@ -634,7 +634,8 @@ def run(self, parts):
missing_dates = [date]

dates = sorted(dates_ok + missing_dates)
self.tmp_storage[dates] = [dates_ok, missing_dates, i, stats]
self.tmp_storage.add([dates_ok, missing_dates, i, stats], key=dates)
self.tmp_storage.flush()

LOG.info(f"Dataset {self.path} additions run.")

Expand Down Expand Up @@ -716,7 +717,6 @@ def finalise(self):
assert sums.shape == has_nans.shape

mean = sums / count
print(f"Aggregate {sums=} {count=} -> {mean=}")
assert sums.shape == mean.shape

x = squares / count - mean * mean
Expand Down Expand Up @@ -752,8 +752,6 @@ def _write(self, summary):
def check_statistics(self):
ds = open_dataset(self.path)
ref = ds.statistics
for k in ["mean", "stdev", "minimum", "maximum"]:
print("✅❌", k, ds.statistics[k], self.summary[k])
for k in ds.statistics:
assert np.all(np.isclose(ref[k], self.summary[k], rtol=1e-4, atol=1e-4)), (
k,
Expand Down
41 changes: 41 additions & 0 deletions anemoi/datasets/create/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def add_provenance(self, **kwargs):
with open(os.path.join(self.dirname, "provenance.json"), "w") as f:
json.dump(out, f)

def add(self, elt, *, key):
self[key] = elt

def __setitem__(self, key, elt):
h = hashlib.sha256(str(key).encode("utf-8")).hexdigest()
path = os.path.join(self.dirname, f"{h}.pickle")
Expand All @@ -76,6 +79,44 @@ def __setitem__(self, key, elt):

LOG.debug(f"Written {self.name} data for len {key} in {path}")

def flush(self):
pass


class BufferedPersistentDict(PersistentDict):
def __init__(self, buffer_size=1000, **kwargs):
self.buffer_size = buffer_size
self.elements = []
self.keys = []
self.storage = PersistentDict(**kwargs)

def add(self, elt, *, key):
self.elements.append(elt)
self.keys.append(key)
if len(self.keys) > self.buffer_size:
self.flush()

def flush(self):
k = sorted(self.keys)
self.storage.add(self.elements, key=k)
self.elements = []
self.keys = []

def items(self):
for keys, elements in self.storage.items():
for key, elt in zip(keys, elements):
yield key, elt

def delete(self):
self.storage.delete()

def create(self):
self.storage.create()


def build_storage(directory, create=True):
return BufferedPersistentDict(directory=directory, create=create)


if __name__ == "__main__":
N = 3
Expand Down
2 changes: 0 additions & 2 deletions anemoi/datasets/create/statistics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def compute_statistics(array, check_variables_names=None, allow_nan=False):
sums[i] = np.nansum(values, axis=1)
squares[i] = np.nansum(values * values, axis=1)
count[i] = np.sum(~np.isnan(values), axis=1)
print(".❌", i, count[i], sums[i], squares[i], ":", values.shape, array.shape)
has_nans[i] = np.isnan(values).any()

return {
Expand Down Expand Up @@ -311,7 +310,6 @@ def aggregate(self):
count = np.nansum(self.count, axis=0)
has_nans = np.any(self.has_nans, axis=0)
mean = sums / count
print(f"Aggregate {sums=} {count=} -> {mean=}")

assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape

Expand Down

0 comments on commit 79a3a83

Please sign in to comment.