Skip to content

Commit

Permalink
[TASK] return appropriate python types
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreSchnizer committed Sep 20, 2023
1 parent 22f5f2a commit bc2701d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 62 deletions.
78 changes: 28 additions & 50 deletions python/src/elements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,8 @@ template<typename double_type, typename Class>
void field_kick_add_set_methods(py::class_<Class> t_mapper)
{
t_mapper
.def("set_dx", [](Class &kick, const double_type& dx)
{kick.getTransform()->setDx(dx);})
.def("set_dy", [](Class &kick, const double_type& dy)
{kick.getTransform()->setDy(dy);})
.def("set_dx", [](Class &kick, const double_type& dx) { kick.getTransform()->setDx(dx) ; })
.def("set_dy", [](Class &kick, const double_type& dy) { kick.getTransform()->setDy(dy) ; })
/*
* Todo: seems to crash the interpreter
*/
Expand All @@ -158,9 +156,9 @@ template<typename Class>
void field_kick_add_get_methods(py::class_<Class> t_mapper)
{
t_mapper
.def("get_dx", [](Class &kick) { return gpy::to_pyobject(kick.getTransform()->getDx()); })
.def("get_dy", [](Class &kick) { return gpy::to_pyobject(kick.getTransform()->getDy()); })
.def("get_roll", [](Class &kick) { return gpy::to_pyobject(kick.getTransform()->getRoll()); })
.def("get_dx", [](Class &kick) { return gpy::to_pyobject( kick.getTransform()->getDx()); })
.def("get_dy", [](Class &kick) { return gpy::to_pyobject( kick.getTransform()->getDy()); })
.def("get_roll", [](Class &kick) { return gpy::to_pyobject( kick.getTransform()->getRoll()); })
;
}

Expand All @@ -173,42 +171,25 @@ void field_kick_add_methods(py::class_<Class> t_mapper)
field_kick_add_get_methods(t_mapper);

t_mapper
.def("is_thick",
&Class::isThick)
.def("as_thick",
&Class::asThick)
.def("get_number_of_integration_steps",
&Class::getNumberOfIntegrationSteps)
.def("set_number_of_integration_steps",
&Class::setNumberOfIntegrationSteps)
.def("get_integration_method",
&Class::getIntegrationMethod)
.def("get_curvature",
&Class::getCurvature)
.def("set_curvature",
&Class::setCurvature)
.def("assuming_curved_trajectory",
&Class::assumingCurvedTrajectory)
.def("get_bending_angle",
&Class::getBendingAngle)
.def("set_bending_angle",
&Class::setBendingAngle)
.def("set_entrance_angle",
&Class::setEntranceAngle)
.def("get_entrance_angle",
&Class::getEntranceAngle)
.def("set_exit_angle",
&Class::setExitAngle)
.def("get_exit_angle",
&Class::getExitAngle)
.def("get_radiation_delegate",
&Class::getRadiationDelegate)
.def("set_radiation_delegate",
&Class::setRadiationDelegate)
.def("get_field_interpolator",
&Class::getFieldInterpolator)
.def("set_field_interpolator",
&Class::setFieldInterpolator)
.def("is_thick", &Class::isThick )
.def("as_thick", &Class::asThick )
.def("get_number_of_integration_steps", &Class::getNumberOfIntegrationSteps )
.def("set_number_of_integration_steps", &Class::setNumberOfIntegrationSteps )
.def("get_integration_method", &Class::getIntegrationMethod )
.def("assuming_curved_trajectory", &Class::assumingCurvedTrajectory )
.def("set_curvature", &Class::setCurvature )
.def("set_bending_angle", &Class::setBendingAngle )
.def("set_entrance_angle", &Class::setEntranceAngle )
.def("set_exit_angle", &Class::setExitAngle )
.def("get_radiation_delegate", &Class::getRadiationDelegate )
.def("set_radiation_delegate", &Class::setRadiationDelegate )
.def("get_field_interpolator", &Class::getFieldInterpolator )
.def("set_field_interpolator", &Class::setFieldInterpolator )

.def("get_curvature", [](const Class& kick) { return gpy::to_pyobject( kick.getCurvature() ); })
.def("get_bending_angle", [](const Class& kick) { return gpy::to_pyobject( kick.getBendingAngle() ); })
.def("get_entrance_angle", [](const Class& kick) { return gpy::to_pyobject( kick.getEntranceAngle() ); })
.def("get_exit_angle", [](const Class& kick) { return gpy::to_pyobject( kick.getExitAngle() ); })
;
}

Expand Down Expand Up @@ -237,8 +218,9 @@ void classical_magnet_add_methods(py::class_<Class> t_mapper)
// .def("set_multipoles",&Class::setMultipoles)
.def("get_main_multipole_number",
&Class::getMainMultipoleNumber)
.def("get_main_multipole_strength",
&Class::getMainMultipoleStrength)
.def("get_main_multipole_strength", [](Class& inst){
return gpy::to_pyobject(inst.getMainMultipoleStrength());
})
.def("get_main_multipole_strength_component",
&Class::getMainMultipoleStrengthComponent)

Expand Down Expand Up @@ -325,11 +307,7 @@ struct TemplatedClasses
py::class_<QuadK, std::shared_ptr<QuadK>>
(this->m_module, quad_name.c_str(), cm)
.def(py::init<const Config &>());
#if 0
// J.B. 14/07/23: test.
.def("Q_init", &QuadK::Q_init)
#endif
;

std::string sext_name = "Sextupole" + this->m_suffix;
typedef tse::SextupoleTypeWithKnob<C> SextK;
py::class_<SextK, std::shared_ptr<SextK>>
Expand Down
24 changes: 17 additions & 7 deletions python/src/interpolation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ void add_methods_multipoles(py::class_<Class> t_mapper)
{
t_mapper
.def("get_multipole", [](const Class& inst, int n){ return gpy::to_pyobject(inst.getMultipole(n)); })
.def("set_multipole", &Class::setMultipole)
.def("set_multipole", [](Class& inst, int n, typename Types::complex_type coeff){
inst.setMultipole(n, coeff);
})
.def("apply_roll_angle", &Class::applyRollAngle)
.def("__str__", &Class::pstr)
.def("__repr__", &Class::repr)
Expand Down Expand Up @@ -270,19 +272,27 @@ void py_thor_scsi_init_field_interpolation(py::module &m) {
> multipoles_tpsa_base(m, "_MultipolesBaseTpsa", field2dintpvar);
add_methods_multipoles<tsc::TpsaVariantType, tsc::TwoDimensionalMultipolesKnobbed<tsc::TpsaVariantType>>(multipoles_tpsa_base);

py::class_<tsc::TwoDimensionalMultipolesTpsa, std::shared_ptr<tsc::TwoDimensionalMultipolesTpsa>>
multipoles_tpsa(m, "TwoDimensionalMultipolesTpsa", multipoles_tpsa_base //, py::buffer_protocol()
);
multipoles_tpsa
.def("set_multipole", [](tsc::TwoDimensionalMultipolesTpsa& inst, const unsigned int n, gtpsa::ctpsa& obj){
multipoles_tpsa_base
.def("set_multipole", [](tsc::TwoDimensionalMultipolesKnobbed<tsc::TpsaVariantType>& inst, const unsigned int n, gtpsa::ctpsa& obj){
inst.setMultipole(n, obj);
})
.def("set_multipole", [](tsc::TwoDimensionalMultipolesTpsa& inst, const unsigned int n, std::complex<double>& obj){
.def("set_multipole", [](tsc::TwoDimensionalMultipolesKnobbed<tsc::TpsaVariantType>& inst, const unsigned int n, std::complex<double> obj){
inst.setMultipole(n, obj);
})
;
py::class_<tsc::TwoDimensionalMultipolesTpsa, std::shared_ptr<tsc::TwoDimensionalMultipolesTpsa>>
multipoles_tpsa(m, "TwoDimensionalMultipolesTpsa", multipoles_tpsa_base //, py::buffer_protocol()
);
add_methods_multipoles<tsc::TpsaVariantType, tsc::TwoDimensionalMultipolesTpsa>(multipoles_tpsa);
multipoles_tpsa
// why do I need that here again ...
// does it not cast down here from above?
.def("set_multipole", [](tsc::TwoDimensionalMultipolesTpsa& inst, const unsigned int n, gtpsa::ctpsa& obj){
inst.setMultipole(n, obj);
})
.def("set_multipole", [](tsc::TwoDimensionalMultipolesTpsa& inst, const unsigned int n, std::complex<double> obj){
inst.setMultipole(n, obj);
})
.def(py::init<const std::complex<double>, const unsigned int>(), "initalise multipoles",
py::arg("default_value"), py::arg("h_max") = tsc::max_multipole)
// shall one return a list of base objects ?
Expand Down
3 changes: 2 additions & 1 deletion python/tests/knobbable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_knob_quadrupole():
assert nquad.name == quad.name
assert nquad.get_length() == pytest.approx(quad.get_length(), rel=1e-12)
# test of to_object should be avoided
K_check = nquad.get_main_multipole_strength().to_object()
K_check = nquad.get_main_multipole_strength()
assert K_check == pytest.approx(K, rel=1e-12)


Expand Down Expand Up @@ -193,6 +193,7 @@ def test_knobbable_dx_set_from_float():
def test_knobbalbe_muls_set_from_float():
from thor_scsi.lib import TwoDimensionalMultipolesTpsa as Mul2DTpsa
m = Mul2DTpsa(0)
print(type(m))
m.set_multipole(3, 1e-3+1e-2j)


Expand Down
9 changes: 5 additions & 4 deletions python/thor_scsi/utils/knobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,22 @@ def make_magnet_strength_knobbable(
):
if multipole_number is None:
multipole_number = magnet.get_main_multipole_number()
k_orig = magnet.get_main_multipole_strength().to_object()
k_orig = magnet.get_main_multipole_strength()
k = gtpsa.ctpsa(desc, po, mapping=named_index)
k.name = magnet.name + "_K"
k.set_knob(k_orig, "K")
# k.set_variable(k_orig, "K")
magnet.get_multipoles().set_multipole(multipole_number, gtpsa.CTpsaOrComplex(k))
print(type(magnet.get_multipoles()))
magnet.get_multipoles().set_multipole(multipole_number, k)


def make_magnet_strength_unknobbable(magnet: tslib.Mpole, multipole_number: int = None):
if multipole_number is None:
multipole_number = magnet.get_main_multipole_number()
mul = complex(
magnet.get_multipoles().get_multipole(multipole_number).to_object().get()
magnet.get_multipoles().get_multipole(multipole_number)
)
magnet.get_multipoles().set_multipole(multipole_number, gtpsa.CTpsaOrComplex(mul))
magnet.get_multipoles().set_multipole(multipole_number, mul)


def make_magnet_knobbable(
Expand Down

0 comments on commit bc2701d

Please sign in to comment.