Skip to content

Commit

Permalink
noconvert pyarray arguments
Browse files Browse the repository at this point in the history
Signed-off-by: koubaa <koubaa@github.com>
  • Loading branch information
koubaa committed Dec 10, 2024
1 parent 977f07f commit 8e0289b
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions python/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,39 +563,42 @@ PYBIND11_MODULE(kp, m)
const std::vector<std::shared_ptr<kp::Memory>>& tensors,
const py::bytes& spirv,
const kp::Workgroup& workgroup,
const std::vector<float>& spec_consts,
const std::vector<float>& push_consts) {
const py::list& spec_consts,
const py::list& push_consts) {
std::vector<uint32_t> spirvVec = pv::to_vector(spirv);
KP_LOG_DEBUG("Kompute Python Manager creating Algorithm.");
auto pushConstsVec = push_consts.cast<std::vector<float>>();
auto specConstsVec = spec_consts.cast<std::vector<float>>();
return self.algorithm(
tensors, spirvVec, workgroup, spec_consts, push_consts);
tensors, spirvVec, workgroup, specConstsVec, pushConstsVec);
},
DOC(kp, Manager, algorithm),
py::arg("tensors"),
py::arg("spirv"),
py::arg("workgroup") = kp::Workgroup(),
py::arg("spec_consts") = std::vector<float>(),
py::arg("push_consts") = std::vector<float>())
py::arg("spec_consts") = py::list(),
py::arg("push_consts") = py::list())
.def(
"algorithm",
[](kp::Manager& self,
const std::vector<std::shared_ptr<kp::Memory>>& tensors,
const py::bytes& spirv,
const kp::Workgroup& workgroup,
const py::array& spec_consts,
const std::vector<float>& push_consts) {
const py::list& push_consts) {
KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with "
"spec consts data size {} dtype {}",
spec_consts.size(),
std::string(py::str(spec_consts.dtype())));
auto pushConstsVec = push_consts.cast<std::vector<float>>();
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
return pv::algorithm<float>(self, tensors, spirv, workgroup, spec_consts, push_consts);
return pv::algorithm<float>(self, tensors, spirv, workgroup, spec_consts, pushConstsVec);
} else if (spec_consts.dtype().is(py::dtype::of<std::int32_t>())) {
return pv::algorithm<int32_t>(self, tensors, spirv, workgroup, spec_consts, push_consts);
return pv::algorithm<int32_t>(self, tensors, spirv, workgroup, spec_consts, pushConstsVec);
} else if (spec_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
return pv::algorithm<uint32_t>(self, tensors, spirv, workgroup, spec_consts, push_consts);
return pv::algorithm<uint32_t>(self, tensors, spirv, workgroup, spec_consts, pushConstsVec);
} else if (spec_consts.dtype().is(py::dtype::of<std::double_t>())) {
return pv::algorithm<double_t>(self, tensors, spirv, workgroup, spec_consts, push_consts);
return pv::algorithm<double_t>(self, tensors, spirv, workgroup, spec_consts, pushConstsVec);
}
// If reach then no valid dtype supported
throw std::runtime_error("Kompute Python no valid dtype supported");
Expand All @@ -604,28 +607,29 @@ PYBIND11_MODULE(kp, m)
py::arg("tensors"),
py::arg("spirv"),
py::arg("workgroup") = kp::Workgroup(),
py::arg("spec_consts") = std::vector<float>(),
py::arg("push_consts") = std::vector<float>())
py::arg("spec_consts").noconvert(true) = py::array(),
py::arg("push_consts") = py::list())
.def(
"algorithm",
[](kp::Manager& self,
const std::vector<std::shared_ptr<kp::Memory>>& tensors,
const py::bytes& spirv,
const kp::Workgroup& workgroup,
const std::vector<float>& spec_consts,
const py::list& spec_consts,
const py::array& push_consts) {
KP_LOG_DEBUG("Kompute Python Manager creating Algorithm_T with "
"push consts data size {} dtype {}",
push_consts.size(),
std::string(py::str(push_consts.dtype())));
auto specConstsVec = spec_consts.cast<std::vector<float>>();
if (push_consts.dtype().is(py::dtype::of<std::float_t>())) {
return pv::algorithm<float>(self, tensors, spirv, workgroup, spec_consts, push_consts);
return pv::algorithm<float>(self, tensors, spirv, workgroup, specConstsVec, push_consts);
} else if (push_consts.dtype().is(py::dtype::of<std::int32_t>())) {
return pv::algorithm<int32_t>(self, tensors, spirv, workgroup, spec_consts, push_consts);
return pv::algorithm<int32_t>(self, tensors, spirv, workgroup, specConstsVec, push_consts);
} else if (push_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
return pv::algorithm<uint32_t>(self, tensors, spirv, workgroup, spec_consts, push_consts);
return pv::algorithm<uint32_t>(self, tensors, spirv, workgroup, specConstsVec, push_consts);
} else if (push_consts.dtype().is(py::dtype::of<std::double_t>())) {
return pv::algorithm<double_t>(self, tensors, spirv, workgroup, spec_consts, push_consts);
return pv::algorithm<double_t>(self, tensors, spirv, workgroup, specConstsVec, push_consts);
}
// If reach then no valid dtype supported
throw std::runtime_error("Kompute Python no valid dtype supported");
Expand All @@ -634,8 +638,8 @@ PYBIND11_MODULE(kp, m)
py::arg("tensors"),
py::arg("spirv"),
py::arg("workgroup") = kp::Workgroup(),
py::arg("spec_consts") = std::vector<float>(),
py::arg("push_consts") = std::vector<float>())
py::arg("spec_consts") = py::list(),
py::arg("push_consts").noconvert(true) = py::array())
.def(
"algorithm",
[np](kp::Manager& self,
Expand Down Expand Up @@ -703,8 +707,8 @@ PYBIND11_MODULE(kp, m)
py::arg("tensors"),
py::arg("spirv"),
py::arg("workgroup") = kp::Workgroup(),
py::arg("spec_consts") = std::vector<float>(),
py::arg("push_consts") = std::vector<float>())
py::arg("spec_consts").noconvert(true) = py::array(),
py::arg("push_consts").noconvert(true) = py::array())
.def(
"list_devices",
[](kp::Manager& self) {
Expand Down

0 comments on commit 8e0289b

Please sign in to comment.