diff --git a/xla/python/pjrt_ifrt/BUILD b/xla/python/pjrt_ifrt/BUILD index 369f2c00ab0a82..de4447f90260ba 100644 --- a/xla/python/pjrt_ifrt/BUILD +++ b/xla/python/pjrt_ifrt/BUILD @@ -277,6 +277,22 @@ cc_library( alwayslink = True, ) +cc_library( + name = "se_gpu_client_test_lib", + testonly = True, + srcs = ["se_gpu_client_test_lib.cc"], + deps = [ + ":pjrt_ifrt", + "//xla/pjrt/gpu:se_gpu_pjrt_client", + "//xla/python/ifrt", + "//xla/python/ifrt:test_util", + "//xla/service:gpu_plugin", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:statusor", + ], + alwayslink = True, +) + cc_library( name = "pjrt_attribute_map_util", srcs = ["pjrt_attribute_map_util.cc"], @@ -393,6 +409,20 @@ xla_cc_test( ], ) +xla_cc_test( + name = "pjrt_client_impl_test_se_gpu", + size = "small", + srcs = [], + tags = [ + "requires-gpu-nvidia:2", + ], + deps = [ + ":se_gpu_client_test_lib", + "//xla/python/ifrt:client_impl_test_lib", + "@com_google_googletest//:gtest_main", + ], +) + xla_cc_test( name = "pjrt_executable_impl_test_tfrt_cpu", size = "small", diff --git a/xla/python/pjrt_ifrt/se_gpu_client_test_lib.cc b/xla/python/pjrt_ifrt/se_gpu_client_test_lib.cc new file mode 100644 index 00000000000000..d7c4a96c88f3e6 --- /dev/null +++ b/xla/python/pjrt_ifrt/se_gpu_client_test_lib.cc @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/test_util.h" +#include "xla/python/pjrt_ifrt/pjrt_client.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace { + +const bool kUnused = + (test_util::RegisterClientFactory( + []() -> absl::StatusOr> { + TF_ASSIGN_OR_RETURN( + auto pjrt_client, + xla::GetStreamExecutorGpuClient(xla::GpuClientOptions())); + return PjRtClient::Create(std::move(pjrt_client)); + }), + true); + +} // namespace +} // namespace ifrt +} // namespace xla