Skip to content

Commit

Permalink
Updates to dials.scale error modelling to handle stills data (dials#2654
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jbeilstenedmands committed May 1, 2024
1 parent a8945b6 commit 9b07803
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 16 deletions.
1 change: 1 addition & 0 deletions newsfragments/2654.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
``dials.scale``: Add filtering options to default basic error model to allow error modelling of stills data
7 changes: 6 additions & 1 deletion src/dials/algorithms/scaling/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,16 +641,21 @@ def targeted_scaling_algorithm(scaler):
scaler.make_ready_for_scaling()
scaler.perform_scaling()

expand_and_do_outlier_rejection(scaler, calc_cov=True)
do_error_analysis(scaler, reselect=True)

if scaler.params.scaling_options.full_matrix and (
scaler.params.scaling_refinery.engine == "SimpleLBFGS"
):
scaler.perform_scaling(
engine=scaler.params.scaling_refinery.full_matrix_engine,
max_iterations=scaler.params.scaling_refinery.full_matrix_max_iterations,
)
else:
scaler.perform_scaling()

expand_and_do_outlier_rejection(scaler, calc_cov=True)
# do_error_analysis(scaler, reselect=False)
do_error_analysis(scaler, reselect=False)

scaler.prepare_reflection_tables_for_output()
return scaler
11 changes: 7 additions & 4 deletions src/dials/algorithms/scaling/error_model/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
logger = logging.getLogger("dials")


def run_error_model_refinement(model, Ih_table):
def run_error_model_refinement(
model, Ih_table, min_partiality=0.4, use_stills_filtering=False
):
"""
Refine an error model for the input data, returning the model.
Expand All @@ -30,7 +32,9 @@ def run_error_model_refinement(model, Ih_table):
RuntimeError: can be raised in LBFGS minimiser.
"""
assert Ih_table.n_work_blocks == 1
model.configure_for_refinement(Ih_table.blocked_data_list[0])
model.configure_for_refinement(
Ih_table.blocked_data_list[0], min_partiality, use_stills_filtering
)
if not model.active_parameters:
logger.info("All error model parameters fixed, skipping refinement")
else:
Expand Down Expand Up @@ -170,7 +174,6 @@ def test_value_convergence(self):
r2 = self.avals[-2]
except IndexError:
return False

if r2 > 0:
return abs((r2 - r1) / r2) < self._avals_tolerance
else:
Expand Down Expand Up @@ -201,7 +204,7 @@ def _refine_component(self, model, target, parameterisation):
def run(self):
"""Refine the model."""
if self.parameters_to_refine == ["a", "b"]:
for n in range(20): # usually converges in around 5 cycles
for n in range(50): # usually converges in around 5 cycles
self._refine_a()
# now update in model
self.avals.append(self.model.components["a"].parameters[0])
Expand Down
78 changes: 74 additions & 4 deletions src/dials/algorithms/scaling/error_model/error_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@
"determine both parameters concurrently. If minimisation=None,"
"the model parameters are fixed to their initial or given values."
.expert_level = 3
stills {
min_Isigma = 2.0
.type=float
.help = "Minimum uncorrected I/sigma for individual reflections used in error model optimisation"
min_multiplicity = 4
.type = int
.help = "Only reflections with at least this multiplicity (after Isigma filtering) are"
"used in error model optimisation."
}
min_Ih = 25.0
.type = float
.help = "Reflections with expected intensity above this value are to."
Expand Down Expand Up @@ -248,7 +257,10 @@ def _create_summation_matrix(self):
n = self.Ih_table.size
self.binning_info["n_reflections"] = n
summation_matrix = sparse.matrix(n, self.n_bins)
# calculate expected intensity value in pixels on scale of each image
Ih = self.Ih_table.Ih_values * self.Ih_table.inverse_scale_factors
if "partiality" in self.Ih_table.Ih_table:
Ih *= self.Ih_table.Ih_table["partiality"].to_numpy()
size_order = flex.sort_permutation(flumpy.from_numpy(Ih), reverse=True)
Imax = Ih.max()
min_Ih = Ih.min()
Expand Down Expand Up @@ -383,7 +395,6 @@ def __init__(self, a=None, b=None, basic_params=None):
see if a user specified fixed value is set. If no fixed values are given
then the model starts with the default parameters a=1.0 b=0.02
"""

self.free_components = []
self.sortedy = None
self.sortedx = None
Expand All @@ -408,14 +419,19 @@ def __init__(self, a=None, b=None, basic_params=None):
if not basic_params.b:
self._active_parameters.append("b")

def configure_for_refinement(self, Ih_table, min_partiality=0.4):
def configure_for_refinement(
self, Ih_table, min_partiality=0.4, use_stills_filtering=False
):
"""
Add data to allow error model refinement.
Raises: ValueError if insufficient reflections left after filtering.
"""
self.filtered_Ih_table = self.filter_unsuitable_reflections(
Ih_table, self.params, min_partiality
Ih_table,
self.params,
min_partiality,
use_stills_filtering,
)
# always want binning info so that can calc for output.
self.binner = ErrorModelBinner(
Expand Down Expand Up @@ -455,8 +471,19 @@ def n_refl(self):
return self.filtered_Ih_table.size

@classmethod
def filter_unsuitable_reflections(cls, Ih_table, error_params, min_partiality):
def filter_unsuitable_reflections(
cls, Ih_table, error_params, min_partiality, use_stills_filtering
):
"""Filter suitable reflections for minimisation."""
if use_stills_filtering:
return filter_unsuitable_reflections_stills(
Ih_table,
error_params.stills.min_multiplicity,
error_params.stills.min_Isigma,
min_partiality=min_partiality,
min_reflections_required=cls.min_reflections_required,
min_Ih=error_params.min_Ih,
)
return filter_unsuitable_reflections(
Ih_table,
min_Ih=error_params.min_Ih,
Expand Down Expand Up @@ -570,6 +597,49 @@ def binned_variances_summary(self):
)


def filter_unsuitable_reflections_stills(
Ih_table,
min_multiplicity,
min_Isigma,
min_partiality,
min_reflections_required,
min_Ih,
):
"""Filter suitable reflections for minimisation."""

if "partiality" in Ih_table.Ih_table:
sel = Ih_table.Ih_table["partiality"].to_numpy() > min_partiality
Ih_table = Ih_table.select(sel)

sel = (Ih_table.intensities / (Ih_table.variances**0.5)) >= min_Isigma
Ih_table = Ih_table.select(sel)

Ih = Ih_table.Ih_values * Ih_table.inverse_scale_factors
if "partiality" in Ih_table.Ih_table:
Ih *= Ih_table.Ih_table["partiality"].to_numpy()
sel = Ih > min_Ih
Ih_table = Ih_table.select(sel)

n_h = Ih_table.calc_nh()
sigmaprime = calc_sigmaprime([1.0, 0.0], Ih_table)
delta_hl = calc_deltahl(Ih_table, n_h, sigmaprime)
# Optimise on the central bulk distribution of the data - avoid the few
# reflections in the long tails.
sel = np.abs(delta_hl) < 6.0
Ih_table = Ih_table.select(sel)

sel = Ih_table.calc_nh() >= min_multiplicity
Ih_table = Ih_table.select(sel)
n = Ih_table.size

if n < min_reflections_required:
raise ValueError(
"Insufficient reflections (%s < %s) to perform error modelling."
% (n, min_reflections_required)
)
return Ih_table


def filter_unsuitable_reflections(
Ih_table, min_Ih, min_partiality, min_reflections_required
):
Expand Down
20 changes: 17 additions & 3 deletions src/dials/algorithms/scaling/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,9 @@ def __init__(self, params, experiment, reflection_table, for_multi=False):
self.free_set_selection = flex.bool(self.n_suitable_refl, False)
self._free_Ih_table = None # An array of len n_suitable_refl
self._configure_model_and_datastructures(for_multi=for_multi)
self.is_still = True
if self._experiment.scan and self._experiment.scan.get_oscillation()[1] != 0.0:
self.is_still = False
if self.params.weighting.error_model.error_model:
# reload current error model parameters, or create new null
self.experiment.scaling_model.load_error_model(
Expand Down Expand Up @@ -380,7 +383,10 @@ def perform_error_optimisation(self, update_Ih=True):
Ih_table, _ = self._create_global_Ih_table(anomalous=True, remove_outliers=True)
try:
model = run_error_model_refinement(
self._experiment.scaling_model.error_model, Ih_table
self._experiment.scaling_model.error_model,
Ih_table,
self.params.reflection_selection.min_partiality,
use_stills_filtering=self.is_still,
)
except (ValueError, RuntimeError) as e:
logger.info(e)
Expand Down Expand Up @@ -1500,7 +1506,12 @@ def perform_error_optimisation(self, update_Ih=True):
continue
tables = [s.get_valid_reflections().select(~s.outliers) for s in scalers]
space_group = scalers[0].experiment.crystal.get_space_group()
Ih_table = IhTable(tables, space_group, anomalous=True)
Ih_table = IhTable(
tables,
space_group,
anomalous=True,
additional_cols=["partiality"],
)
if len(minimisation_groups) == 1:
logger.info("Determining a combined error model for all datasets")
else:
Expand All @@ -1509,7 +1520,10 @@ def perform_error_optimisation(self, update_Ih=True):
)
try:
model = run_error_model_refinement(
scalers[0]._experiment.scaling_model.error_model, Ih_table
scalers[0]._experiment.scaling_model.error_model,
Ih_table,
min_partiality=self.params.reflection_selection.min_partiality,
use_stills_filtering=scalers[0].is_still,
)
except (ValueError, RuntimeError) as e:
logger.info(e)
Expand Down
17 changes: 15 additions & 2 deletions src/dials/command_line/refine_error_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
phil_scope = phil.parse(
"""
include scope dials.algorithms.scaling.error_model.error_model.phil_scope
min_partiality = 0.4
.type = float
.help = "Use reflections with at least this partiality in error model optimisation."
intensity_choice = *profile sum combine
.type = choice
.help = "Use profile or summation intensities"
Expand Down Expand Up @@ -101,13 +104,23 @@ def refine_error_model(params, experiments, reflection_tables):
reflection_tables[i] = table
space_group = experiments[0].crystal.get_space_group()
Ih_table = IhTable(
reflection_tables, space_group, additional_cols=["partiality"], anomalous=True
reflection_tables,
space_group,
additional_cols=["partiality"],
anomalous=True,
)

use_stills_filtering = True
for expt in experiments:
if expt.scan and expt.scan.get_oscillation()[1] != 0.0:
use_stills_filtering = False
break
# now do the error model refinement
model = BasicErrorModel(basic_params=params.basic)
try:
model = run_error_model_refinement(model, Ih_table)
model = run_error_model_refinement(
model, Ih_table, params.min_partiality, use_stills_filtering
)
except (ValueError, RuntimeError) as e:
logger.info(e)
else:
Expand Down
9 changes: 7 additions & 2 deletions tests/algorithms/scaling/test_error_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,24 @@ def test_error_model_on_simulated_data(
)


def test_errormodel(large_reflection_table, test_sg):
@pytest.mark.parametrize("use_stills_filtering", [True, False])
def test_errormodel(large_reflection_table, test_sg, use_stills_filtering):
"""Test the initialisation and methods of the error model."""

Ih_table = IhTable([large_reflection_table], test_sg, nblocks=1)
block = Ih_table.blocked_data_list[0]
params = generated_param()
params.weighting.error_model.basic.stills.min_multiplicity = 2
params.weighting.error_model.basic.stills.min_Isigma = 0.0
params.weighting.error_model.basic.n_bins = 2
params.weighting.error_model.basic.min_Ih = 1.0
em = BasicErrorModel
em.min_reflections_required = 1
error_model = em(basic_params=params.weighting.error_model.basic)
error_model.min_reflections_required = 1
error_model.configure_for_refinement(block)
error_model.configure_for_refinement(
block, use_stills_filtering=use_stills_filtering
)
assert error_model.binner.summation_matrix[0, 1] == 1
assert error_model.binner.summation_matrix[1, 1] == 1
assert error_model.binner.summation_matrix[2, 0] == 1
Expand Down
8 changes: 8 additions & 0 deletions tests/command_line/test_ssx_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,11 @@ def test_ssx_reduction(dials_data, tmp_path):
)
assert not result.returncode and not result.stderr
assert (tmp_path / "compute_delta_cchalf.html").is_file()

# will not be able to refine error model due to lack of data, but should rather exit cleanly.
result = subprocess.run(
[shutil.which("dials.refine_error_model"), scale_expts, scale_refls],
cwd=tmp_path,
capture_output=True,
)
assert not result.returncode and not result.stderr

0 comments on commit 9b07803

Please sign in to comment.