diff --git a/src/pipelines/code_retrieval.py b/src/pipelines/code_retrieval.py index 69e4abc..e004457 100644 --- a/src/pipelines/code_retrieval.py +++ b/src/pipelines/code_retrieval.py @@ -25,6 +25,7 @@ from __future__ import annotations +import asyncio import logging from typing import Any, Callable, Dict, List, Optional @@ -375,11 +376,9 @@ async def run( turn_records: List[SourceRecord] = [] only_read_tools = True - for tc in ai_response.tool_calls: + async def _process_tool_call(tc): tool_name = tc["name"] tool_args = tc["args"] - tool_id = tc["id"] - t1 = _time.perf_counter() records = await self._execute_tool( tool_name, tool_args, repo=repo, top_k=top_k, @@ -387,6 +386,14 @@ async def run( ) tool_ms = (_time.perf_counter() - t1) * 1000 logger.info(" Tool: %s(%s) → %d results (%.0fms)", tool_name, tool_args, len(records), tool_ms) + return tc, records + + results = await asyncio.gather(*(_process_tool_call(tc) for tc in ai_response.tool_calls)) + + for tc, records in results: + tool_name = tc["name"] + tool_id = tc["id"] + turn_records.extend(records) sources.extend(records) @@ -589,14 +596,19 @@ async def _search_symbols( ) -> List[SourceRecord]: if not repo: logger.warning("search_symbols called without repo — searching all repos") - results = [] - for r in self.repos: - results.extend(await self._search_namespace( + all_results = await asyncio.gather(*( + self._search_namespace( namespace=symbols_namespace(self.org_id, r), query=query, domain="symbol", top_k=top_k, - )) + ) + for r in self.repos + )) + + results = [] + for res in all_results: + results.extend(res) return results[:top_k] return await self._search_namespace( @@ -612,14 +624,19 @@ async def _search_files( self, query: str, repo: str, top_k: int = 10, ) -> List[SourceRecord]: if not repo: - results = [] - for r in self.repos: - results.extend(await self._search_namespace( + all_results = await asyncio.gather(*( + self._search_namespace( namespace=files_namespace(self.org_id, r), query=query, domain="file", top_k=top_k, - )) + ) + for r in self.repos + )) + + results = [] + for res in all_results: + results.extend(res) return results[:top_k] return await self._search_namespace( diff --git a/src/pipelines/retrieval.py b/src/pipelines/retrieval.py index d54cc0d..edd854d 100644 --- a/src/pipelines/retrieval.py +++ b/src/pipelines/retrieval.py @@ -20,6 +20,7 @@ from __future__ import annotations +import asyncio import logging import os from typing import Any, Callable, Dict, List, Optional @@ -177,16 +178,21 @@ async def run( if ai_response.tool_calls: called_tools = set() - for tc in ai_response.tool_calls: + + async def _process_tool_call(tc): tool_name = tc["name"] tool_args = tc["args"] - tool_id = tc["id"] - logger.info(" Tool call: %s(%s)", tool_name, tool_args) - records = await self._execute_tool( tool_name, tool_args, user_id, top_k, ) + return tc, records + + results = await asyncio.gather(*(_process_tool_call(tc) for tc in ai_response.tool_calls)) + + for tc, records in results: + tool_name = tc["name"] + tool_id = tc["id"] sources.extend(records) # Build ToolMessage for the LLM