Skip to content

Commit

Permalink
Better encapsulation of HloModuleConfig's fields through setters and …
Browse files Browse the repository at this point in the history
…returning references instead of pointers.

PiperOrigin-RevId: 693857053
  • Loading branch information
toli-y authored and Google-ML-Automation committed Nov 15, 2024
1 parent 4d5f691 commit 1397371
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 19 deletions.
13 changes: 13 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1767,11 +1767,19 @@ cc_library(
hdrs = ["hlo_module_util.h"],
deps = [
":compiler",
":computation_layout",
":hlo_module_config",
"//xla:debug_options_flags",
"//xla:shape_layout",
"//xla:shape_util",
"//xla:util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
],
)

Expand Down Expand Up @@ -4311,12 +4319,16 @@ cc_library(
":hlo_proto_cc",
"//xla:debug_options_flags",
"//xla:shape_layout",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:protobuf",
"@tsl//tsl/platform:statusor",
],
Expand All @@ -4330,6 +4342,7 @@ xla_cc_test(
"//xla:xla_proto_cc",
"//xla/tests:test_utils",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
Expand Down
2 changes: 1 addition & 1 deletion xla/service/dump_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ TEST(DumpTest, DumpProtobufToFileWhenDisabled) {
TEST(DumpTest, DumpFdoProfileToFileWhenEnabled) {
std::string fdo_profile = "fdo_profile";
HloModuleConfig config;
*config.mutable_fdo_profile() = fdo_profile;
config.set_fdo_profile(fdo_profile);
DebugOptions options = config.debug_options();
auto env = tsl::Env::Default();
std::string dump_dir;
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3091,9 +3091,11 @@ xla_cc_test(
"//xla/tests:hlo_test_base",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:statusor",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/gpu_hlo_schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class GpuHloScheduleTest : public HloTestBase {
debug_options.set_xla_gpu_lhs_enable_gpu_async_tracker(
enable_gpu_async_tracker);
config.set_debug_options(debug_options);
*config.mutable_fdo_profile() = fdo_profile;
config.set_fdo_profile(fdo_profile);
return config;
}

Expand Down
6 changes: 5 additions & 1 deletion xla/service/gpu/gpu_latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@ limitations under the License.

#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"

#include <cstdint>
#include <memory>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/gpu_hlo_schedule.h"
Expand Down Expand Up @@ -78,7 +82,7 @@ class GpuLatencyHidingSchedulerBaseTest : public HloTestBase {
debug_options.set_xla_gpu_enable_latency_hiding_scheduler(true);
debug_options.set_xla_gpu_lhs_enable_gpu_async_tracker(true);
config.set_debug_options(debug_options);
*config.mutable_fdo_profile() = fdo_profile;
config.set_fdo_profile(fdo_profile);
return config;
}
};
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3310,10 +3310,10 @@ xla_cc_test(
deps = [
":pgle_accuracy_checker",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/service:latency_hiding_scheduler",
"//xla/service:profile_guided_latency_estimator",
"//xla/service/gpu:gpu_latency_hiding_scheduler",
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
Expand Down
8 changes: 4 additions & 4 deletions xla/service/gpu/transforms/pgle_accuracy_checker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/service/gpu/gpu_latency_hiding_scheduler.h"
#include "xla/service/latency_hiding_scheduler.h"
#include "xla/service/profile_guided_latency_estimator.h"
#include "xla/tests/hlo_test_base.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/status_matchers.h"
#include "tsl/platform/statusor.h"

namespace xla::gpu {
namespace {

using PGLEAccuracyCheckerTest = HloTestBase;
using PGLEAccuracyCheckerTest = HloHardwareIndependentTestBase;
using ::tensorflow::profiler::ProfiledInstructionsProto;
using ::tsl::protobuf::TextFormat;
using ::tsl::testing::StatusIs;
Expand Down Expand Up @@ -95,7 +95,7 @@ TEST_F(PGLEAccuracyCheckerTest,

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloString));
*module->mutable_config().mutable_fdo_profile() = kProfileString;
module->mutable_config().set_fdo_profile(kProfileString);

auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile);
PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator);
Expand Down Expand Up @@ -147,7 +147,7 @@ TEST_F(PGLEAccuracyCheckerTest,

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloString));
*module->mutable_config().mutable_fdo_profile() = kProfileString;
module->mutable_config().set_fdo_profile(kProfileString);
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_pgle_accuracy_checker(
Expand Down
8 changes: 6 additions & 2 deletions xla/service/hlo_module_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,15 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "xla/service/computation_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/shape_layout.h"
#include "xla/xla.pb.h"
#include "tsl/platform/statusor.h"
Expand Down Expand Up @@ -217,7 +221,7 @@ static void AssignStructFusionConfig(HloModuleConfig& config,
}
module_config.push_back(std::move(temp));
}
*config.mutable_fusion_config() = std::move(module_config);
config.set_fusion_config(std::move(module_config));
}

static void AssignStructDotConfig(HloModuleConfig& config,
Expand Down Expand Up @@ -259,7 +263,7 @@ static void AssignStructPhaseOrderingConfig(HloModuleConfig& config,
}
module_config.push_back(std::move(temp));
}
*config.mutable_phase_ordering_config() = std::move(module_config);
config.set_phase_ordering_config(std::move(module_config));
}

HloModuleConfigProto HloModuleConfig::ToProto() const {
Expand Down
24 changes: 19 additions & 5 deletions xla/service/hlo_module_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_SERVICE_HLO_MODULE_CONFIG_H_
#define XLA_SERVICE_HLO_MODULE_CONFIG_H_

#include <cstdint>
#include <memory>
#include <optional>
#include <string>
Expand All @@ -25,11 +26,15 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/debug_options_flags.h"
#include "xla/service/computation_layout.h"
#include "xla/service/computation_placer.h"
#include "xla/service/hlo.pb.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/protobuf.h"
Expand Down Expand Up @@ -325,8 +330,11 @@ class HloModuleConfig {
const std::vector<std::vector<bool>>& fusion_config() const {
return fusion_config_;
}
std::vector<std::vector<bool>>* mutable_fusion_config() {
return &fusion_config_;
void set_fusion_config(std::vector<std::vector<bool>> fusion_config) {
fusion_config_ = std::move(fusion_config);
}
std::vector<std::vector<bool>>& mutable_fusion_config() {
return fusion_config_;
}

const absl::flat_hash_map<std::string, std::vector<int64_t>>& dot_config()
Expand All @@ -347,8 +355,12 @@ class HloModuleConfig {
const std::vector<std::vector<bool>>& phase_ordering_config() const {
return phase_ordering_config_;
}
std::vector<std::vector<bool>>* mutable_phase_ordering_config() {
return &phase_ordering_config_;
void set_phase_ordering_config(
std::vector<std::vector<bool>> phase_ordering_config) {
phase_ordering_config_ = std::move(phase_ordering_config);
}
std::vector<std::vector<bool>>& mutable_phase_ordering_config() {
return phase_ordering_config_;
}

int phase_index() const { return phase_index_; }
Expand Down Expand Up @@ -398,7 +410,9 @@ class HloModuleConfig {
}

absl::string_view fdo_profile() const { return fdo_profile_; }
std::string* mutable_fdo_profile() { return &fdo_profile_; }
void set_fdo_profile(absl::string_view fdo_profile) {
fdo_profile_ = fdo_profile;
}

int64_t device_memory_size() const { return device_memory_size_; }
void set_device_memory_size(int64_t device_memory_size) {
Expand Down
1 change: 1 addition & 0 deletions xla/service/hlo_module_config_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include "xla/tests/test_utils.h"
#include "xla/xla.pb.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
Expand Down
13 changes: 11 additions & 2 deletions xla/service/hlo_module_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,27 @@ limitations under the License.

#include "xla/service/hlo_module_util.h"

#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/debug_options_flags.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/compiler.h"
#include "xla/service/computation_layout.h"
#include "xla/service/hlo_module_config.h"
#include "xla/shape.h"
#include "xla/shape_layout.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla {

Expand Down Expand Up @@ -134,7 +143,7 @@ absl::StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
}
config->set_alias_passthrough_params(
execution_options->alias_passthrough_params());
*config->mutable_fdo_profile() = execution_options->fdo_profile();
config->set_fdo_profile(execution_options->fdo_profile());
config->set_device_memory_size(execution_options->device_memory_size());
config->set_use_shardy_partitioner(
execution_options->use_shardy_partitioner());
Expand All @@ -154,7 +163,7 @@ absl::StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
FusionConfigCollection::kOff) {
config->set_fusion_config_collection(
aot_options->fusion_config_collection());
*config->mutable_fusion_config() = aot_options->fusion_config();
config->set_fusion_config(aot_options->fusion_config());
}
}

Expand Down
3 changes: 1 addition & 2 deletions xla/service/instruction_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,8 +700,7 @@ absl::StatusOr<bool> InstructionFusion::Run(
VLOG(1) << "There are " << fused_count << " fused bits that cause "
<< fuse_count << " fusion actions.";
}
*module->mutable_config().mutable_fusion_config() =
std::move(fusion_config);
module->mutable_config().set_fusion_config(std::move(fusion_config));
}

VLOG(1) << "Fusion count: " << fuse_count;
Expand Down

0 comments on commit 1397371

Please sign in to comment.