Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 2 additions & 33 deletions src/pipelines/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,48 +71,17 @@
from src.schemas.classification import ClassificationResult
from src.schemas.events import EventResult
from src.schemas.image import ImageResult
from src.schemas.judge import JudgeDomain, JudgeResult, OperationType
from src.schemas.judge import JudgeDomain, JudgeResult
from src.schemas.profile import ProfileResult
from src.schemas.summary import SummaryResult
from src.schemas.weaver import WeaverResult
from src.storage.base import BaseVectorStore, SearchResult
from src.storage.pinecone import PineconeVectorStore
from src.utils.embeddings import embed_text

logger = logging.getLogger("xmem.pipelines.ingest")


# ---------------------------------------------------------------------------
# Embedding helper — wraps Google GenAI into a simple callable
# ---------------------------------------------------------------------------

from google import genai
from google.genai import types

_embedding_client: Optional[genai.Client] = None


def get_embedding_client() -> genai.Client:
global _embedding_client
if _embedding_client is None:
api_key_to_use = settings.gemini_api_key or None
_embedding_client = genai.Client(api_key=api_key_to_use) if api_key_to_use else genai.Client()
logger.info("Loaded embedding client for model: %s", settings.embedding_model)
return _embedding_client


def embed_text(text: str) -> List[float]:
"""Embed a single text string → list of floats."""
client = get_embedding_client()
result = client.models.embed_content(
model=settings.embedding_model,
contents=text,
config=types.EmbedContentConfig(output_dimensionality=settings.pinecone_dimension)
)
[embedding_obj] = result.embeddings
return embedding_obj.values


# ---------------------------------------------------------------------------
# LangGraph state (typed dict shared across all nodes)
# ---------------------------------------------------------------------------

Expand Down
3 changes: 1 addition & 2 deletions src/pipelines/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from __future__ import annotations

import logging
import os
from typing import Any, Callable, Dict, List, Optional

from dotenv import load_dotenv
Expand Down Expand Up @@ -75,7 +74,7 @@ class SearchSummary(BaseModel):
# ═══════════════════════════════════════════════════════════════════════════

def _get_embed_fn() -> Callable[[str], List[float]]:
from src.pipelines.ingest import embed_text
from src.utils.embeddings import embed_text
return embed_text


Expand Down
2 changes: 1 addition & 1 deletion src/storage/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ async def search_by_text(
Returns:
List[SearchResult] — matched records sorted by similarity.
"""
from src.pipelines.ingest import embed_text
from src.utils.embeddings import embed_text

query_embedding = embed_text(query_text)
return self.search(
Expand Down
32 changes: 32 additions & 0 deletions src/utils/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
from typing import List, Optional

from google import genai
from google.genai import types

from src.config import settings

logger = logging.getLogger("xmem.utils.embeddings")

_embedding_client: Optional[genai.Client] = None


def get_embedding_client() -> genai.Client:
global _embedding_client
if _embedding_client is None:
api_key_to_use = settings.gemini_api_key or None
_embedding_client = genai.Client(api_key=api_key_to_use) if api_key_to_use else genai.Client()
logger.info("Loaded embedding client for model: %s", settings.embedding_model)
return _embedding_client


def embed_text(text: str) -> List[float]:
"""Embed a single text string → list of floats."""
client = get_embedding_client()
result = client.models.embed_content(
model=settings.embedding_model,
contents=text,
config=types.EmbedContentConfig(output_dimensionality=settings.pinecone_dimension)
)
[embedding_obj] = result.embeddings
return embedding_obj.values
85 changes: 85 additions & 0 deletions tests/unit/utils/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import unittest
from unittest.mock import MagicMock, patch
import sys
import importlib

class TestEmbeddings(unittest.TestCase):
def setUp(self):
# Create mocks
self.mock_google = MagicMock()
self.mock_genai = MagicMock()
self.mock_types = MagicMock()
# Ensure imports work: import google.genai -> accessing mock_google.genai
self.mock_google.genai = self.mock_genai
self.mock_genai.types = self.mock_types

# Configure genai.Client
self.mock_client_cls = self.mock_genai.Client

self.mock_settings = MagicMock()
self.mock_settings.gemini_api_key = "fake_key"
self.mock_settings.embedding_model = "fake_model"
self.mock_settings.pinecone_dimension = 123

# Patch sys.modules
self.modules_patcher = patch.dict("sys.modules", {
"google": self.mock_google,
"google.genai": self.mock_genai,
"google.genai.types": self.mock_types,
"src.config": MagicMock(settings=self.mock_settings),
})
self.modules_patcher.start()

# Import or reload src.utils.embeddings
if "src.utils.embeddings" in sys.modules:
import src.utils.embeddings
importlib.reload(src.utils.embeddings)
else:
import src.utils.embeddings

self.module = sys.modules["src.utils.embeddings"]
# Ensure _embedding_client is None
self.module._embedding_client = None

def tearDown(self):
self.modules_patcher.stop()

def test_get_embedding_client_initialization(self):
# When
client = self.module.get_embedding_client()

# Then
self.mock_client_cls.assert_called_once_with(api_key="fake_key")
self.assertEqual(client, self.mock_client_cls.return_value)

# Call again
client2 = self.module.get_embedding_client()
self.mock_client_cls.assert_called_once() # Should be called only once
self.assertEqual(client, client2)

def test_embed_text(self):
# Setup
mock_client_instance = self.mock_client_cls.return_value

mock_result = MagicMock()
mock_embedding_obj = MagicMock()
mock_embedding_obj.values = [0.1, 0.2, 0.3]
mock_result.embeddings = [mock_embedding_obj]

mock_client_instance.models.embed_content.return_value = mock_result

# When
result = self.module.embed_text("hello world")

# Then
self.assertEqual(result, [0.1, 0.2, 0.3])
mock_client_instance.models.embed_content.assert_called_once()

_, kwargs = mock_client_instance.models.embed_content.call_args
self.assertEqual(kwargs['model'], "fake_model")
self.assertEqual(kwargs['contents'], "hello world")
# Ensure config is passed
self.assertIn('config', kwargs)

if __name__ == "__main__":
unittest.main()