Skip to content

Commit

Permalink
Add mode attribute to hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
oberbichler committed Mar 8, 2021
1 parent f601141 commit 384fd33
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 6 deletions.
27 changes: 24 additions & 3 deletions include/hyperjet.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,19 +480,40 @@ class DDScalar {
return Eigen::Map<Eigen::Matrix<TScalar, 1, TSize>>(ptr() + 1, size());
}

Eigen::Matrix<TScalar, TSize, TSize> hm() const
Eigen::Matrix<TScalar, TSize, TSize> hm(const std::string mode) const
{
Eigen::Matrix<TScalar, TSize, TSize> result(size(), size());

hm(mode, result);

return result;
}

void hm(const std::string mode, Eigen::Ref<Eigen::Matrix<TScalar, TSize, TSize>> out) const
{
index it = 0;

for (index i = 0; i < size(); i++) {
for (index j = i; j < size(); j++) {
result(i, j) = h(it++);
out(i, j) = h(it++);
}
}

return result;
if (mode == "zeros") {
for (index i = 0; i < size(); i++) {
for (index j = 0; j < i; j++) {
out(i, j) = 0;
}
}
} else if (mode == "full") {
for (index i = 0; i < size(); i++) {
for (index j = 0; j < i; j++) {
out(i, j) = out(j, i);
}
}
} else {
throw std::runtime_error("Invalid value for 'mode'");
}
}

void set_hm(Eigen::Ref<const Eigen::Matrix<TScalar, TSize, TSize>> value)
Expand Down
2 changes: 1 addition & 1 deletion src/python_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void register_ddscalar(pybind11::module& m, const std::string& name)
.def("abs", &Type::abs)
.def("h", py::overload_cast<hj::index, hj::index>(&Type::h), "row"_a, "col"_a)
.def("set_h", py::overload_cast<hj::index, hj::index, TScalar>(&Type::set_h), "row"_a, "col"_a, "value"_a)
.def("hm", &Type::hm)
.def("hm", py::overload_cast<std::string>(&Type::hm, py::const_), "mode"_a="full")
.def("set_hm", &Type::set_hm, "value"_a)
// methods: arithmetic operations
.def("reciprocal", &Type::reciprocal)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_DDScalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,17 @@ def test_ndarray(ctx):

assert_allclose(u.g, [2, 3])

assert_equal(np.triu(u.hm()), [[4, 5], [0, 6]])
assert_equal(u.hm(), [[4, 5], [5, 6]])
assert_equal(u.hm(mode='full'), [[4, 5], [5, 6]])
assert_equal(u.hm(mode='zeros'), [[4, 5], [0, 6]])

u.g[:] = [5, 4]
assert_allclose(u.g, [5, 4])

u.set_hm([[3, 2], [0, 1]])
assert_equal(np.triu(u.hm()), [[3, 2], [0, 1]])
assert_equal(u.hm(), [[3, 2], [2, 1]])
assert_equal(u.hm(mode='full'), [[3, 2], [2, 1]])
assert_equal(u.hm(mode='zeros'), [[3, 2], [0, 1]])

assert_allclose(u.data, [1, 5, 4, 3, 2, 1])

Expand Down

0 comments on commit 384fd33

Please sign in to comment.