Skip to content

Commit

Permalink
Changes to dali/util/numpy
Browse files Browse the repository at this point in the history
- Make `numpy.h` available to users
- Extend `ReadTensor` to also accept a `bool pinned` argument
- See #5337 for details

Signed-off-by: Francesco Versaci <francesco.versaci@gmail.com>
  • Loading branch information
fversaci committed Apr 5, 2024
1 parent 43f2671 commit 1c973d0
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dali/kernels/slice/slice_flip_normalize_gpu_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class SliceFlipNormalizeGPUTest : public ::testing::Test {

void LoadTensor(Tensor<CPUBackend> &tensor, const std::string& path_npy) {
auto stream = FileStream::Open(path_npy, false, false);
tensor = ::dali::numpy::ReadTensor(stream.get());
tensor = ::dali::numpy::ReadTensor(stream.get(), true);
}

template <typename T, int ndim>
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/imgcodec/decoder_test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ inline void AssertClose(const Tensor<CPUBackend> &img, const Tensor<CPUBackend>
}

inline Tensor<CPUBackend> ReadReference(InputStream *src, TensorLayout layout = "HWC") {
auto tensor = numpy::ReadTensor(src);
auto tensor = numpy::ReadTensor(src, true);
tensor.SetLayout(layout);
return tensor;
}
Expand Down
1 change: 1 addition & 0 deletions dali/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(DALI_INST_HDRS ${DALI_INST_HDRS}
"${CMAKE_CURRENT_SOURCE_DIR}/ocv.h"
"${CMAKE_CURRENT_SOURCE_DIR}/random_crop_generator.h"
"${CMAKE_CURRENT_SOURCE_DIR}/thread_safe_queue.h"
"${CMAKE_CURRENT_SOURCE_DIR}/numpy.h"
"${CMAKE_CURRENT_SOURCE_DIR}/user_stream.h")

set(DALI_SRCS ${DALI_SRCS}
Expand Down
3 changes: 2 additions & 1 deletion dali/util/numpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,13 @@ size_t HeaderData::nbytes() const {
return type_info ? type_info->size() * size() : 0_uz;
}

Tensor<CPUBackend> ReadTensor(InputStream *src) {
Tensor<CPUBackend> ReadTensor(InputStream *src, bool pinned) {
numpy::HeaderData header;
numpy::ParseHeader(header, src);
src->SeekRead(header.data_offset, SEEK_SET);

Tensor<CPUBackend> data;
data.set_pinned(pinned);
data.Resize(header.shape, header.type());
auto ret = src->Read(static_cast<uint8_t*>(data.raw_mutable_data()), header.nbytes());
DALI_ENFORCE(ret == header.nbytes(), "Failed to read numpy file");
Expand Down
2 changes: 1 addition & 1 deletion dali/util/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ DLL_PUBLIC void FromFortranOrder(SampleView<CPUBackend> output, ConstSampleView<

DLL_PUBLIC void ParseHeaderContents(HeaderData& target, const std::string &header);

DLL_PUBLIC Tensor<CPUBackend> ReadTensor(InputStream *src);
DLL_PUBLIC Tensor<CPUBackend> ReadTensor(InputStream *src, bool pinned);

} // namespace numpy
} // namespace dali
Expand Down

0 comments on commit 1c973d0

Please sign in to comment.