diff --git a/src/pipelines/code_retrieval.py b/src/pipelines/code_retrieval.py index 69e4abc..ee05916 100644 --- a/src/pipelines/code_retrieval.py +++ b/src/pipelines/code_retrieval.py @@ -25,6 +25,7 @@ from __future__ import annotations +import asyncio import logging from typing import Any, Callable, Dict, List, Optional @@ -589,14 +590,22 @@ async def _search_symbols( ) -> List[SourceRecord]: if not repo: logger.warning("search_symbols called without repo — searching all repos") - results = [] - for r in self.repos: - results.extend(await self._search_namespace( + + async def _search(r: str) -> List[SourceRecord]: + return await self._search_namespace( namespace=symbols_namespace(self.org_id, r), query=query, domain="symbol", top_k=top_k, - )) + ) + + tasks = [_search(r) for r in self.repos] + all_results = await asyncio.gather(*tasks) + + results = [] + for res in all_results: + results.extend(res) + return results[:top_k] return await self._search_namespace( @@ -612,14 +621,21 @@ async def _search_files( self, query: str, repo: str, top_k: int = 10, ) -> List[SourceRecord]: if not repo: - results = [] - for r in self.repos: - results.extend(await self._search_namespace( + async def _search(r: str) -> List[SourceRecord]: + return await self._search_namespace( namespace=files_namespace(self.org_id, r), query=query, domain="file", top_k=top_k, - )) + ) + + tasks = [_search(r) for r in self.repos] + all_results = await asyncio.gather(*tasks) + + results = [] + for res in all_results: + results.extend(res) + return results[:top_k] return await self._search_namespace(