diff --git a/src/dials/algorithms/spot_finding/finder.py b/src/dials/algorithms/spot_finding/finder.py index da777b4dec..32ead2c47a 100644 --- a/src/dials/algorithms/spot_finding/finder.py +++ b/src/dials/algorithms/spot_finding/finder.py @@ -9,7 +9,11 @@ import pickle from typing import Iterable, Tuple +import numpy as np +from scipy.spatial import distance + import libtbx +from dxtbx import flumpy from dxtbx.format.image import ImageBool from dxtbx.imageset import ImageSequence, ImageSet from dxtbx.model import ExperimentList, Scan @@ -790,10 +794,62 @@ def find_spots(self, experiments: ExperimentList) -> flex.reflection_table: # If any of the experiments are ToF experiments, add wavelength data if experiments.contains_tof_experiments(): reflections.add_beam_data(experiments) + reflections = self.filter_nearby_spots(reflections, experiments) # Return the reflections return reflections + def filter_nearby_spots(self, reflections, experiments, threshold=0.015, k=20): + + if "rlp" not in reflections: + reflections.centroid_px_to_mm(experiments) + reflections.map_centroids_to_reciprocal_space(experiments) + + points = reflections["rlp"] + points = flumpy.to_numpy(points) + + # Compute pairwise distances + pairwise_distances = distance.cdist(points, points) + + # Exclude self-distances + np.fill_diagonal(pairwise_distances, np.inf) + + new_points = np.copy(points) + used_indices = set() + + for i, point in enumerate(points): + if i not in used_indices: + # Sort distances for the current point + sorted_indices = np.argsort(pairwise_distances[i]) + k_nearest_indices = sorted_indices[ + 1 : k + 1 + ] # Exclude self, consider k-nearest neighbors + + # Filter distances based on threshold + close_neighbors_indices = [ + idx + for idx in k_nearest_indices + if pairwise_distances[i][idx] < threshold + ] + + if close_neighbors_indices: + close_neighbors_coords = points[close_neighbors_indices] + avg_coord = np.mean(close_neighbors_coords, axis=0) + new_points[i] = avg_coord + + # Mark points as used + for idx in close_neighbors_indices: + used_indices.add(idx) + new_points[idx] = [0, 0, 0] # Set coordinates to (0,0,0) + + new_points = flumpy.vec_from_numpy(new_points) + sel = new_points.norms() > 1e-7 + original_len = len(reflections) + reflections = reflections.select(sel) + new_len = len(reflections) + logger.info(f"Removed {original_len - new_len} spots due to close proximity.") + return reflections + def _get_spot_finding_algorithm(self, imageset): # The input mask