diff --git a/agentex/src/domain/services/task_message_service.py b/agentex/src/domain/services/task_message_service.py index 4cd23b34..ad75f98c 100644 --- a/agentex/src/domain/services/task_message_service.py +++ b/agentex/src/domain/services/task_message_service.py @@ -3,12 +3,17 @@ from fastapi import Depends +from src.adapters.streams.adapter_redis import DRedisStreamRepository +from src.domain.entities.task_message_updates import ( + StreamTaskMessageFullEntity, +) from src.domain.entities.task_messages import ( TaskMessageContentEntity, TaskMessageEntity, ) from src.domain.repositories.task_message_repository import DTaskMessageRepository from src.utils.logging import make_logger +from src.utils.stream_topics import get_task_event_stream_topic logger = make_logger(__name__) @@ -18,14 +23,52 @@ class TaskMessageService: Service for handling task message operations. """ - def __init__(self, message_repository: DTaskMessageRepository): + def __init__( + self, + message_repository: DTaskMessageRepository, + stream_repository: DRedisStreamRepository, + ): """ Initialize the service with required dependencies. Args: message_repository: Repository for storing and retrieving messages + stream_repository: Repository for publishing stream events """ self.repository = message_repository + self.stream_repository = stream_repository + + async def _publish_message_full_event( + self, + task_message: TaskMessageEntity, + ) -> None: + """ + Publish a StreamTaskMessageFull event to the Redis stream. + + This notifies subscribed clients (UI) about the new/updated message. + + Args: + task_message: The message entity to publish + """ + try: + topic = get_task_event_stream_topic(task_id=task_message.task_id) + event = StreamTaskMessageFullEntity( + type="full", + index=0, + parent_task_message=task_message, + content=task_message.content, + ) + await self.stream_repository.send_data( + topic, + event.model_dump(mode="json"), + ) + logger.info( + f"Published message_full event for message {task_message.id} to topic: {topic}" + ) + except Exception as e: + logger.error( + f"Error publishing message_full event to stream: {e}", exc_info=True + ) async def get_message(self, message_id: str) -> TaskMessageEntity: """ @@ -119,7 +162,12 @@ async def append_message( streaming_status=streaming_status, ) - return await self.repository.create(task_message) + created_message = await self.repository.create(task_message) + + # Publish stream event to notify UI clients + await self._publish_message_full_event(created_message) + + return created_message async def append_messages( self, @@ -155,6 +203,11 @@ async def append_messages( created_messages = await self.repository.batch_create(task_messages) logger.info(f"Created batch of messages: {created_messages}") + + # Publish stream events for each created message to notify UI clients + for created_message in created_messages: + await self._publish_message_full_event(created_message) + return created_messages async def update_message( @@ -183,7 +236,12 @@ async def update_message( task_message.content = content if streaming_status is not None: task_message.streaming_status = streaming_status - return await self.repository.update(task_message) + updated_message = await self.repository.update(task_message) + + # Publish stream event to notify UI clients + await self._publish_message_full_event(updated_message) + + return updated_message return None async def update_messages( @@ -206,7 +264,11 @@ async def update_messages( if task_message and task_message.task_id == task_id: # Update the message field but preserve other fields task_message.content = message - updated_messages.append(await self.repository.update(task_message)) + updated_message = await self.repository.update(task_message) + updated_messages.append(updated_message) + + # Publish stream event to notify UI clients + await self._publish_message_full_event(updated_message) return updated_messages diff --git a/agentex/tests/fixtures/services.py b/agentex/tests/fixtures/services.py index 30b16ca2..94ca86b6 100644 --- a/agentex/tests/fixtures/services.py +++ b/agentex/tests/fixtures/services.py @@ -12,11 +12,14 @@ # ============================================================================= -def create_task_message_service(task_message_repository): +def create_task_message_service(task_message_repository, stream_repository): """Factory function to create TaskMessageService with given repository""" from src.domain.services.task_message_service import TaskMessageService - return TaskMessageService(task_message_repository=task_message_repository) + return TaskMessageService( + message_repository=task_message_repository, + stream_repository=stream_repository, + ) def create_agent_acp_service(http_gateway, agent_repository, agent_api_key_repository): @@ -110,9 +113,9 @@ def mock_environment_variables(): @pytest.fixture -def task_message_service(task_message_repository): +def task_message_service(task_message_repository, redis_stream_repository): """Task message service for unit tests""" - return create_task_message_service(task_message_repository) + return create_task_message_service(task_message_repository, redis_stream_repository) @pytest.fixture diff --git a/agentex/tests/integration/fixtures/integration_client.py b/agentex/tests/integration/fixtures/integration_client.py index ade4bdc4..2ad3c249 100644 --- a/agentex/tests/integration/fixtures/integration_client.py +++ b/agentex/tests/integration/fixtures/integration_client.py @@ -455,7 +455,8 @@ def create_messages_use_case(): from src.domain.services.task_message_service import TaskMessageService task_message_service = TaskMessageService( - message_repository=isolated_repositories["task_message_repository"] + message_repository=isolated_repositories["task_message_repository"], + stream_repository=isolated_repositories["redis_stream_repository"], ) return MessagesUseCase(task_message_service=task_message_service) diff --git a/agentex/tests/unit/services/test_task_message_service.py b/agentex/tests/unit/services/test_task_message_service.py index 50bf29eb..90feb1e6 100644 --- a/agentex/tests/unit/services/test_task_message_service.py +++ b/agentex/tests/unit/services/test_task_message_service.py @@ -20,9 +20,12 @@ def task_message_repository(mongodb_database): @pytest.fixture -def task_message_service(task_message_repository): +def task_message_service(task_message_repository, redis_stream_repository): """Create TaskMessageService instance with real repository""" - return TaskMessageService(message_repository=task_message_repository) + return TaskMessageService( + message_repository=task_message_repository, + stream_repository=redis_stream_repository, + ) @pytest.fixture diff --git a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py index 3f9278d4..c025209b 100644 --- a/agentex/tests/unit/use_cases/test_agents_acp_use_case.py +++ b/agentex/tests/unit/use_cases/test_agents_acp_use_case.py @@ -131,10 +131,11 @@ def task_service( @pytest.fixture -def task_message_service(task_message_repository): +def task_message_service(task_message_repository, redis_stream_repository): """Real TaskMessageService instance""" return TaskMessageService( message_repository=task_message_repository, + stream_repository=redis_stream_repository, )