Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/pydsl/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pydsl.macro import CallMacro, Compiled
from pydsl.protocols import lower_single, SubtreeOut, ToMLIRBase
from pydsl.type import Int, Float, Sign
from pydsl.vector import Vector

import mlir.dialects.arith as arith

Expand Down Expand Up @@ -59,3 +60,69 @@ def min(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut:
return rett(arith.MinimumFOp(av, bv))
else:
raise TypeError(f"cannot take min of {rett.__qualname__}")


@CallMacro.generate()
def trunc(
visitor: ToMLIRBase,
a: Compiled,
truncated_type: Compiled,
*,
round_mode: Compiled = None,
) -> SubtreeOut:
a_type = type(a)
out_type = truncated_type
if isinstance(a, Vector):
out_type = Vector.get(a.shape, truncated_type)
a_type = a.element_type

if truncated_type.width >= a_type.width:
raise TypeError("truncated type must be smaller than called type.")

if issubclass(a_type, Int):
out = arith.TruncIOp(lower_single(out_type), lower_single(a))
elif issubclass(a_type, Float):
out = arith.TruncFOp(lower_single(out_type), lower_single(a))
else:
raise TypeError(f"cannot take trunc of {a_type.__qualname__}")
if round_mode is not None:
out.attributes["round_mode"] = lower_single(round_mode)
return (out_type)(out)


@CallMacro.generate()
def vadd(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut:
rett = type(a)

if not isinstance(a, Vector):
raise TypeError(f"NOT a vector addition operation")
if type(a) != type(b):
raise TypeError(f"VADD type {type(a)} does not match {type(b)}")

a_type = a.element_type
if issubclass(a_type, Int):
op = arith.addi(lower_single(a), lower_single(b))
elif issubclass(a_type, Float):
op = arith.addf(lower_single(a), lower_single(b))
else:
raise TypeError(f"unsupported vector addition type: {a_type}")
return rett(op)


@CallMacro.generate()
def vmul(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut:
rett = type(a)

if not isinstance(a, Vector):
raise TypeError(f"NOT a vector multiplication operation")
if type(a) != type(b):
raise TypeError(f"VMUL type {type(a)} does not match {type(b)}")

a_type = a.element_type
if issubclass(a_type, Int):
op = arith.muli(lower_single(a), lower_single(b))
elif issubclass(a_type, Float):
op = arith.mulf(lower_single(a), lower_single(b))
else:
raise TypeError(f"unsupported vector multiplication type: {a_type}")
return rett(op)
74 changes: 74 additions & 0 deletions src/pydsl/bufferization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from mlir.dialects import bufferization
from pydsl.macro import CallMacro, Compiled
from pydsl.tensor import Tensor
from pydsl.memref import MemRef
from pydsl.protocols import ToMLIRBase, lower_single, SubtreeOut

TensorFactory = Tensor.class_factory


def verify_all_memref(*args):
"""
Checks that all arguments are MemRef.
Raises a TypeError otherwise.
"""

# Collect argument type names for error messages
arg_type_names = []
for arg in args:
arg_type_names.append(type(arg).__qualname__)
arg_type_str = ", ".join(arg_type_names)

# Check that every argument is a MemRef
for arg in args:
if not isinstance(arg, MemRef):
raise TypeError(
"bufferization operation expects arguments of type MemRef, "
f"got {arg_type_str}"
)


def verify_all_tensor(*args):
"""
Checks that all arguments are Tensor.
Raises a TypeError otherwise.
"""

# Collect argument type names for error messages
arg_type_names = []
for arg in args:
arg_type_names.append(type(arg).__qualname__)
arg_type_str = ", ".join(arg_type_names)

# Check that every argument is a Tensor
for arg in args:
if not isinstance(arg, Tensor):
raise TypeError(
"bufferization operation expects arguments of type Tensor, "
f"got {arg_type_str}"
)


@CallMacro.generate()
def to_tensor(visitor: "ToMLIRBase", x: Compiled) -> SubtreeOut:
verify_all_memref(x)

rep = bufferization.to_tensor(
lower_single(x), restrict=True, writable=True
)
static_shape = rep.type.shape
t_type = TensorFactory(tuple(static_shape), rep.type.element_type)

return t_type(rep)


@CallMacro.generate()
def materialize_in_destination(
visitor: "ToMLIRBase", x: Compiled, y: Compiled
):
verify_all_tensor(x)
verify_all_memref(y)
bufferization.MaterializeInDestinationOp(
None, lower_single(x), lower_single(y), writable=True
)
return
1 change: 1 addition & 0 deletions src/pydsl/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ def get_supported_dialects(self) -> set[Dialect]:
Dialect.from_name("transform"),
Dialect.from_name("transform.loop"),
Dialect.from_name("transform.structured"),
Dialect.from_name("vector"),
}


Expand Down
Loading