Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
arnaudon committed Nov 19, 2024
1 parent b886749 commit 0d31987
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 22 deletions.
5 changes: 3 additions & 2 deletions examples/synthesize_single_neuron_y_direction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
This example shows how to synthesize a single cell with simple parameters.
"""

import numpy as np
from pathlib import Path
import json
from pathlib import Path

import numpy as np

import neurots

Expand Down
2 changes: 1 addition & 1 deletion neurots/generate/algorithms/abstractgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AbstractAlgo:

def __init__(self, input_data, params, start_point, context):
"""The TreeGrower Algorithm initialization."""
self.context = context if context is not None else {}
self.context = context
self.input_data = copy.deepcopy(input_data)
self.params = copy.deepcopy(params)
self.start_point = start_point
Expand Down
12 changes: 2 additions & 10 deletions neurots/generate/algorithms/basicgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from neurots.generate.algorithms.abstractgrower import AbstractAlgo
from neurots.generate.algorithms.common import bif_methods
from neurots.generate.algorithms.common import section_data
from neurots.morphmath import rotation

logger = logging.getLogger(__name__)

Expand All @@ -26,14 +25,7 @@ class TrunkAlgo(AbstractAlgo):
context (Any): An object containing contextual information.
"""

def __init__(
self,
input_data,
params,
start_point,
context=None,
**_,
):
def __init__(self, input_data, params, start_point, context=None, **_):
"""Constructor of the TrunkAlgo class."""
super().__init__(input_data, params, start_point, context)
self.bif_method = bif_methods[params["branching_method"]]
Expand All @@ -60,7 +52,7 @@ def bifurcate(self, current_section):
Returns:
tuple[dict, dict]: Two dictionaries containing the two children sections data.
"""
dir1, dir2 = self.bif_method(pia_rotation=self.pia_rotation)
dir1, dir2 = self.bif_method(pia_rotation=self.context.get("y_rotation"))
first_point = np.array(current_section.last_point)
stop = current_section.stop_criteria

Expand Down
3 changes: 1 addition & 2 deletions neurots/generate/algorithms/tmdgrower.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from neurots.generate.algorithms.common import TMDStop
from neurots.generate.algorithms.common import bif_methods
from neurots.generate.algorithms.common import section_data
from neurots.morphmath import rotation
from neurots.morphmath import sample
from neurots.morphmath.utils import norm

Expand Down Expand Up @@ -253,7 +252,7 @@ def bifurcate(self, current_section):

if current_section.process == "major":
dir1, dir2 = bif_methods["directional"](
current_section.direction, angles=ang, y_rotation=self.context.get('y_rotation')
current_section.direction, angles=ang, y_rotation=self.context.get("y_rotation")
)

if not self._found_last_bif:
Expand Down
13 changes: 8 additions & 5 deletions neurots/generate/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@
from numpy.random import RandomState
from numpy.random import SeedSequence

from neurots.morphmath import rotation
from neurots.utils import Y_DIRECTION
from neurots.generate import diametrizer
from neurots.generate import orientations as _oris
from neurots.generate.orientations import OrientationManager
from neurots.generate.orientations import check_3d_angles
from neurots.generate.soma import Soma
from neurots.generate.soma import SomaGrower
from neurots.generate.tree import TreeGrower
from neurots.morphmath import rotation
from neurots.morphmath import sample
from neurots.morphmath.utils import normalize_vectors
from neurots.preprocess import preprocess_inputs
from neurots.utils import Y_DIRECTION
from neurots.utils import NeuroTSError
from neurots.utils import convert_from_legacy_neurite_type
from neurots.utils import point_to_section_segment
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
):
"""Constructor of the NeuronGrower class."""
self.neuron = Morphology()
self.context = self.process_context(context)
self.context = self._process_context(context)
if rng_or_seed is None or isinstance(
rng_or_seed, (int, np.integer, SeedSequence, BitGenerator)
):
Expand Down Expand Up @@ -137,10 +137,12 @@ def __init__(

self._trunk_orientations_class = trunk_orientations_class

def process_context(self, context):
"""Process the context if it is a dictionary."""
def _process_context(self, context):
"""Apply some required processing to the context dictionary."""
if context is None:
context = {}

# we ofen need to use the y_direction as a rotation, so we save to it once here
if "y_direction" in context:
context["y_rotation"] = rotation.rotation_matrix_from_vectors(
Y_DIRECTION, context["y_direction"]
Expand Down Expand Up @@ -385,6 +387,7 @@ def _3d_angles_grow_trunks(self):
)
for neurite_type in self.input_parameters["grow_types"]:
orientations = trunk_orientations_manager.compute_tree_type_orientations(neurite_type)

for p in self.soma_grower.add_points_from_orientations(orientations):
self.active_neurites.append(
TreeGrower(
Expand Down
2 changes: 1 addition & 1 deletion neurots/generate/orientations.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, soma, parameters, distributions, context, rng):
self._soma = soma
self._parameters = parameters
self._distributions = distributions
self._context = context if context is not None else {}
self._context = context
self._rng = rng

self._orientations = {}
Expand Down
2 changes: 1 addition & 1 deletion neurots/generate/section.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
self.stop_criteria = stop_criteria
self.process = process
self.latest_directions = deque(maxlen=MEMORY)
self.context = context if context is not None else {}
self.context = context
self._rng = random_generator
self.step_size_distribution = step_size_distribution
self.pathlength = 0 if parent is None else pathlength
Expand Down
1 change: 1 addition & 0 deletions neurots/generate/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def __init__(
def _initialize_algorithm(self):
"""Initialization steps for TreeGrower."""
grow_meth = growth_algorithms[self.params["growth_method"]]

growth_algo = grow_meth(
input_data=self.distr,
params=self.params,
Expand Down

0 comments on commit 0d31987

Please sign in to comment.