diff --git a/src/scanner/ast_parser.py b/src/scanner/ast_parser.py index 66e2b07..07367d0 100644 --- a/src/scanner/ast_parser.py +++ b/src/scanner/ast_parser.py @@ -317,24 +317,36 @@ def _extract_calls( ) -> List[ParsedCall]: """Extract function calls from within each symbol's AST subtree.""" calls: List[ParsedCall] = [] - known_names = {s.name for s in symbols} - for node in ast.walk(tree): - if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - continue - - caller = node.name - for child in ast.walk(node): - if not isinstance(child, ast.Call): - continue - - callee = self._call_to_name(child) - if callee and callee != caller: - calls.append(ParsedCall( - caller_name=caller, - callee_name=callee, - is_direct=True, - )) + class CallVisitor(ast.NodeVisitor): + def __init__(self, parser): + self.parser = parser + self.caller_stack = [] + + def visit_FunctionDef(self, node): + self.caller_stack.append(node.name) + self.generic_visit(node) + self.caller_stack.pop() + + def visit_AsyncFunctionDef(self, node): + self.caller_stack.append(node.name) + self.generic_visit(node) + self.caller_stack.pop() + + def visit_Call(self, node): + callee = self.parser._call_to_name(node) + if callee: + for caller in self.caller_stack: + if callee != caller: + calls.append(ParsedCall( + caller_name=caller, + callee_name=callee, + is_direct=True, + )) + self.generic_visit(node) + + visitor = CallVisitor(self) + visitor.visit(tree) return calls @@ -400,19 +412,45 @@ def _decorator_to_str(self, node: ast.expr) -> str: def _compute_complexity(self, node: ast.AST) -> int: """Compute cyclomatic complexity from AST nodes. No LLM needed.""" - complexity = 1 - for child in ast.walk(node): - if isinstance(child, (ast.If, ast.IfExp)): - complexity += 1 - elif isinstance(child, (ast.For, ast.AsyncFor, ast.While)): - complexity += 1 - elif isinstance(child, ast.ExceptHandler): - complexity += 1 - elif isinstance(child, ast.BoolOp): - complexity += len(child.values) - 1 - elif isinstance(child, ast.Assert): - complexity += 1 - return complexity + class ComplexityVisitor(ast.NodeVisitor): + def __init__(self): + self.complexity = 1 + + def visit_If(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_IfExp(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_For(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_AsyncFor(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_While(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_ExceptHandler(self, node): + self.complexity += 1 + self.generic_visit(node) + + def visit_BoolOp(self, node): + self.complexity += len(node.values) - 1 + self.generic_visit(node) + + def visit_Assert(self, node): + self.complexity += 1 + self.generic_visit(node) + + visitor = ComplexityVisitor() + visitor.visit(node) + return visitor.complexity # ---------------------------------------------------------------------------