diff --git a/python/xoscar/backends/context.py b/python/xoscar/backends/context.py index 5f981b64..23fe3f71 100644 --- a/python/xoscar/backends/context.py +++ b/python/xoscar/backends/context.py @@ -26,7 +26,7 @@ from ..core import ActorRef, BufferRef, FileObjectRef, create_local_actor_ref from ..debug import debug_async_timeout, detect_cycle_send from ..errors import CannotCancelTask -from ..utils import dataslots +from ..utils import dataslots, fix_all_zero_ip from .allocate_strategy import AddressSpecified, AllocateStrategy from .communication import Client, DummyClient, UCXClient from .core import ActorCaller @@ -187,6 +187,7 @@ async def kill_actor(self, actor_ref: ActorRef, force: bool = True): async def actor_ref(self, *args, **kwargs): actor_ref = create_actor_ref(*args, **kwargs) + connect_addr = actor_ref.address local_actor_ref = create_local_actor_ref(actor_ref.address, actor_ref.uid) if local_actor_ref is not None: return local_actor_ref @@ -195,7 +196,10 @@ async def actor_ref(self, *args, **kwargs): ) future = await self._call(actor_ref.address, message, wait=False) result = await self._wait(future, actor_ref.address, message) - return self._process_result_message(result) + res = self._process_result_message(result) + if res.address != connect_addr: + res.address = fix_all_zero_ip(res.address, connect_addr) + return res async def send( self, diff --git a/python/xoscar/tests/test_utils.py b/python/xoscar/tests/test_utils.py index 3735628b..d241d5b5 100644 --- a/python/xoscar/tests/test_utils.py +++ b/python/xoscar/tests/test_utils.py @@ -160,3 +160,25 @@ def test_timer(): time.sleep(0.1) assert timer.duration >= 0.1 + + +def test_fix_all_zero_ip(): + assert utils.is_v4_zero_ip("0.0.0.0:1234") == True + assert utils.is_v4_zero_ip("127.0.0.1:1234") == False + assert utils.is_v6_zero_ip(":::1234") == True + assert utils.is_v6_zero_ip("::FFFF:1234") == False + return utils.is_v6_zero_ip("0000:0000:0000:0000:0000:0000:0000:0000:1234") == True + return utils.is_v6_zero_ip("0:0:0:0:0:0:0:0:1234") == True + return utils.is_v6_zero_ip("0:0:0:0:0:1234") == True + assert utils.is_v6_zero_ip("2001:db8:3333:4444:5555:6666:7777:8888:1234") == False + assert utils.is_v6_zero_ip("127.0.0.1:1234") == False + assert utils.fix_all_zero_ip("127.0.0.1:1234", "127.0.0.1:5678") == "127.0.0.1:1234" + assert utils.fix_all_zero_ip("0.0.0.0:1234", "0.0.0.0:5678") == "0.0.0.0:1234" + assert ( + utils.fix_all_zero_ip("0.0.0.0:1234", "192.168.0.1:5678") == "192.168.0.1:1234" + ) + assert utils.fix_all_zero_ip("127.0.0.1:1234", "0.0.0.0:5678") == "127.0.0.1:1234" + assert ( + utils.fix_all_zero_ip(":::1234", "2001:0db8:0001:0000:0000:0ab9:C0A8:0102:5678") + == "2001:0db8:0001:0000:0000:0ab9:C0A8:0102:1234" + ) diff --git a/python/xoscar/utils.py b/python/xoscar/utils.py index 04e19ec0..a4e98bd1 100644 --- a/python/xoscar/utils.py +++ b/python/xoscar/utils.py @@ -462,3 +462,41 @@ def is_windows(): def is_linux(): return sys.platform.startswith("linux") + + +def is_v4_zero_ip(ip_port_addr: str) -> bool: + return ip_port_addr.startswith("0.0.0.0:") + + +def is_v6_zero_ip(ip_port_addr: str) -> bool: + # tcp6 addr ":::123", ":: means all zero" + arr = ip_port_addr.split(":") + if len(arr) <= 2: # Not tcp6 or udp6 + return False + for part in arr[0:-1]: + if part != "": + if int(part, 16) != 0: + return False + return True + + +def fix_all_zero_ip(remote_addr: str, connect_addr: str) -> str: + """ + Use connect_addr to fix ActorRef.address return by remote server. + When remote server listen on "0.0.0.0:port" or ":::port", it will return ActorRef.address set to listening addr, + it cannot be use by client for the following interaction unless we fix it. + (client will treat 0.0.0.0 as 127.0.0.1) + + NOTE: Server might return a different addr from a pool for load-balance purpose. + """ + if remote_addr == connect_addr: + return remote_addr + if not is_v4_zero_ip(remote_addr) and not is_v6_zero_ip(remote_addr): + # Remote server returns on non-zero ip + return remote_addr + if is_v4_zero_ip(connect_addr) or is_v6_zero_ip(connect_addr): + # Client connect to local server + return remote_addr + remote_port = remote_addr.split(":")[-1] + connect_ip = ":".join(connect_addr.split(":")[0:-1]) # Remote the port + return f"{connect_ip}:{remote_port}"