Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/pydsl/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,12 @@ def visit_BoolOp(self, node: ast.BoolOp) -> SubtreeOut:

def reducer(a, b):
return a.op_and(b)

case ast.Or():

def reducer(a, b):
return a.op_or(b)

case _:
raise SyntaxError(
f"{type(node.op)} is not supported as a boolean operator"
Expand Down Expand Up @@ -644,9 +646,9 @@ def visit_For(self, node: ast.For) -> SubtreeOut:
iterator = node.iter

# we will not accept any other way to pass in an iterator for now
assert (
type(iterator) is ast.Call
), "iterator of the for loop must be a Call for now"
assert type(iterator) is ast.Call, (
"iterator of the for loop must be a Call for now"
)

name = iterator.func.id
iterator = self.scope_stack.resolve_as_protocol(name, HandlesFor)
Expand Down
6 changes: 3 additions & 3 deletions src/pydsl/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _flags(self: Self) -> list[str]:
*transform_passes,
"-eliminate-empty-tensors",
"-empty-tensor-to-alloc-tensor",
f"-one-shot-bufferize='{" ".join(one_shot_bufferize)}'",
f"-one-shot-bufferize='{' '.join(one_shot_bufferize)}'",
"-canonicalize",
"-buffer-deallocation",
"-convert-bufferization-to-memref",
Expand Down Expand Up @@ -798,7 +798,7 @@ def call_function(self, fname: str, *args, **kwargs) -> Any:
if not len(sig.parameters) == len(args) + len(kwargs):
raise TypeError(
f"{f.name} takes {len(sig.parameters)} "
f"argument{"s" if len(sig.parameters) > 1 else ""} "
f"argument{'s' if len(sig.parameters) > 1 else ''} "
f"but {len(args) + len(kwargs)} were given"
)

Expand Down Expand Up @@ -1088,7 +1088,7 @@ def call_function(self, fname: str, *args) -> Any:
if not len(sig.parameters) == len(args):
raise TypeError(
f"{f.name} takes {len(sig.parameters)} positional "
f"argument{"s" if len(sig.parameters) > 1 else ""} "
f"argument{'s' if len(sig.parameters) > 1 else ''} "
f"but {len(args)} were given"
)

Expand Down
18 changes: 9 additions & 9 deletions src/pydsl/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,9 @@ def signature_to_class(
argst = tuple([p.annotation for p in lower_sig.parameters.values()])
rett = lower_sig.return_annotation

assert (
rett is not Signature.empty
), "signature contains Signature.empty after lowered"
assert rett is not Signature.empty, (
"signature contains Signature.empty after lowered"
)

return cls.class_factory(argst, rett)

Expand Down Expand Up @@ -648,9 +648,9 @@ def init_val(self) -> mlir.Operation:
*self._lower_typing(),
)

assert (
self.visibility == Visibility.PRIVATE
), "TransformSequence must be private by nature"
assert self.visibility == Visibility.PRIVATE, (
"TransformSequence must be private by nature"
)

return val

Expand Down Expand Up @@ -821,9 +821,9 @@ def on_Call(
arguments to a SubtreeOut, to be passed into __call__.
"""
fn: InlineFunction = attr_chain[-1]
assert isinstance(
fn, InlineFunction
), "InlineFunction on_Call called on not an InlineFunction"
assert isinstance(fn, InlineFunction), (
"InlineFunction on_Call called on not an InlineFunction"
)

args = tuple(visitor.visit(arg) for arg in (*prefix_args, *node.args))
kwargs = {kw.arg: visitor.visit(kw.value) for kw in node.keywords}
Expand Down
6 changes: 3 additions & 3 deletions src/pydsl/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def payload(
elif issubclass(t, Int) and t.sign == Sign.UNSIGNED:
return fn_unsigned, TypeFn.cast_unsigned
else:
assert (
False
), "Already checked type of t, this should be uncreachable"
assert False, (
"Already checked type of t, this should be uncreachable"
)

return payload

Expand Down
16 changes: 8 additions & 8 deletions src/pydsl/macro.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,33 +286,33 @@ def parse_args(
# Python doesn't have functor abstraction, so we'll have to do
# functor mapping for each of the common Python data type
match annorigin:
case builtins.list if (
ArgCompiler.is_type_ArgCompiler(annargs[0])
case builtins.list if ArgCompiler.is_type_ArgCompiler(
annargs[0]
):
argcomp = ArgCompiler.from_type(annargs[0])
binding[name] = [
argcomp.compile(visitor, i) for i in binding[name]
]

case builtins.tuple if (
ArgCompiler.is_type_ArgCompiler(annargs[0])
case builtins.tuple if ArgCompiler.is_type_ArgCompiler(
annargs[0]
):
argcomp = ArgCompiler.from_type(annargs[0])
binding[name] = tuple([
argcomp.compile(visitor, i) for i in binding[name]
])

case builtins.dict if (
ArgCompiler.is_type_ArgCompiler(annargs[0])
case builtins.dict if ArgCompiler.is_type_ArgCompiler(
annargs[0]
):
argcomp = ArgCompiler.from_type(annargs[0])
binding[name] = {
k: argcomp.compile(visitor, v)
for (k, v) in binding[name].items()
}

case builtins.set if (
ArgCompiler.is_type_ArgCompiler(annargs[0])
case builtins.set if ArgCompiler.is_type_ArgCompiler(
annargs[0]
):
argcomp = ArgCompiler.from_type(annargs[0])
binding[name] = {
Expand Down
24 changes: 10 additions & 14 deletions src/pydsl/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def __iter__(self):
)

def rank(self) -> int:
assert len(self.shape) == len(
self.strides
), "rank of a RankedMemRefDescriptor is inconsistent!"
assert len(self.shape) == len(self.strides), (
"rank of a RankedMemRefDescriptor is inconsistent!"
)

return len(self.shape)

Expand Down Expand Up @@ -546,9 +546,9 @@ def __init__(self, rep: OpView | Value) -> None:
lower_single(self.element_type) == rep.type.element_type,
]):
raise TypeError(
f"expected shape {"x".join([str(sh) for sh in self.shape])}"
f"expected shape {'x'.join([str(sh) for sh in self.shape])}"
f"x{lower_single(self.element_type)}, got representation with shape "
f"{"x".join([str(sh) for sh in rep.type.shape])}"
f"{'x'.join([str(sh) for sh in rep.type.shape])}"
f"x{rep.type.element_type}"
)

Expand Down Expand Up @@ -1028,7 +1028,7 @@ def calc_shape(memref_shape: tuple, assoc: list[list[int]]):
res = 1
for i in group:
dim = memref_shape[i]
if (dim == DYNAMIC or res == DYNAMIC):
if dim == DYNAMIC or res == DYNAMIC:
res = DYNAMIC
else:
res *= dim
Expand All @@ -1038,20 +1038,16 @@ def calc_shape(memref_shape: tuple, assoc: list[list[int]]):


@CallMacro.generate()
def collapse_shape(
visitor: ToMLIRBase,
mem: Compiled,
assoc: Evaluated
):
def collapse_shape(visitor: ToMLIRBase, mem: Compiled, assoc: Evaluated):
shpe = calc_shape(mem.shape, assoc)
result_type = MemRef[mem.element_type, *shpe]
return result_type(
memref.CollapseShapeOp(
lower_single(result_type),
lower_single(mem),
assoc
lower_single(result_type), lower_single(mem), assoc
)
)


def split_static_dynamic_dims(
shape: Iterable[Number | SupportsIndex],
) -> tuple[list[int], list[Index]]:
Expand Down
3 changes: 1 addition & 2 deletions src/pydsl/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def globals(self) -> dict[str, Any]:
"""
if self.global_scope is None:
raise AssertionError(
"attempted to call globals() on a stack without a global "
"scope"
"attempted to call globals() on a stack without a global scope"
)
return self.global_scope.f_locals

Expand Down
4 changes: 2 additions & 2 deletions src/pydsl/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def __init__(self, rep: OpView | Value) -> None:
lower_single(self.element_type) == rep.type.element_type,
]):
raise TypeError(
f"expected shape {"x".join([str(sh) for sh in self.shape])}"
f"expected shape {'x'.join([str(sh) for sh in self.shape])}"
f"x{lower_single(self.element_type)}, got OpView with shape "
f"{"x".join([str(sh) for sh in rep.type.shape])}"
f"{'x'.join([str(sh) for sh in rep.type.shape])}"
f"x{rep.type.element_type}"
)

Expand Down
6 changes: 3 additions & 3 deletions src/pydsl/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,9 +852,9 @@ class F64(Float, width=64, mlir_type=F64Type):
# the target the current machine this runs on
def get_index_width() -> int:
s = log2(sys.maxsize + 1) + 1
assert (
s.is_integer()
), "the compiler cannot determine the index size of the current "
assert s.is_integer(), (
"the compiler cannot determine the index size of the current "
)
f"system. sys.maxsize yielded {sys.maxsize}"

return int(s)
Expand Down
4 changes: 2 additions & 2 deletions src/pydsl/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def __init__(self, rep: OpView | Value) -> None:
lower_single(self.element_type) == rep.type.element_type,
]):
raise TypeError(
f"expected shape {"x".join([str(sh) for sh in self.shape])}"
f"expected shape {'x'.join([str(sh) for sh in self.shape])}"
f"x{lower_single(self.element_type)}, got OpView with shape "
f"{"x".join([str(sh) for sh in rep.type.shape])}"
f"{'x'.join([str(sh) for sh in rep.type.shape])}"
f"x{rep.type.element_type}"
)

Expand Down
17 changes: 15 additions & 2 deletions tests/e2e/test_memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
from pydsl.gpu import GPU_AddrSpace
import pydsl.linalg as linalg
import pydsl.memref as memref
from pydsl.memref import alloc, alloca, dealloc, DYNAMIC, MemRef, MemRefFactory, collapse_shape
from pydsl.memref import (
alloc,
alloca,
dealloc,
DYNAMIC,
MemRef,
MemRefFactory,
collapse_shape,
)
from pydsl.type import Bool, F32, F64, Index, SInt16, Tuple, UInt32
from helper import compilation_failed_from, failed_from, multi_arange, run

Expand Down Expand Up @@ -182,15 +190,18 @@ def f() -> MemRef[F64, 4, 6]:

def test_alloc_bad_align():
with compilation_failed_from(TypeError):

@compile()
def f():
alloc((4, 6), F64, alignment="xyz")

with compilation_failed_from(ValueError):

@compile()
def f():
alloc((4, 6), F64, alignment=-123)


def test_slice_memory_space():
"""
We can't use GPU address spaces on CPU, so just test if it compiles.
Expand Down Expand Up @@ -488,6 +499,7 @@ def my_func(a: MemRef[F32, 1, 3]) -> MemRef[F32, 3]:
n1 = np.array([[1.0, 2.0, 3.0]], dtype=np.float32)
assert all([a == b for a, b in zip(my_func(n1), [1.0, 2.0, 3.0])])


def test_cast_basic():
@compile()
def f(
Expand Down Expand Up @@ -540,7 +552,8 @@ def f2(m1: MemRef[F64, DYNAMIC, 4]):
@compile()
def f3(m1: MemRef1):
m1.cast(strides=(DYNAMIC, 16))



def test_copy_basic():
@compile()
def f(m1: MemRef[SInt16, 10, DYNAMIC], m2: MemRef[SInt16, 10, 10]):
Expand Down
12 changes: 6 additions & 6 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def failed_from(error_type: type[Exception]):
try:
yield
except Exception as e:
assert isinstance(
e, error_type
), f"expected {error_type}, but instead got {type(e)}"
assert isinstance(e, error_type), (
f"expected {error_type}, but instead got {type(e)}"
)
else:
assert False, f"expected {error_type}, but no error was raised"

Expand All @@ -30,9 +30,9 @@ def compilation_failed_from(error_type: type[Exception]):
try:
yield
except Exception as e:
assert isinstance(
e, CompilationError
), f"expected CompilationError, but instead got {type(e)}"
assert isinstance(e, CompilationError), (
f"expected CompilationError, but instead got {type(e)}"
)

assert isinstance(e.exception, error_type), (
f"CompilationError is caused by {type(e.exception)}: "
Expand Down
20 changes: 10 additions & 10 deletions tests/polybench/benchmarks/deriche.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def deriche(
y1: Memref2DF32,
y2: Memref2DF32,
) -> None:
xm = alloca(MemRef[F32, 1])
tm = alloca(MemRef[F32, 1])
ym1 = alloca(MemRef[F32, 1])
ym2 = alloca(MemRef[F32, 1])
xp1 = alloca(MemRef[F32, 1])
xp2 = alloca(MemRef[F32, 1])
tp1 = alloca(MemRef[F32, 1])
tp2 = alloca(MemRef[F32, 1])
yp1 = alloca(MemRef[F32, 1])
yp2 = alloca(MemRef[F32, 1])
xm = alloca((1,), F32)
tm = alloca((1,), F32)
ym1 = alloca((1,), F32)
ym2 = alloca((1,), F32)
xp1 = alloca((1,), F32)
xp2 = alloca((1,), F32)
tp1 = alloca((1,), F32)
tp2 = alloca((1,), F32)
yp1 = alloca((1,), F32)
yp2 = alloca((1,), F32)

cst1: F32 = 1.0
cst2: F32 = 2.0
Expand Down
2 changes: 1 addition & 1 deletion tests/polybench/benchmarks/gramschmidt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def gramschmidt(
Q: MemrefF32MN,
) -> None:
b: F32 = 0.0
temp = alloca(MemRef[F32, 1])
temp = alloca((1,), F32)
for k in arange(n):
temp[0] = b
for i in arange(m):
Expand Down
2 changes: 1 addition & 1 deletion tests/polybench/benchmarks/ludcmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def ludcmp(
x: MemrefF321D,
y: MemrefF321D,
) -> None:
w = alloca(MemRef[F32, 1])
w = alloca((1,), F32)
for i in arange(n):
for j in arange(i):
w[0] = A[i, j]
Expand Down
2 changes: 1 addition & 1 deletion tests/polybench/benchmarks/symm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def symm(
A: MemrefMMF32,
B: MemrefMNF32,
) -> None:
temp_arr = alloca(MemRef[F32, 1])
temp_arr = alloca((1,), F32)
for i in arange(m):
for j in arange(n):
temp_arr[0] = F32(0.0)
Expand Down