diff --git a/Cargo.toml b/Cargo.toml index 2fdc21e..30eb044 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,13 +26,14 @@ fusion = [] autotune = [] [dependencies] -burn = { version = "0.16", default-features = false, features = ["std"] } -burn-tensor = "0.16" +burn = { version = "0.20", default-features = false, features = ["std"] } +burn-tensor = "0.20" mlx-rs = { package = "mlx-rs-burn", version = "0.25.5" } +mlx-sys = { package = "mlx-sys-burn", version = "0.2" } derive-new = "0.7" half = { version = "2.4", features = ["num-traits"] } num-traits = "0.2" [dev-dependencies] -burn = { version = "0.16", features = ["train", "dataset"] } +burn = { version = "0.20", features = ["train", "dataset"] } serial_test = "3.2" diff --git a/src/backend.rs b/src/backend.rs index 1ff2377..7c252bf 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,11 +1,14 @@ //! MLX Backend implementation for Burn. -use burn_tensor::backend::Backend; -use burn_tensor::TensorMetadata; +use burn_tensor::backend::{Backend, ExecutionError}; +use burn_tensor::{DType, TensorMetadata}; +use burn_tensor::quantization::QuantScheme; use mlx_rs::Array; +use std::marker::PhantomData; use std::sync::atomic::{AtomicU64, Ordering}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; // Global seed for random number generation static SEED: AtomicU64 = AtomicU64::new(0); @@ -43,17 +46,17 @@ unsafe impl Send for MlxTensorPrimitive {} unsafe impl Sync for MlxTensorPrimitive {} impl TensorMetadata for MlxTensorPrimitive { - fn dtype(&self) -> burn_tensor::DType { + fn dtype(&self) -> DType { // Map MLX dtype to Burn dtype match self.array.dtype() { - mlx_rs::Dtype::Float32 => burn_tensor::DType::F32, - mlx_rs::Dtype::Float16 => burn_tensor::DType::F16, - mlx_rs::Dtype::Bfloat16 => burn_tensor::DType::BF16, - mlx_rs::Dtype::Float64 => burn_tensor::DType::F64, - mlx_rs::Dtype::Int32 => burn_tensor::DType::I32, - mlx_rs::Dtype::Int64 => burn_tensor::DType::I64, - mlx_rs::Dtype::Bool => burn_tensor::DType::Bool, - _ => burn_tensor::DType::F32, // Default fallback + mlx_rs::Dtype::Float32 => DType::F32, + mlx_rs::Dtype::Float16 => DType::F16, + mlx_rs::Dtype::Bfloat16 => DType::BF16, + mlx_rs::Dtype::Float64 => DType::F64, + mlx_rs::Dtype::Int32 => DType::I32, + mlx_rs::Dtype::Int64 => DType::I64, + mlx_rs::Dtype::Bool => DType::Bool, + _ => DType::F32, // Default fallback } } @@ -62,56 +65,91 @@ impl TensorMetadata for MlxTensorPrimitive { } } -/// Quantized tensor primitive (placeholder for future implementation). +/// Quantized tensor primitive storing MLX's native quantized representation. #[derive(Debug, Clone)] pub struct MlxQuantizedTensorPrimitive { - /// The underlying tensor (stored as float for now). - pub tensor: MlxTensorPrimitive, - /// Quantization scheme. - pub scheme: QuantizationScheme, -} - -/// Quantization scheme. -#[derive(Debug, Clone, Copy, Default)] -pub enum QuantizationScheme { - #[default] - None, + /// Quantized weight values (MLX's packed uint format). + pub quantized: Array, + /// Per-group scale factors. + pub scales: Array, + /// Per-group zero-point biases. + pub biases: Array, + /// Logical tensor shape (e.g. [in_features, out_features]). + pub shape: Vec, + /// MLX group size (e.g. 32 or 64). + pub group_size: i32, + /// Bit width (4 or 8). + pub bits: i32, + /// Burn quantization scheme (for round-tripping back to Burn format). + pub scheme: QuantScheme, } -// SAFETY: Same as MlxTensorPrimitive +// SAFETY: Same as MlxTensorPrimitive — MLX uses internal synchronization. unsafe impl Send for MlxQuantizedTensorPrimitive {} unsafe impl Sync for MlxQuantizedTensorPrimitive {} impl TensorMetadata for MlxQuantizedTensorPrimitive { - fn dtype(&self) -> burn_tensor::DType { - self.tensor.dtype() + fn dtype(&self) -> DType { + DType::QFloat(self.scheme) } fn shape(&self) -> burn_tensor::Shape { - burn_tensor::Shape::from(self.tensor.shape.clone()) + burn_tensor::Shape::from(self.shape.clone()) } } impl burn_tensor::quantization::QTensorPrimitive for MlxQuantizedTensorPrimitive { - fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme { - // Return a reference to a static scheme - static SYMMETRIC: burn_tensor::quantization::QuantizationScheme = - burn_tensor::quantization::QuantizationScheme::PerTensorSymmetric( - burn_tensor::quantization::QuantizationType::QInt8, - ); - &SYMMETRIC + fn scheme(&self) -> &QuantScheme { + &self.scheme + } +} + +/// MLX Backend for Burn, generic over float precision. +/// +/// The default float type is `f32`. Use `Mlx` (or the `MlxHalf` alias) +/// for half-precision inference, which halves memory bandwidth and leverages +/// Apple Silicon's native f16 support. +/// +/// # Examples +/// +/// ```ignore +/// use burn_mlx::{Mlx, MlxHalf}; +/// +/// // f32 backend (default, same as before) +/// type Backend32 = Mlx; +/// +/// // f16 backend for faster inference +/// type Backend16 = MlxHalf; +/// ``` +pub struct Mlx { + _phantom: PhantomData, +} + +impl std::fmt::Debug for Mlx { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Mlx").finish() } } -/// MLX Backend for Burn. -#[derive(Debug, Default, Clone, Copy)] -pub struct Mlx; +impl Default for Mlx { + fn default() -> Self { + Self { _phantom: PhantomData } + } +} + +impl Clone for Mlx { + fn clone(&self) -> Self { + *self + } +} -impl Backend for Mlx { +impl Copy for Mlx {} + +impl Backend for Mlx { type Device = MlxDevice; type FloatTensorPrimitive = MlxTensorPrimitive; - type FloatElem = f32; + type FloatElem = F; type IntTensorPrimitive = MlxTensorPrimitive; type IntElem = i32; @@ -120,23 +158,34 @@ impl Backend for Mlx { type BoolElem = bool; type QuantizedTensorPrimitive = MlxQuantizedTensorPrimitive; - type QuantizedEncoding = i8; - fn name() -> String { + fn name(_device: &Self::Device) -> String { "mlx".to_string() } - fn seed(seed: u64) { + fn seed(_device: &Self::Device, seed: u64) { SEED.store(seed, Ordering::SeqCst); // MLX uses its own seeding mechanism - mlx_rs::random::seed(seed); + let _ = mlx_rs::random::seed(seed); + } + + fn supports_dtype(_device: &Self::Device, dtype: DType) -> bool { + matches!( + dtype, + DType::F32 | DType::F64 | DType::F16 | DType::BF16 | DType::I32 | DType::I64 | DType::Bool + ) } - fn sync(device: &Self::Device) { - // MLX is lazy-evaluated; sync forces evaluation - // This is a no-op in MLX as synchronization happens implicitly - // when reading tensor values - let _ = device; + fn sync(_device: &Self::Device) -> Result<(), ExecutionError> { + let stream = mlx_rs::Stream::default(); + let status = unsafe { mlx_sys::mlx_synchronize(stream.as_ptr()) }; + if status == 0 { + Ok(()) + } else { + Err(ExecutionError::WithContext { + reason: "MLX stream synchronization failed".into(), + }) + } } } diff --git a/src/device.rs b/src/device.rs index 0918c34..6f89029 100644 --- a/src/device.rs +++ b/src/device.rs @@ -32,11 +32,24 @@ impl fmt::Display for MlxDevice { } } -impl DeviceOps for MlxDevice { - fn id(&self) -> burn_tensor::backend::DeviceId { +impl burn_tensor::backend::Device for MlxDevice { + fn from_id(device_id: burn_tensor::backend::DeviceId) -> Self { + match device_id.type_id { + 0 => MlxDevice::Cpu, + _ => MlxDevice::Gpu, + } + } + + fn to_id(&self) -> burn_tensor::backend::DeviceId { match self { MlxDevice::Cpu => burn_tensor::backend::DeviceId::new(0, 0), MlxDevice::Gpu => burn_tensor::backend::DeviceId::new(1, 0), } } + + fn device_count(_type_id: u16) -> usize { + 1 + } } + +impl DeviceOps for MlxDevice {} diff --git a/src/element.rs b/src/element.rs index 27e920b..6977924 100644 --- a/src/element.rs +++ b/src/element.rs @@ -2,7 +2,8 @@ use burn_tensor::{DType, Element}; use half::{bf16, f16}; -use mlx_rs::Dtype; +use mlx_rs::{Array, Dtype}; +use num_traits::{Float, FromPrimitive}; /// Trait for elements that can be used with MLX. pub trait MlxElement: Element + Clone + Send + Sync + 'static { @@ -130,6 +131,87 @@ impl MlxElement for bool { } } +/// Trait for float elements that can be used as the primary float type in the MLX backend. +/// +/// This enables `Mlx` to be generic over the float precision (f32, f16, bf16, f64). +pub trait FloatMlxElement: MlxElement + Float + FromPrimitive { + /// Create a scalar MLX array from this value. + fn scalar_array(value: Self) -> Array; + + /// Create a scalar MLX array from an f64 constant. + fn f64_scalar_array(value: f64) -> Array { + Self::scalar_array(Self::from_f64(value).unwrap()) + } + + /// Create an MLX array from a slice of elements. + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array; + + /// Create a zeros array in this element's dtype. + fn zeros_array(shape: &[i32]) -> Array; + + /// Create a ones array in this element's dtype. + fn ones_array(shape: &[i32]) -> Array; + + /// Read an MLX array's data as a vector of this element type. + fn array_to_vec(array: &Array) -> Vec; + + /// Cast an MLX array to this element's dtype. + fn cast_array(array: &Array) -> Array; +} + +impl FloatMlxElement for f32 { + fn scalar_array(value: Self) -> Array { Array::from_f32(value) } + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) } + fn zeros_array(shape: &[i32]) -> Array { Array::zeros::(shape).expect("zeros") } + fn ones_array(shape: &[i32]) -> Array { Array::ones::(shape).expect("ones") } + fn array_to_vec(array: &Array) -> Vec { array.as_slice::().to_vec() } + fn cast_array(array: &Array) -> Array { array.as_type::().expect("cast") } +} + +impl FloatMlxElement for f16 { + fn scalar_array(value: Self) -> Array { Array::from_slice(&[value], &[1]) } + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) } + fn zeros_array(shape: &[i32]) -> Array { Array::zeros::(shape).expect("zeros") } + fn ones_array(shape: &[i32]) -> Array { Array::ones::(shape).expect("ones") } + fn array_to_vec(array: &Array) -> Vec { array.as_slice::().to_vec() } + fn cast_array(array: &Array) -> Array { array.as_type::().expect("cast") } +} + +impl FloatMlxElement for bf16 { + fn scalar_array(value: Self) -> Array { Array::from_slice(&[value], &[1]) } + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) } + fn zeros_array(shape: &[i32]) -> Array { Array::zeros::(shape).expect("zeros") } + fn ones_array(shape: &[i32]) -> Array { Array::ones::(shape).expect("ones") } + fn array_to_vec(array: &Array) -> Vec { array.as_slice::().to_vec() } + fn cast_array(array: &Array) -> Array { array.as_type::().expect("cast") } +} + +impl FloatMlxElement for f64 { + fn scalar_array(value: Self) -> Array { + // f64 doesn't implement FromSliceElement in mlx-rs, route through f32 + let arr = Array::from_f32(value as f32); + arr.as_type::().expect("cast to f64") + } + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { + let f32_data: Vec = data.iter().map(|&v| v as f32).collect(); + let arr = Array::from_slice(&f32_data, shape); + arr.as_type::().expect("cast to f64") + } + fn zeros_array(shape: &[i32]) -> Array { + let arr = Array::zeros::(shape).expect("zeros"); + arr.as_type::().expect("cast to f64") + } + fn ones_array(shape: &[i32]) -> Array { + let arr = Array::ones::(shape).expect("ones"); + arr.as_type::().expect("cast to f64") + } + fn array_to_vec(array: &Array) -> Vec { + let arr = array.as_type::().expect("cast to f32"); + arr.as_slice::().iter().map(|&v| v as f64).collect() + } + fn cast_array(array: &Array) -> Array { array.as_type::().expect("cast") } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index 180b3ab..0941819 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,9 +54,15 @@ mod ops; // Public exports pub use backend::{Mlx, MlxTensorPrimitive, MlxQuantizedTensorPrimitive}; pub use device::MlxDevice; -pub use element::MlxElement; +pub use element::{MlxElement, FloatMlxElement}; pub use tensor::MlxTensor; +/// Half-precision (f16) MLX backend for faster inference on Apple Silicon. +pub type MlxHalf = Mlx; + +/// BFloat16 MLX backend. +pub type MlxBf16 = Mlx; + /// Re-export mlx-rs types for advanced usage. pub mod mlx { pub use mlx_rs::*; @@ -65,7 +71,7 @@ pub mod mlx { #[cfg(test)] mod tests { use super::*; - use burn_tensor::{backend::Backend, Tensor, TensorData, Shape}; + use burn_tensor::{Tensor, TensorData, Shape}; #[test] fn test_device_creation() { @@ -163,12 +169,13 @@ mod tests { ); // Apply avg_pool2d with kernel_size=2, stride=2 - let pooled = Mlx::avg_pool2d( + let pooled = Mlx::::avg_pool2d( x.into_primitive().tensor(), [2, 2], [2, 2], [0, 0], true, + false, ); let shape = pooled.shape(); @@ -189,12 +196,13 @@ mod tests { ); // Apply max_pool2d with kernel_size=2, stride=2 - let pooled = Mlx::max_pool2d( + let pooled = Mlx::::max_pool2d( x.into_primitive().tensor(), [2, 2], [2, 2], [0, 0], [1, 1], + false, ); let shape = pooled.shape(); @@ -215,12 +223,13 @@ mod tests { ); // Apply max_pool2d_with_indices with kernel_size=2, stride=2 - let result = Mlx::max_pool2d_with_indices( + let result = Mlx::::max_pool2d_with_indices( x.into_primitive().tensor(), [2, 2], [2, 2], [0, 0], [1, 1], + false, ); let output_shape = result.output.shape(); @@ -244,12 +253,13 @@ mod tests { ); // Apply avg_pool1d with kernel_size=2, stride=2 - let pooled = Mlx::avg_pool1d( + let pooled = Mlx::::avg_pool1d( x.into_primitive().tensor(), 2, 2, 0, true, + false, ); let shape = pooled.shape(); @@ -270,12 +280,13 @@ mod tests { ); // Apply max_pool1d with kernel_size=2, stride=2 - let pooled = Mlx::max_pool1d( + let pooled = Mlx::::max_pool1d( x.into_primitive().tensor(), 2, 2, 0, 1, + false, ); let shape = pooled.shape(); diff --git a/src/ops/bool_ops.rs b/src/ops/bool_ops.rs index a1fab31..fb362ab 100644 --- a/src/ops/bool_ops.rs +++ b/src/ops/bool_ops.rs @@ -1,13 +1,14 @@ //! Boolean tensor operations for MLX backend. -use burn_tensor::{ops::BoolTensorOps, Shape, TensorData}; +use burn_tensor::{backend::ExecutionError, ops::BoolTensorOps, Shape, Slice, TensorData}; use mlx_rs::Array; -use std::ops::Range; +use mlx_rs::ops::indexing::{take_axis, take_along_axis}; use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; -impl BoolTensorOps for Mlx { +impl BoolTensorOps for Mlx { fn bool_from_data(data: TensorData, device: &MlxDevice) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); @@ -19,19 +20,18 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } - async fn bool_into_data(tensor: MlxTensorPrimitive) -> TensorData { + async fn bool_into_data(tensor: MlxTensorPrimitive) -> Result { tensor.array.eval().expect("Failed to evaluate tensor"); let shape = tensor.shape().to_vec(); let data: Vec = tensor.array.as_slice().to_vec(); - TensorData::new(data, shape) + Ok(TensorData::new(data, shape)) } - fn bool_device(tensor: &MlxTensorPrimitive) -> MlxDevice { + fn bool_device(_tensor: &MlxTensorPrimitive) -> MlxDevice { MlxDevice::Gpu } - fn bool_to_device(tensor: MlxTensorPrimitive, device: &MlxDevice) -> MlxTensorPrimitive { - let _ = device; + fn bool_to_device(tensor: MlxTensorPrimitive, _device: &MlxDevice) -> MlxTensorPrimitive { tensor } @@ -45,23 +45,58 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } + fn bool_zeros(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + Self::bool_empty(shape, device) + } + + fn bool_ones(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + let mlx_device = device.to_mlx_device(); + mlx_rs::Device::set_default(&mlx_device); + + let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); + let array = Array::ones::(&shape_i32).expect("Failed to create ones bool array"); + + MlxTensorPrimitive::new(array) + } + fn bool_reshape(tensor: MlxTensorPrimitive, shape: Shape) -> MlxTensorPrimitive { let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); let array = tensor.array.reshape(&shape_i32).expect("Failed to reshape"); MlxTensorPrimitive::new(array) } - fn bool_slice(tensor: MlxTensorPrimitive, ranges: &[Range]) -> MlxTensorPrimitive { - // Placeholder - need proper slice implementation - tensor + fn bool_slice(tensor: MlxTensorPrimitive, slices: &[Slice]) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); + let array = mlx_rs::ops::slice(&tensor.array, &starts, &stops, None) + .expect("Failed to slice"); + MlxTensorPrimitive::new(array) } fn bool_slice_assign( tensor: MlxTensorPrimitive, - ranges: &[Range], + slices: &[Slice], value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - tensor + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); + let array = mlx_rs::ops::slice_update(&tensor.array, &value.array, &starts, &stops, None) + .expect("Failed to slice_assign"); + MlxTensorPrimitive::new(array) } fn bool_into_int(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { @@ -70,7 +105,7 @@ impl BoolTensorOps for Mlx { } fn bool_into_float(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { - let array = tensor.array.as_type::().expect("Failed to cast to float"); + let array = F::cast_array(&tensor.array); MlxTensorPrimitive::new(array) } @@ -79,6 +114,16 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } + fn bool_and(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::logical_and(&lhs.array, &rhs.array).expect("Failed to logical_and"); + MlxTensorPrimitive::new(array) + } + + fn bool_or(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::logical_or(&lhs.array, &rhs.array).expect("Failed to logical_or"); + MlxTensorPrimitive::new(array) + } + fn bool_swap_dims(tensor: MlxTensorPrimitive, dim1: usize, dim2: usize) -> MlxTensorPrimitive { let ndim = tensor.shape().len(); let mut axes: Vec = (0..ndim as i32).collect(); @@ -94,8 +139,10 @@ impl BoolTensorOps for Mlx { } fn bool_flip(tensor: MlxTensorPrimitive, axes: &[usize]) -> MlxTensorPrimitive { - // Placeholder - MLX doesn't have direct flip - tensor + let axes_i32: Vec = axes.iter().map(|&a| a as i32).collect(); + let array = mlx_rs::ops::flip(&tensor.array, &axes_i32[..]) + .expect("Failed to flip"); + MlxTensorPrimitive::new(array) } fn bool_expand(tensor: MlxTensorPrimitive, shape: Shape) -> MlxTensorPrimitive { @@ -109,6 +156,12 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } + fn bool_equal_elem(lhs: MlxTensorPrimitive, rhs: bool) -> MlxTensorPrimitive { + let scalar = Array::from_slice(&[rhs], &[1]); + let array = mlx_rs::ops::eq(&lhs.array, &scalar).expect("Failed to equal_elem"); + MlxTensorPrimitive::new(array) + } + fn bool_any(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { let array = mlx_rs::ops::any(&tensor.array, false).expect("Failed to any"); MlxTensorPrimitive::new(array) @@ -129,9 +182,52 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } + fn bool_mask_where( + tensor: MlxTensorPrimitive, + mask: MlxTensorPrimitive, + value: MlxTensorPrimitive, + ) -> MlxTensorPrimitive { + let array = mlx_rs::ops::r#where(&mask.array, &value.array, &tensor.array) + .expect("Failed to mask_where"); + MlxTensorPrimitive::new(array) + } + + fn bool_mask_fill( + tensor: MlxTensorPrimitive, + mask: MlxTensorPrimitive, + value: bool, + ) -> MlxTensorPrimitive { + let fill_val = Array::from_slice(&[value], &[1]); + let fill_broadcast = mlx_rs::ops::broadcast_to(&fill_val, tensor.array.shape()) + .expect("Failed to broadcast"); + let array = mlx_rs::ops::r#where(&mask.array, &fill_broadcast, &tensor.array) + .expect("Failed to mask_fill"); + MlxTensorPrimitive::new(array) + } + + fn bool_gather( + dim: usize, + tensor: MlxTensorPrimitive, + indices: MlxTensorPrimitive, + ) -> MlxTensorPrimitive { + let array = take_along_axis(&tensor.array, &indices.array, dim as i32) + .expect("Failed to gather"); + MlxTensorPrimitive::new(array) + } + + fn bool_scatter_or( + dim: usize, + tensor: MlxTensorPrimitive, + indices: MlxTensorPrimitive, + value: MlxTensorPrimitive, + ) -> MlxTensorPrimitive { + let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) + .expect("Failed to scatter_or"); + MlxTensorPrimitive::new(array) + } + async fn bool_argwhere(_tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { // MLX argwhere may not be available in mlx-rs bindings - // Placeholder: return empty tensor let empty = mlx_rs::Array::zeros::(&[0, 1]).expect("Failed to create empty array"); MlxTensorPrimitive::new(empty) } @@ -140,4 +236,33 @@ impl BoolTensorOps for Mlx { let array = mlx_rs::ops::repeat_axis::(tensor.array, dim as i32, times as i32).expect("Failed to repeat"); MlxTensorPrimitive::new(array) } + + fn bool_unfold( + tensor: MlxTensorPrimitive, + dim: usize, + size: usize, + step: usize, + ) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let dim_size = shape[dim]; + let num_windows = (dim_size - size) / step + 1; + + let mut window_indices = Vec::new(); + for w in 0..num_windows { + let start = w * step; + for i in 0..size { + window_indices.push((start + i) as i32); + } + } + + let indices = Array::from_slice(&window_indices, &[(num_windows * size) as i32]); + let gathered = take_axis(&tensor.array, &indices, dim as i32).expect("take"); + + let mut new_shape: Vec = shape.iter().map(|&s| s as i32).collect(); + new_shape[dim] = num_windows as i32; + new_shape.push(size as i32); + let array = gathered.reshape(&new_shape).expect("reshape"); + + MlxTensorPrimitive::new(array) + } } diff --git a/src/ops/float_ops.rs b/src/ops/float_ops.rs index 4f632be..cbb0aba 100644 --- a/src/ops/float_ops.rs +++ b/src/ops/float_ops.rs @@ -1,21 +1,22 @@ //! Float tensor operations for MLX backend. -use burn_tensor::{ops::FloatTensorOps, Distribution, FloatDType, Shape, TensorData}; +use burn_tensor::{backend::ExecutionError, ops::FloatTensorOps, Distribution, FloatDType, Shape, Slice, TensorData}; +use half::{bf16, f16}; use mlx_rs::Array; use mlx_rs::ops::indexing::{take_axis, take_along_axis}; -use std::ops::Range; use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; -impl FloatTensorOps for Mlx { +impl FloatTensorOps for Mlx { fn float_from_data(data: TensorData, device: &MlxDevice) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); let shape: Vec = data.shape.iter().map(|&s| s as i32).collect(); - let values: Vec = data.to_vec().expect("Failed to convert data to f32 vec"); - let array = Array::from_slice(&values, &shape); + let values: Vec = data.to_vec().expect("Failed to convert data to vec"); + let array = F::array_from_slice(&values, &shape); MlxTensorPrimitive::new(array) } @@ -30,7 +31,9 @@ impl FloatTensorOps for Mlx { let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); - let array = match distribution { + // Generate random values in f32 (widest support in mlx-rs random API), + // then cast to the target float type F. + let array_f32 = match distribution { Distribution::Default => { mlx_rs::random::uniform::(0.0, 1.0, &shape_i32, None) .expect("Failed to create uniform random array") @@ -58,14 +61,15 @@ impl FloatTensorOps for Mlx { } }; + let array = F::cast_array(&array_f32); MlxTensorPrimitive::new(array) } - async fn float_into_data(tensor: MlxTensorPrimitive) -> TensorData { + async fn float_into_data(tensor: MlxTensorPrimitive) -> Result { tensor.array.eval().expect("Failed to evaluate tensor"); let shape = tensor.shape().to_vec(); - let data: Vec = tensor.array.as_slice().to_vec(); - TensorData::new(data, shape) + let data: Vec = F::array_to_vec(&tensor.array); + Ok(TensorData::new(data, shape)) } fn float_device(tensor: &MlxTensorPrimitive) -> MlxDevice { @@ -78,12 +82,12 @@ impl FloatTensorOps for Mlx { tensor } - fn float_empty(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + fn float_empty(shape: Shape, device: &MlxDevice, _dtype: FloatDType) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); - let array = Array::zeros::(&shape_i32).expect("Failed to create empty array"); + let array = F::zeros_array(&shape_i32); MlxTensorPrimitive::new(array) } @@ -93,8 +97,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_add_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_add_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::add(&lhs.array, &scalar).expect("Failed to add scalar"); MlxTensorPrimitive::new(array) } @@ -104,8 +108,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_sub_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_sub_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::subtract(&lhs.array, &scalar).expect("Failed to subtract scalar"); MlxTensorPrimitive::new(array) } @@ -115,8 +119,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_mul_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_mul_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::multiply(&lhs.array, &scalar).expect("Failed to multiply scalar"); MlxTensorPrimitive::new(array) } @@ -126,8 +130,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_div_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_div_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::divide(&lhs.array, &scalar).expect("Failed to divide scalar"); MlxTensorPrimitive::new(array) } @@ -137,8 +141,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_remainder_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_remainder_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::remainder(&lhs.array, &scalar).expect("Failed to remainder scalar"); MlxTensorPrimitive::new(array) @@ -155,7 +159,7 @@ impl FloatTensorOps for Mlx { } fn float_recip(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { - let one = Array::from_f32(1.0); + let one = F::f64_scalar_array(1.0); let array = mlx_rs::ops::divide(&one, &tensor.array).expect("Failed to recip"); MlxTensorPrimitive::new(array) } @@ -203,15 +207,14 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_scatter( + fn float_scatter_add( dim: usize, tensor: MlxTensorPrimitive, indices: MlxTensorPrimitive, value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Use put_along_axis for scatter operation let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) - .expect("Failed to scatter"); + .expect("Failed to scatter_add"); MlxTensorPrimitive::new(array) } @@ -225,21 +228,27 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_select_assign( + fn float_select_add( tensor: MlxTensorPrimitive, dim: usize, indices: MlxTensorPrimitive, value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Use put_along_axis for select_assign operation let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) - .expect("Failed to select_assign"); + .expect("Failed to select_add"); MlxTensorPrimitive::new(array) } - fn float_slice(tensor: MlxTensorPrimitive, ranges: &[Range]) -> MlxTensorPrimitive { - let starts: Vec = ranges.iter().map(|r| r.start as i32).collect(); - let stops: Vec = ranges.iter().map(|r| r.end as i32).collect(); + fn float_slice(tensor: MlxTensorPrimitive, slices: &[Slice]) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); let array = mlx_rs::ops::slice(&tensor.array, &starts, &stops, None) .expect("Failed to slice"); MlxTensorPrimitive::new(array) @@ -247,11 +256,18 @@ impl FloatTensorOps for Mlx { fn float_slice_assign( tensor: MlxTensorPrimitive, - ranges: &[Range], + slices: &[Slice], value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - let starts: Vec = ranges.iter().map(|r| r.start as i32).collect(); - let stops: Vec = ranges.iter().map(|r| r.end as i32).collect(); + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); let array = mlx_rs::ops::slice_update(&tensor.array, &value.array, &starts, &stops, None) .expect("Failed to slice_assign"); MlxTensorPrimitive::new(array) @@ -270,9 +286,9 @@ impl FloatTensorOps for Mlx { fn float_mask_fill( tensor: MlxTensorPrimitive, mask: MlxTensorPrimitive, - value: f32, + value: F, ) -> MlxTensorPrimitive { - let fill_val = Array::from_f32(value); + let fill_val = F::scalar_array(value); let fill_broadcast = mlx_rs::ops::broadcast_to(&fill_val, tensor.array.shape()).expect("Failed to broadcast"); let array = mlx_rs::ops::r#where(&mask.array, &fill_broadcast, &tensor.array) @@ -285,8 +301,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_equal_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_equal_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::eq(&lhs.array, &scalar).expect("Failed to equal_elem"); MlxTensorPrimitive::new(array) } @@ -296,8 +312,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_greater_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_greater_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::gt(&lhs.array, &scalar).expect("Failed to greater_elem"); MlxTensorPrimitive::new(array) } @@ -307,8 +323,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_greater_equal_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_greater_equal_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::ge(&lhs.array, &scalar).expect("Failed to greater_equal_elem"); MlxTensorPrimitive::new(array) } @@ -318,8 +334,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_lower_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_lower_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::lt(&lhs.array, &scalar).expect("Failed to lower_elem"); MlxTensorPrimitive::new(array) } @@ -329,8 +345,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_lower_equal_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_lower_equal_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::le(&lhs.array, &scalar).expect("Failed to lower_equal_elem"); MlxTensorPrimitive::new(array) } @@ -389,7 +405,11 @@ impl FloatTensorOps for Mlx { } fn float_powf_scalar(tensor: MlxTensorPrimitive, value: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(value); + Self::float_powf_scalar_impl(tensor, value) + } + + fn float_powf_scalar_impl(tensor: MlxTensorPrimitive, value: f32) -> MlxTensorPrimitive { + let scalar = F::f64_scalar_array(value as f64); let array = mlx_rs::ops::power(&tensor.array, &scalar).expect("Failed to powf_scalar"); MlxTensorPrimitive::new(array) } @@ -491,21 +511,21 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_clamp(tensor: MlxTensorPrimitive, min: f32, max: f32) -> MlxTensorPrimitive { - let min_arr = Array::from_f32(min); - let max_arr = Array::from_f32(max); + fn float_clamp(tensor: MlxTensorPrimitive, min: F, max: F) -> MlxTensorPrimitive { + let min_arr = F::scalar_array(min); + let max_arr = F::scalar_array(max); let array = mlx_rs::ops::clip(&tensor.array, (&min_arr, &max_arr)).expect("Failed to clamp"); MlxTensorPrimitive::new(array) } - fn float_clamp_min(tensor: MlxTensorPrimitive, min: f32) -> MlxTensorPrimitive { - let min_arr = Array::from_f32(min); + fn float_clamp_min(tensor: MlxTensorPrimitive, min: F) -> MlxTensorPrimitive { + let min_arr = F::scalar_array(min); let array = mlx_rs::ops::maximum(&tensor.array, &min_arr).expect("Failed to clamp_min"); MlxTensorPrimitive::new(array) } - fn float_clamp_max(tensor: MlxTensorPrimitive, max: f32) -> MlxTensorPrimitive { - let max_arr = Array::from_f32(max); + fn float_clamp_max(tensor: MlxTensorPrimitive, max: F) -> MlxTensorPrimitive { + let max_arr = F::scalar_array(max); let array = mlx_rs::ops::minimum(&tensor.array, &max_arr).expect("Failed to clamp_max"); MlxTensorPrimitive::new(array) } @@ -546,7 +566,6 @@ impl FloatTensorOps for Mlx { fn float_sort(tensor: MlxTensorPrimitive, dim: usize, _descending: bool) -> MlxTensorPrimitive { let sorted = mlx_rs::ops::sort_axis(&tensor.array, dim as i32).expect("Failed to sort"); - // Note: MLX sort is ascending only; descending would need flip MlxTensorPrimitive::new(sorted) } @@ -570,9 +589,11 @@ impl FloatTensorOps for Mlx { fn float_cast(tensor: MlxTensorPrimitive, dtype: FloatDType) -> MlxTensorPrimitive { let array = match dtype { + FloatDType::F16 => tensor.array.as_type::().expect("cast to f16"), + FloatDType::BF16 => tensor.array.as_type::().expect("cast to bf16"), FloatDType::F32 => tensor.array.as_type::().expect("cast to f32"), FloatDType::F64 => tensor.array.as_type::().expect("cast to f64"), - _ => tensor.array, // Keep as-is for unsupported types + _ => tensor.array, }; MlxTensorPrimitive::new(array) } @@ -591,4 +612,148 @@ impl FloatTensorOps for Mlx { let array = mlx_rs::ops::ceil(&tensor.array).expect("Failed to ceil"); MlxTensorPrimitive::new(array) } + + fn float_trunc(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + // trunc(x) = sign(x) * floor(abs(x)) + let abs_val = mlx_rs::ops::abs(&tensor.array).expect("abs"); + let floored = mlx_rs::ops::floor(&abs_val).expect("floor"); + let sign_val = mlx_rs::ops::sign(&tensor.array).expect("sign"); + let array = mlx_rs::ops::multiply(&sign_val, &floored).expect("multiply"); + MlxTensorPrimitive::new(array) + } + + fn float_tan(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::tan(&tensor.array).expect("Failed to tan"); + MlxTensorPrimitive::new(array) + } + + fn float_cosh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cosh(&tensor.array).expect("Failed to cosh"); + MlxTensorPrimitive::new(array) + } + + fn float_sinh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::sinh(&tensor.array).expect("Failed to sinh"); + MlxTensorPrimitive::new(array) + } + + fn float_acos(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::acos(&tensor.array).expect("Failed to acos"); + MlxTensorPrimitive::new(array) + } + + fn float_acosh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::acosh(&tensor.array).expect("Failed to acosh"); + MlxTensorPrimitive::new(array) + } + + fn float_asin(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::asin(&tensor.array).expect("Failed to asin"); + MlxTensorPrimitive::new(array) + } + + fn float_asinh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::asinh(&tensor.array).expect("Failed to asinh"); + MlxTensorPrimitive::new(array) + } + + fn float_atan(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::atan(&tensor.array).expect("Failed to atan"); + MlxTensorPrimitive::new(array) + } + + fn float_atanh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::atanh(&tensor.array).expect("Failed to atanh"); + MlxTensorPrimitive::new(array) + } + + fn float_atan2(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::atan2(&lhs.array, &rhs.array).expect("Failed to atan2"); + MlxTensorPrimitive::new(array) + } + + fn float_cross( + lhs: MlxTensorPrimitive, + rhs: MlxTensorPrimitive, + dim: usize, + ) -> MlxTensorPrimitive { + let dim_i32 = dim as i32; + + let a0 = take_axis(&lhs.array, &Array::from_int(0), dim_i32).expect("take"); + let a1 = take_axis(&lhs.array, &Array::from_int(1), dim_i32).expect("take"); + let a2 = take_axis(&lhs.array, &Array::from_int(2), dim_i32).expect("take"); + + let b0 = take_axis(&rhs.array, &Array::from_int(0), dim_i32).expect("take"); + let b1 = take_axis(&rhs.array, &Array::from_int(1), dim_i32).expect("take"); + let b2 = take_axis(&rhs.array, &Array::from_int(2), dim_i32).expect("take"); + + let r0 = mlx_rs::ops::subtract( + &mlx_rs::ops::multiply(&a1, &b2).expect("mul"), + &mlx_rs::ops::multiply(&a2, &b1).expect("mul"), + ).expect("sub"); + let r1 = mlx_rs::ops::subtract( + &mlx_rs::ops::multiply(&a2, &b0).expect("mul"), + &mlx_rs::ops::multiply(&a0, &b2).expect("mul"), + ).expect("sub"); + let r2 = mlx_rs::ops::subtract( + &mlx_rs::ops::multiply(&a0, &b1).expect("mul"), + &mlx_rs::ops::multiply(&a1, &b0).expect("mul"), + ).expect("sub"); + + let array = mlx_rs::ops::stack_axis(&[&r0, &r1, &r2], dim_i32).expect("stack"); + MlxTensorPrimitive::new(array) + } + + fn float_cumsum(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cumsum(&tensor.array, dim as i32, None, None) + .expect("Failed to cumsum"); + MlxTensorPrimitive::new(array) + } + + fn float_cumprod(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cumprod(&tensor.array, dim as i32, None, None) + .expect("Failed to cumprod"); + MlxTensorPrimitive::new(array) + } + + fn float_cummin(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cummin(&tensor.array, dim as i32, None, None) + .expect("Failed to cummin"); + MlxTensorPrimitive::new(array) + } + + fn float_cummax(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cummax(&tensor.array, dim as i32, None, None) + .expect("Failed to cummax"); + MlxTensorPrimitive::new(array) + } + + fn float_unfold( + tensor: MlxTensorPrimitive, + dim: usize, + size: usize, + step: usize, + ) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let dim_size = shape[dim]; + let num_windows = (dim_size - size) / step + 1; + + let mut window_indices = Vec::new(); + for w in 0..num_windows { + let start = w * step; + for i in 0..size { + window_indices.push((start + i) as i32); + } + } + + let indices = Array::from_slice(&window_indices, &[(num_windows * size) as i32]); + let gathered = take_axis(&tensor.array, &indices, dim as i32).expect("take"); + + let mut new_shape: Vec = shape.iter().map(|&s| s as i32).collect(); + new_shape[dim] = num_windows as i32; + new_shape.push(size as i32); + let array = gathered.reshape(&new_shape).expect("reshape"); + + MlxTensorPrimitive::new(array) + } } diff --git a/src/ops/int_ops.rs b/src/ops/int_ops.rs index c2f85e9..5fe1f25 100644 --- a/src/ops/int_ops.rs +++ b/src/ops/int_ops.rs @@ -1,14 +1,14 @@ //! Integer tensor operations for MLX backend. -use burn_tensor::{ops::IntTensorOps, Distribution, Shape, TensorData}; +use burn_tensor::{backend::ExecutionError, ops::IntTensorOps, Distribution, IntDType, Shape, Slice, TensorData}; use mlx_rs::Array; use mlx_rs::ops::indexing::{argmax_axis, argmin_axis, take_axis, take_along_axis}; -use std::ops::Range; use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; -impl IntTensorOps for Mlx { +impl IntTensorOps for Mlx { fn int_from_data(data: TensorData, device: &MlxDevice) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); @@ -20,14 +20,14 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - async fn int_into_data(tensor: MlxTensorPrimitive) -> TensorData { + async fn int_into_data(tensor: MlxTensorPrimitive) -> Result { tensor.array.eval().expect("Failed to evaluate tensor"); let shape = tensor.shape().to_vec(); let data: Vec = tensor.array.as_slice().to_vec(); - TensorData::new(data, shape) + Ok(TensorData::new(data, shape)) } - fn int_device(tensor: &MlxTensorPrimitive) -> MlxDevice { + fn int_device(_tensor: &MlxTensorPrimitive) -> MlxDevice { MlxDevice::Gpu } @@ -36,7 +36,7 @@ impl IntTensorOps for Mlx { tensor } - fn int_empty(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + fn int_empty(shape: Shape, device: &MlxDevice, _dtype: IntDType) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); @@ -44,11 +44,11 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn int_zeros(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { - Self::int_empty(shape, device) + fn int_zeros(shape: Shape, device: &MlxDevice, dtype: IntDType) -> MlxTensorPrimitive { + Self::int_empty(shape, device, dtype) } - fn int_ones(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + fn int_ones(shape: Shape, device: &MlxDevice, _dtype: IntDType) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); @@ -170,9 +170,16 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn int_slice(tensor: MlxTensorPrimitive, ranges: &[Range]) -> MlxTensorPrimitive { - let starts: Vec = ranges.iter().map(|r| r.start as i32).collect(); - let stops: Vec = ranges.iter().map(|r| r.end as i32).collect(); + fn int_slice(tensor: MlxTensorPrimitive, slices: &[Slice]) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); let array = mlx_rs::ops::slice(&tensor.array, &starts, &stops, None) .expect("Failed to slice"); MlxTensorPrimitive::new(array) @@ -180,11 +187,18 @@ impl IntTensorOps for Mlx { fn int_slice_assign( tensor: MlxTensorPrimitive, - ranges: &[Range], + slices: &[Slice], value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - let starts: Vec = ranges.iter().map(|r| r.start as i32).collect(); - let stops: Vec = ranges.iter().map(|r| r.end as i32).collect(); + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); let array = mlx_rs::ops::slice_update(&tensor.array, &value.array, &starts, &stops, None) .expect("Failed to slice_assign"); MlxTensorPrimitive::new(array) @@ -215,10 +229,9 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn int_scatter(dim: usize, tensor: MlxTensorPrimitive, indices: MlxTensorPrimitive, value: MlxTensorPrimitive) -> MlxTensorPrimitive { - // Use put_along_axis for scatter operation + fn int_scatter_add(dim: usize, tensor: MlxTensorPrimitive, indices: MlxTensorPrimitive, value: MlxTensorPrimitive) -> MlxTensorPrimitive { let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) - .expect("Failed to scatter"); + .expect("Failed to scatter_add"); MlxTensorPrimitive::new(array) } @@ -227,10 +240,9 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn int_select_assign(tensor: MlxTensorPrimitive, dim: usize, indices: MlxTensorPrimitive, value: MlxTensorPrimitive) -> MlxTensorPrimitive { - // Use put_along_axis for select_assign operation + fn int_select_add(tensor: MlxTensorPrimitive, dim: usize, indices: MlxTensorPrimitive, value: MlxTensorPrimitive) -> MlxTensorPrimitive { let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) - .expect("Failed to select_assign"); + .expect("Failed to select_add"); MlxTensorPrimitive::new(array) } @@ -357,7 +369,7 @@ impl IntTensorOps for Mlx { } fn int_into_float(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { - let array = tensor.array.as_type::().expect("Failed to cast to float"); + let array = F::cast_array(&tensor.array); MlxTensorPrimitive::new(array) } @@ -407,4 +419,176 @@ impl IntTensorOps for Mlx { let array = mlx_rs::ops::all_axis(&tensor.array, dim as i32, true).expect("Failed to all_dim"); MlxTensorPrimitive::new(array) } + + fn int_matmul(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + // MLX matmul requires float, so cast to backend float type, matmul, then cast back + let lhs_f = F::cast_array(&lhs.array); + let rhs_f = F::cast_array(&rhs.array); + let result = lhs_f.matmul(&rhs_f).expect("matmul"); + let array = result.as_type::().expect("cast back"); + MlxTensorPrimitive::new(array) + } + + fn int_cast(tensor: MlxTensorPrimitive, dtype: IntDType) -> MlxTensorPrimitive { + let array = match dtype { + IntDType::I32 => tensor.array.as_type::().expect("cast to i32"), + IntDType::I64 => tensor.array.as_type::().expect("cast to i64"), + IntDType::I16 => tensor.array.as_type::().expect("cast to i16"), + IntDType::I8 => tensor.array.as_type::().expect("cast to i8"), + _ => tensor.array, + }; + MlxTensorPrimitive::new(array) + } + + fn int_cumsum(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cumsum(&tensor.array, dim as i32, None, None) + .expect("Failed to cumsum"); + MlxTensorPrimitive::new(array) + } + + fn int_cumprod(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cumprod(&tensor.array, dim as i32, None, None) + .expect("Failed to cumprod"); + MlxTensorPrimitive::new(array) + } + + fn int_cummin(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cummin(&tensor.array, dim as i32, None, None) + .expect("Failed to cummin"); + MlxTensorPrimitive::new(array) + } + + fn int_cummax(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cummax(&tensor.array, dim as i32, None, None) + .expect("Failed to cummax"); + MlxTensorPrimitive::new(array) + } + + fn int_unfold( + tensor: MlxTensorPrimitive, + dim: usize, + size: usize, + step: usize, + ) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let dim_size = shape[dim]; + let num_windows = (dim_size - size) / step + 1; + + let mut window_indices = Vec::new(); + for w in 0..num_windows { + let start = w * step; + for i in 0..size { + window_indices.push((start + i) as i32); + } + } + + let indices = Array::from_slice(&window_indices, &[(num_windows * size) as i32]); + let gathered = take_axis(&tensor.array, &indices, dim as i32).expect("take"); + + let mut new_shape: Vec = shape.iter().map(|&s| s as i32).collect(); + new_shape[dim] = num_windows as i32; + new_shape.push(size as i32); + let array = gathered.reshape(&new_shape).expect("reshape"); + + MlxTensorPrimitive::new(array) + } + + // Bitwise operations - implemented in software as mlx-rs doesn't expose bitwise ops + fn bitwise_and(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a & b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_and_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a & rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_or(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a | b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_or_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a | rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_xor(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a ^ b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_xor_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a ^ rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_not(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + tensor.array.eval().expect("eval"); + let data: Vec = tensor.array.as_slice().to_vec(); + let result: Vec = data.iter().map(|a| !a).collect(); + let shape: Vec = tensor.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_left_shift(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a << b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_left_shift_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a << rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_right_shift(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a >> b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_right_shift_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a >> rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } } diff --git a/src/ops/module_ops.rs b/src/ops/module_ops.rs index ce6856a..3449945 100644 --- a/src/ops/module_ops.rs +++ b/src/ops/module_ops.rs @@ -8,6 +8,7 @@ use mlx_rs::Array; use mlx_rs::ops::indexing::take_axis; use crate::backend::{Mlx, MlxTensorPrimitive}; +use crate::element::FloatMlxElement; /// Helper function to compute pooling using as_strided approach. /// This follows the pattern from mlx-rs nn/pooling.rs. @@ -184,17 +185,6 @@ fn max_pool2d_with_indices_impl( let local_indices = mlx_rs::ops::indexing::argmax_axis(&reshaped, 3, None).expect("argmax"); // Convert local indices (within kernel) to flat indices into padded NHWC input - // For each output position (n, oh, ow, c), the local_idx tells us which element - // in the kH*kW kernel was the max. - // - // The actual position in the padded input (NHWC layout) is: - // n * (H * W * C) + (oh * stride[0] + local_h) * (W * C) + (ow * stride[1] + local_w) * C + c - // where local_h = local_idx / kW, local_w = local_idx % kW - // - // We need to compute this index for the backward pass. - - // Create coordinate arrays for output positions - // Shape of output/indices: [N, out_H, out_W, C] let out_h_size = out_h as usize; let out_w_size = out_w as usize; let n_size = n as usize; @@ -203,22 +193,18 @@ fn max_pool2d_with_indices_impl( let w_size = w as usize; // Create index arrays for n, oh, ow, c dimensions - // n_idx: [N, 1, 1, 1] broadcast to [N, out_H, out_W, C] let n_range: Vec = (0..n_size as i32).collect(); let n_idx = Array::from_slice(&n_range, &[n_size as i32]) .reshape(&[n, 1, 1, 1]).expect("reshape"); - // oh_idx: [1, out_H, 1, 1] let oh_range: Vec = (0..out_h_size as i32).collect(); let oh_idx = Array::from_slice(&oh_range, &[out_h_size as i32]) .reshape(&[1, out_h, 1, 1]).expect("reshape"); - // ow_idx: [1, 1, out_W, 1] let ow_range: Vec = (0..out_w_size as i32).collect(); let ow_idx = Array::from_slice(&ow_range, &[out_w_size as i32]) .reshape(&[1, 1, out_w, 1]).expect("reshape"); - // c_idx: [1, 1, 1, C] let c_range: Vec = (0..c_size as i32).collect(); let c_idx = Array::from_slice(&c_range, &[c_size as i32]) .reshape(&[1, 1, 1, c]).expect("reshape"); @@ -232,13 +218,11 @@ fn max_pool2d_with_indices_impl( let sh_arr = Array::from_int(stride[0] as i32); let sw_arr = Array::from_int(stride[1] as i32); - // actual_h = oh * stride[0] + local_h let actual_h = mlx_rs::ops::add( &mlx_rs::ops::multiply(&oh_idx, &sh_arr).expect("mul"), &local_h ).expect("add"); - // actual_w = ow * stride[1] + local_w let actual_w = mlx_rs::ops::add( &mlx_rs::ops::multiply(&ow_idx, &sw_arr).expect("mul"), &local_w @@ -263,7 +247,7 @@ fn max_pool2d_with_indices_impl( (output, flat_indices) } -impl ModuleOps for Mlx { +impl ModuleOps for Mlx { fn conv1d( x: MlxTensorPrimitive, weight: MlxTensorPrimitive, @@ -293,7 +277,6 @@ impl ModuleOps for Mlx { // Add bias if provided if let Some(b) = bias { - // Reshape bias from [C_out] to [1, C_out, 1] let b_shape = b.shape(); let b_reshaped = b.array.reshape(&[1, b_shape[0] as i32, 1]).expect("reshape bias"); output = mlx_rs::ops::add(&output, &b_reshaped).expect("add bias"); @@ -359,7 +342,6 @@ impl ModuleOps for Mlx { _bias: Option, _options: ConvTransposeOptions<1>, ) -> MlxTensorPrimitive { - // conv_transpose1d is complex in MLX - placeholder x } @@ -369,7 +351,6 @@ impl ModuleOps for Mlx { _bias: Option, _options: ConvTransposeOptions<2>, ) -> MlxTensorPrimitive { - // conv_transpose2d is complex in MLX - placeholder x } @@ -379,7 +360,6 @@ impl ModuleOps for Mlx { _bias: Option, _options: ConvTransposeOptions<3>, ) -> MlxTensorPrimitive { - // Placeholder x } @@ -391,9 +371,8 @@ impl ModuleOps for Mlx { _bias: Option, _options: DeformConvOptions<2>, ) -> MlxTensorPrimitive { - // Deformable convolution is not supported in MLX - placeholder let shape = [1i32, 1, 1, 1]; - let array = Array::zeros::(&shape).expect("zeros"); + let array = F::zeros_array(&shape); MlxTensorPrimitive::new(array) } @@ -405,10 +384,9 @@ impl ModuleOps for Mlx { _bias: Option, _out_grad: MlxTensorPrimitive, _options: DeformConvOptions<2>, - ) -> DeformConv2dBackward { - // Placeholder + ) -> DeformConv2dBackward> { let shape = [1i32, 1, 1, 1]; - let zeros = MlxTensorPrimitive::new(Array::zeros::(&shape).expect("zeros")); + let zeros = MlxTensorPrimitive::new(F::zeros_array(&shape)); DeformConv2dBackward::new( zeros.clone(), zeros.clone(), @@ -424,16 +402,13 @@ impl ModuleOps for Mlx { stride: usize, padding: usize, _count_include_pad: bool, + _ceil_mode: bool, ) -> MlxTensorPrimitive { // Burn uses NCL format, MLX uses NLC format - // Transpose from [N, C, L] to [N, L, C] let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 1]).expect("transpose"); - // Apply padding if needed let x_padded = if padding > 0 { let pad = padding as i32; - // Pad only the L dimension (axis 1 in NLC format) - // PadWidth for [N, L, C]: [(0,0), (pad,pad), (0,0)] mlx_rs::ops::pad( &x_nhwc, &[(0, 0), (pad, pad), (0, 0)], @@ -444,14 +419,11 @@ impl ModuleOps for Mlx { x_nhwc }; - // Apply pooling using as_strided + mean_axes let pooled = pool1d_strided(&x_padded, kernel_size, stride, |arr, axes| { arr.mean_axes(axes, None) }); - // Transpose back from [N, L, C] to [N, C, L] let output = mlx_rs::ops::transpose_axes(&pooled, &[0, 2, 1]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -461,17 +433,14 @@ impl ModuleOps for Mlx { stride: [usize; 2], padding: [usize; 2], _count_include_pad: bool, + _ceil_mode: bool, ) -> MlxTensorPrimitive { // Burn uses NCHW format, MLX uses NHWC format - // Transpose from [N, C, H, W] to [N, H, W, C] let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 3, 1]).expect("transpose"); - // Apply padding if needed let x_padded = if padding[0] > 0 || padding[1] > 0 { let pad_h = padding[0] as i32; let pad_w = padding[1] as i32; - // Pad H and W dimensions (axes 1 and 2 in NHWC format) - // PadWidth for [N, H, W, C]: [(0,0), (pad_h,pad_h), (pad_w,pad_w), (0,0)] mlx_rs::ops::pad( &x_nhwc, &[(0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)], @@ -482,14 +451,11 @@ impl ModuleOps for Mlx { x_nhwc }; - // Apply pooling using as_strided + mean_axes let pooled = pool2d_strided(&x_padded, kernel_size, stride, |arr, axes| { arr.mean_axes(axes, None) }); - // Transpose back from [N, H, W, C] to [N, C, H, W] let output = mlx_rs::ops::transpose_axes(&pooled, &[0, 3, 1, 2]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -500,8 +466,8 @@ impl ModuleOps for Mlx { stride: [usize; 2], padding: [usize; 2], _count_include_pad: bool, + _ceil_mode: bool, ) -> MlxTensorPrimitive { - // Burn uses NCHW format let input_shape = x.shape(); let n = input_shape[0]; let c = input_shape[1]; @@ -515,47 +481,27 @@ impl ModuleOps for Mlx { let pad_h = padding[0]; let pad_w = padding[1]; - // Padded input dimensions let h_padded = h + 2 * pad_h; let w_padded = w + 2 * pad_w; - // Output dimensions let out_h = (h_padded - kh) / sh + 1; let out_w = (w_padded - kw) / sw + 1; let pool_size = (kh * kw) as f32; - // Transpose grad from NCHW to NHWC for processing let grad_nhwc = mlx_rs::ops::transpose_axes(&grad.array, &[0, 2, 3, 1]).expect("transpose"); - // Scale gradient by 1/pool_size - let scale = Array::from_f32(1.0 / pool_size); + let scale = F::f64_scalar_array(1.0 / pool_size as f64); let grad_scaled = mlx_rs::ops::multiply(&grad_nhwc, &scale).expect("multiply"); - // Create zeros for padded input gradient (NHWC format) - let grad_input_padded = Array::zeros::(&[ + let grad_input_padded = F::zeros_array(&[ n as i32, h_padded as i32, w_padded as i32, c as i32, - ]).expect("zeros"); - - // For avg pooling backward, each output gradient contributes equally to all - // input positions in its window. We use scatter_add to accumulate gradients. - // - // For each output position (oh, ow), the window covers: - // h_start = oh * stride[0] - // w_start = ow * stride[1] - // positions: (h_start..h_start+kH, w_start..w_start+kW) - - // Create flat indices for all input positions that receive gradients - // We need to iterate over all output positions and all kernel positions - - // Build index arrays - // For each (oh, ow, kh_off, kw_off), compute flat index into padded input + ]); let mut all_indices: Vec = Vec::with_capacity(n * out_h * out_w * kh * kw * c); - let mut all_n_indices: Vec = Vec::with_capacity(n * out_h * out_w * kh * kw * c); let mut update_indices: Vec = Vec::with_capacity(n * out_h * out_w * kh * kw * c); for ni in 0..n { @@ -568,14 +514,11 @@ impl ModuleOps for Mlx { let hi = h_start + khi; let wi = w_start + kwi; for ci in 0..c { - // Flat index in NHWC layout let flat_idx = (ni * h_padded * w_padded * c + hi * w_padded * c + wi * c + ci) as i32; all_indices.push(flat_idx); - all_n_indices.push(ni as i32); - // Index into the flat grad_scaled array let grad_idx = ni * out_h * out_w * c + ohi * out_w * c + owi * c @@ -588,7 +531,6 @@ impl ModuleOps for Mlx { } } - // Flatten the scaled gradient and gather the values we need let grad_flat = grad_scaled.flatten(None, None).expect("flatten"); let update_idx_arr = Array::from_slice( &update_indices.iter().map(|&x| x as i32).collect::>(), @@ -596,11 +538,9 @@ impl ModuleOps for Mlx { ); let updates = take_axis(&grad_flat, &update_idx_arr, 0).expect("take"); - // Flatten the input gradient and use scatter_add let grad_input_flat = grad_input_padded.flatten(None, None).expect("flatten"); let indices_arr = Array::from_slice(&all_indices, &[all_indices.len() as i32]); - // Use scatter_add: add updates to grad_input_flat at indices let result_flat = mlx_rs::ops::scatter_add( &grad_input_flat, &[&indices_arr], @@ -608,7 +548,6 @@ impl ModuleOps for Mlx { &[0], ).expect("scatter_add"); - // Reshape back to NHWC let result_nhwc = result_flat.reshape(&[ n as i32, h_padded as i32, @@ -616,7 +555,6 @@ impl ModuleOps for Mlx { c as i32, ]).expect("reshape"); - // Remove padding if present let result_unpadded = if pad_h > 0 || pad_w > 0 { mlx_rs::ops::slice( &result_nhwc, @@ -628,9 +566,7 @@ impl ModuleOps for Mlx { result_nhwc }; - // Transpose back from NHWC to NCHW let output = mlx_rs::ops::transpose_axes(&result_unpadded, &[0, 3, 1, 2]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -640,15 +576,13 @@ impl ModuleOps for Mlx { stride: usize, padding: usize, _dilation: usize, + _ceil_mode: bool, ) -> MlxTensorPrimitive { - // Burn uses NCL format, MLX uses NLC format - // Transpose from [N, C, L] to [N, L, C] let x_nlc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 1]).expect("transpose"); - // Apply padding if needed (use -inf for max pooling) let x_padded = if padding > 0 { let pad = padding as i32; - let neg_inf = Array::from_f32(f32::NEG_INFINITY); + let neg_inf = F::scalar_array(F::neg_infinity()); mlx_rs::ops::pad( &x_nlc, &[(0, 0), (pad, pad), (0, 0)], @@ -659,14 +593,11 @@ impl ModuleOps for Mlx { x_nlc }; - // Apply pooling using as_strided + max_axes let pooled = pool1d_strided(&x_padded, kernel_size, stride, |arr, axes| { arr.max_axes(axes, None) }); - // Transpose back from [N, L, C] to [N, C, L] let output = mlx_rs::ops::transpose_axes(&pooled, &[0, 2, 1]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -676,16 +607,14 @@ impl ModuleOps for Mlx { stride: [usize; 2], padding: [usize; 2], _dilation: [usize; 2], + _ceil_mode: bool, ) -> MlxTensorPrimitive { - // Burn uses NCHW format, MLX uses NHWC format - // Transpose from [N, C, H, W] to [N, H, W, C] let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 3, 1]).expect("transpose"); - // Apply padding if needed (use -inf for max pooling) let x_padded = if padding[0] > 0 || padding[1] > 0 { let pad_h = padding[0] as i32; let pad_w = padding[1] as i32; - let neg_inf = Array::from_f32(f32::NEG_INFINITY); + let neg_inf = F::scalar_array(F::neg_infinity()); mlx_rs::ops::pad( &x_nhwc, &[(0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)], @@ -696,14 +625,11 @@ impl ModuleOps for Mlx { x_nhwc }; - // Apply pooling using as_strided + max_axes let pooled = pool2d_strided(&x_padded, kernel_size, stride, |arr, axes| { arr.max_axes(axes, None) }); - // Transpose back from [N, H, W, C] to [N, C, H, W] let output = mlx_rs::ops::transpose_axes(&pooled, &[0, 3, 1, 2]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -713,9 +639,9 @@ impl ModuleOps for Mlx { stride: usize, padding: usize, dilation: usize, - ) -> MaxPool1dWithIndices { - let output = Self::max_pool1d(x, kernel_size, stride, padding, dilation); - // Create dummy indices (placeholder) + _ceil_mode: bool, + ) -> MaxPool1dWithIndices> { + let output = Self::max_pool1d(x, kernel_size, stride, padding, dilation, false); let indices = MlxTensorPrimitive::new( Array::zeros::(&output.array.shape().iter().map(|&s| s as i32).collect::>()) .expect("zeros") @@ -729,16 +655,14 @@ impl ModuleOps for Mlx { stride: [usize; 2], padding: [usize; 2], _dilation: [usize; 2], - ) -> MaxPool2dWithIndices { - // Burn uses NCHW format, MLX uses NHWC format - // Transpose from [N, C, H, W] to [N, H, W, C] + _ceil_mode: bool, + ) -> MaxPool2dWithIndices> { let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 3, 1]).expect("transpose"); - // Apply padding if needed (use -inf for max pooling) let x_padded = if padding[0] > 0 || padding[1] > 0 { let pad_h = padding[0] as i32; let pad_w = padding[1] as i32; - let neg_inf = Array::from_f32(f32::NEG_INFINITY); + let neg_inf = F::scalar_array(F::neg_infinity()); mlx_rs::ops::pad( &x_nhwc, &[(0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)], @@ -749,10 +673,8 @@ impl ModuleOps for Mlx { x_nhwc }; - // Get max values and indices let (output_nhwc, indices_nhwc) = max_pool2d_with_indices_impl(&x_padded, kernel_size, stride); - // Transpose back from [N, H, W, C] to [N, C, H, W] let output = mlx_rs::ops::transpose_axes(&output_nhwc, &[0, 3, 1, 2]).expect("transpose"); let indices = mlx_rs::ops::transpose_axes(&indices_nhwc, &[0, 3, 1, 2]).expect("transpose"); @@ -768,12 +690,10 @@ impl ModuleOps for Mlx { _stride: [usize; 2], padding: [usize; 2], _dilation: [usize; 2], + _ceil_mode: bool, output_grad: MlxTensorPrimitive, indices: MlxTensorPrimitive, - ) -> MaxPool2dBackward { - // The indices contain flat indices into the padded NHWC input tensor. - // We need to scatter the gradients to those positions. - + ) -> MaxPool2dBackward> { let input_shape = x.shape(); let n = input_shape[0]; let c = input_shape[1]; @@ -783,23 +703,18 @@ impl ModuleOps for Mlx { let pad_h = padding[0]; let pad_w = padding[1]; - // Padded dimensions let h_padded = h + 2 * pad_h; let w_padded = w + 2 * pad_w; - // Create zeros for padded input gradient (NHWC flattened) let total_size = n * h_padded * w_padded * c; - let grad_input_flat = Array::zeros::(&[total_size as i32]).expect("zeros"); + let grad_input_flat = F::zeros_array(&[total_size as i32]); - // Transpose grad and indices from NCHW to NHWC to match index computation let grad_nhwc = mlx_rs::ops::transpose_axes(&output_grad.array, &[0, 2, 3, 1]).expect("transpose"); let indices_nhwc = mlx_rs::ops::transpose_axes(&indices.array, &[0, 2, 3, 1]).expect("transpose"); - // Flatten both let grad_flat = grad_nhwc.flatten(None, None).expect("flatten"); let indices_flat = indices_nhwc.flatten(None, None).expect("flatten"); - // Scatter the gradients to the positions indicated by indices let result_flat = mlx_rs::ops::scatter_add( &grad_input_flat, &[&indices_flat], @@ -807,7 +722,6 @@ impl ModuleOps for Mlx { &[0], ).expect("scatter_add"); - // Reshape to NHWC let result_nhwc = result_flat.reshape(&[ n as i32, h_padded as i32, @@ -815,7 +729,6 @@ impl ModuleOps for Mlx { c as i32, ]).expect("reshape"); - // Remove padding if present let result_unpadded = if pad_h > 0 || pad_w > 0 { mlx_rs::ops::slice( &result_nhwc, @@ -827,18 +740,15 @@ impl ModuleOps for Mlx { result_nhwc }; - // Transpose back from NHWC to NCHW let output = mlx_rs::ops::transpose_axes(&result_unpadded, &[0, 3, 1, 2]).expect("transpose"); - MaxPool2dBackward::new(MlxTensorPrimitive::new(output)) } fn adaptive_avg_pool1d(x: MlxTensorPrimitive, output_size: usize) -> MlxTensorPrimitive { - // Calculate kernel_size and stride to achieve output_size let input_size = x.shape()[2]; let stride = input_size / output_size; let kernel_size = input_size - (output_size - 1) * stride; - Self::avg_pool1d(x, kernel_size, stride, 0, true) + Self::avg_pool1d(x, kernel_size, stride, 0, true, false) } fn adaptive_avg_pool2d(x: MlxTensorPrimitive, output_size: [usize; 2]) -> MlxTensorPrimitive { @@ -851,16 +761,15 @@ impl ModuleOps for Mlx { let kernel_h = input_h - (output_size[0] - 1) * stride_h; let kernel_w = input_w - (output_size[1] - 1) * stride_w; - Self::avg_pool2d(x, [kernel_h, kernel_w], [stride_h, stride_w], [0, 0], true) + Self::avg_pool2d(x, [kernel_h, kernel_w], [stride_h, stride_w], [0, 0], true, false) } fn adaptive_avg_pool2d_backward( x: MlxTensorPrimitive, _grad: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Placeholder: return zeros with input shape let shape: Vec = x.shape().iter().map(|&s| s as i32).collect(); - let output = Array::zeros::(&shape).expect("zeros"); + let output = F::zeros_array(&shape); MlxTensorPrimitive::new(output) } @@ -869,7 +778,6 @@ impl ModuleOps for Mlx { _output_size: [usize; 2], _options: InterpolateOptions, ) -> MlxTensorPrimitive { - // MLX doesn't have direct interpolate - placeholder x } @@ -879,9 +787,8 @@ impl ModuleOps for Mlx { _output_size: [usize; 2], _options: InterpolateOptions, ) -> MlxTensorPrimitive { - // Placeholder: return zeros with input shape let shape: Vec = x.shape().iter().map(|&s| s as i32).collect(); - let output = Array::zeros::(&shape).expect("zeros"); + let output = F::zeros_array(&shape); MlxTensorPrimitive::new(output) } @@ -889,7 +796,6 @@ impl ModuleOps for Mlx { weights: MlxTensorPrimitive, indices: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Embedding lookup - gather rows from weights based on indices let array = take_axis(&weights.array, &indices.array, 0) .expect("embedding"); MlxTensorPrimitive::new(array) @@ -900,8 +806,6 @@ impl ModuleOps for Mlx { _output_grad: MlxTensorPrimitive, _indices: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Scatter gradients back to weights - // Placeholder - proper implementation needed weights } } diff --git a/src/ops/other_ops.rs b/src/ops/other_ops.rs index eb61126..edeceb6 100644 --- a/src/ops/other_ops.rs +++ b/src/ops/other_ops.rs @@ -1,19 +1,21 @@ //! Additional ops implementations for MLX backend. use burn_tensor::{ - ops::{ActivationOps, QTensorOps, TransactionOps}, - quantization::QuantizationScheme, - Shape, TensorData, + backend::ExecutionError, + ops::{ActivationOps, FloatTensorOps, QTensorOps, TransactionOps}, + quantization::{QuantLevel, QuantScheme, QuantValue, QuantizedBytes}, + DType, Shape, Slice, TensorData, TensorPrimitive, }; use mlx_rs::Array; use crate::backend::{Mlx, MlxQuantizedTensorPrimitive, MlxTensorPrimitive}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; // ActivationOps - most methods have default implementations -impl ActivationOps for Mlx { +impl ActivationOps for Mlx { fn relu(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { - let zero = Array::from_f32(0.0); + let zero = F::f64_scalar_array(0.0); let array = mlx_rs::ops::maximum(&tensor.array, &zero).expect("relu"); MlxTensorPrimitive::new(array) } @@ -26,25 +28,26 @@ impl ActivationOps for Mlx { fn gelu(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { // GELU(x) = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) // Simplified: x * sigmoid(1.702 * x) - let coef = Array::from_f32(1.702); + let coef = F::f64_scalar_array(1.702); let scaled = mlx_rs::ops::multiply(&tensor.array, &coef).expect("multiply"); let sigmoid = mlx_rs::ops::sigmoid(&scaled).expect("sigmoid"); let array = mlx_rs::ops::multiply(&tensor.array, &sigmoid).expect("multiply"); MlxTensorPrimitive::new(array) } - fn leaky_relu(tensor: MlxTensorPrimitive, negative_slope: f32) -> MlxTensorPrimitive { - let array = mlx_rs::nn::leaky_relu(&tensor.array, negative_slope).expect("leaky_relu"); + fn leaky_relu(tensor: MlxTensorPrimitive, negative_slope: F) -> MlxTensorPrimitive { + let slope_f32 = num_traits::ToPrimitive::to_f32(&negative_slope).unwrap(); + let array = mlx_rs::nn::leaky_relu(&tensor.array, slope_f32).expect("leaky_relu"); MlxTensorPrimitive::new(array) } - fn hard_sigmoid(tensor: MlxTensorPrimitive, alpha: f32, beta: f32) -> MlxTensorPrimitive { - let alpha_arr = Array::from_f32(alpha); - let beta_arr = Array::from_f32(beta); + fn hard_sigmoid(tensor: MlxTensorPrimitive, alpha: F, beta: F) -> MlxTensorPrimitive { + let alpha_arr = F::scalar_array(alpha); + let beta_arr = F::scalar_array(beta); let scaled = mlx_rs::ops::multiply(&tensor.array, &alpha_arr).expect("multiply"); let shifted = mlx_rs::ops::add(&scaled, &beta_arr).expect("add"); - let zero = Array::from_f32(0.0); - let one = Array::from_f32(1.0); + let zero = F::f64_scalar_array(0.0); + let one = F::f64_scalar_array(1.0); let array = mlx_rs::ops::clip(&shifted, (&zero, &one)).expect("clip"); MlxTensorPrimitive::new(array) } @@ -56,7 +59,7 @@ impl ActivationOps for Mlx { } fn prelu(tensor: MlxTensorPrimitive, alpha: MlxTensorPrimitive) -> MlxTensorPrimitive { - let zero = Array::from_f32(0.0); + let zero = F::f64_scalar_array(0.0); let pos = mlx_rs::ops::maximum(&tensor.array, &zero).expect("max"); let neg = mlx_rs::ops::minimum(&tensor.array, &zero).expect("min"); let scaled_neg = mlx_rs::ops::multiply(&alpha.array, &neg).expect("multiply"); @@ -64,72 +67,198 @@ impl ActivationOps for Mlx { MlxTensorPrimitive::new(array) } - fn gelu_backward(x: MlxTensorPrimitive, grad: MlxTensorPrimitive) -> MlxTensorPrimitive { - // Backward pass for GELU - placeholder + fn gelu_backward(_x: MlxTensorPrimitive, grad: MlxTensorPrimitive) -> MlxTensorPrimitive { grad } fn relu_backward(x: MlxTensorPrimitive, grad: MlxTensorPrimitive) -> MlxTensorPrimitive { - let zero = Array::from_f32(0.0); + let zero = F::f64_scalar_array(0.0); let mask = mlx_rs::ops::gt(&x.array, &zero).expect("greater"); - let mask_float = mask.as_type::().expect("cast"); + let mask_float = F::cast_array(&mask); let array = mlx_rs::ops::multiply(&grad.array, &mask_float).expect("multiply"); MlxTensorPrimitive::new(array) } } -// QTensorOps - Quantization operations (placeholder) -impl QTensorOps for Mlx { - fn q_from_data(data: TensorData, device: &MlxDevice) -> MlxQuantizedTensorPrimitive { - let tensor = >::float_from_data( - data.convert::(), - device, - ); +/// Helper: extract MLX quantization parameters (bits, group_size) from a Burn QuantScheme. +fn scheme_to_mlx_params(scheme: &QuantScheme, num_elements: usize) -> (i32, i32) { + let bits: i32 = match scheme.value { + QuantValue::Q8S | QuantValue::Q8F => 8, + QuantValue::Q4S | QuantValue::Q4F => 4, + _ => panic!("Unsupported quantization value: {:?}", scheme.value), + }; + let group_size: i32 = match scheme.level { + QuantLevel::Block(bs) => bs.as_slice()[0] as i32, + QuantLevel::Tensor => num_elements as i32, + }; + (bits, group_size) +} + +// QTensorOps - Native MLX quantization +impl QTensorOps for Mlx { + fn q_from_data(data: TensorData, _device: &MlxDevice) -> MlxQuantizedTensorPrimitive { + // Extract the quantization scheme from the data's dtype + let scheme = match data.dtype { + DType::QFloat(scheme) => scheme, + other => panic!("q_from_data called with non-quantized dtype: {:?}", other), + }; + + let num_elements: usize = data.shape.iter().product(); + let (bits, group_size) = scheme_to_mlx_params(&scheme, num_elements); + + // Unpack Burn's quantized bytes into i8 values + scale parameters + let qb = QuantizedBytes { + bytes: data.bytes, + scheme, + num_elements, + }; + let (i8_values, qparams) = qb.into_vec_i8(); + let scales = qparams.scales; + + // Dequantize on CPU: symmetric quantization means float_val = i8_val * scale + let mut float_vals = vec![0.0f32; i8_values.len()]; + for (block_idx, chunk) in i8_values.chunks(group_size as usize).enumerate() { + let scale = scales[block_idx]; + for (i, &val) in chunk.iter().enumerate() { + float_vals[block_idx * group_size as usize + i] = val as f32 * scale; + } + } + + // Create a float MLX Array from the dequantized values + let shape_i32: Vec = data.shape.iter().map(|&s| s as i32).collect(); + let float_array = Array::from_slice(&float_vals, &shape_i32); + + // Re-quantize into MLX's native format + let (quantized, mlx_scales, mlx_biases) = + mlx_rs::ops::quantize(&float_array, group_size, bits).expect("MLX quantize failed"); + MlxQuantizedTensorPrimitive { - tensor, - scheme: crate::backend::QuantizationScheme::None, + quantized, + scales: mlx_scales, + biases: mlx_biases, + shape: data.shape, + group_size, + bits, + scheme, } } fn quantize( tensor: MlxTensorPrimitive, - scheme: &QuantizationScheme, - qparams: burn_tensor::quantization::QuantizationParametersPrimitive, + scheme: &QuantScheme, + _qparams: burn_tensor::quantization::QuantizationParametersPrimitive, ) -> MlxQuantizedTensorPrimitive { + let num_elements: usize = tensor.shape.iter().product(); + let (bits, group_size) = scheme_to_mlx_params(scheme, num_elements); + let shape = tensor.shape.clone(); + + let (quantized, scales, biases) = + mlx_rs::ops::quantize(&tensor.array, group_size, bits).expect("MLX quantize failed"); + MlxQuantizedTensorPrimitive { - tensor, - scheme: crate::backend::QuantizationScheme::None, + quantized, + scales, + biases, + shape, + group_size, + bits, + scheme: *scheme, } } fn dequantize(tensor: MlxQuantizedTensorPrimitive) -> MlxTensorPrimitive { - tensor.tensor + let array = mlx_rs::ops::dequantize( + &tensor.quantized, + &tensor.scales, + &tensor.biases, + tensor.group_size, + tensor.bits, + ) + .expect("MLX dequantize failed"); + // MLX dequantize may return f32 regardless of F; cast to backend's float type. + let array = F::cast_array(&array); + MlxTensorPrimitive::new(array) + } + + fn q_matmul(lhs: TensorPrimitive, rhs: TensorPrimitive) -> TensorPrimitive { + match (lhs, rhs) { + // float x quantized — the common case for inference (activation x weight) + (TensorPrimitive::Float(lhs_f), TensorPrimitive::QFloat(rhs_q)) => { + let result = mlx_rs::ops::quantized_matmul( + &lhs_f.array, + &rhs_q.quantized, + &rhs_q.scales, + &rhs_q.biases, + false, // no transpose — Burn stores weights as [in, out] + rhs_q.group_size, + rhs_q.bits, + ) + .expect("MLX quantized_matmul failed"); + // MLX quantized_matmul may return f32; cast to backend's float type. + let result = F::cast_array(&result); + TensorPrimitive::Float(MlxTensorPrimitive::new(result)) + } + // quantized x float — dequantize LHS + (TensorPrimitive::QFloat(lhs_q), TensorPrimitive::Float(rhs_f)) => { + let lhs_f = Self::dequantize(lhs_q); + TensorPrimitive::Float(>::float_matmul(lhs_f, rhs_f)) + } + // both quantized — dequantize both + (TensorPrimitive::QFloat(lhs_q), TensorPrimitive::QFloat(rhs_q)) => { + let lhs_f = Self::dequantize(lhs_q); + let rhs_f = Self::dequantize(rhs_q); + TensorPrimitive::Float(>::float_matmul(lhs_f, rhs_f)) + } + // both float — standard matmul + (TensorPrimitive::Float(lhs_f), TensorPrimitive::Float(rhs_f)) => { + TensorPrimitive::Float(>::float_matmul(lhs_f, rhs_f)) + } + } } - fn q_device(tensor: &MlxQuantizedTensorPrimitive) -> MlxDevice { + fn q_device(_tensor: &MlxQuantizedTensorPrimitive) -> MlxDevice { MlxDevice::Gpu } fn q_to_device( tensor: MlxQuantizedTensorPrimitive, - device: &MlxDevice, + _device: &MlxDevice, ) -> MlxQuantizedTensorPrimitive { tensor } fn q_reshape(tensor: MlxQuantizedTensorPrimitive, shape: Shape) -> MlxQuantizedTensorPrimitive { - let reshaped = >::float_reshape( - tensor.tensor, - shape, - ); - MlxQuantizedTensorPrimitive { - tensor: reshaped, - scheme: tensor.scheme, + let new_dims: Vec = shape.dims.to_vec(); + + // Fast path: if the last 2 dimensions are unchanged, just update the + // logical shape. The underlying quantized/scales/biases arrays remain + // valid since they represent the same 2D weight matrix. + // This avoids expensive dequant→reshape→requant for trivial unsqueezes + // (e.g. [M, N] → [1, M, N]) triggered by nn::Linear::forward. + let old = &tensor.shape; + if old.len() >= 2 + && new_dims.len() >= 2 + && old[old.len() - 2] == new_dims[new_dims.len() - 2] + && old[old.len() - 1] == new_dims[new_dims.len() - 1] + { + return MlxQuantizedTensorPrimitive { + shape: new_dims, + ..tensor + }; } + + // Fallback: actual data reshape requires dequant → reshape → requant + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let reshaped = >::float_reshape(float_tensor, shape); + Self::quantize_dynamic(reshaped, &scheme) } - async fn q_into_data(tensor: MlxQuantizedTensorPrimitive) -> TensorData { - >::float_into_data(tensor.tensor).await + async fn q_into_data( + tensor: MlxQuantizedTensorPrimitive, + ) -> Result { + let float_tensor = Self::dequantize(tensor); + >::float_into_data(float_tensor).await } fn q_swap_dims( @@ -137,43 +266,41 @@ impl QTensorOps for Mlx { dim1: usize, dim2: usize, ) -> MlxQuantizedTensorPrimitive { - let swapped = >::float_swap_dims( - tensor.tensor, - dim1, - dim2, - ); - MlxQuantizedTensorPrimitive { - tensor: swapped, - scheme: tensor.scheme, + let ndim = tensor.shape.len(); + + // Fast path: swapping within batch/prefix dims (not the last 2) just + // updates the logical shape — the underlying quantized arrays are + // unaffected. Also handles the trivial dim1 == dim2 case. + if ndim >= 2 && dim1 < ndim - 2 && dim2 < ndim - 2 { + let mut new_shape = tensor.shape.clone(); + new_shape.swap(dim1, dim2); + return MlxQuantizedTensorPrimitive { + shape: new_shape, + ..tensor + }; } + + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let swapped = >::float_swap_dims(float_tensor, dim1, dim2); + Self::quantize_dynamic(swapped, &scheme) } fn q_permute( tensor: MlxQuantizedTensorPrimitive, axes: &[usize], ) -> MlxQuantizedTensorPrimitive { - let permuted = >::float_permute( - tensor.tensor, - axes, - ); - MlxQuantizedTensorPrimitive { - tensor: permuted, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let permuted = >::float_permute(float_tensor, axes); + Self::quantize_dynamic(permuted, &scheme) } - fn q_flip( - tensor: MlxQuantizedTensorPrimitive, - axes: &[usize], - ) -> MlxQuantizedTensorPrimitive { - let flipped = >::float_flip( - tensor.tensor, - axes, - ); - MlxQuantizedTensorPrimitive { - tensor: flipped, - scheme: tensor.scheme, - } + fn q_flip(tensor: MlxQuantizedTensorPrimitive, axes: &[usize]) -> MlxQuantizedTensorPrimitive { + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let flipped = >::float_flip(float_tensor, axes); + Self::quantize_dynamic(flipped, &scheme) } fn q_select( @@ -181,45 +308,50 @@ impl QTensorOps for Mlx { dim: usize, indices: MlxTensorPrimitive, ) -> MlxQuantizedTensorPrimitive { - let selected = >::float_select( - tensor.tensor, - dim, - indices, - ); - MlxQuantizedTensorPrimitive { - tensor: selected, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let selected = >::float_select(float_tensor, dim, indices); + Self::quantize_dynamic(selected, &scheme) } fn q_slice( tensor: MlxQuantizedTensorPrimitive, - ranges: &[std::ops::Range], + slices: &[Slice], ) -> MlxQuantizedTensorPrimitive { - let sliced = >::float_slice( - tensor.tensor, - ranges, - ); - MlxQuantizedTensorPrimitive { - tensor: sliced, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let sliced = >::float_slice(float_tensor, slices); + Self::quantize_dynamic(sliced, &scheme) } - fn q_expand( - tensor: MlxQuantizedTensorPrimitive, - shape: Shape, - ) -> MlxQuantizedTensorPrimitive { - let expanded = >::float_expand( - tensor.tensor, - shape, - ); - MlxQuantizedTensorPrimitive { - tensor: expanded, - scheme: tensor.scheme, + fn q_expand(tensor: MlxQuantizedTensorPrimitive, shape: Shape) -> MlxQuantizedTensorPrimitive { + let new_dims: Vec = shape.dims.to_vec(); + let old = &tensor.shape; + + // Fast path: if the last 2 dimensions are unchanged and any new prefix + // dims are size-1 (broadcast markers), the underlying quantized arrays + // don't need modification. + if old.len() >= 2 + && new_dims.len() >= 2 + && old[old.len() - 2] == new_dims[new_dims.len() - 2] + && old[old.len() - 1] == new_dims[new_dims.len() - 1] + { + // Check that any extra leading dims are size 1 + let extra = new_dims.len().saturating_sub(old.len()); + if new_dims[..extra].iter().all(|&d| d == 1) { + return MlxQuantizedTensorPrimitive { + shape: new_dims, + ..tensor + }; + } } + + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let expanded = >::float_expand(float_tensor, shape); + Self::quantize_dynamic(expanded, &scheme) } } // TransactionOps - transaction batching (default impl) -impl TransactionOps for Mlx {} +impl TransactionOps for Mlx {}