Skip to content

Commit

Permalink
Add GPU topology proto python target.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697028309
  • Loading branch information
changhuilin authored and Google-ML-Automation committed Nov 16, 2024
1 parent 31947e4 commit 392ea25
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 0 deletions.
7 changes: 7 additions & 0 deletions xla/pjrt/gpu/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@rules_python//python:proto.bzl", "py_proto_library")
load("@tsl//tsl/platform:build_config.bzl", "tf_proto_library")
load("@tsl//tsl/platform:rules_cc.bzl", "cc_library")
load("//xla:xla.bzl", "xla_cc_test")
Expand Down Expand Up @@ -250,6 +251,12 @@ tf_proto_library(
visibility = ["//visibility:public"],
)

py_proto_library(
name = "gpu_topology_proto_py_pb2",
visibility = ["//visibility:public"],
deps = [":gpu_topology_proto"],
)

cc_library(
name = "gpu_topology",
srcs = ["gpu_topology.cc"],
Expand Down
2 changes: 2 additions & 0 deletions xla/python/ifrt/topology.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class Topology : public llvm::RTTIExtends<Topology, llvm::RTTIRoot> {

virtual PjRtPlatformId platform_id() const = 0;

virtual bool is_subslice_topology() const = 0;

// Returns an unordered list of descriptions for all devices in this topology.
// TODO(phawkins): consider introducing an IFRT-specific API here instead of
// delegating to PJRT.
Expand Down
4 changes: 4 additions & 0 deletions xla/python/pjrt_ifrt/pjrt_topology.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ PjRtPlatformId PjRtTopology::platform_id() const {
return description_->platform_id();
}

bool PjRtTopology::is_subslice_topology() const {
return description_->is_subslice_topology();
}

std::vector<std::unique_ptr<const PjRtDeviceDescription>>
PjRtTopology::DeviceDescriptions() const {
return description_->DeviceDescriptions();
Expand Down
1 change: 1 addition & 0 deletions xla/python/pjrt_ifrt/pjrt_topology.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class PjRtTopology final : public llvm::RTTIExtends<PjRtTopology, Topology> {
absl::string_view platform_name() const override;
absl::string_view platform_version() const override;
PjRtPlatformId platform_id() const override;
bool is_subslice_topology() const override;

std::vector<std::unique_ptr<const PjRtDeviceDescription>> DeviceDescriptions()
const override;
Expand Down
7 changes: 7 additions & 0 deletions xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,13 @@ NB_MODULE(xla_extension, m_nb) {
.def_prop_ro(
"platform_version",
[](ifrt::Topology& topology) { return topology.platform_version(); })
.def_prop_ro(
"platform_id",
[](ifrt::Topology& topology) { return topology.platform_id(); })
.def_prop_ro("is_subslice_topology",
[](ifrt::Topology& topology) {
return topology.is_subslice_topology();
})
.def("serialize",
[](ifrt::Topology& topology) -> nb::bytes {
std::string serialized = ValueOrThrow(topology.Serialize());
Expand Down
2 changes: 2 additions & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,8 @@ class Executable:
class DeviceTopology:
platform: str
platform_version: str
platform_id: int
is_subslice_topology: bool
def _make_compile_only_devices(self) -> List[Device]: ...
def serialize(self) -> bytes: ...
def __getattr__(self, name: str) -> Any: ...
Expand Down

0 comments on commit 392ea25

Please sign in to comment.