diff --git a/src/pipelines/code_retrieval.py b/src/pipelines/code_retrieval.py index 69e4abc..bbc3640 100644 --- a/src/pipelines/code_retrieval.py +++ b/src/pipelines/code_retrieval.py @@ -26,6 +26,7 @@ from __future__ import annotations import logging +import asyncio from typing import Any, Callable, Dict, List, Optional from langchain_core.language_models import BaseChatModel @@ -589,14 +590,17 @@ async def _search_symbols( ) -> List[SourceRecord]: if not repo: logger.warning("search_symbols called without repo — searching all repos") - results = [] - for r in self.repos: - results.extend(await self._search_namespace( + tasks = [ + self._search_namespace( namespace=symbols_namespace(self.org_id, r), query=query, domain="symbol", top_k=top_k, - )) + ) + for r in self.repos + ] + results_list = await asyncio.gather(*tasks) + results = [record for sublist in results_list for record in sublist] return results[:top_k] return await self._search_namespace( @@ -612,14 +616,17 @@ async def _search_files( self, query: str, repo: str, top_k: int = 10, ) -> List[SourceRecord]: if not repo: - results = [] - for r in self.repos: - results.extend(await self._search_namespace( + tasks = [ + self._search_namespace( namespace=files_namespace(self.org_id, r), query=query, domain="file", top_k=top_k, - )) + ) + for r in self.repos + ] + results_list = await asyncio.gather(*tasks) + results = [record for sublist in results_list for record in sublist] return results[:top_k] return await self._search_namespace(