From d04799a6247b58ed5c95da5b2df46a3218029996 Mon Sep 17 00:00:00 2001 From: Mike Marcacci Date: Thu, 12 Feb 2026 12:25:30 -0800 Subject: [PATCH 1/6] Update burn dependency from 0.16 to 0.20 Migrates burn-mlx across 4 major burn versions (0.17-0.20) with all required API changes: Device trait impl, Slice replacing Range, ExecutionError return types, QuantScheme, renamed scatter/select ops, ceil_mode on pooling, and new required methods (trig, cumulative, bitwise, unfold, cross, cast, matmul for ints). Co-Authored-By: Claude Opus 4.6 --- Cargo.toml | 6 +- src/backend.rs | 58 +++++------ src/device.rs | 17 +++- src/lib.rs | 7 +- src/ops/bool_ops.rs | 154 ++++++++++++++++++++++++++--- src/ops/float_ops.rs | 203 +++++++++++++++++++++++++++++++++---- src/ops/int_ops.rs | 225 ++++++++++++++++++++++++++++++++++++++---- src/ops/module_ops.rs | 119 +++------------------- src/ops/other_ops.rs | 26 ++--- 9 files changed, 599 insertions(+), 216 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2fdc21e..0454a72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,13 +26,13 @@ fusion = [] autotune = [] [dependencies] -burn = { version = "0.16", default-features = false, features = ["std"] } -burn-tensor = "0.16" +burn = { version = "0.20", default-features = false, features = ["std"] } +burn-tensor = "0.20" mlx-rs = { package = "mlx-rs-burn", version = "0.25.5" } derive-new = "0.7" half = { version = "2.4", features = ["num-traits"] } num-traits = "0.2" [dev-dependencies] -burn = { version = "0.16", features = ["train", "dataset"] } +burn = { version = "0.20", features = ["train", "dataset"] } serial_test = "3.2" diff --git a/src/backend.rs b/src/backend.rs index 1ff2377..10a0063 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,7 +1,8 @@ //! MLX Backend implementation for Burn. use burn_tensor::backend::Backend; -use burn_tensor::TensorMetadata; +use burn_tensor::{DType, TensorMetadata}; +use burn_tensor::quantization::QuantScheme; use mlx_rs::Array; use std::sync::atomic::{AtomicU64, Ordering}; @@ -43,17 +44,17 @@ unsafe impl Send for MlxTensorPrimitive {} unsafe impl Sync for MlxTensorPrimitive {} impl TensorMetadata for MlxTensorPrimitive { - fn dtype(&self) -> burn_tensor::DType { + fn dtype(&self) -> DType { // Map MLX dtype to Burn dtype match self.array.dtype() { - mlx_rs::Dtype::Float32 => burn_tensor::DType::F32, - mlx_rs::Dtype::Float16 => burn_tensor::DType::F16, - mlx_rs::Dtype::Bfloat16 => burn_tensor::DType::BF16, - mlx_rs::Dtype::Float64 => burn_tensor::DType::F64, - mlx_rs::Dtype::Int32 => burn_tensor::DType::I32, - mlx_rs::Dtype::Int64 => burn_tensor::DType::I64, - mlx_rs::Dtype::Bool => burn_tensor::DType::Bool, - _ => burn_tensor::DType::F32, // Default fallback + mlx_rs::Dtype::Float32 => DType::F32, + mlx_rs::Dtype::Float16 => DType::F16, + mlx_rs::Dtype::Bfloat16 => DType::BF16, + mlx_rs::Dtype::Float64 => DType::F64, + mlx_rs::Dtype::Int32 => DType::I32, + mlx_rs::Dtype::Int64 => DType::I64, + mlx_rs::Dtype::Bool => DType::Bool, + _ => DType::F32, // Default fallback } } @@ -68,14 +69,7 @@ pub struct MlxQuantizedTensorPrimitive { /// The underlying tensor (stored as float for now). pub tensor: MlxTensorPrimitive, /// Quantization scheme. - pub scheme: QuantizationScheme, -} - -/// Quantization scheme. -#[derive(Debug, Clone, Copy, Default)] -pub enum QuantizationScheme { - #[default] - None, + pub scheme: QuantScheme, } // SAFETY: Same as MlxTensorPrimitive @@ -83,7 +77,7 @@ unsafe impl Send for MlxQuantizedTensorPrimitive {} unsafe impl Sync for MlxQuantizedTensorPrimitive {} impl TensorMetadata for MlxQuantizedTensorPrimitive { - fn dtype(&self) -> burn_tensor::DType { + fn dtype(&self) -> DType { self.tensor.dtype() } @@ -93,13 +87,8 @@ impl TensorMetadata for MlxQuantizedTensorPrimitive { } impl burn_tensor::quantization::QTensorPrimitive for MlxQuantizedTensorPrimitive { - fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme { - // Return a reference to a static scheme - static SYMMETRIC: burn_tensor::quantization::QuantizationScheme = - burn_tensor::quantization::QuantizationScheme::PerTensorSymmetric( - burn_tensor::quantization::QuantizationType::QInt8, - ); - &SYMMETRIC + fn scheme(&self) -> &QuantScheme { + &self.scheme } } @@ -120,23 +109,22 @@ impl Backend for Mlx { type BoolElem = bool; type QuantizedTensorPrimitive = MlxQuantizedTensorPrimitive; - type QuantizedEncoding = i8; - fn name() -> String { + fn name(_device: &Self::Device) -> String { "mlx".to_string() } - fn seed(seed: u64) { + fn seed(_device: &Self::Device, seed: u64) { SEED.store(seed, Ordering::SeqCst); // MLX uses its own seeding mechanism - mlx_rs::random::seed(seed); + let _ = mlx_rs::random::seed(seed); } - fn sync(device: &Self::Device) { - // MLX is lazy-evaluated; sync forces evaluation - // This is a no-op in MLX as synchronization happens implicitly - // when reading tensor values - let _ = device; + fn supports_dtype(_device: &Self::Device, dtype: DType) -> bool { + matches!( + dtype, + DType::F32 | DType::F64 | DType::F16 | DType::BF16 | DType::I32 | DType::I64 | DType::Bool + ) } } diff --git a/src/device.rs b/src/device.rs index 0918c34..6f89029 100644 --- a/src/device.rs +++ b/src/device.rs @@ -32,11 +32,24 @@ impl fmt::Display for MlxDevice { } } -impl DeviceOps for MlxDevice { - fn id(&self) -> burn_tensor::backend::DeviceId { +impl burn_tensor::backend::Device for MlxDevice { + fn from_id(device_id: burn_tensor::backend::DeviceId) -> Self { + match device_id.type_id { + 0 => MlxDevice::Cpu, + _ => MlxDevice::Gpu, + } + } + + fn to_id(&self) -> burn_tensor::backend::DeviceId { match self { MlxDevice::Cpu => burn_tensor::backend::DeviceId::new(0, 0), MlxDevice::Gpu => burn_tensor::backend::DeviceId::new(1, 0), } } + + fn device_count(_type_id: u16) -> usize { + 1 + } } + +impl DeviceOps for MlxDevice {} diff --git a/src/lib.rs b/src/lib.rs index 180b3ab..fc435f8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,7 +65,7 @@ pub mod mlx { #[cfg(test)] mod tests { use super::*; - use burn_tensor::{backend::Backend, Tensor, TensorData, Shape}; + use burn_tensor::{Tensor, TensorData, Shape}; #[test] fn test_device_creation() { @@ -169,6 +169,7 @@ mod tests { [2, 2], [0, 0], true, + false, ); let shape = pooled.shape(); @@ -195,6 +196,7 @@ mod tests { [2, 2], [0, 0], [1, 1], + false, ); let shape = pooled.shape(); @@ -221,6 +223,7 @@ mod tests { [2, 2], [0, 0], [1, 1], + false, ); let output_shape = result.output.shape(); @@ -250,6 +253,7 @@ mod tests { 2, 0, true, + false, ); let shape = pooled.shape(); @@ -276,6 +280,7 @@ mod tests { 2, 0, 1, + false, ); let shape = pooled.shape(); diff --git a/src/ops/bool_ops.rs b/src/ops/bool_ops.rs index a1fab31..f058db7 100644 --- a/src/ops/bool_ops.rs +++ b/src/ops/bool_ops.rs @@ -1,8 +1,8 @@ //! Boolean tensor operations for MLX backend. -use burn_tensor::{ops::BoolTensorOps, Shape, TensorData}; +use burn_tensor::{backend::ExecutionError, ops::BoolTensorOps, Shape, Slice, TensorData}; use mlx_rs::Array; -use std::ops::Range; +use mlx_rs::ops::indexing::{take_axis, take_along_axis}; use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; @@ -19,19 +19,18 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } - async fn bool_into_data(tensor: MlxTensorPrimitive) -> TensorData { + async fn bool_into_data(tensor: MlxTensorPrimitive) -> Result { tensor.array.eval().expect("Failed to evaluate tensor"); let shape = tensor.shape().to_vec(); let data: Vec = tensor.array.as_slice().to_vec(); - TensorData::new(data, shape) + Ok(TensorData::new(data, shape)) } - fn bool_device(tensor: &MlxTensorPrimitive) -> MlxDevice { + fn bool_device(_tensor: &MlxTensorPrimitive) -> MlxDevice { MlxDevice::Gpu } - fn bool_to_device(tensor: MlxTensorPrimitive, device: &MlxDevice) -> MlxTensorPrimitive { - let _ = device; + fn bool_to_device(tensor: MlxTensorPrimitive, _device: &MlxDevice) -> MlxTensorPrimitive { tensor } @@ -45,23 +44,58 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } + fn bool_zeros(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + Self::bool_empty(shape, device) + } + + fn bool_ones(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + let mlx_device = device.to_mlx_device(); + mlx_rs::Device::set_default(&mlx_device); + + let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); + let array = Array::ones::(&shape_i32).expect("Failed to create ones bool array"); + + MlxTensorPrimitive::new(array) + } + fn bool_reshape(tensor: MlxTensorPrimitive, shape: Shape) -> MlxTensorPrimitive { let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); let array = tensor.array.reshape(&shape_i32).expect("Failed to reshape"); MlxTensorPrimitive::new(array) } - fn bool_slice(tensor: MlxTensorPrimitive, ranges: &[Range]) -> MlxTensorPrimitive { - // Placeholder - need proper slice implementation - tensor + fn bool_slice(tensor: MlxTensorPrimitive, slices: &[Slice]) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); + let array = mlx_rs::ops::slice(&tensor.array, &starts, &stops, None) + .expect("Failed to slice"); + MlxTensorPrimitive::new(array) } fn bool_slice_assign( tensor: MlxTensorPrimitive, - ranges: &[Range], + slices: &[Slice], value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - tensor + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); + let array = mlx_rs::ops::slice_update(&tensor.array, &value.array, &starts, &stops, None) + .expect("Failed to slice_assign"); + MlxTensorPrimitive::new(array) } fn bool_into_int(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { @@ -79,6 +113,16 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } + fn bool_and(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::logical_and(&lhs.array, &rhs.array).expect("Failed to logical_and"); + MlxTensorPrimitive::new(array) + } + + fn bool_or(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::logical_or(&lhs.array, &rhs.array).expect("Failed to logical_or"); + MlxTensorPrimitive::new(array) + } + fn bool_swap_dims(tensor: MlxTensorPrimitive, dim1: usize, dim2: usize) -> MlxTensorPrimitive { let ndim = tensor.shape().len(); let mut axes: Vec = (0..ndim as i32).collect(); @@ -94,8 +138,10 @@ impl BoolTensorOps for Mlx { } fn bool_flip(tensor: MlxTensorPrimitive, axes: &[usize]) -> MlxTensorPrimitive { - // Placeholder - MLX doesn't have direct flip - tensor + let axes_i32: Vec = axes.iter().map(|&a| a as i32).collect(); + let array = mlx_rs::ops::flip(&tensor.array, &axes_i32[..]) + .expect("Failed to flip"); + MlxTensorPrimitive::new(array) } fn bool_expand(tensor: MlxTensorPrimitive, shape: Shape) -> MlxTensorPrimitive { @@ -109,6 +155,12 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } + fn bool_equal_elem(lhs: MlxTensorPrimitive, rhs: bool) -> MlxTensorPrimitive { + let scalar = Array::from_slice(&[rhs], &[1]); + let array = mlx_rs::ops::eq(&lhs.array, &scalar).expect("Failed to equal_elem"); + MlxTensorPrimitive::new(array) + } + fn bool_any(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { let array = mlx_rs::ops::any(&tensor.array, false).expect("Failed to any"); MlxTensorPrimitive::new(array) @@ -129,9 +181,52 @@ impl BoolTensorOps for Mlx { MlxTensorPrimitive::new(array) } + fn bool_mask_where( + tensor: MlxTensorPrimitive, + mask: MlxTensorPrimitive, + value: MlxTensorPrimitive, + ) -> MlxTensorPrimitive { + let array = mlx_rs::ops::r#where(&mask.array, &value.array, &tensor.array) + .expect("Failed to mask_where"); + MlxTensorPrimitive::new(array) + } + + fn bool_mask_fill( + tensor: MlxTensorPrimitive, + mask: MlxTensorPrimitive, + value: bool, + ) -> MlxTensorPrimitive { + let fill_val = Array::from_slice(&[value], &[1]); + let fill_broadcast = mlx_rs::ops::broadcast_to(&fill_val, tensor.array.shape()) + .expect("Failed to broadcast"); + let array = mlx_rs::ops::r#where(&mask.array, &fill_broadcast, &tensor.array) + .expect("Failed to mask_fill"); + MlxTensorPrimitive::new(array) + } + + fn bool_gather( + dim: usize, + tensor: MlxTensorPrimitive, + indices: MlxTensorPrimitive, + ) -> MlxTensorPrimitive { + let array = take_along_axis(&tensor.array, &indices.array, dim as i32) + .expect("Failed to gather"); + MlxTensorPrimitive::new(array) + } + + fn bool_scatter_or( + dim: usize, + tensor: MlxTensorPrimitive, + indices: MlxTensorPrimitive, + value: MlxTensorPrimitive, + ) -> MlxTensorPrimitive { + let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) + .expect("Failed to scatter_or"); + MlxTensorPrimitive::new(array) + } + async fn bool_argwhere(_tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { // MLX argwhere may not be available in mlx-rs bindings - // Placeholder: return empty tensor let empty = mlx_rs::Array::zeros::(&[0, 1]).expect("Failed to create empty array"); MlxTensorPrimitive::new(empty) } @@ -140,4 +235,33 @@ impl BoolTensorOps for Mlx { let array = mlx_rs::ops::repeat_axis::(tensor.array, dim as i32, times as i32).expect("Failed to repeat"); MlxTensorPrimitive::new(array) } + + fn bool_unfold( + tensor: MlxTensorPrimitive, + dim: usize, + size: usize, + step: usize, + ) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let dim_size = shape[dim]; + let num_windows = (dim_size - size) / step + 1; + + let mut window_indices = Vec::new(); + for w in 0..num_windows { + let start = w * step; + for i in 0..size { + window_indices.push((start + i) as i32); + } + } + + let indices = Array::from_slice(&window_indices, &[(num_windows * size) as i32]); + let gathered = take_axis(&tensor.array, &indices, dim as i32).expect("take"); + + let mut new_shape: Vec = shape.iter().map(|&s| s as i32).collect(); + new_shape[dim] = num_windows as i32; + new_shape.push(size as i32); + let array = gathered.reshape(&new_shape).expect("reshape"); + + MlxTensorPrimitive::new(array) + } } diff --git a/src/ops/float_ops.rs b/src/ops/float_ops.rs index 4f632be..45866ed 100644 --- a/src/ops/float_ops.rs +++ b/src/ops/float_ops.rs @@ -1,9 +1,8 @@ //! Float tensor operations for MLX backend. -use burn_tensor::{ops::FloatTensorOps, Distribution, FloatDType, Shape, TensorData}; +use burn_tensor::{backend::ExecutionError, ops::FloatTensorOps, Distribution, FloatDType, Shape, Slice, TensorData}; use mlx_rs::Array; use mlx_rs::ops::indexing::{take_axis, take_along_axis}; -use std::ops::Range; use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; @@ -61,11 +60,11 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - async fn float_into_data(tensor: MlxTensorPrimitive) -> TensorData { + async fn float_into_data(tensor: MlxTensorPrimitive) -> Result { tensor.array.eval().expect("Failed to evaluate tensor"); let shape = tensor.shape().to_vec(); let data: Vec = tensor.array.as_slice().to_vec(); - TensorData::new(data, shape) + Ok(TensorData::new(data, shape)) } fn float_device(tensor: &MlxTensorPrimitive) -> MlxDevice { @@ -78,7 +77,7 @@ impl FloatTensorOps for Mlx { tensor } - fn float_empty(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + fn float_empty(shape: Shape, device: &MlxDevice, _dtype: FloatDType) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); @@ -203,15 +202,14 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_scatter( + fn float_scatter_add( dim: usize, tensor: MlxTensorPrimitive, indices: MlxTensorPrimitive, value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Use put_along_axis for scatter operation let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) - .expect("Failed to scatter"); + .expect("Failed to scatter_add"); MlxTensorPrimitive::new(array) } @@ -225,21 +223,27 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_select_assign( + fn float_select_add( tensor: MlxTensorPrimitive, dim: usize, indices: MlxTensorPrimitive, value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Use put_along_axis for select_assign operation let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) - .expect("Failed to select_assign"); + .expect("Failed to select_add"); MlxTensorPrimitive::new(array) } - fn float_slice(tensor: MlxTensorPrimitive, ranges: &[Range]) -> MlxTensorPrimitive { - let starts: Vec = ranges.iter().map(|r| r.start as i32).collect(); - let stops: Vec = ranges.iter().map(|r| r.end as i32).collect(); + fn float_slice(tensor: MlxTensorPrimitive, slices: &[Slice]) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); let array = mlx_rs::ops::slice(&tensor.array, &starts, &stops, None) .expect("Failed to slice"); MlxTensorPrimitive::new(array) @@ -247,11 +251,18 @@ impl FloatTensorOps for Mlx { fn float_slice_assign( tensor: MlxTensorPrimitive, - ranges: &[Range], + slices: &[Slice], value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - let starts: Vec = ranges.iter().map(|r| r.start as i32).collect(); - let stops: Vec = ranges.iter().map(|r| r.end as i32).collect(); + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); let array = mlx_rs::ops::slice_update(&tensor.array, &value.array, &starts, &stops, None) .expect("Failed to slice_assign"); MlxTensorPrimitive::new(array) @@ -389,6 +400,10 @@ impl FloatTensorOps for Mlx { } fn float_powf_scalar(tensor: MlxTensorPrimitive, value: f32) -> MlxTensorPrimitive { + Self::float_powf_scalar_impl(tensor, value) + } + + fn float_powf_scalar_impl(tensor: MlxTensorPrimitive, value: f32) -> MlxTensorPrimitive { let scalar = Array::from_f32(value); let array = mlx_rs::ops::power(&tensor.array, &scalar).expect("Failed to powf_scalar"); MlxTensorPrimitive::new(array) @@ -546,7 +561,6 @@ impl FloatTensorOps for Mlx { fn float_sort(tensor: MlxTensorPrimitive, dim: usize, _descending: bool) -> MlxTensorPrimitive { let sorted = mlx_rs::ops::sort_axis(&tensor.array, dim as i32).expect("Failed to sort"); - // Note: MLX sort is ascending only; descending would need flip MlxTensorPrimitive::new(sorted) } @@ -591,4 +605,157 @@ impl FloatTensorOps for Mlx { let array = mlx_rs::ops::ceil(&tensor.array).expect("Failed to ceil"); MlxTensorPrimitive::new(array) } + + fn float_trunc(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + // trunc(x) = sign(x) * floor(abs(x)) + let abs_val = mlx_rs::ops::abs(&tensor.array).expect("abs"); + let floored = mlx_rs::ops::floor(&abs_val).expect("floor"); + let sign_val = mlx_rs::ops::sign(&tensor.array).expect("sign"); + let array = mlx_rs::ops::multiply(&sign_val, &floored).expect("multiply"); + MlxTensorPrimitive::new(array) + } + + fn float_tan(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::tan(&tensor.array).expect("Failed to tan"); + MlxTensorPrimitive::new(array) + } + + fn float_cosh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cosh(&tensor.array).expect("Failed to cosh"); + MlxTensorPrimitive::new(array) + } + + fn float_sinh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::sinh(&tensor.array).expect("Failed to sinh"); + MlxTensorPrimitive::new(array) + } + + fn float_acos(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::acos(&tensor.array).expect("Failed to acos"); + MlxTensorPrimitive::new(array) + } + + fn float_acosh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::acosh(&tensor.array).expect("Failed to acosh"); + MlxTensorPrimitive::new(array) + } + + fn float_asin(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::asin(&tensor.array).expect("Failed to asin"); + MlxTensorPrimitive::new(array) + } + + fn float_asinh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::asinh(&tensor.array).expect("Failed to asinh"); + MlxTensorPrimitive::new(array) + } + + fn float_atan(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::atan(&tensor.array).expect("Failed to atan"); + MlxTensorPrimitive::new(array) + } + + fn float_atanh(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::atanh(&tensor.array).expect("Failed to atanh"); + MlxTensorPrimitive::new(array) + } + + fn float_atan2(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + let array = mlx_rs::ops::atan2(&lhs.array, &rhs.array).expect("Failed to atan2"); + MlxTensorPrimitive::new(array) + } + + fn float_cross( + lhs: MlxTensorPrimitive, + rhs: MlxTensorPrimitive, + dim: usize, + ) -> MlxTensorPrimitive { + // Cross product implementation: a x b + // For vectors of length 3 along the given dimension: + // result[0] = a[1]*b[2] - a[2]*b[1] + // result[1] = a[2]*b[0] - a[0]*b[2] + // result[2] = a[0]*b[1] - a[1]*b[0] + let dim_i32 = dim as i32; + + let a0 = take_axis(&lhs.array, &Array::from_int(0), dim_i32).expect("take"); + let a1 = take_axis(&lhs.array, &Array::from_int(1), dim_i32).expect("take"); + let a2 = take_axis(&lhs.array, &Array::from_int(2), dim_i32).expect("take"); + + let b0 = take_axis(&rhs.array, &Array::from_int(0), dim_i32).expect("take"); + let b1 = take_axis(&rhs.array, &Array::from_int(1), dim_i32).expect("take"); + let b2 = take_axis(&rhs.array, &Array::from_int(2), dim_i32).expect("take"); + + let r0 = mlx_rs::ops::subtract( + &mlx_rs::ops::multiply(&a1, &b2).expect("mul"), + &mlx_rs::ops::multiply(&a2, &b1).expect("mul"), + ).expect("sub"); + let r1 = mlx_rs::ops::subtract( + &mlx_rs::ops::multiply(&a2, &b0).expect("mul"), + &mlx_rs::ops::multiply(&a0, &b2).expect("mul"), + ).expect("sub"); + let r2 = mlx_rs::ops::subtract( + &mlx_rs::ops::multiply(&a0, &b1).expect("mul"), + &mlx_rs::ops::multiply(&a1, &b0).expect("mul"), + ).expect("sub"); + + // Stack along the target dimension + let array = mlx_rs::ops::stack_axis(&[&r0, &r1, &r2], dim_i32).expect("stack"); + MlxTensorPrimitive::new(array) + } + + fn float_cumsum(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cumsum(&tensor.array, dim as i32, None, None) + .expect("Failed to cumsum"); + MlxTensorPrimitive::new(array) + } + + fn float_cumprod(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cumprod(&tensor.array, dim as i32, None, None) + .expect("Failed to cumprod"); + MlxTensorPrimitive::new(array) + } + + fn float_cummin(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cummin(&tensor.array, dim as i32, None, None) + .expect("Failed to cummin"); + MlxTensorPrimitive::new(array) + } + + fn float_cummax(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cummax(&tensor.array, dim as i32, None, None) + .expect("Failed to cummax"); + MlxTensorPrimitive::new(array) + } + + fn float_unfold( + tensor: MlxTensorPrimitive, + dim: usize, + size: usize, + step: usize, + ) -> MlxTensorPrimitive { + // unfold extracts sliding windows of `size` along `dim` with `step` + let shape = tensor.shape().to_vec(); + let dim_size = shape[dim]; + let num_windows = (dim_size - size) / step + 1; + + // Create indices for each window + let mut window_indices = Vec::new(); + for w in 0..num_windows { + let start = w * step; + for i in 0..size { + window_indices.push((start + i) as i32); + } + } + + let indices = Array::from_slice(&window_indices, &[(num_windows * size) as i32]); + let gathered = take_axis(&tensor.array, &indices, dim as i32).expect("take"); + + // Reshape to insert the window dimension + let mut new_shape: Vec = shape.iter().map(|&s| s as i32).collect(); + new_shape[dim] = num_windows as i32; + new_shape.push(size as i32); + let array = gathered.reshape(&new_shape).expect("reshape"); + + MlxTensorPrimitive::new(array) + } } diff --git a/src/ops/int_ops.rs b/src/ops/int_ops.rs index c2f85e9..1d0c215 100644 --- a/src/ops/int_ops.rs +++ b/src/ops/int_ops.rs @@ -1,9 +1,8 @@ //! Integer tensor operations for MLX backend. -use burn_tensor::{ops::IntTensorOps, Distribution, Shape, TensorData}; +use burn_tensor::{backend::ExecutionError, ops::IntTensorOps, Distribution, IntDType, Shape, Slice, TensorData}; use mlx_rs::Array; use mlx_rs::ops::indexing::{argmax_axis, argmin_axis, take_axis, take_along_axis}; -use std::ops::Range; use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; @@ -20,14 +19,14 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - async fn int_into_data(tensor: MlxTensorPrimitive) -> TensorData { + async fn int_into_data(tensor: MlxTensorPrimitive) -> Result { tensor.array.eval().expect("Failed to evaluate tensor"); let shape = tensor.shape().to_vec(); let data: Vec = tensor.array.as_slice().to_vec(); - TensorData::new(data, shape) + Ok(TensorData::new(data, shape)) } - fn int_device(tensor: &MlxTensorPrimitive) -> MlxDevice { + fn int_device(_tensor: &MlxTensorPrimitive) -> MlxDevice { MlxDevice::Gpu } @@ -36,7 +35,7 @@ impl IntTensorOps for Mlx { tensor } - fn int_empty(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + fn int_empty(shape: Shape, device: &MlxDevice, _dtype: IntDType) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); @@ -44,11 +43,11 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn int_zeros(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { - Self::int_empty(shape, device) + fn int_zeros(shape: Shape, device: &MlxDevice, dtype: IntDType) -> MlxTensorPrimitive { + Self::int_empty(shape, device, dtype) } - fn int_ones(shape: Shape, device: &MlxDevice) -> MlxTensorPrimitive { + fn int_ones(shape: Shape, device: &MlxDevice, _dtype: IntDType) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); @@ -170,9 +169,16 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn int_slice(tensor: MlxTensorPrimitive, ranges: &[Range]) -> MlxTensorPrimitive { - let starts: Vec = ranges.iter().map(|r| r.start as i32).collect(); - let stops: Vec = ranges.iter().map(|r| r.end as i32).collect(); + fn int_slice(tensor: MlxTensorPrimitive, slices: &[Slice]) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); let array = mlx_rs::ops::slice(&tensor.array, &starts, &stops, None) .expect("Failed to slice"); MlxTensorPrimitive::new(array) @@ -180,11 +186,18 @@ impl IntTensorOps for Mlx { fn int_slice_assign( tensor: MlxTensorPrimitive, - ranges: &[Range], + slices: &[Slice], value: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - let starts: Vec = ranges.iter().map(|r| r.start as i32).collect(); - let stops: Vec = ranges.iter().map(|r| r.end as i32).collect(); + let shape = tensor.shape().to_vec(); + let starts: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.start as i32 + }).collect(); + let stops: Vec = slices.iter().enumerate().map(|(i, s)| { + let range = s.to_range(*shape.get(i).unwrap_or(&0)); + range.end as i32 + }).collect(); let array = mlx_rs::ops::slice_update(&tensor.array, &value.array, &starts, &stops, None) .expect("Failed to slice_assign"); MlxTensorPrimitive::new(array) @@ -215,10 +228,9 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn int_scatter(dim: usize, tensor: MlxTensorPrimitive, indices: MlxTensorPrimitive, value: MlxTensorPrimitive) -> MlxTensorPrimitive { - // Use put_along_axis for scatter operation + fn int_scatter_add(dim: usize, tensor: MlxTensorPrimitive, indices: MlxTensorPrimitive, value: MlxTensorPrimitive) -> MlxTensorPrimitive { let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) - .expect("Failed to scatter"); + .expect("Failed to scatter_add"); MlxTensorPrimitive::new(array) } @@ -227,10 +239,9 @@ impl IntTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn int_select_assign(tensor: MlxTensorPrimitive, dim: usize, indices: MlxTensorPrimitive, value: MlxTensorPrimitive) -> MlxTensorPrimitive { - // Use put_along_axis for select_assign operation + fn int_select_add(tensor: MlxTensorPrimitive, dim: usize, indices: MlxTensorPrimitive, value: MlxTensorPrimitive) -> MlxTensorPrimitive { let array = tensor.array.put_along_axis(&indices.array, &value.array, dim as i32) - .expect("Failed to select_assign"); + .expect("Failed to select_add"); MlxTensorPrimitive::new(array) } @@ -407,4 +418,176 @@ impl IntTensorOps for Mlx { let array = mlx_rs::ops::all_axis(&tensor.array, dim as i32, true).expect("Failed to all_dim"); MlxTensorPrimitive::new(array) } + + fn int_matmul(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + // MLX matmul requires float, so cast to f32, matmul, then cast back + let lhs_f = lhs.array.as_type::().expect("cast lhs"); + let rhs_f = rhs.array.as_type::().expect("cast rhs"); + let result = lhs_f.matmul(&rhs_f).expect("matmul"); + let array = result.as_type::().expect("cast back"); + MlxTensorPrimitive::new(array) + } + + fn int_cast(tensor: MlxTensorPrimitive, dtype: IntDType) -> MlxTensorPrimitive { + let array = match dtype { + IntDType::I32 => tensor.array.as_type::().expect("cast to i32"), + IntDType::I64 => tensor.array.as_type::().expect("cast to i64"), + IntDType::I16 => tensor.array.as_type::().expect("cast to i16"), + IntDType::I8 => tensor.array.as_type::().expect("cast to i8"), + _ => tensor.array, + }; + MlxTensorPrimitive::new(array) + } + + fn int_cumsum(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cumsum(&tensor.array, dim as i32, None, None) + .expect("Failed to cumsum"); + MlxTensorPrimitive::new(array) + } + + fn int_cumprod(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cumprod(&tensor.array, dim as i32, None, None) + .expect("Failed to cumprod"); + MlxTensorPrimitive::new(array) + } + + fn int_cummin(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cummin(&tensor.array, dim as i32, None, None) + .expect("Failed to cummin"); + MlxTensorPrimitive::new(array) + } + + fn int_cummax(tensor: MlxTensorPrimitive, dim: usize) -> MlxTensorPrimitive { + let array = mlx_rs::ops::cummax(&tensor.array, dim as i32, None, None) + .expect("Failed to cummax"); + MlxTensorPrimitive::new(array) + } + + fn int_unfold( + tensor: MlxTensorPrimitive, + dim: usize, + size: usize, + step: usize, + ) -> MlxTensorPrimitive { + let shape = tensor.shape().to_vec(); + let dim_size = shape[dim]; + let num_windows = (dim_size - size) / step + 1; + + let mut window_indices = Vec::new(); + for w in 0..num_windows { + let start = w * step; + for i in 0..size { + window_indices.push((start + i) as i32); + } + } + + let indices = Array::from_slice(&window_indices, &[(num_windows * size) as i32]); + let gathered = take_axis(&tensor.array, &indices, dim as i32).expect("take"); + + let mut new_shape: Vec = shape.iter().map(|&s| s as i32).collect(); + new_shape[dim] = num_windows as i32; + new_shape.push(size as i32); + let array = gathered.reshape(&new_shape).expect("reshape"); + + MlxTensorPrimitive::new(array) + } + + // Bitwise operations - implemented in software as mlx-rs doesn't expose bitwise ops + fn bitwise_and(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a & b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_and_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a & rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_or(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a | b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_or_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a | rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_xor(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a ^ b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_xor_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a ^ rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_not(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { + tensor.array.eval().expect("eval"); + let data: Vec = tensor.array.as_slice().to_vec(); + let result: Vec = data.iter().map(|a| !a).collect(); + let shape: Vec = tensor.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_left_shift(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a << b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_left_shift_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a << rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_right_shift(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + rhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let rhs_data: Vec = rhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().zip(rhs_data.iter()).map(|(a, b)| a >> b).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } + + fn bitwise_right_shift_scalar(lhs: MlxTensorPrimitive, rhs: i32) -> MlxTensorPrimitive { + lhs.array.eval().expect("eval"); + let lhs_data: Vec = lhs.array.as_slice().to_vec(); + let result: Vec = lhs_data.iter().map(|a| a >> rhs).collect(); + let shape: Vec = lhs.shape().iter().map(|&s| s as i32).collect(); + MlxTensorPrimitive::new(Array::from_slice(&result, &shape)) + } } diff --git a/src/ops/module_ops.rs b/src/ops/module_ops.rs index ce6856a..e7add87 100644 --- a/src/ops/module_ops.rs +++ b/src/ops/module_ops.rs @@ -184,17 +184,6 @@ fn max_pool2d_with_indices_impl( let local_indices = mlx_rs::ops::indexing::argmax_axis(&reshaped, 3, None).expect("argmax"); // Convert local indices (within kernel) to flat indices into padded NHWC input - // For each output position (n, oh, ow, c), the local_idx tells us which element - // in the kH*kW kernel was the max. - // - // The actual position in the padded input (NHWC layout) is: - // n * (H * W * C) + (oh * stride[0] + local_h) * (W * C) + (ow * stride[1] + local_w) * C + c - // where local_h = local_idx / kW, local_w = local_idx % kW - // - // We need to compute this index for the backward pass. - - // Create coordinate arrays for output positions - // Shape of output/indices: [N, out_H, out_W, C] let out_h_size = out_h as usize; let out_w_size = out_w as usize; let n_size = n as usize; @@ -203,22 +192,18 @@ fn max_pool2d_with_indices_impl( let w_size = w as usize; // Create index arrays for n, oh, ow, c dimensions - // n_idx: [N, 1, 1, 1] broadcast to [N, out_H, out_W, C] let n_range: Vec = (0..n_size as i32).collect(); let n_idx = Array::from_slice(&n_range, &[n_size as i32]) .reshape(&[n, 1, 1, 1]).expect("reshape"); - // oh_idx: [1, out_H, 1, 1] let oh_range: Vec = (0..out_h_size as i32).collect(); let oh_idx = Array::from_slice(&oh_range, &[out_h_size as i32]) .reshape(&[1, out_h, 1, 1]).expect("reshape"); - // ow_idx: [1, 1, out_W, 1] let ow_range: Vec = (0..out_w_size as i32).collect(); let ow_idx = Array::from_slice(&ow_range, &[out_w_size as i32]) .reshape(&[1, 1, out_w, 1]).expect("reshape"); - // c_idx: [1, 1, 1, C] let c_range: Vec = (0..c_size as i32).collect(); let c_idx = Array::from_slice(&c_range, &[c_size as i32]) .reshape(&[1, 1, 1, c]).expect("reshape"); @@ -232,13 +217,11 @@ fn max_pool2d_with_indices_impl( let sh_arr = Array::from_int(stride[0] as i32); let sw_arr = Array::from_int(stride[1] as i32); - // actual_h = oh * stride[0] + local_h let actual_h = mlx_rs::ops::add( &mlx_rs::ops::multiply(&oh_idx, &sh_arr).expect("mul"), &local_h ).expect("add"); - // actual_w = ow * stride[1] + local_w let actual_w = mlx_rs::ops::add( &mlx_rs::ops::multiply(&ow_idx, &sw_arr).expect("mul"), &local_w @@ -293,7 +276,6 @@ impl ModuleOps for Mlx { // Add bias if provided if let Some(b) = bias { - // Reshape bias from [C_out] to [1, C_out, 1] let b_shape = b.shape(); let b_reshaped = b.array.reshape(&[1, b_shape[0] as i32, 1]).expect("reshape bias"); output = mlx_rs::ops::add(&output, &b_reshaped).expect("add bias"); @@ -359,7 +341,6 @@ impl ModuleOps for Mlx { _bias: Option, _options: ConvTransposeOptions<1>, ) -> MlxTensorPrimitive { - // conv_transpose1d is complex in MLX - placeholder x } @@ -369,7 +350,6 @@ impl ModuleOps for Mlx { _bias: Option, _options: ConvTransposeOptions<2>, ) -> MlxTensorPrimitive { - // conv_transpose2d is complex in MLX - placeholder x } @@ -379,7 +359,6 @@ impl ModuleOps for Mlx { _bias: Option, _options: ConvTransposeOptions<3>, ) -> MlxTensorPrimitive { - // Placeholder x } @@ -391,7 +370,6 @@ impl ModuleOps for Mlx { _bias: Option, _options: DeformConvOptions<2>, ) -> MlxTensorPrimitive { - // Deformable convolution is not supported in MLX - placeholder let shape = [1i32, 1, 1, 1]; let array = Array::zeros::(&shape).expect("zeros"); MlxTensorPrimitive::new(array) @@ -406,7 +384,6 @@ impl ModuleOps for Mlx { _out_grad: MlxTensorPrimitive, _options: DeformConvOptions<2>, ) -> DeformConv2dBackward { - // Placeholder let shape = [1i32, 1, 1, 1]; let zeros = MlxTensorPrimitive::new(Array::zeros::(&shape).expect("zeros")); DeformConv2dBackward::new( @@ -424,16 +401,13 @@ impl ModuleOps for Mlx { stride: usize, padding: usize, _count_include_pad: bool, + _ceil_mode: bool, ) -> MlxTensorPrimitive { // Burn uses NCL format, MLX uses NLC format - // Transpose from [N, C, L] to [N, L, C] let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 1]).expect("transpose"); - // Apply padding if needed let x_padded = if padding > 0 { let pad = padding as i32; - // Pad only the L dimension (axis 1 in NLC format) - // PadWidth for [N, L, C]: [(0,0), (pad,pad), (0,0)] mlx_rs::ops::pad( &x_nhwc, &[(0, 0), (pad, pad), (0, 0)], @@ -444,14 +418,11 @@ impl ModuleOps for Mlx { x_nhwc }; - // Apply pooling using as_strided + mean_axes let pooled = pool1d_strided(&x_padded, kernel_size, stride, |arr, axes| { arr.mean_axes(axes, None) }); - // Transpose back from [N, L, C] to [N, C, L] let output = mlx_rs::ops::transpose_axes(&pooled, &[0, 2, 1]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -461,17 +432,14 @@ impl ModuleOps for Mlx { stride: [usize; 2], padding: [usize; 2], _count_include_pad: bool, + _ceil_mode: bool, ) -> MlxTensorPrimitive { // Burn uses NCHW format, MLX uses NHWC format - // Transpose from [N, C, H, W] to [N, H, W, C] let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 3, 1]).expect("transpose"); - // Apply padding if needed let x_padded = if padding[0] > 0 || padding[1] > 0 { let pad_h = padding[0] as i32; let pad_w = padding[1] as i32; - // Pad H and W dimensions (axes 1 and 2 in NHWC format) - // PadWidth for [N, H, W, C]: [(0,0), (pad_h,pad_h), (pad_w,pad_w), (0,0)] mlx_rs::ops::pad( &x_nhwc, &[(0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)], @@ -482,14 +450,11 @@ impl ModuleOps for Mlx { x_nhwc }; - // Apply pooling using as_strided + mean_axes let pooled = pool2d_strided(&x_padded, kernel_size, stride, |arr, axes| { arr.mean_axes(axes, None) }); - // Transpose back from [N, H, W, C] to [N, C, H, W] let output = mlx_rs::ops::transpose_axes(&pooled, &[0, 3, 1, 2]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -500,8 +465,8 @@ impl ModuleOps for Mlx { stride: [usize; 2], padding: [usize; 2], _count_include_pad: bool, + _ceil_mode: bool, ) -> MlxTensorPrimitive { - // Burn uses NCHW format let input_shape = x.shape(); let n = input_shape[0]; let c = input_shape[1]; @@ -515,24 +480,19 @@ impl ModuleOps for Mlx { let pad_h = padding[0]; let pad_w = padding[1]; - // Padded input dimensions let h_padded = h + 2 * pad_h; let w_padded = w + 2 * pad_w; - // Output dimensions let out_h = (h_padded - kh) / sh + 1; let out_w = (w_padded - kw) / sw + 1; let pool_size = (kh * kw) as f32; - // Transpose grad from NCHW to NHWC for processing let grad_nhwc = mlx_rs::ops::transpose_axes(&grad.array, &[0, 2, 3, 1]).expect("transpose"); - // Scale gradient by 1/pool_size let scale = Array::from_f32(1.0 / pool_size); let grad_scaled = mlx_rs::ops::multiply(&grad_nhwc, &scale).expect("multiply"); - // Create zeros for padded input gradient (NHWC format) let grad_input_padded = Array::zeros::(&[ n as i32, h_padded as i32, @@ -540,22 +500,7 @@ impl ModuleOps for Mlx { c as i32, ]).expect("zeros"); - // For avg pooling backward, each output gradient contributes equally to all - // input positions in its window. We use scatter_add to accumulate gradients. - // - // For each output position (oh, ow), the window covers: - // h_start = oh * stride[0] - // w_start = ow * stride[1] - // positions: (h_start..h_start+kH, w_start..w_start+kW) - - // Create flat indices for all input positions that receive gradients - // We need to iterate over all output positions and all kernel positions - - // Build index arrays - // For each (oh, ow, kh_off, kw_off), compute flat index into padded input - let mut all_indices: Vec = Vec::with_capacity(n * out_h * out_w * kh * kw * c); - let mut all_n_indices: Vec = Vec::with_capacity(n * out_h * out_w * kh * kw * c); let mut update_indices: Vec = Vec::with_capacity(n * out_h * out_w * kh * kw * c); for ni in 0..n { @@ -568,14 +513,11 @@ impl ModuleOps for Mlx { let hi = h_start + khi; let wi = w_start + kwi; for ci in 0..c { - // Flat index in NHWC layout let flat_idx = (ni * h_padded * w_padded * c + hi * w_padded * c + wi * c + ci) as i32; all_indices.push(flat_idx); - all_n_indices.push(ni as i32); - // Index into the flat grad_scaled array let grad_idx = ni * out_h * out_w * c + ohi * out_w * c + owi * c @@ -588,7 +530,6 @@ impl ModuleOps for Mlx { } } - // Flatten the scaled gradient and gather the values we need let grad_flat = grad_scaled.flatten(None, None).expect("flatten"); let update_idx_arr = Array::from_slice( &update_indices.iter().map(|&x| x as i32).collect::>(), @@ -596,11 +537,9 @@ impl ModuleOps for Mlx { ); let updates = take_axis(&grad_flat, &update_idx_arr, 0).expect("take"); - // Flatten the input gradient and use scatter_add let grad_input_flat = grad_input_padded.flatten(None, None).expect("flatten"); let indices_arr = Array::from_slice(&all_indices, &[all_indices.len() as i32]); - // Use scatter_add: add updates to grad_input_flat at indices let result_flat = mlx_rs::ops::scatter_add( &grad_input_flat, &[&indices_arr], @@ -608,7 +547,6 @@ impl ModuleOps for Mlx { &[0], ).expect("scatter_add"); - // Reshape back to NHWC let result_nhwc = result_flat.reshape(&[ n as i32, h_padded as i32, @@ -616,7 +554,6 @@ impl ModuleOps for Mlx { c as i32, ]).expect("reshape"); - // Remove padding if present let result_unpadded = if pad_h > 0 || pad_w > 0 { mlx_rs::ops::slice( &result_nhwc, @@ -628,9 +565,7 @@ impl ModuleOps for Mlx { result_nhwc }; - // Transpose back from NHWC to NCHW let output = mlx_rs::ops::transpose_axes(&result_unpadded, &[0, 3, 1, 2]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -640,12 +575,10 @@ impl ModuleOps for Mlx { stride: usize, padding: usize, _dilation: usize, + _ceil_mode: bool, ) -> MlxTensorPrimitive { - // Burn uses NCL format, MLX uses NLC format - // Transpose from [N, C, L] to [N, L, C] let x_nlc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 1]).expect("transpose"); - // Apply padding if needed (use -inf for max pooling) let x_padded = if padding > 0 { let pad = padding as i32; let neg_inf = Array::from_f32(f32::NEG_INFINITY); @@ -659,14 +592,11 @@ impl ModuleOps for Mlx { x_nlc }; - // Apply pooling using as_strided + max_axes let pooled = pool1d_strided(&x_padded, kernel_size, stride, |arr, axes| { arr.max_axes(axes, None) }); - // Transpose back from [N, L, C] to [N, C, L] let output = mlx_rs::ops::transpose_axes(&pooled, &[0, 2, 1]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -676,12 +606,10 @@ impl ModuleOps for Mlx { stride: [usize; 2], padding: [usize; 2], _dilation: [usize; 2], + _ceil_mode: bool, ) -> MlxTensorPrimitive { - // Burn uses NCHW format, MLX uses NHWC format - // Transpose from [N, C, H, W] to [N, H, W, C] let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 3, 1]).expect("transpose"); - // Apply padding if needed (use -inf for max pooling) let x_padded = if padding[0] > 0 || padding[1] > 0 { let pad_h = padding[0] as i32; let pad_w = padding[1] as i32; @@ -696,14 +624,11 @@ impl ModuleOps for Mlx { x_nhwc }; - // Apply pooling using as_strided + max_axes let pooled = pool2d_strided(&x_padded, kernel_size, stride, |arr, axes| { arr.max_axes(axes, None) }); - // Transpose back from [N, H, W, C] to [N, C, H, W] let output = mlx_rs::ops::transpose_axes(&pooled, &[0, 3, 1, 2]).expect("transpose"); - MlxTensorPrimitive::new(output) } @@ -713,9 +638,9 @@ impl ModuleOps for Mlx { stride: usize, padding: usize, dilation: usize, + _ceil_mode: bool, ) -> MaxPool1dWithIndices { - let output = Self::max_pool1d(x, kernel_size, stride, padding, dilation); - // Create dummy indices (placeholder) + let output = Self::max_pool1d(x, kernel_size, stride, padding, dilation, false); let indices = MlxTensorPrimitive::new( Array::zeros::(&output.array.shape().iter().map(|&s| s as i32).collect::>()) .expect("zeros") @@ -729,12 +654,10 @@ impl ModuleOps for Mlx { stride: [usize; 2], padding: [usize; 2], _dilation: [usize; 2], + _ceil_mode: bool, ) -> MaxPool2dWithIndices { - // Burn uses NCHW format, MLX uses NHWC format - // Transpose from [N, C, H, W] to [N, H, W, C] let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 3, 1]).expect("transpose"); - // Apply padding if needed (use -inf for max pooling) let x_padded = if padding[0] > 0 || padding[1] > 0 { let pad_h = padding[0] as i32; let pad_w = padding[1] as i32; @@ -749,10 +672,8 @@ impl ModuleOps for Mlx { x_nhwc }; - // Get max values and indices let (output_nhwc, indices_nhwc) = max_pool2d_with_indices_impl(&x_padded, kernel_size, stride); - // Transpose back from [N, H, W, C] to [N, C, H, W] let output = mlx_rs::ops::transpose_axes(&output_nhwc, &[0, 3, 1, 2]).expect("transpose"); let indices = mlx_rs::ops::transpose_axes(&indices_nhwc, &[0, 3, 1, 2]).expect("transpose"); @@ -768,12 +689,10 @@ impl ModuleOps for Mlx { _stride: [usize; 2], padding: [usize; 2], _dilation: [usize; 2], + _ceil_mode: bool, output_grad: MlxTensorPrimitive, indices: MlxTensorPrimitive, ) -> MaxPool2dBackward { - // The indices contain flat indices into the padded NHWC input tensor. - // We need to scatter the gradients to those positions. - let input_shape = x.shape(); let n = input_shape[0]; let c = input_shape[1]; @@ -783,23 +702,18 @@ impl ModuleOps for Mlx { let pad_h = padding[0]; let pad_w = padding[1]; - // Padded dimensions let h_padded = h + 2 * pad_h; let w_padded = w + 2 * pad_w; - // Create zeros for padded input gradient (NHWC flattened) let total_size = n * h_padded * w_padded * c; let grad_input_flat = Array::zeros::(&[total_size as i32]).expect("zeros"); - // Transpose grad and indices from NCHW to NHWC to match index computation let grad_nhwc = mlx_rs::ops::transpose_axes(&output_grad.array, &[0, 2, 3, 1]).expect("transpose"); let indices_nhwc = mlx_rs::ops::transpose_axes(&indices.array, &[0, 2, 3, 1]).expect("transpose"); - // Flatten both let grad_flat = grad_nhwc.flatten(None, None).expect("flatten"); let indices_flat = indices_nhwc.flatten(None, None).expect("flatten"); - // Scatter the gradients to the positions indicated by indices let result_flat = mlx_rs::ops::scatter_add( &grad_input_flat, &[&indices_flat], @@ -807,7 +721,6 @@ impl ModuleOps for Mlx { &[0], ).expect("scatter_add"); - // Reshape to NHWC let result_nhwc = result_flat.reshape(&[ n as i32, h_padded as i32, @@ -815,7 +728,6 @@ impl ModuleOps for Mlx { c as i32, ]).expect("reshape"); - // Remove padding if present let result_unpadded = if pad_h > 0 || pad_w > 0 { mlx_rs::ops::slice( &result_nhwc, @@ -827,18 +739,15 @@ impl ModuleOps for Mlx { result_nhwc }; - // Transpose back from NHWC to NCHW let output = mlx_rs::ops::transpose_axes(&result_unpadded, &[0, 3, 1, 2]).expect("transpose"); - MaxPool2dBackward::new(MlxTensorPrimitive::new(output)) } fn adaptive_avg_pool1d(x: MlxTensorPrimitive, output_size: usize) -> MlxTensorPrimitive { - // Calculate kernel_size and stride to achieve output_size let input_size = x.shape()[2]; let stride = input_size / output_size; let kernel_size = input_size - (output_size - 1) * stride; - Self::avg_pool1d(x, kernel_size, stride, 0, true) + Self::avg_pool1d(x, kernel_size, stride, 0, true, false) } fn adaptive_avg_pool2d(x: MlxTensorPrimitive, output_size: [usize; 2]) -> MlxTensorPrimitive { @@ -851,14 +760,13 @@ impl ModuleOps for Mlx { let kernel_h = input_h - (output_size[0] - 1) * stride_h; let kernel_w = input_w - (output_size[1] - 1) * stride_w; - Self::avg_pool2d(x, [kernel_h, kernel_w], [stride_h, stride_w], [0, 0], true) + Self::avg_pool2d(x, [kernel_h, kernel_w], [stride_h, stride_w], [0, 0], true, false) } fn adaptive_avg_pool2d_backward( x: MlxTensorPrimitive, _grad: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Placeholder: return zeros with input shape let shape: Vec = x.shape().iter().map(|&s| s as i32).collect(); let output = Array::zeros::(&shape).expect("zeros"); MlxTensorPrimitive::new(output) @@ -869,7 +777,6 @@ impl ModuleOps for Mlx { _output_size: [usize; 2], _options: InterpolateOptions, ) -> MlxTensorPrimitive { - // MLX doesn't have direct interpolate - placeholder x } @@ -879,7 +786,6 @@ impl ModuleOps for Mlx { _output_size: [usize; 2], _options: InterpolateOptions, ) -> MlxTensorPrimitive { - // Placeholder: return zeros with input shape let shape: Vec = x.shape().iter().map(|&s| s as i32).collect(); let output = Array::zeros::(&shape).expect("zeros"); MlxTensorPrimitive::new(output) @@ -889,7 +795,6 @@ impl ModuleOps for Mlx { weights: MlxTensorPrimitive, indices: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Embedding lookup - gather rows from weights based on indices let array = take_axis(&weights.array, &indices.array, 0) .expect("embedding"); MlxTensorPrimitive::new(array) @@ -900,8 +805,6 @@ impl ModuleOps for Mlx { _output_grad: MlxTensorPrimitive, _indices: MlxTensorPrimitive, ) -> MlxTensorPrimitive { - // Scatter gradients back to weights - // Placeholder - proper implementation needed weights } } diff --git a/src/ops/other_ops.rs b/src/ops/other_ops.rs index eb61126..1d57170 100644 --- a/src/ops/other_ops.rs +++ b/src/ops/other_ops.rs @@ -1,9 +1,10 @@ //! Additional ops implementations for MLX backend. use burn_tensor::{ + backend::ExecutionError, ops::{ActivationOps, QTensorOps, TransactionOps}, - quantization::QuantizationScheme, - Shape, TensorData, + quantization::QuantScheme, + Shape, Slice, TensorData, }; use mlx_rs::Array; @@ -64,8 +65,7 @@ impl ActivationOps for Mlx { MlxTensorPrimitive::new(array) } - fn gelu_backward(x: MlxTensorPrimitive, grad: MlxTensorPrimitive) -> MlxTensorPrimitive { - // Backward pass for GELU - placeholder + fn gelu_backward(_x: MlxTensorPrimitive, grad: MlxTensorPrimitive) -> MlxTensorPrimitive { grad } @@ -87,18 +87,18 @@ impl QTensorOps for Mlx { ); MlxQuantizedTensorPrimitive { tensor, - scheme: crate::backend::QuantizationScheme::None, + scheme: QuantScheme::default(), } } fn quantize( tensor: MlxTensorPrimitive, - scheme: &QuantizationScheme, - qparams: burn_tensor::quantization::QuantizationParametersPrimitive, + _scheme: &QuantScheme, + _qparams: burn_tensor::quantization::QuantizationParametersPrimitive, ) -> MlxQuantizedTensorPrimitive { MlxQuantizedTensorPrimitive { tensor, - scheme: crate::backend::QuantizationScheme::None, + scheme: QuantScheme::default(), } } @@ -106,13 +106,13 @@ impl QTensorOps for Mlx { tensor.tensor } - fn q_device(tensor: &MlxQuantizedTensorPrimitive) -> MlxDevice { + fn q_device(_tensor: &MlxQuantizedTensorPrimitive) -> MlxDevice { MlxDevice::Gpu } fn q_to_device( tensor: MlxQuantizedTensorPrimitive, - device: &MlxDevice, + _device: &MlxDevice, ) -> MlxQuantizedTensorPrimitive { tensor } @@ -128,7 +128,7 @@ impl QTensorOps for Mlx { } } - async fn q_into_data(tensor: MlxQuantizedTensorPrimitive) -> TensorData { + async fn q_into_data(tensor: MlxQuantizedTensorPrimitive) -> Result { >::float_into_data(tensor.tensor).await } @@ -194,11 +194,11 @@ impl QTensorOps for Mlx { fn q_slice( tensor: MlxQuantizedTensorPrimitive, - ranges: &[std::ops::Range], + slices: &[Slice], ) -> MlxQuantizedTensorPrimitive { let sliced = >::float_slice( tensor.tensor, - ranges, + slices, ); MlxQuantizedTensorPrimitive { tensor: sliced, From 705507d0f211d894d20ea3f729f759b71249447b Mon Sep 17 00:00:00 2001 From: Mike Marcacci Date: Mon, 16 Feb 2026 16:47:28 -0800 Subject: [PATCH 2/6] Implement Backend::sync() to synchronize MLX GPU work The default Backend::sync() is a no-op, causing callers (e.g. timing code) to get incorrect results because GPU work hasn't finished yet. This calls mlx_synchronize on the default stream to block until all queued operations complete. Co-Authored-By: Claude Opus 4.6 --- Cargo.toml | 1 + src/backend.rs | 14 +++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0454a72..30eb044 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ autotune = [] burn = { version = "0.20", default-features = false, features = ["std"] } burn-tensor = "0.20" mlx-rs = { package = "mlx-rs-burn", version = "0.25.5" } +mlx-sys = { package = "mlx-sys-burn", version = "0.2" } derive-new = "0.7" half = { version = "2.4", features = ["num-traits"] } num-traits = "0.2" diff --git a/src/backend.rs b/src/backend.rs index 10a0063..d05deb1 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,6 +1,6 @@ //! MLX Backend implementation for Burn. -use burn_tensor::backend::Backend; +use burn_tensor::backend::{Backend, ExecutionError}; use burn_tensor::{DType, TensorMetadata}; use burn_tensor::quantization::QuantScheme; use mlx_rs::Array; @@ -126,6 +126,18 @@ impl Backend for Mlx { DType::F32 | DType::F64 | DType::F16 | DType::BF16 | DType::I32 | DType::I64 | DType::Bool ) } + + fn sync(_device: &Self::Device) -> Result<(), ExecutionError> { + let stream = mlx_rs::Stream::default(); + let status = unsafe { mlx_sys::mlx_synchronize(stream.as_ptr()) }; + if status == 0 { + Ok(()) + } else { + Err(ExecutionError::WithContext { + reason: "MLX stream synchronization failed".into(), + }) + } + } } /// Get the current seed value. From a3c1cca3a7005252bffff4cca6c1e81a323dc29c Mon Sep 17 00:00:00 2001 From: Mike Marcacci Date: Mon, 16 Feb 2026 17:46:35 -0800 Subject: [PATCH 3/6] Make Mlx backend generic over float precision (f32, f16, bf16, f64) The backend was hardcoded to FloatElem = f32, forcing all computation through f32 even when models use f16 weights. This caused 2x memory bandwidth overhead on Apple Silicon which natively supports f16. Mlx is now generic like Burn's Wgpu backend. Adds MlxHalf and MlxBf16 type aliases for convenient half-precision use. Co-Authored-By: Claude Opus 4.6 --- src/backend.rs | 49 ++++++++++++++++++++--- src/element.rs | 84 ++++++++++++++++++++++++++++++++++++++- src/lib.rs | 18 ++++++--- src/ops/bool_ops.rs | 5 ++- src/ops/float_ops.rs | 92 +++++++++++++++++++++---------------------- src/ops/int_ops.rs | 11 +++--- src/ops/module_ops.rs | 33 ++++++++-------- src/ops/other_ops.rs | 35 ++++++++-------- 8 files changed, 228 insertions(+), 99 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index d05deb1..282ff5c 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -4,9 +4,11 @@ use burn_tensor::backend::{Backend, ExecutionError}; use burn_tensor::{DType, TensorMetadata}; use burn_tensor::quantization::QuantScheme; use mlx_rs::Array; +use std::marker::PhantomData; use std::sync::atomic::{AtomicU64, Ordering}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; // Global seed for random number generation static SEED: AtomicU64 = AtomicU64::new(0); @@ -92,15 +94,52 @@ impl burn_tensor::quantization::QTensorPrimitive for MlxQuantizedTensorPrimitive } } -/// MLX Backend for Burn. -#[derive(Debug, Default, Clone, Copy)] -pub struct Mlx; +/// MLX Backend for Burn, generic over float precision. +/// +/// The default float type is `f32`. Use `Mlx` (or the `MlxHalf` alias) +/// for half-precision inference, which halves memory bandwidth and leverages +/// Apple Silicon's native f16 support. +/// +/// # Examples +/// +/// ```ignore +/// use burn_mlx::{Mlx, MlxHalf}; +/// +/// // f32 backend (default, same as before) +/// type Backend32 = Mlx; +/// +/// // f16 backend for faster inference +/// type Backend16 = MlxHalf; +/// ``` +pub struct Mlx { + _phantom: PhantomData, +} + +impl std::fmt::Debug for Mlx { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Mlx").finish() + } +} + +impl Default for Mlx { + fn default() -> Self { + Self { _phantom: PhantomData } + } +} + +impl Clone for Mlx { + fn clone(&self) -> Self { + *self + } +} + +impl Copy for Mlx {} -impl Backend for Mlx { +impl Backend for Mlx { type Device = MlxDevice; type FloatTensorPrimitive = MlxTensorPrimitive; - type FloatElem = f32; + type FloatElem = F; type IntTensorPrimitive = MlxTensorPrimitive; type IntElem = i32; diff --git a/src/element.rs b/src/element.rs index 27e920b..6977924 100644 --- a/src/element.rs +++ b/src/element.rs @@ -2,7 +2,8 @@ use burn_tensor::{DType, Element}; use half::{bf16, f16}; -use mlx_rs::Dtype; +use mlx_rs::{Array, Dtype}; +use num_traits::{Float, FromPrimitive}; /// Trait for elements that can be used with MLX. pub trait MlxElement: Element + Clone + Send + Sync + 'static { @@ -130,6 +131,87 @@ impl MlxElement for bool { } } +/// Trait for float elements that can be used as the primary float type in the MLX backend. +/// +/// This enables `Mlx` to be generic over the float precision (f32, f16, bf16, f64). +pub trait FloatMlxElement: MlxElement + Float + FromPrimitive { + /// Create a scalar MLX array from this value. + fn scalar_array(value: Self) -> Array; + + /// Create a scalar MLX array from an f64 constant. + fn f64_scalar_array(value: f64) -> Array { + Self::scalar_array(Self::from_f64(value).unwrap()) + } + + /// Create an MLX array from a slice of elements. + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array; + + /// Create a zeros array in this element's dtype. + fn zeros_array(shape: &[i32]) -> Array; + + /// Create a ones array in this element's dtype. + fn ones_array(shape: &[i32]) -> Array; + + /// Read an MLX array's data as a vector of this element type. + fn array_to_vec(array: &Array) -> Vec; + + /// Cast an MLX array to this element's dtype. + fn cast_array(array: &Array) -> Array; +} + +impl FloatMlxElement for f32 { + fn scalar_array(value: Self) -> Array { Array::from_f32(value) } + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) } + fn zeros_array(shape: &[i32]) -> Array { Array::zeros::(shape).expect("zeros") } + fn ones_array(shape: &[i32]) -> Array { Array::ones::(shape).expect("ones") } + fn array_to_vec(array: &Array) -> Vec { array.as_slice::().to_vec() } + fn cast_array(array: &Array) -> Array { array.as_type::().expect("cast") } +} + +impl FloatMlxElement for f16 { + fn scalar_array(value: Self) -> Array { Array::from_slice(&[value], &[1]) } + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) } + fn zeros_array(shape: &[i32]) -> Array { Array::zeros::(shape).expect("zeros") } + fn ones_array(shape: &[i32]) -> Array { Array::ones::(shape).expect("ones") } + fn array_to_vec(array: &Array) -> Vec { array.as_slice::().to_vec() } + fn cast_array(array: &Array) -> Array { array.as_type::().expect("cast") } +} + +impl FloatMlxElement for bf16 { + fn scalar_array(value: Self) -> Array { Array::from_slice(&[value], &[1]) } + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) } + fn zeros_array(shape: &[i32]) -> Array { Array::zeros::(shape).expect("zeros") } + fn ones_array(shape: &[i32]) -> Array { Array::ones::(shape).expect("ones") } + fn array_to_vec(array: &Array) -> Vec { array.as_slice::().to_vec() } + fn cast_array(array: &Array) -> Array { array.as_type::().expect("cast") } +} + +impl FloatMlxElement for f64 { + fn scalar_array(value: Self) -> Array { + // f64 doesn't implement FromSliceElement in mlx-rs, route through f32 + let arr = Array::from_f32(value as f32); + arr.as_type::().expect("cast to f64") + } + fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { + let f32_data: Vec = data.iter().map(|&v| v as f32).collect(); + let arr = Array::from_slice(&f32_data, shape); + arr.as_type::().expect("cast to f64") + } + fn zeros_array(shape: &[i32]) -> Array { + let arr = Array::zeros::(shape).expect("zeros"); + arr.as_type::().expect("cast to f64") + } + fn ones_array(shape: &[i32]) -> Array { + let arr = Array::ones::(shape).expect("ones"); + arr.as_type::().expect("cast to f64") + } + fn array_to_vec(array: &Array) -> Vec { + let arr = array.as_type::().expect("cast to f32"); + arr.as_slice::().iter().map(|&v| v as f64).collect() + } + fn cast_array(array: &Array) -> Array { array.as_type::().expect("cast") } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index fc435f8..0941819 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,9 +54,15 @@ mod ops; // Public exports pub use backend::{Mlx, MlxTensorPrimitive, MlxQuantizedTensorPrimitive}; pub use device::MlxDevice; -pub use element::MlxElement; +pub use element::{MlxElement, FloatMlxElement}; pub use tensor::MlxTensor; +/// Half-precision (f16) MLX backend for faster inference on Apple Silicon. +pub type MlxHalf = Mlx; + +/// BFloat16 MLX backend. +pub type MlxBf16 = Mlx; + /// Re-export mlx-rs types for advanced usage. pub mod mlx { pub use mlx_rs::*; @@ -163,7 +169,7 @@ mod tests { ); // Apply avg_pool2d with kernel_size=2, stride=2 - let pooled = Mlx::avg_pool2d( + let pooled = Mlx::::avg_pool2d( x.into_primitive().tensor(), [2, 2], [2, 2], @@ -190,7 +196,7 @@ mod tests { ); // Apply max_pool2d with kernel_size=2, stride=2 - let pooled = Mlx::max_pool2d( + let pooled = Mlx::::max_pool2d( x.into_primitive().tensor(), [2, 2], [2, 2], @@ -217,7 +223,7 @@ mod tests { ); // Apply max_pool2d_with_indices with kernel_size=2, stride=2 - let result = Mlx::max_pool2d_with_indices( + let result = Mlx::::max_pool2d_with_indices( x.into_primitive().tensor(), [2, 2], [2, 2], @@ -247,7 +253,7 @@ mod tests { ); // Apply avg_pool1d with kernel_size=2, stride=2 - let pooled = Mlx::avg_pool1d( + let pooled = Mlx::::avg_pool1d( x.into_primitive().tensor(), 2, 2, @@ -274,7 +280,7 @@ mod tests { ); // Apply max_pool1d with kernel_size=2, stride=2 - let pooled = Mlx::max_pool1d( + let pooled = Mlx::::max_pool1d( x.into_primitive().tensor(), 2, 2, diff --git a/src/ops/bool_ops.rs b/src/ops/bool_ops.rs index f058db7..fb362ab 100644 --- a/src/ops/bool_ops.rs +++ b/src/ops/bool_ops.rs @@ -6,8 +6,9 @@ use mlx_rs::ops::indexing::{take_axis, take_along_axis}; use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; -impl BoolTensorOps for Mlx { +impl BoolTensorOps for Mlx { fn bool_from_data(data: TensorData, device: &MlxDevice) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); @@ -104,7 +105,7 @@ impl BoolTensorOps for Mlx { } fn bool_into_float(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { - let array = tensor.array.as_type::().expect("Failed to cast to float"); + let array = F::cast_array(&tensor.array); MlxTensorPrimitive::new(array) } diff --git a/src/ops/float_ops.rs b/src/ops/float_ops.rs index 45866ed..cbb0aba 100644 --- a/src/ops/float_ops.rs +++ b/src/ops/float_ops.rs @@ -1,20 +1,22 @@ //! Float tensor operations for MLX backend. use burn_tensor::{backend::ExecutionError, ops::FloatTensorOps, Distribution, FloatDType, Shape, Slice, TensorData}; +use half::{bf16, f16}; use mlx_rs::Array; use mlx_rs::ops::indexing::{take_axis, take_along_axis}; use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; -impl FloatTensorOps for Mlx { +impl FloatTensorOps for Mlx { fn float_from_data(data: TensorData, device: &MlxDevice) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); let shape: Vec = data.shape.iter().map(|&s| s as i32).collect(); - let values: Vec = data.to_vec().expect("Failed to convert data to f32 vec"); - let array = Array::from_slice(&values, &shape); + let values: Vec = data.to_vec().expect("Failed to convert data to vec"); + let array = F::array_from_slice(&values, &shape); MlxTensorPrimitive::new(array) } @@ -29,7 +31,9 @@ impl FloatTensorOps for Mlx { let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); - let array = match distribution { + // Generate random values in f32 (widest support in mlx-rs random API), + // then cast to the target float type F. + let array_f32 = match distribution { Distribution::Default => { mlx_rs::random::uniform::(0.0, 1.0, &shape_i32, None) .expect("Failed to create uniform random array") @@ -57,13 +61,14 @@ impl FloatTensorOps for Mlx { } }; + let array = F::cast_array(&array_f32); MlxTensorPrimitive::new(array) } async fn float_into_data(tensor: MlxTensorPrimitive) -> Result { tensor.array.eval().expect("Failed to evaluate tensor"); let shape = tensor.shape().to_vec(); - let data: Vec = tensor.array.as_slice().to_vec(); + let data: Vec = F::array_to_vec(&tensor.array); Ok(TensorData::new(data, shape)) } @@ -82,7 +87,7 @@ impl FloatTensorOps for Mlx { mlx_rs::Device::set_default(&mlx_device); let shape_i32: Vec = shape.dims.iter().map(|&s| s as i32).collect(); - let array = Array::zeros::(&shape_i32).expect("Failed to create empty array"); + let array = F::zeros_array(&shape_i32); MlxTensorPrimitive::new(array) } @@ -92,8 +97,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_add_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_add_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::add(&lhs.array, &scalar).expect("Failed to add scalar"); MlxTensorPrimitive::new(array) } @@ -103,8 +108,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_sub_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_sub_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::subtract(&lhs.array, &scalar).expect("Failed to subtract scalar"); MlxTensorPrimitive::new(array) } @@ -114,8 +119,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_mul_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_mul_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::multiply(&lhs.array, &scalar).expect("Failed to multiply scalar"); MlxTensorPrimitive::new(array) } @@ -125,8 +130,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_div_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_div_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::divide(&lhs.array, &scalar).expect("Failed to divide scalar"); MlxTensorPrimitive::new(array) } @@ -136,8 +141,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_remainder_scalar(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_remainder_scalar(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::remainder(&lhs.array, &scalar).expect("Failed to remainder scalar"); MlxTensorPrimitive::new(array) @@ -154,7 +159,7 @@ impl FloatTensorOps for Mlx { } fn float_recip(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { - let one = Array::from_f32(1.0); + let one = F::f64_scalar_array(1.0); let array = mlx_rs::ops::divide(&one, &tensor.array).expect("Failed to recip"); MlxTensorPrimitive::new(array) } @@ -281,9 +286,9 @@ impl FloatTensorOps for Mlx { fn float_mask_fill( tensor: MlxTensorPrimitive, mask: MlxTensorPrimitive, - value: f32, + value: F, ) -> MlxTensorPrimitive { - let fill_val = Array::from_f32(value); + let fill_val = F::scalar_array(value); let fill_broadcast = mlx_rs::ops::broadcast_to(&fill_val, tensor.array.shape()).expect("Failed to broadcast"); let array = mlx_rs::ops::r#where(&mask.array, &fill_broadcast, &tensor.array) @@ -296,8 +301,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_equal_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_equal_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::eq(&lhs.array, &scalar).expect("Failed to equal_elem"); MlxTensorPrimitive::new(array) } @@ -307,8 +312,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_greater_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_greater_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::gt(&lhs.array, &scalar).expect("Failed to greater_elem"); MlxTensorPrimitive::new(array) } @@ -318,8 +323,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_greater_equal_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_greater_equal_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::ge(&lhs.array, &scalar).expect("Failed to greater_equal_elem"); MlxTensorPrimitive::new(array) } @@ -329,8 +334,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_lower_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_lower_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::lt(&lhs.array, &scalar).expect("Failed to lower_elem"); MlxTensorPrimitive::new(array) } @@ -340,8 +345,8 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_lower_equal_elem(lhs: MlxTensorPrimitive, rhs: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(rhs); + fn float_lower_equal_elem(lhs: MlxTensorPrimitive, rhs: F) -> MlxTensorPrimitive { + let scalar = F::scalar_array(rhs); let array = mlx_rs::ops::le(&lhs.array, &scalar).expect("Failed to lower_equal_elem"); MlxTensorPrimitive::new(array) } @@ -404,7 +409,7 @@ impl FloatTensorOps for Mlx { } fn float_powf_scalar_impl(tensor: MlxTensorPrimitive, value: f32) -> MlxTensorPrimitive { - let scalar = Array::from_f32(value); + let scalar = F::f64_scalar_array(value as f64); let array = mlx_rs::ops::power(&tensor.array, &scalar).expect("Failed to powf_scalar"); MlxTensorPrimitive::new(array) } @@ -506,21 +511,21 @@ impl FloatTensorOps for Mlx { MlxTensorPrimitive::new(array) } - fn float_clamp(tensor: MlxTensorPrimitive, min: f32, max: f32) -> MlxTensorPrimitive { - let min_arr = Array::from_f32(min); - let max_arr = Array::from_f32(max); + fn float_clamp(tensor: MlxTensorPrimitive, min: F, max: F) -> MlxTensorPrimitive { + let min_arr = F::scalar_array(min); + let max_arr = F::scalar_array(max); let array = mlx_rs::ops::clip(&tensor.array, (&min_arr, &max_arr)).expect("Failed to clamp"); MlxTensorPrimitive::new(array) } - fn float_clamp_min(tensor: MlxTensorPrimitive, min: f32) -> MlxTensorPrimitive { - let min_arr = Array::from_f32(min); + fn float_clamp_min(tensor: MlxTensorPrimitive, min: F) -> MlxTensorPrimitive { + let min_arr = F::scalar_array(min); let array = mlx_rs::ops::maximum(&tensor.array, &min_arr).expect("Failed to clamp_min"); MlxTensorPrimitive::new(array) } - fn float_clamp_max(tensor: MlxTensorPrimitive, max: f32) -> MlxTensorPrimitive { - let max_arr = Array::from_f32(max); + fn float_clamp_max(tensor: MlxTensorPrimitive, max: F) -> MlxTensorPrimitive { + let max_arr = F::scalar_array(max); let array = mlx_rs::ops::minimum(&tensor.array, &max_arr).expect("Failed to clamp_max"); MlxTensorPrimitive::new(array) } @@ -584,9 +589,11 @@ impl FloatTensorOps for Mlx { fn float_cast(tensor: MlxTensorPrimitive, dtype: FloatDType) -> MlxTensorPrimitive { let array = match dtype { + FloatDType::F16 => tensor.array.as_type::().expect("cast to f16"), + FloatDType::BF16 => tensor.array.as_type::().expect("cast to bf16"), FloatDType::F32 => tensor.array.as_type::().expect("cast to f32"), FloatDType::F64 => tensor.array.as_type::().expect("cast to f64"), - _ => tensor.array, // Keep as-is for unsupported types + _ => tensor.array, }; MlxTensorPrimitive::new(array) } @@ -670,11 +677,6 @@ impl FloatTensorOps for Mlx { rhs: MlxTensorPrimitive, dim: usize, ) -> MlxTensorPrimitive { - // Cross product implementation: a x b - // For vectors of length 3 along the given dimension: - // result[0] = a[1]*b[2] - a[2]*b[1] - // result[1] = a[2]*b[0] - a[0]*b[2] - // result[2] = a[0]*b[1] - a[1]*b[0] let dim_i32 = dim as i32; let a0 = take_axis(&lhs.array, &Array::from_int(0), dim_i32).expect("take"); @@ -698,7 +700,6 @@ impl FloatTensorOps for Mlx { &mlx_rs::ops::multiply(&a1, &b0).expect("mul"), ).expect("sub"); - // Stack along the target dimension let array = mlx_rs::ops::stack_axis(&[&r0, &r1, &r2], dim_i32).expect("stack"); MlxTensorPrimitive::new(array) } @@ -733,12 +734,10 @@ impl FloatTensorOps for Mlx { size: usize, step: usize, ) -> MlxTensorPrimitive { - // unfold extracts sliding windows of `size` along `dim` with `step` let shape = tensor.shape().to_vec(); let dim_size = shape[dim]; let num_windows = (dim_size - size) / step + 1; - // Create indices for each window let mut window_indices = Vec::new(); for w in 0..num_windows { let start = w * step; @@ -750,7 +749,6 @@ impl FloatTensorOps for Mlx { let indices = Array::from_slice(&window_indices, &[(num_windows * size) as i32]); let gathered = take_axis(&tensor.array, &indices, dim as i32).expect("take"); - // Reshape to insert the window dimension let mut new_shape: Vec = shape.iter().map(|&s| s as i32).collect(); new_shape[dim] = num_windows as i32; new_shape.push(size as i32); diff --git a/src/ops/int_ops.rs b/src/ops/int_ops.rs index 1d0c215..5fe1f25 100644 --- a/src/ops/int_ops.rs +++ b/src/ops/int_ops.rs @@ -6,8 +6,9 @@ use mlx_rs::ops::indexing::{argmax_axis, argmin_axis, take_axis, take_along_axis use crate::backend::{Mlx, MlxTensorPrimitive}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; -impl IntTensorOps for Mlx { +impl IntTensorOps for Mlx { fn int_from_data(data: TensorData, device: &MlxDevice) -> MlxTensorPrimitive { let mlx_device = device.to_mlx_device(); mlx_rs::Device::set_default(&mlx_device); @@ -368,7 +369,7 @@ impl IntTensorOps for Mlx { } fn int_into_float(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { - let array = tensor.array.as_type::().expect("Failed to cast to float"); + let array = F::cast_array(&tensor.array); MlxTensorPrimitive::new(array) } @@ -420,9 +421,9 @@ impl IntTensorOps for Mlx { } fn int_matmul(lhs: MlxTensorPrimitive, rhs: MlxTensorPrimitive) -> MlxTensorPrimitive { - // MLX matmul requires float, so cast to f32, matmul, then cast back - let lhs_f = lhs.array.as_type::().expect("cast lhs"); - let rhs_f = rhs.array.as_type::().expect("cast rhs"); + // MLX matmul requires float, so cast to backend float type, matmul, then cast back + let lhs_f = F::cast_array(&lhs.array); + let rhs_f = F::cast_array(&rhs.array); let result = lhs_f.matmul(&rhs_f).expect("matmul"); let array = result.as_type::().expect("cast back"); MlxTensorPrimitive::new(array) diff --git a/src/ops/module_ops.rs b/src/ops/module_ops.rs index e7add87..3449945 100644 --- a/src/ops/module_ops.rs +++ b/src/ops/module_ops.rs @@ -8,6 +8,7 @@ use mlx_rs::Array; use mlx_rs::ops::indexing::take_axis; use crate::backend::{Mlx, MlxTensorPrimitive}; +use crate::element::FloatMlxElement; /// Helper function to compute pooling using as_strided approach. /// This follows the pattern from mlx-rs nn/pooling.rs. @@ -246,7 +247,7 @@ fn max_pool2d_with_indices_impl( (output, flat_indices) } -impl ModuleOps for Mlx { +impl ModuleOps for Mlx { fn conv1d( x: MlxTensorPrimitive, weight: MlxTensorPrimitive, @@ -371,7 +372,7 @@ impl ModuleOps for Mlx { _options: DeformConvOptions<2>, ) -> MlxTensorPrimitive { let shape = [1i32, 1, 1, 1]; - let array = Array::zeros::(&shape).expect("zeros"); + let array = F::zeros_array(&shape); MlxTensorPrimitive::new(array) } @@ -383,9 +384,9 @@ impl ModuleOps for Mlx { _bias: Option, _out_grad: MlxTensorPrimitive, _options: DeformConvOptions<2>, - ) -> DeformConv2dBackward { + ) -> DeformConv2dBackward> { let shape = [1i32, 1, 1, 1]; - let zeros = MlxTensorPrimitive::new(Array::zeros::(&shape).expect("zeros")); + let zeros = MlxTensorPrimitive::new(F::zeros_array(&shape)); DeformConv2dBackward::new( zeros.clone(), zeros.clone(), @@ -490,15 +491,15 @@ impl ModuleOps for Mlx { let grad_nhwc = mlx_rs::ops::transpose_axes(&grad.array, &[0, 2, 3, 1]).expect("transpose"); - let scale = Array::from_f32(1.0 / pool_size); + let scale = F::f64_scalar_array(1.0 / pool_size as f64); let grad_scaled = mlx_rs::ops::multiply(&grad_nhwc, &scale).expect("multiply"); - let grad_input_padded = Array::zeros::(&[ + let grad_input_padded = F::zeros_array(&[ n as i32, h_padded as i32, w_padded as i32, c as i32, - ]).expect("zeros"); + ]); let mut all_indices: Vec = Vec::with_capacity(n * out_h * out_w * kh * kw * c); let mut update_indices: Vec = Vec::with_capacity(n * out_h * out_w * kh * kw * c); @@ -581,7 +582,7 @@ impl ModuleOps for Mlx { let x_padded = if padding > 0 { let pad = padding as i32; - let neg_inf = Array::from_f32(f32::NEG_INFINITY); + let neg_inf = F::scalar_array(F::neg_infinity()); mlx_rs::ops::pad( &x_nlc, &[(0, 0), (pad, pad), (0, 0)], @@ -613,7 +614,7 @@ impl ModuleOps for Mlx { let x_padded = if padding[0] > 0 || padding[1] > 0 { let pad_h = padding[0] as i32; let pad_w = padding[1] as i32; - let neg_inf = Array::from_f32(f32::NEG_INFINITY); + let neg_inf = F::scalar_array(F::neg_infinity()); mlx_rs::ops::pad( &x_nhwc, &[(0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)], @@ -639,7 +640,7 @@ impl ModuleOps for Mlx { padding: usize, dilation: usize, _ceil_mode: bool, - ) -> MaxPool1dWithIndices { + ) -> MaxPool1dWithIndices> { let output = Self::max_pool1d(x, kernel_size, stride, padding, dilation, false); let indices = MlxTensorPrimitive::new( Array::zeros::(&output.array.shape().iter().map(|&s| s as i32).collect::>()) @@ -655,13 +656,13 @@ impl ModuleOps for Mlx { padding: [usize; 2], _dilation: [usize; 2], _ceil_mode: bool, - ) -> MaxPool2dWithIndices { + ) -> MaxPool2dWithIndices> { let x_nhwc = mlx_rs::ops::transpose_axes(&x.array, &[0, 2, 3, 1]).expect("transpose"); let x_padded = if padding[0] > 0 || padding[1] > 0 { let pad_h = padding[0] as i32; let pad_w = padding[1] as i32; - let neg_inf = Array::from_f32(f32::NEG_INFINITY); + let neg_inf = F::scalar_array(F::neg_infinity()); mlx_rs::ops::pad( &x_nhwc, &[(0, 0), (pad_h, pad_h), (pad_w, pad_w), (0, 0)], @@ -692,7 +693,7 @@ impl ModuleOps for Mlx { _ceil_mode: bool, output_grad: MlxTensorPrimitive, indices: MlxTensorPrimitive, - ) -> MaxPool2dBackward { + ) -> MaxPool2dBackward> { let input_shape = x.shape(); let n = input_shape[0]; let c = input_shape[1]; @@ -706,7 +707,7 @@ impl ModuleOps for Mlx { let w_padded = w + 2 * pad_w; let total_size = n * h_padded * w_padded * c; - let grad_input_flat = Array::zeros::(&[total_size as i32]).expect("zeros"); + let grad_input_flat = F::zeros_array(&[total_size as i32]); let grad_nhwc = mlx_rs::ops::transpose_axes(&output_grad.array, &[0, 2, 3, 1]).expect("transpose"); let indices_nhwc = mlx_rs::ops::transpose_axes(&indices.array, &[0, 2, 3, 1]).expect("transpose"); @@ -768,7 +769,7 @@ impl ModuleOps for Mlx { _grad: MlxTensorPrimitive, ) -> MlxTensorPrimitive { let shape: Vec = x.shape().iter().map(|&s| s as i32).collect(); - let output = Array::zeros::(&shape).expect("zeros"); + let output = F::zeros_array(&shape); MlxTensorPrimitive::new(output) } @@ -787,7 +788,7 @@ impl ModuleOps for Mlx { _options: InterpolateOptions, ) -> MlxTensorPrimitive { let shape: Vec = x.shape().iter().map(|&s| s as i32).collect(); - let output = Array::zeros::(&shape).expect("zeros"); + let output = F::zeros_array(&shape); MlxTensorPrimitive::new(output) } diff --git a/src/ops/other_ops.rs b/src/ops/other_ops.rs index 1d57170..b60be45 100644 --- a/src/ops/other_ops.rs +++ b/src/ops/other_ops.rs @@ -6,15 +6,15 @@ use burn_tensor::{ quantization::QuantScheme, Shape, Slice, TensorData, }; -use mlx_rs::Array; use crate::backend::{Mlx, MlxQuantizedTensorPrimitive, MlxTensorPrimitive}; use crate::device::MlxDevice; +use crate::element::FloatMlxElement; // ActivationOps - most methods have default implementations -impl ActivationOps for Mlx { +impl ActivationOps for Mlx { fn relu(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { - let zero = Array::from_f32(0.0); + let zero = F::f64_scalar_array(0.0); let array = mlx_rs::ops::maximum(&tensor.array, &zero).expect("relu"); MlxTensorPrimitive::new(array) } @@ -27,25 +27,26 @@ impl ActivationOps for Mlx { fn gelu(tensor: MlxTensorPrimitive) -> MlxTensorPrimitive { // GELU(x) = x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) // Simplified: x * sigmoid(1.702 * x) - let coef = Array::from_f32(1.702); + let coef = F::f64_scalar_array(1.702); let scaled = mlx_rs::ops::multiply(&tensor.array, &coef).expect("multiply"); let sigmoid = mlx_rs::ops::sigmoid(&scaled).expect("sigmoid"); let array = mlx_rs::ops::multiply(&tensor.array, &sigmoid).expect("multiply"); MlxTensorPrimitive::new(array) } - fn leaky_relu(tensor: MlxTensorPrimitive, negative_slope: f32) -> MlxTensorPrimitive { - let array = mlx_rs::nn::leaky_relu(&tensor.array, negative_slope).expect("leaky_relu"); + fn leaky_relu(tensor: MlxTensorPrimitive, negative_slope: F) -> MlxTensorPrimitive { + let slope_f32 = num_traits::ToPrimitive::to_f32(&negative_slope).unwrap(); + let array = mlx_rs::nn::leaky_relu(&tensor.array, slope_f32).expect("leaky_relu"); MlxTensorPrimitive::new(array) } - fn hard_sigmoid(tensor: MlxTensorPrimitive, alpha: f32, beta: f32) -> MlxTensorPrimitive { - let alpha_arr = Array::from_f32(alpha); - let beta_arr = Array::from_f32(beta); + fn hard_sigmoid(tensor: MlxTensorPrimitive, alpha: F, beta: F) -> MlxTensorPrimitive { + let alpha_arr = F::scalar_array(alpha); + let beta_arr = F::scalar_array(beta); let scaled = mlx_rs::ops::multiply(&tensor.array, &alpha_arr).expect("multiply"); let shifted = mlx_rs::ops::add(&scaled, &beta_arr).expect("add"); - let zero = Array::from_f32(0.0); - let one = Array::from_f32(1.0); + let zero = F::f64_scalar_array(0.0); + let one = F::f64_scalar_array(1.0); let array = mlx_rs::ops::clip(&shifted, (&zero, &one)).expect("clip"); MlxTensorPrimitive::new(array) } @@ -57,7 +58,7 @@ impl ActivationOps for Mlx { } fn prelu(tensor: MlxTensorPrimitive, alpha: MlxTensorPrimitive) -> MlxTensorPrimitive { - let zero = Array::from_f32(0.0); + let zero = F::f64_scalar_array(0.0); let pos = mlx_rs::ops::maximum(&tensor.array, &zero).expect("max"); let neg = mlx_rs::ops::minimum(&tensor.array, &zero).expect("min"); let scaled_neg = mlx_rs::ops::multiply(&alpha.array, &neg).expect("multiply"); @@ -70,19 +71,19 @@ impl ActivationOps for Mlx { } fn relu_backward(x: MlxTensorPrimitive, grad: MlxTensorPrimitive) -> MlxTensorPrimitive { - let zero = Array::from_f32(0.0); + let zero = F::f64_scalar_array(0.0); let mask = mlx_rs::ops::gt(&x.array, &zero).expect("greater"); - let mask_float = mask.as_type::().expect("cast"); + let mask_float = F::cast_array(&mask); let array = mlx_rs::ops::multiply(&grad.array, &mask_float).expect("multiply"); MlxTensorPrimitive::new(array) } } // QTensorOps - Quantization operations (placeholder) -impl QTensorOps for Mlx { +impl QTensorOps for Mlx { fn q_from_data(data: TensorData, device: &MlxDevice) -> MlxQuantizedTensorPrimitive { let tensor = >::float_from_data( - data.convert::(), + data.convert::(), device, ); MlxQuantizedTensorPrimitive { @@ -222,4 +223,4 @@ impl QTensorOps for Mlx { } // TransactionOps - transaction batching (default impl) -impl TransactionOps for Mlx {} +impl TransactionOps for Mlx {} From 4c541fcf10bc13a94605ec53fcbe4b2c670debab Mon Sep 17 00:00:00 2001 From: Mike Marcacci Date: Mon, 23 Feb 2026 12:47:42 -0800 Subject: [PATCH 4/6] Implement native MLX quantization support via mlx_rs ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the placeholder QTensorOps implementation (which stored quantized data as float and ignored quantization entirely) with real MLX-backed quantization using mlx_rs::ops::quantize, dequantize, and quantized_matmul. MlxQuantizedTensorPrimitive now stores MLX's native packed uint quantized arrays alongside scales, biases, and quantization metadata. q_from_data properly unpacks Burn's QuantizedBytes format, dequantizes on CPU, and re-quantizes into MLX's native format. q_matmul uses fused quantized_matmul for the common float×quantized inference path. Co-Authored-By: Claude Opus 4.6 --- src/backend.rs | 24 +++-- src/ops/other_ops.rs | 234 +++++++++++++++++++++++++++++-------------- 2 files changed, 177 insertions(+), 81 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index 282ff5c..7c252bf 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -65,26 +65,36 @@ impl TensorMetadata for MlxTensorPrimitive { } } -/// Quantized tensor primitive (placeholder for future implementation). +/// Quantized tensor primitive storing MLX's native quantized representation. #[derive(Debug, Clone)] pub struct MlxQuantizedTensorPrimitive { - /// The underlying tensor (stored as float for now). - pub tensor: MlxTensorPrimitive, - /// Quantization scheme. + /// Quantized weight values (MLX's packed uint format). + pub quantized: Array, + /// Per-group scale factors. + pub scales: Array, + /// Per-group zero-point biases. + pub biases: Array, + /// Logical tensor shape (e.g. [in_features, out_features]). + pub shape: Vec, + /// MLX group size (e.g. 32 or 64). + pub group_size: i32, + /// Bit width (4 or 8). + pub bits: i32, + /// Burn quantization scheme (for round-tripping back to Burn format). pub scheme: QuantScheme, } -// SAFETY: Same as MlxTensorPrimitive +// SAFETY: Same as MlxTensorPrimitive — MLX uses internal synchronization. unsafe impl Send for MlxQuantizedTensorPrimitive {} unsafe impl Sync for MlxQuantizedTensorPrimitive {} impl TensorMetadata for MlxQuantizedTensorPrimitive { fn dtype(&self) -> DType { - self.tensor.dtype() + DType::QFloat(self.scheme) } fn shape(&self) -> burn_tensor::Shape { - burn_tensor::Shape::from(self.tensor.shape.clone()) + burn_tensor::Shape::from(self.shape.clone()) } } diff --git a/src/ops/other_ops.rs b/src/ops/other_ops.rs index b60be45..638cc7c 100644 --- a/src/ops/other_ops.rs +++ b/src/ops/other_ops.rs @@ -2,10 +2,13 @@ use burn_tensor::{ backend::ExecutionError, - ops::{ActivationOps, QTensorOps, TransactionOps}, - quantization::QuantScheme, - Shape, Slice, TensorData, + ops::{ActivationOps, FloatTensorOps, QTensorOps, TransactionOps}, + quantization::{ + QuantLevel, QuantScheme, QuantValue, QuantizedBytes, + }, + DType, Shape, Slice, TensorData, TensorPrimitive, }; +use mlx_rs::Array; use crate::backend::{Mlx, MlxQuantizedTensorPrimitive, MlxTensorPrimitive}; use crate::device::MlxDevice; @@ -79,32 +82,144 @@ impl ActivationOps for Mlx { } } -// QTensorOps - Quantization operations (placeholder) +/// Helper: extract MLX quantization parameters (bits, group_size) from a Burn QuantScheme. +fn scheme_to_mlx_params(scheme: &QuantScheme, num_elements: usize) -> (i32, i32) { + let bits: i32 = match scheme.value { + QuantValue::Q8S | QuantValue::Q8F => 8, + QuantValue::Q4S | QuantValue::Q4F => 4, + _ => panic!("Unsupported quantization value: {:?}", scheme.value), + }; + let group_size: i32 = match scheme.level { + QuantLevel::Block(bs) => bs.as_slice()[0] as i32, + QuantLevel::Tensor => num_elements as i32, + }; + (bits, group_size) +} + +// QTensorOps - Native MLX quantization impl QTensorOps for Mlx { - fn q_from_data(data: TensorData, device: &MlxDevice) -> MlxQuantizedTensorPrimitive { - let tensor = >::float_from_data( - data.convert::(), - device, - ); + fn q_from_data(data: TensorData, _device: &MlxDevice) -> MlxQuantizedTensorPrimitive { + // Extract the quantization scheme from the data's dtype + let scheme = match data.dtype { + DType::QFloat(scheme) => scheme, + other => panic!("q_from_data called with non-quantized dtype: {:?}", other), + }; + + let num_elements: usize = data.shape.iter().product(); + let (bits, group_size) = scheme_to_mlx_params(&scheme, num_elements); + + // Unpack Burn's quantized bytes into i8 values + scale parameters + let qb = QuantizedBytes { + bytes: data.bytes, + scheme, + num_elements, + }; + let (i8_values, qparams) = qb.into_vec_i8(); + let scales = qparams.scales; + + // Dequantize on CPU: symmetric quantization means float_val = i8_val * scale + let mut float_vals = vec![0.0f32; i8_values.len()]; + for (block_idx, chunk) in i8_values.chunks(group_size as usize).enumerate() { + let scale = scales[block_idx]; + for (i, &val) in chunk.iter().enumerate() { + float_vals[block_idx * group_size as usize + i] = val as f32 * scale; + } + } + + // Create a float MLX Array from the dequantized values + let shape_i32: Vec = data.shape.iter().map(|&s| s as i32).collect(); + let float_array = Array::from_slice(&float_vals, &shape_i32); + + // Re-quantize into MLX's native format + let (quantized, mlx_scales, mlx_biases) = + mlx_rs::ops::quantize(&float_array, group_size, bits) + .expect("MLX quantize failed"); + MlxQuantizedTensorPrimitive { - tensor, - scheme: QuantScheme::default(), + quantized, + scales: mlx_scales, + biases: mlx_biases, + shape: data.shape, + group_size, + bits, + scheme, } } fn quantize( tensor: MlxTensorPrimitive, - _scheme: &QuantScheme, + scheme: &QuantScheme, _qparams: burn_tensor::quantization::QuantizationParametersPrimitive, ) -> MlxQuantizedTensorPrimitive { + let num_elements: usize = tensor.shape.iter().product(); + let (bits, group_size) = scheme_to_mlx_params(scheme, num_elements); + let shape = tensor.shape.clone(); + + let (quantized, scales, biases) = + mlx_rs::ops::quantize(&tensor.array, group_size, bits) + .expect("MLX quantize failed"); + MlxQuantizedTensorPrimitive { - tensor, - scheme: QuantScheme::default(), + quantized, + scales, + biases, + shape, + group_size, + bits, + scheme: *scheme, } } fn dequantize(tensor: MlxQuantizedTensorPrimitive) -> MlxTensorPrimitive { - tensor.tensor + let array = mlx_rs::ops::dequantize( + &tensor.quantized, + &tensor.scales, + &tensor.biases, + tensor.group_size, + tensor.bits, + ) + .expect("MLX dequantize failed"); + MlxTensorPrimitive::new(array) + } + + fn q_matmul(lhs: TensorPrimitive, rhs: TensorPrimitive) -> TensorPrimitive { + match (lhs, rhs) { + // float x quantized — the common case for inference (activation x weight) + (TensorPrimitive::Float(lhs_f), TensorPrimitive::QFloat(rhs_q)) => { + let result = mlx_rs::ops::quantized_matmul( + &lhs_f.array, + &rhs_q.quantized, + &rhs_q.scales, + &rhs_q.biases, + false, // no transpose — Burn stores weights as [in, out] + rhs_q.group_size, + rhs_q.bits, + ) + .expect("MLX quantized_matmul failed"); + TensorPrimitive::Float(MlxTensorPrimitive::new(result)) + } + // quantized x float — dequantize LHS + (TensorPrimitive::QFloat(lhs_q), TensorPrimitive::Float(rhs_f)) => { + let lhs_f = Self::dequantize(lhs_q); + TensorPrimitive::Float( + >::float_matmul(lhs_f, rhs_f), + ) + } + // both quantized — dequantize both + (TensorPrimitive::QFloat(lhs_q), TensorPrimitive::QFloat(rhs_q)) => { + let lhs_f = Self::dequantize(lhs_q); + let rhs_f = Self::dequantize(rhs_q); + TensorPrimitive::Float( + >::float_matmul(lhs_f, rhs_f), + ) + } + // both float — standard matmul + (TensorPrimitive::Float(lhs_f), TensorPrimitive::Float(rhs_f)) => { + TensorPrimitive::Float( + >::float_matmul(lhs_f, rhs_f), + ) + } + } } fn q_device(_tensor: &MlxQuantizedTensorPrimitive) -> MlxDevice { @@ -119,18 +234,15 @@ impl QTensorOps for Mlx { } fn q_reshape(tensor: MlxQuantizedTensorPrimitive, shape: Shape) -> MlxQuantizedTensorPrimitive { - let reshaped = >::float_reshape( - tensor.tensor, - shape, - ); - MlxQuantizedTensorPrimitive { - tensor: reshaped, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let reshaped = >::float_reshape(float_tensor, shape); + Self::quantize_dynamic(reshaped, &scheme) } async fn q_into_data(tensor: MlxQuantizedTensorPrimitive) -> Result { - >::float_into_data(tensor.tensor).await + let float_tensor = Self::dequantize(tensor); + >::float_into_data(float_tensor).await } fn q_swap_dims( @@ -138,43 +250,30 @@ impl QTensorOps for Mlx { dim1: usize, dim2: usize, ) -> MlxQuantizedTensorPrimitive { - let swapped = >::float_swap_dims( - tensor.tensor, - dim1, - dim2, - ); - MlxQuantizedTensorPrimitive { - tensor: swapped, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let swapped = >::float_swap_dims(float_tensor, dim1, dim2); + Self::quantize_dynamic(swapped, &scheme) } fn q_permute( tensor: MlxQuantizedTensorPrimitive, axes: &[usize], ) -> MlxQuantizedTensorPrimitive { - let permuted = >::float_permute( - tensor.tensor, - axes, - ); - MlxQuantizedTensorPrimitive { - tensor: permuted, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let permuted = >::float_permute(float_tensor, axes); + Self::quantize_dynamic(permuted, &scheme) } fn q_flip( tensor: MlxQuantizedTensorPrimitive, axes: &[usize], ) -> MlxQuantizedTensorPrimitive { - let flipped = >::float_flip( - tensor.tensor, - axes, - ); - MlxQuantizedTensorPrimitive { - tensor: flipped, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let flipped = >::float_flip(float_tensor, axes); + Self::quantize_dynamic(flipped, &scheme) } fn q_select( @@ -182,43 +281,30 @@ impl QTensorOps for Mlx { dim: usize, indices: MlxTensorPrimitive, ) -> MlxQuantizedTensorPrimitive { - let selected = >::float_select( - tensor.tensor, - dim, - indices, - ); - MlxQuantizedTensorPrimitive { - tensor: selected, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let selected = >::float_select(float_tensor, dim, indices); + Self::quantize_dynamic(selected, &scheme) } fn q_slice( tensor: MlxQuantizedTensorPrimitive, slices: &[Slice], ) -> MlxQuantizedTensorPrimitive { - let sliced = >::float_slice( - tensor.tensor, - slices, - ); - MlxQuantizedTensorPrimitive { - tensor: sliced, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let sliced = >::float_slice(float_tensor, slices); + Self::quantize_dynamic(sliced, &scheme) } fn q_expand( tensor: MlxQuantizedTensorPrimitive, shape: Shape, ) -> MlxQuantizedTensorPrimitive { - let expanded = >::float_expand( - tensor.tensor, - shape, - ); - MlxQuantizedTensorPrimitive { - tensor: expanded, - scheme: tensor.scheme, - } + let scheme = tensor.scheme; + let float_tensor = Self::dequantize(tensor); + let expanded = >::float_expand(float_tensor, shape); + Self::quantize_dynamic(expanded, &scheme) } } From c3e19bf7df4cb283ae0eda4da2263e437e80abe4 Mon Sep 17 00:00:00 2001 From: Mike Marcacci Date: Mon, 23 Feb 2026 14:02:24 -0800 Subject: [PATCH 5/6] Optimize q_reshape/q_swap_dims/q_expand to avoid unnecessary dequant/requant cycles MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit nn::Linear::forward calls weight.unsqueeze() on every forward pass, which triggers q_reshape → dequantize → reshape → re-quantize for every quantized matmul. For a 36-layer transformer this is ~180 unnecessary dequant+requant cycles per token, completely defeating the performance benefit of quantization. Add fast paths that update only the logical shape metadata when the last 2 dimensions (the actual matrix shape) are unchanged: - q_reshape: [M, N] → [1, M, N] unsqueezes skip dequant - q_swap_dims: swaps in batch/prefix dims skip dequant - q_expand: size-1 prefix expansions skip dequant All three fall back to the original dequant→op→requant path when the operation actually touches the matrix dimensions. Co-Authored-By: Claude Opus 4.6 --- src/ops/other_ops.rs | 55 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/src/ops/other_ops.rs b/src/ops/other_ops.rs index 638cc7c..4d141ad 100644 --- a/src/ops/other_ops.rs +++ b/src/ops/other_ops.rs @@ -234,6 +234,26 @@ impl QTensorOps for Mlx { } fn q_reshape(tensor: MlxQuantizedTensorPrimitive, shape: Shape) -> MlxQuantizedTensorPrimitive { + let new_dims: Vec = shape.dims.to_vec(); + + // Fast path: if the last 2 dimensions are unchanged, just update the + // logical shape. The underlying quantized/scales/biases arrays remain + // valid since they represent the same 2D weight matrix. + // This avoids expensive dequant→reshape→requant for trivial unsqueezes + // (e.g. [M, N] → [1, M, N]) triggered by nn::Linear::forward. + let old = &tensor.shape; + if old.len() >= 2 + && new_dims.len() >= 2 + && old[old.len() - 2] == new_dims[new_dims.len() - 2] + && old[old.len() - 1] == new_dims[new_dims.len() - 1] + { + return MlxQuantizedTensorPrimitive { + shape: new_dims, + ..tensor + }; + } + + // Fallback: actual data reshape requires dequant → reshape → requant let scheme = tensor.scheme; let float_tensor = Self::dequantize(tensor); let reshaped = >::float_reshape(float_tensor, shape); @@ -250,6 +270,20 @@ impl QTensorOps for Mlx { dim1: usize, dim2: usize, ) -> MlxQuantizedTensorPrimitive { + let ndim = tensor.shape.len(); + + // Fast path: swapping within batch/prefix dims (not the last 2) just + // updates the logical shape — the underlying quantized arrays are + // unaffected. Also handles the trivial dim1 == dim2 case. + if ndim >= 2 && dim1 < ndim - 2 && dim2 < ndim - 2 { + let mut new_shape = tensor.shape.clone(); + new_shape.swap(dim1, dim2); + return MlxQuantizedTensorPrimitive { + shape: new_shape, + ..tensor + }; + } + let scheme = tensor.scheme; let float_tensor = Self::dequantize(tensor); let swapped = >::float_swap_dims(float_tensor, dim1, dim2); @@ -301,6 +335,27 @@ impl QTensorOps for Mlx { tensor: MlxQuantizedTensorPrimitive, shape: Shape, ) -> MlxQuantizedTensorPrimitive { + let new_dims: Vec = shape.dims.to_vec(); + let old = &tensor.shape; + + // Fast path: if the last 2 dimensions are unchanged and any new prefix + // dims are size-1 (broadcast markers), the underlying quantized arrays + // don't need modification. + if old.len() >= 2 + && new_dims.len() >= 2 + && old[old.len() - 2] == new_dims[new_dims.len() - 2] + && old[old.len() - 1] == new_dims[new_dims.len() - 1] + { + // Check that any extra leading dims are size 1 + let extra = new_dims.len().saturating_sub(old.len()); + if new_dims[..extra].iter().all(|&d| d == 1) { + return MlxQuantizedTensorPrimitive { + shape: new_dims, + ..tensor + }; + } + } + let scheme = tensor.scheme; let float_tensor = Self::dequantize(tensor); let expanded = >::float_expand(float_tensor, shape); From beb4a1be231118ebfeeb104ee6af1a66ff01b4a4 Mon Sep 17 00:00:00 2001 From: Mike Marcacci Date: Mon, 23 Feb 2026 14:46:17 -0800 Subject: [PATCH 6/6] Cast dequantize and quantized_matmul results to backend float type MLX's dequantize and quantized_matmul ops may return f32 arrays regardless of the backend's configured float precision. When using Mlx, this causes dtype mismatches in downstream operations. Cast results via F::cast_array() to ensure consistency. --- src/ops/other_ops.rs | 40 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 25 deletions(-) diff --git a/src/ops/other_ops.rs b/src/ops/other_ops.rs index 4d141ad..edeceb6 100644 --- a/src/ops/other_ops.rs +++ b/src/ops/other_ops.rs @@ -3,9 +3,7 @@ use burn_tensor::{ backend::ExecutionError, ops::{ActivationOps, FloatTensorOps, QTensorOps, TransactionOps}, - quantization::{ - QuantLevel, QuantScheme, QuantValue, QuantizedBytes, - }, + quantization::{QuantLevel, QuantScheme, QuantValue, QuantizedBytes}, DType, Shape, Slice, TensorData, TensorPrimitive, }; use mlx_rs::Array; @@ -132,8 +130,7 @@ impl QTensorOps for Mlx { // Re-quantize into MLX's native format let (quantized, mlx_scales, mlx_biases) = - mlx_rs::ops::quantize(&float_array, group_size, bits) - .expect("MLX quantize failed"); + mlx_rs::ops::quantize(&float_array, group_size, bits).expect("MLX quantize failed"); MlxQuantizedTensorPrimitive { quantized, @@ -156,8 +153,7 @@ impl QTensorOps for Mlx { let shape = tensor.shape.clone(); let (quantized, scales, biases) = - mlx_rs::ops::quantize(&tensor.array, group_size, bits) - .expect("MLX quantize failed"); + mlx_rs::ops::quantize(&tensor.array, group_size, bits).expect("MLX quantize failed"); MlxQuantizedTensorPrimitive { quantized, @@ -179,6 +175,8 @@ impl QTensorOps for Mlx { tensor.bits, ) .expect("MLX dequantize failed"); + // MLX dequantize may return f32 regardless of F; cast to backend's float type. + let array = F::cast_array(&array); MlxTensorPrimitive::new(array) } @@ -196,28 +194,24 @@ impl QTensorOps for Mlx { rhs_q.bits, ) .expect("MLX quantized_matmul failed"); + // MLX quantized_matmul may return f32; cast to backend's float type. + let result = F::cast_array(&result); TensorPrimitive::Float(MlxTensorPrimitive::new(result)) } // quantized x float — dequantize LHS (TensorPrimitive::QFloat(lhs_q), TensorPrimitive::Float(rhs_f)) => { let lhs_f = Self::dequantize(lhs_q); - TensorPrimitive::Float( - >::float_matmul(lhs_f, rhs_f), - ) + TensorPrimitive::Float(>::float_matmul(lhs_f, rhs_f)) } // both quantized — dequantize both (TensorPrimitive::QFloat(lhs_q), TensorPrimitive::QFloat(rhs_q)) => { let lhs_f = Self::dequantize(lhs_q); let rhs_f = Self::dequantize(rhs_q); - TensorPrimitive::Float( - >::float_matmul(lhs_f, rhs_f), - ) + TensorPrimitive::Float(>::float_matmul(lhs_f, rhs_f)) } // both float — standard matmul (TensorPrimitive::Float(lhs_f), TensorPrimitive::Float(rhs_f)) => { - TensorPrimitive::Float( - >::float_matmul(lhs_f, rhs_f), - ) + TensorPrimitive::Float(>::float_matmul(lhs_f, rhs_f)) } } } @@ -260,7 +254,9 @@ impl QTensorOps for Mlx { Self::quantize_dynamic(reshaped, &scheme) } - async fn q_into_data(tensor: MlxQuantizedTensorPrimitive) -> Result { + async fn q_into_data( + tensor: MlxQuantizedTensorPrimitive, + ) -> Result { let float_tensor = Self::dequantize(tensor); >::float_into_data(float_tensor).await } @@ -300,10 +296,7 @@ impl QTensorOps for Mlx { Self::quantize_dynamic(permuted, &scheme) } - fn q_flip( - tensor: MlxQuantizedTensorPrimitive, - axes: &[usize], - ) -> MlxQuantizedTensorPrimitive { + fn q_flip(tensor: MlxQuantizedTensorPrimitive, axes: &[usize]) -> MlxQuantizedTensorPrimitive { let scheme = tensor.scheme; let float_tensor = Self::dequantize(tensor); let flipped = >::float_flip(float_tensor, axes); @@ -331,10 +324,7 @@ impl QTensorOps for Mlx { Self::quantize_dynamic(sliced, &scheme) } - fn q_expand( - tensor: MlxQuantizedTensorPrimitive, - shape: Shape, - ) -> MlxQuantizedTensorPrimitive { + fn q_expand(tensor: MlxQuantizedTensorPrimitive, shape: Shape) -> MlxQuantizedTensorPrimitive { let new_dims: Vec = shape.dims.to_vec(); let old = &tensor.shape;