diff --git a/examples/xegpu/matmul.py b/examples/xegpu/matmul.py index 044cbb4..5d8b24a 100644 --- a/examples/xegpu/matmul.py +++ b/examples/xegpu/matmul.py @@ -17,6 +17,7 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -194,13 +195,13 @@ def payload_module(self) -> ir.Module: def schedule_module( self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None ) -> ir.Module: + assert parameters is not None, "Schedule parameters must be provided" return get_schedule_module( has_bias=self.has_bias, has_relu=self.has_relu, has_convert_c=False, stop_at_stage=stop_at_stage, - nlayers=1, - params={"layer_0": parameters}, + params=[parameters], ) def shared_libs(self) -> list[str]: @@ -212,6 +213,9 @@ def parse_cli(): description="Matrix Multiplication using MLIR", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument( + "--all-knobs", action="store_true", help="Use knobs for all schedule parameters" + ) parser.add_argument( "--sizes", type=int, @@ -333,28 +337,34 @@ def parse_cli(): if __name__ == "__main__": args = parse_cli() + M, N, K = args.sizes + params = { - "wg_m": args.wg_tile[0], - "wg_n": args.wg_tile[1], - "sg_m": args.sg_tile[0], - "sg_n": args.sg_tile[1], - "k": args.k_tile, - "load_a_m": args.load_tile_a[0], - "load_a_k": args.load_tile_a[1], - "load_b_k": args.load_tile_b[0], - "load_b_n": args.load_tile_b[1], - "pf_a_m": args.prefetch_tile_a[0], - "pf_a_k": args.prefetch_tile_a[1], - "pf_b_k": args.prefetch_tile_b[0], - "pf_b_n": args.prefetch_tile_b[1], - "pf_nb": args.nb_prefetch, + "m": M, + "n": N, + "k": K, + "wg_m": None if args.all_knobs else args.wg_tile[0], + "wg_n": None if args.all_knobs else args.wg_tile[1], + "sg_m": None if args.all_knobs else args.sg_tile[0], + "sg_n": None if args.all_knobs else args.sg_tile[1], + "k_tile": None if args.all_knobs else args.k_tile, + "load_a_m": None if args.all_knobs else args.load_tile_a[0], + "load_a_k": None if args.all_knobs else args.load_tile_a[1], + "load_b_k": None if args.all_knobs else args.load_tile_b[0], + "load_b_n": None if args.all_knobs else args.load_tile_b[1], + "prefetch_a_m": None if args.all_knobs else args.prefetch_tile_a[0], + "prefetch_a_k": None if args.all_knobs else args.prefetch_tile_a[1], + "prefetch_b_k": None if args.all_knobs else args.prefetch_tile_b[0], + "prefetch_b_n": None if args.all_knobs else args.prefetch_tile_b[1], + "nb_prefetch": args.nb_prefetch, } - M, N, K = args.sizes ab_type = "f16" c_type = "f32" with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUMatMul( M=M, N=N, diff --git a/examples/xegpu/mlp.py b/examples/xegpu/mlp.py index 5e8e51c..391aa0b 100644 --- a/examples/xegpu/mlp.py +++ b/examples/xegpu/mlp.py @@ -21,6 +21,7 @@ from mlir import ir from mlir.execution_engine import ExecutionEngine +from lighthouse import dialects as lh_dialects from lighthouse.workload import benchmark from lighthouse.utils.memref import to_ctype as memref_to_ctype from lighthouse.utils.numpy import numpy_to_ctype @@ -65,7 +66,6 @@ def __init__( layer_sizes = [self.input_size] + self.hidden_layer_sizes + [self.output_size] self.weight_shapes = list(zip(layer_sizes[:-1], layer_sizes[1:])) self.matmul_layers = [(self.batch_size, o, i) for i, o in self.weight_shapes] - self.nlayers = len(self.matmul_layers) self.identity_weights = identity_weights self.bias_shapes = [(o,) for o in layer_sizes[1:]] if has_bias else [] @@ -260,7 +260,6 @@ def schedule_module( has_relu=self.has_relu, skip_final_layer_relu=True, stop_at_stage=stop_at_stage, - nlayers=self.nlayers, params=parameters, ) @@ -369,6 +368,8 @@ def parse_cli(): identity_weights = args.check_result with ir.Context(), ir.Location.unknown(): + lh_dialects.register_and_load() + wload = XeGPUMLP( batch_size=args.batch_size, input_size=args.input_size, diff --git a/examples/xegpu/parameter_selector.py b/examples/xegpu/parameter_selector.py index 54d2ef7..5de7e74 100644 --- a/examples/xegpu/parameter_selector.py +++ b/examples/xegpu/parameter_selector.py @@ -4,159 +4,186 @@ def get_matmul_parameters(workload): - parameters = {} + parameters = [] for i, shape in enumerate(workload.matmul_layers): if shape not in matmul_param_db: raise ValueError( f"Parameter selector: No parameters found for matmul shape {shape}" ) - parameters[f"layer_{i}"] = matmul_param_db[shape] + parameters.append(matmul_param_db[shape]) return parameters matmul_param_db = { (4096, 4096, 4096): { + "m": 4096, + "n": 4096, + "k": 4096, "wg_m": 256, "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k": 64, + "k_tile": 64, "load_a_m": 32, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 32, - "pf_b_k": 8, - "pf_b_n": 32, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 32, + "nb_prefetch": 1, }, (128, 16384, 16384): { + "m": 128, + "n": 16384, + "k": 16384, "wg_m": 128, "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k": 256, + "k_tile": 256, "load_a_m": 8, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 16, - "pf_b_k": 8, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 16, + "prefetch_b_k": 8, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, (128, 8192, 16384): { + "m": 128, + "n": 8192, + "k": 16384, "wg_m": 64, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 128, + "k_tile": 128, "load_a_m": 16, "load_a_k": 16, "load_b_k": 16, "load_b_n": 16, - "pf_a_m": 32, - "pf_a_k": 16, - "pf_b_k": 16, - "pf_b_n": 32, - "pf_nb": 1, + "prefetch_a_m": 32, + "prefetch_a_k": 16, + "prefetch_b_k": 16, + "prefetch_b_n": 32, + "nb_prefetch": 1, }, (128, 32768, 16384): { + "m": 128, + "n": 32768, + "k": 16384, "wg_m": 128, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 256, + "k_tile": 256, "load_a_m": 8, "load_a_k": 16, "load_b_k": 16, "load_b_n": 16, - "pf_a_m": 16, - "pf_a_k": 32, - "pf_b_k": 8, - "pf_b_n": 32, - "pf_nb": 1, + "prefetch_a_m": 16, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 32, + "nb_prefetch": 1, }, (128, 16384, 32768): { + "m": 128, + "n": 16384, + "k": 32768, "wg_m": 128, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 256, + "k_tile": 256, "load_a_m": 8, "load_a_k": 16, "load_b_k": 16, "load_b_n": 16, - "pf_a_m": 32, - "pf_a_k": 32, - "pf_b_k": 8, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 32, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, (128, 32768, 32768): { + "m": 128, + "n": 32768, + "k": 32768, "wg_m": 128, "wg_n": 256, "sg_m": 32, "sg_n": 32, - "k": 256, + "k_tile": 256, "load_a_m": 8, "load_a_k": 16, "load_b_k": 16, "load_b_n": 16, - "pf_a_m": 16, - "pf_a_k": 32, - "pf_b_k": 32, - "pf_b_n": 32, - "pf_nb": 1, + "prefetch_a_m": 16, + "prefetch_a_k": 32, + "prefetch_b_k": 32, + "prefetch_b_n": 32, + "nb_prefetch": 1, }, (1024, 1024, 8192): { + "m": 1024, + "n": 1024, + "k": 8192, "wg_m": 256, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 32, + "k_tile": 32, "load_a_m": 8, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 16, - "pf_b_k": 8, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 16, + "prefetch_b_k": 8, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, (1024, 8192, 1024): { + "m": 1024, + "n": 8192, + "k": 1024, "wg_m": 256, "wg_n": 128, "sg_m": 32, "sg_n": 32, - "k": 32, + "k_tile": 32, "load_a_m": 16, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 16, - "pf_b_k": 16, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 16, + "prefetch_b_k": 16, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, (1024, 1024, 1024): { + "m": 1024, + "n": 1024, + "k": 1024, "wg_m": 128, "wg_n": 64, "sg_m": 32, "sg_n": 32, - "k": 32, + "k_tile": 32, "load_a_m": 16, "load_a_k": 16, "load_b_k": 32, "load_b_n": 16, - "pf_a_m": 8, - "pf_a_k": 32, - "pf_b_k": 8, - "pf_b_n": 16, - "pf_nb": 1, + "prefetch_a_m": 8, + "prefetch_a_k": 32, + "prefetch_b_k": 8, + "prefetch_b_n": 16, + "nb_prefetch": 1, }, } diff --git a/lighthouse/dialects/__init__.py b/lighthouse/dialects/__init__.py new file mode 100644 index 0000000..450585f --- /dev/null +++ b/lighthouse/dialects/__init__.py @@ -0,0 +1,8 @@ +def register_and_load(): + from . import smt_ext + from . import transform_smt_ext + from . import transform_tune_ext + + smt_ext.register_and_load() + transform_smt_ext.register_and_load() + transform_tune_ext.register_and_load() diff --git a/lighthouse/dialects/smt_ext.py b/lighthouse/dialects/smt_ext.py new file mode 100644 index 0000000..c4c279d --- /dev/null +++ b/lighthouse/dialects/smt_ext.py @@ -0,0 +1,100 @@ +from typing import Callable + +from mlir import ir +from mlir.dialects import smt + +__all__ = ["SMTIntValue", "assert_", "register_and_load"] + + +def register_and_load(context=None): + """Register and load the SMTIntValue caster.""" + + SMTIntValue.register_value_caster() + + +def assert_(predicate: ir.Value[smt.BoolType] | bool, error_message: str = ""): + """Assert normally if a bool else produce an SMT assertion op.""" + + if isinstance(predicate, bool): + assert predicate, error_message + else: + assert_ = smt.assert_(predicate) + if error_message: + assert_.attributes["error"] = ir.StringAttr.get(error_message) + + +def int_to_smt(operand: "int | SMTIntValue") -> "SMTIntValue": + if isinstance(operand, int): + int_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), operand) + return SMTIntValue(smt.int_constant(int_attr)) + return operand + + +def swapped( + f: Callable[["int | SMTIntValue", "int | SMTIntValue"], "int | SMTIntValue"], +) -> Callable[["int | SMTIntValue", "int | SMTIntValue"], "int | SMTIntValue"]: + return lambda a, b: f(b, a) + + +class SMTIntValue(ir.Value[smt.IntType]): + """A Value caster for `!smt.int` that supports Pythonic arithmetic and comparison operations.""" + + def __init__(self, v): + super().__init__(v) + + def __hash__(self): + return super().__hash__() + + @staticmethod + def register_value_caster(): + if not hasattr(SMTIntValue, "_is_registered"): + ir.register_value_caster(smt.IntType.static_typeid)(SMTIntValue) + setattr(SMTIntValue, "_is_registered", True) + + def __add__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_add([self, int_to_smt(rhs)])) + + def __radd__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_add([int_to_smt(lhs), self])) + + def __sub__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_sub(self, int_to_smt(rhs))) + + def __rsub__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_sub(int_to_smt(lhs), self)) + + def __mul__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_mul([self, int_to_smt(rhs)])) + + def __rmul__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_mul([int_to_smt(lhs), self])) + + def __floordiv__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_div(self, int_to_smt(rhs))) + + def __rfloordiv__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_div(int_to_smt(lhs), self)) + + def __mod__(self, rhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_mod(self, int_to_smt(rhs))) + + def __rmod__(self, lhs: "int | SMTIntValue") -> "SMTIntValue": + return SMTIntValue(smt.int_mod(int_to_smt(lhs), self)) + + def __eq__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.eq([self, int_to_smt(rhs)]) + + def __le__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.int_cmp(smt.IntPredicate.le, self, int_to_smt(rhs)) + + def __lt__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.int_cmp(smt.IntPredicate.lt, self, int_to_smt(rhs)) + + def __ge__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.int_cmp(smt.IntPredicate.ge, self, int_to_smt(rhs)) + + def __gt__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]: + return smt.int_cmp(smt.IntPredicate.gt, self, int_to_smt(rhs)) + + def __str__(self): + return super().__str__().replace(ir.Value.__name__, SMTIntValue.__name__) diff --git a/lighthouse/dialects/transform_smt_ext.py b/lighthouse/dialects/transform_smt_ext.py new file mode 100644 index 0000000..70a6cbd --- /dev/null +++ b/lighthouse/dialects/transform_smt_ext.py @@ -0,0 +1,234 @@ +from typing import overload, Sequence, Callable + +from mlir import ir +from mlir.dialects import ext, smt, transform + +from lighthouse.tune import trace + +__all__ = [ + "ConstrainParamsOp", + "TransformSMTDialectExtension", + "constrain_params", + "register_and_load", +] + + +def register_and_load(context=None): + """Register and load the TransformSMTDialectExtension and its operations.""" + + TransformSMTDialectExtension.load() + + +class TransformSMTDialectExtension(ext.Dialect, name="transform_smt_ext"): + """A Transform Dialect extension for SMT-related operations.""" + + @classmethod + def load(cls, *args, **kwargs): + super(TransformSMTDialectExtension, cls).load(*args, **kwargs) + + for op in cls.operations: + if hasattr(op, "attach_interfaces"): + op.attach_interfaces() + + +class ConstrainParamsOp( + TransformSMTDialectExtension.Operation, name="constrain_params" +): + """Constrain transform params by SMT ops while also producing new params. + + In effect applies a predicate defined by the SMT ops in the body, which can + reference the parameters as block arguments as !smt.int. The result params + are defined by the !smt.int values yielded from the body. + """ + + results_: Sequence[ext.Result[transform.AnyParamType]] + params: Sequence[ext.Operand[transform.AnyParamType]] + body_: ext.Region + + @property + def body(self): + return self.body_.blocks[0] + + @classmethod + def attach_interfaces(cls, ctx=None): + if not hasattr(cls, "_interfaces_attached"): + cls.ConstrainParamsTransformOpInterfaceModel.attach( + cls.OPERATION_NAME, context=ctx + ) + cls.ConstrainParamsMemoryEffectsOpInterfaceModel.attach( + cls.OPERATION_NAME, context=ctx + ) + setattr(cls, "_interfaces_attached", True) + + class ConstrainParamsTransformOpInterfaceModel(transform.TransformOpInterface): + """TransformOpInterface impl for evaluating the SMT constraints and producing new params.""" + + @staticmethod + def apply( + op: "ConstrainParamsOp", + _rewriter: transform.TransformRewriter, + results: transform.TransformResults, + state: transform.TransformState, + ) -> transform.DiagnosedSilenceableFailure: + # Set up the tracing environment by obtaining the transform params + # and mapping them to Node constants, so that the traced Node + # representation will refer to the params as just constants. + env = dict() + for operand in op.params: + params = state.get_params(operand) + assert len(params) == 1 and isinstance(params[0].value, int) + env[operand] = trace.Constant(params[0].value) + + # Obtained traced representation of the body of the op. + env = trace.trace_tune_and_smt_ops(op.operation, env) + + # Evaluate the predicate that represents the successful execution of the body. + if not env[op].evaluate(env): + return transform.DiagnosedSilenceableFailure.DefiniteFailure + + # If the predicate is satisfied, we can extract the values of the result params + # from the environment and set them as the results of the transformation. + for result in op.results: + res_value = env[result].evaluate(env) + i64 = ir.IntegerType.get_signless(64) + results.set_params(result, [ir.IntegerAttr.get(i64, res_value)]) + + return transform.DiagnosedSilenceableFailure.Success + + @staticmethod + def allow_repeated_handle_operands(_op: "ConstrainParamsOp") -> bool: + return False + + class ConstrainParamsMemoryEffectsOpInterfaceModel(ir.MemoryEffectsOpInterface): + @staticmethod + def get_effects(op: "ConstrainParamsOp", effects): + if op.op_operands: + transform.only_reads_handle(op.op_operands, effects) + transform.produces_handle(op.results, effects) + transform.only_reads_payload(effects) + + +class MixedResultConstrainParamsOp(ConstrainParamsOp): + """ConstrainParamsOp that supports both integer and SMTIntValues as results. + + Used to support `constrain_params` as a decorator on functions that yield a + mix of Python integers and `!smt.int`s (which are either arguments to the + function/block or the result of operations in the body). Upon the body's function + returning, the original ConstrainParamsOp is replaced with this version + that has the same parameters but whose `.results` corresponds to the mix of + integers and SMT values yielded from the body. + """ + + def __init__( + self, + *args, + result_values_or_types: Sequence[int | ir.Type], + **kwargs, + ): + result_types = [ + res for res in result_values_or_types if isinstance(res, ir.Type) + ] + super().__init__(result_types, *args, **kwargs) + op_results = iter(super().results) + self._results = [ + next(op_results) if isinstance(res, ir.Type) else res + for res in result_values_or_types + ] + + @property + def results(self) -> Sequence[int | ext.Result[transform.AnyParamType]]: + return self._results + + +@overload +def constrain_params( + *params: ir.Value | int, loc=None, ip=None +) -> Callable[..., MixedResultConstrainParamsOp]: + """Calls the decorated function with param args converted to !smt.int args. + + The decorated function defines the body of the ConstrainParamsOp and handles + args as `!smt.int` or Python integer. The function should yield a mix of + Python integers and `!smt.int`s (the latter can be either block arguments or + results of operations in the body). The original ConstrainParamsOp created + by the decorator will be replaced with a MixedResultConstrainParamsOp that + has the same parameters but whose results correspond to the mix of integers + and SMT values yielded from the body. + """ + + ... + + +@overload +def constrain_params( + results: Sequence[ir.Type], + params: Sequence[transform.AnyParamType], + loc=None, + ip=None, +) -> ConstrainParamsOp: + """Creates a ConstrainParamsOp where the body is defined by the caller.""" + + ... + + +def constrain_params( + *args, **kwargs +) -> ConstrainParamsOp | Callable[..., MixedResultConstrainParamsOp]: + """Creates a ConstrainParamsOp or a decorator for a function that yields mixed results.""" + + # The second overload: + if len(args) == 0 or not ( + isinstance(args[0], ir.Value) or isinstance(args[0], int) + ): + params = kwargs.get("params") or args[1] + arg_types = [smt.IntType.get()] * len(params) + op = ConstrainParamsOp(*args, **kwargs) + op.body_.blocks.append(*arg_types) + return op + + # The first overload: + def wrapper(func): + # Create a ConstrainParamsOp with just the transform parameters as block arguments. + param_args = [p for p in args if isinstance(p, ir.Value)] + constrain_params = ConstrainParamsOp([], param_args, **kwargs) + constrain_params.body_.blocks.append(*[smt.IntType.get()] * len(param_args)) + + # Call `func` with !smt.int block arguments for corresponding transform params, + # and just normal ints for those passed via `args`. The body of `func` will be + # the body of the op, and it can yield a mix of Python integers and `!smt.int`s. + # A corresponding `smt.yield` will be generated as the terminator. + block_args_iter = iter(constrain_params.body_.blocks[0].arguments) + with ir.InsertionPoint(constrain_params.body): + yielded_results = func( + *( + next(block_args_iter) if isinstance(arg, ir.Value) else arg + for arg in args + ) + ) + if not isinstance(yielded_results, Sequence): + yielded_results = [yielded_results] + smt.yield_(res for res in yielded_results if isinstance(res, ir.Value)) + + # In case no results are returned, the current ConstrainParamsOp is sufficient. + if len(yielded_results) == 0: + return constrain_params + + # Create a new version of the ConstrainParamsOp that has the same + # parameters but whose results correspond to the mix of integers and + # SMT values yielded from the body. + result_values_or_types = [ + transform.AnyParamType.get() if isinstance(res, ir.Value) else res + for res in yielded_results + ] + + mixed_result_op = MixedResultConstrainParamsOp( + params=param_args, + result_values_or_types=result_values_or_types, + **kwargs, + ) + # Move the body of the original op to the version with (mixed) results. + constrain_params.body_.blocks[0].append_to(mixed_result_op.body_) + # Safe to remove as the op doesn't have results, so no users either. + constrain_params.erase() + return mixed_result_op + + return wrapper diff --git a/lighthouse/dialects/transform_tune_ext.py b/lighthouse/dialects/transform_tune_ext.py new file mode 100644 index 0000000..5cecf8e --- /dev/null +++ b/lighthouse/dialects/transform_tune_ext.py @@ -0,0 +1,219 @@ +import inspect +import re +import ast +import math +from dataclasses import dataclass +from typing import Any, Callable, Literal, Optional, Sequence +from functools import wraps +from operator import mod + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import tune as transform_tune + +__all__ = ["KnobValue", "knob"] + + +def register_and_load(context=None): + pass # NB: currently nothing to register or load. + + +def knob( + *args, + result: Optional[ir.Type] = None, + **kwargs, +) -> "KnobValue": + """Create a `transform.tune.knob` op whose result is wrapped in/cast to KnobValue.""" + + options = ir.DictAttr.get() + result = result or transform.AnyParamType.get() + return KnobValue( + transform_tune.KnobOp(result, *args, options=options, **kwargs).result + ) + + +def update_knob_options(knob: transform_tune.KnobOp, key, value): + items = list((namedattr.name, namedattr.attr) for namedattr in knob.options) + if isinstance(value, int): + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) + items.append((key, value)) + knob.options = ir.DictAttr.get(dict(items)) + + +class KnobValue(ir.Value): + """Wrapper for KnobOp's result for a pythonic API for specifying the knob's constraints.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def in_(self, options: Sequence): + """Specify that the knob's value must be among the given options.""" + + i64 = ir.IntegerType.get_signless(64) + options_attr = ir.ArrayAttr.get([ir.IntegerAttr.get(i64, v) for v in options]) + + assert ( + isinstance(self.owner.options, ir.DictAttr) and len(self.owner.options) == 0 + ) # Only one constraint supported for now. + self.owner.options = ir.DictAttr.get({"options": options_attr}) + return True + + @staticmethod + def ast_rewrite(in_exprs: bool = False): + """Decorator to allow `in` expressions on KnobOps in the function's body. + + Rewrite the function's AST to replace `in` expressions with calls to `In`, + which in case the LHS is a KnobOp corresponds to calling `KnobOp.in_`. + """ + + def decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + # Get the func's textual source and deal with the function being indented. + func_source = inspect.getsource(func) + indent = math.inf + for line in func_source.splitlines(): + indent = min(indent, len(re.match(" *", line).group(0))) + func_source = "\n".join( + line[indent:] for line in func_source.splitlines() + ) + # Obtain the corresponding AST. + func_ast = ast.parse(func_source) + func_def_ast = func_ast.body[0] + + # TODO: in case of multiple decorators, remove just @KnobValue.ast_rewrite + func_def_ast.decorator_list.clear() # Remove the decorator to avoid infinite recursion. + if in_exprs: + # Apply the rewriting of `in` expressions. + func_def_ast.body = [ + InTransformer().visit(stmt) for stmt in func_def_ast.body + ] + ast.fix_missing_locations(func_def_ast) + # Adjust line numbers to match the original source file. + source_file = inspect.getsourcefile(func) or "" + _, start_lineno = inspect.getsourcelines(func) + ast.increment_lineno(func_ast, start_lineno - 1) + # Compile from the AST directly to preserve line number mapping. + mod = compile(func_ast, filename=source_file, mode="exec") + frame = inspect.currentframe() + assert frame and frame.f_back + # Make the original function's globals and locals available to the rewritten function. + temp_globals = frame.f_back.f_globals.copy() + temp_globals |= frame.f_back.f_locals.copy() + temp_locals = frame.f_back.f_locals.copy() + temp_globals["In"] = In + # Make the rewritten function available as a value at the original name. + exec(mod, temp_globals, temp_locals) + # Call the rewritten function and return its result. + return temp_locals[func.__name__](*args, **kwargs) + + return wrapper + + return decorator + + def _set_bound(self, key, combine, value): + assert isinstance(self.owner.options, ir.DictAttr) + existing = self.owner.options[key] if key in self.owner.options else None + update_knob_options( + self.owner.opview, + key, + value if existing is None else combine(existing.value, value), + ) + return True + + def __mod__(self, other): + assert isinstance(other, int) + return KnobExpression(lhs=self, rhs=other, operator=mod) + + def __rmod__(self, other): + assert isinstance(other, int) + return KnobExpression(lhs=other, rhs=self, operator=mod) + + def __lt__(self, other): + assert isinstance(other, int) + return self._set_bound("upper_bound", min, other + 1) + + def __le__(self, other): + assert isinstance(other, int) + return self._set_bound("upper_bound", min, other) + + def __ge__(self, other): + assert isinstance(other, int) + return self._set_bound("lower_bound", max, other) + + def __gt__(self, other): + assert isinstance(other, int) + return self._set_bound("lower_bound", max, other + 1) + + def __eq__(self, other): + assert isinstance(other, int) + assert isinstance(self.owner.options, ir.DictAttr) + assert len(self.owner.options) == 0, "Only one constraint supported for now." + i64 = ir.IntegerType.get_signless(64) + update_knob_options( + self.owner.opview, + "options", + ir.ArrayAttr.get([ir.IntegerAttr.get(i64, other)]), + ) + return True + + +@dataclass +class KnobExpression: + """Helper class to represent expressions on KnobValues that then occur in equalities. + + In order to support (the LHS) in such constraints as `knob("X") % 16 == 0`. + """ + + lhs: KnobValue | int + rhs: KnobValue | int + operator: Literal[mod] + + def __eq__(self, other): + assert other == 0, "Only equality to zero supported for now." + assert self.operator is mod + if isinstance(self.lhs, KnobValue): + assert isinstance(self.lhs.owner.options, ir.DictAttr) + assert isinstance(self.rhs, int) + assert "divisible_by" not in self.lhs.owner.options + update_knob_options(self.lhs.owner.opview, "divisible_by", self.rhs) + elif isinstance(self.rhs, KnobValue): + assert isinstance(self.lhs, int) + assert isinstance(self.rhs.owner.options, ir.DictAttr) + assert "divides" not in self.rhs.owner.options + update_knob_options(self.rhs.owner.opview, "divides", self.lhs) + else: + assert False, "At least one operand must be a KnobValue." + + return True + + +@dataclass +class In: + """Helper for rewriting `in` expressions. + + Only `knob('X') in Y` gets mapped to `knob('X').in_(Y)` everything else + corresponds to just a regular Python `in` expression. + """ + + lhs: Any + rhs: Any + + def __bool__(self): + if isinstance(self.lhs, KnobValue): + return self.lhs.in_(self.rhs) + return self.lhs in self.rhs + + +class InTransformer(ast.NodeTransformer): + """AST transformer to rewrite `in` expressions to calls to `In`.""" + + def visit_Compare(self, node: ast.Compare) -> Any: + self.generic_visit(node) + if len(node.ops) == 1 and isinstance(node.ops[0], ast.In): + return ast.Call( + func=ast.Name(id="In", ctx=ast.Load()), + args=[node.left, node.comparators[0]], + keywords=[], + ) + return node diff --git a/lighthouse/schedule/xegpu/mlp_schedule.py b/lighthouse/schedule/xegpu/mlp_schedule.py index 9020b30..d9c211d 100644 --- a/lighthouse/schedule/xegpu/mlp_schedule.py +++ b/lighthouse/schedule/xegpu/mlp_schedule.py @@ -1,3 +1,5 @@ +from collections import namedtuple + from mlir import ir from mlir.dialects.transform import loop from mlir.dialects.transform import bufferization @@ -10,7 +12,25 @@ canonicalize, match, ) -from typing import Optional + +from lighthouse.dialects import smt_ext, transform_smt_ext as td_smt_ext +from lighthouse.dialects.transform_tune_ext import knob, KnobValue + +# hardware constraints +DPAS = namedtuple("DPAS", ["M", "N", "K", "A_TILE", "B_TILE", "C_TILE"])( + 8, 16, 16, (8, 16), (16, 16), (8, 16) +) +PREFETCH_INST_DATA = [8, 16] +NB_WORKITEMS = 16 # workitems in subgroup +LOAD_MAX_ROWS = 32 +LOAD_MAX_COLS = 16 +PFETCH_MIN_ROWS = 8 +PFETCH_MAX_ROWS = 32 +PFETCH_MIN_COLS = 16 +PFETCH_MAX_COLS = 32 +MAX_NB_SG_THREADS = 64 +# heuristics: skip likely suboptimal configurations +MIN_NB_THREADS = 16 class PipelineInterrupt(Exception): @@ -29,20 +49,75 @@ def match_and_split(*args, nhandles=1, **kwargs): return matched_ops -# hardware constraints -DPAS_TILE = [8, 16, 16] -PREFETCH_INST_DATA = [8, 16] -NB_WORKITEMS = 16 # workitems in subgroup +@KnobValue.ast_rewrite(in_exprs=True) +def checked_params_or_knobs( + params: dict[str, int | None], layer_id="" +) -> dict[str, int | KnobValue]: + """Check the parameters for validity and replace `None`s with knobs with asserted ranges.""" + m, n, k = params["m"], params["n"], params["k"] + assert isinstance(m, int) and isinstance(n, int) and isinstance(k, int) + assert m > 0 and n > 0 and k > 0 + wg_m = params["wg_m"] or knob(layer_id + "wg_m") + wg_n = params["wg_n"] or knob(layer_id + "wg_n") + sg_m = params["sg_m"] or knob(layer_id + "sg_m") + sg_n = params["sg_n"] or knob(layer_id + "sg_n") + k_tile = params["k_tile"] or knob(layer_id + "k_tile") + load_a_m = params["load_a_m"] or knob(layer_id + "load_a_m") + load_a_k = params["load_a_k"] or knob(layer_id + "load_a_k") + load_b_k = params["load_b_k"] or knob(layer_id + "load_b_k") + load_b_n = params["load_b_n"] or knob(layer_id + "load_b_n") + prefetch_a_m = params["prefetch_a_m"] or knob(layer_id + "prefetch_a_m") + prefetch_a_k = params["prefetch_a_k"] or knob(layer_id + "prefetch_a_k") + prefetch_b_k = params["prefetch_b_k"] or knob(layer_id + "prefetch_b_k") + prefetch_b_n = params["prefetch_b_n"] or knob(layer_id + "prefetch_b_n") + + # NB: Constraints on knobs will be added as attributes on the KnobOps, while + # constraints on concrete values will be checked immediately. + assert min(max(m // 4, 16), 64) <= wg_m <= min(m, 256) + assert m % wg_m == 0 and wg_m % DPAS.M == 0 + assert min(max(n // 4, 16), 64) <= wg_n <= min(n, 256) + assert n % wg_n == 0 and wg_n % DPAS.N == 0 + assert min(max(m // 8, 16), 32) <= sg_m <= min(m, 128) + assert m % sg_m == 0 and sg_m % DPAS.M == 0 + assert min(max(n // 8, 16), 32) <= sg_n <= min(n, 128) + assert n % sg_n == 0 and sg_n % DPAS.N == 0 + assert 16 <= k_tile <= min(k, 256) + assert k % k_tile == 0 and k_tile % DPAS.K == 0 + + LOAD_TILE_SIZES = [8, 16, 32] + assert load_a_m in LOAD_TILE_SIZES and load_a_m % DPAS.M == 0 + assert load_a_k in LOAD_TILE_SIZES and load_a_k % DPAS.K == 0 + assert load_b_k in LOAD_TILE_SIZES and load_b_k % DPAS.K == 0 + assert load_b_n in LOAD_TILE_SIZES and load_b_n % DPAS.N == 0 + assert prefetch_a_m in LOAD_TILE_SIZES + assert prefetch_a_k in LOAD_TILE_SIZES + assert prefetch_b_k in LOAD_TILE_SIZES + assert prefetch_b_n in LOAD_TILE_SIZES + + return { + "wg_m": wg_m, + "wg_n": wg_n, + "sg_m": sg_m, + "sg_n": sg_n, + "k_tile": k_tile, + "load_a_m": load_a_m, + "load_a_k": load_a_k, + "load_b_k": load_b_k, + "load_b_n": load_b_n, + "prefetch_a_m": prefetch_a_m, + "prefetch_a_k": prefetch_a_k, + "prefetch_b_k": prefetch_b_k, + "prefetch_b_n": prefetch_b_n, + } def get_schedule_module( + params: list[dict[str, int | None]], has_bias: bool = False, has_relu: bool = False, has_convert_c: bool = True, skip_final_layer_relu: bool = False, stop_at_stage: str = "", - nlayers: int = 1, - params: Optional[dict] = None, ) -> ir.Module: """Generate transform schedule module.""" mod = ir.Module.create() @@ -64,41 +139,43 @@ def get_schedule_module( op_name="builtin.module", deduplicate=True, ) + for i, layer_params in enumerate(params): + layer_params |= checked_params_or_knobs( + layer_params, layer_id=f"layer_{i}_" + ) + xegpu_mlp_transform_schedule( payload_mod, + params=params, has_bias=has_bias, has_relu=has_relu, has_convert_c=has_convert_c, skip_final_layer_relu=skip_final_layer_relu, stop_at_stage=stop_at_stage, - nlayers=nlayers, - params=params, ) return mod def xegpu_mlp_transform_schedule( - mod: ir.Value, + mod: ir.Value[transform.AnyOpType], + params: list[dict[str, int | KnobValue]], has_bias: bool = False, has_relu: bool = False, has_convert_c: bool = True, skip_final_layer_relu: bool = False, stop_at_stage: str = "", - nlayers: int = 1, - params: Optional[list[dict]] = None, ): """Transform schedule for MLP-like payload.""" try: mod = bundle_xegpu_mlp_schedule( mod, + params=params, has_bias=has_bias, has_relu=has_relu, has_convert_c=has_convert_c, skip_final_layer_relu=skip_final_layer_relu, stop_at_stage=stop_at_stage, - nlayers=nlayers, - params=params, ) mod = bundle_xegpu_to_binary( @@ -112,27 +189,22 @@ def xegpu_mlp_transform_schedule( def bundle_xegpu_mlp_schedule( - mod: ir.Value, + mod: ir.Value[transform.AnyOpType], + params: list[dict[str, int | KnobValue]], has_bias: bool = False, has_relu: bool = False, skip_final_layer_relu: bool = False, has_convert_c: bool = True, stop_at_stage: str = "", - nlayers: int = 1, - params: Optional[list[dict]] = None, -) -> ir.Module: +) -> ir.Value[transform.AnyOpType]: """Schedule for lowering MLP-like payload to xegpu wg level.""" - if params is None: - raise ValueError("Schedule parameters must be provided.") + nlayers = len(params) if stop_at_stage == "initial": raise PipelineInterrupt() anytype = transform.AnyOpType.get() - for i in range(nlayers): - assert f"layer_{i}" in params, f"Missing parameters for 'layer_{i}'" - # wg tiling if has_convert_c: trunc_op = match(mod, ops={"arith.truncf"}) @@ -156,17 +228,14 @@ def bundle_xegpu_mlp_schedule( terminal_ops = list(relu_ops) + [terminal_ops[-1]] # tile each layer separately - for i_layer in range(nlayers): - layer_params = params[f"layer_{i_layer}"] + for terminal_op, layer_params in zip(terminal_ops, params): # tunable parameters: wg level tiling wg_tile = [layer_params["wg_m"], layer_params["wg_n"]] - sg_tile = [layer_params["sg_m"], layer_params["sg_n"]] - k_tile = layer_params["k"] + k_tile = layer_params["k_tile"] - terminal = terminal_ops[i_layer] - # FIXME use structured.structured_fuse + # FIXME: use structured.structured_fuse _, wg_loop = structured.FuseOp( - terminal, tile_sizes=wg_tile, use_forall=True + terminal_op, tile_sizes=wg_tile, use_forall=True ).results transform.apply_cse(mod) canonicalize(mod) @@ -241,16 +310,38 @@ def bundle_xegpu_mlp_schedule( # set correct number of gpu threads launch_ops = match_and_split(mod, ops={"gpu.launch"}, nhandles=nlayers) - for i_layer, launch_op in enumerate(launch_ops): - layer_params = params[f"layer_{i_layer}"] + assert len(launch_ops) == nlayers + for launch_op, layer_params in zip(launch_ops, params): # tunable parameters - wg_tile = [layer_params["wg_m"], layer_params["wg_n"]] - sg_tile = [layer_params["sg_m"], layer_params["sg_n"]] - - # derived parameters - sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] - # number of threads collapsed to 1d layout - nb_threads = sg_layout[0] * sg_layout[1] * NB_WORKITEMS + wg_m, wg_n = layer_params["wg_m"], layer_params["wg_n"] + sg_m, sg_n = layer_params["sg_m"], layer_params["sg_n"] + + @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) + def constrain_wg_sg_and_calc_nb_threads( + WG_M: int | smt_ext.SMTIntValue, + WG_N: int | smt_ext.SMTIntValue, + SG_M: int | smt_ext.SMTIntValue, + SG_N: int | smt_ext.SMTIntValue, + ): + # NB: normal asserts in case of concrete values, SMT assert ops for symbolic values. + smt_ext.assert_(WG_M % SG_M == 0) + smt_ext.assert_(WG_N % SG_N == 0) + + # NB: normal ints in case of concrete values, SMT int values for symbolic values. + sg_m_threads = WG_M // SG_M + sg_n_threads = WG_N // SG_N + sg_threads = sg_m_threads * sg_n_threads + smt_ext.assert_(sg_threads <= MAX_NB_SG_THREADS, "too many SG threads") + if isinstance(sg_threads, smt_ext.SMTIntValue): + # NB: Constraint only enabled during tuning. + smt_ext.assert_(sg_threads >= MIN_NB_THREADS, "too few SG threads") + + # number of threads collapsed to 1d layout + return sg_threads * NB_WORKITEMS + + nb_threads: int | smt_ext.SMTIntValue = ( + constrain_wg_sg_and_calc_nb_threads.results + ) xegpu.set_gpu_launch_threads(launch_op, threads=[nb_threads, 1, 1]) @@ -278,11 +369,12 @@ def bundle_xegpu_mlp_schedule( if stop_at_stage == "xegpu-initial": raise PipelineInterrupt() - for i_layer, gpu_mod in enumerate(gpu_mod_ops): + assert len(gpu_mod_ops) == nlayers, ( + "Expected one gpu.module per MLP layer after outlining" + ) + for gpu_mod, layer_params in zip(gpu_mod_ops, params): gpu_func = match(gpu_mod, ops={"gpu.func"}) - xegpu_wg_annotation_for_mlp_layer( - gpu_func, params[f"layer_{i_layer}"], has_bias=has_bias - ) + xegpu_wg_annotation_for_mlp_layer(gpu_func, **layer_params, has_bias=has_bias) if stop_at_stage == "xegpu-wg": raise PipelineInterrupt() @@ -290,44 +382,126 @@ def bundle_xegpu_mlp_schedule( return mod -def xegpu_wg_annotation_for_mlp_layer(gpu_func: ir.Value, params: dict, has_bias: bool): +def xegpu_wg_annotation_for_mlp_layer( + gpu_func: ir.Value, + *, + wg_m: int | KnobValue, + wg_n: int | KnobValue, + sg_m: int | KnobValue, + sg_n: int | KnobValue, + k_tile: int | KnobValue, + load_a_m: int | KnobValue, + load_a_k: int | KnobValue, + load_b_k: int | KnobValue, + load_b_n: int | KnobValue, + prefetch_a_m: int | KnobValue, + prefetch_a_k: int | KnobValue, + prefetch_b_k: int | KnobValue, + prefetch_b_n: int | KnobValue, + nb_prefetch: int, + has_bias: bool, + **_catch_all, +): """ Adds prefetching and XeGPU anchor layout annotations for an MLP layer. Should be applied after the payload has been converted to XeGPU using the convert-vector-to-xegpu pass. """ + anytype = transform.AnyOpType.get() anyvalue = transform.AnyValueType.get() - dpas_shape_a = [DPAS_TILE[0], DPAS_TILE[2]] - dpas_shape_b = [DPAS_TILE[2], DPAS_TILE[1]] - dpas_shape_c = [DPAS_TILE[0], DPAS_TILE[1]] - - wg_tile = [params["wg_m"], params["wg_n"]] - sg_tile = [params["sg_m"], params["sg_n"]] - k_tile = params["k"] - - sg_layout = [wg_tile[0] // sg_tile[0], wg_tile[1] // sg_tile[1]] - - load_tile_a = [params["load_a_m"], params["load_a_k"]] - load_tile_b = [params["load_b_k"], params["load_b_n"]] - prefetch_tile_a = [params["pf_a_m"], params["pf_a_k"]] - prefetch_tile_b = [params["pf_b_k"], params["pf_b_n"]] - nb_prefetch = params["pf_nb"] - - prefetch_layout_a = [ - wg_tile[0] // prefetch_tile_a[0], - k_tile // prefetch_tile_a[1], - ] - prefetch_layout_b = [ - k_tile // prefetch_tile_b[0], - wg_tile[1] // prefetch_tile_b[1], - ] + # Calculate with SMT ops in case of symbolic values, normal ints in case of concrete values. + @td_smt_ext.constrain_params(wg_m, wg_n, sg_m, sg_n) + def calc_sg_layout(WG_M, WG_N, SG_M, SG_N): + # NB: Constraint on overall num SG threads already dealt with elsewhere. + return WG_M // SG_M, WG_N // SG_N + + sg_layout = calc_sg_layout.results + + load_tile_a = [load_a_m, load_a_k] + load_tile_b = [load_b_k, load_b_n] + prefetch_tile_a = [prefetch_a_m, prefetch_a_k] + prefetch_tile_b = [prefetch_b_k, prefetch_b_n] + + @td_smt_ext.constrain_params( + sg_m, + sg_n, + k_tile, + load_a_m, + load_a_k, + load_b_k, + load_b_n, + prefetch_a_m, + prefetch_a_k, + prefetch_b_k, + prefetch_b_n, + ) + def constrain_and_calculate_load_and_prefetch_params( + SG_M, SG_N, K_TILE, LDA_M, LDA_K, LDB_K, LDB_N, PFA_M, PFA_K, PFB_K, PFB_N + ): + # NB: normal asserts in case of concrete values, SMT assert ops for symbolic values + smt_ext.assert_(SG_M % LDA_M == 0) + smt_ext.assert_(K_TILE % LDA_K == 0) + smt_ext.assert_(K_TILE % LDB_K == 0) + smt_ext.assert_(SG_N % LDB_N == 0) + + smt_ext.assert_(LDA_M <= LOAD_MAX_ROWS) + smt_ext.assert_(LDA_K <= LOAD_MAX_COLS) + smt_ext.assert_(LDB_K <= LOAD_MAX_ROWS) + smt_ext.assert_(LDB_N <= LOAD_MAX_COLS) + + smt_ext.assert_(SG_M % PFA_M == 0) + smt_ext.assert_(K_TILE % PFA_K == 0) + smt_ext.assert_(K_TILE % PFB_K == 0) + smt_ext.assert_(SG_N % PFB_N == 0) + + smt_ext.assert_(PFA_M <= PFETCH_MAX_ROWS) + smt_ext.assert_(PFA_K <= PFETCH_MAX_COLS) + smt_ext.assert_(PFB_K <= PFETCH_MAX_ROWS) + smt_ext.assert_(PFB_N <= PFETCH_MAX_COLS) + smt_ext.assert_(PFA_M >= PFETCH_MIN_ROWS) + smt_ext.assert_(PFA_K >= PFETCH_MIN_COLS) + smt_ext.assert_(PFB_K >= PFETCH_MIN_ROWS) + smt_ext.assert_(PFB_N >= PFETCH_MIN_COLS) + + smt_ext.assert_(LDA_M % DPAS.M == 0) + smt_ext.assert_(LDA_K % DPAS.K == 0) + smt_ext.assert_(LDB_K % DPAS.K == 0) + smt_ext.assert_(LDB_N % DPAS.N == 0) + + nb_load_b_n = LDB_N // DPAS.N + # unsupported VNNI layout, loaded tile can only be row-sliced for vnni + # NOTE this can plausibly be relaxed + smt_ext.assert_(nb_load_b_n <= 1, "invalid load_tile_b_n for VNNI") + + # prefetch A layout + nb_prefetch_a_m = SG_M // PFA_M + nb_prefetch_a_k = K_TILE // PFA_K + nb_prefetch_a = nb_prefetch_a_m * nb_prefetch_a_k + smt_ext.assert_(nb_prefetch_a <= MAX_NB_SG_THREADS) + if isinstance(nb_prefetch_a, smt_ext.SMTIntValue): + # NB: Constraint only enabled during tuning. + smt_ext.assert_(nb_prefetch_a_m * nb_prefetch_a_k >= MIN_NB_THREADS) + + # prefetch B layout + nb_prefetch_b_k = K_TILE // PFB_K + nb_prefetch_b_n = SG_N // PFB_N + nb_prefetch_b = nb_prefetch_b_k * nb_prefetch_b_n + smt_ext.assert_(nb_prefetch_b <= MAX_NB_SG_THREADS) + if isinstance(nb_prefetch_b, smt_ext.SMTIntValue): + # NB: Constraint only enabled during tuning. + smt_ext.assert_(nb_prefetch_b_k * nb_prefetch_b_n >= MIN_NB_THREADS) + + return nb_prefetch_a_m, nb_prefetch_a_k, nb_prefetch_b_k, nb_prefetch_b_n + + prefetch_layout_a = constrain_and_calculate_load_and_prefetch_params.results[0:2] + prefetch_layout_b = constrain_and_calculate_load_and_prefetch_params.results[2:4] # matmul matrix shapes - sg_tile_a = [sg_tile[0], k_tile] - sg_tile_b = [k_tile, sg_tile[1]] + sg_tile_a = [sg_m, k_tile] + sg_tile_b = [k_tile, sg_n] # add layouts to DPAS op operands k_loop = match(gpu_func, ops={"scf.for"}) @@ -382,7 +556,7 @@ def annotate_ab_load(tile, layout_load, layout_dpas): } # A tile dpas layout layout_dpas_a = layout_load_a.copy() - layout_dpas_a["inst_data"] = dpas_shape_a + layout_dpas_a["inst_data"] = DPAS.A_TILE annotate_ab_load(tile_a, layout_load_a, layout_dpas_a) # B tile load layout @@ -393,14 +567,14 @@ def annotate_ab_load(tile, layout_load, layout_dpas): } # B tile dpas layout layout_dpas_b = layout_load_b.copy() - layout_dpas_b["inst_data"] = dpas_shape_b + layout_dpas_b["inst_data"] = DPAS.B_TILE annotate_ab_load(tile_b, layout_load_b, layout_dpas_b) # C tile layout output_layout = { "sg_layout": sg_layout, - "sg_data": sg_tile, - "inst_data": dpas_shape_c, + "sg_data": [sg_m, sg_n], + "inst_data": DPAS.C_TILE, } # C tile dpas anchor layout xegpu.set_op_layout_attr(dpas_op, index=0, **layout_dpas_a) @@ -429,7 +603,9 @@ def annotate_ab_load(tile, layout_load, layout_dpas): transform.apply_cse(gpu_func) -def bundle_xegpu_to_binary(mod, stop_at_stage: str = "") -> ir.Module: +def bundle_xegpu_to_binary( + mod, stop_at_stage: str = "" +) -> ir.Value[transform.AnyOpType]: """Schedule for lowering xegpu wg level to binary.""" # upstream xegpu/xevm pipeline is payload independent. mod = apply_registered_pass( diff --git a/lighthouse/tune/__main__.py b/lighthouse/tune/__main__.py new file mode 100644 index 0000000..de3f916 --- /dev/null +++ b/lighthouse/tune/__main__.py @@ -0,0 +1,92 @@ +import sys +import argparse +from typing import Mapping + +from mlir import ir +from mlir.dialects import transform +from lighthouse.tune import ( + rewrite as lh_tune_rewrite, + trace as lh_tune_trace, + enumerate as lh_tune_enumerate, +) +from lighthouse import dialects as lh_dialects +from lighthouse.utils.types import LazyChainMap + +HEADER = "//" * 40 + "\n// {}\n" + "//" * 40 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("file", type=str, help="Path to the MLIR file to process") + parser.add_argument("--count-only", action="store_true") + parser.add_argument( + "--mode", + choices=["enumerate"], + default="enumerate", + help="Mode of operation", + ) + parser.add_argument( + "-n", type=int, help="Number of concrete schedules to output", default=1 + ) + args = parser.parse_args() + + file = sys.stdin if args.file == "-" else open(args.file, "r") + with ir.Context() as ctx, ir.Location.unknown(): + lh_dialects.register_and_load() + + module = ir.Module.parse(file.read()) + + if args.mode == "enumerate": + # Trace the named_seq, obtaining a DAG of tunable nodes and nodes + # which are functions and predicates dependent on the tunable nodes. + named_seq = module.body.operations[0].opview + assert isinstance(named_seq, transform.NamedSequenceOp) + op_or_value_to_node = lh_tune_trace.trace_tune_and_smt_ops( + named_seq.operation + ) + + # The predicate associated to the overall named_seq is the conjunct + # -ion of each of the predicates for each operation in seq's body. + overall_predicate = op_or_value_to_node[named_seq] + assert isinstance(overall_predicate, lh_tune_trace.Predicate) + tunables = list( + set( + node + for node in op_or_value_to_node.values() + if isinstance(node, lh_tune_trace.NonDeterministic) + ) + ) + + # Start enumerating assignments for the tune.knob and tune.alternatives ops. + count = 0 + for count, node_to_int in zip( + range(1, args.n + 1), + lh_tune_enumerate.all_satisfying_assignments( + tunables, [overall_predicate] + ), + ): + if args.count_only: + if count >= args.n: + break + continue + + print(HEADER.format(f"Config {count}:")) + + i64 = ir.IntegerType.get_signless(64) + + # Map the tuneable ops to the attributes that should assigned to them. + mapping: Mapping[ir.Value | ir.Operation, ir.Attribute] = LazyChainMap( + op_or_value_to_node, + lambda node: ir.IntegerAttr.get(i64, node_to_int[node]), + ) + + # Walk the IR, obtaining and setting the corresponding attr for each tuneable op. + mod_op = lh_tune_rewrite.set_selected(module.operation, mapping) + print(mod_op) + + if count >= args.n: + break + print("// count:", count) + else: + assert False, "Other modes are not yet implemented" diff --git a/lighthouse/tune/enumerate.py b/lighthouse/tune/enumerate.py new file mode 100644 index 0000000..5d62752 --- /dev/null +++ b/lighthouse/tune/enumerate.py @@ -0,0 +1,20 @@ +from itertools import product +from typing import Sequence + +from .trace import NonDeterministic, Predicate + + +def all_satisfying_assignments( + tuneables: Sequence[NonDeterministic], predicates: Sequence[Predicate] +): + """Generate all assignments of values to the tuneables that satisfy the predicates.""" + + for tuneable_values in product( + *(tuneable.possibilities() for tuneable in tuneables) + ): + environment = dict(zip((tunable for tunable in tuneables), tuneable_values)) + for pred in predicates: + if not pred.evaluate(environment): + break + else: + yield environment diff --git a/lighthouse/tune/rewrite.py b/lighthouse/tune/rewrite.py new file mode 100644 index 0000000..c8308a4 --- /dev/null +++ b/lighthouse/tune/rewrite.py @@ -0,0 +1,20 @@ +from typing import Mapping + +from mlir import ir +from mlir.dialects.transform import tune as transform_tune + + +def set_selected(op: ir.Operation, env: Mapping[ir.Value | ir.Operation, ir.Attribute]): + """Walk op's IR and set attrs on transform.tune.* ops per the env mapping.""" + + def set(op: ir.Operation) -> ir.WalkResult: + op = op.opview + match op: + case transform_tune.KnobOp(): + op.attributes["selected"] = env[op.result] + case transform_tune.AlternativesOp(): + op.attributes["selected_region"] = env[op] + return ir.WalkResult.ADVANCE + + op.walk(set, ir.WalkOrder.PRE_ORDER) + return op diff --git a/lighthouse/tune/trace.py b/lighthouse/tune/trace.py new file mode 100644 index 0000000..c7183bc --- /dev/null +++ b/lighthouse/tune/trace.py @@ -0,0 +1,315 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass +from itertools import islice + +from operator import eq, ge, gt, le, lt, ne, mul, mod, floordiv, add + +from typing import Callable, Generator, Sequence, Optional + +from mlir import ir +from mlir.dialects import transform, smt +from mlir.dialects.transform import tune as transform_tune + +from lighthouse.dialects import transform_smt_ext + + +class Node(ABC): + """Base class for `Node`s which can be evaluated w.r.t. an environment.""" + + @abstractmethod + def evaluate(self, environment: dict["Node", int]) -> int | bool: + raise NotImplementedError + + +@dataclass(frozen=True) +class Constant(Node): + """Trivial base case `Node` which evaluates to a constant irrespective of env. + + Intended to represent `Value`s which are constants in IR.""" + + value: int + + def evaluate(self, environment: dict[Node, int]) -> int: + return self.value + + +@dataclass(frozen=True) +class NonDeterministic(Node, ABC): + """Abstract `Node` which evaluates via looking up its name in the environment.""" + + name: str + + @abstractmethod + def possibilities(self) -> Generator[int, None, None]: + raise NotImplementedError + + def evaluate(self, environment: dict[Node, int]) -> int | bool: + return environment[self] + + +@dataclass(frozen=True) +class Knob(NonDeterministic): + """Base case `Node` which evals per name in env and knows its possible values. + + Intended to represent the `Value` associated with a tuneable knob in IR.""" + + options: Optional[Sequence[int]] = None + lower_bound: Optional[int] = None + upper_bound: Optional[int] = None + divisible_by: Optional[int] = None + divides: Optional[int] = None + + def __post_init__(self): + assert self.options or (None not in (self.lower_bound, self.upper_bound)), ( + "Options attribute not finitely specified" + ) + assert self.divisible_by is None or self.divisible_by > 0, ( + "divisible_by must be positive" + ) + assert self.divides is None or self.divides > 0, "divides must be positive" + + def __repr__(self): + return ( + f"{self.name}<".upper() + + f"{list(self.possibilities())}>".replace(", ", "|")[1:-2] + + ">" + ) + + def possibilities(self) -> Generator[int, None, None]: + if self.options is not None: + if self.divides is not None or self.divisible_by is not None: + yield from filter( + lambda val: (self.divides is None or (self.divides % val == 0)) + and (self.divisible_by is None or (val % self.divisible_by == 0)), + self.options, + ) + else: + yield from self.options + else: + low = self.lower_bound + step = 1 + if self.divisible_by is not None: + low = self.lower_bound + (-self.lower_bound % self.divisible_by) + step = self.divisible_by + for val in range(low, self.upper_bound + 1, step): + if self.divides is None or self.divides % val == 0: + yield val + + +@dataclass(frozen=True) +class Apply(Node): + """Recursive case `Node` which calculates a function from other eval-ables. + + Intended to represent `Value`s in IR dependent on other `Value`s.""" + + operator: Callable[..., int] + args: Sequence[Node] + + def evaluate(self, environment: dict[Node, int]) -> int: + return self.operator(*[arg.evaluate(environment) for arg in self.args]) + + +@dataclass(frozen=True) +class Predicate(Node): + """Recursive case `Node` which applies a predicate to other eval-ables. + + Intended to represent the condition that must hold for execution to be able + to succesfully proceed beyond a specified op.""" + + operator: Callable[..., bool] + args: Sequence[Node] + + def evaluate(self, environment: dict[Node, int]) -> bool: + return self.operator(*[arg.evaluate(environment) for arg in self.args]) + + +@dataclass(frozen=True) +class Alternatives(NonDeterministic): + """Recursive case `Node` which acts as its selected child predicate, per its env. + + Intended to represent selection among a fixed number of regions in IR.""" + + alt_idx_to_pred: Sequence[Predicate] + + def evaluate(self, environment: dict[Node, int]) -> bool: + selected_region_idx = super().evaluate(environment) # evals name to int + return self.alt_idx_to_pred[selected_region_idx].evaluate(environment) + + def possibilities(self) -> Generator[int, None, None]: + yield from range(len(self.alt_idx_to_pred)) + + +@dataclass +class AlternativesResult(Node): + """Recursive case `Node` which selects among child nodes per the env-chosen alternative. + + Intended to represent results associated with a particular region being selected.""" + + alternatives: Alternatives + region_idx_to_result: dict[int, Node] + + def evaluate(self, environment: dict[Node, int]) -> int: + alt_idx = environment[self.alternatives] + return self.region_idx_to_result[alt_idx].evaluate(environment) + + +def trace_smt_op(op: ir.Operation, env: dict) -> dict: + """Add mapping from SMT op to a new Node that represents it to the env.""" + + op = op.opview + match op: + case smt.IntConstantOp(): + env[op.result] = Constant(op.value.value) + + case smt.EqOp(): + assert len(op.operands) == 2 + env[op.result] = Predicate(eq, [env[value] for value in op.operands]) + + case smt.IntAddOp(): + assert len(op.operands) == 2 + env[op.result] = Apply(add, [env[value] for value in op.operands]) + + case smt.IntMulOp(): + assert len(op.operands) == 2 + env[op.result] = Apply(mul, [env[value] for value in op.operands]) + + case smt.IntModOp(): + env[op.result] = Apply(mod, [env[op.lhs], env[op.rhs]]) + + case smt.IntDivOp(): + env[op.result] = Apply(floordiv, [env[op.lhs], env[op.rhs]]) + + case smt.IntCmpOp(): + operator = [lt, le, gt, ge][op.pred.value] + env[op.result] = Predicate(operator, [env[op.lhs], env[op.rhs]]) + + case smt.AssertOp(): + pred = env[op.input] + assert isinstance(pred, Predicate), ( + "SMT assert expected argument to map to a Predicate node" + ) + env[op] = pred + + case _: + assert False, f"Unknown SMT operation: {op}" + + return env + + +def trace_tune_and_smt_ops(op: ir.Operation, env: Optional[dict] = None) -> dict: + """Recursively add mapping from transform(.tune) and SMT ops to representative Nodes to env.""" + + env = env if env is not None else {} # TODO: nested scopes + + op = op.opview + match op: + case transform.ParamConstantOp(): + env[op.result] = Constant(op.value.value) + + case transform.MatchParamCmpIOp(): + operator = [eq, ne, le, lt, ge, gt][op.predicate.value] + env[op] = Predicate(operator, [env[op.param], env[op.reference]]) + + case transform_tune.KnobOp(): + kwargs = {} + + # Inspect attrs on KnobOp and convert to args to pass to Knob Node. + if op.selected is not None: + kwargs["options"] = (op.selected.value,) + elif isinstance(op.options, ir.ArrayAttr): + kwargs["options"] = tuple(opt.value for opt in op.options) + elif isinstance(op.options, ir.DictAttr): + if "options" in op.options: + kwargs["options"] = tuple( + opt.value for opt in op.options["options"] + ) + if "lower_bound" in op.options: + kwargs["lower_bound"] = op.options["lower_bound"].value + if "upper_bound" in op.options: + kwargs["upper_bound"] = op.options["upper_bound"].value + if "divisible_by" in op.options: + kwargs["divisible_by"] = op.options["divisible_by"].value + if "divides" in op.options: + kwargs["divides"] = op.options["divides"].value + else: + assert False, "Unknown options attribute type" + + env[op.result] = Knob(name=op.name.value, **kwargs) + + case transform_tune.AlternativesOp(): + # Recursively visit each "alternative" child region, deriving a predicate + # for each region and track which results are associated to that region. + region_idx_to_pred = [] + result_idx_region_idx_to_node = defaultdict(lambda: dict()) + for reg_idx, region in enumerate(op.regions): + region_preds = [] + for child in region.blocks[0]: + trace_tune_and_smt_ops(child.operation, env) + if (child_pred := env.get(child)) is not None: + assert isinstance(child_pred, Predicate) + region_preds.append(child_pred) + assert isinstance(child, transform.YieldOp) + + region_idx_to_pred[reg_idx] = Predicate( + lambda *args: all(args), region_preds + ) + for res_idx, yield_operand in enumerate(child.operands): + result_idx_region_idx_to_node[res_idx][reg_idx] = env[yield_operand] + + # Construct the node, that acts as a predicate, which represents + # selecting among the "alternative" regions. + env[op] = Alternatives(name=op.name, alt_idx_to_pred=region_idx_to_pred) + + # Construct the nodes, which act as a functions, that represent the + # results corresponding to one of the "alternative" regions being selected. + for res_idx, result in enumerate(op.results): + env[result] = AlternativesResult( + alternatives=env[op], + region_idx_to_result=result_idx_region_idx_to_node[res_idx], + ) + + case transform_smt_ext.ConstrainParamsOp(): + # Map the block args in the op's region to the nodes already + # associated to the corresponding arguments on the op itself. + for operand, block_arg in zip(op.operands, op.body.arguments): + env[block_arg] = env[operand] + + # Recursively trace the child (SMT) ops and construct an overall + # predicate representing the block/region successfully terminating. + child_predicates = [] + for child in islice(op.body.operations, len(op.body.operations) - 1): + trace_smt_op(child, env) + if (child_pred := env.get(child)) is not None: + assert isinstance(child_pred, Predicate) + child_predicates.append(child_pred) + + env[op] = Predicate(lambda *args: all(args), child_predicates) + + assert isinstance(smt_yield := op.body.operations[-1], smt.YieldOp) + + # Map the op's results to the nodes already associated to the + # corresponding values yielded by the region/block's terminator. + for yield_operand, op_res in zip(smt_yield.operands, op.results): + env[op_res] = env[yield_operand] + + case transform.NamedSequenceOp(): + # Recursively trace the child ops and construct an overall + # predicate representing the block/region successfully terminating. + child_predicates = [] + for child in op.body.operations: + trace_tune_and_smt_ops(child.operation, env) + if (child_pred := env.get(child)) is not None: + assert isinstance(child_pred, Predicate) + child_predicates.append(child_pred) + + env[op] = Predicate(lambda *args: all(args), child_predicates) + + case transform.ApplyPatternsOp(): + # A transform op with child ops we do skip over. + pass + + case _: + assert len(op.regions) == 0, f"Unhandled operation with regions: {op}" + + return env diff --git a/lighthouse/utils/mlir.py b/lighthouse/utils/mlir.py index c7163b1..440d205 100644 --- a/lighthouse/utils/mlir.py +++ b/lighthouse/utils/mlir.py @@ -12,7 +12,10 @@ def get_mlir_library_path(): pkg_path = ir.__file__ if "python_packages" in pkg_path: # looks like a local mlir install - path = os.path.join(pkg_path.split("python_packages")[0], "lib") + build_tools_mlir_dir = pkg_path.split("python_packages")[0] + build_tools_dir = build_tools_mlir_dir.rsplit("mlir")[0] + build_dir = build_tools_dir.rsplit("tools")[0] + path = os.path.join(build_dir, "lib") else: # maybe installed in python path path = os.path.join(os.path.split(pkg_path)[0], "_mlir_libs") diff --git a/lighthouse/utils/types.py b/lighthouse/utils/types.py new file mode 100644 index 0000000..9bb90b5 --- /dev/null +++ b/lighthouse/utils/types.py @@ -0,0 +1,26 @@ +from typing import TypeVar, Generic, Callable + +from collections.abc import Mapping + +K = TypeVar("K") +V = TypeVar("V") +W = TypeVar("W") + + +class LazyChainMap(Mapping, Generic[K, V, W]): + """A mapping that applies a function to the values of an underlying dictionary on access.""" + + def __init__(self, data: dict[K, V], func: Callable[[V], W]): + self._data = data + self._func = func + + def __getitem__(self, key): + # Access the underlying data and apply the transformation + value = self._data[key] + return self._func(value) + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data)