diff --git a/src/pipelines/ingest.py b/src/pipelines/ingest.py index c85aa9c..8c407e2 100644 --- a/src/pipelines/ingest.py +++ b/src/pipelines/ingest.py @@ -71,48 +71,17 @@ from src.schemas.classification import ClassificationResult from src.schemas.events import EventResult from src.schemas.image import ImageResult -from src.schemas.judge import JudgeDomain, JudgeResult, OperationType +from src.schemas.judge import JudgeDomain, JudgeResult from src.schemas.profile import ProfileResult from src.schemas.summary import SummaryResult from src.schemas.weaver import WeaverResult from src.storage.base import BaseVectorStore, SearchResult from src.storage.pinecone import PineconeVectorStore +from src.utils.embeddings import embed_text logger = logging.getLogger("xmem.pipelines.ingest") -# --------------------------------------------------------------------------- -# Embedding helper — wraps Google GenAI into a simple callable -# --------------------------------------------------------------------------- - -from google import genai -from google.genai import types - -_embedding_client: Optional[genai.Client] = None - - -def get_embedding_client() -> genai.Client: - global _embedding_client - if _embedding_client is None: - api_key_to_use = settings.gemini_api_key or None - _embedding_client = genai.Client(api_key=api_key_to_use) if api_key_to_use else genai.Client() - logger.info("Loaded embedding client for model: %s", settings.embedding_model) - return _embedding_client - - -def embed_text(text: str) -> List[float]: - """Embed a single text string → list of floats.""" - client = get_embedding_client() - result = client.models.embed_content( - model=settings.embedding_model, - contents=text, - config=types.EmbedContentConfig(output_dimensionality=settings.pinecone_dimension) - ) - [embedding_obj] = result.embeddings - return embedding_obj.values - - -# --------------------------------------------------------------------------- # LangGraph state (typed dict shared across all nodes) # --------------------------------------------------------------------------- diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index 8d2711a..894cb7b 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -21,7 +21,6 @@ from __future__ import annotations import logging -import os from typing import Any, Callable, Dict, List, Optional from dotenv import load_dotenv @@ -75,7 +74,7 @@ class SearchSummary(BaseModel): # ═══════════════════════════════════════════════════════════════════════════ def _get_embed_fn() -> Callable[[str], List[float]]: - from src.pipelines.ingest import embed_text + from src.utils.embeddings import embed_text return embed_text diff --git a/src/storage/pinecone.py b/src/storage/pinecone.py index b558e17..1a878aa 100644 --- a/src/storage/pinecone.py +++ b/src/storage/pinecone.py @@ -1029,7 +1029,7 @@ async def search_by_text( Returns: List[SearchResult] — matched records sorted by similarity. """ - from src.pipelines.ingest import embed_text + from src.utils.embeddings import embed_text query_embedding = embed_text(query_text) return self.search( diff --git a/src/utils/embeddings.py b/src/utils/embeddings.py new file mode 100644 index 0000000..d327459 --- /dev/null +++ b/src/utils/embeddings.py @@ -0,0 +1,32 @@ +import logging +from typing import List, Optional + +from google import genai +from google.genai import types + +from src.config import settings + +logger = logging.getLogger("xmem.utils.embeddings") + +_embedding_client: Optional[genai.Client] = None + + +def get_embedding_client() -> genai.Client: + global _embedding_client + if _embedding_client is None: + api_key_to_use = settings.gemini_api_key or None + _embedding_client = genai.Client(api_key=api_key_to_use) if api_key_to_use else genai.Client() + logger.info("Loaded embedding client for model: %s", settings.embedding_model) + return _embedding_client + + +def embed_text(text: str) -> List[float]: + """Embed a single text string → list of floats.""" + client = get_embedding_client() + result = client.models.embed_content( + model=settings.embedding_model, + contents=text, + config=types.EmbedContentConfig(output_dimensionality=settings.pinecone_dimension) + ) + [embedding_obj] = result.embeddings + return embedding_obj.values diff --git a/tests/unit/utils/test_embeddings.py b/tests/unit/utils/test_embeddings.py new file mode 100644 index 0000000..06ff814 --- /dev/null +++ b/tests/unit/utils/test_embeddings.py @@ -0,0 +1,85 @@ +import unittest +from unittest.mock import MagicMock, patch +import sys +import importlib + +class TestEmbeddings(unittest.TestCase): + def setUp(self): + # Create mocks + self.mock_google = MagicMock() + self.mock_genai = MagicMock() + self.mock_types = MagicMock() + # Ensure imports work: import google.genai -> accessing mock_google.genai + self.mock_google.genai = self.mock_genai + self.mock_genai.types = self.mock_types + + # Configure genai.Client + self.mock_client_cls = self.mock_genai.Client + + self.mock_settings = MagicMock() + self.mock_settings.gemini_api_key = "fake_key" + self.mock_settings.embedding_model = "fake_model" + self.mock_settings.pinecone_dimension = 123 + + # Patch sys.modules + self.modules_patcher = patch.dict("sys.modules", { + "google": self.mock_google, + "google.genai": self.mock_genai, + "google.genai.types": self.mock_types, + "src.config": MagicMock(settings=self.mock_settings), + }) + self.modules_patcher.start() + + # Import or reload src.utils.embeddings + if "src.utils.embeddings" in sys.modules: + import src.utils.embeddings + importlib.reload(src.utils.embeddings) + else: + import src.utils.embeddings + + self.module = sys.modules["src.utils.embeddings"] + # Ensure _embedding_client is None + self.module._embedding_client = None + + def tearDown(self): + self.modules_patcher.stop() + + def test_get_embedding_client_initialization(self): + # When + client = self.module.get_embedding_client() + + # Then + self.mock_client_cls.assert_called_once_with(api_key="fake_key") + self.assertEqual(client, self.mock_client_cls.return_value) + + # Call again + client2 = self.module.get_embedding_client() + self.mock_client_cls.assert_called_once() # Should be called only once + self.assertEqual(client, client2) + + def test_embed_text(self): + # Setup + mock_client_instance = self.mock_client_cls.return_value + + mock_result = MagicMock() + mock_embedding_obj = MagicMock() + mock_embedding_obj.values = [0.1, 0.2, 0.3] + mock_result.embeddings = [mock_embedding_obj] + + mock_client_instance.models.embed_content.return_value = mock_result + + # When + result = self.module.embed_text("hello world") + + # Then + self.assertEqual(result, [0.1, 0.2, 0.3]) + mock_client_instance.models.embed_content.assert_called_once() + + _, kwargs = mock_client_instance.models.embed_content.call_args + self.assertEqual(kwargs['model'], "fake_model") + self.assertEqual(kwargs['contents'], "hello world") + # Ensure config is passed + self.assertIn('config', kwargs) + +if __name__ == "__main__": + unittest.main()