diff --git a/src/cai/cli.py b/src/cai/cli.py index 8cacf5258..fde1c5be7 100644 --- a/src/cai/cli.py +++ b/src/cai/cli.py @@ -313,6 +313,7 @@ def suppress_aiohttp_warnings(): from cai.sdk.agents.models.openai_chatcompletions import ( get_agent_message_history, get_all_agent_histories, + ContextCompactedError, ) # Import handled where needed to avoid circular imports from cai.sdk.agents.run_to_jsonl import get_session_recorder @@ -442,6 +443,9 @@ def run_cai_cli( agent = starting_agent turn_count = 0 idle_time = 0 + # Holds a user message to replay on the next iteration without prompting + # the user — set by auto-compact so the agent continues its current task. + _post_compact_input: str | None = None console = Console() last_model = os.getenv("CAI_MODEL", "alias1") last_agent_type = os.getenv("CAI_AGENT_TYPE", "one_tool_agent") @@ -482,6 +486,18 @@ def run_cai_cli( print("\n") display_quick_guide(console) + # Notify user if auto-compact is active so they can confirm the vars loaded. + _sc_model_startup = os.getenv("CAI_SUPPORT_MODEL") + _sc_interval_startup = os.getenv("CAI_SUPPORT_INTERVAL") + if _sc_model_startup and _sc_interval_startup: + try: + console.print( + f"[bold cyan]🗜 Auto-compact enabled: every {int(_sc_interval_startup)} LLM responses " + f"using {_sc_model_startup}[/bold cyan]" + ) + except ValueError: + pass + # Function to get the short name of the agent for display def get_agent_short_name(agent): if hasattr(agent, "name"): @@ -690,6 +706,11 @@ def get_agent_short_name(agent): if use_initial_prompt: user_input = initial_prompt use_initial_prompt = False # Only use it once + elif _post_compact_input is not None: + # Auto-compact just ran — replay the last task so the agent + # continues working without waiting for human input. + user_input = _post_compact_input + _post_compact_input = None else: # Get user input with command completion and history user_input = get_user_input( @@ -1479,6 +1500,10 @@ async def process_parallel_responses(): {"role": "assistant", "content": f"{result.final_output}"} ) else: + # Capture user_input before runner calls so ContextCompactedError + # handlers can reference it even on the very first iteration. + _last_user_input = user_input if isinstance(user_input, str) else "" + # Disable streaming by default, unless specifically enabled cai_stream = os.getenv("CAI_STREAM", "false") # Handle empty string or None values @@ -1556,6 +1581,9 @@ async def process_streamed_response(agent, conversation_input): pass raise e + except ContextCompactedError: + # Propagate so the outer try block can handle the restart. + raise except Exception as e: # Clean up on any other exception if stream_iterator is not None: @@ -1583,6 +1611,26 @@ async def process_streamed_response(agent, conversation_input): try: asyncio.run(process_streamed_response(agent, conversation_input)) + except ContextCompactedError: + # Auto-compact fired mid-runner; restart with fresh context. + _base = _last_user_input or "Continue the current task." + _post_compact_input = ( + f"{_base}\n\n" + "IMPORTANT: Your context window was just compacted. " + "Your session memory is already loaded above. " + "Review the 'Exhausted Approaches' section in your memory and " + "DO NOT repeat any technique, command, URL, port scan, or login " + "attempt already listed there. " + "Pick up exactly where you left off using only NEW approaches." + ) + from cai.sdk.agents.simple_agent_manager import AGENT_MANAGER as _AM + _reloaded = _AM.get_active_agent() + if _reloaded is not None: + agent = _reloaded + console.print( + "[bold green]✓ Context window reset — resuming task[/bold green]\n" + ) + continue except OutputGuardrailTripwireTriggered as e: # Display a user-friendly warning instead of crashing (streaming mode) guardrail_name = e.guardrail_result.guardrail.get_name() @@ -1642,6 +1690,26 @@ async def process_streamed_response(agent, conversation_input): # Use non-streamed response try: response = asyncio.run(Runner.run(agent, conversation_input)) + except ContextCompactedError: + # Auto-compact fired mid-runner; restart with fresh context. + _base = _last_user_input or "Continue the current task." + _post_compact_input = ( + f"{_base}\n\n" + "IMPORTANT: Your context window was just compacted. " + "Your session memory is already loaded above. " + "Review the 'Exhausted Approaches' section in your memory and " + "DO NOT repeat any technique, command, URL, port scan, or login " + "attempt already listed there. " + "Pick up exactly where you left off using only NEW approaches." + ) + from cai.sdk.agents.simple_agent_manager import AGENT_MANAGER as _AM + _reloaded = _AM.get_active_agent() + if _reloaded is not None: + agent = _reloaded + console.print( + "[bold green]✓ Context window reset — resuming task[/bold green]\n" + ) + continue except InputGuardrailTripwireTriggered as e: # Display a user-friendly warning for input guardrails reason = "Potential security threat detected in input" @@ -1711,6 +1779,63 @@ async def process_streamed_response(agent, conversation_input): agent.model.message_history[:] = fix_message_list(agent.model.message_history) turn_count += 1 + # Auto-compact: when CAI_SUPPORT_MODEL + CAI_SUPPORT_INTERVAL are both set, + # compact the conversation every N LLM *responses* (assistant messages in + # history) using the support model. Counting assistant messages rather + # than outer-loop turns means agentic sessions — where the agent makes + # many tool-call rounds per single user input — are handled correctly. + _support_model = os.getenv("CAI_SUPPORT_MODEL") + _support_interval_raw = os.getenv("CAI_SUPPORT_INTERVAL") + if _support_model and _support_interval_raw: + try: + _support_interval = int(_support_interval_raw) + if _support_interval > 0: + # Count assistant messages as a proxy for LLM API calls. + _history = getattr(getattr(agent, 'model', None), 'message_history', []) + _llm_call_count = sum( + 1 for m in _history + if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) + == "assistant" + ) + if _llm_call_count > 0: + _calls_until = max(0, _support_interval - _llm_call_count) + if _calls_until > 0: + console.print( + f"[dim cyan] ↻ auto-compact in {_calls_until} LLM response(s) " + f"[{_llm_call_count}/{_support_interval}][/dim cyan]" + ) + if _llm_call_count >= _support_interval: + from cai.repl.commands.compact import COMPACT_COMMAND_INSTANCE + console.print( + f"\n[bold yellow]⟳ Auto-compact: {_llm_call_count} LLM responses " + f"(threshold {_support_interval}) — " + f"summarising with {_support_model}[/bold yellow]" + ) + COMPACT_COMMAND_INSTANCE._perform_compaction( + model_override=_support_model + ) + # Re-sync the local agent reference so the loop continues + # with the freshly reloaded agent (history cleared, memory + # summary already injected into its system prompt). + from cai.sdk.agents.simple_agent_manager import AGENT_MANAGER as _AM + _reloaded = _AM.get_active_agent() + if _reloaded is not None: + agent = _reloaded + # Queue the last user task to be replayed on the next + # iteration so the agent continues without human input. + _post_compact_input = ( + _last_user_input + if _last_user_input.strip() + else "Continue the current task." + ) + console.print( + "[bold green]✓ Memory summary applied to agent system prompt — " + "context window reset — continuing task[/bold green]\n" + ) + except (ValueError, Exception) as _e: + # Always show auto-compact errors so they are never silently lost. + console.print(f"[red]Auto-compact error: {_e}[/red]") + # Stop measuring active time and start measuring idle time again stop_active_timer() start_idle_timer() diff --git a/src/cai/repl/commands/memory.py b/src/cai/repl/commands/memory.py index 9f963d4e6..e0523a36a 100644 --- a/src/cai/repl/commands/memory.py +++ b/src/cai/repl/commands/memory.py @@ -4,6 +4,7 @@ """ from typing import List, Optional, Dict, Any +import inspect import os import asyncio import json @@ -1221,7 +1222,8 @@ async def _ai_summarize_history(self, agent_name: Optional[str] = None) -> Optio 6. **All User Messages**: Complete list of user messages in order 7. **Pending Tasks**: What still needs to be done 8. **Current Work**: What was being worked on when the conversation ended -9. **Optional Next Step**: If there's a clear next action, mention it +9. **Exhausted Approaches — DO NOT RETRY**: Every technique, command, path, or attack vector that was attempted and failed. Format each as a bullet starting with ❌. Be specific (include exact commands, URLs, usernames, ports). This section is CRITICAL — the agent will use it to avoid wasting time on dead ends. +10. **Recommended Next Steps**: Concrete actions NOT yet tried, ordered by likelihood of success. ## Important Guidelines @@ -1232,6 +1234,7 @@ async def _ai_summarize_history(self, agent_name: Optional[str] = None) -> Optio - Maintain technical accuracy - don't paraphrase technical terms - The summary will be used as the primary context for resuming work, so completeness is crucial - When the conversation is resumed, it should feel like a natural continuation +- Section 9 (Exhausted Approaches) is the most important section for offensive/hacking tasks: list every failed attempt so the agent doesn't loop. This session is being continued from a previous conversation that ran out of context. The conversation is summarized below:""" @@ -1254,50 +1257,87 @@ async def _ai_summarize_history(self, agent_name: Optional[str] = None) -> Optio input=f"Please summarize the following conversation:\n\n{conversation_text}", max_turns=1 ) - + if result.final_output: return str(result.final_output) else: return None - + except Exception as e: console.print(f"[red]Error generating summary: {e}[/red]") return None + finally: + # Best-effort: explicitly cleanup the temporary summary/support model + try: + model_inst = getattr(summary_agent, "model", None) + # Some Agent constructions put the Model object directly on `agent.model` + # and some providers expose a cleanup coroutine. + if model_inst is not None and hasattr(model_inst, "cleanup"): + try: + coro = model_inst.cleanup() + if inspect.isawaitable(coro): + await coro + except Exception: + # best-effort cleanup — swallow any errors + pass + except Exception: + pass def _format_history_for_summary(self, history: List[Dict[str, Any]]) -> str: - """Format message history for summarization.""" + """Format message history for summarization. + + Critical design goals: + - Include EVERY tool call with its exact arguments (commands run, URLs visited, + ports scanned) so the summary model can produce an "Exhausted Approaches" list. + - Include enough of each tool result to convey success/failure and key findings. + - Avoid blowing out the summary model's context by capping large outputs. + """ + TOOL_OUTPUT_KEEP = 2000 # chars to preserve from each tool result + MAX_PARTS = 200 # maximum formatted blocks to pass (covers ~100 turns) + formatted_parts = [] - + for msg in history: role = msg.get("role", "unknown") content = msg.get("content", "") - - # Skip empty messages - if not content: - continue - - # Format based on role + if role == "user": - formatted_parts.append(f"USER: {content}") + if content: + formatted_parts.append(f"USER: {content}") + elif role == "assistant": - # Check for tool calls - if "tool_calls" in msg and msg["tool_calls"]: + # -- tool calls: extract args from both dict-style and object-style entries -- + tool_calls = msg.get("tool_calls") or [] + if tool_calls: tool_info = [] - for tc in msg["tool_calls"]: - if hasattr(tc, "function"): - tool_info.append(f"{tc.function.name}({tc.function.arguments})") + for tc in tool_calls: + if isinstance(tc, dict): + fn = tc.get("function", {}) + name = fn.get("name", "?") + args = fn.get("arguments", "") + tool_info.append(f"{name}({args})") + elif hasattr(tc, "function"): + tool_info.append( + f"{tc.function.name}({tc.function.arguments})" + ) if tool_info: - formatted_parts.append(f"ASSISTANT (tools): {', '.join(tool_info)}") + formatted_parts.append( + f"ASSISTANT called tools: {', '.join(tool_info)}" + ) if content: formatted_parts.append(f"ASSISTANT: {content}") + elif role == "tool": - # Include important tool outputs - if len(str(content)) < 500: # Only include short outputs - formatted_parts.append(f"TOOL OUTPUT: {content}") + raw = str(content) if content else "" + if len(raw) <= TOOL_OUTPUT_KEEP: + formatted_parts.append(f"TOOL OUTPUT:\n{raw}") else: - formatted_parts.append(f"TOOL OUTPUT: [Long output truncated]") - - return "\n\n".join(formatted_parts[-50:]) # Limit to last 50 exchanges + head = raw[:TOOL_OUTPUT_KEEP] + formatted_parts.append( + f"TOOL OUTPUT (truncated to {TOOL_OUTPUT_KEEP} chars):\n{head}\n[...truncated]" + ) + + return "\n\n".join(formatted_parts[-MAX_PARTS:]) def _get_current_agent_name(self) -> Optional[str]: """Get the name of the current active agent.""" diff --git a/src/cai/sdk/agents/models/openai_chatcompletions.py b/src/cai/sdk/agents/models/openai_chatcompletions.py index 8931edd63..324b0fdaa 100644 --- a/src/cai/sdk/agents/models/openai_chatcompletions.py +++ b/src/cai/sdk/agents/models/openai_chatcompletions.py @@ -363,6 +363,13 @@ def count_tokens_with_tiktoken(text_or_messages): return 0, 0 +class ContextCompactedError(Exception): + """Raised inside get_response/stream_response when a CAI_SUPPORT_INTERVAL-based + auto-compact fires mid-runner. The outer CLI loop catches this, sets + _post_compact_input, and restarts the runner with a clean context window.""" + pass + + class OpenAIChatCompletionsModel(Model): """OpenAI Chat Completions Model""" @@ -463,6 +470,49 @@ def __del__(self): # Ignore any errors during cleanup pass + async def cleanup(self) -> None: + """Explicitly cleanup underlying clients and free instance registry. + + This is intended to be called when a temporary model instance (for + example the summary/support model) is no longer needed. It will try + to close the HTTP/async client if available, remove the instance + from the legacy `ACTIVE_MODEL_INSTANCES` registry and clear the + in-memory message history so any backing LLM server can free slots + or context. + """ + try: + client = getattr(self, "_client", None) + if client is not None: + aclose = getattr(client, "aclose", None) + if aclose: + try: + res = aclose() + # Await if it's awaitable + if inspect.isawaitable(res): + await res + except Exception: + # Best-effort close + pass + try: + delattr(self, "_client") + except Exception: + pass + except Exception: + pass + + try: + key = (getattr(self, '_display_name', None), getattr(self, 'agent_id', None)) + if key in ACTIVE_MODEL_INSTANCES: + del ACTIVE_MODEL_INSTANCES[key] + except Exception: + pass + + try: + if hasattr(self, 'message_history') and isinstance(self.message_history, list): + self.message_history.clear() + except Exception: + pass + def add_to_message_history(self, msg): """Add a message to this instance's history if it's not a duplicate. @@ -544,20 +594,12 @@ async def get_response( | {"base_url": str(self._get_client().base_url)}, disabled=tracing.is_disabled(), ) as span_generation: - # Prepare the messages for consistent token counting - # IMPORTANT: Include existing message history for context + # Prepare the messages for consistent token counting. + # History is already included in `input` via cli.py's history_context mechanism + # (history_context = agent.model.message_history is passed as conversation_input + # to Runner.run, which then passes it as original_input to get_response). + # Prepending message_history here would double-count every message. converted_messages = [] - - # First, add all existing messages from history - if self.message_history: - for msg in self.message_history: - msg_copy = msg.copy() # Use copy to avoid modifying original - # Remove any existing cache_control to avoid exceeding the 4-block limit - if "cache_control" in msg_copy: - del msg_copy["cache_control"] - converted_messages.append(msg_copy) - - # Then convert and add the new input new_messages = self._converter.items_to_messages(input, model_instance=self) converted_messages.extend(new_messages) @@ -2545,19 +2587,12 @@ async def _fetch_response( # start by re-fetching self.is_ollama self.is_ollama = os.getenv("OLLAMA") is not None and os.getenv("OLLAMA").lower() == "true" - # IMPORTANT: Include existing message history for context + # Build the message list from `input` only. + # History is already included in `input` via cli.py's history_context mechanism: + # cli.py passes history_context (= message_history) as part of conversation_input + # to Runner.run, which passes it as original_input through to _fetch_response. + # Prepending message_history again would send every historical message twice. converted_messages = [] - - # First, add all existing messages from history - if self.message_history: - for msg in self.message_history: - msg_copy = msg.copy() # Use copy to avoid modifying original - # Remove any existing cache_control to avoid exceeding the 4-block limit - if "cache_control" in msg_copy: - del msg_copy["cache_control"] - converted_messages.append(msg_copy) - - # Then convert and add the new input new_messages = self._converter.items_to_messages(input, model_instance=self) converted_messages.extend(new_messages) @@ -3490,7 +3525,89 @@ async def _auto_compact_if_needed(self, estimated_tokens: int, input: str | list # Check if auto-compaction is disabled if os.getenv("CAI_AUTO_COMPACT", "true").lower() == "false": return input, system_instructions, False - + + # --- CAI_SUPPORT_INTERVAL count-based trigger --- + # This fires on EVERY API call (not just at the outer CLI-loop level), so it correctly + # handles agentic sessions where the agent makes many tool calls inside one Runner.run. + _support_model = os.getenv("CAI_SUPPORT_MODEL") + _support_interval_raw = os.getenv("CAI_SUPPORT_INTERVAL") + if _support_model and _support_interval_raw: + try: + _support_interval = int(_support_interval_raw) + if _support_interval > 0: + _asst_count = sum( + 1 for m in self.message_history + if isinstance(m, dict) and m.get("role") == "assistant" + ) + if _asst_count >= _support_interval: + from rich.console import Console as _Console + _console = _Console() + _console.print( + f"\n[bold yellow]⟳ Auto-compact: {_asst_count} LLM responses " + f"(threshold {_support_interval}) — summarising with " + f"{_support_model}[/bold yellow]" + ) + try: + from cai.repl.commands.memory import ( + MEMORY_COMMAND_INSTANCE, + COMPACTED_SUMMARIES, + APPLIED_MEMORY_IDS, + ) + from cai.repl.commands.compact import COMPACT_COMMAND_INSTANCE + _orig_compact = COMPACT_COMMAND_INSTANCE.compact_model + COMPACT_COMMAND_INSTANCE.compact_model = _support_model + try: + _summary = await MEMORY_COMMAND_INSTANCE._ai_summarize_history( + self.agent_name + ) + finally: + COMPACT_COMMAND_INSTANCE.compact_model = _orig_compact + if _summary: + if self.agent_name not in COMPACTED_SUMMARIES: + COMPACTED_SUMMARIES[self.agent_name] = [] + APPLIED_MEMORY_IDS[self.agent_name] = [] + COMPACTED_SUMMARIES[self.agent_name] = [_summary] + self.message_history.clear() + # Re-inject the summary as the first exchange so + # the next Runner turn has full context and won't + # repeat work that was already attempted. + self.message_history.append({ + "role": "user", + "content": ( + "\n" + + _summary + + "\n\n\n" + "This is your memory from the previous context window. " + "Use it to continue your work. " + "Do NOT retry any approach already marked as failed or exhausted." + ), + }) + self.message_history.append({ + "role": "assistant", + "content": ( + "Understood. I have reviewed my previous session memory. " + "I will continue the task using only new approaches " + "and will not repeat anything already attempted." + ), + }) + os.environ["CAI_CONTEXT_USAGE"] = "0.0" + _console.print( + "[bold green]✓ Memory summary applied — " + "context window reset — restarting task[/bold green]\n" + ) + except Exception as _ce: + _console.print(f"[red]Auto-compact error: {_ce}[/red]") + # Always abort the current runner invocation so the outer loop + # can restart with our freshly cleared context. + raise ContextCompactedError( + f"Context compacted after {_asst_count} LLM responses " + f"(threshold {_support_interval})" + ) + except ContextCompactedError: + raise # propagate to the outer runner / CLI loop + except (ValueError, Exception): + pass # malformed interval — ignore silently + max_tokens = self._get_model_max_tokens(str(self.model)) threshold_percent = float(os.getenv("CAI_AUTO_COMPACT_THRESHOLD", "0.8")) threshold = max_tokens * threshold_percent