Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions examples/xegpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mlir import ir
from mlir.execution_engine import ExecutionEngine

from lighthouse import dialects as lh_dialects
from lighthouse.workload import benchmark
from lighthouse.utils.memref import to_ctype as memref_to_ctype
from lighthouse.utils.numpy import numpy_to_ctype
Expand Down Expand Up @@ -194,13 +195,13 @@ def payload_module(self) -> ir.Module:
def schedule_module(
self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None
) -> ir.Module:
assert parameters is not None, "Schedule parameters must be provided"
return get_schedule_module(
has_bias=self.has_bias,
has_relu=self.has_relu,
has_convert_c=False,
stop_at_stage=stop_at_stage,
nlayers=1,
params={"layer_0": parameters},
params=[parameters],
)

def shared_libs(self) -> list[str]:
Expand All @@ -212,6 +213,9 @@ def parse_cli():
description="Matrix Multiplication using MLIR",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--all-knobs", action="store_true", help="Use knobs for all schedule parameters"
)
parser.add_argument(
"--sizes",
type=int,
Expand Down Expand Up @@ -243,28 +247,28 @@ def parse_cli():
"--load-tile-a",
type=int,
nargs=2,
default=[32, 16],
default=[16, 32],
help="Tile size for loading A matrix for DPAS op.",
)
parser.add_argument(
"--load-tile-b",
type=int,
nargs=2,
default=[32, 16],
default=[16, 16],
help="Tile size for loading B matrix for DPAS op.",
)
parser.add_argument(
"--prefetch-tile-a",
type=int,
nargs=2,
default=[8, 32],
default=[16, 32],
help="Tile size for cooperative prefetching of subgroup A matrix",
)
parser.add_argument(
"--prefetch-tile-b",
type=int,
nargs=2,
default=[8, 32],
default=[16, 16],
help="Tile size for cooperative prefetching of subgroup B matrix",
)
parser.add_argument(
Expand Down Expand Up @@ -333,28 +337,34 @@ def parse_cli():
if __name__ == "__main__":
args = parse_cli()

M, N, K = args.sizes

params = {
"wg_m": args.wg_tile[0],
"wg_n": args.wg_tile[1],
"sg_m": args.sg_tile[0],
"sg_n": args.sg_tile[1],
"k": args.k_tile,
"load_a_m": args.load_tile_a[0],
"load_a_k": args.load_tile_a[1],
"load_b_k": args.load_tile_b[0],
"load_b_n": args.load_tile_b[1],
"pf_a_m": args.prefetch_tile_a[0],
"pf_a_k": args.prefetch_tile_a[1],
"pf_b_k": args.prefetch_tile_b[0],
"pf_b_n": args.prefetch_tile_b[1],
"pf_nb": args.nb_prefetch,
"m": M,
"n": N,
"k": K,
"wg_m": None if args.all_knobs else args.wg_tile[0],
"wg_n": None if args.all_knobs else args.wg_tile[1],
"sg_m": None if args.all_knobs else args.sg_tile[0],
"sg_n": None if args.all_knobs else args.sg_tile[1],
"k_tile": None if args.all_knobs else args.k_tile,
"load_a_m": None if args.all_knobs else args.load_tile_a[0],
"load_a_k": None if args.all_knobs else args.load_tile_a[1],
"load_b_k": None if args.all_knobs else args.load_tile_b[0],
"load_b_n": None if args.all_knobs else args.load_tile_b[1],
"prefetch_a_m": None if args.all_knobs else args.prefetch_tile_a[0],
"prefetch_a_k": None if args.all_knobs else args.prefetch_tile_a[1],
"prefetch_b_k": None if args.all_knobs else args.prefetch_tile_b[0],
"prefetch_b_n": None if args.all_knobs else args.prefetch_tile_b[1],
"nb_prefetch": args.nb_prefetch,
}

M, N, K = args.sizes
ab_type = "f16"
c_type = "f32"

with ir.Context(), ir.Location.unknown():
lh_dialects.register_and_load()

wload = XeGPUMatMul(
M=M,
N=N,
Expand Down
8 changes: 8 additions & 0 deletions lighthouse/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def register_and_load():
from . import smt_ext
from . import transform_smt_ext
from . import transform_tune_ext

smt_ext.register_and_load()
transform_smt_ext.register_and_load()
transform_tune_ext.register_and_load()
98 changes: 98 additions & 0 deletions lighthouse/dialects/smt_ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Callable

from mlir import ir
from mlir.dialects import smt

__all__ = ["SMTIntValue", "assert_", "register_and_load"]


def register_and_load(context=None):
"""Register and load the SMTIntValue caster."""

SMTIntValue.register_value_caster()


def assert_(predicate: ir.Value[smt.BoolType] | bool):
"""Assert normally if a bool else produce an SMT assertion op."""

if isinstance(predicate, bool):
assert predicate
else:
smt.assert_(predicate)


def int_to_smt(operand: "int | SMTIntValue") -> "SMTIntValue":
if isinstance(operand, int):
int_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), operand)
return SMTIntValue(smt.int_constant(int_attr))
return operand


def swapped(
f: Callable[["int | SMTIntValue", "int | SMTIntValue"], "int | SMTIntValue"],
) -> Callable[["int | SMTIntValue", "int | SMTIntValue"], "int | SMTIntValue"]:
return lambda a, b: f(b, a)


class SMTIntValue(ir.Value[smt.IntType]):
"""A Value caster for `!smt.int` that supports Pythonic arithmetic and comparison operations."""

def __init__(self, v):
super().__init__(v)

def __hash__(self):
return super().__hash__()

@staticmethod
def register_value_caster():
if not hasattr(SMTIntValue, "_is_registered"):
ir.register_value_caster(smt.IntType.static_typeid)(SMTIntValue)
setattr(SMTIntValue, "_is_registered", True)

def __add__(self, rhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_add([self, int_to_smt(rhs)]))

def __radd__(self, lhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_add([int_to_smt(lhs), self]))

def __sub__(self, rhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_sub(self, int_to_smt(rhs)))

def __rsub__(self, lhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_sub(int_to_smt(lhs), self))

def __mul__(self, rhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_mul([self, int_to_smt(rhs)]))

def __rmul__(self, lhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_mul([int_to_smt(lhs), self]))

def __floordiv__(self, rhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_div(self, int_to_smt(rhs)))

def __rfloordiv__(self, lhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_div(int_to_smt(lhs), self))

def __mod__(self, rhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_mod(self, int_to_smt(rhs)))

def __rmod__(self, lhs: "int | SMTIntValue") -> "SMTIntValue":
return SMTIntValue(smt.int_mod(int_to_smt(lhs), self))

def __eq__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]:
return smt.eq([self, int_to_smt(rhs)])

def __le__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]:
return smt.int_cmp(smt.IntPredicate.le, self, int_to_smt(rhs))

def __lt__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]:
return smt.int_cmp(smt.IntPredicate.lt, self, int_to_smt(rhs))

def __ge__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]:
return smt.int_cmp(smt.IntPredicate.ge, self, int_to_smt(rhs))

def __gt__(self, rhs: "int | SMTIntValue") -> ir.Value[smt.BoolType]:
return smt.int_cmp(smt.IntPredicate.gt, self, int_to_smt(rhs))

def __str__(self):
return super().__str__().replace(ir.Value.__name__, SMTIntValue.__name__)
Loading
Loading