|
3 | 3 | import json
|
4 | 4 | import logging
|
5 | 5 | import signal
|
| 6 | +import uuid |
6 | 7 | import warnings
|
7 | 8 | from asyncio import Future, Task
|
8 | 9 | from collections import defaultdict
|
|
47 | 48 | )
|
48 | 49 | from ..base._serialization import MessageSerializer, SerializationRegistry
|
49 | 50 | from ..base._type_helpers import ChannelArgumentType
|
50 |
| -from ..components import TypeSubscription |
| 51 | +from ..components import TypePrefixSubscription, TypeSubscription |
51 | 52 | from ._helpers import SubscriptionManager, get_impl
|
52 | 53 | from ._utils import GRPC_IMPORT_ERROR_STR
|
53 | 54 | from .protos import agent_worker_pb2, agent_worker_pb2_grpc
|
@@ -371,11 +372,17 @@ async def publish_message(
|
371 | 372 | *,
|
372 | 373 | sender: AgentId | None = None,
|
373 | 374 | cancellation_token: CancellationToken | None = None,
|
| 375 | + message_id: str | None = None, |
374 | 376 | ) -> None:
|
375 | 377 | if not self._running:
|
376 | 378 | raise ValueError("Runtime must be running when publishing message.")
|
377 | 379 | if self._host_connection is None:
|
378 | 380 | raise RuntimeError("Host connection is not set.")
|
| 381 | + if message_id is None: |
| 382 | + message_id = str(uuid.uuid4()) |
| 383 | + |
| 384 | + # TODO: consume message_id |
| 385 | + |
379 | 386 | message_type = self._serialization_registry.type_name(message)
|
380 | 387 | with self._trace_helper.trace_block(
|
381 | 388 | "create", topic_id, parent=None, extraAttributes={"message_type": message_type}
|
@@ -447,6 +454,7 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
|
447 | 454 | topic_id=None,
|
448 | 455 | is_rpc=True,
|
449 | 456 | cancellation_token=CancellationToken(),
|
| 457 | + message_id=request.request_id, |
450 | 458 | )
|
451 | 459 |
|
452 | 460 | # Call the receiving agent.
|
@@ -530,11 +538,13 @@ async def _process_event(self, event: agent_worker_pb2.Event) -> None:
|
530 | 538 | for agent_id in recipients:
|
531 | 539 | if agent_id == sender:
|
532 | 540 | continue
|
| 541 | + # TODO: consume message_id |
533 | 542 | message_context = MessageContext(
|
534 | 543 | sender=sender,
|
535 | 544 | topic_id=topic_id,
|
536 | 545 | is_rpc=False,
|
537 | 546 | cancellation_token=CancellationToken(),
|
| 547 | + message_id="NOT_DEFINED_TODO_FIX", |
538 | 548 | )
|
539 | 549 | agent = await self._get_agent(agent_id)
|
540 | 550 | with MessageHandlerContext.populate_context(agent.id):
|
@@ -705,27 +715,44 @@ async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = A
|
705 | 715 | async def add_subscription(self, subscription: Subscription) -> None:
|
706 | 716 | if self._host_connection is None:
|
707 | 717 | raise RuntimeError("Host connection is not set.")
|
708 |
| - if not isinstance(subscription, TypeSubscription): |
709 |
| - raise ValueError("Only TypeSubscription is supported.") |
710 |
| - # Add to local subscription manager. |
711 |
| - await self._subscription_manager.add_subscription(subscription) |
712 | 718 |
|
713 | 719 | # Create a future for the subscription response.
|
714 | 720 | future = asyncio.get_event_loop().create_future()
|
715 | 721 | request_id = await self._get_new_request_id()
|
| 722 | + |
| 723 | + match subscription: |
| 724 | + case TypeSubscription(topic_type=topic_type, agent_type=agent_type): |
| 725 | + message = agent_worker_pb2.Message( |
| 726 | + addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest( |
| 727 | + request_id=request_id, |
| 728 | + subscription=agent_worker_pb2.Subscription( |
| 729 | + typeSubscription=agent_worker_pb2.TypeSubscription( |
| 730 | + topic_type=topic_type, agent_type=agent_type |
| 731 | + ) |
| 732 | + ), |
| 733 | + ) |
| 734 | + ) |
| 735 | + case TypePrefixSubscription(topic_type_prefix=topic_type_prefix, agent_type=agent_type): |
| 736 | + message = agent_worker_pb2.Message( |
| 737 | + addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest( |
| 738 | + request_id=request_id, |
| 739 | + subscription=agent_worker_pb2.Subscription( |
| 740 | + typePrefixSubscription=agent_worker_pb2.TypePrefixSubscription( |
| 741 | + topic_type_prefix=topic_type_prefix, agent_type=agent_type |
| 742 | + ) |
| 743 | + ), |
| 744 | + ) |
| 745 | + ) |
| 746 | + case _: |
| 747 | + raise ValueError("Unsupported subscription type.") |
| 748 | + |
| 749 | + # Add the future to the pending requests. |
716 | 750 | self._pending_requests[request_id] = future
|
717 | 751 |
|
| 752 | + # Add to local subscription manager. |
| 753 | + await self._subscription_manager.add_subscription(subscription) |
| 754 | + |
718 | 755 | # Send the subscription to the host.
|
719 |
| - message = agent_worker_pb2.Message( |
720 |
| - addSubscriptionRequest=agent_worker_pb2.AddSubscriptionRequest( |
721 |
| - request_id=request_id, |
722 |
| - subscription=agent_worker_pb2.Subscription( |
723 |
| - typeSubscription=agent_worker_pb2.TypeSubscription( |
724 |
| - topic_type=subscription.topic_type, agent_type=subscription.agent_type |
725 |
| - ) |
726 |
| - ), |
727 |
| - ) |
728 |
| - ) |
729 | 756 | await self._host_connection.send(message)
|
730 | 757 |
|
731 | 758 | # Wait for the subscription response.
|
|
0 commit comments