Skip to content

Commit

Permalink
Updates the BasicStringArray class to use absl::Cord as the eleme…
Browse files Browse the repository at this point in the history
…nt type.

Before this change, it was using absl::string_view and switching to absl::Cord allows both the IFRT client and its users more flexibility as well as opportunities for optimizations by allowing the strings to be: either included inline, or be readonly views of existing strings (say, in a numpy array or a tensor).

PiperOrigin-RevId: 689556652
  • Loading branch information
Google-ML-Automation committed Oct 24, 2024
1 parent 6c0ce17 commit 31e7e36
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 116 deletions.
3 changes: 2 additions & 1 deletion xla/python/pjrt_ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ cc_library(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
Expand All @@ -357,6 +357,7 @@ xla_cc_test(
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
Expand Down
97 changes: 42 additions & 55 deletions xla/python/pjrt_ifrt/basic_string_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/pjrt/pjrt_layout.h"
Expand All @@ -48,7 +47,7 @@ limitations under the License.
// DisassembleIntoSingleDeviceArrays, Reshard, FullyReplicatedShard,
// CopyToHostBuffer and AssembleFromSingleDeviceArrays share a common pattern
// that waits for the source array(s) buffers to become ready and then copies
// the data into a new array's buffer backing store. Factor out the common
// the data into a new array's buffer. Factor out the common
// pattern into a helper function.

namespace xla {
Expand Down Expand Up @@ -104,7 +103,7 @@ absl::StatusOr<tsl::RCReference<BasicStringArray>> BasicStringArray::Create(
auto ready_future = Future<>(ready_promise);

// Buffers when the become ready must be consistent with the sharding. For
// instance, Buffers.size() (the number of per-shard spans of string_views)
// instance, Buffers.size() (the number of per-shard spans of absl::Cords)
// and the devices in the sharding that was used to create an array must
// match. If they do not, the array's ready future and buffers future should
// become ready with an appropriate error status.
Expand Down Expand Up @@ -216,66 +215,62 @@ BasicStringArray::DisassembleIntoSingleDeviceArrays(
// For each single device array we are going to pre-make:
// (1) a Promise-Future pair for passing the buffers,
//
// (2) a Per-shard buffer backing store and the corresponding
// on-done-with-buffer callback.
// (2) a Per-shard data store and the corresponding on-done-with-buffer
// callback.
//
// (3) shape and sharding by disassembing the source array's sharding.
//
// The Futures, the on-done-with-host-buffer callbacks, shapes and shardings
// are used to make the arrays. The promises and the buffer backing stores
// are used to make the arrays. The promises and the per-shard stores
// are passed onto the OnReady callback that populates them when the buffers
// of the source array become ready.
std::vector<Promise<Buffers>> buffer_promises;
buffer_promises.reserve(num_shards);
std::vector<Future<Buffers>> buffer_futures;
buffer_futures.reserve(num_shards);

struct PerShardBufferBackingStore { // Data (strings) for a single shard.
void CopyFrom(absl::Span<const absl::string_view> input_buffer) {
struct PerShardStringStore { // Data (strings) for a single shard.
void CopyFrom(absl::Span<const absl::Cord> input_buffer) {
strings.reserve(input_buffer.size());
string_views.reserve(input_buffer.size());
for (absl::string_view buf : input_buffer) {
strings.push_back(std::string(buf.data(), buf.size()));
string_views.push_back(strings.back());
for (const auto& input_string : input_buffer) {
strings.push_back(input_string);
}
}
std::vector<std::string> strings;
std::vector<absl::string_view> string_views;
std::vector<absl::Cord> strings;
};
std::vector<std::shared_ptr<PerShardBufferBackingStore>>
per_shard_buffer_backing_stores;
per_shard_buffer_backing_stores.reserve(num_shards);

std::vector<std::shared_ptr<PerShardStringStore>> per_shard_strings;
per_shard_strings.reserve(num_shards);
std::vector<OnDoneWithBuffer> on_done_with_buffer_callbacks;
on_done_with_buffer_callbacks.reserve(num_shards);

for (int i = 0; i < num_shards; ++i) {
buffer_promises.push_back(Future<Buffers>::CreatePromise());
buffer_futures.push_back(Future<Buffers>(buffer_promises.back()));

auto backing_store = std::make_shared<PerShardBufferBackingStore>();
per_shard_buffer_backing_stores.push_back(backing_store);
auto current_shard_strings = std::make_shared<PerShardStringStore>();
per_shard_strings.push_back(current_shard_strings);
on_done_with_buffer_callbacks.push_back(
[backing_store = std::move(backing_store)]() {});
[data = std::move(current_shard_strings)]() {});
}

// Copy each of the per-shard data into the its per-shard buffer backing
// store, make a Buffers object and set the corresponding promise.
// When the buffers become ready, copy each of the per-shard data into the
// buffer of the corresponding single-device array.
buffers_.OnReady([buffer_promises = std::move(buffer_promises),
per_shard_buffer_backing_stores =
std::move(per_shard_buffer_backing_stores)](
per_shard_data = std::move(per_shard_strings)](
absl::StatusOr<Buffers> buffers) mutable {
if (!buffers.ok()) {
for (auto& promise : buffer_promises) {
promise.Set(buffers.status());
}
per_shard_buffer_backing_stores.clear();
per_shard_data.clear();
return;
}
auto num_shards = buffers->size();
for (int i = 0; i < num_shards; ++i) {
per_shard_buffer_backing_stores[i]->CopyFrom((*buffers)[i]);
per_shard_data[i]->CopyFrom((*buffers)[i]);
Buffers buffers;
buffers.push_back(per_shard_buffer_backing_stores[i]->string_views);
buffers.push_back(absl::MakeConstSpan(per_shard_data[i]->strings));
buffer_promises[i].Set(std::move(buffers));
}
});
Expand Down Expand Up @@ -325,29 +320,24 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Copy(
sharding_->devices()->size()));
}

struct BufferBackingStore {
void AddShardData(absl::Span<const absl::string_view> input_buffer) {
struct StringStore {
void AddShardData(absl::Span<const absl::Cord> input_buffer) {
auto& shard_strings = strings.emplace_back();
shard_strings.reserve(input_buffer.size());

auto& shard_string_views = string_views.emplace_back();
shard_string_views.reserve(input_buffer.size());

for (absl::string_view buf : input_buffer) {
shard_strings.push_back(std::string(buf.data(), buf.size()));
shard_string_views.push_back(shard_strings.back());
for (const auto& input_string : input_buffer) {
shard_strings.push_back(input_string);
}
}
std::vector<std::vector<std::string>> strings;
std::vector<std::vector<absl::string_view>> string_views;
std::vector<std::vector<absl::Cord>> strings;
};

auto backing_store = std::make_shared<BufferBackingStore>();
auto on_done_with_buffer = [backing_store]() {};
auto string_store = std::make_shared<StringStore>();
auto on_done_with_buffer = [string_store]() {};
auto buffers_promise = Future<Buffers>::CreatePromise();
auto buffers_future = Future<Buffers>(buffers_promise);

auto copier = [backing_store = std::move(backing_store),
auto copier = [string_store = std::move(string_store),
buffers_promise = std::move(buffers_promise)](
absl::StatusOr<Buffers> input_buffers) mutable {
if (!input_buffers.ok()) {
Expand All @@ -357,8 +347,8 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Copy(
Buffers buffers;
buffers.reserve(input_buffers->size());
for (auto& input_buffer : *input_buffers) {
backing_store->AddShardData(input_buffer);
buffers.push_back(backing_store->string_views.back());
string_store->AddShardData(input_buffer);
buffers.push_back(string_store->strings.back());
}
buffers_promise.Set(std::move(buffers));
};
Expand All @@ -384,25 +374,22 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::FullyReplicatedShard(
if (!sharding_->IsFullyReplicated()) {
return absl::FailedPreconditionError("This array is not fully replicated");
}
struct BufferBackingStore { // Data (strings) for a single shard.
void CopyFrom(absl::Span<const absl::string_view> input_buffer) {
struct StringStore { // Data (strings) for a single shard.
void CopyFrom(absl::Span<const absl::Cord> input_buffer) {
strings.reserve(input_buffer.size());
string_views.reserve(input_buffer.size());
for (absl::string_view buf : input_buffer) {
strings.push_back(std::string(buf.data(), buf.size()));
string_views.push_back(strings.back());
for (const auto& input_strings : input_buffer) {
strings.push_back(input_strings);
}
}
std::vector<std::string> strings;
std::vector<absl::string_view> string_views;
std::vector<absl::Cord> strings;
};

auto backing_store = std::make_shared<BufferBackingStore>();
auto on_done_with_buffer = [backing_store]() {};
auto string_store = std::make_shared<StringStore>();
auto on_done_with_buffer = [string_store]() {};
auto buffers_promise = Future<Buffers>::CreatePromise();
auto buffers_future = Future<Buffers>(buffers_promise);

auto copier = [backing_store = std::move(backing_store),
auto copier = [string_store = std::move(string_store),
buffers_promise = std::move(buffers_promise)](
absl::StatusOr<Buffers> input_buffers) mutable {
if (!input_buffers.ok()) {
Expand All @@ -414,10 +401,10 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::FullyReplicatedShard(
// were run when the source array's buffers became ready would have
// ensured that the input_buffers have at least one shard's worth of data.
auto& input_buffer = (*input_buffers)[0];
backing_store->CopyFrom(input_buffer);
string_store->CopyFrom(input_buffer);

Buffers buffers;
buffers.push_back(backing_store->string_views);
buffers.push_back(string_store->strings);
buffers_promise.Set(std::move(buffers));
};
buffers_.OnReady(std::move(copier));
Expand Down
6 changes: 3 additions & 3 deletions xla/python/pjrt_ifrt/basic_string_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "absl/hash/hash.h"
#include "absl/log/check.h"
#include "absl/strings/string_view.h"
#include "absl/strings/cord.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "llvm/Support/ExtensibleRTTI.h"
Expand Down Expand Up @@ -71,7 +71,7 @@ class BasicStringArray final
: public llvm::RTTIExtends<BasicStringArray, Array> {
public:
// Must be in dense major to minor order.
using Buffer = absl::Span<const absl::string_view>;
using Buffer = absl::Span<const absl::Cord>;

// One Buffer per shard.
static constexpr int kBuffersInlineSize = 1;
Expand All @@ -82,7 +82,7 @@ class BasicStringArray final
using OnDoneWithBuffer = std::function<void()>;

// General array construction. The `buffers` and their elements
// (absl::string_views) must live until the `on_done_with_buffer` is called.
// (absl::Cords) must live until the `on_done_with_buffer` is called.
// The number and order of buffers must match the number and order of devices
// in `sharding`.
static absl::StatusOr<tsl::RCReference<BasicStringArray>> Create(
Expand Down
63 changes: 29 additions & 34 deletions xla/python/pjrt_ifrt/basic_string_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/notification.h"
Expand Down Expand Up @@ -84,21 +85,15 @@ std::pair<BasicStringArray::Buffers, BasicStringArray::OnDoneWithBuffer>
MakeBuffersAndOnDoneWithBuffer(
absl::Span<const absl::string_view> input_strings) {
BasicStringArray::Buffers buffers;
auto string_holder = std::make_shared<std::vector<std::string>>();
string_holder->reserve(input_strings.size());
auto string_view_holder = std::make_shared<std::vector<absl::string_view>>();
string_view_holder->reserve(input_strings.size());
for (const auto str : input_strings) {
string_holder->push_back(std::string(str));
auto strings = std::make_shared<std::vector<absl::Cord>>();
strings->reserve(input_strings.size());
for (const auto input_str : input_strings) {
strings->push_back(absl::Cord(input_str));
}
for (const auto& str : *string_holder) {
string_view_holder->push_back(absl::string_view(str));
}
buffers.push_back(*string_view_holder);
buffers.push_back(*strings);

BasicStringArray::OnDoneWithBuffer on_done_with_buffer =
[string_holder = std::move(string_holder),
string_view_holder = std::move(string_view_holder)]() {};
[strings = std::move(strings)]() {};

return std::make_pair(std::move(buffers), std::move(on_done_with_buffer));
}
Expand Down Expand Up @@ -175,7 +170,7 @@ TEST(BasicStringArrayLayoutTest, Equality) {
TEST(BasicStringArrayTest, CreateSuccess) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
BasicStringArray::Buffers buffers;
buffers.push_back({"abc", "def"});
buffers.push_back({absl::Cord("abc"), absl::Cord("def")});

// This test implicitly tests that the on_done_with_buffer can be a nullptr,
// and that the destruction of the BasicStringArray object completes
Expand All @@ -197,7 +192,7 @@ TEST(BasicStringArrayTest, Destruction) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());

BasicStringArray::Buffers buffers;
buffers.push_back({"abc", "def"});
buffers.push_back({absl::Cord("abc"), absl::Cord("def")});

absl::Notification on_done_with_buffer_called;
BasicStringArray::OnDoneWithBuffer on_done_with_buffer =
Expand Down Expand Up @@ -228,10 +223,10 @@ TEST(BasicStringArrayTest, InvalidBuffersAreHandledCorrectly) {
ASSERT_GE(devices.size(), 1);

// Make a BasicStringArray::Buffer with two shards.
auto shard0_data = std::make_shared<std::vector<absl::string_view>>();
shard0_data->push_back("abc");
auto shard1_data = std::make_shared<std::vector<absl::string_view>>();
shard1_data->push_back("def");
auto shard0_data = std::make_shared<std::vector<absl::Cord>>();
shard0_data->push_back(absl::Cord("abc"));
auto shard1_data = std::make_shared<std::vector<absl::Cord>>();
shard1_data->push_back(absl::Cord("def"));
BasicStringArray::Buffers buffers;
buffers.push_back(*shard0_data);
buffers.push_back(*shard1_data);
Expand Down Expand Up @@ -260,7 +255,7 @@ TEST(BasicStringArrayTest, InvalidBuffersAreHandledCorrectly) {
TEST(BasicStringArrayTest, Delete) {
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
BasicStringArray::Buffers buffers;
buffers.push_back({"abc", "def"});
buffers.push_back({absl::Cord("abc"), absl::Cord("def")});
absl::Notification on_done_with_buffer_called;
BasicStringArray::OnDoneWithBuffer on_done_with_buffer =
[&on_done_with_buffer_called]() { on_done_with_buffer_called.Notify(); };
Expand Down Expand Up @@ -294,7 +289,7 @@ TEST(GetReadyFutureTest, SuccessCase) {

// Make the buffers future ready asynchronously.
BasicStringArray::Buffers buffers;
buffers.push_back({"abc", "def"});
buffers.push_back({absl::Cord("abc"), absl::Cord("def")});
tsl::Env::Default()->SchedClosure([&]() { promise.Set(buffers); });
TF_EXPECT_OK(ready_future.Await());
}
Expand Down Expand Up @@ -326,11 +321,11 @@ TEST(MakeArrayFromHostBufferTest, SuccessCase) {
std::shared_ptr<const Sharding> sharding =
SingleDeviceSharding::Create(device, MemoryKind());

auto string_views = std::make_shared<std::vector<absl::string_view>>();
string_views->push_back("abc");
string_views->push_back("def");
const void* data = string_views->data();
auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {};
auto strings = std::make_shared<std::vector<absl::Cord>>();
strings->push_back(absl::Cord("abc"));
strings->push_back(absl::Cord("def"));
const void* data = strings->data();
auto on_done_with_host_buffer = [strings = std::move(strings)]() {};

TF_ASSERT_OK(client->MakeArrayFromHostBuffer(
data, DType(DType::kString), shape,
Expand All @@ -345,11 +340,11 @@ TEST(MakeArrayFromHostBufferTest, FailureCases) {
Device* device = client->addressable_devices().at(0);
std::shared_ptr<const Sharding> single_device_sharding =
SingleDeviceSharding::Create(device, MemoryKind());
auto string_views = std::make_shared<std::vector<absl::string_view>>();
string_views->push_back("abc");
string_views->push_back("def");
const void* data = string_views->data();
auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {};
auto strings = std::make_shared<std::vector<absl::Cord>>();
strings->push_back(absl::Cord("abc"));
strings->push_back(absl::Cord("def"));
const void* data = strings->data();
auto on_done_with_host_buffer = [strings = std::move(strings)]() {};

// MakeArrayFromHostBuffer should check and fail if `byte_strides` in not
// nullopt.
Expand Down Expand Up @@ -398,12 +393,12 @@ absl::StatusOr<tsl::RCReference<Array>> MakeSingleDeviceStringTestArray(
std::shared_ptr<const Sharding> sharding =
SingleDeviceSharding::Create(device, MemoryKind());

auto string_views = std::make_shared<std::vector<absl::string_view>>();
auto strings = std::make_shared<std::vector<absl::Cord>>();
for (const auto& content : contents) {
string_views->push_back(content);
strings->push_back(absl::Cord(content));
}
const void* data = string_views->data();
auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {};
const void* data = strings->data();
auto on_done_with_host_buffer = [strings = std::move(strings)]() {};

return client->MakeArrayFromHostBuffer(
data, DType(DType::kString), shape,
Expand Down
Loading

0 comments on commit 31e7e36

Please sign in to comment.