From 182b471cdd189ce2dbd88b36d82dc90554371385 Mon Sep 17 00:00:00 2001 From: miili Date: Sun, 12 Nov 2023 18:39:32 +0100 Subject: [PATCH] quadtree: using dataclasses --- lassie/models/semblance.py | 2 +- lassie/octree.py | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lassie/models/semblance.py b/lassie/models/semblance.py index 75f5ef78..21cf3ac6 100644 --- a/lassie/models/semblance.py +++ b/lassie/models/semblance.py @@ -111,7 +111,7 @@ async def maxima_node_idx(self, nparallel: int = 6) -> np.ndarray: if self._node_idx_max is None: self._node_idx_max = await asyncio.to_thread( parstack.argmax, - self.semblance.astype(np.float64), + np.ascontiguousarray(self.semblance), nparallel=nparallel, ) return self._node_idx_max diff --git a/lassie/octree.py b/lassie/octree.py index 30e8823d..66d89ee3 100644 --- a/lassie/octree.py +++ b/lassie/octree.py @@ -4,6 +4,7 @@ import logging import struct from collections import defaultdict +from dataclasses import dataclass from functools import cached_property from hashlib import sha1 from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence @@ -57,19 +58,20 @@ class NodeSplitError(Exception): ... -class Node(BaseModel): +@dataclass(slots=True) +class Node: east: float north: float depth: float size: float semblance: float = 0.0 - tree: Octree | None = Field(None, exclude=True) - children: tuple[Node, ...] = Field(default=(), exclude=True) + tree: Octree | None = None + children: tuple[Node, ...] = () - _hash: bytes | None = PrivateAttr(None) - _children_cached: tuple[Node, ...] = PrivateAttr(()) - _location: Location | None = PrivateAttr(None) + _hash: bytes | None = None + _children_cached: tuple[Node, ...] = () + _location: Location | None = None def split(self) -> tuple[Node, ...]: if not self.tree: @@ -82,7 +84,7 @@ def split(self) -> tuple[Node, ...]: half_size = self.size / 2 self._children_cached = tuple( - Node.model_construct( + Node( east=self.east + east * half_size / 2, north=self.north + north * half_size / 2, depth=self.depth + depth * half_size / 2, @@ -287,9 +289,7 @@ def _get_root_nodes(self, length: float) -> list[Node]: depth_nodes = np.arange(ext_depth // ln) * ln + ln / 2 + self.depth_bounds[0] return [ - Node.model_construct( - east=east, north=north, depth=depth, size=ln, tree=self - ) + Node(east=east, north=north, depth=depth, size=ln, tree=self) for east in east_nodes for north in north_nodes for depth in depth_nodes