Skip to content

Commit

Permalink
Move IFRT lib to XLA:CPU Public API
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696928442
  • Loading branch information
changm authored and Google-ML-Automation committed Nov 15, 2024
1 parent 54fad16 commit a7aff7f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
30 changes: 20 additions & 10 deletions xla/python/pjrt_ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ xla_cc_test(
size = "small",
srcs = ["xla_sharding_test.cc"],
deps = [
":tfrt_cpu_client_test_lib",
":pjrt_cpu_client_test_lib",
":xla_ifrt",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
Expand Down Expand Up @@ -267,17 +267,27 @@ cc_library(
)

cc_library(
name = "tfrt_cpu_client_test_lib",
name = "pjrt_cpu_client_test_lib",
testonly = True,
srcs = ["tfrt_cpu_client_test_lib.cc"],
srcs = ["pjrt_cpu_client_test_lib.cc"],
deps = [
":pjrt_ifrt",
"//xla/pjrt/cpu:cpu_client",
"//xla/pjrt/plugin/xla_cpu:cpu_client_options",
"//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client",
"//xla/python/ifrt",
"//xla/python/ifrt:test_util",
"@com_google_absl//absl/status:statusor",
"@tsl//tsl/platform:statusor",
],
alwayslink = True,
)

# TODO(hyeontaek): Remove this target after migration is done.
alias(
name = "tfrt_cpu_client_test_lib",
actual = ":pjrt_cpu_client_test_lib",
)

cc_library(
name = "pjrt_attribute_map_util",
srcs = ["pjrt_attribute_map_util.cc"],
Expand Down Expand Up @@ -347,7 +357,7 @@ xla_cc_test(
srcs = ["basic_string_array_test.cc"],
deps = [
":basic_string_array",
":tfrt_cpu_client_test_lib",
":pjrt_cpu_client_test_lib",
"//xla:shape_util",
"//xla/pjrt:pjrt_future",
"//xla/pjrt:pjrt_layout",
Expand Down Expand Up @@ -375,7 +385,7 @@ xla_cc_test(
size = "small",
srcs = ["pjrt_array_impl_test_tfrt_cpu.cc"],
deps = [
":tfrt_cpu_client_test_lib",
":pjrt_cpu_client_test_lib",
"//xla/python/ifrt:array_impl_test_lib",
"//xla/python/ifrt:test_util",
"@com_google_absl//absl/strings",
Expand All @@ -388,7 +398,7 @@ xla_cc_test(
size = "small",
srcs = [],
deps = [
":tfrt_cpu_client_test_lib",
":pjrt_cpu_client_test_lib",
"//xla/python/ifrt:client_impl_test_lib",
"@com_google_googletest//:gtest_main",
],
Expand All @@ -399,7 +409,7 @@ xla_cc_test(
size = "small",
srcs = ["pjrt_executable_impl_test_tfrt_cpu.cc"],
deps = [
":tfrt_cpu_client_test_lib",
":pjrt_cpu_client_test_lib",
":xla_executable_impl_test_lib",
"//xla/python/ifrt:test_util",
"@com_google_absl//absl/strings",
Expand All @@ -412,7 +422,7 @@ xla_cc_test(
size = "small",
srcs = [],
deps = [
":tfrt_cpu_client_test_lib",
":pjrt_cpu_client_test_lib",
"//xla/python/ifrt:tuple_impl_test_lib",
"@com_google_googletest//:gtest_main",
],
Expand All @@ -423,7 +433,7 @@ xla_cc_test(
size = "small",
srcs = [],
deps = [
":tfrt_cpu_client_test_lib",
":pjrt_cpu_client_test_lib",
"//xla/python/ifrt:remap_impl_test_lib",
"@com_google_googletest//:gtest_main",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ limitations under the License.
#include <memory>
#include <utility>

#include "xla/pjrt/cpu/cpu_client.h"
#include "absl/status/statusor.h"
#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h"
#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h"
#include "xla/python/ifrt/client.h"
#include "xla/python/ifrt/test_util.h"
#include "xla/python/pjrt_ifrt/pjrt_client.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace ifrt {
Expand All @@ -27,10 +31,10 @@ namespace {
const bool kUnused =
(test_util::RegisterClientFactory(
[]() -> absl::StatusOr<std::shared_ptr<Client>> {
CpuClientOptions options;
xla::CpuClientOptions options;
options.cpu_device_count = 4;
TF_ASSIGN_OR_RETURN(auto pjrt_client,
xla::GetTfrtCpuClient(std::move(options)));
xla::GetXlaPjrtCpuClient(std::move(options)));
return std::shared_ptr<Client>(
PjRtClient::Create(std::move(pjrt_client)));
}),
Expand Down

0 comments on commit a7aff7f

Please sign in to comment.