Skip to content

Commit

Permalink
Add python binding for loading bin from memory (#5164)
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Aug 17, 2024
1 parent 4de5369 commit a0c9e77
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
21 changes: 21 additions & 0 deletions python/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ using namespace ncnn;

namespace py = pybind11;

class DataReaderFromMemoryCopy : public DataReaderFromMemory
{
public:
explicit DataReaderFromMemoryCopy(const unsigned char*& mem)
: DataReaderFromMemory(mem)
{
}

virtual size_t reference(size_t size, const void** buf) const
{
return 0;
}
};

struct LayerFactory
{
std::string name;
Expand Down Expand Up @@ -956,6 +970,13 @@ PYBIND11_MODULE(ncnn, m)
#endif // NCNN_STRING
.def("load_param_bin", (int (Net::*)(const char*)) & Net::load_param_bin, py::arg("protopath"))
.def("load_model", (int (Net::*)(const char*)) & Net::load_model, py::arg("modelpath"))
.def(
"load_model_mem", [](Net& net, const char* mem) {
const unsigned char* _mem = (const unsigned char*)mem;
DataReaderFromMemoryCopy dr(_mem);
net.load_model(dr);
},
py::arg("mem"))
#endif // NCNN_STDIO

.def("clear", &Net::clear)
Expand Down
26 changes: 26 additions & 0 deletions python/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,32 @@ def test_net():
assert len(net.blobs()) == 0 and len(net.layers()) == 0


def test_net_mem():
modelbin = bytearray(303940)
modelbin[0:4] = 71,107,48,1
modelbin[180:184] = 71,107,48,1

with ncnn.Net() as net:
ret = net.load_param("tests/test.param")
net.load_model_mem(bytes(modelbin))
assert ret == 0 and len(net.blobs()) == 3 and len(net.layers()) == 3

input_names = net.input_names()
output_names = net.output_names()
assert len(input_names) > 0 and len(output_names) > 0

in_mat = ncnn.Mat((227, 227, 3))

with net.create_extractor() as ex:
ex.input("data", in_mat)
ret, out_mat = ex.extract("output")

assert ret == 0 and out_mat.dims == 1 and out_mat.w == 1

net.clear()
assert len(net.blobs()) == 0 and len(net.layers()) == 0


def test_net_vulkan():
if not hasattr(ncnn, "get_gpu_count"):
return
Expand Down

0 comments on commit a0c9e77

Please sign in to comment.