From 7b912bd959736d2ce114693dc8d818c28ce0ac5d Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 4 Mar 2026 16:59:19 -0800 Subject: [PATCH 1/5] [tune][xegpu] infrastructure for tuning, applied to XeGPU matmul example --- examples/xegpu/matmul.py | 52 ++-- lighthouse/dialects/__init__.py | 8 + lighthouse/dialects/smt_ext.py | 93 +++++++ lighthouse/dialects/transform_smt_ext.py | 244 +++++++++++++++++ lighthouse/dialects/transform_tune_ext.py | 182 +++++++++++++ lighthouse/schedule/xegpu/mlp_schedule.py | 275 ++++++++++++++----- lighthouse/tune/__main__.py | 92 +++++++ lighthouse/tune/enumerate.py | 20 ++ lighthouse/tune/rewrite.py | 20 ++ lighthouse/tune/trace.py | 315 ++++++++++++++++++++++ lighthouse/utils/types.py | 24 ++ lighthouse/workload/workload.py | 7 +- 12 files changed, 1241 insertions(+), 91 deletions(-) create mode 100644 lighthouse/dialects/__init__.py create mode 100644 lighthouse/dialects/smt_ext.py create mode 100644 lighthouse/dialects/transform_smt_ext.py create mode 100644 lighthouse/dialects/transform_tune_ext.py create mode 100644 lighthouse/tune/__main__.py create mode 100644 lighthouse/tune/enumerate.py create mode 100644 lighthouse/tune/rewrite.py create mode 100644 lighthouse/tune/trace.py create mode 100644 lighthouse/utils/types.py diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index 24fd224..50b63ce 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -17,6 +17,7 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -177,13 +178,13 @@ def payload_module(self) -> ir.Module: def schedule_module( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> ir.Module: + assert parameters is not None, "Schedule parameters must be provided" return get_schedule_module( has_bias=self.has_bias, has_relu=self.has_relu, has_convert_c=False, stop_at_stage=stop_at_stage, - nlayers=1, - params={"layer_0": parameters}, + params=[parameters], ) def shared_libs(self) -> list[str]: @@ -195,6 +196,9 @@ def parse_cli(): description="Matrix Multiplication using MLIR", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument( + "--all-knobs", action="store_true", help="Use knobs for all schedule parameters" + ) parser.add_argument( "--sizes", type=int, @@ -226,28 +230,28 @@ def parse_cli(): "--load-tile-a", type=int, nargs=2, - default=[32, 16], + default=[16, 32], help="Tile size for loading A matrix for DPAS op.", ) parser.add_argument( "--load-tile-b", type=int, nargs=2, - default=[32, 16], + default=[16, 16], help="Tile size for loading B matrix for DPAS op.", ) parser.add_argument( "--prefetch-tile-a", type=int, nargs=2, - default=[8, 32], + default=[16, 32], help="Tile size for cooperative prefetching of subgroup A matrix", ) parser.add_argument( "--prefetch-tile-b", type=int, nargs=2, - default=[8, 16], + default=[16, 16], help="Tile size for cooperative prefetching of subgroup B matrix", ) parser.add_argument( @@ -311,28 +315,34 @@ def parse_cli(): if __name__ == "__main__": args = parse_cli() + M, N, K = args.sizes + params = { - "wg_m": args.wg_tile[0], - "wg_n": args.wg_tile[1], - "sg_m": args.sg_tile[0], - "sg_n": args.sg_tile[1], - "k": args.k_tile, - "load_a_m": args.load_tile_a[0], - "load_a_k": args.load_tile_a[1], - "load_b_k": args.load_tile_b[0], - "load_b_n": args.load_tile_b[1], - "pf_a_m": args.prefetch_tile_a[0], - "pf_a_k": args.prefetch_tile_a[1], - "pf_b_k": args.prefetch_tile_b[0], - "pf_b_n": args.prefetch_tile_b[1], - "pf_nb": args.nb_prefetch, + "m": M, + "n": N, + "k": K, + "wg_m": None if args.all_knobs else args.wg_tile[0], + "wg_n": None if args.all_knobs else args.wg_tile[1], + "sg_m": None if args.all_knobs else args.sg_tile[0], + "sg_n": None if args.all_knobs else args.sg_tile[1], + "k_tile": None if args.all_knobs else args.k_tile, + "load_a_m": None if args.all_knobs else args.load_tile_a[0], + "load_a_k": None if args.all_knobs else args.load_tile_a[1], + "load_b_k": None if args.all_knobs else args.load_tile_b[0], + "load_b_n": None if args.all_knobs else args.load_tile_b[1], + "prefetch_a_m": None if args.all_knobs else args.prefetch_tile_a[0], + "prefetch_a_k": None if args.all_knobs else args.prefetch_tile_a[1], + "prefetch_b_k": None if args.all_knobs else args.prefetch_tile_b[0], + "prefetch_b_n": None if args.all_knobs else args.prefetch_tile_b[1], + "nb_prefetch": args.nb_prefetch, } - M, N, K = args.sizes ab_type = "f16" c_type = "f32" with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUMatMul( M=M, N=N, diff --git a/lighthouse/dialects/__init__.py b/lighthouse/dialects/__init__.py new file mode 100644 index 0000000..450585f --- /dev/null +++ b/lighthouse/dialects/__init__.py @@ -0,0 +1,8 @@ +def register_and_load(): + from . import smt_ext + from . import transform_smt_ext + from . import transform_tune_ext + + smt_ext.register_and_load() + transform_smt_ext.register_and_load() + transform_tune_ext.register_and_load() diff --git a/lighthouse/dialects/smt_ext.py b/lighthouse/dialects/smt_ext.py new file mode 100644 index 0000000..fdd9930 --- /dev/null +++ b/lighthouse/dialects/smt_ext.py @@ -0,0 +1,93 @@ +from typing import Callable + +from mlir import ir +from mlir.dialects import smt + +__all__ = ["SMTIntValue", "assert_", "register_and_load"] + + +def register_and_load(context=None): + SMTIntValue.register_value_caster() + + +def assert_(predicate: ir.Value[smt.BoolType] | bool): + """Assert normally if a bool else produce an SMT assertion op.""" + if isinstance(predicate, bool): + assert predicate + else: + smt.assert_(predicate) + + +def int_to_smt(operand: "int | SMTIntValue") -> "SMTIntValue": + if isinstance(operand, int): + int_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), operand) + return SMTIntValue(smt.int_constant(int_attr)) + return operand + + +def swapped( + f: Callable[["int | SMTIntValue", "int | SMTIntValue"], "int | SMTIntValue"], +) -> Callable[["int | SMTIntValue", "int | SMTIntValue"], "int | SMTIntValue"]: + return lambda a, b: f(b, a) + + +class SMTIntValue(ir.Value[smt.IntType]): + def __init__(self, v): + super().__init__(v) + + def __hash__(self): + return super().__hash__() + + @staticmethod + def register_value_caster(): + if not hasattr(SMTIntValue, "_is_registered"): + ir.register_value_caster(smt.IntType.static_typeid)(SMTIntValue) + setattr(SMTIntValue, "_is_registered", True) + + def __add__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_add([self, int_to_smt(rhs)])) + + def __radd__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_add([int_to_smt(lhs), self])) + + def __sub__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_sub(self, int_to_smt(rhs))) + + def __rsub__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_sub(int_to_smt(lhs), self)) + + def __mul__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_mul([self, int_to_smt(rhs)])) + + def __rmul__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_mul([int_to_smt(lhs), self])) + + def __floordiv__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_div(self, int_to_smt(rhs))) + + def __rfloordiv__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_div(int_to_smt(lhs), self)) + + def __mod__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_mod(self, int_to_smt(rhs))) + + def __rmod__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_mod(int_to_smt(lhs), self)) + + def __eq__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.eq([self, int_to_smt(rhs)]) + + def __le__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.int_cmp(smt.IntPredicate.le, self, int_to_smt(rhs)) + + def __lt__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.int_cmp(smt.IntPredicate.lt, self, int_to_smt(rhs)) + + def __ge__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.int_cmp(smt.IntPredicate.ge, self, int_to_smt(rhs)) + + def __gt__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.int_cmp(smt.IntPredicate.gt, self, int_to_smt(rhs)) + + def __str__(self): + return super().__str__().replace(ir.Value.__name__, SMTIntValue.__name__) diff --git a/lighthouse/dialects/transform_smt_ext.py b/lighthouse/dialects/transform_smt_ext.py new file mode 100644 index 0000000..4935c34 --- /dev/null +++ b/lighthouse/dialects/transform_smt_ext.py @@ -0,0 +1,244 @@ +from typing import overload, Sequence, Callable + +from mlir import ir +from mlir.dialects import ext, smt, transform + +from lighthouse.tune import trace + +__all__ = [ + "ConstrainParamsOp", + "TransformSMTDialectExtension", + "constrain_params", + "register_and_load", +] + +def register_and_load(context=None): + TransformSMTDialectExtension.load() + + +class TransformSMTDialectExtension(ext.Dialect, name="transform_smt_ext"): + @classmethod + def load(cls, *args, **kwargs): + super(TransformSMTDialectExtension, cls).load(*args, **kwargs) + + for op in cls.operations: + if hasattr(op, "attach_interfaces"): + op.attach_interfaces() + + +class ConstrainParamsOp( + TransformSMTDialectExtension.Operation, name="constrain_params" +): + results_: Sequence[ext.Result[transform.AnyParamType]] + params: Sequence[ext.Operand[transform.AnyParamType]] + body_: ext.Region + + @property + def body(self): + return self.body_.blocks[0] + + @classmethod + def attach_interfaces(cls, ctx=None): + if not hasattr(cls, "_interfaces_attached"): + cls.ConstrainParamsTransformOpInterfaceModel.attach( + cls.OPERATION_NAME, context=ctx + ) + cls.ConstrainParamsMemoryEffectsOpInterfaceModel.attach( + cls.OPERATION_NAME, context=ctx + ) + setattr(cls, "_interfaces_attached", True) + + class ConstrainParamsTransformOpInterfaceModel(transform.TransformOpInterface): + @staticmethod + def apply( + op: "ConstrainParamsOp", + _rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> transform.DiagnosedSilenceableFailure: + env = dict() + for operand in op.params: + params = state.get_params(operand) + assert len(params) == 1 and isinstance(params[0].value, int) + env[operand] = trace.Constant(params[0].value) + + env = trace.trace_tune_and_smt_ops(op.operation, env) + + if not env[op].evaluate(env): # evaluate the conjoined predicate + return transform.DiagnosedSilenceableFailure.DefiniteFailure + + for result in op.results: + res_value = env[result].evaluate(env) + i64 = ir.IntegerType.get_signless(64) + results.set_params(result, [ir.IntegerAttr.get(i64, res_value)]) + + return transform.DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "ConstrainParamsOp") -> bool: + return False + + class ConstrainParamsMemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: "ConstrainParamsOp", effects): + if op.op_operands: + transform.only_reads_handle(op.op_operands, effects) + transform.produces_handle(op.results, effects) + transform.only_reads_payload(effects) + + +class MixedResultConstrainParamsOp(ConstrainParamsOp): + def __init__( + self, + *args, + result_values_or_types: Sequence[int | ir.Type], + **kwargs, + ): + result_types = [ + res for res in result_values_or_types if isinstance(res, ir.Type) + ] + super().__init__(result_types, *args, **kwargs) + op_results = iter(super().results) + self._results = [ + next(op_results) if isinstance(res, ir.Type) else res + for res in result_values_or_types + ] + + @property + def results(self) -> Sequence[int | ext.Result[transform.AnyParamType]]: + return self._results + + +# class ConstrainParamsOpDecorator(ConstrainParamsOp): +# def __init__( +# self, +# *params: transform.AnyParamType | int, +# results: Sequence[int | ext.Result[transform.AnyParamType]] | None = None, +# **kwargs, +# ): +# transform_params = [p for p in params if isinstance(p, ir.Value)] +# super().__init__([], transform_params, **kwargs) +# block_arg_types = [smt.IntType.get()] * len(transform_params) +# self.body_.blocks.append(*block_arg_types) +# +# self._arguments = [] +# self._results = results +# smt_arguments = iter(self.body.arguments) +# for param in params: +# if isinstance(param, int): +# self._arguments.append(param) +# else: +# self._arguments.append(next(smt_arguments)) +# +# @property +# def results(self) -> Sequence[ext.Result | int]: +# """Returns the yielded results of the decorated function, which are either +# integers or the transform parameters that correspond to the yielded SMT +# int values.""" +# assert self._results is not None, ( +# "Results are not available until the decorated function is called" +# ) +# return self._results +# +# def __call__(self, func): +# with ir.InsertionPoint(self.body): +# yielded_results = func(*self._arguments) +# +# smt.yield_(res for res in yielded_results if isinstance(res, ir.Value)) +# +# print(f"{yielded_results=}") +# if len(yielded_results) == 0: +# return self +# +# # In case of yielded results, we need to create a new ConstrainParamsOp with the same parameters and a body that contains the original body of the decorator, but with the yielded results as the results of the new op. We then replace the original op with the new one and return it. +# result_types = [transform.AnyParamType.get()] * sum( +# 1 for res in yielded_results if isinstance(res, ir.Value) +# ) +# with ir.InsertionPoint(self): +# self_with_results = ConstrainParamsOp( +# result_types, self.params, loc=self.location +# ) +# self.body_.blocks[0].append_to(self_with_results.body_) +# # new_block = self_with_results.body_.blocks.append( +# # *orig_block.arguments.types +# # ) +# # arg_mapping = dict(zip(orig_block.arguments, new_block.arguments)) +# # lh_utils_rewrite.move_block(orig_block, new_block, arg_mapping) +# # self.erase() +# +# results = [] +# op_results = iter(self_with_results.results) +# for yielded_result in yielded_results: +# if isinstance(yielded_result, int): +# results.append(yielded_result) +# elif isinstance(yielded_result, ir.Value): +# results.append(next(op_results)) +# else: +# assert False, "Unsupported yielded result type" +# setattr(self_with_results, "_results", results) +# return self_with_results + + +@overload +def constrain_params( + *params: ir.Value | int, loc=None, ip=None +) -> Callable[..., MixedResultConstrainParamsOp]: ... + + +@overload +def constrain_params( + results: Sequence[ir.Type], + params: Sequence[transform.AnyParamType], + arg_types: Sequence[ir.Type], + loc=None, + ip=None, +) -> ConstrainParamsOp: ... + + +def constrain_params( + *args, **kwargs +) -> ConstrainParamsOp | Callable[..., MixedResultConstrainParamsOp]: + # The second overload: + if len(args) == 0 or isinstance(args[0], ir.Type): + arg_types = kwargs.pop("arg_types") + op = ConstrainParamsOp(*args, **kwargs) + op.body_.blocks.append(*arg_types) + return op + + # The first overload: + # return ConstrainParamsOpDecorator(*args, **kwargs) + def wrapper(func): + param_args = [p for p in args if isinstance(p, ir.Value)] + constrain_params = ConstrainParamsOp([], param_args, **kwargs) + constrain_params.body_.blocks.append(*[smt.IntType.get()] * len(param_args)) + + block_args_iter = iter(constrain_params.body_.blocks[0].arguments) + with ir.InsertionPoint(constrain_params.body): + yielded_results = func( + *( + next(block_args_iter) if isinstance(arg, ir.Value) else arg + for arg in args + ) + ) + if not isinstance(yielded_results, Sequence): + yielded_results = [yielded_results] + smt.yield_(res for res in yielded_results if isinstance(res, ir.Value)) + + if len(yielded_results) == 0: + return constrain_params + + result_values_or_types = [ + transform.AnyParamType.get() if isinstance(res, ir.Value) else res + for res in yielded_results + ] + + mixed_result_op = MixedResultConstrainParamsOp( + params=param_args, result_values_or_types=result_values_or_types, **kwargs + ) + # Move the body of the original op to the version with (mixed) results. + constrain_params.body_.blocks[0].append_to(mixed_result_op.body_) + # Safe to remove as the op doesn't have results, so no users either. + constrain_params.erase() + return mixed_result_op + + return wrapper diff --git a/lighthouse/dialects/transform_tune_ext.py b/lighthouse/dialects/transform_tune_ext.py new file mode 100644 index 0000000..4905036 --- /dev/null +++ b/lighthouse/dialects/transform_tune_ext.py @@ -0,0 +1,182 @@ +import inspect +import re +import ast +import math +from dataclasses import dataclass +from typing import Any, Callable, Literal, Optional +from functools import wraps +from operator import mod + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import tune as transform_tune + +__all__ = ["KnobValue", "knob"] + + +def register_and_load(context=None): + pass # NB: currently nothing to register or load. + + +def knob( + *args, + result: Optional[ir.Type] = None, + **kwargs, +) -> "KnobValue": + options = ir.DictAttr.get() + result = result or transform.AnyParamType.get() + return KnobValue( + transform_tune.KnobOp(result, *args, options=options, **kwargs).result + ) + + +def update_knob_options(knob: transform_tune.KnobOp, key, value): + items = list((namedattr.name, namedattr.attr) for namedattr in knob.options) + if isinstance(value, int): + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) + items.append((key, value)) + knob.options = ir.DictAttr.get(dict(items)) + + +class KnobValue(ir.Value): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def in_(self, options): + i64 = ir.IntegerType.get_signless(64) + options_attr = ir.ArrayAttr.get([ir.IntegerAttr.get(i64, v) for v in options]) + + assert ( + isinstance(self.owner.options, ir.DictAttr) and len(self.owner.options) == 0 + ) # Only one constraint supported for now. + self.owner.options = ir.DictAttr.get({"options": options_attr}) + return True + + @staticmethod + def ast_rewrite(in_exprs: bool = False): + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + func_source = inspect.getsource(func) + indent = math.inf + for line in func_source.splitlines(): + indent = min(indent, len(re.match(" *", line).group(0))) + func_source = "\n".join(line[indent:] for line in func_source.splitlines()) + func_ast = ast.parse(func_source) + func_def_ast = func_ast.body[0] + + # TODO: carefully remove just the @KnobValue.ast_rewrite decorator in case of multiple decorators. + func_def_ast.decorator_list.clear() # Remove the decorator to avoid infinite recursion. + if in_exprs: + func_def_ast.body = [ + InTransformer().visit(stmt) for stmt in func_def_ast.body + ] + ast.fix_missing_locations(func_def_ast) + mod = compile(ast.unparse(func_ast), filename="", mode="exec") + frame = inspect.currentframe() + assert frame and frame.f_back + temp_globals = frame.f_back.f_globals.copy() + temp_globals |= frame.f_back.f_locals.copy() + temp_locals = frame.f_back.f_locals.copy() + temp_globals["In"] = In + exec(mod, temp_globals, temp_locals) + return temp_locals[func.__name__](*args, **kwargs) + + return wrapper + + return decorator + + def _set_bound(self, key, combine, value): + assert isinstance(self.owner.options, ir.DictAttr) + existing = self.owner.options[key] if key in self.owner.options else None + update_knob_options( + self.owner.opview, + key, + value if existing is None else combine(existing.value, value), + ) + return True + + def __mod__(self, other): + assert isinstance(other, int) + return KnobExpression(lhs=self, rhs=other, operator=mod) + + def __rmod__(self, other): + assert isinstance(other, int) + return KnobExpression(lhs=other, rhs=self, operator=mod) + + def __lt__(self, other): + assert isinstance(other, int) + return self._set_bound("upper_bound", min, other + 1) + + def __le__(self, other): + assert isinstance(other, int) + return self._set_bound("upper_bound", min, other) + + def __ge__(self, other): + assert isinstance(other, int) + return self._set_bound("lower_bound", max, other) + + def __gt__(self, other): + assert isinstance(other, int) + return self._set_bound("lower_bound", max, other + 1) + + def __eq__(self, other): + assert isinstance(other, int) + assert isinstance(self.owner.options, ir.DictAttr) + assert len(self.owner.options) == 0, "Only one constraint supported for now." + i64 = ir.IntegerType.get_signless(64) + update_knob_options( + self.owner.opview, + "options", + ir.ArrayAttr.get([ir.IntegerAttr.get(i64, other)]), + ) + return True + + +@dataclass +class KnobExpression: + lhs: KnobValue | int + rhs: KnobValue | int + operator: Literal[mod] + + def __eq__(self, other): + assert other == 0, "Only equality to zero supported for now." + assert self.operator is mod + i64 = ir.IntegerType.get_signless(64) + if isinstance(self.lhs, KnobValue): + assert isinstance(self.lhs.owner.options, ir.DictAttr) + assert isinstance(self.rhs, int) + assert "divisible_by" not in self.lhs.owner.options + update_knob_options(self.lhs.owner.opview, "divisible_by", self.rhs) + elif isinstance(self.rhs, KnobValue): + assert isinstance(self.lhs, int) + assert isinstance(self.rhs.owner.options, ir.DictAttr) + assert "divides" not in self.rhs.owner.options + update_knob_options(self.rhs.owner.opview, "divides", self.lhs) + else: + assert False, "At least one operand must be a KnobValue." + + return True + + +@dataclass +class In: + lhs: Any + rhs: Any + + def __bool__(self): + if isinstance(self.lhs, KnobValue): + return self.lhs.in_(self.rhs) + return self.lhs in self.rhs + + +class InTransformer(ast.NodeTransformer): + def visit_Compare(self, node: ast.Compare) -> Any: + self.generic_visit(node) + if len(node.ops) == 1 and isinstance(node.ops[0], ast.In): + return ast.Call( + func=ast.Name(id="In", ctx=ast.Load()), + args=[node.left, node.comparators[0]], + keywords=[], + ) + return node diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 92fdd9c..35f63b6 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -1,3 +1,5 @@ +from collections import namedtuple + from mlir import ir from mlir.dialects.transform import loop from mlir.dialects.transform import bufferization @@ -10,7 +12,9 @@ canonicalize, match, ) -from typing import Optional + +from lighthouse.dialects import smt_ext, transform_smt_ext as td_smt_ext +from lighthouse.dialects.transform_tune_ext import knob, KnobValue class PipelineInterrupt(Exception): @@ -30,19 +34,76 @@ def match_and_split(*args, nhandles=1, **kwargs): # hardware constraints -DPAS_TILE = [8, 16, 16] +DPAS = namedtuple("DPAS", ["M", "N", "K", "A_TILE", "B_TILE", "C_TILE"])( + 8, 16, 16, (8, 16), (16, 16), (8, 16) +) PREFETCH_INST_DATA = [8, 16] NB_WORKITEMS = 16 # workitems in subgroup +LOAD_TILE_SIZES = [8, 16, 32] + + +@KnobValue.ast_rewrite(in_exprs=True) +def checked_params_or_knobs( + params: dict[str, int | None], layer_id="" +) -> dict[str, int | KnobValue]: + """Check the parameters for validity and replace `None`s with knobs with asserted ranges.""" + m, n, k = params["m"], params["n"], params["k"] + assert isinstance(m, int) and isinstance(n, int) and isinstance(k, int) + assert m > 0 and n > 0 and k > 0 + wg_m = params["wg_m"] or knob(layer_id + "wg_m") + wg_n = params["wg_n"] or knob(layer_id + "wg_n") + sg_m = params["sg_m"] or knob(layer_id + "sg_m") + sg_n = params["sg_n"] or knob(layer_id + "sg_n") + k_tile = params["k_tile"] or knob(layer_id + "k_tile") + load_a_m = params["load_a_m"] or knob(layer_id + "load_a_m") + load_a_k = params["load_a_k"] or knob(layer_id + "load_a_k") + load_b_k = params["load_b_k"] or knob(layer_id + "load_b_k") + load_b_n = params["load_b_n"] or knob(layer_id + "load_b_n") + prefetch_a_m = params["prefetch_a_m"] or knob(layer_id + "prefetch_a_m") + prefetch_a_k = params["prefetch_a_k"] or knob(layer_id + "prefetch_a_k") + prefetch_b_k = params["prefetch_b_k"] or knob(layer_id + "prefetch_b_k") + prefetch_b_n = params["prefetch_b_n"] or knob(layer_id + "prefetch_b_n") + + # NB: Constraints on knobs will be added as attributes on the KnobOps, while + # constraints on concrete values will be checked immediately. + assert 64 <= wg_m <= 256 and m % wg_m == 0 and wg_m % DPAS.M == 0 + assert 64 <= wg_n <= 256 and n % wg_n == 0 and wg_n % DPAS.N == 0 + assert 32 <= sg_m <= 128 and m % sg_m == 0 and sg_m % DPAS.M == 0 + assert 32 <= sg_n <= 128 and n % sg_n == 0 and sg_n % DPAS.N == 0 + assert 16 <= k_tile <= 50 and k % k_tile == 0 and k_tile % DPAS.K == 0 + assert load_a_m in LOAD_TILE_SIZES and load_a_m % DPAS.M == 0 + assert load_a_k in LOAD_TILE_SIZES and load_a_k % DPAS.K == 0 + assert load_b_k in LOAD_TILE_SIZES and load_b_k % DPAS.K == 0 + assert load_b_n in LOAD_TILE_SIZES and load_b_n % DPAS.N == 0 + assert prefetch_a_m in LOAD_TILE_SIZES + assert prefetch_a_k in LOAD_TILE_SIZES + assert prefetch_b_k in LOAD_TILE_SIZES + assert prefetch_b_n in LOAD_TILE_SIZES + + return { + "wg_m": wg_m, + "wg_n": wg_n, + "sg_m": sg_m, + "sg_n": sg_n, + "k_tile": k_tile, + "load_a_m": load_a_m, + "load_a_k": load_a_k, + "load_b_k": load_b_k, + "load_b_n": load_b_n, + "prefetch_a_m": prefetch_a_m, + "prefetch_a_k": prefetch_a_k, + "prefetch_b_k": prefetch_b_k, + "prefetch_b_n": prefetch_b_n, + } def get_schedule_module( + params: list[dict[str, int | None]], has_bias: bool = False, has_relu: bool = False, has_convert_c: bool = True, skip_final_layer_relu: bool = False, stop_at_stage: str = "", - nlayers: int = 1, - params: Optional[dict] = None, ) -> ir.Module: """Generate transform schedule module.""" mod = ir.Module.create() @@ -64,41 +125,43 @@ def get_schedule_module( op_name="builtin.module", deduplicate=True, ) + for i, layer_params in enumerate(params): + layer_params |= checked_params_or_knobs( + layer_params, layer_id=f"layer_{i}_" + ) + xegpu_mlp_transform_schedule( payload_mod, + params=params, has_bias=has_bias, has_relu=has_relu, has_convert_c=has_convert_c, skip_final_layer_relu=skip_final_layer_relu, stop_at_stage=stop_at_stage, - nlayers=nlayers, - params=params, ) return mod def xegpu_mlp_transform_schedule( - mod: ir.Value, + mod: ir.Value[transform.AnyOpType], + params: list[dict[str, int | KnobValue]], has_bias: bool = False, has_relu: bool = False, has_convert_c: bool = True, skip_final_layer_relu: bool = False, stop_at_stage: str = "", - nlayers: int = 1, - params: Optional[list[dict]] = None, ): """Transform schedule for MLP-like payload.""" try: mod = bundle_xegpu_mlp_schedule( mod, + params=params, has_bias=has_bias, has_relu=has_relu, has_convert_c=has_convert_c, skip_final_layer_relu=skip_final_layer_relu, stop_at_stage=stop_at_stage, - nlayers=nlayers, - params=params, ) mod = bundle_xegpu_to_binary( @@ -112,27 +175,22 @@ def xegpu_mlp_transform_schedule( def bundle_xegpu_mlp_schedule( - mod: ir.Value, + mod: ir.Value[transform.AnyOpType], + params: list[dict[str, int | KnobValue]], has_bias: bool = False, has_relu: bool = False, skip_final_layer_relu: bool = False, has_convert_c: bool = True, stop_at_stage: str = "", - nlayers: int = 1, - params: Optional[list[dict]] = None, -) -> ir.Module: +) -> ir.Value[transform.AnyOpType]: """Schedule for lowering MLP-like payload to xegpu wg level.""" - if params is None: - raise ValueError("Schedule parameters must be provided.") + nlayers = len(params) if stop_at_stage == "initial": raise PipelineInterrupt() anytype = transform.AnyOpType.get() - for i in range(nlayers): - assert f"layer_{i}" in params, f"Missing parameters for 'layer_{i}'" - # wg tiling if has_convert_c: trunc_op = match(mod, ops={"arith.truncf"}) @@ -156,17 +214,14 @@ def bundle_xegpu_mlp_schedule( terminal_ops = list(relu_ops) + [terminal_ops[-1]] # tile each layer separately - for i_layer in range(nlayers): - layer_params = params[f"layer_{i_layer}"] + for terminal_op, layer_params in zip(terminal_ops, params): # tunable parameters: wg level tiling wg_tile = [layer_params["wg_m"], layer_params["wg_n"]] - sg_tile = [layer_params["sg_m"], layer_params["sg_n"]] - k_tile = layer_params["k"] + k_tile = layer_params["k_tile"] - terminal = terminal_ops[i_layer] - # FIXME use structured.structured_fuse + # FIXME: use structured.structured_fuse _, wg_loop = structured.FuseOp( - terminal, tile_sizes=wg_tile, use_forall=True + terminal_op, tile_sizes=wg_tile, use_forall=True ).results transform.apply_cse(mod) canonicalize(mod) @@ -241,16 +296,35 @@ def bundle_xegpu_mlp_schedule( # set correct number of gpu threads launch_ops = match_and_split(mod, ops={"gpu.launch"}, nhandles=nlayers) - for i_layer, launch_op in enumerate(launch_ops): - layer_params = params[f"layer_{i_layer}"] + assert len(launch_ops) == nlayers + for launch_op, layer_params in zip(launch_ops, params): # tunable parameters - wg_tile = [layer_params["wg_m"], layer_params["wg_n"]] - sg_tile = [layer_params["sg_m"], layer_params["sg_n"]] - - # derived parameters - sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] - # number of threads collapsed to 1d layout - nb_threads = sg_layout[0] * sg_layout[1] * NB_WORKITEMS + wg_m, wg_n = layer_params["wg_m"], layer_params["wg_n"] + sg_m, sg_n = layer_params["sg_m"], layer_params["sg_n"] + + @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) + def constrain_wg_sg_and_calc_nb_threads( + WG_M: int | smt_ext.SMTIntValue, + WG_N: int | smt_ext.SMTIntValue, + SG_M: int | smt_ext.SMTIntValue, + SG_N: int | smt_ext.SMTIntValue, + ): + # NB: normal asserts in case of concrete values, SMT assert ops for symbolic values. + smt_ext.assert_(WG_M % SG_M == 0) + smt_ext.assert_(WG_N % SG_N == 0) + + # NB: normal ints in case of concrete values, SMT int values for symbolic values. + sg_m_threads = WG_M // SG_M + sg_n_threads = WG_N // SG_N + sg_threads = sg_m_threads * sg_n_threads + smt_ext.assert_(sg_threads <= 64) + + # number of threads collapsed to 1d layout + return sg_threads * NB_WORKITEMS + + nb_threads: int | smt_ext.SMTIntValue = ( + constrain_wg_sg_and_calc_nb_threads.results + ) xegpu.set_gpu_launch_threads(launch_op, threads=[nb_threads, 1, 1]) @@ -278,11 +352,12 @@ def bundle_xegpu_mlp_schedule( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - for i_layer, gpu_mod in enumerate(gpu_mod_ops): + assert ( + len(gpu_mod_ops) == nlayers + ), "Expected one gpu.module per MLP layer after outlining" + for gpu_mod, layer_params in zip(gpu_mod_ops, params): gpu_func = match(gpu_mod, ops={"gpu.func"}) - xegpu_wg_annotation_for_mlp_layer( - gpu_func, params[f"layer_{i_layer}"], has_bias=has_bias - ) + xegpu_wg_annotation_for_mlp_layer(gpu_func, **layer_params, has_bias=has_bias) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() @@ -290,44 +365,106 @@ def bundle_xegpu_mlp_schedule( return mod -def xegpu_wg_annotation_for_mlp_layer(gpu_func: ir.Value, params: dict, has_bias: bool): +def xegpu_wg_annotation_for_mlp_layer( + gpu_func: ir.Value, + *, + wg_m: int | KnobValue, + wg_n: int | KnobValue, + sg_m: int | KnobValue, + sg_n: int | KnobValue, + k_tile: int | KnobValue, + load_a_m: int | KnobValue, + load_a_k: int | KnobValue, + load_b_k: int | KnobValue, + load_b_n: int | KnobValue, + prefetch_a_m: int | KnobValue, + prefetch_a_k: int | KnobValue, + prefetch_b_k: int | KnobValue, + prefetch_b_n: int | KnobValue, + nb_prefetch: int, + has_bias: bool, + **_catch_all, +): """ Adds prefetching and XeGPU anchor layout annotations for an MLP layer. Should be applied after the payload has been converted to XeGPU using the convert-vector-to-xegpu pass. """ + anytype = transform.AnyOpType.get() anyvalue = transform.AnyValueType.get() - dpas_shape_a = [DPAS_TILE[0], DPAS_TILE[2]] - dpas_shape_b = [DPAS_TILE[2], DPAS_TILE[1]] - dpas_shape_c = [DPAS_TILE[0], DPAS_TILE[1]] + # Calculate with SMT ops in case of symbolic values, normal ints in case of concrete values. + @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) + def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): + return WG_M // SG_M, WG_N // SG_N + + sg_layout = calc_sg_layout.results + + load_tile_a = [load_a_m, load_a_k] + load_tile_b = [load_b_k, load_b_n] + prefetch_tile_a = [prefetch_a_m, prefetch_a_k] + prefetch_tile_b = [prefetch_b_k, prefetch_b_n] + + @td_smt_ext.constrain_params( + sg_m, + sg_n, + k_tile, + load_a_m, + load_a_k, + load_b_k, + load_b_n, + prefetch_a_m, + prefetch_a_k, + prefetch_b_k, + prefetch_b_n, + ) + def constrain_and_calculate_load_and_prefetch_params( + SG_M, SG_N, K_TILE, LDA_M, LDA_K, LDB_K, LDB_N, PFA_M, PFA_K, PFB_K, PFB_N + ): + # NB: normal asserts in case of concrete values, SMT assert ops for symbolic values + # TODO: Tuomas' comments explaining constraints: + smt_ext.assert_(SG_M % PFA_M == 0) + smt_ext.assert_(SG_M % LDA_M == 0) + + smt_ext.assert_(SG_N % PFB_N == 0) + smt_ext.assert_(SG_N % LDB_N == 0) + smt_ext.assert_(K_TILE % PFA_K == 0) + smt_ext.assert_(K_TILE % PFB_K == 0) + smt_ext.assert_(K_TILE % LDA_K == 0) + smt_ext.assert_(K_TILE % LDB_K == 0) + + smt_ext.assert_(LDA_M * LDA_K >= 16 * 16) + smt_ext.assert_(LDB_K * LDB_N >= 16 * 16) + + smt_ext.assert_(LDA_M <= LDA_K) + smt_ext.assert_(LDB_K <= LDB_N) + smt_ext.assert_(LDB_N == DPAS.N) + + PFA_M_step = SG_M // PFA_M + PFA_K_step = K_TILE // PFA_K + smt_ext.assert_(PFA_M_step * PFA_K_step <= 64) + + PFB_K_step = K_TILE // PFB_K + PFB_N_step = SG_N // PFB_N + smt_ext.assert_(PFB_K_step * PFB_N_step <= 64) - wg_tile = [params["wg_m"], params["wg_n"]] - sg_tile = [params["sg_m"], params["sg_n"]] - k_tile = params["k"] + smt_ext.assert_(PFA_M * PFA_K >= 16 * 16) + smt_ext.assert_(PFA_M >= PFA_K) - sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] + smt_ext.assert_(PFB_K * PFB_N >= 16 * 16) + smt_ext.assert_(PFB_K >= PFB_N) + smt_ext.assert_((SG_M // DPAS.M) * (SG_N // DPAS.N) * (K_TILE // DPAS.K) <= 64) - load_tile_a = [params["load_a_m"], params["load_a_k"]] - load_tile_b = [params["load_b_k"], params["load_b_n"]] - prefetch_tile_a = [params["pf_a_m"], params["pf_a_k"]] - prefetch_tile_b = [params["pf_b_k"], params["pf_b_n"]] - nb_prefetch = params["pf_nb"] + return PFA_M_step, PFA_K_step, PFB_K_step, PFB_N_step - prefetch_layout_a = [ - wg_tile[0] // prefetch_tile_a[0], - k_tile // prefetch_tile_a[1], - ] - prefetch_layout_b = [ - k_tile // prefetch_tile_b[0], - wg_tile[1] // prefetch_tile_b[1], - ] + prefetch_layout_a = constrain_and_calculate_load_and_prefetch_params.results[0:2] + prefetch_layout_b = constrain_and_calculate_load_and_prefetch_params.results[2:4] # matmul matrix shapes - sg_tile_a = [sg_tile[0], k_tile] - sg_tile_b = [k_tile, sg_tile[1]] + sg_tile_a = [sg_m, k_tile] + sg_tile_b = [k_tile, sg_n] # add layouts to DPAS op operands k_loop = match(gpu_func, ops={"scf.for"}) @@ -382,7 +519,7 @@ def annotate_ab_load(tile, layout_load, layout_dpas): } # A tile dpas layout layout_dpas_a = layout_load_a.copy() - layout_dpas_a["inst_data"] = dpas_shape_a + layout_dpas_a["inst_data"] = DPAS.A_TILE annotate_ab_load(tile_a, layout_load_a, layout_dpas_a) # B tile load layout @@ -393,14 +530,14 @@ def annotate_ab_load(tile, layout_load, layout_dpas): } # B tile dpas layout layout_dpas_b = layout_load_b.copy() - layout_dpas_b["inst_data"] = dpas_shape_b + layout_dpas_b["inst_data"] = DPAS.B_TILE annotate_ab_load(tile_b, layout_load_b, layout_dpas_b) # C tile layout output_layout = { "sg_layout": sg_layout, - "sg_data": sg_tile, - "inst_data": dpas_shape_c, + "sg_data": [sg_m, sg_n], + "inst_data": DPAS.C_TILE, } # C tile dpas anchor layout xegpu.set_op_layout_attr(dpas_op, index=0, **layout_dpas_a) @@ -430,7 +567,9 @@ def annotate_ab_load(tile, layout_load, layout_dpas): transform.apply_cse(gpu_func) -def bundle_xegpu_to_binary(mod, stop_at_stage: str = "") -> ir.Module: +def bundle_xegpu_to_binary( + mod, stop_at_stage: str = "" +) -> ir.Value[transform.AnyOpType]: """Schedule for lowering xegpu wg level to binary.""" # upstream xegpu/xevm pipeline is payload independent. mod = apply_registered_pass( diff --git a/lighthouse/tune/__main__.py b/lighthouse/tune/__main__.py new file mode 100644 index 0000000..de3f916 --- /dev/null +++ b/lighthouse/tune/__main__.py @@ -0,0 +1,92 @@ +import sys +import argparse +from typing import Mapping + +from mlir import ir +from mlir.dialects import transform +from lighthouse.tune import ( + rewrite as lh_tune_rewrite, + trace as lh_tune_trace, + enumerate as lh_tune_enumerate, +) +from lighthouse import dialects as lh_dialects +from lighthouse.utils.types import LazyChainMap + +HEADER = "//" * 40 + "\n// {}\n" + "//" * 40 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("file", type=str, help="Path to the MLIR file to process") + parser.add_argument("--count-only", action="store_true") + parser.add_argument( + "--mode", + choices=["enumerate"], + default="enumerate", + help="Mode of operation", + ) + parser.add_argument( + "-n", type=int, help="Number of concrete schedules to output", default=1 + ) + args = parser.parse_args() + + file = sys.stdin if args.file == "-" else open(args.file, "r") + with ir.Context() as ctx, ir.Location.unknown(): + lh_dialects.register_and_load() + + module = ir.Module.parse(file.read()) + + if args.mode == "enumerate": + # Trace the named_seq, obtaining a DAG of tunable nodes and nodes + # which are functions and predicates dependent on the tunable nodes. + named_seq = module.body.operations[0].opview + assert isinstance(named_seq, transform.NamedSequenceOp) + op_or_value_to_node = lh_tune_trace.trace_tune_and_smt_ops( + named_seq.operation + ) + + # The predicate associated to the overall named_seq is the conjunct + # -ion of each of the predicates for each operation in seq's body. + overall_predicate = op_or_value_to_node[named_seq] + assert isinstance(overall_predicate, lh_tune_trace.Predicate) + tunables = list( + set( + node + for node in op_or_value_to_node.values() + if isinstance(node, lh_tune_trace.NonDeterministic) + ) + ) + + # Start enumerating assignments for the tune.knob and tune.alternatives ops. + count = 0 + for count, node_to_int in zip( + range(1, args.n + 1), + lh_tune_enumerate.all_satisfying_assignments( + tunables, [overall_predicate] + ), + ): + if args.count_only: + if count >= args.n: + break + continue + + print(HEADER.format(f"Config {count}:")) + + i64 = ir.IntegerType.get_signless(64) + + # Map the tuneable ops to the attributes that should assigned to them. + mapping: Mapping[ir.Value | ir.Operation, ir.Attribute] = LazyChainMap( + op_or_value_to_node, + lambda node: ir.IntegerAttr.get(i64, node_to_int[node]), + ) + + # Walk the IR, obtaining and setting the corresponding attr for each tuneable op. + mod_op = lh_tune_rewrite.set_selected(module.operation, mapping) + print(mod_op) + + if count >= args.n: + break + print("// count:", count) + else: + assert False, "Other modes are not yet implemented" diff --git a/lighthouse/tune/enumerate.py b/lighthouse/tune/enumerate.py new file mode 100644 index 0000000..5d62752 --- /dev/null +++ b/lighthouse/tune/enumerate.py @@ -0,0 +1,20 @@ +from itertools import product +from typing import Sequence + +from .trace import NonDeterministic, Predicate + + +def all_satisfying_assignments( + tuneables: Sequence[NonDeterministic], predicates: Sequence[Predicate] +): + """Generate all assignments of values to the tuneables that satisfy the predicates.""" + + for tuneable_values in product( + *(tuneable.possibilities() for tuneable in tuneables) + ): + environment = dict(zip((tunable for tunable in tuneables), tuneable_values)) + for pred in predicates: + if not pred.evaluate(environment): + break + else: + yield environment diff --git a/lighthouse/tune/rewrite.py b/lighthouse/tune/rewrite.py new file mode 100644 index 0000000..c8308a4 --- /dev/null +++ b/lighthouse/tune/rewrite.py @@ -0,0 +1,20 @@ +from typing import Mapping + +from mlir import ir +from mlir.dialects.transform import tune as transform_tune + + +def set_selected(op: ir.Operation, env: Mapping[ir.Value | ir.Operation, ir.Attribute]): + """Walk op's IR and set attrs on transform.tune.* ops per the env mapping.""" + + def set(op: ir.Operation) -> ir.WalkResult: + op = op.opview + match op: + case transform_tune.KnobOp(): + op.attributes["selected"] = env[op.result] + case transform_tune.AlternativesOp(): + op.attributes["selected_region"] = env[op] + return ir.WalkResult.ADVANCE + + op.walk(set, ir.WalkOrder.PRE_ORDER) + return op diff --git a/lighthouse/tune/trace.py b/lighthouse/tune/trace.py new file mode 100644 index 0000000..96f57c1 --- /dev/null +++ b/lighthouse/tune/trace.py @@ -0,0 +1,315 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass +from itertools import islice + +from operator import eq, ge, gt, le, lt, ne, mul, mod, floordiv, add + +from typing import Callable, Generator, Sequence, Optional + +from mlir import ir +from mlir.dialects import transform, smt +from mlir.dialects.transform import tune as transform_tune + +from lighthouse.dialects import transform_smt_ext + + +class Node(ABC): + """Base class for `Node`s which can be evaluated w.r.t. an environment.""" + + @abstractmethod + def evaluate(self, environment: dict["Node", int]) -> int | bool: + raise NotImplementedError + + +@dataclass(frozen=True) +class Constant(Node): + """Trivial base case `Node` which evaluates to a constant irrespective of env. + + Intended to represent `Value`s which are constants in IR.""" + + value: int + + def evaluate(self, environment: dict[Node, int]) -> int: + return self.value + + +@dataclass(frozen=True) +class NonDeterministic(Node, ABC): + """Abstract `Node` which evaluates via looking up its name in the environment.""" + + name: str + + @abstractmethod + def possibilities(self) -> Generator[int, None, None]: + raise NotImplementedError + + def evaluate(self, environment: dict[Node, int]) -> int | bool: + return environment[self] + + +@dataclass(frozen=True) +class Knob(NonDeterministic): + """Base case `Node` which evals per name in env and knows its possible values. + + Intended to represent the `Value` associated with a tuneable knob in IR.""" + + options: Optional[Sequence[int]] = None + lower_bound: Optional[int] = None + upper_bound: Optional[int] = None + divisible_by: Optional[int] = None + divides: Optional[int] = None + + def __post_init__(self): + assert self.options or ( + None not in (self.lower_bound, self.upper_bound) + ), "Options attribute not finitely specified" + assert ( + self.divisible_by is None or self.divisible_by > 0 + ), "divisible_by must be positive" + assert self.divides is None or self.divides > 0, "divides must be positive" + + def __repr__(self): + return ( + f"{self.name}<".upper() + + f"{list(self.possibilities())}>".replace(", ", "|")[1:-2] + + ">" + ) + + def possibilities(self) -> Generator[int, None, None]: + if self.options is not None: + if self.divides is not None or self.divisible_by is not None: + yield from filter( + lambda val: (self.divides is None or (self.divides % val == 0)) + and (self.divisible_by is None or (val % self.divisible_by == 0)), + self.options, + ) + else: + yield from self.options + else: + low = self.lower_bound + step = 1 + if self.divisible_by is not None: + low = self.lower_bound + (-self.lower_bound % self.divisible_by) + step = self.divisible_by + for val in range(low, self.upper_bound + 1, step): + if self.divides is None or self.divides % val == 0: + yield val + + +@dataclass(frozen=True) +class Apply(Node): + """Recursive case `Node` which calculates a function from other eval-ables. + + Intended to represent `Value`s in IR dependent on other `Value`s.""" + + operator: Callable[..., int] + args: Sequence[Node] + + def evaluate(self, environment: dict[Node, int]) -> int: + return self.operator(*[arg.evaluate(environment) for arg in self.args]) + + +@dataclass(frozen=True) +class Predicate(Node): + """Recursive case `Node` which applies a predicate to other eval-ables. + + Intended to represent the condition that must hold for execution to be able + to succesfully proceed beyond a specified op.""" + + operator: Callable[..., bool] + args: Sequence[Node] + + def evaluate(self, environment: dict[Node, int]) -> bool: + return self.operator(*[arg.evaluate(environment) for arg in self.args]) + + +@dataclass(frozen=True) +class Alternatives(NonDeterministic): + """Recursive case `Node` which acts as its selected child predicate, per its env. + + Intended to represent selection among a fixed number of regions in IR.""" + + alt_idx_to_pred: Sequence[Predicate] + + def evaluate(self, environment: dict[Node, int]) -> bool: + selected_region_idx = super().evaluate(environment) # evals name to int + return self.alt_idx_to_pred[selected_region_idx].evaluate(environment) + + def possibilities(self) -> Generator[int, None, None]: + yield from range(len(self.alt_idx_to_pred)) + + +@dataclass +class AlternativesResult(Node): + """Recursive case `Node` which selects among child nodes per the env-chosen alternative. + + Intended to represent results associated with a particular region being selected.""" + + alternatives: Alternatives + region_idx_to_result: dict[int, Node] + + def evaluate(self, environment: dict[Node, int]) -> int: + alt_idx = environment[self.alternatives] + return self.region_idx_to_result[alt_idx].evaluate(environment) + + +def trace_smt_op(op: ir.Operation, env: dict) -> dict: + """Add mapping from SMT op to a new Node that represents it to the env.""" + + op = op.opview + match op: + case smt.IntConstantOp(): + env[op.result] = Constant(op.value.value) + + case smt.EqOp(): + assert len(op.operands) == 2 + env[op.result] = Predicate(eq, [env[value] for value in op.operands]) + + case smt.IntAddOp(): + assert len(op.operands) == 2 + env[op.result] = Apply(add, [env[value] for value in op.operands]) + + case smt.IntMulOp(): + assert len(op.operands) == 2 + env[op.result] = Apply(mul, [env[value] for value in op.operands]) + + case smt.IntModOp(): + env[op.result] = Apply(mod, [env[op.lhs], env[op.rhs]]) + + case smt.IntDivOp(): + env[op.result] = Apply(floordiv, [env[op.lhs], env[op.rhs]]) + + case smt.IntCmpOp(): + operator = [lt, le, gt, ge][op.pred.value] + env[op.result] = Predicate(operator, [env[op.lhs], env[op.rhs]]) + + case smt.AssertOp(): + pred = env[op.input] + assert isinstance( + pred, Predicate + ), "SMT assert expected argument to map to a Predicate node" + env[op] = pred + + case _: + assert False, f"Unknown SMT operation: {op}" + + return env + + +def trace_tune_and_smt_ops(op: ir.Operation, env: Optional[dict] = None) -> dict: + """Recursively add mapping from transform(.tune) and SMT ops to representative Nodes to env.""" + + env = env if env is not None else {} # TODO: nested scopes + + op = op.opview + match op: + case transform.ParamConstantOp(): + env[op.result] = Constant(op.value.value) + + case transform.MatchParamCmpIOp(): + operator = [eq, ne, le, lt, ge, gt][op.predicate.value] + env[op] = Predicate(operator, [env[op.param], env[op.reference]]) + + case transform_tune.KnobOp(): + kwargs = {} + + # Inspect attrs on KnobOp and convert to args to pass to Knob Node. + if op.selected is not None: + kwargs["options"] = (op.selected.value,) + elif isinstance(op.options, ir.ArrayAttr): + kwargs["options"] = tuple(opt.value for opt in op.options) + elif isinstance(op.options, ir.DictAttr): + if "options" in op.options: + kwargs["options"] = tuple( + opt.value for opt in op.options["options"] + ) + if "lower_bound" in op.options: + kwargs["lower_bound"] = op.options["lower_bound"].value + if "upper_bound" in op.options: + kwargs["upper_bound"] = op.options["upper_bound"].value + if "divisible_by" in op.options: + kwargs["divisible_by"] = op.options["divisible_by"].value + if "divides" in op.options: + kwargs["divides"] = op.options["divides"].value + else: + assert False, "Unknown options attribute type" + + env[op.result] = Knob(name=op.name.value, **kwargs) + + case transform_tune.AlternativesOp(): + # Recursively visit each "alternative" child region, deriving a predicate + # for each region and track which results are associated to that region. + region_idx_to_pred = [] + result_idx_region_idx_to_node = defaultdict(lambda: dict()) + for reg_idx, region in enumerate(op.regions): + region_preds = [] + for child in region.blocks[0]: + trace_tune_and_smt_ops(child.operation, env) + if (child_pred := env.get(child)) is not None: + assert isinstance(child_pred, Predicate) + region_preds.append(child_pred) + assert isinstance(child, transform.YieldOp) + + region_idx_to_pred[reg_idx] = Predicate( + lambda *args: all(args), region_preds + ) + for res_idx, yield_operand in enumerate(child.operands): + result_idx_region_idx_to_node[res_idx][reg_idx] = env[yield_operand] + + # Construct the node, that acts as a predicate, which represents + # selecting among the "alternative" regions. + env[op] = Alternatives(name=op.name, alt_idx_to_pred=region_idx_to_pred) + + # Construct the nodes, which act as a functions, that represent the + # results corresponding to one of the "alternative" regions being selected. + for res_idx, result in enumerate(op.results): + env[result] = AlternativesResult( + alternatives=env[op], + region_idx_to_result=result_idx_region_idx_to_node[res_idx], + ) + + case transform_smt_ext.ConstrainParamsOp(): + # Map the block args in the op's region to the nodes already + # associated to the corresponding arguments on the op itself. + for operand, block_arg in zip(op.operands, op.body.arguments): + env[block_arg] = env[operand] + + # Recursively trace the child (SMT) ops and construct an overall + # predicate representing the block/region successfully terminating. + child_predicates = [] + for child in islice(op.body.operations, len(op.body.operations) - 1): + trace_smt_op(child, env) + if (child_pred := env.get(child)) is not None: + assert isinstance(child_pred, Predicate) + child_predicates.append(child_pred) + + env[op] = Predicate(lambda *args: all(args), child_predicates) + + assert isinstance(smt_yield := op.body.operations[-1], smt.YieldOp) + + # Map the op's results to the nodes already associated to the + # corresponding values yielded by the region/block's terminator. + for yield_operand, op_res in zip(smt_yield.operands, op.results): + env[op_res] = env[yield_operand] + + case transform.NamedSequenceOp(): + # Recursively trace the child ops and construct an overall + # predicate representing the block/region successfully terminating. + child_predicates = [] + for child in op.body.operations: + trace_tune_and_smt_ops(child.operation, env) + if (child_pred := env.get(child)) is not None: + assert isinstance(child_pred, Predicate) + child_predicates.append(child_pred) + + env[op] = Predicate(lambda *args: all(args), child_predicates) + + case transform.ApplyPatternsOp(): + # A transform op with child ops we do skip over. + pass + + case _: + assert len(op.regions) == 0, f"Unhandled operation with regions: {op}" + + return env diff --git a/lighthouse/utils/types.py b/lighthouse/utils/types.py new file mode 100644 index 0000000..733dc1b --- /dev/null +++ b/lighthouse/utils/types.py @@ -0,0 +1,24 @@ +from typing import TypeVar, Generic, Callable + +from collections.abc import Mapping + +K = TypeVar("K") +V = TypeVar("V") +W = TypeVar("W") + + +class LazyChainMap(Mapping, Generic[K, V, W]): + def __init__(self, data: dict[K, V], func: Callable[[V], W]): + self._data = data + self._func = func + + def __getitem__(self, key): + # Access the underlying data and apply the transformation + value = self._data[key] + return self._func(value) + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) diff --git a/lighthouse/workload/workload.py b/lighthouse/workload/workload.py index cc2c4f5..e600958 100644 --- a/lighthouse/workload/workload.py +++ b/lighthouse/workload/workload.py @@ -4,6 +4,8 @@ Defines the expected interface for generic workload execution methods. """ +import sys + from mlir import ir from mlir.execution_engine import ExecutionEngine from abc import ABC, abstractmethod @@ -67,14 +69,15 @@ def lower_payload( schedule_module = self.schedule_module( stop_at_stage=dump_payload, parameters=schedule_parameters ) + if dump_schedule: + print(schedule_module) + sys.exit(0) if not dump_payload or dump_payload != "initial": # apply schedule on payload module named_seq = schedule_module.body.operations[0] named_seq.apply(payload_module) if dump_payload: print(payload_module) - if dump_schedule: - print(schedule_module) return payload_module @abstractmethod From 42b580161718e0d6ba0cd20a0dfa07713d72dc5e Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 9 Mar 2026 16:09:37 -0700 Subject: [PATCH 2/5] Documentation --- lighthouse/dialects/smt_ext.py | 5 + lighthouse/dialects/transform_smt_ext.py | 170 ++++++++++------------ lighthouse/dialects/transform_tune_ext.py | 42 +++++- lighthouse/utils/types.py | 2 + 4 files changed, 124 insertions(+), 95 deletions(-) diff --git a/lighthouse/dialects/smt_ext.py b/lighthouse/dialects/smt_ext.py index fdd9930..a2ecc43 100644 --- a/lighthouse/dialects/smt_ext.py +++ b/lighthouse/dialects/smt_ext.py @@ -7,11 +7,14 @@ def register_and_load(context=None): + """Register and load the SMTIntValue caster.""" + SMTIntValue.register_value_caster() def assert_(predicate: ir.Value[smt.BoolType] | bool): """Assert normally if a bool else produce an SMT assertion op.""" + if isinstance(predicate, bool): assert predicate else: @@ -32,6 +35,8 @@ def swapped( class SMTIntValue(ir.Value[smt.IntType]): + """A Value caster for `!smt.int` that supports Pythonic arithmetic and comparison operations.""" + def __init__(self, v): super().__init__(v) diff --git a/lighthouse/dialects/transform_smt_ext.py b/lighthouse/dialects/transform_smt_ext.py index 4935c34..abf173e 100644 --- a/lighthouse/dialects/transform_smt_ext.py +++ b/lighthouse/dialects/transform_smt_ext.py @@ -12,11 +12,16 @@ "register_and_load", ] + def register_and_load(context=None): + """Register and load the TransformSMTDialectExtension and its operations.""" + TransformSMTDialectExtension.load() class TransformSMTDialectExtension(ext.Dialect, name="transform_smt_ext"): + """A Transform Dialect extension for SMT-related operations.""" + @classmethod def load(cls, *args, **kwargs): super(TransformSMTDialectExtension, cls).load(*args, **kwargs) @@ -29,6 +34,13 @@ def load(cls, *args, **kwargs): class ConstrainParamsOp( TransformSMTDialectExtension.Operation, name="constrain_params" ): + """Constrain transform params by SMT ops while also producing new params. + + In effect applies a predicate defined by the SMT ops in the body, which can + reference the parameters as block arguments as !smt.int. The result params + are defined by the !smt.int values yielded from the body. + """ + results_: Sequence[ext.Result[transform.AnyParamType]] params: Sequence[ext.Operand[transform.AnyParamType]] body_: ext.Region @@ -49,6 +61,8 @@ def attach_interfaces(cls, ctx=None): setattr(cls, "_interfaces_attached", True) class ConstrainParamsTransformOpInterfaceModel(transform.TransformOpInterface): + """TransformOpInterface impl for evaluating the SMT constraints and producing new params.""" + @staticmethod def apply( op: "ConstrainParamsOp", @@ -56,17 +70,24 @@ def apply( results: transform.TransformResults, state: transform.TransformState, ) -> transform.DiagnosedSilenceableFailure: + # Set up the tracing environment by obtaining the transform params + # and mapping them to Node constants, so that the traced Node + # representation will refer to the params as just constants. env = dict() for operand in op.params: params = state.get_params(operand) assert len(params) == 1 and isinstance(params[0].value, int) env[operand] = trace.Constant(params[0].value) + # Obtained traced representation of the body of the op. env = trace.trace_tune_and_smt_ops(op.operation, env) - if not env[op].evaluate(env): # evaluate the conjoined predicate + # Evaluate the predicate that represents the successful execution of the body. + if not env[op].evaluate(env): return transform.DiagnosedSilenceableFailure.DefiniteFailure + # If the predicate is satisfied, we can extract the values of the result params + # from the environment and set them as the results of the transformation. for result in op.results: res_value = env[result].evaluate(env) i64 = ir.IntegerType.get_signless(64) @@ -88,6 +109,16 @@ def get_effects(op: "ConstrainParamsOp", effects): class MixedResultConstrainParamsOp(ConstrainParamsOp): + """ConstrainParamsOp that supports both integer and SMTIntValues as results. + + Used to support `constrain_params` as a decorator on functions that yield a + mix of Python integers and `!smt.int`s (which are either arguments to the + function/block or the result of operations in the body). Upon the body's function + returning, the original ConstrainParamsOp is replaced with this version + that has the same parameters but whose `.results` corresponds to the mix of + integers and SMT values yielded from the body. + """ + def __init__( self, *args, @@ -109,109 +140,62 @@ def results(self) -> Sequence[int | ext.Result[transform.AnyParamType]]: return self._results -# class ConstrainParamsOpDecorator(ConstrainParamsOp): -# def __init__( -# self, -# *params: transform.AnyParamType | int, -# results: Sequence[int | ext.Result[transform.AnyParamType]] | None = None, -# **kwargs, -# ): -# transform_params = [p for p in params if isinstance(p, ir.Value)] -# super().__init__([], transform_params, **kwargs) -# block_arg_types = [smt.IntType.get()] * len(transform_params) -# self.body_.blocks.append(*block_arg_types) -# -# self._arguments = [] -# self._results = results -# smt_arguments = iter(self.body.arguments) -# for param in params: -# if isinstance(param, int): -# self._arguments.append(param) -# else: -# self._arguments.append(next(smt_arguments)) -# -# @property -# def results(self) -> Sequence[ext.Result | int]: -# """Returns the yielded results of the decorated function, which are either -# integers or the transform parameters that correspond to the yielded SMT -# int values.""" -# assert self._results is not None, ( -# "Results are not available until the decorated function is called" -# ) -# return self._results -# -# def __call__(self, func): -# with ir.InsertionPoint(self.body): -# yielded_results = func(*self._arguments) -# -# smt.yield_(res for res in yielded_results if isinstance(res, ir.Value)) -# -# print(f"{yielded_results=}") -# if len(yielded_results) == 0: -# return self -# -# # In case of yielded results, we need to create a new ConstrainParamsOp with the same parameters and a body that contains the original body of the decorator, but with the yielded results as the results of the new op. We then replace the original op with the new one and return it. -# result_types = [transform.AnyParamType.get()] * sum( -# 1 for res in yielded_results if isinstance(res, ir.Value) -# ) -# with ir.InsertionPoint(self): -# self_with_results = ConstrainParamsOp( -# result_types, self.params, loc=self.location -# ) -# self.body_.blocks[0].append_to(self_with_results.body_) -# # new_block = self_with_results.body_.blocks.append( -# # *orig_block.arguments.types -# # ) -# # arg_mapping = dict(zip(orig_block.arguments, new_block.arguments)) -# # lh_utils_rewrite.move_block(orig_block, new_block, arg_mapping) -# # self.erase() -# -# results = [] -# op_results = iter(self_with_results.results) -# for yielded_result in yielded_results: -# if isinstance(yielded_result, int): -# results.append(yielded_result) -# elif isinstance(yielded_result, ir.Value): -# results.append(next(op_results)) -# else: -# assert False, "Unsupported yielded result type" -# setattr(self_with_results, "_results", results) -# return self_with_results - - @overload def constrain_params( *params: ir.Value | int, loc=None, ip=None -) -> Callable[..., MixedResultConstrainParamsOp]: ... +) -> Callable[..., MixedResultConstrainParamsOp]: + """Calls the decorated function with param args converted to !smt.int args. + + The decorated function defines the body of the ConstrainParamsOp and handles + args as `!smt.int` or Python integer. The function should yield a mix of + Python integers and `!smt.int`s (the latter can be either block arguments or + results of operations in the body). The original ConstrainParamsOp created + by the decorator will be replaced with a MixedResultConstrainParamsOp that + has the same parameters but whose results correspond to the mix of integers + and SMT values yielded from the body. + """ + + ... @overload def constrain_params( results: Sequence[ir.Type], params: Sequence[transform.AnyParamType], - arg_types: Sequence[ir.Type], loc=None, ip=None, -) -> ConstrainParamsOp: ... +) -> ConstrainParamsOp: + """Creates a ConstrainParamsOp where the body is defined by the caller.""" + + ... def constrain_params( *args, **kwargs ) -> ConstrainParamsOp | Callable[..., MixedResultConstrainParamsOp]: + """Creates a ConstrainParamsOp or a decorator for a function that yields mixed results.""" + # The second overload: - if len(args) == 0 or isinstance(args[0], ir.Type): - arg_types = kwargs.pop("arg_types") + if len(args) == 0 or not ( + isinstance(args[0], ir.Value) or isinstance(args[0], int) + ): + params = kwargs.get("params") or args[1] + arg_types = [smt.IntType.get()] * len(params) op = ConstrainParamsOp(*args, **kwargs) op.body_.blocks.append(*arg_types) return op # The first overload: - # return ConstrainParamsOpDecorator(*args, **kwargs) def wrapper(func): + # Create a ConstrainParamsOp with just the transform parameters as block arguments. param_args = [p for p in args if isinstance(p, ir.Value)] constrain_params = ConstrainParamsOp([], param_args, **kwargs) constrain_params.body_.blocks.append(*[smt.IntType.get()] * len(param_args)) + # Call `func` with !smt.int block arguments for corresponding transform params, + # and just normal ints for those passed via `args`. The body of `func` will be + # the body of the op, and it can yield a mix of Python integers and `!smt.int`s. + # A corresponding `smt.yield` will be generated as the terminator. block_args_iter = iter(constrain_params.body_.blocks[0].arguments) with ir.InsertionPoint(constrain_params.body): yielded_results = func( @@ -224,21 +208,25 @@ def wrapper(func): yielded_results = [yielded_results] smt.yield_(res for res in yielded_results if isinstance(res, ir.Value)) - if len(yielded_results) == 0: - return constrain_params + # In case no results are returned, the current ConstrainParamsOp is sufficient. + if len(yielded_results) == 0: + return constrain_params - result_values_or_types = [ - transform.AnyParamType.get() if isinstance(res, ir.Value) else res - for res in yielded_results - ] + # Create a new version of the ConstrainParamsOp that has the same + # parameters but whose results correspond to the mix of integers and + # SMT values yielded from the body. + result_values_or_types = [ + transform.AnyParamType.get() if isinstance(res, ir.Value) else res + for res in yielded_results + ] - mixed_result_op = MixedResultConstrainParamsOp( - params=param_args, result_values_or_types=result_values_or_types, **kwargs - ) - # Move the body of the original op to the version with (mixed) results. - constrain_params.body_.blocks[0].append_to(mixed_result_op.body_) - # Safe to remove as the op doesn't have results, so no users either. - constrain_params.erase() - return mixed_result_op + mixed_result_op = MixedResultConstrainParamsOp( + params=param_args, result_values_or_types=result_values_or_types, **kwargs + ) + # Move the body of the original op to the version with (mixed) results. + constrain_params.body_.blocks[0].append_to(mixed_result_op.body_) + # Safe to remove as the op doesn't have results, so no users either. + constrain_params.erase() + return mixed_result_op return wrapper diff --git a/lighthouse/dialects/transform_tune_ext.py b/lighthouse/dialects/transform_tune_ext.py index 4905036..f1ed505 100644 --- a/lighthouse/dialects/transform_tune_ext.py +++ b/lighthouse/dialects/transform_tune_ext.py @@ -3,7 +3,7 @@ import ast import math from dataclasses import dataclass -from typing import Any, Callable, Literal, Optional +from typing import Any, Callable, Literal, Optional, Sequence from functools import wraps from operator import mod @@ -23,6 +23,8 @@ def knob( result: Optional[ir.Type] = None, **kwargs, ) -> "KnobValue": + """Create a `transform.tune.knob` op whose result is wrapped in/cast to KnobValue.""" + options = ir.DictAttr.get() result = result or transform.AnyParamType.get() return KnobValue( @@ -39,10 +41,14 @@ def update_knob_options(knob: transform_tune.KnobOp, key, value): class KnobValue(ir.Value): + """Wrapper for KnobOp's result for a pythonic API for specifying the knob's constraints.""" + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def in_(self, options): + def in_(self, options: Sequence): + """Specify that the knob's value must be among the given options.""" + i64 = ir.IntegerType.get_signless(64) options_attr = ir.ArrayAttr.get([ir.IntegerAttr.get(i64, v) for v in options]) @@ -54,32 +60,47 @@ def in_(self, options): @staticmethod def ast_rewrite(in_exprs: bool = False): + """Decorator to allow `in` expressions on KnobOps in the function's body. + + Rewrite the function's AST to replace `in` expressions with calls to `In`, + which in case the LHS is a KnobOp corresponds to calling `KnobOp.in_`. + """ + def decorator(func: Callable): @wraps(func) def wrapper(*args, **kwargs): + # Get the func's textual source and deal with the function being indented. func_source = inspect.getsource(func) indent = math.inf for line in func_source.splitlines(): indent = min(indent, len(re.match(" *", line).group(0))) - func_source = "\n".join(line[indent:] for line in func_source.splitlines()) + func_source = "\n".join( + line[indent:] for line in func_source.splitlines() + ) + # Obtain the corresponding AST. func_ast = ast.parse(func_source) func_def_ast = func_ast.body[0] - # TODO: carefully remove just the @KnobValue.ast_rewrite decorator in case of multiple decorators. + # TODO: in case of multiple decorators, remove just @KnobValue.ast_rewrite func_def_ast.decorator_list.clear() # Remove the decorator to avoid infinite recursion. if in_exprs: + # Apply the rewriting of `in` expressions. func_def_ast.body = [ InTransformer().visit(stmt) for stmt in func_def_ast.body ] ast.fix_missing_locations(func_def_ast) + # Obtain executable code which still needs an execution environment. mod = compile(ast.unparse(func_ast), filename="", mode="exec") frame = inspect.currentframe() assert frame and frame.f_back + # Make the original function's globals and locals available to the rewritten function. temp_globals = frame.f_back.f_globals.copy() temp_globals |= frame.f_back.f_locals.copy() temp_locals = frame.f_back.f_locals.copy() temp_globals["In"] = In + # Make the rewritten function available as a value at the original name. exec(mod, temp_globals, temp_locals) + # Call the rewritten function and return its result. return temp_locals[func.__name__](*args, **kwargs) return wrapper @@ -135,6 +156,11 @@ def __eq__(self, other): @dataclass class KnobExpression: + """Helper class to represent expressions on KnobValues that then occur in equalities. + + In order to support (the LHS) in such constraints as `knob("X") % 16 == 0`. + """ + lhs: KnobValue | int rhs: KnobValue | int operator: Literal[mod] @@ -161,6 +187,12 @@ def __eq__(self, other): @dataclass class In: + """Helper for rewriting `in` expressions. + + Only `knob('X') in Y` gets mapped to `knob('X').in_(Y)` everything else + corresponds to just a regular Python `in` expression. + """ + lhs: Any rhs: Any @@ -171,6 +203,8 @@ def __bool__(self): class InTransformer(ast.NodeTransformer): + """AST transformer to rewrite `in` expressions to calls to `In`.""" + def visit_Compare(self, node: ast.Compare) -> Any: self.generic_visit(node) if len(node.ops) == 1 and isinstance(node.ops[0], ast.In): diff --git a/lighthouse/utils/types.py b/lighthouse/utils/types.py index 733dc1b..9bb90b5 100644 --- a/lighthouse/utils/types.py +++ b/lighthouse/utils/types.py @@ -8,6 +8,8 @@ class LazyChainMap(Mapping, Generic[K, V, W]): + """A mapping that applies a function to the values of an underlying dictionary on access.""" + def __init__(self, data: dict[K, V], func: Callable[[V], W]): self._data = data self._func = func From 96be18a4f6f665e4d47f705c0a1c29cf55128d2d Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 9 Mar 2026 16:13:44 -0700 Subject: [PATCH 3/5] Formatting --- lighthouse/dialects/transform_smt_ext.py | 4 +++- lighthouse/schedule/xegpu/mlp_schedule.py | 6 +++--- lighthouse/tune/trace.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/lighthouse/dialects/transform_smt_ext.py b/lighthouse/dialects/transform_smt_ext.py index abf173e..f144b36 100644 --- a/lighthouse/dialects/transform_smt_ext.py +++ b/lighthouse/dialects/transform_smt_ext.py @@ -221,7 +221,9 @@ def wrapper(func): ] mixed_result_op = MixedResultConstrainParamsOp( - params=param_args, result_values_or_types=result_values_or_types, **kwargs + params=param_args, + result_values_or_types=result_values_or_types, + **kwargs, ) # Move the body of the original op to the version with (mixed) results. constrain_params.body_.blocks[0].append_to(mixed_result_op.body_) diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index a9ceefc..451a0fe 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -352,9 +352,9 @@ def constrain_wg_sg_and_calc_nb_threads( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - assert ( - len(gpu_mod_ops) == nlayers - ), "Expected one gpu.module per MLP layer after outlining" + assert len(gpu_mod_ops) == nlayers, ( + "Expected one gpu.module per MLP layer after outlining" + ) for gpu_mod, layer_params in zip(gpu_mod_ops, params): gpu_func = match(gpu_mod, ops={"gpu.func"}) xegpu_wg_annotation_for_mlp_layer(gpu_func, **layer_params, has_bias=has_bias) diff --git a/lighthouse/tune/trace.py b/lighthouse/tune/trace.py index 96f57c1..c7183bc 100644 --- a/lighthouse/tune/trace.py +++ b/lighthouse/tune/trace.py @@ -61,12 +61,12 @@ class Knob(NonDeterministic): divides: Optional[int] = None def __post_init__(self): - assert self.options or ( - None not in (self.lower_bound, self.upper_bound) - ), "Options attribute not finitely specified" - assert ( - self.divisible_by is None or self.divisible_by > 0 - ), "divisible_by must be positive" + assert self.options or (None not in (self.lower_bound, self.upper_bound)), ( + "Options attribute not finitely specified" + ) + assert self.divisible_by is None or self.divisible_by > 0, ( + "divisible_by must be positive" + ) assert self.divides is None or self.divides > 0, "divides must be positive" def __repr__(self): @@ -186,9 +186,9 @@ def trace_smt_op(op: ir.Operation, env: dict) -> dict: case smt.AssertOp(): pred = env[op.input] - assert isinstance( - pred, Predicate - ), "SMT assert expected argument to map to a Predicate node" + assert isinstance(pred, Predicate), ( + "SMT assert expected argument to map to a Predicate node" + ) env[op] = pred case _: From 6f496bac7d0ee7a294b3019b62d6181e5fe4aead Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 10 Mar 2026 16:50:12 -0700 Subject: [PATCH 4/5] Update to Tuomas's latest params and constraints --- examples/xegpu/matmul.py | 8 +- examples/xegpu/mlp.py | 5 +- examples/xegpu/parameter_selector.py | 139 +++++++++++++--------- lighthouse/dialects/smt_ext.py | 8 +- lighthouse/dialects/transform_smt_ext.py | 44 +++---- lighthouse/dialects/transform_tune_ext.py | 8 +- lighthouse/schedule/xegpu/mlp_schedule.py | 129 +++++++++++++------- lighthouse/utils/mlir.py | 5 +- lighthouse/workload/workload.py | 5 +- 9 files changed, 212 insertions(+), 139 deletions(-) diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index d4131d4..5d8b24a 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -247,28 +247,28 @@ def parse_cli(): "--load-tile-a", type=int, nargs=2, - default=[16, 32], + default=[32, 16], help="Tile size for loading A matrix for DPAS op.", ) parser.add_argument( "--load-tile-b", type=int, nargs=2, - default=[16, 16], + default=[32, 16], help="Tile size for loading B matrix for DPAS op.", ) parser.add_argument( "--prefetch-tile-a", type=int, nargs=2, - default=[16, 32], + default=[8, 32], help="Tile size for cooperative prefetching of subgroup A matrix", ) parser.add_argument( "--prefetch-tile-b", type=int, nargs=2, - default=[16, 16], + default=[8, 32], help="Tile size for cooperative prefetching of subgroup B matrix", ) parser.add_argument( diff --git a/examples/xegpu/mlp.py b/examples/xegpu/mlp.py index 5e8e51c..391aa0b 100644 --- a/examples/xegpu/mlp.py +++ b/examples/xegpu/mlp.py @@ -21,6 +21,7 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -65,7 +66,6 @@ def __init__( layer_sizes = [self.input_size] + self.hidden_layer_sizes + [self.output_size] self.weight_shapes = list(zip(layer_sizes[:-1], layer_sizes[1:])) self.matmul_layers = [(self.batch_size, o, i) for i, o in self.weight_shapes] - self.nlayers = len(self.matmul_layers) self.identity_weights = identity_weights self.bias_shapes = [(o,) for o in layer_sizes[1:]] if has_bias else [] @@ -260,7 +260,6 @@ def schedule_module( has_relu=self.has_relu, skip_final_layer_relu=True, stop_at_stage=stop_at_stage, - nlayers=self.nlayers, params=parameters, ) @@ -369,6 +368,8 @@ def parse_cli(): identity_weights = args.check_result with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUMLP( batch_size=args.batch_size, input_size=args.input_size, diff --git a/examples/xegpu/parameter_selector.py b/examples/xegpu/parameter_selector.py index 54d2ef7..5de7e74 100644 --- a/examples/xegpu/parameter_selector.py +++ b/examples/xegpu/parameter_selector.py @@ -4,159 +4,186 @@ def get_matmul_parameters(workload): - parameters = {} + parameters = [] for i, shape in enumerate(workload.matmul_layers): if shape not in matmul_param_db: raise ValueError( f"Parameter selector: No parameters found for matmul shape {shape}" ) - parameters[f"layer_{i}"] = matmul_param_db[shape] + parameters.append(matmul_param_db[shape]) return parameters matmul_param_db = { (4096, 4096, 4096): { + "m": 4096, + "n": 4096, + "k": 4096, "wg_m": 256, "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k": 64, + "k_tile": 64, "load_a_m": 32, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 32, - "pf_b_k": 8, - "pf_b_n": 32, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 32, + "nb_prefetch": 1, }, (128, 16384, 16384): { + "m": 128, + "n": 16384, + "k": 16384, "wg_m": 128, "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k": 256, + "k_tile": 256, "load_a_m": 8, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 16, - "pf_b_k": 8, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 16, + "prefetch_b_k": 8, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, (128, 8192, 16384): { + "m": 128, + "n": 8192, + "k": 16384, "wg_m": 64, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 128, + "k_tile": 128, "load_a_m": 16, "load_a_k": 16, "load_b_k": 16, "load_b_n": 16, - "pf_a_m": 32, - "pf_a_k": 16, - "pf_b_k": 16, - "pf_b_n": 32, - "pf_nb": 1, + "prefetch_a_m": 32, + "prefetch_a_k": 16, + "prefetch_b_k": 16, + "prefetch_b_n": 32, + "nb_prefetch": 1, }, (128, 32768, 16384): { + "m": 128, + "n": 32768, + "k": 16384, "wg_m": 128, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 256, + "k_tile": 256, "load_a_m": 8, "load_a_k": 16, "load_b_k": 16, "load_b_n": 16, - "pf_a_m": 16, - "pf_a_k": 32, - "pf_b_k": 8, - "pf_b_n": 32, - "pf_nb": 1, + "prefetch_a_m": 16, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 32, + "nb_prefetch": 1, }, (128, 16384, 32768): { + "m": 128, + "n": 16384, + "k": 32768, "wg_m": 128, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 256, + "k_tile": 256, "load_a_m": 8, "load_a_k": 16, "load_b_k": 16, "load_b_n": 16, - "pf_a_m": 32, - "pf_a_k": 32, - "pf_b_k": 8, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 32, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, (128, 32768, 32768): { + "m": 128, + "n": 32768, + "k": 32768, "wg_m": 128, "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k": 256, + "k_tile": 256, "load_a_m": 8, "load_a_k": 16, "load_b_k": 16, "load_b_n": 16, - "pf_a_m": 16, - "pf_a_k": 32, - "pf_b_k": 32, - "pf_b_n": 32, - "pf_nb": 1, + "prefetch_a_m": 16, + "prefetch_a_k": 32, + "prefetch_b_k": 32, + "prefetch_b_n": 32, + "nb_prefetch": 1, }, (1024, 1024, 8192): { + "m": 1024, + "n": 1024, + "k": 8192, "wg_m": 256, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 32, + "k_tile": 32, "load_a_m": 8, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 16, - "pf_b_k": 8, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 16, + "prefetch_b_k": 8, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, (1024, 8192, 1024): { + "m": 1024, + "n": 8192, + "k": 1024, "wg_m": 256, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 32, + "k_tile": 32, "load_a_m": 16, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 16, - "pf_b_k": 16, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 16, + "prefetch_b_k": 16, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, (1024, 1024, 1024): { + "m": 1024, + "n": 1024, + "k": 1024, "wg_m": 128, "wg_n": 64, "sg_m": 32, "sg_n": 32, - "k": 32, + "k_tile": 32, "load_a_m": 16, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 32, - "pf_b_k": 8, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, } diff --git a/lighthouse/dialects/smt_ext.py b/lighthouse/dialects/smt_ext.py index a2ecc43..c4c279d 100644 --- a/lighthouse/dialects/smt_ext.py +++ b/lighthouse/dialects/smt_ext.py @@ -12,13 +12,15 @@ def register_and_load(context=None): SMTIntValue.register_value_caster() -def assert_(predicate: ir.Value[smt.BoolType] | bool): +def assert_(predicate: ir.Value[smt.BoolType] | bool, error_message: str = ""): """Assert normally if a bool else produce an SMT assertion op.""" if isinstance(predicate, bool): - assert predicate + assert predicate, error_message else: - smt.assert_(predicate) + assert_ = smt.assert_(predicate) + if error_message: + assert_.attributes["error"] = ir.StringAttr.get(error_message) def int_to_smt(operand: "int | SMTIntValue") -> "SMTIntValue": diff --git a/lighthouse/dialects/transform_smt_ext.py b/lighthouse/dialects/transform_smt_ext.py index f144b36..70a6cbd 100644 --- a/lighthouse/dialects/transform_smt_ext.py +++ b/lighthouse/dialects/transform_smt_ext.py @@ -208,27 +208,27 @@ def wrapper(func): yielded_results = [yielded_results] smt.yield_(res for res in yielded_results if isinstance(res, ir.Value)) - # In case no results are returned, the current ConstrainParamsOp is sufficient. - if len(yielded_results) == 0: - return constrain_params - - # Create a new version of the ConstrainParamsOp that has the same - # parameters but whose results correspond to the mix of integers and - # SMT values yielded from the body. - result_values_or_types = [ - transform.AnyParamType.get() if isinstance(res, ir.Value) else res - for res in yielded_results - ] - - mixed_result_op = MixedResultConstrainParamsOp( - params=param_args, - result_values_or_types=result_values_or_types, - **kwargs, - ) - # Move the body of the original op to the version with (mixed) results. - constrain_params.body_.blocks[0].append_to(mixed_result_op.body_) - # Safe to remove as the op doesn't have results, so no users either. - constrain_params.erase() - return mixed_result_op + # In case no results are returned, the current ConstrainParamsOp is sufficient. + if len(yielded_results) == 0: + return constrain_params + + # Create a new version of the ConstrainParamsOp that has the same + # parameters but whose results correspond to the mix of integers and + # SMT values yielded from the body. + result_values_or_types = [ + transform.AnyParamType.get() if isinstance(res, ir.Value) else res + for res in yielded_results + ] + + mixed_result_op = MixedResultConstrainParamsOp( + params=param_args, + result_values_or_types=result_values_or_types, + **kwargs, + ) + # Move the body of the original op to the version with (mixed) results. + constrain_params.body_.blocks[0].append_to(mixed_result_op.body_) + # Safe to remove as the op doesn't have results, so no users either. + constrain_params.erase() + return mixed_result_op return wrapper diff --git a/lighthouse/dialects/transform_tune_ext.py b/lighthouse/dialects/transform_tune_ext.py index f1ed505..56bb647 100644 --- a/lighthouse/dialects/transform_tune_ext.py +++ b/lighthouse/dialects/transform_tune_ext.py @@ -89,8 +89,12 @@ def wrapper(*args, **kwargs): InTransformer().visit(stmt) for stmt in func_def_ast.body ] ast.fix_missing_locations(func_def_ast) - # Obtain executable code which still needs an execution environment. - mod = compile(ast.unparse(func_ast), filename="", mode="exec") + # Adjust line numbers to match the original source file. + source_file = inspect.getsourcefile(func) or "" + _, start_lineno = inspect.getsourcelines(func) + ast.increment_lineno(func_ast, start_lineno - 1) + # Compile from the AST directly to preserve line number mapping. + mod = compile(func_ast, filename=source_file, mode="exec") frame = inspect.currentframe() assert frame and frame.f_back # Make the original function's globals and locals available to the rewritten function. diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 451a0fe..223ee6e 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -16,6 +16,22 @@ from lighthouse.dialects import smt_ext, transform_smt_ext as td_smt_ext from lighthouse.dialects.transform_tune_ext import knob, KnobValue +# hardware constraints +DPAS = namedtuple("DPAS", ["M", "N", "K", "A_TILE", "B_TILE", "C_TILE"])( + 8, 16, 16, (8, 16), (16, 16), (8, 16) +) +PREFETCH_INST_DATA = [8, 16] +NB_WORKITEMS = 16 # workitems in subgroup +LOAD_MAX_ROWS = 32 +LOAD_MAX_COLS = 16 +PFETCH_MIN_ROWS = 8 +PFETCH_MAX_ROWS = 32 +PFETCH_MIN_COLS = 16 +PFETCH_MAX_COLS = 32 +MAX_NB_SG_THREADS = 64 +# heuristics: skip likely suboptimal configurations +MIN_NB_THREADS = 16 + class PipelineInterrupt(Exception): """Exception to signal early termination of the transform schedule.""" @@ -33,15 +49,6 @@ def match_and_split(*args, nhandles=1, **kwargs): return matched_ops -# hardware constraints -DPAS = namedtuple("DPAS", ["M", "N", "K", "A_TILE", "B_TILE", "C_TILE"])( - 8, 16, 16, (8, 16), (16, 16), (8, 16) -) -PREFETCH_INST_DATA = [8, 16] -NB_WORKITEMS = 16 # workitems in subgroup -LOAD_TILE_SIZES = [8, 16, 32] - - @KnobValue.ast_rewrite(in_exprs=True) def checked_params_or_knobs( params: dict[str, int | None], layer_id="" @@ -66,11 +73,18 @@ def checked_params_or_knobs( # NB: Constraints on knobs will be added as attributes on the KnobOps, while # constraints on concrete values will be checked immediately. - assert 64 <= wg_m <= 256 and m % wg_m == 0 and wg_m % DPAS.M == 0 - assert 64 <= wg_n <= 256 and n % wg_n == 0 and wg_n % DPAS.N == 0 - assert 32 <= sg_m <= 128 and m % sg_m == 0 and sg_m % DPAS.M == 0 - assert 32 <= sg_n <= 128 and n % sg_n == 0 and sg_n % DPAS.N == 0 - assert 16 <= k_tile <= 50 and k % k_tile == 0 and k_tile % DPAS.K == 0 + assert min(max(m // 4, 16), 64) <= wg_m <= min(m, 256) + assert m % wg_m == 0 and wg_m % DPAS.M == 0 + assert min(max(n // 4, 16), 64) <= wg_n <= min(n, 256) + assert n % wg_n == 0 and wg_n % DPAS.N == 0 + assert min(max(m // 8, 16), 32) <= sg_m <= min(m, 128) + assert m % sg_m == 0 and sg_m % DPAS.M == 0 + assert min(max(n // 8, 16), 32) <= sg_n <= min(n, 128) + assert n % sg_n == 0 and sg_n % DPAS.N == 0 + assert 16 <= k_tile <= min(k, 256) + assert k % k_tile == 0 and k_tile % DPAS.K == 0 + + LOAD_TILE_SIZES = [8, 16, 32] assert load_a_m in LOAD_TILE_SIZES and load_a_m % DPAS.M == 0 assert load_a_k in LOAD_TILE_SIZES and load_a_k % DPAS.K == 0 assert load_b_k in LOAD_TILE_SIZES and load_b_k % DPAS.K == 0 @@ -317,7 +331,10 @@ def constrain_wg_sg_and_calc_nb_threads( sg_m_threads = WG_M // SG_M sg_n_threads = WG_N // SG_N sg_threads = sg_m_threads * sg_n_threads - smt_ext.assert_(sg_threads <= 64) + smt_ext.assert_(sg_threads <= MAX_NB_SG_THREADS, "too many SG threads") + if isinstance(sg_threads, smt_ext.SMTIntValue): + # NB: Constraint only enabled during tuning. + smt_ext.assert_(sg_threads >= MIN_NB_THREADS, "too few SG threads") # number of threads collapsed to 1d layout return sg_threads * NB_WORKITEMS @@ -352,9 +369,9 @@ def constrain_wg_sg_and_calc_nb_threads( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - assert len(gpu_mod_ops) == nlayers, ( - "Expected one gpu.module per MLP layer after outlining" - ) + assert ( + len(gpu_mod_ops) == nlayers + ), "Expected one gpu.module per MLP layer after outlining" for gpu_mod, layer_params in zip(gpu_mod_ops, params): gpu_func = match(gpu_mod, ops={"gpu.func"}) xegpu_wg_annotation_for_mlp_layer(gpu_func, **layer_params, has_bias=has_bias) @@ -398,6 +415,7 @@ def xegpu_wg_annotation_for_mlp_layer( # Calculate with SMT ops in case of symbolic values, normal ints in case of concrete values. @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): + # NB: Constraint on overall num SG threads already dealt with elsewhere. return WG_M // SG_M, WG_N // SG_N sg_layout = calc_sg_layout.results @@ -424,40 +442,59 @@ def constrain_and_calculate_load_and_prefetch_params( SG_M, SG_N, K_TILE, LDA_M, LDA_K, LDB_K, LDB_N, PFA_M, PFA_K, PFB_K, PFB_N ): # NB: normal asserts in case of concrete values, SMT assert ops for symbolic values - # TODO: Tuomas' comments explaining constraints: - smt_ext.assert_(SG_M % PFA_M == 0) smt_ext.assert_(SG_M % LDA_M == 0) - - smt_ext.assert_(SG_N % PFB_N == 0) - smt_ext.assert_(SG_N % LDB_N == 0) - smt_ext.assert_(K_TILE % PFA_K == 0) - smt_ext.assert_(K_TILE % PFB_K == 0) smt_ext.assert_(K_TILE % LDA_K == 0) smt_ext.assert_(K_TILE % LDB_K == 0) + smt_ext.assert_(SG_N % LDB_N == 0) - smt_ext.assert_(LDA_M * LDA_K >= 16 * 16) - smt_ext.assert_(LDB_K * LDB_N >= 16 * 16) - - smt_ext.assert_(LDA_M <= LDA_K) - smt_ext.assert_(LDB_K <= LDB_N) - smt_ext.assert_(LDB_N == DPAS.N) - - PFA_M_step = SG_M // PFA_M - PFA_K_step = K_TILE // PFA_K - smt_ext.assert_(PFA_M_step * PFA_K_step <= 64) - - PFB_K_step = K_TILE // PFB_K - PFB_N_step = SG_N // PFB_N - smt_ext.assert_(PFB_K_step * PFB_N_step <= 64) - - smt_ext.assert_(PFA_M * PFA_K >= 16 * 16) - smt_ext.assert_(PFA_M >= PFA_K) + smt_ext.assert_(LDA_M <= LOAD_MAX_ROWS) + smt_ext.assert_(LDA_K <= LOAD_MAX_COLS) + smt_ext.assert_(LDB_K <= LOAD_MAX_ROWS) + smt_ext.assert_(LDB_N <= LOAD_MAX_COLS) - smt_ext.assert_(PFB_K * PFB_N >= 16 * 16) - smt_ext.assert_(PFB_K >= PFB_N) - smt_ext.assert_((SG_M // DPAS.M) * (SG_N // DPAS.N) * (K_TILE // DPAS.K) <= 64) + smt_ext.assert_(SG_M % PFA_M == 0) + smt_ext.assert_(K_TILE % PFA_K == 0) + smt_ext.assert_(K_TILE % PFB_K == 0) + smt_ext.assert_(SG_N % PFB_N == 0) - return PFA_M_step, PFA_K_step, PFB_K_step, PFB_N_step + smt_ext.assert_(PFA_M <= PFETCH_MAX_ROWS) + smt_ext.assert_(PFA_K <= PFETCH_MAX_COLS) + smt_ext.assert_(PFB_K <= PFETCH_MAX_ROWS) + smt_ext.assert_(PFB_N <= PFETCH_MAX_COLS) + smt_ext.assert_(PFA_M >= PFETCH_MIN_ROWS) + smt_ext.assert_(PFA_K >= PFETCH_MIN_COLS) + smt_ext.assert_(PFB_K >= PFETCH_MIN_ROWS) + smt_ext.assert_(PFB_N >= PFETCH_MIN_COLS) + + smt_ext.assert_(LDA_M % DPAS.M == 0) + smt_ext.assert_(LDA_K % DPAS.K == 0) + smt_ext.assert_(LDB_K % DPAS.K == 0) + smt_ext.assert_(LDB_N % DPAS.N == 0) + + nb_load_b_n = LDB_N // DPAS.N + # unsupported VNNI layout, loaded tile can only be row-sliced for vnni + # NOTE this can plausibly be relaxed + smt_ext.assert_(nb_load_b_n <= 1, "invalid load_tile_b_n for VNNI") + + # prefetch A layout + nb_prefetch_a_m = SG_M // PFA_M + nb_prefetch_a_k = K_TILE // PFA_K + nb_prefetch_a = nb_prefetch_a_m * nb_prefetch_a_k + smt_ext.assert_(nb_prefetch_a <= MAX_NB_SG_THREADS) + if isinstance(nb_prefetch_a, smt_ext.SMTIntValue): + # NB: Constraint only enabled during tuning. + smt_ext.assert_(nb_prefetch_a_m * nb_prefetch_a_k >= MIN_NB_THREADS) + + # prefetch B layout + nb_prefetch_b_k = K_TILE // PFB_K + nb_prefetch_b_n = SG_N // PFB_N + nb_prefetch_b = nb_prefetch_b_k * nb_prefetch_b_n + smt_ext.assert_(nb_prefetch_b <= MAX_NB_SG_THREADS) + if isinstance(nb_prefetch_b, smt_ext.SMTIntValue): + # NB: Constraint only enabled during tuning. + smt_ext.assert_(nb_prefetch_b_k * nb_prefetch_b_n >= MIN_NB_THREADS) + + return nb_prefetch_a_m, nb_prefetch_a_k, nb_prefetch_b_k, nb_prefetch_b_n prefetch_layout_a = constrain_and_calculate_load_and_prefetch_params.results[0:2] prefetch_layout_b = constrain_and_calculate_load_and_prefetch_params.results[2:4] diff --git a/lighthouse/utils/mlir.py b/lighthouse/utils/mlir.py index c7163b1..440d205 100644 --- a/lighthouse/utils/mlir.py +++ b/lighthouse/utils/mlir.py @@ -12,7 +12,10 @@ def get_mlir_library_path(): pkg_path = ir.__file__ if "python_packages" in pkg_path: # looks like a local mlir install - path = os.path.join(pkg_path.split("python_packages")[0], "lib") + build_tools_mlir_dir = pkg_path.split("python_packages")[0] + build_tools_dir = build_tools_mlir_dir.rsplit("mlir")[0] + build_dir = build_tools_dir.rsplit("tools")[0] + path = os.path.join(build_dir, "lib") else: # maybe installed in python path path = os.path.join(os.path.split(pkg_path)[0], "_mlir_libs") diff --git a/lighthouse/workload/workload.py b/lighthouse/workload/workload.py index e600958..69a82cc 100644 --- a/lighthouse/workload/workload.py +++ b/lighthouse/workload/workload.py @@ -69,15 +69,14 @@ def lower_payload( schedule_module = self.schedule_module( stop_at_stage=dump_payload, parameters=schedule_parameters ) - if dump_schedule: - print(schedule_module) - sys.exit(0) if not dump_payload or dump_payload != "initial": # apply schedule on payload module named_seq = schedule_module.body.operations[0] named_seq.apply(payload_module) if dump_payload: print(payload_module) + if dump_schedule: + print(schedule_module) return payload_module @abstractmethod From 6091109f187fadaeee2a2de29fe0f6ec98930397 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 10 Mar 2026 16:54:40 -0700 Subject: [PATCH 5/5] Syntax fixes --- lighthouse/dialects/transform_tune_ext.py | 1 - lighthouse/schedule/xegpu/mlp_schedule.py | 6 +++--- lighthouse/workload/workload.py | 2 -- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/lighthouse/dialects/transform_tune_ext.py b/lighthouse/dialects/transform_tune_ext.py index 56bb647..5cecf8e 100644 --- a/lighthouse/dialects/transform_tune_ext.py +++ b/lighthouse/dialects/transform_tune_ext.py @@ -172,7 +172,6 @@ class KnobExpression: def __eq__(self, other): assert other == 0, "Only equality to zero supported for now." assert self.operator is mod - i64 = ir.IntegerType.get_signless(64) if isinstance(self.lhs, KnobValue): assert isinstance(self.lhs.owner.options, ir.DictAttr) assert isinstance(self.rhs, int) diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 223ee6e..d9c211d 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -369,9 +369,9 @@ def constrain_wg_sg_and_calc_nb_threads( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - assert ( - len(gpu_mod_ops) == nlayers - ), "Expected one gpu.module per MLP layer after outlining" + assert len(gpu_mod_ops) == nlayers, ( + "Expected one gpu.module per MLP layer after outlining" + ) for gpu_mod, layer_params in zip(gpu_mod_ops, params): gpu_func = match(gpu_mod, ops={"gpu.func"}) xegpu_wg_annotation_for_mlp_layer(gpu_func, **layer_params, has_bias=has_bias) diff --git a/lighthouse/workload/workload.py b/lighthouse/workload/workload.py index 69a82cc..cc2c4f5 100644 --- a/lighthouse/workload/workload.py +++ b/lighthouse/workload/workload.py @@ -4,8 +4,6 @@ Defines the expected interface for generic workload execution methods. """ -import sys - from mlir import ir from mlir.execution_engine import ExecutionEngine from abc import ABC, abstractmethod