Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix worker sample in core #4104

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Any, NoReturn

from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, RoutedAgent, default_subscription, message_handler
from autogen_core.base import MessageContext, try_get_known_serializers_for_type
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler


@dataclass
Expand Down Expand Up @@ -33,7 +33,6 @@ class ReturnedFeedback:
content: str


@default_subscription
class ReceiveAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("Receive Agent")
Expand All @@ -50,7 +49,6 @@ async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoRet
print(f"Unhandled message: {message}")


@default_subscription
class GreeterAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("Greeter Agent")
Expand All @@ -70,9 +68,13 @@ async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoRet
async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
runtime.start()
for t in [AskToGreet, Greeting, ReturnedGreeting, Feedback, ReturnedFeedback]:
runtime.add_message_serializer(try_get_known_serializers_for_type(t))

await ReceiveAgent.register(runtime, "receiver", ReceiveAgent)
await runtime.add_subscription(DefaultSubscription(agent_type="receiver"))
await GreeterAgent.register(runtime, "greeter", GreeterAgent)
await runtime.add_subscription(DefaultSubscription(agent_type="greeter"))

await runtime.publish_message(AskToGreet("Hello World!"), topic_id=DefaultTopicId())

Expand Down
26 changes: 12 additions & 14 deletions python/packages/autogen-core/samples/worker/run_worker_rpc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import Any, NoReturn

from autogen_core.application import WorkerAgentRuntime
from autogen_core.base import (
AgentId,
AgentInstantiationContext,
MessageContext,
)
from autogen_core.components import DefaultSubscription, DefaultTopicId, RoutedAgent, message_handler
Expand Down Expand Up @@ -39,34 +37,34 @@ async def on_greet(self, message: Greeting, ctx: MessageContext) -> Greeting:
async def on_feedback(self, message: Feedback, ctx: MessageContext) -> None:
print(f"Feedback received: {message.content}")

async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")


class GreeterAgent(RoutedAgent):
def __init__(self, receive_agent_id: AgentId) -> None:
def __init__(self, receive_agent_type: str) -> None:
super().__init__("Greeter Agent")
self._receive_agent_id = receive_agent_id
self._receive_agent_id = AgentId(receive_agent_type, self.id.key)

@message_handler
async def on_ask(self, message: AskToGreet, ctx: MessageContext) -> None:
response = await self.send_message(Greeting(f"Hello, {message.content}!"), recipient=self._receive_agent_id)
await self.publish_message(Feedback(f"Feedback: {response.content}"), topic_id=DefaultTopicId())

async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> NoReturn: # type: ignore
print(f"Unhandled message: {message}")


async def main() -> None:
runtime = WorkerAgentRuntime(host_address="localhost:50051")
runtime.start()

await runtime.register("receiver", lambda: ReceiveAgent(), lambda: [DefaultSubscription()])
await runtime.register(
await ReceiveAgent.register(
runtime,
"receiver",
lambda: ReceiveAgent(),
)
await runtime.add_subscription(DefaultSubscription(agent_type="receiver"))
await GreeterAgent.register(
runtime,
"greeter",
lambda: GreeterAgent(AgentId("receiver", AgentInstantiationContext.current_agent_id().key)),
lambda: [DefaultSubscription()],
lambda: GreeterAgent("receiver"),
)
await runtime.add_subscription(DefaultSubscription(agent_type="greeter"))
await runtime.publish_message(AskToGreet("Hello World!"), topic_id=DefaultTopicId())

await runtime.stop_when_signal()
Expand Down
Loading