Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 44 additions & 53 deletions src/api/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# --------------------------------------------------------------------

import symtable
from collections.abc import Generator
from collections.abc import Callable, Generator
from typing import Any, NamedTuple

import src.api.check as chk
Expand All @@ -17,28 +17,17 @@
from src.api import errmsg
from src.api.config import OPTIONS
from src.api.constants import CLASS, CONVENTION, SCOPE, TYPE
from src.api.debug import __DEBUG__
from src.api.errmsg import warning_not_used
from src.ast import Ast, NodeVisitor
from src.symbols import sym as symbols
from src.symbols.id_ import ref


class ToVisit(NamedTuple):
"""Used just to signal an object to be
traversed.
"""

obj: symbols.SYMBOL


class GenericVisitor(NodeVisitor):
"""A slightly different visitor, that just traverses an AST, but does not return
a translation of it. Used to examine the AST or do transformations
"""

node_type = ToVisit

@property
def O_LEVEL(self):
return OPTIONS.optimization_level
Expand All @@ -58,36 +47,42 @@ def TYPE(type_):
assert TYPE.is_valid(type_)
return gl.SYMBOL_TABLE.basic_types[type_]

def visit(self, node):
return super().visit(ToVisit(node))

def _visit(self, node: ToVisit):
if node.obj is None:
return None

__DEBUG__(f"Optimizer: Visiting node {node.obj!s}[{node.obj.token}]", 1)
meth = getattr(self, f"visit_{node.obj.token}", self.generic_visit)
return meth(node.obj)

def generic_visit(self, node: Ast) -> Generator[Ast | None, Any, None]:
for i, child in enumerate(node.children):
node.children[i] = yield self.visit(child)

yield node


class UniqueVisitor(GenericVisitor):
def __init__(self):
super().__init__()
self.visited = set()

def _visit(self, node: ToVisit):
if node.obj in self.visited:
return node.obj
def _visit(self, node: Ast):
if node in self.visited:
return node

self.visited.add(node.obj)
self.visited.add(node)
return super()._visit(node)

def filter_inorder(
self,
node,
filter_func: Callable[[Any], bool],
child_selector: Callable[[Ast], bool] = lambda x: True,
) -> Generator[Ast, None, None]:
"""Visit the tree inorder, but only those that return true for filter_func and visiting children which
return true for child_selector.
"""
visited = set()
stack = [node]
while stack:
node = stack.pop()
if node in visited:
continue

visited.add(node)
if filter_func(node):
yield self.visit(node)

if isinstance(node, Ast) and child_selector(node):
stack.extend(node.children[::-1])


class UnreachableCodeVisitor(UniqueVisitor):
"""Visitor to optimize unreachable code (and prune it)."""
Expand All @@ -107,7 +102,7 @@ def visit_FUNCTION(self, node: symbols.ID):
if type_ is not None and type_ == self.TYPE(TYPE.string):
node.body.append(symbols.ASM("\nld hl, 0\n", lineno, node.filename, is_sentinel=True))

yield (yield self.generic_visit(node))
yield self.generic_visit(node)

def visit_BLOCK(self, node):
# Remove CHKBREAK after labels
Expand Down Expand Up @@ -155,7 +150,7 @@ def visit_BLOCK(self, node):
yield self.NOP
return

yield (yield self.generic_visit(node))
yield self.generic_visit(node)


class FunctionGraphVisitor(UniqueVisitor):
Expand All @@ -165,6 +160,7 @@ def _get_calls_from_children(self, node: symtable.Symbol):
return list(self.filter_inorder(node, lambda x: x.token in ("CALL", "FUNCCALL")))

def _set_children_as_accessed(self, node: symbols.SYMBOL):
""" "Traverse only those"""
parent = node.get_parent(symbols.FUNCDECL)
if parent is None: # Global scope?
for symbol in self._get_calls_from_children(node):
Expand Down Expand Up @@ -314,7 +310,7 @@ def visit_FUNCDECL(self, node):
if self.O_LEVEL > 1 and node.params_size == node.locals_size == 0:
node.entry.ref.convention = CONVENTION.fastcall

node.children[1] = yield ToVisit(node.entry)
node.children[1] = yield self.visit(node.entry)
yield node

def visit_LET(self, node):
Expand Down Expand Up @@ -370,19 +366,20 @@ def visit_RETURN(self, node):
might cause infinite recursion.
"""
if len(node.children) == 2:
node.children[1] = yield ToVisit(node.children[1])
node.children[1] = yield self.visit(node.children[1])

yield node

def visit_UNARY(self, node):
if node.operator == "ADDRESS":
yield (yield self.visit_ADDRESS(node))
yield self.visit_ADDRESS(node)
else:
yield (yield self.generic_visit(node))
yield self.generic_visit(node)

def visit_IF(self, node):
expr_ = yield ToVisit(node.children[0])
then_ = yield ToVisit(node.children[1])
else_ = (yield ToVisit(node.children[2])) if len(node.children) == 3 else self.NOP
expr_ = yield self.visit(node.children[0])
then_ = yield self.visit(node.children[1])
else_ = (yield self.visit(node.children[2])) if len(node.children) == 3 else self.NOP

if self.O_LEVEL >= 1:
if chk.is_null(then_, else_):
Expand All @@ -405,6 +402,7 @@ def visit_IF(self, node):

for i in range(len(node.children)):
node.children[i] = (expr_, then_, else_)[i]

yield node

def visit_WHILE(self, node):
Expand All @@ -419,6 +417,7 @@ def visit_WHILE(self, node):

for i, child in enumerate((expr_, body_)):
node.children[i] = child

yield node

def visit_FOR(self, node):
Expand All @@ -433,6 +432,7 @@ def visit_FOR(self, node):
if from_.value > to_.value and step_.value > 0:
yield self.NOP
return

if from_.value < to_.value and step_.value < 0:
yield self.NOP
return
Expand All @@ -446,12 +446,6 @@ def _visit_LABEL(self, node):
else:
yield node

def generic_visit(self, node: Ast):
for i, child in enumerate(node.children):
node.children[i] = yield ToVisit(child)

yield node

def _check_if_any_arg_is_an_array_and_needs_lbound_or_ubound(
self, params: symbols.PARAMLIST, args: symbols.ARGLIST
):
Expand Down Expand Up @@ -502,10 +496,7 @@ class VariableVisitor(GenericVisitor):
def generic_visit(self, node: Ast):
if node not in VariableVisitor._visited:
VariableVisitor._visited.add(node)
for i in range(len(node.children)):
node.children[i] = yield ToVisit(node.children[i])

yield node
yield super().generic_visit(node)

def has_circular_dependency(self, var_dependency: VarDependency) -> bool:
if var_dependency.dependency == VariableVisitor._original_variable:
Expand All @@ -532,7 +523,7 @@ def visit_var(entry):
if entry.token != "VAR":
for child in entry.children:
visit_var(child)
if child.token in ("FUNCTION", "LABEL", "VAR", "VARARRAY"):
if child.token in {"FUNCTION", "LABEL", "VAR", "VARARRAY"}:
result.add(VarDependency(parent=VariableVisitor._parent_variable, dependency=child))
return

Expand Down
11 changes: 9 additions & 2 deletions src/arch/z80/visitor/builtin_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ class BuiltinTranslator(TranslatorVisitor):

REQUIRES = backend.REQUIRES

def __init__(self, backend: backend.Backend, parent_visitor: TranslatorVisitor):
super().__init__(backend)
self.parent_visitor = parent_visitor

def visit(self, node):
return self.parent_visitor.visit(node)

# region STRING Functions
def visit_INKEY(self, node):
self.runtime_call(RuntimeLabel.INKEY, Type.string.size)
Expand Down Expand Up @@ -125,7 +132,7 @@ def visit_SQR(self, node):
# endregion

def visit_LBOUND(self, node):
yield node.operands[1]
yield self.visit(node.operands[1])
self.ic_param(gl.BOUND_TYPE, node.operands[1].t)
entry = node.operands[0]
if entry.scope == SCOPE.global_:
Expand All @@ -141,7 +148,7 @@ def visit_LBOUND(self, node):
self.runtime_call(RuntimeLabel.LBOUND, self.TYPE(gl.BOUND_TYPE).size)

def visit_UBOUND(self, node):
yield node.operands[1]
yield self.visit(node.operands[1])
self.ic_param(gl.BOUND_TYPE, node.operands[1].t)
entry = node.operands[0]
if entry.scope == SCOPE.global_:
Expand Down
4 changes: 2 additions & 2 deletions src/arch/z80/visitor/function_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class FunctionTranslator(Translator):
REQUIRES = backend.REQUIRES

def __init__(self, backend: Backend, function_list: list[symbols.ID]):
super().__init__(backend)
if function_list is None:
function_list = []
super().__init__(backend)

assert isinstance(function_list, list)
assert all(x.token == "FUNCTION" for x in function_list)
Expand Down Expand Up @@ -115,7 +115,7 @@ def visit_FUNCTION(self, node):
self.ic_lvard(local_var.offset, q)

for i in node.ref.body:
yield i
yield self.visit(i)

self.ic_label("%s__leave" % node.mangled)

Expand Down
Loading