diff --git a/modules/saxs/include/WeightedProfileFitter.h b/modules/saxs/include/WeightedProfileFitter.h index 60367d7136..e98cb359e0 100644 --- a/modules/saxs/include/WeightedProfileFitter.h +++ b/modules/saxs/include/WeightedProfileFitter.h @@ -46,6 +46,18 @@ class WeightedProfileFitter : public ProfileFitter { Wb_ = W_.asDiagonal() * Wb_; } +// When used from Python, compute_score will return the weights as a second +// return value, rather than modifying a passed-in vector +#ifdef SWIG +%typemap(in, numinputs=0) Vector& weights (Vector temp) { + $1 = &temp; +} +%typemap(argout) Vector& weights { + PyObject *obj = ConvertSequence, Convert>::create_python_object(ValueOrObject>::get(*$1), $descriptor(double*), SWIG_POINTER_OWN); + $result = SWIG_AppendOutput($result, obj); +} +#endif + //! compute a weighted score that minimizes chi /** it is assumed that the q values of the profiles are the same as @@ -56,6 +68,10 @@ class WeightedProfileFitter : public ProfileFitter { double compute_score(const ProfilesTemp& profiles, Vector& weights, bool use_offset = false, bool NNLS = true) const; +#ifdef SWIG +%clear Vector& weights; +#endif + //! fit profiles by optimization of c1/c2 and weights /** it is assumed that the q values of the profiles are the same as diff --git a/modules/saxs/test/test_weighted_profile.py b/modules/saxs/test/test_weighted_profile.py index 8388ee1fbb..901ceac6d7 100644 --- a/modules/saxs/test/test_weighted_profile.py +++ b/modules/saxs/test/test_weighted_profile.py @@ -58,6 +58,19 @@ def test_weighted_profile(self): saxs_score = IMP.saxs.WeightedProfileFitterChi(weighted_profile) profile_list = [resampled_profile1, resampled_profile2] + # Get score for single profile + score, weights = saxs_score.compute_score(profile_list[:1], + False, False) + self.assertAlmostEqual(score, 81.765, delta=0.001) + self.assertEqual(len(weights), 1) + self.assertAlmostEqual(weights[0], 1.0, delta=0.001) + + # Get score for both profiles + score, weights = saxs_score.compute_score(profile_list, False, False) + self.assertAlmostEqual(score, 3.031, delta=0.001) + self.assertEqual(len(weights), 2) + self.assertAlmostEqual(weights[0], 0.297, delta=0.001) + self.assertAlmostEqual(weights[1], 0.703, delta=0.001) wfp = saxs_score.fit_profile(profile_list, 0.95, 1.05, -2.0, 4.0) chi = wfp.get_chi_square()