Skip to content

Commit

Permalink
Merge pull request #35 from oberbichler/feature/variables
Browse files Browse the repository at this point in the history
Add `hj::variables`
  • Loading branch information
oberbichler authored Mar 26, 2021
2 parents 63f23cf + 4fa3f89 commit ed2b02e
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 4 deletions.
26 changes: 25 additions & 1 deletion include/hyperjet.h
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class DDScalar {
}
}

static std::vector<Type> variables(std::vector<Scalar> values)
static std::vector<Type> variables(const std::vector<Scalar>& values)
{
const index s = length(values);

Expand All @@ -412,6 +412,30 @@ class DDScalar {
return vars;
}

template <index T>
static std::conditional_t<TSize == Dynamic, std::vector<Type>, std::array<Type, T>> variables(const std::array<Scalar, T>& values)
{
if constexpr (!is_dynamic()) {
static_assert(T == TSize);
}

const index s = length(values);

if constexpr (is_dynamic()) {
std::vector<Type> vars(s);
for (index i = 0; i < s; i++) {
vars[i] = variable(i, values[i], s);
}
return vars;
} else {
std::array<Type, TSize> vars;
for (index i = 0; i < s; i++) {
vars[i] = variable(i, values[i], s);
}
return vars;
}
}

Scalar& f()
{
return m_data[0];
Expand Down
136 changes: 133 additions & 3 deletions src/python_module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ void register_ddscalar(pybind11::module& m, const std::string& name)
.def_static("empty", py::overload_cast<>(&Type::empty))
.def_static("empty", py::overload_cast<hj::index>(&Type::empty), "size"_a)
.def_static("variable", py::overload_cast<hj::index, double, hj::index>(&Type::variable), "i"_a, "f"_a, "size"_a)
.def_static("variables", &Type::variables, "values"_a)
.def_static("zero", py::overload_cast<>(&Type::zero))
.def_static("zero", py::overload_cast<hj::index>(&Type::zero), "size"_a)
// methods
Expand Down Expand Up @@ -159,19 +158,25 @@ void register_ddscalar(pybind11::module& m, const std::string& name)
// methods
.def("resize", &Type::resize, "size"_a)
.def("pad_right", &Type::pad_right, "new_size"_a)
.def("pad_left", &Type::pad_left, "new_size"_a);
.def("pad_left", &Type::pad_left, "new_size"_a)
// static methods
.def_static("variables", [](const std::vector<TScalar>& values) { return Type::variables(values); }, "values"_a);
} else {
py_class
// constructor
.def(py::init(py::overload_cast<TScalar>(&Type::constant)), "f"_a=0)
// static methods
.def_static("variable", py::overload_cast<hj::index, double>(&Type::variable), "i"_a, "f"_a);
.def_static("variable", py::overload_cast<hj::index, double>(&Type::variable), "i"_a, "f"_a)
.def_static("variables", [](const std::array<TScalar, TSize>& values) { return Type::template variables<TSize>(values); }, "values"_a);
}
}

PYBIND11_MODULE(hyperjet, m)
{
using namespace pybind11::literals;

namespace py = pybind11;
namespace hj = hyperjet;

m.doc() = "HyperJet by Thomas Oberbichler";
m.attr("__author__") = "Thomas Oberbichler";
Expand Down Expand Up @@ -225,5 +230,130 @@ PYBIND11_MODULE(hyperjet, m)
m.attr("f") = py::eval("np.vectorize(lambda v: v.f if hasattr(v, 'f') else v)", global);
m.attr("d") = py::eval("np.vectorize(lambda v: v.g if hasattr(v, 'g') else np.zeros((0)), signature='()->(n)')", global);
m.attr("dd") = py::eval("np.vectorize(lambda v: v.hm() if hasattr(v, 'hm') else np.zeros((0, 0)), signature='()->(n,m)')", global);

m.def("variables", [](const std::vector<double>& values, const hj::index order) {
if (order < 0 || 2 < order) {
throw std::runtime_error("Invalid order");
}

py::list results;

const auto extend = results.attr("extend");

switch (order) {
case 0:
extend(values);
break;
case 1:
switch (hj::length(values)) {
case 0:
break;
case 1:
extend(hj::DDScalar<1, double, 1>::variables(values));
break;
case 2:
extend(hj::DDScalar<1, double, 2>::variables(values));
break;
case 3:
extend(hj::DDScalar<1, double, 3>::variables(values));
break;
case 4:
extend(hj::DDScalar<1, double, 4>::variables(values));
break;
case 5:
extend(hj::DDScalar<1, double, 5>::variables(values));
break;
case 6:
extend(hj::DDScalar<1, double, 6>::variables(values));
break;
case 7:
extend(hj::DDScalar<1, double, 7>::variables(values));
break;
case 8:
extend(hj::DDScalar<1, double, 8>::variables(values));
break;
case 9:
extend(hj::DDScalar<1, double, 9>::variables(values));
break;
case 10:
extend(hj::DDScalar<1, double, 10>::variables(values));
break;
case 11:
extend(hj::DDScalar<1, double, 11>::variables(values));
break;
case 12:
extend(hj::DDScalar<1, double, 12>::variables(values));
break;
case 13:
extend(hj::DDScalar<1, double, 13>::variables(values));
break;
case 14:
extend(hj::DDScalar<1, double, 14>::variables(values));
break;
case 15:
extend(hj::DDScalar<1, double, 15>::variables(values));
break;
default:
extend(hj::DDScalar<1, double, -1>::variables(values));
break;
}
break;
case 2:
switch (hj::length(values)) {
case 0:
break;
case 1:
extend(hj::DDScalar<2, double, 1>::variables(values));
break;
case 2:
extend(hj::DDScalar<2, double, 2>::variables(values));
break;
case 3:
extend(hj::DDScalar<2, double, 3>::variables(values));
break;
case 4:
extend(hj::DDScalar<2, double, 4>::variables(values));
break;
case 5:
extend(hj::DDScalar<2, double, 5>::variables(values));
break;
case 6:
extend(hj::DDScalar<2, double, 6>::variables(values));
break;
case 7:
extend(hj::DDScalar<2, double, 7>::variables(values));
break;
case 8:
extend(hj::DDScalar<2, double, 8>::variables(values));
break;
case 9:
extend(hj::DDScalar<2, double, 9>::variables(values));
break;
case 10:
extend(hj::DDScalar<2, double, 10>::variables(values));
break;
case 11:
extend(hj::DDScalar<2, double, 11>::variables(values));
break;
case 12:
extend(hj::DDScalar<2, double, 12>::variables(values));
break;
case 13:
extend(hj::DDScalar<2, double, 13>::variables(values));
break;
case 14:
extend(hj::DDScalar<2, double, 14>::variables(values));
break;
case 15:
extend(hj::DDScalar<2, double, 15>::variables(values));
break;
default:
extend(hj::DDScalar<2, double, -1>::variables(values));
break;
}
break;
}
return results;
}, "values"_a, "order"_a=2);
}
}
30 changes: 30 additions & 0 deletions tests/test_DDScalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,3 +794,33 @@ def test_dd_of_scalar():
assert_equal(hj.dd(np.array([1, 2, 3, 4])), np.empty((4, 0, 0)))

assert_equal(hj.dd(np.array([[1, 2], [3, 4]])), np.empty((2, 2, 0, 0)))


def test_generate_variables():
small = [1, 2, 3]
large = [i + 1 for i in range(20)]

variables = hj.variables(small, order=0)

assert_equal(variables, small)
assert_equal(type(variables[0]), float)

variables = hj.variables(small, order=1)

assert_equal(hj.f(variables), small)
assert_equal(type(variables[0]), hj.D3Scalar)

variables = hj.variables(large, order=1)

assert_equal(hj.f(variables), large)
assert_equal(type(variables[0]), hj.DScalar)

variables = hj.variables(small, order=2)

assert_equal(hj.f(variables), small)
assert_equal(type(variables[0]), hj.DD3Scalar)

variables = hj.variables(large, order=2)

assert_equal(hj.f(variables), large)
assert_equal(type(variables[0]), hj.DDScalar)

0 comments on commit ed2b02e

Please sign in to comment.