Skip to content

Commit

Permalink
Reverts 1f6bd97
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689845725
  • Loading branch information
Google-ML-Automation committed Oct 25, 2024
1 parent acb4e56 commit 79142a8
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 159 deletions.
132 changes: 52 additions & 80 deletions xla/hlo/experimental/auto_sharding/auto_sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ void FollowArrayOrTokenStrategyGroup(
double compute_cost = 0, communication_cost = 0;
double memory_cost = ByteSizeOfShapeWithSharding(shape, *output_spec);
size_t num_in_nodes = strategy_group.in_nodes.size();
InputShardings input_shardings{name, {num_in_nodes, *output_spec}};
ReshardingCosts communication_resharding_costs;
ReshardingCosts memory_resharding_costs;
for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) {
Expand All @@ -343,14 +344,11 @@ void FollowArrayOrTokenStrategyGroup(
memory_resharding_costs.push_back(MemoryReshardingCostVector(
*strategy_group.in_nodes[i], shape, *output_spec, cluster_env));
}
InputShardings input_shardings{name,
{num_in_nodes, *output_spec},
communication_resharding_costs,
memory_resharding_costs};

strategy_group.AddStrategy(
ShardingStrategy(
{*output_spec, compute_cost, communication_cost, memory_cost}),
ShardingStrategy({*output_spec, compute_cost, communication_cost,
memory_cost, communication_resharding_costs,
memory_resharding_costs}),
input_shardings);
}
}
Expand Down Expand Up @@ -406,14 +404,12 @@ std::unique_ptr<StrategyGroup> HandlePartialReduce(
GenerateReshardingCostsAndMissingShardingsForAllOperands(
ins, output_spec, strategy_map, cluster_env, call_graph,
input_shardings);
input_shardings.communication_resharding_costs =
std::move(resharding_costs.first);
input_shardings.memory_resharding_costs =
std::move(resharding_costs.second);

child_strategy_group->AddStrategy(
ShardingStrategy({std::move(output_spec), compute_cost,
communication_cost, memory_cost}),
communication_cost, memory_cost,
std::move(resharding_costs.first),
std::move(resharding_costs.second)}),
std::move(input_shardings));
}

Expand Down Expand Up @@ -558,11 +554,9 @@ absl::StatusOr<std::unique_ptr<StrategyGroup>> FollowReduceStrategy(
}
}
const ShardingStrategy strategy = ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost});
strategy_group->AddStrategy(strategy, {name,
{input_sharding},
communication_resharding_costs,
memory_resharding_costs});
{output_spec, compute_cost, communication_cost, memory_cost,
communication_resharding_costs, memory_resharding_costs});
strategy_group->AddStrategy(strategy, {name, {input_sharding}});
}
} else {
LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString();
Expand Down Expand Up @@ -703,13 +697,11 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape,
}
communication_resharding_costs.push_back({});
memory_resharding_costs.push_back({});
input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs = std::move(memory_resharding_costs);
double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec);
strategy_group.AddStrategy(
ShardingStrategy(
{HloSharding::Replicate(), replicated_penalty, 0, memory_cost}),
ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0,
memory_cost, std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
input_shardings);
}

Expand Down Expand Up @@ -810,18 +802,15 @@ void AddReplicatedStrategy(
}
}

for (size_t j = 0; j < possible_input_shardings.size(); ++j) {
possible_input_shardings[j].communication_resharding_costs =
std::move(possible_communication_resharding_costs[j]);
possible_input_shardings[j].memory_resharding_costs =
std::move(possible_memory_resharding_costs[j]);
}
for (size_t j = 0; j < possible_input_shardings.size(); ++j) {
double communication_cost = ComputeCommunicationCost(
ins, possible_input_shardings[j], cluster_env);
strategy_group.AddStrategy(
ShardingStrategy({replicated_strategy, replicated_penalty,
communication_cost, memory_cost}),
ShardingStrategy(
{replicated_strategy, replicated_penalty, communication_cost,
memory_cost,
std::move(possible_communication_resharding_costs[j]),
std::move(possible_memory_resharding_costs[j])}),
std::move(possible_input_shardings[j]));
}
} else {
Expand Down Expand Up @@ -859,13 +848,11 @@ void AddReplicatedStrategy(
}
}
}
input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs =
std::move(memory_resharding_costs);
strategy_group.AddStrategy(
ShardingStrategy(
{HloSharding::Replicate(), replicated_penalty, 0, memory_cost}),
ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0,
memory_cost,
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
input_shardings);
}
}
Expand Down Expand Up @@ -952,13 +939,11 @@ void EnumerateAll1DPartition(
communication_cost = ComputeSortCommunicationCost(
ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env);
}
input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs =
std::move(memory_resharding_costs);
strategy_group.AddStrategy(
ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost}),
ShardingStrategy({output_spec, compute_cost, communication_cost,
memory_cost,
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
input_shardings);
}
}
Expand Down Expand Up @@ -1066,12 +1051,10 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape,
}
}

input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs = std::move(memory_resharding_costs);
strategy_group.AddStrategy(
ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost}),
ShardingStrategy({output_spec, compute_cost, communication_cost,
memory_cost, std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
input_shardings);
}

Expand Down Expand Up @@ -1119,12 +1102,11 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins,
ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector(
operand_strategy_group, operand_shape, *input_spec, cluster_env)};
strategy_group.AddStrategy(
ShardingStrategy(
{output_spec, compute_cost, communication_cost, memory_cost}),
{name,
{*input_spec},
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)});
ShardingStrategy({output_spec, compute_cost, communication_cost,
memory_cost,
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
{name, {*input_spec}});
}
}
}
Expand Down Expand Up @@ -1439,21 +1421,19 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding(
strategy_group.GetStrategies();
}
strategy_group.ClearStrategies();
input_shardings.communication_resharding_costs =
communication_resharding_costs;
input_shardings.memory_resharding_costs = memory_resharding_costs;
strategy_group.AddStrategy(
ShardingStrategy({existing_sharding, 0, 0, memory_cost}),
ShardingStrategy({existing_sharding, 0, 0, memory_cost,
communication_resharding_costs,
memory_resharding_costs}),
input_shardings);
}
// If there is only one option for resharding, and the cost computed for
// that option is kInfinityCost, set the cost to zero. This is okay
// because there is only one option anyway, and having the costs set to
// kInfinityCost is problematic for the solver.
if (strategy_group.GetStrategyInputShardings().size() == 1) {
if (strategy_group.GetStrategies().size() == 1) {
for (auto& operand_communication_resharding_costs :
strategy_group.GetMutableInputShardings(0)
.communication_resharding_costs) {
strategy_group.GetStrategy(0).communication_resharding_costs) {
if (operand_communication_resharding_costs.size() == 1 &&
operand_communication_resharding_costs[0] >= kInfinityCost) {
operand_communication_resharding_costs[0] = 0;
Expand Down Expand Up @@ -1581,17 +1561,10 @@ void ScaleCostsWithExecutionCounts(const int64_t execution_count,
ShardingStrategy& strategy = leaf_strategy_group.GetStrategy(sid);
scale_cost(strategy.compute_cost);
scale_cost(strategy.communication_cost);
}
for (int iid = 0;
iid < leaf_strategy_group.GetStrategyInputShardings().size(); ++iid) {
InputShardings& input_shardings =
leaf_strategy_group.GetMutableInputShardings(iid);
for (int i = 0; i < input_shardings.communication_resharding_costs.size();
++i) {
for (int j = 0;
j < input_shardings.communication_resharding_costs[i].size();
for (int i = 0; i < strategy.communication_resharding_costs.size(); ++i) {
for (int j = 0; j < strategy.communication_resharding_costs[i].size();
++j) {
scale_cost(input_shardings.communication_resharding_costs[i][j]);
scale_cost(strategy.communication_resharding_costs[i][j]);
}
}
}
Expand Down Expand Up @@ -1703,13 +1676,11 @@ std::unique_ptr<StrategyGroup> HandleManuallyShardedInstruction(
memory_resharding_costs.push_back(zeros);
}
}
input_shardings.communication_resharding_costs =
std::move(communication_resharding_costs);
input_shardings.memory_resharding_costs =
std::move(memory_resharding_costs);
strategy_group->AddStrategy(
ShardingStrategy({HloSharding::Replicate(), 0, 0,
static_cast<double>(ShapeUtil::ByteSizeOf(shape))}),
static_cast<double>(ShapeUtil::ByteSizeOf(shape)),
std::move(communication_resharding_costs),
std::move(memory_resharding_costs)}),
std::move(input_shardings));
} else {
LOG(FATAL) << "Unsupported instruction shape: " << shape.DebugString();
Expand Down Expand Up @@ -1757,12 +1728,13 @@ std::unique_ptr<StrategyGroup> CreateReshapeStrategies(
operand_strategy_group, operand->shape(),
operand_strategy.output_sharding, cluster_env);
strategy_group->AddStrategy(
ShardingStrategy(
{*output_sharding, compute_cost, communication_cost, memory_cost}),
{name,
{operand_strategy.output_sharding},
{communication_resharding_costs},
{memory_resharding_costs}});
ShardingStrategy({*output_sharding,
compute_cost,
communication_cost,
memory_cost,
{communication_resharding_costs},
{memory_resharding_costs}}),
{name, {operand_strategy.output_sharding}});
}

if (strategy_group->GetStrategies().empty()) {
Expand Down
25 changes: 11 additions & 14 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,31 +163,28 @@ EdgeReshardingCostMatrix CostGraph::CreateEdgeCost(
CHECK_LT(src_idx, node_lens_.size());
CHECK_LT(dst_idx, node_lens_.size());
EdgeReshardingCostMatrix edge_cost(node_lens_[src_idx], node_lens_[dst_idx]);
const auto& strategy_input_shardings =
strategy_group->GetStrategyInputShardings();
for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) {
const InputShardings& input_shardings = strategy_input_shardings[iid];
const NodeStrategyIdx k =
strategy_group->GetStrategyIdxForInputShardings(iid);
const auto& strategies = strategy_group->GetStrategies();
for (NodeStrategyIdx k = 0; k < strategies.size(); ++k) {
const ShardingStrategy& strategy = strategies[k];
size_t start_idx = 0;
CHECK_LT(in_node_idx, input_shardings.memory_resharding_costs.size())
CHECK_LT(in_node_idx, strategy.memory_resharding_costs.size())
<< strategy_group->node_idx;
if (input_shardings.memory_resharding_costs[in_node_idx].size() >
if (strategy.memory_resharding_costs[in_node_idx].size() >
node_lens_[src_idx]) {
start_idx = input_shardings.memory_resharding_costs[in_node_idx].size() -
start_idx = strategy.memory_resharding_costs[in_node_idx].size() -
node_lens_[src_idx];
}
for (size_t j = start_idx;
j < input_shardings.memory_resharding_costs[in_node_idx].size(); ++j) {
j < strategy.memory_resharding_costs[in_node_idx].size(); ++j) {
double communication_cost = 0;
double memory_cost = 0;
if (!zero_cost) {
communication_cost =
input_shardings.communication_resharding_costs[in_node_idx][j];
memory_cost = input_shardings.memory_resharding_costs[in_node_idx][j];
strategy.communication_resharding_costs[in_node_idx][j];
memory_cost = strategy.memory_resharding_costs[in_node_idx][j];
}
edge_cost(j - start_idx, k)
.take_max(EdgeReshardingCost(communication_cost, memory_cost));
edge_cost(j - start_idx, k) =
EdgeReshardingCost(communication_cost, memory_cost);
}
}
return edge_cost;
Expand Down
6 changes: 0 additions & 6 deletions xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_COST_GRAPH_H_
#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_COST_GRAPH_H_

#include <algorithm>
#include <cstddef>
#include <cstdlib>
#include <string>
Expand All @@ -43,11 +42,6 @@ struct EdgeReshardingCost {
EdgeReshardingCost(double communication_cost_, double memory_cost_)
: communication_cost(communication_cost_), memory_cost(memory_cost_) {}

void take_max(const EdgeReshardingCost& other) {
communication_cost = std::max(communication_cost, other.communication_cost);
memory_cost = std::max(memory_cost, other.memory_cost);
}

EdgeReshardingCost operator+(const EdgeReshardingCost& other) const {
return EdgeReshardingCost(other.communication_cost + communication_cost,
other.memory_cost + memory_cost);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,10 @@ void HandlerBase::AppendNewStrategy(const std::string& name,
strategy_group_->AddStrategy(
ShardingStrategy({output_spec, compute_cost, communication_cost,
static_cast<double>(ByteSizeOfShapeWithSharding(
ins_->shape(), output_spec))}),
{name,
{input_specs.begin(), input_specs.end()},
communication_resharding_costs,
memory_resharding_costs});
ins_->shape(), output_spec)),
communication_resharding_costs,
memory_resharding_costs}),
{name, {input_specs.begin(), input_specs.end()}});
}

// Given lhs and rhs dim maps, infers a sharding for the output by relying
Expand Down
Loading

0 comments on commit 79142a8

Please sign in to comment.