Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support explicit batch dimensions for gather/scatter in HLO evaluator. #19400

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xla/hlo/evaluator/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cc_library(
"//xla/service:call_graph",
"//xla/service:compilation_environments",
"//xla/service:dynamic_dimension_inference",
"//xla/service:gather_scatter_utils",
"//xla/service:hlo_module_config",
"//xla/service:logical_buffer",
"//xla/service:pattern_matcher",
Expand Down Expand Up @@ -148,7 +149,6 @@ xla_cc_test(
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:endian",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
95 changes: 63 additions & 32 deletions xla/hlo/evaluator/hlo_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ limitations under the License.
#include "xla/service/call_graph.h"
#include "xla/service/compilation_environments.h"
#include "xla/service/cpu/runtime_single_threaded_matmul.h"
#include "xla/service/gather_scatter_utils.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/logical_buffer.h"
#include "xla/service/pattern_matcher.h"
Expand Down Expand Up @@ -2386,6 +2387,18 @@ class OutputBatchIndexToInputIndex {
int64_t index_vector_size =
start_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
index_vector_.resize(index_vector_size);

absl::flat_hash_map<int64_t, int64_t> start_indices_dims_to_output_dims =
GetStartIndicesDimToOutputDimForExplicitBatchingDims(
dim_numbers_.start_indices_batching_dims(),
dim_numbers_.index_vector_dim(), dim_numbers_.offset_dims(),
start_indices_.shape().rank(), output_shape.rank());
for (int64_t i = 0; i < dim_numbers->operand_batching_dims().size(); ++i) {
int64_t operand_dim = dim_numbers->operand_batching_dims(i);
int64_t start_indices_dim = dim_numbers->start_indices_batching_dims(i);
int64_t output_dim = start_indices_dims_to_output_dims[start_indices_dim];
explicit_batch_dims_operand_dim_to_output_dim_[operand_dim] = output_dim;
}
}

// Returns the contribution of start_indices to the input index corresponding
Expand All @@ -2407,6 +2420,7 @@ class OutputBatchIndexToInputIndex {
PropagateOutputIndexGatherDimsToIndexVectorIndex(output_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
PropagateExplicitBatchDimsToInputIndex(output_index);
return absl::Span<const int64_t>(input_index_);
}

Expand Down Expand Up @@ -2456,6 +2470,14 @@ class OutputBatchIndexToInputIndex {
}
}

void PropagateExplicitBatchDimsToInputIndex(
absl::Span<const int64_t> output_index) {
for (const auto& [operand_dim, output_dim] :
explicit_batch_dims_operand_dim_to_output_dim_) {
input_index_[operand_dim] = output_index[output_dim];
}
}

// input_dim_value_to_index_vector_[i] tells us how to compute dimension i of
// the input index from the index vector. See
// PropagateIndexVectorToInputIndex.
Expand All @@ -2476,6 +2498,9 @@ class OutputBatchIndexToInputIndex {
// this vector.
std::vector<int64_t> input_index_;

absl::flat_hash_map<int64_t, int64_t>
explicit_batch_dims_operand_dim_to_output_dim_;

const GatherDimensionNumbers& dim_numbers_;
const Literal& start_indices_;
};
Expand All @@ -2488,25 +2513,16 @@ class OutputOffsetIndexToInputIndex {
// The constructor does some setup work that is amortized across all
// iterations.
explicit OutputOffsetIndexToInputIndex(
const GatherDimensionNumbers& dim_numbers, const Shape& input_shape,
const Shape& output_shape) {
std::vector<int64_t> window_index_to_output_index;
int64_t output_index_count = 0;
for (int64_t i = 0; i < output_shape.dimensions_size(); i++) {
if (absl::c_binary_search(dim_numbers.offset_dims(), i)) {
window_index_to_output_index.push_back(output_index_count++);
} else {
output_index_count++;
}
}

const GatherDimensionNumbers& dim_numbers, const Shape& input_shape) {
CHECK(absl::c_is_sorted(dim_numbers.offset_dims()));
int64_t window_dim_count = 0;
for (int64_t i = 0; i < input_shape.dimensions_size(); i++) {
if (absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
if (IsCollapsedOrBatchingDim(dim_numbers.collapsed_slice_dims(),
dim_numbers.operand_batching_dims(), i)) {
input_dim_value_to_output_index_.push_back(-1);
} else {
input_dim_value_to_output_index_.push_back(
window_index_to_output_index[window_dim_count++]);
dim_numbers.offset_dims()[window_dim_count++]);
}
}

Expand Down Expand Up @@ -2617,8 +2633,7 @@ absl::Status HloEvaluator::HandleGather(const HloInstruction* gather) {
&gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
/*output_shape=*/shape, &start_indices);
OutputOffsetIndexToInputIndex output_offset_index_to_input_index(
gather->gather_dimension_numbers(), /*input_shape=*/operand.shape(),
/*output_shape=*/shape);
gather->gather_dimension_numbers(), /*input_shape=*/operand.shape());

const Shape& operand_shape = operand.shape();
if (ShapeUtil::IsZeroElementArray(operand_shape)) {
Expand Down Expand Up @@ -2791,6 +2806,20 @@ class UpdateScatterIndexToInputIndex {
int64_t index_vector_size =
scatter_indices_.shape().dimensions(dim_numbers_.index_vector_dim());
index_vector_.resize(index_vector_size);

absl::flat_hash_map<int64_t, int64_t> scatter_indices_dims_to_update_dims =
GetStartIndicesDimToOutputDimForExplicitBatchingDims(
dim_numbers_.scatter_indices_batching_dims(),
dim_numbers_.index_vector_dim(), dim_numbers_.update_window_dims(),
scatter_indices_.shape().rank(), updates_rank);
for (int64_t i = 0; i < dim_numbers.input_batching_dims().size(); ++i) {
int64_t input_dim = dim_numbers.input_batching_dims(i);
int64_t scatter_indices_dim =
dim_numbers.scatter_indices_batching_dims(i);
int64_t update_dim =
scatter_indices_dims_to_update_dims[scatter_indices_dim];
explicit_batch_dims_input_dim_to_update_dim_[input_dim] = update_dim;
}
}

// Returns the contribution of scatter_indices to the input index
Expand All @@ -2812,6 +2841,7 @@ class UpdateScatterIndexToInputIndex {
PropagateUpdateIndexScatterDimsToIndexVectorIndex(update_index);
TF_RETURN_IF_ERROR(FetchIndexVector());
PropagateIndexVectorToInputIndex();
PropagateExplicitBatchDimsToInputIndex(update_index);
return absl::Span<const int64_t>(input_index_);
}

Expand Down Expand Up @@ -2860,6 +2890,14 @@ class UpdateScatterIndexToInputIndex {
}
}

void PropagateExplicitBatchDimsToInputIndex(
absl::Span<const int64_t> update_index) {
for (const auto& [input_dim, update_dim] :
explicit_batch_dims_input_dim_to_update_dim_) {
input_index_[input_dim] = update_index[update_dim];
}
}

// input_dim_value_to_index_vector_[i] tells us how to compute dimension i
// of the input index from the index vector. See
// PropagateIndexVectorToInputIndex.
Expand All @@ -2880,6 +2918,9 @@ class UpdateScatterIndexToInputIndex {
// into this vector.
std::vector<int64_t> input_index_;

absl::flat_hash_map<int64_t, int64_t>
explicit_batch_dims_input_dim_to_update_dim_;

const ScatterDimensionNumbers& dim_numbers_;
const Literal& scatter_indices_;
};
Expand All @@ -2896,25 +2937,16 @@ class UpdateWindowIndexToInputIndex {
// The constructor does some setup work that is amortized across all
// iterations.
explicit UpdateWindowIndexToInputIndex(
const ScatterDimensionNumbers& dim_numbers, int64_t input_rank,
int64_t update_rank) {
std::vector<int64_t> window_index_to_update_index;
int64_t update_index_count = 0;
for (int64_t i = 0; i < update_rank; i++) {
if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
window_index_to_update_index.push_back(update_index_count++);
} else {
update_index_count++;
}
}

const ScatterDimensionNumbers& dim_numbers, int64_t input_rank) {
CHECK(absl::c_is_sorted(dim_numbers.update_window_dims()));
int64_t window_dim_count = 0;
for (int64_t i = 0; i < input_rank; i++) {
if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
if (IsCollapsedOrBatchingDim(dim_numbers.inserted_window_dims(),
dim_numbers.input_batching_dims(), i)) {
input_dim_value_to_update_index_.push_back(-1);
} else {
input_dim_value_to_update_index_.push_back(
window_index_to_update_index[window_dim_count++]);
dim_numbers.update_window_dims()[window_dim_count++]);
}
}

Expand Down Expand Up @@ -3004,8 +3036,7 @@ absl::Status HloEvaluator::HandleScatter(const HloInstruction* hlo) {
/*input_rank=*/operand_dims.size(), updates_dims.size(),
&scatter_indices);
UpdateWindowIndexToInputIndex update_window_index_to_input_index(
scatter->scatter_dimension_numbers(),
/*input_rank=*/operand_dims.size(), updates_dims.size());
scatter->scatter_dimension_numbers(), /*input_rank=*/operand_dims.size());

// Initialize the result with the operand. This makes it easier to handle
// the updates even when the indices are repeated.
Expand Down
95 changes: 95 additions & 0 deletions xla/hlo/evaluator/hlo_evaluator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3676,6 +3676,40 @@ ENTRY main {
LiteralUtil::CreateR2<int32_t>({{0, 1}, {2, 1}}), result));
}

TEST_F(HloEvaluatorTest, EvaluateGather_ExplicitBatchDims) {
const std::string hlo_text = R"(
HloModule gather
ENTRY main {
operand = s32[3,2,1,3] parameter(0)
indices = s32[3,2] parameter(1)
ROOT gather = s32[3,2,3] gather(operand, indices),
offset_dims={2},
collapsed_slice_dims={2},
start_index_map={0},
index_vector_dim=2,
slice_sizes={3,1,1,1},
operand_batching_dims={1,3},
start_indices_batching_dims={1,0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));

Literal operand =
LiteralUtil::CreateR4<int32_t>({{{{1, 2, 3}}, {{4, 5, 6}}},
{{{7, 8, 9}}, {{10, 11, 12}}},
{{{13, 14, 15}}, {{16, 17, 18}}}});
Literal start_indices =
LiteralUtil::CreateR2<int32_t>({{1, 0}, {0, 1}, {1, 0}});
Literal expected_result =
LiteralUtil::CreateR3<int32_t>({{{1, 7, 13}, {4, 10, 16}},
{{2, 8, 14}, {5, 11, 17}},
{{3, 9, 15}, {6, 12, 18}}});

TF_ASSERT_OK_AND_ASSIGN(Literal result, Evaluate({&operand, &start_indices}));
EXPECT_TRUE(LiteralTestUtil::Equal(expected_result, result));
}

TEST_F(HloEvaluatorTest, EvaluateScatter_TensorFlowScatterV1_Update) {
const char* hlo_text = R"(
HloModule TensorFlowScatterV1
Expand Down Expand Up @@ -4287,6 +4321,67 @@ ENTRY main {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}

TEST_F(HloEvaluatorTest, EvaluateScatter_ExplicitBatchDims) {
const char* hlo_text = R"(
HloModule ScatterExplicitBatchDims
add_s32 {
x = s32[] parameter(0)
y = s32[] parameter(1)
ROOT s = s32[] add(x,y)
}
ENTRY main {
indices = s32[2,3,5] parameter(0)
update = s32[2,3,2,5] parameter(1)
z = s32[] constant(0)
input = s32[5,3,2,2] broadcast(z), dimensions={}
ROOT s = s32[5,3,2,2] scatter(input, indices, update),
update_window_dims={2},
inserted_window_dims={1},
scatter_dims_to_operand_dims={1},
index_vector_dim=3,
input_batching_dims={0,3},
scatter_indices_batching_dims={2,0},
to_apply=add_s32
}
)";
TF_ASSERT_OK_AND_ASSIGN(m_, ParseAndReturnVerifiedModule(hlo_text));
auto indices = std::make_unique<Literal>(
ShapeUtil::MakeShapeWithDenseLayout(S32, {2, 3, 5}, {2, 1, 0}));
indices
->Populate<int>([](absl::Span<const int64_t> indices) {
return static_cast<int>((indices[1] + 1) % 3);
})
.IgnoreError();
auto updates = std::make_unique<Literal>(
ShapeUtil::MakeShapeWithDenseLayout(S32, {2, 3, 2, 5}, {3, 2, 1, 0}));
updates
->Populate<int>([](absl::Span<const int64_t> indices) {
return static_cast<int>(indices[0] * 1000 + indices[1] * 100 +
indices[2] * 10 + indices[3]);
})
.IgnoreError();
Literal expected =
LiteralUtil::CreateR4<int32_t>({{{{200, 1200}, {210, 1210}},
{{0, 1000}, {10, 1010}},
{{100, 1100}, {110, 1110}}},
{{{201, 1201}, {211, 1211}},
{{1, 1001}, {11, 1011}},
{{101, 1101}, {111, 1111}}},
{{{202, 1202}, {212, 1212}},
{{2, 1002}, {12, 1012}},
{{102, 1102}, {112, 1112}}},
{{{203, 1203}, {213, 1213}},
{{3, 1003}, {13, 1013}},
{{103, 1103}, {113, 1113}}},
{{{204, 1204}, {214, 1214}},
{{4, 1004}, {14, 1014}},
{{104, 1104}, {114, 1114}}}});
TF_ASSERT_OK_AND_ASSIGN(Literal result,
Evaluate({indices.get(), updates.get()}));
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}

// Verifies that HloEvaluator evaluates a HLO instruction that performs
// element-wise comparison with 2 bfloat16 operands.
TEST_F(HloEvaluatorTest, DoesCompareBF16) {
Expand Down
Loading