From a7aff7f0ce02919d07c4fce4861f50b57171becd Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Fri, 15 Nov 2024 10:23:03 -0800 Subject: [PATCH] Move IFRT lib to XLA:CPU Public API PiperOrigin-RevId: 696928442 --- xla/python/pjrt_ifrt/BUILD | 30 ++++++++++++------- ...est_lib.cc => pjrt_cpu_client_test_lib.cc} | 10 +++++-- 2 files changed, 27 insertions(+), 13 deletions(-) rename xla/python/pjrt_ifrt/{tfrt_cpu_client_test_lib.cc => pjrt_cpu_client_test_lib.cc} (78%) diff --git a/xla/python/pjrt_ifrt/BUILD b/xla/python/pjrt_ifrt/BUILD index 0a57a950d2df6..c929fe34209d1 100644 --- a/xla/python/pjrt_ifrt/BUILD +++ b/xla/python/pjrt_ifrt/BUILD @@ -168,7 +168,7 @@ xla_cc_test( size = "small", srcs = ["xla_sharding_test.cc"], deps = [ - ":tfrt_cpu_client_test_lib", + ":pjrt_cpu_client_test_lib", ":xla_ifrt", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -267,17 +267,27 @@ cc_library( ) cc_library( - name = "tfrt_cpu_client_test_lib", + name = "pjrt_cpu_client_test_lib", testonly = True, - srcs = ["tfrt_cpu_client_test_lib.cc"], + srcs = ["pjrt_cpu_client_test_lib.cc"], deps = [ ":pjrt_ifrt", - "//xla/pjrt/cpu:cpu_client", + "//xla/pjrt/plugin/xla_cpu:cpu_client_options", + "//xla/pjrt/plugin/xla_cpu:xla_cpu_pjrt_client", + "//xla/python/ifrt", "//xla/python/ifrt:test_util", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:statusor", ], alwayslink = True, ) +# TODO(hyeontaek): Remove this target after migration is done. +alias( + name = "tfrt_cpu_client_test_lib", + actual = ":pjrt_cpu_client_test_lib", +) + cc_library( name = "pjrt_attribute_map_util", srcs = ["pjrt_attribute_map_util.cc"], @@ -347,7 +357,7 @@ xla_cc_test( srcs = ["basic_string_array_test.cc"], deps = [ ":basic_string_array", - ":tfrt_cpu_client_test_lib", + ":pjrt_cpu_client_test_lib", "//xla:shape_util", "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_layout", @@ -375,7 +385,7 @@ xla_cc_test( size = "small", srcs = ["pjrt_array_impl_test_tfrt_cpu.cc"], deps = [ - ":tfrt_cpu_client_test_lib", + ":pjrt_cpu_client_test_lib", "//xla/python/ifrt:array_impl_test_lib", "//xla/python/ifrt:test_util", "@com_google_absl//absl/strings", @@ -388,7 +398,7 @@ xla_cc_test( size = "small", srcs = [], deps = [ - ":tfrt_cpu_client_test_lib", + ":pjrt_cpu_client_test_lib", "//xla/python/ifrt:client_impl_test_lib", "@com_google_googletest//:gtest_main", ], @@ -399,7 +409,7 @@ xla_cc_test( size = "small", srcs = ["pjrt_executable_impl_test_tfrt_cpu.cc"], deps = [ - ":tfrt_cpu_client_test_lib", + ":pjrt_cpu_client_test_lib", ":xla_executable_impl_test_lib", "//xla/python/ifrt:test_util", "@com_google_absl//absl/strings", @@ -412,7 +422,7 @@ xla_cc_test( size = "small", srcs = [], deps = [ - ":tfrt_cpu_client_test_lib", + ":pjrt_cpu_client_test_lib", "//xla/python/ifrt:tuple_impl_test_lib", "@com_google_googletest//:gtest_main", ], @@ -423,7 +433,7 @@ xla_cc_test( size = "small", srcs = [], deps = [ - ":tfrt_cpu_client_test_lib", + ":pjrt_cpu_client_test_lib", "//xla/python/ifrt:remap_impl_test_lib", "@com_google_googletest//:gtest_main", ], diff --git a/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc b/xla/python/pjrt_ifrt/pjrt_cpu_client_test_lib.cc similarity index 78% rename from xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc rename to xla/python/pjrt_ifrt/pjrt_cpu_client_test_lib.cc index 4fb1ca36e33a5..f782b7fc64428 100644 --- a/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc +++ b/xla/python/pjrt_ifrt/pjrt_cpu_client_test_lib.cc @@ -16,9 +16,13 @@ limitations under the License. #include #include -#include "xla/pjrt/cpu/cpu_client.h" +#include "absl/status/statusor.h" +#include "xla/pjrt/plugin/xla_cpu/cpu_client_options.h" +#include "xla/pjrt/plugin/xla_cpu/xla_cpu_pjrt_client.h" +#include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "tsl/platform/statusor.h" namespace xla { namespace ifrt { @@ -27,10 +31,10 @@ namespace { const bool kUnused = (test_util::RegisterClientFactory( []() -> absl::StatusOr> { - CpuClientOptions options; + xla::CpuClientOptions options; options.cpu_device_count = 4; TF_ASSIGN_OR_RETURN(auto pjrt_client, - xla::GetTfrtCpuClient(std::move(options))); + xla::GetXlaPjrtCpuClient(std::move(options))); return std::shared_ptr( PjRtClient::Create(std::move(pjrt_client))); }),