From 1940abbc54c588b1025cc590ca65552bfbcf5bc5 Mon Sep 17 00:00:00 2001 From: Mengxuan Cai Date: Thu, 19 Feb 2026 17:45:08 -0500 Subject: [PATCH] add support for bufferization dialect's to_trensor and materialize_in_destination. add support for arith dialect's vadd and mul. Update memref.py and frontend.py. --- src/pydsl/arith.py | 67 +++++ src/pydsl/bufferization.py | 74 +++++ src/pydsl/frontend.py | 1 + src/pydsl/memref.py | 561 +++++++++++++++++++++++++++++++++---- tests/e2e/test_arith.py | 40 +++ 5 files changed, 681 insertions(+), 62 deletions(-) create mode 100644 src/pydsl/bufferization.py diff --git a/src/pydsl/arith.py b/src/pydsl/arith.py index 5fc9457..b51d570 100644 --- a/src/pydsl/arith.py +++ b/src/pydsl/arith.py @@ -2,6 +2,7 @@ from pydsl.macro import CallMacro, Compiled from pydsl.protocols import lower_single, SubtreeOut, ToMLIRBase from pydsl.type import Int, Float, Sign +from pydsl.vector import Vector import mlir.dialects.arith as arith @@ -59,3 +60,69 @@ def min(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut: return rett(arith.MinimumFOp(av, bv)) else: raise TypeError(f"cannot take min of {rett.__qualname__}") + + +@CallMacro.generate() +def trunc( + visitor: ToMLIRBase, + a: Compiled, + truncated_type: Compiled, + *, + round_mode: Compiled = None, +) -> SubtreeOut: + a_type = type(a) + out_type = truncated_type + if isinstance(a, Vector): + out_type = Vector.get(a.shape, truncated_type) + a_type = a.element_type + + if truncated_type.width >= a_type.width: + raise TypeError("truncated type must be smaller than called type.") + + if issubclass(a_type, Int): + out = arith.TruncIOp(lower_single(out_type), lower_single(a)) + elif issubclass(a_type, Float): + out = arith.TruncFOp(lower_single(out_type), lower_single(a)) + else: + raise TypeError(f"cannot take trunc of {a_type.__qualname__}") + if round_mode is not None: + out.attributes["round_mode"] = lower_single(round_mode) + return (out_type)(out) + + +@CallMacro.generate() +def vadd(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut: + rett = type(a) + + if not isinstance(a, Vector): + raise TypeError(f"NOT a vector addition operation") + if type(a) != type(b): + raise TypeError(f"VADD type {type(a)} does not match {type(b)}") + + a_type = a.element_type + if issubclass(a_type, Int): + op = arith.addi(lower_single(a), lower_single(b)) + elif issubclass(a_type, Float): + op = arith.addf(lower_single(a), lower_single(b)) + else: + raise TypeError(f"unsupported vector addition type: {a_type}") + return rett(op) + + +@CallMacro.generate() +def vmul(visitor: ToMLIRBase, a: Compiled, b: Compiled) -> SubtreeOut: + rett = type(a) + + if not isinstance(a, Vector): + raise TypeError(f"NOT a vector multiplication operation") + if type(a) != type(b): + raise TypeError(f"VMUL type {type(a)} does not match {type(b)}") + + a_type = a.element_type + if issubclass(a_type, Int): + op = arith.muli(lower_single(a), lower_single(b)) + elif issubclass(a_type, Float): + op = arith.mulf(lower_single(a), lower_single(b)) + else: + raise TypeError(f"unsupported vector multiplication type: {a_type}") + return rett(op) diff --git a/src/pydsl/bufferization.py b/src/pydsl/bufferization.py new file mode 100644 index 0000000..069ac0e --- /dev/null +++ b/src/pydsl/bufferization.py @@ -0,0 +1,74 @@ +from mlir.dialects import bufferization +from pydsl.macro import CallMacro, Compiled +from pydsl.tensor import Tensor +from pydsl.memref import MemRef +from pydsl.protocols import ToMLIRBase, lower_single, SubtreeOut + +TensorFactory = Tensor.class_factory + + +def verify_all_memref(*args): + """ + Checks that all arguments are MemRef. + Raises a TypeError otherwise. + """ + + # Collect argument type names for error messages + arg_type_names = [] + for arg in args: + arg_type_names.append(type(arg).__qualname__) + arg_type_str = ", ".join(arg_type_names) + + # Check that every argument is a MemRef + for arg in args: + if not isinstance(arg, MemRef): + raise TypeError( + "bufferization operation expects arguments of type MemRef, " + f"got {arg_type_str}" + ) + + +def verify_all_tensor(*args): + """ + Checks that all arguments are Tensor. + Raises a TypeError otherwise. + """ + + # Collect argument type names for error messages + arg_type_names = [] + for arg in args: + arg_type_names.append(type(arg).__qualname__) + arg_type_str = ", ".join(arg_type_names) + + # Check that every argument is a Tensor + for arg in args: + if not isinstance(arg, Tensor): + raise TypeError( + "bufferization operation expects arguments of type Tensor, " + f"got {arg_type_str}" + ) + + +@CallMacro.generate() +def to_tensor(visitor: "ToMLIRBase", x: Compiled) -> SubtreeOut: + verify_all_memref(x) + + rep = bufferization.to_tensor( + lower_single(x), restrict=True, writable=True + ) + static_shape = rep.type.shape + t_type = TensorFactory(tuple(static_shape), rep.type.element_type) + + return t_type(rep) + + +@CallMacro.generate() +def materialize_in_destination( + visitor: "ToMLIRBase", x: Compiled, y: Compiled +): + verify_all_tensor(x) + verify_all_memref(y) + bufferization.MaterializeInDestinationOp( + None, lower_single(x), lower_single(y), writable=True + ) + return diff --git a/src/pydsl/frontend.py b/src/pydsl/frontend.py index 31d4a8c..20a30ce 100644 --- a/src/pydsl/frontend.py +++ b/src/pydsl/frontend.py @@ -846,6 +846,7 @@ def get_supported_dialects(self) -> set[Dialect]: Dialect.from_name("transform"), Dialect.from_name("transform.loop"), Dialect.from_name("transform.structured"), + Dialect.from_name("vector"), } diff --git a/src/pydsl/memref.py b/src/pydsl/memref.py index 9cb49b7..8e22c43 100644 --- a/src/pydsl/memref.py +++ b/src/pydsl/memref.py @@ -6,12 +6,12 @@ from ctypes import POINTER, c_void_p from dataclasses import dataclass from enum import Enum -from functools import cache, reduce +from functools import cache from typing import TYPE_CHECKING, Final import mlir.ir as mlir import numpy as np -from mlir.dialects import affine, memref +from mlir.dialects import affine, arith, memref from mlir.ir import ( DenseI64ArrayAttr, MemRefType, @@ -30,6 +30,8 @@ ) from pydsl.type import ( Index, + UInt8, + SInt8, Lowerable, Number, Slice, @@ -171,19 +173,47 @@ def from_CType( ) -def are_dims_compatible(*dims): - """Return True if all dimensions are compatible, considering DYNAMIC.""" - concrete_dims = {d for d in dims if d != DYNAMIC} - return len(concrete_dims) <= 1 +def are_dims_compatible(*dims, allow_broadcast=False): + """ + Return True if all dimensions are compatible. + + Compatibility rules: + - If allow_broadcast=False: + all non-DYNAMIC dims must be equal + - If allow_broadcast=True: + dims are compatible if they can broadcast + (same dim or 1, ignoring DYNAMIC) + """ + # Remove DYNAMIC (always compatible) + concrete_dims = [d for d in dims if d != DYNAMIC] + + if not concrete_dims: # all are DYNAMIC + return True + + if not allow_broadcast: + # Strict equality + return len(set(concrete_dims)) == 1 + # Broadcasting mode: all must be 1 or the maximum dimension + max_dim = max(concrete_dims) + return all(d == 1 or d == max_dim for d in concrete_dims) -def are_shapes_compatible(arr1: Iterable[int], arr2: Iterable[int]) -> bool: + +def are_shapes_compatible( + arr1: Iterable[int], arr2: Iterable[int], broadcast: list | None = None +) -> bool: """ Returns whether arr1 and arr2 have the same elements, excluding positions where at least one of the values is DYNAMIC. """ + if broadcast is None: + broadcast = [False for _ in arr1] + else: + broadcast = [i in broadcast for i in range(len(arr1))] + return len(arr1) == len(arr2) and all( - are_dims_compatible(a, b) for a, b in zip(arr1, arr2) + are_dims_compatible(a, b, allow_broadcast=bc) + for a, b, bc in zip(arr1, arr2, broadcast) ) @@ -199,6 +229,24 @@ def assert_shapes_compatible(arr1: Iterable[int], arr2: Iterable[int]) -> None: ) +def get_dynamic_and_static_values(Values: Tuple) -> Tuple[Tuple]: + """ + Checks for static python types (int, float, etc) which will be interpreted as static values in mlir, + and also for dynamic types (any PyDSL types). Then returns a tuple containing the dynamic and static value + tuples (dynamic_list, static_list), with the static value tuple containing DYNAMIC whenever a dynamic + value has been transferred to the dynamic list + """ + dynamic_list = [] + static_list = [] + for value in Values: + if isinstance(value, Lowerable): + dynamic_list.append(lower_single(value)) + static_list.append(DYNAMIC) + else: + static_list.append(value) + return tuple(dynamic_list), tuple(static_list) + + class UsesRMRD: """ A mixin class for adding CType support for classes that eventually lower @@ -509,9 +557,7 @@ def class_factory( get = class_factory @classmethod - def get_fully_dynamic( - cls, element_type, rank: int, memory_space: MemorySpace = None - ): + def get_fully_dynamic(cls, element_type, rank: int): """ Quick alias for returning a MemRef type where shape, offset, and strides are all dynamic. @@ -522,9 +568,145 @@ def get_fully_dynamic( element_type, offset=DYNAMIC, strides=dyn_list, - memory_space=memory_space, + memory_space=cls.memory_space + if hasattr(cls, "memory_space") + else None, + ) + + @classmethod + def calculate_subview_result(cls, low_list, size_list, step_list): + """ + Compute the result MemRef *class* for a subview with the given + low/size/step specifications. low, high, step may contain static + Python ints or dynamic Lowerable objects (e.g. Index). Each argument + may be None, a scalar, or an iterable. A single-element iterable is + broadcast to all dimensions. + """ + rank = len(cls.shape) + + def normalize(arg, default): + # Broadcast scalars, treat None as default for each dimension + if arg is None: + return tuple(default for _ in range(rank)) + if isinstance(arg, Iterable) and not isinstance( + arg, (Lowerable, Number) + ): + lst = tuple(arg) + if len(lst) == 1: + return tuple(lst[0] for _ in range(rank)) + if len(lst) != rank: + raise ValueError( + f"argument length {len(lst)} doesn't match rank {rank}" + ) + return lst + # single value (could be Number or Lowerable) + return tuple(arg for _ in range(rank)) + + # Defaults: low=0, size=original dim, step=1 + lows = normalize(low_list, 0) + sizes = normalize(size_list, None) + steps = normalize(step_list, 1) + + # Helper to decide if a value is dynamic (Lowerable) or not + def is_dynamic_val(v): + return isinstance(v, Lowerable) + + # Compute original (absolute) strides for default-layout memrefs + def compute_default_strides(shape): + strides = [None] * rank + prod = 1 + for i in range(rank - 1, -1, -1): + strides[i] = prod if prod != DYNAMIC else DYNAMIC + dim = shape[i] + if dim == DYNAMIC or prod == DYNAMIC: + prod = DYNAMIC + else: + prod = prod * dim + return tuple(strides) + + orig_strides = None + if cls.strides is not None: + orig_strides = tuple(cls.strides) + else: + orig_strides = compute_default_strides(cls.shape) + + # Compute static sizes (use DYNAMIC sentinel when dynamic) + static_sizes = [] + for i, s in enumerate(sizes): + if is_dynamic_val(s): + static_sizes.append(DYNAMIC) + elif s is None: + # default: whole dimension + static_sizes.append(cls.shape[i]) + else: + static_sizes.append(int(s)) + + # Compute resulting strides: stride' = orig_stride * step (or DYNAMIC) + result_strides = [] + for i, st in enumerate(steps): + if is_dynamic_val(st) or orig_strides[i] == DYNAMIC: + result_strides.append(DYNAMIC) + else: + result_strides.append(orig_strides[i] * int(st)) + + # Compute new offset: base_offset + sum(low_i * orig_stride_i) + base_off = cls.offset + if base_off == DYNAMIC: + new_offset = DYNAMIC + else: + new_offset_val = int(base_off) + offset_dynamic = False + for i, low in enumerate(lows): + if is_dynamic_val(low) or orig_strides[i] == DYNAMIC: + offset_dynamic = True + break + new_offset_val += int(low) * ( + orig_strides[i] if orig_strides[i] != DYNAMIC else 0 + ) + new_offset = DYNAMIC if offset_dynamic else new_offset_val + + # Decide whether to keep default (implicit) layout or emit explicit strided layout. + # Keep implicit default layout only if: + # - original layout was default (cls.strides is None), + # - every step is a compile-time 1, + # - every low is a compile-time 0, + # - and every resulting static size equals the original dimension (i.e., no shape change). + def const_int_or_none(v, default=None): + if is_dynamic_val(v): + return None + if v is None: + return default + try: + return int(v) + except Exception: + return None + + keep_implicit = ( + cls.strides is None + and all(const_int_or_none(st, 1) == 1 for st in steps) + and all(const_int_or_none(lo, 0) == 0 for lo in lows) + and all( + ( + static_sizes[i] != DYNAMIC + and static_sizes[i] == cls.shape[i] + ) + for i in range(rank) + ) ) + final_strides = None if keep_implicit else tuple(result_strides) + + # Build the result MemRef *class* + result_type = MemRef.class_factory( + tuple(static_sizes), + cls.element_type, + offset=(new_offset if new_offset != DYNAMIC else DYNAMIC), + strides=final_strides, + memory_space=cls.memory_space, + ) + + return result_type + def __init__(self, rep: OpView | Value) -> None: mlir_element_type = lower_single(self.element_type) if not any([ @@ -538,22 +720,31 @@ def __init__(self, rep: OpView | Value) -> None: if isinstance(rep, OpView): rep = rep.result - if (rep_type := type(rep.type)) is not MemRefType: + if isinstance(rep, Value): + rep_type = rep.type + element_type = rep_type.element_type + else: + return self.__init__(lower_single(rep)) + + if type(rep_type) is not MemRefType and not isinstance(rep, MemRef): raise TypeError(f"{rep_type} cannot be casted as a MemRef") if not all([ - self.shape == tuple(rep.type.shape), - lower_single(self.element_type) == rep.type.element_type, + all([ + x == y or x == DYNAMIC + for x, y in zip(self.shape, rep_type.shape) + ]), + lower_single(self.element_type) == element_type, ]): raise TypeError( f"expected shape {'x'.join([str(sh) for sh in self.shape])}" f"x{lower_single(self.element_type)}, got representation with shape " - f"{'x'.join([str(sh) for sh in rep.type.shape])}" - f"x{rep.type.element_type}" + f"{'x'.join([str(sh) for sh in rep_type.shape])}" + f"x{element_type}" ) cls_is_strided = self.strides is not None - rep_is_strided = isinstance(rep.type.layout, StridedLayoutAttr) + rep_is_strided = isinstance(rep_type.layout, StridedLayoutAttr) if cls_is_strided and not rep_is_strided: raise TypeError( @@ -566,14 +757,14 @@ def __init__(self, rep: OpView | Value) -> None: ) if cls_is_strided and not all([ - self.offset == rep.type.layout.offset, - self.strides == tuple(rep.type.layout.strides), + self.offset == rep_type.layout.offset, + self.strides == tuple(rep_type.layout.strides), ]): raise TypeError( f"expected layout with offset = {self.offset}," - f"strides = {self.strtides}, got representation with" - f"offset = {rep.type.layout.offset}, strides = " - f"{tuple(rep.type.layout.strides)}" + f"strides = {self.strides}, got representation with" + f"offset = {rep_type.layout.offset}, strides = " + f"{tuple(rep_type.layout.strides)}" ) self.value = rep @@ -615,9 +806,7 @@ def cons_affine_load(am: AffineMapExpr): lo_list, size_list, step_list = slices_to_mlir_format( key_list, self.runtime_shape ) - result_type = self.get_fully_dynamic( - self.element_type, dim, self.memory_space - ) + result_type = self.get_fully_dynamic(self.element_type, dim) dynamic_i64_attr = DenseI64ArrayAttr.get([DYNAMIC] * dim) rep = memref.SubViewOp( result_type.lower_class()[0], @@ -697,6 +886,105 @@ def on_class_getitem( # Equivalent to cls.__class_getitem__(args) return cls[args] + @CallMacro.generate(method_type=MethodType.INSTANCE) + def subview( + visitor: ToMLIRBase, + self: typing.Self, + *, + offsets: Evaluated = None, + sizes: Evaluated = None, + strides: Evaluated = None, + result_type: Evaluated = None, + ) -> typing.Self: + """ + is an explicit version of __getitem__ that allows for specifying the + result type. + """ + result_type = ( + result_type + if result_type is not None + else self.calculate_subview_result(offsets, sizes, strides) + ) + dynamic_offsets, static_offsets = get_dynamic_and_static_values( + offsets if offsets is not None else tuple([0] * self.rank()) + ) + dynamic_sizes, static_sizes = get_dynamic_and_static_values( + sizes if sizes is not None else self.shape + ) + dynamic_strides, static_strides = get_dynamic_and_static_values( + strides if strides is not None else tuple([1] * self.rank()) + ) + rep = memref.SubViewOp( + lower_single(result_type), + lower_single(self), + dynamic_offsets, + dynamic_sizes, + dynamic_strides, + static_offsets, + static_sizes, + static_strides, + ) + return result_type(rep) + + @CallMacro.generate(method_type=MethodType.INSTANCE) + def view( + visitor: ToMLIRBase, + self: typing.Self, + shape: Compiled, + dtype: Evaluated, + *, + byte_offset: Evaluated = None, + strides: Evaluated = None, + ) -> typing.Self: + """ """ + if not isinstance(shape, Tuple): + raise TypeError( + f"shape should be a Tuple, got {type(shape).__qualname__}" + ) + + if not issubclass(self.element_type, UInt8) and not issubclass( + self.element_type, SInt8 + ): + raise TypeError( + f"memref.view should be called on a array of type i8, got {self.element_type}" + ) + + if len(self.shape) != 1: + raise TypeError( + f"memref.view should be called on a 1D array, got {len(self.shape)}" + ) + + shape = shape.as_iterable(visitor) + strides = tuple(strides) if strides is not None else None + static_shape, dynamic_sizes = split_static_dynamic_dims(shape) + + m_type = MemRefFactory( + tuple(static_shape), + dtype, + memory_space=self.memory_space, + strides=strides, + ) + + return m_type( + memref.ViewOp( + lower_single(m_type), + lower_single(self), + lower_single(byte_offset) + if byte_offset is not None + else lower_single(Index(0)), + lower_flatten(dynamic_sizes), + ) + ) + + @CallMacro.generate(method_type=MethodType.INSTANCE) + def explicit_cast( + visitor: ToMLIRBase, + self: typing.Self, + result_type: Evaluated = None, + ) -> typing.Self: + rep = memref.cast(lower_single(result_type), lower_single(self)) + return result_type(rep) + @CallMacro.generate(method_type=MethodType.INSTANCE) def cast( visitor: ToMLIRBase, @@ -705,6 +993,7 @@ def cast( *, offset: Evaluated = None, strides: Evaluated = (-1,), + reinterpret: Evaluated = False, ) -> typing.Self: """ Converts a memref from one type to an equivalent type with a compatible @@ -738,6 +1027,33 @@ def f(m1: MemRef[F32, DYNAMIC, 32, 5]) -> MemRef[F32, 64, 32, DYNAMIC]: # to represent default layout... This is definitely not a great # solution, but it's unclear how to do this better. + if reinterpret: + dynamic_offsets, static_offsets = get_dynamic_and_static_values(( + offset, + )) + dynamic_strides, static_strides = get_dynamic_and_static_values( + strides + ) + dynamic_sizes, static_sizes = get_dynamic_and_static_values(shape) + result_type = self.class_factory( + static_sizes, + self.element_type, + offset=static_offsets[0], + strides=static_strides, + memory_space=self.memory_space, + ) + rep = memref.reinterpret_cast( + lower_single(result_type), + lower_single(self), + dynamic_offsets, + dynamic_sizes, + dynamic_strides, + static_offsets, + static_sizes, + static_strides, + ) + return result_type(rep) + shape = tuple(shape) if shape is not None else self.shape offset = int(offset) if offset is not None else self.offset strides = ( @@ -772,7 +1088,11 @@ def f(m1: MemRef[F32, DYNAMIC, 32, 5]) -> MemRef[F32, 64, 32, DYNAMIC]: assert_shapes_compatible(self.strides, strides) result_type = self.class_factory( - shape, self.element_type, offset=offset, strides=strides + shape, + self.element_type, + offset=offset, + strides=strides, + memory_space=self.memory_space, ) rep = memref.cast(lower_single(result_type), lower_single(self)) return result_type(rep) @@ -837,8 +1157,9 @@ def _alloc_generic( alloc_func: Callable, shape: Compiled, dtype: Evaluated, - memory_space: MemorySpace | None = None, - alignment: int | None = None, + memory_space: MemorySpace | None, + alignment: int | None, + strides: Evaluated = None, ) -> SubtreeOut: """ Does the logic required for alloc/alloca. It was silly having two functions @@ -869,10 +1190,11 @@ def _alloc_generic( raise ValueError(f"alignment must be positive, got {alignment}") shape = shape.as_iterable(visitor) + strides = tuple(strides) if strides is not None else None static_shape, dynamic_sizes = split_static_dynamic_dims(shape) m_type = MemRefFactory( - tuple(static_shape), dtype, memory_space=memory_space + tuple(static_shape), dtype, memory_space=memory_space, strides=strides ) return m_type( @@ -907,9 +1229,10 @@ def alloc( *, memory_space: Evaluated = None, alignment: Evaluated = None, + strides: Evaluated = None, ) -> SubtreeOut: return _alloc_generic( - visitor, memref.alloc, shape, dtype, memory_space, alignment + visitor, memref.alloc, shape, dtype, memory_space, alignment, strides ) @@ -952,6 +1275,67 @@ def copy(visitor: ToMLIRBase, src: Compiled, dst: Compiled) -> None: memref.copy(lower_single(src), lower_single(dst)) +def slice_shape( + key_list: list[Slice | SupportsIndex], shape: tuple[int] +) -> tuple[int]: + """ + Given a list of slices/indices, slice the shape. + """ + + dim = len(shape) + + if len(key_list) > dim: + raise IndexError( + f"number of subscripts {len(key_list)} is greater than number" + f"of dimensions {dim}" + ) + + while len(key_list) < dim: + key_list.append(Slice(None, None, None)) + + new_shape = [] + + def get_const_value(idx: Index, default: int): + """Extracts the integer value from an arith.constant OpResult, or returns default if None.""" + if idx is None: + return default if default != DYNAMIC else None + if isinstance(idx, int): + return idx + op = idx.value.owner + if isinstance(op, mlir.Operation) and isinstance( + op.operation, arith.ConstantOp + ): + return op.attributes["value"].value + else: + return None + + for i in range(dim): + key = key_list[i] + if isinstance(key, SupportsIndex): + new_shape.append(1) + elif isinstance(key, Slice): + lo, hi, step = key.lo, key.hi, key.step + + lo = get_const_value(lo, 0) + hi = get_const_value(hi, shape[i]) + step = get_const_value(step, 1) + if lo is None or hi is None or step is None: + dim = DYNAMIC + else: + if step == 0: + raise ValueError("slice step cannot be zero") + if step > 0: + dim = max(0, (hi - lo + step - 1) // step) + else: + dim = max(0, (lo - hi - step - 1) // (-step)) + new_shape.append(dim) + + else: + raise TypeError(f"{type(key)} cannot be used as a subscript") + + return new_shape + + def slices_to_mlir_format( key_list: list[Slice | SupportsIndex], runtime_shape: RuntimeMemrefShape ) -> tuple[list[Value], list[Value], list[Value]]: @@ -999,6 +1383,55 @@ def slices_to_mlir_format( return (lo_list, size_list, step_list) +def slices_to_mlir_format_dynamic( + key_list: list[Slice | SupportsIndex], runtime_shape: RuntimeMemrefShape +) -> tuple[list[Value], list[Value], list[Value]]: + """ + Given a list of slices/indices, converts the slices to MLIR format and + infers missing bounds/dimensions based on the dimensions of the tensor/memref. + 3 lists will be returned: [offsets], [sizes], [strides], which can be + passed to MLIR functions like tensor.extract_slice and memref.subview. + If key_list is shorter than runtime_shape, assume the entirety of the + remaining dimensions should be included ([:]). + There is currently no bounds checking! + Negative strides or indices are not supported (even though [3:-2:-1] can + be valid in normal Python) and result in undefined behaviour! + """ + + dim = len(runtime_shape) + + if len(key_list) > dim: + raise IndexError( + f"number of subscripts {len(key_list)} is greater than number" + f"of dimensions {dim}" + ) + + while len(key_list) < dim: + key_list.append(Slice(None, None, None)) + + lo_list = [] + size_list = [] + step_list = [] + + for i in range(dim): + key = key_list[i] + if isinstance(key, SupportsIndex): + lo_list.append(lower_single(Index(key))) + size_list.append(lower_single(Index(1))) + step_list.append(lower_single(Index(1))) + elif isinstance(key, Slice): + lo, size, step = key.get_args_static_and_dynamic( + Index(runtime_shape[i]) + ) + lo_list.append(lower_single(lo)) + size_list.append(lower_single(size)) + step_list.append(lower_single(step)) + else: + raise TypeError(f"{type(key)} cannot be used as a subscript") + + return (lo_list, size_list, step_list) + + def subtree_to_slices( visitor: "ToMLIRBase", key: SubtreeOut ) -> list[Slice | SupportsIndex]: @@ -1013,6 +1446,37 @@ def subtree_to_slices( raise TypeError(f"{type(key)} cannot be used as a subscript") +def split_static_dynamic_dims( + shape: Iterable[Number | SupportsIndex], +) -> tuple[list[int], list[Index]]: + """ + Given a shape with both static and dynamic dimensions, returns two lists: + static_shape and dynamic_sizes. static_shape is the same as shape, with all + dynamic dimensions replaced with the constant DYNAMIC. dynamic_sizes is a + list containing only the dynamic sizes, in order. Thus, it is true + that len(static_shape) == len(shape) and len(dynamic_dims) <= len(shape). + Raises a ValueError if the elements of shape are not Number or + SupportsIndex. + """ + static_shape = [] + dynamic_sizes = [] + + for s in shape: + match s: + case Number(): + static_shape.append(int(s.value)) + case SupportsIndex(): + static_shape.append(DYNAMIC) + dynamic_sizes.append(Index(s)) + case _: + raise ValueError( + f"dimension size should have type Number or Index, got " + f"{type(s).__qualname__}" + ) + + return static_shape, dynamic_sizes + + def calc_shape(memref_shape: tuple, assoc: list[list[int]]): # We need to make sure that assoc is valid # grouping of the dimensions 0 to n-1. @@ -1041,39 +1505,12 @@ def calc_shape(memref_shape: tuple, assoc: list[list[int]]): def collapse_shape(visitor: ToMLIRBase, mem: Compiled, assoc: Evaluated): shpe = calc_shape(mem.shape, assoc) result_type = MemRef[mem.element_type, *shpe] + result_type.offset = mem.offset + if mem.strides is not None: + result_type.strides = tuple([mem.strides[i] for i in range(len(shpe))]) + result_type.memory_space = mem.memory_space return result_type( memref.CollapseShapeOp( lower_single(result_type), lower_single(mem), assoc ) ) - - -def split_static_dynamic_dims( - shape: Iterable[Number | SupportsIndex], -) -> tuple[list[int], list[Index]]: - """ - Given a shape with both static and dynamic dimensions, returns two lists: - static_shape and dynamic_sizes. static_shape is the same as shape, with all - dynamic dimensions replaced with the constant DYNAMIC. dynamic_sizes is a - list containing only the dynamic sizes, in order. Thus, it is true - that len(static_shape) == len(shape) and len(dynamic_dims) <= len(shape). - Raises a ValueError if the elements of shape are not Number or - SupportsIndex. - """ - static_shape = [] - dynamic_sizes = [] - - for s in shape: - match s: - case Number(): - static_shape.append(int(s.value)) - case SupportsIndex(): - static_shape.append(DYNAMIC) - dynamic_sizes.append(Index(s)) - case _: - raise ValueError( - f"dimension size should have type Number or Index, got " - f"{type(s).__qualname__}" - ) - - return static_shape, dynamic_sizes diff --git a/tests/e2e/test_arith.py b/tests/e2e/test_arith.py index faedbdf..44109cd 100644 --- a/tests/e2e/test_arith.py +++ b/tests/e2e/test_arith.py @@ -435,6 +435,43 @@ def f(): f() +from pydsl.vector import Vector +import pydsl.arith as arith + +Vector2D = Vector.get((2, 2), UInt64) + + +def test_UInt64_vadd(): + @compile(globals()) + def f(arg0: Vector2D, arg1: Vector2D) -> Vector2D: + res = arith.vadd(arg0, arg1) + return res + + mlir = f.emit_mlir() + assert r"arith.addi" in mlir + + +def test_UInt64_vmul(): + @compile(globals()) + def f(arg0: Vector2D, arg1: Vector2D) -> Vector2D: + res = arith.vmul(arg0, arg1) + return res + + mlir = f.emit_mlir() + assert r"arith.muli" in mlir + + +def test_trunc_F64_to_F32(): + @compile(globals()) + def trunc(arg: F64) -> F32: + return arith.trunc(arg, F32) + + for i in f32_edges: + f32 = trunc(i) + assert isinstance(f32, float) + assert f32_isclose(i, f32) + + if __name__ == "__main__": run(test_val_range_Int8) run(test_ctype_range_UInt8) @@ -471,3 +508,6 @@ def f(): run(test_SInt_unary) run(test_Number_unary) run(test_Number_bool) + run(test_UInt64_vadd) + run(test_UInt64_vmul) + run(test_trunc_F64_to_F32)