diff --git a/src/pydsl/compiler.py b/src/pydsl/compiler.py index e2d2017..70f8968 100644 --- a/src/pydsl/compiler.py +++ b/src/pydsl/compiler.py @@ -408,10 +408,12 @@ def visit_BoolOp(self, node: ast.BoolOp) -> SubtreeOut: def reducer(a, b): return a.op_and(b) + case ast.Or(): def reducer(a, b): return a.op_or(b) + case _: raise SyntaxError( f"{type(node.op)} is not supported as a boolean operator" @@ -644,9 +646,9 @@ def visit_For(self, node: ast.For) -> SubtreeOut: iterator = node.iter # we will not accept any other way to pass in an iterator for now - assert ( - type(iterator) is ast.Call - ), "iterator of the for loop must be a Call for now" + assert type(iterator) is ast.Call, ( + "iterator of the for loop must be a Call for now" + ) name = iterator.func.id iterator = self.scope_stack.resolve_as_protocol(name, HandlesFor) diff --git a/src/pydsl/frontend.py b/src/pydsl/frontend.py index 44beb57..d3b1743 100644 --- a/src/pydsl/frontend.py +++ b/src/pydsl/frontend.py @@ -446,7 +446,7 @@ def _flags(self: Self) -> list[str]: *transform_passes, "-eliminate-empty-tensors", "-empty-tensor-to-alloc-tensor", - f"-one-shot-bufferize='{" ".join(one_shot_bufferize)}'", + f"-one-shot-bufferize='{' '.join(one_shot_bufferize)}'", "-canonicalize", "-buffer-deallocation", "-convert-bufferization-to-memref", @@ -798,7 +798,7 @@ def call_function(self, fname: str, *args, **kwargs) -> Any: if not len(sig.parameters) == len(args) + len(kwargs): raise TypeError( f"{f.name} takes {len(sig.parameters)} " - f"argument{"s" if len(sig.parameters) > 1 else ""} " + f"argument{'s' if len(sig.parameters) > 1 else ''} " f"but {len(args) + len(kwargs)} were given" ) @@ -1088,7 +1088,7 @@ def call_function(self, fname: str, *args) -> Any: if not len(sig.parameters) == len(args): raise TypeError( f"{f.name} takes {len(sig.parameters)} positional " - f"argument{"s" if len(sig.parameters) > 1 else ""} " + f"argument{'s' if len(sig.parameters) > 1 else ''} " f"but {len(args)} were given" ) diff --git a/src/pydsl/func.py b/src/pydsl/func.py index 7ef1edc..16dff39 100644 --- a/src/pydsl/func.py +++ b/src/pydsl/func.py @@ -486,9 +486,9 @@ def signature_to_class( argst = tuple([p.annotation for p in lower_sig.parameters.values()]) rett = lower_sig.return_annotation - assert ( - rett is not Signature.empty - ), "signature contains Signature.empty after lowered" + assert rett is not Signature.empty, ( + "signature contains Signature.empty after lowered" + ) return cls.class_factory(argst, rett) @@ -648,9 +648,9 @@ def init_val(self) -> mlir.Operation: *self._lower_typing(), ) - assert ( - self.visibility == Visibility.PRIVATE - ), "TransformSequence must be private by nature" + assert self.visibility == Visibility.PRIVATE, ( + "TransformSequence must be private by nature" + ) return val @@ -821,9 +821,9 @@ def on_Call( arguments to a SubtreeOut, to be passed into __call__. """ fn: InlineFunction = attr_chain[-1] - assert isinstance( - fn, InlineFunction - ), "InlineFunction on_Call called on not an InlineFunction" + assert isinstance(fn, InlineFunction), ( + "InlineFunction on_Call called on not an InlineFunction" + ) args = tuple(visitor.visit(arg) for arg in (*prefix_args, *node.args)) kwargs = {kw.arg: visitor.visit(kw.value) for kw in node.keywords} diff --git a/src/pydsl/linalg.py b/src/pydsl/linalg.py index b434a1e..0475402 100644 --- a/src/pydsl/linalg.py +++ b/src/pydsl/linalg.py @@ -141,9 +141,9 @@ def payload( elif issubclass(t, Int) and t.sign == Sign.UNSIGNED: return fn_unsigned, TypeFn.cast_unsigned else: - assert ( - False - ), "Already checked type of t, this should be uncreachable" + assert False, ( + "Already checked type of t, this should be uncreachable" + ) return payload diff --git a/src/pydsl/macro.py b/src/pydsl/macro.py index 6dce195..a811766 100644 --- a/src/pydsl/macro.py +++ b/src/pydsl/macro.py @@ -286,24 +286,24 @@ def parse_args( # Python doesn't have functor abstraction, so we'll have to do # functor mapping for each of the common Python data type match annorigin: - case builtins.list if ( - ArgCompiler.is_type_ArgCompiler(annargs[0]) + case builtins.list if ArgCompiler.is_type_ArgCompiler( + annargs[0] ): argcomp = ArgCompiler.from_type(annargs[0]) binding[name] = [ argcomp.compile(visitor, i) for i in binding[name] ] - case builtins.tuple if ( - ArgCompiler.is_type_ArgCompiler(annargs[0]) + case builtins.tuple if ArgCompiler.is_type_ArgCompiler( + annargs[0] ): argcomp = ArgCompiler.from_type(annargs[0]) binding[name] = tuple([ argcomp.compile(visitor, i) for i in binding[name] ]) - case builtins.dict if ( - ArgCompiler.is_type_ArgCompiler(annargs[0]) + case builtins.dict if ArgCompiler.is_type_ArgCompiler( + annargs[0] ): argcomp = ArgCompiler.from_type(annargs[0]) binding[name] = { @@ -311,8 +311,8 @@ def parse_args( for (k, v) in binding[name].items() } - case builtins.set if ( - ArgCompiler.is_type_ArgCompiler(annargs[0]) + case builtins.set if ArgCompiler.is_type_ArgCompiler( + annargs[0] ): argcomp = ArgCompiler.from_type(annargs[0]) binding[name] = { diff --git a/src/pydsl/memref.py b/src/pydsl/memref.py index 928b2cc..9cb49b7 100644 --- a/src/pydsl/memref.py +++ b/src/pydsl/memref.py @@ -135,9 +135,9 @@ def __iter__(self): ) def rank(self) -> int: - assert len(self.shape) == len( - self.strides - ), "rank of a RankedMemRefDescriptor is inconsistent!" + assert len(self.shape) == len(self.strides), ( + "rank of a RankedMemRefDescriptor is inconsistent!" + ) return len(self.shape) @@ -546,9 +546,9 @@ def __init__(self, rep: OpView | Value) -> None: lower_single(self.element_type) == rep.type.element_type, ]): raise TypeError( - f"expected shape {"x".join([str(sh) for sh in self.shape])}" + f"expected shape {'x'.join([str(sh) for sh in self.shape])}" f"x{lower_single(self.element_type)}, got representation with shape " - f"{"x".join([str(sh) for sh in rep.type.shape])}" + f"{'x'.join([str(sh) for sh in rep.type.shape])}" f"x{rep.type.element_type}" ) @@ -1028,7 +1028,7 @@ def calc_shape(memref_shape: tuple, assoc: list[list[int]]): res = 1 for i in group: dim = memref_shape[i] - if (dim == DYNAMIC or res == DYNAMIC): + if dim == DYNAMIC or res == DYNAMIC: res = DYNAMIC else: res *= dim @@ -1038,20 +1038,16 @@ def calc_shape(memref_shape: tuple, assoc: list[list[int]]): @CallMacro.generate() -def collapse_shape( - visitor: ToMLIRBase, - mem: Compiled, - assoc: Evaluated -): +def collapse_shape(visitor: ToMLIRBase, mem: Compiled, assoc: Evaluated): shpe = calc_shape(mem.shape, assoc) result_type = MemRef[mem.element_type, *shpe] return result_type( memref.CollapseShapeOp( - lower_single(result_type), - lower_single(mem), - assoc + lower_single(result_type), lower_single(mem), assoc ) ) + + def split_static_dynamic_dims( shape: Iterable[Number | SupportsIndex], ) -> tuple[list[int], list[Index]]: diff --git a/src/pydsl/scope.py b/src/pydsl/scope.py index d55e526..43a3038 100644 --- a/src/pydsl/scope.py +++ b/src/pydsl/scope.py @@ -70,8 +70,7 @@ def globals(self) -> dict[str, Any]: """ if self.global_scope is None: raise AssertionError( - "attempted to call globals() on a stack without a global " - "scope" + "attempted to call globals() on a stack without a global scope" ) return self.global_scope.f_locals diff --git a/src/pydsl/tensor.py b/src/pydsl/tensor.py index 4236a2b..6854657 100644 --- a/src/pydsl/tensor.py +++ b/src/pydsl/tensor.py @@ -116,9 +116,9 @@ def __init__(self, rep: OpView | Value) -> None: lower_single(self.element_type) == rep.type.element_type, ]): raise TypeError( - f"expected shape {"x".join([str(sh) for sh in self.shape])}" + f"expected shape {'x'.join([str(sh) for sh in self.shape])}" f"x{lower_single(self.element_type)}, got OpView with shape " - f"{"x".join([str(sh) for sh in rep.type.shape])}" + f"{'x'.join([str(sh) for sh in rep.type.shape])}" f"x{rep.type.element_type}" ) diff --git a/src/pydsl/type.py b/src/pydsl/type.py index 0c690c3..a8acb8d 100644 --- a/src/pydsl/type.py +++ b/src/pydsl/type.py @@ -852,9 +852,9 @@ class F64(Float, width=64, mlir_type=F64Type): # the target the current machine this runs on def get_index_width() -> int: s = log2(sys.maxsize + 1) + 1 - assert ( - s.is_integer() - ), "the compiler cannot determine the index size of the current " + assert s.is_integer(), ( + "the compiler cannot determine the index size of the current " + ) f"system. sys.maxsize yielded {sys.maxsize}" return int(s) diff --git a/src/pydsl/vector.py b/src/pydsl/vector.py index 314198e..612b4dd 100644 --- a/src/pydsl/vector.py +++ b/src/pydsl/vector.py @@ -90,9 +90,9 @@ def __init__(self, rep: OpView | Value) -> None: lower_single(self.element_type) == rep.type.element_type, ]): raise TypeError( - f"expected shape {"x".join([str(sh) for sh in self.shape])}" + f"expected shape {'x'.join([str(sh) for sh in self.shape])}" f"x{lower_single(self.element_type)}, got OpView with shape " - f"{"x".join([str(sh) for sh in rep.type.shape])}" + f"{'x'.join([str(sh) for sh in rep.type.shape])}" f"x{rep.type.element_type}" ) diff --git a/tests/e2e/test_memref.py b/tests/e2e/test_memref.py index 7a1f136..5d892df 100644 --- a/tests/e2e/test_memref.py +++ b/tests/e2e/test_memref.py @@ -9,7 +9,15 @@ from pydsl.gpu import GPU_AddrSpace import pydsl.linalg as linalg import pydsl.memref as memref -from pydsl.memref import alloc, alloca, dealloc, DYNAMIC, MemRef, MemRefFactory, collapse_shape +from pydsl.memref import ( + alloc, + alloca, + dealloc, + DYNAMIC, + MemRef, + MemRefFactory, + collapse_shape, +) from pydsl.type import Bool, F32, F64, Index, SInt16, Tuple, UInt32 from helper import compilation_failed_from, failed_from, multi_arange, run @@ -182,15 +190,18 @@ def f() -> MemRef[F64, 4, 6]: def test_alloc_bad_align(): with compilation_failed_from(TypeError): + @compile() def f(): alloc((4, 6), F64, alignment="xyz") with compilation_failed_from(ValueError): + @compile() def f(): alloc((4, 6), F64, alignment=-123) + def test_slice_memory_space(): """ We can't use GPU address spaces on CPU, so just test if it compiles. @@ -488,6 +499,7 @@ def my_func(a: MemRef[F32, 1, 3]) -> MemRef[F32, 3]: n1 = np.array([[1.0, 2.0, 3.0]], dtype=np.float32) assert all([a == b for a, b in zip(my_func(n1), [1.0, 2.0, 3.0])]) + def test_cast_basic(): @compile() def f( @@ -540,7 +552,8 @@ def f2(m1: MemRef[F64, DYNAMIC, 4]): @compile() def f3(m1: MemRef1): m1.cast(strides=(DYNAMIC, 16)) - + + def test_copy_basic(): @compile() def f(m1: MemRef[SInt16, 10, DYNAMIC], m2: MemRef[SInt16, 10, 10]): diff --git a/tests/helper.py b/tests/helper.py index a9f96ea..a05e7f1 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -18,9 +18,9 @@ def failed_from(error_type: type[Exception]): try: yield except Exception as e: - assert isinstance( - e, error_type - ), f"expected {error_type}, but instead got {type(e)}" + assert isinstance(e, error_type), ( + f"expected {error_type}, but instead got {type(e)}" + ) else: assert False, f"expected {error_type}, but no error was raised" @@ -30,9 +30,9 @@ def compilation_failed_from(error_type: type[Exception]): try: yield except Exception as e: - assert isinstance( - e, CompilationError - ), f"expected CompilationError, but instead got {type(e)}" + assert isinstance(e, CompilationError), ( + f"expected CompilationError, but instead got {type(e)}" + ) assert isinstance(e.exception, error_type), ( f"CompilationError is caused by {type(e.exception)}: " diff --git a/tests/polybench/benchmarks/deriche.py b/tests/polybench/benchmarks/deriche.py index 6fb7670..20a1c0f 100644 --- a/tests/polybench/benchmarks/deriche.py +++ b/tests/polybench/benchmarks/deriche.py @@ -27,16 +27,16 @@ def deriche( y1: Memref2DF32, y2: Memref2DF32, ) -> None: - xm = alloca(MemRef[F32, 1]) - tm = alloca(MemRef[F32, 1]) - ym1 = alloca(MemRef[F32, 1]) - ym2 = alloca(MemRef[F32, 1]) - xp1 = alloca(MemRef[F32, 1]) - xp2 = alloca(MemRef[F32, 1]) - tp1 = alloca(MemRef[F32, 1]) - tp2 = alloca(MemRef[F32, 1]) - yp1 = alloca(MemRef[F32, 1]) - yp2 = alloca(MemRef[F32, 1]) + xm = alloca((1,), F32) + tm = alloca((1,), F32) + ym1 = alloca((1,), F32) + ym2 = alloca((1,), F32) + xp1 = alloca((1,), F32) + xp2 = alloca((1,), F32) + tp1 = alloca((1,), F32) + tp2 = alloca((1,), F32) + yp1 = alloca((1,), F32) + yp2 = alloca((1,), F32) cst1: F32 = 1.0 cst2: F32 = 2.0 diff --git a/tests/polybench/benchmarks/gramschmidt.py b/tests/polybench/benchmarks/gramschmidt.py index 4b42501..b1890d7 100644 --- a/tests/polybench/benchmarks/gramschmidt.py +++ b/tests/polybench/benchmarks/gramschmidt.py @@ -26,7 +26,7 @@ def gramschmidt( Q: MemrefF32MN, ) -> None: b: F32 = 0.0 - temp = alloca(MemRef[F32, 1]) + temp = alloca((1,), F32) for k in arange(n): temp[0] = b for i in arange(m): diff --git a/tests/polybench/benchmarks/ludcmp.py b/tests/polybench/benchmarks/ludcmp.py index 0259bcf..1c34141 100644 --- a/tests/polybench/benchmarks/ludcmp.py +++ b/tests/polybench/benchmarks/ludcmp.py @@ -23,7 +23,7 @@ def ludcmp( x: MemrefF321D, y: MemrefF321D, ) -> None: - w = alloca(MemRef[F32, 1]) + w = alloca((1,), F32) for i in arange(n): for j in arange(i): w[0] = A[i, j] diff --git a/tests/polybench/benchmarks/symm.py b/tests/polybench/benchmarks/symm.py index 3de69e2..3b60a98 100644 --- a/tests/polybench/benchmarks/symm.py +++ b/tests/polybench/benchmarks/symm.py @@ -25,7 +25,7 @@ def symm( A: MemrefMMF32, B: MemrefMNF32, ) -> None: - temp_arr = alloca(MemRef[F32, 1]) + temp_arr = alloca((1,), F32) for i in arange(m): for j in arange(n): temp_arr[0] = F32(0.0)