Skip to content

Commit

Permalink
quadtree: using dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
miili committed Nov 12, 2023
1 parent ecb7a1e commit 182b471
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion lassie/models/semblance.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def maxima_node_idx(self, nparallel: int = 6) -> np.ndarray:
if self._node_idx_max is None:
self._node_idx_max = await asyncio.to_thread(
parstack.argmax,
self.semblance.astype(np.float64),
np.ascontiguousarray(self.semblance),
nparallel=nparallel,
)
return self._node_idx_max
Expand Down
20 changes: 10 additions & 10 deletions lassie/octree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import struct
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property
from hashlib import sha1
from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence
Expand Down Expand Up @@ -57,19 +58,20 @@ class NodeSplitError(Exception):
...


class Node(BaseModel):
@dataclass(slots=True)
class Node:
east: float
north: float
depth: float
size: float
semblance: float = 0.0

tree: Octree | None = Field(None, exclude=True)
children: tuple[Node, ...] = Field(default=(), exclude=True)
tree: Octree | None = None
children: tuple[Node, ...] = ()

_hash: bytes | None = PrivateAttr(None)
_children_cached: tuple[Node, ...] = PrivateAttr(())
_location: Location | None = PrivateAttr(None)
_hash: bytes | None = None
_children_cached: tuple[Node, ...] = ()
_location: Location | None = None

def split(self) -> tuple[Node, ...]:
if not self.tree:
Expand All @@ -82,7 +84,7 @@ def split(self) -> tuple[Node, ...]:
half_size = self.size / 2

self._children_cached = tuple(
Node.model_construct(
Node(
east=self.east + east * half_size / 2,
north=self.north + north * half_size / 2,
depth=self.depth + depth * half_size / 2,
Expand Down Expand Up @@ -287,9 +289,7 @@ def _get_root_nodes(self, length: float) -> list[Node]:
depth_nodes = np.arange(ext_depth // ln) * ln + ln / 2 + self.depth_bounds[0]

return [
Node.model_construct(
east=east, north=north, depth=depth, size=ln, tree=self
)
Node(east=east, north=north, depth=depth, size=ln, tree=self)
for east in east_nodes
for north in north_nodes
for depth in depth_nodes
Expand Down

0 comments on commit 182b471

Please sign in to comment.