Skip to content

Commit

Permalink
FEAT: generator support normal function (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Jan 4, 2024
1 parent a442021 commit 4d623f4
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 24 deletions.
66 changes: 44 additions & 22 deletions python/xoscar/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from __future__ import annotations

import asyncio
import functools
import inspect
import logging
import threading
import uuid
from collections import defaultdict
Expand Down Expand Up @@ -44,6 +46,8 @@
from .backends.config import ActorPoolConfig
from .backends.pool import MainActorPoolType

logger = logging.getLogger(__name__)


async def create_actor(
actor_cls: Type, *args, uid=None, address=None, **kwargs
Expand Down Expand Up @@ -295,6 +299,7 @@ def __init__(self, uid: str, actor_addr: str, actor_uid: str):
self._actor_addr = actor_addr
self._actor_uid = actor_uid
self._actor_ref = None
self._gc_destroy = True

async def destroy(self):
if self._actor_ref is None:
Expand All @@ -307,15 +312,24 @@ async def destroy(self):
def __del__(self):
# It's not a good idea to spawn a new thread and join in __del__,
# but currently it's the only way to GC the generator.
thread = threading.Thread(
target=asyncio.run, args=(self.destroy(),), daemon=True
)
thread.start()
thread.join()
# TODO(codingl2k1): This __del__ may hangs if the program is exiting.
if self._gc_destroy:
thread = threading.Thread(
target=asyncio.run, args=(self.destroy(),), daemon=True
)
thread.start()
thread.join()

def __aiter__(self):
return self

def __getstate__(self):
# Transfer gc destroy during serialization.
state = dict(**super().__getstate__())
state["_gc_destroy"] = True
self._gc_destroy = False
return state

async def __anext__(self) -> T:
if self._actor_ref is None:
self._actor_ref = await actor_ref(
Expand Down Expand Up @@ -400,13 +414,7 @@ async def _async_wrapper(_gen):
stop = object()
try:
if inspect.isgenerator(gen):
# to_thread is only available for Python >= 3.9
if hasattr(asyncio, "to_thread"):
r = await asyncio.to_thread(_wrapper, gen)
else:
r = await asyncio.get_event_loop().run_in_executor(
None, _wrapper, gen
)
r = await asyncio.to_thread(_wrapper, gen)
elif inspect.isasyncgen(gen):
r = await asyncio.create_task(_async_wrapper(gen))
else:
Expand All @@ -415,15 +423,20 @@ async def _async_wrapper(_gen):
f"but a {type(gen)} is got."
)
except Exception as e:
logger.exception(
f"Destroy generator {generator_uid} due to an error encountered."
)
await self.__xoscar_destroy_generator__(generator_uid)
del gen # Avoid exception hold generator reference.
raise e
if r is stop:
await self.__xoscar_destroy_generator__(generator_uid)
del gen # Avoid exception hold generator reference.
raise Exception("StopIteration")
else:
return r
else:
raise RuntimeError(f"no iterator with id: {generator_uid}")
raise RuntimeError(f"No iterator with id: {generator_uid}")

async def __xoscar_destroy_generator__(self, generator_uid: str):
"""
Expand All @@ -434,19 +447,28 @@ async def __xoscar_destroy_generator__(self, generator_uid: str):
generator_uid: str
The uid of generator
"""
return self._generators.pop(generator_uid, None)
logger.debug("Destroy generator: %s", generator_uid)
self._generators.pop(generator_uid, None)


def generator(func):
async def wrapper(obj, *args, **kwargs):
gen_uid = uuid.uuid1().hex
obj._generators[gen_uid] = func(obj, *args, **kwargs)
return IteratorWrapper(gen_uid, obj.address, obj.uid)
need_to_thread = not asyncio.iscoroutinefunction(func)

if inspect.isgeneratorfunction(func) or inspect.isasyncgenfunction(func):
return wrapper
else:
return func
@functools.wraps(func)
async def _wrapper(self, *args, **kwargs):
if need_to_thread:
r = await asyncio.to_thread(func, self, *args, **kwargs)
else:
r = await func(self, *args, **kwargs)
if inspect.isgenerator(r) or inspect.isasyncgen(r):
gen_uid = uuid.uuid1().hex
logger.debug("Create generator: %s", gen_uid)
self._generators[gen_uid] = r
return IteratorWrapper(gen_uid, self.address, self.uid)
else:
return r

return _wrapper


class Actor(AsyncActorMixin, _Actor):
Expand Down
2 changes: 1 addition & 1 deletion python/xoscar/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ cdef class ActorRef:
return create_actor_ref, (self.address, self.uid)

def __getattr__(self, item):
if item.startswith('_'):
if item.startswith('_') and item not in ["__xoscar_next__", "__xoscar_destroy_generator__"]:
return object.__getattribute__(self, item)

try:
Expand Down
40 changes: 39 additions & 1 deletion python/xoscar/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,32 @@ async def with_exception(self):
raise Exception("intent raise")
yield 2

@xo.generator
async def mix_gen(self, v):
if v == 1:
return self._gen()
elif v == 2:
return self._gen2()
else:
return 0

@xo.generator
def mix_gen2(self, v):
if v == 1:
return self._gen()
elif v == 2:
return self._gen2()
else:
return 0

def _gen(self):
for x in range(3):
yield x

async def _gen2(self):
for x in range(3):
yield x

@classmethod
def uid(cls):
return "supervisor"
Expand Down Expand Up @@ -134,7 +160,19 @@ async def test_generator():
all_gen = await superivsor_actor.get_all_generators()
assert len(all_gen) == 0

await asyncio.create_task(superivsor_actor.with_exception())
r = await superivsor_actor.with_exception()
del r
await asyncio.sleep(0)
all_gen = await superivsor_actor.get_all_generators()
assert len(all_gen) == 0

for f in [superivsor_actor.mix_gen, superivsor_actor.mix_gen2]:
out = []
async for x in await f(1):
out.append(x)
assert out == [0, 1, 2]
out = []
async for x in await f(2):
out.append(x)
assert out == [0, 1, 2]
assert 0 == await f(0)

0 comments on commit 4d623f4

Please sign in to comment.