Skip to content

Commit

Permalink
FEAT: make it easy to define generator methods in Actor (#82)
Browse files Browse the repository at this point in the history
Co-authored-by: codingl2k1 <codingl2k1@outlook.com>
  • Loading branch information
liunux4odoo and codingl2k1 authored Dec 26, 2023
1 parent c7f3de4 commit a442021
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 1 deletion.
1 change: 1 addition & 0 deletions python/xoscar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
setup_cluster,
wait_actor_pool_recovered,
get_pool_config,
generator,
)
from .backends import allocate_strategy
from .backends.pool import MainActorPoolType
Expand Down
145 changes: 144 additions & 1 deletion python/xoscar/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,24 @@

from __future__ import annotations

import asyncio
import inspect
import threading
import uuid
from collections import defaultdict
from numbers import Number
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from urllib.parse import urlparse

from .aio import AioFileObject
Expand Down Expand Up @@ -271,6 +286,51 @@ def setup_cluster(address_to_resources: Dict[str, Dict[str, Number]]):
get_backend(scheme).get_driver_cls().setup_cluster(address_resources)


T = TypeVar("T")


class IteratorWrapper(Generic[T]):
def __init__(self, uid: str, actor_addr: str, actor_uid: str):
self._uid = uid
self._actor_addr = actor_addr
self._actor_uid = actor_uid
self._actor_ref = None

async def destroy(self):
if self._actor_ref is None:
self._actor_ref = await actor_ref(
address=self._actor_addr, uid=self._actor_uid
)
assert self._actor_ref is not None
return await self._actor_ref.__xoscar_destroy_generator__(self._uid)

def __del__(self):
# It's not a good idea to spawn a new thread and join in __del__,
# but currently it's the only way to GC the generator.
thread = threading.Thread(
target=asyncio.run, args=(self.destroy(),), daemon=True
)
thread.start()
thread.join()

def __aiter__(self):
return self

async def __anext__(self) -> T:
if self._actor_ref is None:
self._actor_ref = await actor_ref(
address=self._actor_addr, uid=self._actor_uid
)
try:
assert self._actor_ref is not None
return await self._actor_ref.__xoscar_next__(self._uid)
except Exception as e:
if "StopIteration" in str(e):
raise StopAsyncIteration
else:
raise


class AsyncActorMixin:
@classmethod
def default_uid(cls):
Expand All @@ -282,6 +342,10 @@ def __new__(cls, *args, **kwargs):
except KeyError:
return super().__new__(cls, *args, **kwargs)

def __init__(self, *args, **kwargs) -> None:
super().__init__()
self._generators: Dict[str, IteratorWrapper] = {}

async def __post_create__(self):
"""
Method called after actor creation
Expand All @@ -305,6 +369,85 @@ async def __on_receive__(self, message: Tuple[Any]):
"""
return await super().__on_receive__(message) # type: ignore

async def __xoscar_next__(self, generator_uid: str) -> Any:
"""
Iter the next of generator.
Parameters
----------
generator_uid: str
The uid of generator
Returns
-------
The next value of generator
"""

def _wrapper(_gen):
try:
return next(_gen)
except StopIteration:
return stop

async def _async_wrapper(_gen):
try:
# anext is only available for Python >= 3.10
return await _gen.__anext__() # noqa: F821
except StopAsyncIteration:
return stop

if gen := self._generators.get(generator_uid):
stop = object()
try:
if inspect.isgenerator(gen):
# to_thread is only available for Python >= 3.9
if hasattr(asyncio, "to_thread"):
r = await asyncio.to_thread(_wrapper, gen)
else:
r = await asyncio.get_event_loop().run_in_executor(
None, _wrapper, gen
)
elif inspect.isasyncgen(gen):
r = await asyncio.create_task(_async_wrapper(gen))
else:
raise Exception(
f"The generator {generator_uid} should be a generator or an async generator, "
f"but a {type(gen)} is got."
)
except Exception as e:
await self.__xoscar_destroy_generator__(generator_uid)
raise e
if r is stop:
await self.__xoscar_destroy_generator__(generator_uid)
raise Exception("StopIteration")
else:
return r
else:
raise RuntimeError(f"no iterator with id: {generator_uid}")

async def __xoscar_destroy_generator__(self, generator_uid: str):
"""
Destroy the generator.
Parameters
----------
generator_uid: str
The uid of generator
"""
return self._generators.pop(generator_uid, None)


def generator(func):
async def wrapper(obj, *args, **kwargs):
gen_uid = uuid.uuid1().hex
obj._generators[gen_uid] = func(obj, *args, **kwargs)
return IteratorWrapper(gen_uid, obj.address, obj.uid)

if inspect.isgeneratorfunction(func) or inspect.isasyncgenfunction(func):
return wrapper
else:
return func


class Actor(AsyncActorMixin, _Actor):
pass
Expand Down
140 changes: 140 additions & 0 deletions python/xoscar/tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import time

import pytest

import xoscar as xo

address = "127.0.0.1:12347"


class WorkerActor(xo.StatelessActor):
@xo.generator
def chat(self):
for x in "hello oscar by sync":
yield x
time.sleep(0.1)

@xo.generator
async def achat(self):
for x in "hello oscar by async":
yield x
await asyncio.sleep(0.1)

@classmethod
def uid(cls):
return "worker"


class SupervisorActor(xo.StatelessActor):
def get_all_generators(self):
return list(self._generators.keys())

@xo.generator
async def chat(self):
worker_actor: xo.ActorRef["WorkerActor"] = await xo.actor_ref(
address=address, uid=WorkerActor.uid()
)
yield "sync"
async for x in await worker_actor.chat(): # this is much confused. I will suggest use async generators only.
yield x

yield "async"
async for x in await worker_actor.achat():
yield x

@xo.generator
async def with_exception(self):
yield 1
raise Exception("intent raise")
yield 2

@classmethod
def uid(cls):
return "supervisor"


async def test_generator():
await xo.create_actor_pool(address, 2)
await xo.create_actor(WorkerActor, address=address, uid=WorkerActor.uid())
superivsor_actor = await xo.create_actor(
SupervisorActor, address=address, uid=SupervisorActor.uid()
)

all_gen = await superivsor_actor.get_all_generators()
assert len(all_gen) == 0
output = []
async for x in await superivsor_actor.chat():
all_gen = await superivsor_actor.get_all_generators()
assert len(all_gen) == 1
output.append(x)
all_gen = await superivsor_actor.get_all_generators()
assert len(all_gen) == 0
assert output == [
"sync",
"h",
"e",
"l",
"l",
"o",
" ",
"o",
"s",
"c",
"a",
"r",
" ",
"b",
"y",
" ",
"s",
"y",
"n",
"c",
"async",
"h",
"e",
"l",
"l",
"o",
" ",
"o",
"s",
"c",
"a",
"r",
" ",
"b",
"y",
" ",
"a",
"s",
"y",
"n",
"c",
]

with pytest.raises(Exception, match="intent"):
async for _ in await superivsor_actor.with_exception():
pass
all_gen = await superivsor_actor.get_all_generators()
assert len(all_gen) == 0

await asyncio.create_task(superivsor_actor.with_exception())
await asyncio.sleep(0)
all_gen = await superivsor_actor.get_all_generators()
assert len(all_gen) == 0

0 comments on commit a442021

Please sign in to comment.