Skip to content

Commit

Permalink
Move GetStartIndicesDimsToOutputDims to gather_scatter_utils.h
Browse files Browse the repository at this point in the history
We will use this function in both algebraic_simplifier and hlo_evaluator. This change only moves util function without behavior change.

PiperOrigin-RevId: 697042675
  • Loading branch information
ZixuanJiang authored and Google-ML-Automation committed Nov 16, 2024
1 parent 31947e4 commit 3f407ff
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 31 deletions.
1 change: 1 addition & 0 deletions xla/hlo/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ cc_library(
"//xla/hlo/ir:hlo_instruction_utils",
"//xla/hlo/pass:hlo_pass",
"//xla/hlo/utils:hlo_sharding_util",
"//xla/service:gather_scatter_utils",
"//xla/service:hlo_cost_analysis",
"//xla/service:hlo_creation_utils",
"//xla/service:hlo_module_config",
Expand Down
37 changes: 6 additions & 31 deletions xla/hlo/transforms/simplifiers/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ limitations under the License.
#include "xla/overflow_util.h"
#include "xla/permutation_util.h"
#include "xla/primitive_util.h"
#include "xla/service/gather_scatter_utils.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/service/hlo_creation_utils.h"
#include "xla/service/hlo_module_config.h"
Expand Down Expand Up @@ -4076,35 +4077,6 @@ std::vector<int64_t> GetPaddedDims(const HloInstruction* pad) {
return padded_dims;
}

// Returns a map from start_indices explicit batching dims to their
// corresponding output dims.
absl::flat_hash_map<int64_t, int64_t> GetStartIndicesDimsToOutputDims(
const HloInstruction* gather) {
absl::flat_hash_map<int64_t, int64_t> start_indices_dims_to_output_dims;
const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers();
start_indices_dims_to_output_dims.reserve(
dnums.start_indices_batching_dims_size());

for (int64_t output_dim = 0, start_indices_dim = 0;
output_dim < gather->shape().rank(); ++output_dim) {
if (absl::c_linear_search(dnums.offset_dims(), output_dim)) {
continue;
}
// Output_dim is an implicit or explicit batching dim.
if (start_indices_dim == dnums.index_vector_dim()) {
start_indices_dim++;
}
CHECK_LT(start_indices_dim, gather->operand(1)->shape().rank());
if (absl::c_linear_search(dnums.start_indices_batching_dims(),
start_indices_dim)) {
// Explicit batching dim.
start_indices_dims_to_output_dims[start_indices_dim] = output_dim;
}
++start_indices_dim;
}
return start_indices_dims_to_output_dims;
}

struct GatherOfPadInfo {
bool should_transform;
bool has_padded_batching_dims;
Expand Down Expand Up @@ -4189,9 +4161,12 @@ GatherOfPadInfo CheckPaddedDimsForGatherOfPad(
// Add padded explicit operand batching dims and their corresponding result
// dims to padded_operand_dims_to_output_dims and
// output_dims_to_padded_operand_dims.
const absl::flat_hash_map<int64_t, int64_t>&
const absl::flat_hash_map<int64_t, int64_t>
start_indices_dims_to_output_dims =
GetStartIndicesDimsToOutputDims(gather);
GetStartIndicesDimToOutputDimForExplicitBatchingDims(
dnums.start_indices_batching_dims(), dnums.index_vector_dim(),
dnums.offset_dims(), start_indices->shape().rank(),
gather->shape().rank());
for (int64_t operand_dim : padded_operand_dims) {
if (!absl::c_linear_search(operand_batching_dims, operand_dim)) {
continue;
Expand Down
4 changes: 4 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5839,7 +5839,11 @@ cc_library(
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
35 changes: 35 additions & 0 deletions xla/service/gather_scatter_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@ limitations under the License.

#include "xla/service/gather_scatter_utils.h"

#include <cstddef>
#include <cstdint>
#include <iterator>
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
Expand All @@ -30,6 +34,7 @@ limitations under the License.
#include "xla/service/hlo_creation_utils.h"
#include "xla/shape.h"
#include "xla/util.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -220,4 +225,34 @@ bool IsCollapsedOrBatchingDim(absl::Span<const int64_t> collapsed_dims,
return absl::c_linear_search(collapsed_dims, dim) ||
absl::c_linear_search(batching_dims, dim);
}

absl::flat_hash_map<int64_t, int64_t>
GetStartIndicesDimToOutputDimForExplicitBatchingDims(
absl::Span<const int64_t> start_indices_batching_dims,
int64_t index_vector_dim, absl::Span<const int64_t> offset_dims,
int64_t start_indices_rank, int64_t output_rank) {
absl::flat_hash_map<int64_t, int64_t>
explicit_batching_dims_start_indices_dim_to_output_dim;
explicit_batching_dims_start_indices_dim_to_output_dim.reserve(
start_indices_batching_dims.size());

for (int64_t output_dim = 0, start_indices_dim = 0; output_dim < output_rank;
++output_dim) {
if (absl::c_linear_search(offset_dims, output_dim)) {
continue;
}
if (start_indices_dim == index_vector_dim) {
start_indices_dim++;
}
CHECK_LT(start_indices_dim, start_indices_rank);
if (absl::c_linear_search(start_indices_batching_dims, start_indices_dim)) {
// Explicit batching dim.
explicit_batching_dims_start_indices_dim_to_output_dim
[start_indices_dim] = output_dim;
}
++start_indices_dim;
}
return explicit_batching_dims_start_indices_dim_to_output_dim;
}

} // namespace xla
14 changes: 14 additions & 0 deletions xla/service/gather_scatter_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ limitations under the License.
#ifndef XLA_SERVICE_GATHER_SCATTER_UTILS_H_
#define XLA_SERVICE_GATHER_SCATTER_UTILS_H_

#include <cstddef>
#include <cstdint>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/shape.h"

namespace xla {

Expand Down Expand Up @@ -66,6 +71,15 @@ absl::StatusOr<HloInstruction*> ExpandIndexVectorIntoOperandSpace(
bool IsCollapsedOrBatchingDim(absl::Span<const int64_t> collapsed_dims,
absl::Span<const int64_t> batching_dims,
int64_t dim);

// Returns a map from start_indices explicit batching dims to their
// corresponding output dims.
absl::flat_hash_map<int64_t, int64_t>
GetStartIndicesDimToOutputDimForExplicitBatchingDims(
absl::Span<const int64_t> start_indices_batching_dims,
int64_t index_vector_dim, absl::Span<const int64_t> offset_dims,
int64_t start_indices_rank, int64_t output_rank);

} // namespace xla

#endif // XLA_SERVICE_GATHER_SCATTER_UTILS_H_

0 comments on commit 3f407ff

Please sign in to comment.