Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .jules/bolt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

## 2025-03-09 - AST Parsing Performance Bottleneck with `ast.walk`
**Learning:** Using nested `ast.walk` loops to traverse AST subtrees for symbol operations (like call extraction) results in massive O(N^2) performance bottlenecks, as `ast.walk` lacks targeted node visitation and must queue/yield the entire subtree repeatedly. It also incurs high generator overhead for simpler operations like complexity computation.
**Action:** Always prefer `ast.NodeVisitor` over `ast.walk` for AST traversals, especially when processing nested structures or specific node types, as it provides a single O(N) pass with minimal overhead.
113 changes: 76 additions & 37 deletions src/scanner/ast_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
import ast
import hashlib
import logging
import textwrap
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Set
from typing import List, Optional

logger = logging.getLogger("xmem.scanner.parser")

Expand Down Expand Up @@ -173,22 +171,26 @@ def parse_file(self, file_path: str, content: str) -> ParsedFile:
def _extract_imports(self, tree: ast.Module) -> List[ParsedImport]:
imports: List[ParsedImport] = []

for node in ast.walk(tree):
if isinstance(node, ast.Import):
class ImportVisitor(ast.NodeVisitor):
def visit_Import(self, node):
for alias in node.names:
imports.append(ParsedImport(
module=alias.name,
alias=alias.asname,
))
elif isinstance(node, ast.ImportFrom):
self.generic_visit(node)

def visit_ImportFrom(self, node):
module = node.module or ""
names = [a.name for a in node.names]
imports.append(ParsedImport(
module=module,
names=names,
is_relative=node.level > 0,
))
self.generic_visit(node)

ImportVisitor().visit(tree)
return imports

# -- Symbols -----------------------------------------------------------
Expand Down Expand Up @@ -317,25 +319,36 @@ def _extract_calls(
) -> List[ParsedCall]:
"""Extract function calls from within each symbol's AST subtree."""
calls: List[ParsedCall] = []
known_names = {s.name for s in symbols}

for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue

caller = node.name
for child in ast.walk(node):
if not isinstance(child, ast.Call):
continue

callee = self._call_to_name(child)
if callee and callee != caller:
calls.append(ParsedCall(
caller_name=caller,
callee_name=callee,
is_direct=True,
))
parser = self

class CallVisitor(ast.NodeVisitor):
def __init__(self):
self.current_caller = None

def visit_FunctionDef(self, node):
prev = self.current_caller
self.current_caller = node.name
self.generic_visit(node)
self.current_caller = prev

def visit_AsyncFunctionDef(self, node):
prev = self.current_caller
self.current_caller = node.name
self.generic_visit(node)
self.current_caller = prev

def visit_Call(self, node):
if self.current_caller:
callee = parser._call_to_name(node)
if callee and callee != self.current_caller:
calls.append(ParsedCall(
caller_name=self.current_caller,
callee_name=callee,
is_direct=True,
))
self.generic_visit(node)

CallVisitor().visit(tree)
return calls

def _call_to_name(self, node: ast.Call) -> Optional[str]:
Expand Down Expand Up @@ -400,19 +413,45 @@ def _decorator_to_str(self, node: ast.expr) -> str:

def _compute_complexity(self, node: ast.AST) -> int:
"""Compute cyclomatic complexity from AST nodes. No LLM needed."""
complexity = 1
for child in ast.walk(node):
if isinstance(child, (ast.If, ast.IfExp)):
complexity += 1
elif isinstance(child, (ast.For, ast.AsyncFor, ast.While)):
complexity += 1
elif isinstance(child, ast.ExceptHandler):
complexity += 1
elif isinstance(child, ast.BoolOp):
complexity += len(child.values) - 1
elif isinstance(child, ast.Assert):
complexity += 1
return complexity
class ComplexityVisitor(ast.NodeVisitor):
def __init__(self):
self.complexity = 1

def visit_If(self, n):
self.complexity += 1
self.generic_visit(n)

def visit_IfExp(self, n):
self.complexity += 1
self.generic_visit(n)

def visit_For(self, n):
self.complexity += 1
self.generic_visit(n)

def visit_AsyncFor(self, n):
self.complexity += 1
self.generic_visit(n)

def visit_While(self, n):
self.complexity += 1
self.generic_visit(n)

def visit_ExceptHandler(self, n):
self.complexity += 1
self.generic_visit(n)

def visit_BoolOp(self, n):
self.complexity += len(n.values) - 1
self.generic_visit(n)

def visit_Assert(self, n):
self.complexity += 1
self.generic_visit(n)

visitor = ComplexityVisitor()
visitor.visit(node)
return visitor.complexity


# ---------------------------------------------------------------------------
Expand Down