Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 66 additions & 4 deletions agentex/src/domain/services/task_message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@

from fastapi import Depends

from src.adapters.streams.adapter_redis import DRedisStreamRepository
from src.domain.entities.task_message_updates import (
StreamTaskMessageFullEntity,
)
from src.domain.entities.task_messages import (
TaskMessageContentEntity,
TaskMessageEntity,
)
from src.domain.repositories.task_message_repository import DTaskMessageRepository
from src.utils.logging import make_logger
from src.utils.stream_topics import get_task_event_stream_topic

logger = make_logger(__name__)

Expand All @@ -18,14 +23,52 @@ class TaskMessageService:
Service for handling task message operations.
"""

def __init__(self, message_repository: DTaskMessageRepository):
def __init__(
self,
message_repository: DTaskMessageRepository,
stream_repository: DRedisStreamRepository,
):
"""
Initialize the service with required dependencies.

Args:
message_repository: Repository for storing and retrieving messages
stream_repository: Repository for publishing stream events
"""
self.repository = message_repository
self.stream_repository = stream_repository

async def _publish_message_full_event(
self,
task_message: TaskMessageEntity,
) -> None:
"""
Publish a StreamTaskMessageFull event to the Redis stream.

This notifies subscribed clients (UI) about the new/updated message.

Args:
task_message: The message entity to publish
"""
try:
topic = get_task_event_stream_topic(task_id=task_message.task_id)
event = StreamTaskMessageFullEntity(
type="full",
index=0,
parent_task_message=task_message,
content=task_message.content,
)
await self.stream_repository.send_data(
topic,
event.model_dump(mode="json"),
)
logger.info(
f"Published message_full event for message {task_message.id} to topic: {topic}"
)
except Exception as e:
logger.error(
f"Error publishing message_full event to stream: {e}", exc_info=True
)

async def get_message(self, message_id: str) -> TaskMessageEntity:
"""
Expand Down Expand Up @@ -119,7 +162,12 @@ async def append_message(
streaming_status=streaming_status,
)

return await self.repository.create(task_message)
created_message = await self.repository.create(task_message)

# Publish stream event to notify UI clients
await self._publish_message_full_event(created_message)

return created_message

async def append_messages(
self,
Expand Down Expand Up @@ -155,6 +203,11 @@ async def append_messages(

created_messages = await self.repository.batch_create(task_messages)
logger.info(f"Created batch of messages: {created_messages}")

# Publish stream events for each created message to notify UI clients
for created_message in created_messages:
await self._publish_message_full_event(created_message)

return created_messages

async def update_message(
Expand Down Expand Up @@ -183,7 +236,12 @@ async def update_message(
task_message.content = content
if streaming_status is not None:
task_message.streaming_status = streaming_status
return await self.repository.update(task_message)
updated_message = await self.repository.update(task_message)

# Publish stream event to notify UI clients
await self._publish_message_full_event(updated_message)

return updated_message
return None

async def update_messages(
Expand All @@ -206,7 +264,11 @@ async def update_messages(
if task_message and task_message.task_id == task_id:
# Update the message field but preserve other fields
task_message.content = message
updated_messages.append(await self.repository.update(task_message))
updated_message = await self.repository.update(task_message)
updated_messages.append(updated_message)

# Publish stream event to notify UI clients
await self._publish_message_full_event(updated_message)

return updated_messages

Expand Down
11 changes: 7 additions & 4 deletions agentex/tests/fixtures/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# =============================================================================


def create_task_message_service(task_message_repository):
def create_task_message_service(task_message_repository, stream_repository):
"""Factory function to create TaskMessageService with given repository"""
from src.domain.services.task_message_service import TaskMessageService

return TaskMessageService(task_message_repository=task_message_repository)
return TaskMessageService(
message_repository=task_message_repository,
stream_repository=stream_repository,
)


def create_agent_acp_service(http_gateway, agent_repository, agent_api_key_repository):
Expand Down Expand Up @@ -110,9 +113,9 @@ def mock_environment_variables():


@pytest.fixture
def task_message_service(task_message_repository):
def task_message_service(task_message_repository, redis_stream_repository):
"""Task message service for unit tests"""
return create_task_message_service(task_message_repository)
return create_task_message_service(task_message_repository, redis_stream_repository)


@pytest.fixture
Expand Down
3 changes: 2 additions & 1 deletion agentex/tests/integration/fixtures/integration_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ def create_messages_use_case():
from src.domain.services.task_message_service import TaskMessageService

task_message_service = TaskMessageService(
message_repository=isolated_repositories["task_message_repository"]
message_repository=isolated_repositories["task_message_repository"],
stream_repository=isolated_repositories["redis_stream_repository"],
)

return MessagesUseCase(task_message_service=task_message_service)
Expand Down
7 changes: 5 additions & 2 deletions agentex/tests/unit/services/test_task_message_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ def task_message_repository(mongodb_database):


@pytest.fixture
def task_message_service(task_message_repository):
def task_message_service(task_message_repository, redis_stream_repository):
"""Create TaskMessageService instance with real repository"""
return TaskMessageService(message_repository=task_message_repository)
return TaskMessageService(
message_repository=task_message_repository,
stream_repository=redis_stream_repository,
)


@pytest.fixture
Expand Down
3 changes: 2 additions & 1 deletion agentex/tests/unit/use_cases/test_agents_acp_use_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def task_service(


@pytest.fixture
def task_message_service(task_message_repository):
def task_message_service(task_message_repository, redis_stream_repository):
"""Real TaskMessageService instance"""
return TaskMessageService(
message_repository=task_message_repository,
stream_repository=redis_stream_repository,
)


Expand Down
Loading