From 6e8181495a99aa31bf27e1bf7b23981876c8f9db Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Thu, 11 May 2023 14:36:35 -0700 Subject: [PATCH] Construct topologies and hook up aot_test for pjrt_c_api. PiperOrigin-RevId: 531310241 --- jax/_src/xla_bridge.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 8355726a8ef8..50f23e5e6a13 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -694,4 +694,11 @@ def using_pjrt_c_api(backend=None): # TODO(parkers): Get rid of this in favor of a generic way to get topologies. def make_pjrt_tpu_topology(topology_name=None, **kwargs): + if xla_extension_version >= 153: + # TODO(b/261484192): Make a system for lazily loading libtpu.so and call + # that inside make_tfrt_tpu_c_api_device_topology. + get_backend() # Properly initialize libtpu.so. + return xla_client.make_tfrt_tpu_c_api_device_topology( + topology_name, **kwargs + ) raise NotImplementedError('make_pjrt_tpu_topology is not implemented')