Skip to content

Commit

Permalink
[XLA:MSA] Remove unnecessary Extend() call in memory space assignment…
Browse files Browse the repository at this point in the history
…. This Extend() call would also lead to a memory assignment issue since it wasn't accompanied by the necessary chunk commit requests. We also add a VerifyAllocations() function that uses a BufferIntervalTree to check for overlapping Allocations before scheduling the asynchronous copies. This is an extra check for the correctness of MsaAlgorithm allocations, and is only applied if options_.verify is enabled in MSA options. options_.verify is disabled by default.

PiperOrigin-RevId: 693110290
  • Loading branch information
mehrdadkhani authored and Google-ML-Automation committed Nov 15, 2024
1 parent 4d5f691 commit 81198bb
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 7 deletions.
1 change: 1 addition & 0 deletions xla/service/memory_space_assignment/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ xla_cc_test(
deps = [
":algorithm",
":allocation",
":allocation_value",
":buffer_interval_comparator",
":cost_analysis",
":memory_space_assignment",
Expand Down
10 changes: 6 additions & 4 deletions xla/service/memory_space_assignment/algorithm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,9 @@ absl::StatusOr<AllocationResult> MsaAlgorithm::AllocateAllocationValues(
definition_time_for_allocation_value.at(&allocation_value_to_update),
RequiresNoCopyAlternateMemAllocation(allocation_value_to_update),
all_use_times, entry.only_extend_existing_allocation);
if (options_.allocation_request_modifier_testing_fn) {
options_.allocation_request_modifier_testing_fn(request);
}
// Bitcasts don't define buffers and don't directly consume buffers.
// Skip allocating buffers for bitcast uses (unless they are the root
// instruction). The uses that feed from bitcasts will be handled
Expand All @@ -2483,6 +2486,9 @@ absl::StatusOr<AllocationResult> MsaAlgorithm::AllocateAllocationValues(
use.hlo_use.instruction ==
use.hlo_use.instruction->parent()->root_instruction()) {
result_mark(AllocateSegment(request), result);
if (options_.allocation_result_modifier_testing_fn) {
options_.allocation_result_modifier_testing_fn(request, result);
}
if (request.require_copy_allocation) {
auto allocation_sequence =
allocation_value_to_update.mutable_allocation_sequence();
Expand Down Expand Up @@ -4378,10 +4384,6 @@ AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) {
*use.instruction, use.operand_number, use.operand_index);
}

if (request.only_extend_existing_allocation &&
!allocation_sequence->empty()) {
allocation_sequence->back()->Extend(request.inclusive_start_time);
}
// There could be a requirement to pin this buffer to default memory either
// because it is a parameter or an output. If the buffer is a parameter, then
// we're allowed to prefetch. If the use expects the output to be in default
Expand Down
34 changes: 34 additions & 0 deletions xla/service/memory_space_assignment/memory_space_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,37 @@ MemorySpaceAssignment::Run(HloModule* module,
alias_analysis);
}

absl::Status MemorySpaceAssignment::VerifyAllocations() const {
BufferIntervalTree interval_tree;
// Checks the chunks that overlap with a given allocation in time do not
// overlap with the allocation's chunk in the memory range. If they do, we
// throw an error, otherwise we add the allocation's chunk to the interval
// tree and return an OK status.
auto add_allocation_and_verify =
[&](const Allocation* allocation) -> absl::Status {
for (const HeapSimulator::Chunk& overlapping_chunk :
interval_tree.ChunksOverlappingInTime(allocation->start_time(),
allocation->end_time() - 1)) {
CHECK(!allocation->chunk().OverlapsWith(overlapping_chunk))
<< "Chunks are overlapping at Allocation level (before fixing the "
"schedule): "
<< allocation->ToString()
<< " overlaps with allocated chunk: " << overlapping_chunk.ToString();
}
interval_tree.Add(allocation->start_time(), allocation->end_time() - 1,
allocation->chunk());
return absl::OkStatus();
};
// Verify that all alternate memory allocations are free of overlapping
// Allocations in time and space, and add them to interval_tree one by one.
for (const auto& allocation : allocations_) {
if (allocation->memory_space() == MemorySpace::kAlternate) {
TF_RETURN_IF_ERROR(add_allocation_and_verify(allocation.get()));
}
}
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<PresetAssignments>>
MemorySpaceAssignment::RunMemorySpaceAssignment(
const HloLiveRange& hlo_live_range,
Expand All @@ -365,6 +396,9 @@ MemorySpaceAssignment::RunMemorySpaceAssignment(
}

TF_RETURN_IF_ERROR(Process(hlo_live_range));
if (options_.verify) {
TF_RETURN_IF_ERROR(VerifyAllocations());
}
// DEBUG_LOG_ALLOCATIONS_AT
//
// Uncomment the following to log the alternate memory allocations that MSA
Expand Down
5 changes: 5 additions & 0 deletions xla/service/memory_space_assignment/memory_space_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ class MemorySpaceAssignment {
// Calculates asynchronous copy statistics.
absl::StatusOr<AsyncCopyStats> CalculateAsyncCopyStats() const;

// Verify that allocations_ are free of overlapping Allocations in time and
// space. This is a post-processing step called after all allocations have
// been finalized, before the async copies get scheduled.
absl::Status VerifyAllocations() const;

// Verify that the memory space assignment is free of overlapping buffers and
// export heap simulator trace to be used by buffer_assignment.
//
Expand Down
168 changes: 166 additions & 2 deletions xla/service/memory_space_assignment/memory_space_assignment_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ limitations under the License.
#include <algorithm>
#include <cstdint>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <numeric>
#include <optional>
#include <ostream>
#include <set>
#include <string>
#include <string_view>
#include <tuple>
Expand Down Expand Up @@ -65,6 +66,7 @@ limitations under the License.
#include "xla/service/hlo_value.h"
#include "xla/service/memory_space_assignment/algorithm.h"
#include "xla/service/memory_space_assignment/allocation.h"
#include "xla/service/memory_space_assignment/allocation_value.h"
#include "xla/service/memory_space_assignment/buffer_interval_comparator.h"
#include "xla/service/memory_space_assignment/cost_analysis.h"
#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h"
Expand Down Expand Up @@ -160,7 +162,7 @@ class MemorySpaceAssignmentTestBase : public HloTestBase {
Options options;
options.max_size_in_bytes = 128;
options.alignment_in_bytes = 8;
options.verify = true;
options.verify = false;
options.alternate_memory_space = kAlternateMemorySpace;
options.max_outstanding_prefetches = -1;
options.max_outstanding_evictions = -1;
Expand Down Expand Up @@ -1214,6 +1216,168 @@ TEST_F(MemorySpaceAssignmentTest, ConditionalCopyReplacement) {
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0));
}

TEST_F(MemorySpaceAssignmentTest, AllocationRequestAndResultModifierTest) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[2,3]{1,0} parameter(0)
p1 = f32[2,3]{1,0} parameter(1)
negate0 = f32[2,3]{1,0} negate(p1)
negate1 = f32[2,3]{1,0} negate(negate0)
negate2 = f32[2,3]{1,0} negate(negate1)
negate3 = f32[2,3]{1,0} negate(negate2)
negate4 = f32[2,3]{1,0} negate(negate3)
negate5 = f32[2,3]{1,0} negate(negate4)
negate6 = f32[2,3]{1,0} negate(negate5)
negate7 = f32[2,3]{1,0} negate(negate6)
ROOT add0 = f32[2,3]{1,0} add(p0, negate7)
}
)";
// The baseline behavior is to prefetch p0 at add0.
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> baseline_module,
ParseAndReturnVerifiedModule(hlo_string));
Options options = DefaultMemorySpaceOptions();
AssignMemorySpace(baseline_module.get(), options);
HloInstruction* add0 = FindInstruction(baseline_module.get(), "add0");
ASSERT_NE(add0, nullptr);
HloInstruction* p0 = FindInstruction(baseline_module.get(), "p0");
ASSERT_NE(p0, nullptr);
EXPECT_THAT(add0->operand(0),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0));

// We should be able to prevent prefetching p0 at add0 using
// allocation_result_modifier_testing_fn.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> result_modifier_module,
ParseAndReturnVerifiedModule(hlo_string));
options.max_retries = 1;
options.allocation_request_modifier_testing_fn = nullptr;
options.allocation_result_modifier_testing_fn =
[](const AllocationRequest& request, AllocationResult& result) {
if (request.allocation_value_to_update->defining_instruction()
->name() == "p0" &&
request.use->hlo_use.instruction->name() == "add0") {
result = AllocationResult::kFailRequiresUncommit;
}
};
AssignMemorySpace(result_modifier_module.get(), options);
add0 = FindInstruction(result_modifier_module.get(), "add0");
ASSERT_NE(add0, nullptr);
p0 = FindInstruction(result_modifier_module.get(), "p0");
ASSERT_NE(p0, nullptr);
EXPECT_EQ(add0->operand(0), p0);

// We should be able to enforce an earlier prefetch of p0 at add0 using
// allocation_request_modifier_testing_fn.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<VerifiedHloModule> request_modifier_module,
ParseAndReturnVerifiedModule(hlo_string));
options.max_retries = 1;
options
.allocation_request_modifier_testing_fn = [](AllocationRequest& request) {
if (request.allocation_value_to_update->defining_instruction()->name() ==
"p0" &&
request.use->hlo_use.instruction->name() == "add0") {
// Schedule the copy-done before negate4 (scheduled at 6).
request.latest_prefetch_time = 6;
}
};
options.allocation_result_modifier_testing_fn = nullptr;
AssignMemorySpace(request_modifier_module.get(), options);
add0 = FindInstruction(request_modifier_module.get(), "add0");
CHECK_NE(add0, nullptr);
p0 = FindInstruction(request_modifier_module.get(), "p0");
CHECK_NE(p0, nullptr);
EXPECT_THAT(add0->operand(0),
op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0));
// The copy-done should have been scheduled before negate4.
HloInstruction* negate4 =
FindInstruction(request_modifier_module.get(), "negate4");
CHECK_NE(negate4, nullptr);
const HloInstructionSequence& sequence =
request_modifier_module->schedule().sequence(
request_modifier_module->entry_computation());
auto find_index = [&](const HloInstruction* instruction) {
return std::distance(sequence.instructions().begin(),
std::find(sequence.instructions().begin(),
sequence.instructions().end(), instruction));
};

int negate4_index = find_index(negate4);
int copy_done_index = find_index(add0->operand(0));
EXPECT_LT(copy_done_index, negate4_index);
}

// Added for b/376869021, which surfaced when we tried to convert a sync slice
// that had to extend the allocation of its operand in the alternate memory. In
// this test, we expect the slice0 operand (p0_copy) maintain a valid allocation
// in the alternate memory, until it gets transferred by the async replacement
// of slice0. We hence stress-test such validity by delaying the allocation of
// slice0 by 3 steps.
TEST_F(MemorySpaceAssignmentTest, SyncReplacementAllocationExtensionBug) {
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true

ENTRY entry {
p0 = f32[2,2,3]{2,1,0} parameter(0)
p1 = f32[4,2,3]{2,1,0} parameter(1)
p0_copy = f32[2,2,3]{2,1,0} copy(p0)
negate0 = negate(p1)
negate1 = negate(negate0)
negate2 = negate(negate1)
p0_copy0_negate = negate(p0_copy)
copy_negate2 = copy(negate2)
slice0 = f32[1,2,3] slice(p0_copy), slice={[0:1], [0:2], [0:3]}
negate3 = negate(copy_negate2)
negate4 = negate(negate3)
negate5 = negate(negate4)
negate6 = negate(negate5)
negate7 = negate(negate6)
neg_slice0 = negate(slice0)
ROOT tuple = tuple(negate7, neg_slice0)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
Options options = DefaultMemorySpaceOptions();
options.enable_sync_copy_replacement = false;
options.enable_sync_slice_replacement = true;
options.verify = true;
options.is_async_slice_implemented_fn =
[](const HloInstruction* instruction) { return true; };
options.max_size_in_bytes = 96;
options.is_position_allowed_in_alternate_mem_fn =
[](const HloPosition& position) {
return position.instruction->name() != "p0_copy";
};
// Delay the allocation of slice0 by 3 steps to allow copy_negate2 to be
// allocated in alternate memory.
options.allocation_request_modifier_testing_fn =
[](AllocationRequest& request) {
if (request.only_extend_existing_allocation) {
request.inclusive_start_time += 3;
request.end_time += 3;
}
};
const std::string text_proto = R"pb(
overrides {
hlo_position_matcher { instruction_name_regex: "copy_negate2|p0_copy" }
override_options { assign_first: true }
})pb";
TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides,
ParseTextProto<MsaSortOrderOverrides>(text_proto));
auto preset_assignments = AssignMemorySpaceUsingCostAnalysis(
module.get(), options,
/*cost_analysis_options_override=*/std::nullopt,
/*hlo_cost_options_override=*/std::nullopt,
/*optional_msa_sort_order_overrides=*/msa_sort_order_overrides);
HloInstruction* p0_copy = FindInstruction(module.get(), "p0_copy");
ASSERT_NE(p0_copy, nullptr);
HloInstruction* neg_slice0 = FindInstruction(module.get(), "neg_slice0");
ASSERT_NE(neg_slice0, nullptr);
EXPECT_THAT(neg_slice0->operand(0), op::AsyncDone(op::AsyncStart(p0_copy)));
}

TEST_F(MemorySpaceAssignmentTest, AlwaysSpillJitPrefetchTest) {
// The negate chain is long enough for asynchronous copy to be inserted
// between p1 and add.
Expand Down
16 changes: 15 additions & 1 deletion xla/service/memory_space_assignment/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/service/buffer_value.h"
#include "xla/service/heap_simulator/heap_simulator.h"
#include "xla/service/hlo_value.h"
#include "xla/service/memory_space_assignment/allocation_value.h"
#include "xla/service/memory_space_assignment/buffer_interval_comparator.h"
#include "xla/service/memory_space_assignment/cost_analysis.h"
#include "xla/service/memory_space_assignment/memory_space_assignment.pb.h"
Expand Down Expand Up @@ -128,9 +128,23 @@ struct Options {
WindowPrefetchNotifyOperandAppendedFunction notify_operand_appended_fn =
[](HloInstruction*, int64_t, int64_t) {};

// This function can be used to check if an equivalent asynchronous slice
// lowering is implemented for a given synchronous slice instruction.
IsAsyncSliceImplementedFunction is_async_slice_implemented_fn =
[](const HloInstruction*) { return false; };

// Should only be used for testing purposes. This function allows us to
// modify the AllocationResult after the AllocationRequest has been processed
// by AllocateSegment().
std::function<void(const AllocationRequest&, AllocationResult&)>
allocation_result_modifier_testing_fn = nullptr;

// Should only be used for testing purposes. This function allows us to
// modify the AllocationRequest before the AllocationRequest is passed to
// AllocateSegment().
std::function<void(AllocationRequest&)>
allocation_request_modifier_testing_fn = nullptr;

// If true, we will try to reduce scoped allocation buffer size for all
// instructions if their operand/output has been allocated in alternate
// memory.
Expand Down

0 comments on commit 81198bb

Please sign in to comment.