diff --git a/include/hyperjet.h b/include/hyperjet.h index 309c7640..92f1c6f0 100644 --- a/include/hyperjet.h +++ b/include/hyperjet.h @@ -480,19 +480,40 @@ class DDScalar { return Eigen::Map>(ptr() + 1, size()); } - Eigen::Matrix hm() const + Eigen::Matrix hm(const std::string mode) const { Eigen::Matrix result(size(), size()); + hm(mode, result); + + return result; + } + + void hm(const std::string mode, Eigen::Ref> out) const + { index it = 0; for (index i = 0; i < size(); i++) { for (index j = i; j < size(); j++) { - result(i, j) = h(it++); + out(i, j) = h(it++); } } - return result; + if (mode == "zeros") { + for (index i = 0; i < size(); i++) { + for (index j = 0; j < i; j++) { + out(i, j) = 0; + } + } + } else if (mode == "full") { + for (index i = 0; i < size(); i++) { + for (index j = 0; j < i; j++) { + out(i, j) = out(j, i); + } + } + } else { + throw std::runtime_error("Invalid value for 'mode'"); + } } void set_hm(Eigen::Ref> value) diff --git a/src/python_module.cpp b/src/python_module.cpp index 7ad6c2a7..eec8d123 100644 --- a/src/python_module.cpp +++ b/src/python_module.cpp @@ -49,7 +49,7 @@ void register_ddscalar(pybind11::module& m, const std::string& name) .def("abs", &Type::abs) .def("h", py::overload_cast(&Type::h), "row"_a, "col"_a) .def("set_h", py::overload_cast(&Type::set_h), "row"_a, "col"_a, "value"_a) - .def("hm", &Type::hm) + .def("hm", py::overload_cast(&Type::hm, py::const_), "mode"_a="full") .def("set_hm", &Type::set_hm, "value"_a) // methods: arithmetic operations .def("reciprocal", &Type::reciprocal) diff --git a/tests/test_DDScalar.py b/tests/test_DDScalar.py index 888d1af4..88eb6d60 100644 --- a/tests/test_DDScalar.py +++ b/tests/test_DDScalar.py @@ -271,13 +271,17 @@ def test_ndarray(ctx): assert_allclose(u.g, [2, 3]) - assert_equal(np.triu(u.hm()), [[4, 5], [0, 6]]) + assert_equal(u.hm(), [[4, 5], [5, 6]]) + assert_equal(u.hm(mode='full'), [[4, 5], [5, 6]]) + assert_equal(u.hm(mode='zeros'), [[4, 5], [0, 6]]) u.g[:] = [5, 4] assert_allclose(u.g, [5, 4]) u.set_hm([[3, 2], [0, 1]]) - assert_equal(np.triu(u.hm()), [[3, 2], [0, 1]]) + assert_equal(u.hm(), [[3, 2], [2, 1]]) + assert_equal(u.hm(mode='full'), [[3, 2], [2, 1]]) + assert_equal(u.hm(mode='zeros'), [[3, 2], [0, 1]]) assert_allclose(u.data, [1, 5, 4, 3, 2, 1])