Skip to content

Commit

Permalink
Enable AWS S3 in readers.file
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Apr 5, 2024
1 parent 4241f6e commit 8f3d01d
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
10 changes: 6 additions & 4 deletions dali/operators/reader/loader/file_label_loader.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,11 +50,13 @@ void FileLabelLoaderBase<checkpointing_supported>::ReadSample(ImageLabelWrapper
return;
}

auto current_image = FileStream::Open(filesystem::join_path(file_root_, image_pair.first),
read_ahead_, !copy_read_data_);
auto uri = filesystem::join_path(file_root_, image_pair.first);
bool is_s3 = uri.rfind("s3://", 0) != std::string::npos;
bool use_mmap = !copy_read_data_ && !is_s3;
auto current_image = FileStream::Open(uri, read_ahead_, use_mmap);
Index image_size = current_image->Size();

if (copy_read_data_) {
if (!use_mmap) {
if (image_label.image.shares_data()) {
image_label.image.Reset();
}
Expand Down
5 changes: 2 additions & 3 deletions dali/util/file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ namespace dali {

std::unique_ptr<FileStream> FileStream::Open(const std::string& uri, bool read_ahead, bool use_mmap,
bool use_odirect) {
std::string processed_uri;

#if AWSSDK_ENABLED
if (uri.rfind("s3://", 0) == 0) {
return std::unique_ptr<FileStream>(
new S3FileStream(S3ClientManager::Instance().client(), processed_uri));
new S3FileStream(S3ClientManager::Instance().client(), uri));
}
#endif

std::string processed_uri;
if (uri.find("file://") == 0) {
processed_uri = uri.substr(std::string("file://").size());
} else {
Expand Down
8 changes: 4 additions & 4 deletions dali/util/s3_client_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ struct S3ClientManager {
private:
S3ClientManager() {
Aws::InitAPI(options_);
clientConfig_.region = std::getenv("DALI_AWS_REGION");
clientConfig_.endpointOverride = std::getenv("DALI_AWS_ENDPOINT");
client_ = std::make_unique<Aws::S3::S3Client>(clientConfig_);
Aws::Client::ClientConfiguration clientConfig;
clientConfig.region = std::getenv("DALI_AWS_REGION");
clientConfig.endpointOverride = std::getenv("DALI_AWS_ENDPOINT");
client_ = std::make_unique<Aws::S3::S3Client>(clientConfig);
}

~S3ClientManager() {
Aws::ShutdownAPI(options_);
}

Aws::SDKOptions options_;
Aws::Client::ClientConfiguration clientConfig_;
std::unique_ptr<Aws::S3::S3Client> client_;
};

Expand Down

0 comments on commit 8f3d01d

Please sign in to comment.