diff --git a/.jules/bolt.md b/.jules/bolt.md new file mode 100644 index 0000000..42c84d8 --- /dev/null +++ b/.jules/bolt.md @@ -0,0 +1,4 @@ + +## 2025-03-09 - AST Parsing Performance Bottleneck with `ast.walk` +**Learning:** Using nested `ast.walk` loops to traverse AST subtrees for symbol operations (like call extraction) results in massive O(N^2) performance bottlenecks, as `ast.walk` lacks targeted node visitation and must queue/yield the entire subtree repeatedly. It also incurs high generator overhead for simpler operations like complexity computation. +**Action:** Always prefer `ast.NodeVisitor` over `ast.walk` for AST traversals, especially when processing nested structures or specific node types, as it provides a single O(N) pass with minimal overhead. diff --git a/src/scanner/ast_parser.py b/src/scanner/ast_parser.py index 66e2b07..92d16bd 100644 --- a/src/scanner/ast_parser.py +++ b/src/scanner/ast_parser.py @@ -13,10 +13,8 @@ import ast import hashlib import logging -import textwrap from dataclasses import dataclass, field -from pathlib import Path -from typing import Dict, List, Optional, Set +from typing import List, Optional logger = logging.getLogger("xmem.scanner.parser") @@ -173,14 +171,16 @@ def parse_file(self, file_path: str, content: str) -> ParsedFile: def _extract_imports(self, tree: ast.Module) -> List[ParsedImport]: imports: List[ParsedImport] = [] - for node in ast.walk(tree): - if isinstance(node, ast.Import): + class ImportVisitor(ast.NodeVisitor): + def visit_Import(self, node): for alias in node.names: imports.append(ParsedImport( module=alias.name, alias=alias.asname, )) - elif isinstance(node, ast.ImportFrom): + self.generic_visit(node) + + def visit_ImportFrom(self, node): module = node.module or "" names = [a.name for a in node.names] imports.append(ParsedImport( @@ -188,7 +188,9 @@ def _extract_imports(self, tree: ast.Module) -> List[ParsedImport]: names=names, is_relative=node.level > 0, )) + self.generic_visit(node) + ImportVisitor().visit(tree) return imports # -- Symbols ----------------------------------------------------------- @@ -317,25 +319,36 @@ def _extract_calls( ) -> List[ParsedCall]: """Extract function calls from within each symbol's AST subtree.""" calls: List[ParsedCall] = [] - known_names = {s.name for s in symbols} - - for node in ast.walk(tree): - if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - continue - - caller = node.name - for child in ast.walk(node): - if not isinstance(child, ast.Call): - continue - - callee = self._call_to_name(child) - if callee and callee != caller: - calls.append(ParsedCall( - caller_name=caller, - callee_name=callee, - is_direct=True, - )) + parser = self + + class CallVisitor(ast.NodeVisitor): + def __init__(self): + self.current_caller = None + + def visit_FunctionDef(self, node): + prev = self.current_caller + self.current_caller = node.name + self.generic_visit(node) + self.current_caller = prev + + def visit_AsyncFunctionDef(self, node): + prev = self.current_caller + self.current_caller = node.name + self.generic_visit(node) + self.current_caller = prev + + def visit_Call(self, node): + if self.current_caller: + callee = parser._call_to_name(node) + if callee and callee != self.current_caller: + calls.append(ParsedCall( + caller_name=self.current_caller, + callee_name=callee, + is_direct=True, + )) + self.generic_visit(node) + CallVisitor().visit(tree) return calls def _call_to_name(self, node: ast.Call) -> Optional[str]: @@ -400,19 +413,45 @@ def _decorator_to_str(self, node: ast.expr) -> str: def _compute_complexity(self, node: ast.AST) -> int: """Compute cyclomatic complexity from AST nodes. No LLM needed.""" - complexity = 1 - for child in ast.walk(node): - if isinstance(child, (ast.If, ast.IfExp)): - complexity += 1 - elif isinstance(child, (ast.For, ast.AsyncFor, ast.While)): - complexity += 1 - elif isinstance(child, ast.ExceptHandler): - complexity += 1 - elif isinstance(child, ast.BoolOp): - complexity += len(child.values) - 1 - elif isinstance(child, ast.Assert): - complexity += 1 - return complexity + class ComplexityVisitor(ast.NodeVisitor): + def __init__(self): + self.complexity = 1 + + def visit_If(self, n): + self.complexity += 1 + self.generic_visit(n) + + def visit_IfExp(self, n): + self.complexity += 1 + self.generic_visit(n) + + def visit_For(self, n): + self.complexity += 1 + self.generic_visit(n) + + def visit_AsyncFor(self, n): + self.complexity += 1 + self.generic_visit(n) + + def visit_While(self, n): + self.complexity += 1 + self.generic_visit(n) + + def visit_ExceptHandler(self, n): + self.complexity += 1 + self.generic_visit(n) + + def visit_BoolOp(self, n): + self.complexity += len(n.values) - 1 + self.generic_visit(n) + + def visit_Assert(self, n): + self.complexity += 1 + self.generic_visit(n) + + visitor = ComplexityVisitor() + visitor.visit(node) + return visitor.complexity # ---------------------------------------------------------------------------