Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@ fusion = []
autotune = []

[dependencies]
burn = { version = "0.16", default-features = false, features = ["std"] }
burn-tensor = "0.16"
burn = { version = "0.20", default-features = false, features = ["std"] }
burn-tensor = "0.20"
mlx-rs = { package = "mlx-rs-burn", version = "0.25.5" }
mlx-sys = { package = "mlx-sys-burn", version = "0.2" }
derive-new = "0.7"
half = { version = "2.4", features = ["num-traits"] }
num-traits = "0.2"

[dev-dependencies]
burn = { version = "0.16", features = ["train", "dataset"] }
burn = { version = "0.20", features = ["train", "dataset"] }
serial_test = "3.2"
145 changes: 97 additions & 48 deletions src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
//! MLX Backend implementation for Burn.

use burn_tensor::backend::Backend;
use burn_tensor::TensorMetadata;
use burn_tensor::backend::{Backend, ExecutionError};
use burn_tensor::{DType, TensorMetadata};
use burn_tensor::quantization::QuantScheme;
use mlx_rs::Array;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicU64, Ordering};

use crate::device::MlxDevice;
use crate::element::FloatMlxElement;

// Global seed for random number generation
static SEED: AtomicU64 = AtomicU64::new(0);
Expand Down Expand Up @@ -43,17 +46,17 @@ unsafe impl Send for MlxTensorPrimitive {}
unsafe impl Sync for MlxTensorPrimitive {}

impl TensorMetadata for MlxTensorPrimitive {
fn dtype(&self) -> burn_tensor::DType {
fn dtype(&self) -> DType {
// Map MLX dtype to Burn dtype
match self.array.dtype() {
mlx_rs::Dtype::Float32 => burn_tensor::DType::F32,
mlx_rs::Dtype::Float16 => burn_tensor::DType::F16,
mlx_rs::Dtype::Bfloat16 => burn_tensor::DType::BF16,
mlx_rs::Dtype::Float64 => burn_tensor::DType::F64,
mlx_rs::Dtype::Int32 => burn_tensor::DType::I32,
mlx_rs::Dtype::Int64 => burn_tensor::DType::I64,
mlx_rs::Dtype::Bool => burn_tensor::DType::Bool,
_ => burn_tensor::DType::F32, // Default fallback
mlx_rs::Dtype::Float32 => DType::F32,
mlx_rs::Dtype::Float16 => DType::F16,
mlx_rs::Dtype::Bfloat16 => DType::BF16,
mlx_rs::Dtype::Float64 => DType::F64,
mlx_rs::Dtype::Int32 => DType::I32,
mlx_rs::Dtype::Int64 => DType::I64,
mlx_rs::Dtype::Bool => DType::Bool,
_ => DType::F32, // Default fallback
}
}

Expand All @@ -62,56 +65,91 @@ impl TensorMetadata for MlxTensorPrimitive {
}
}

/// Quantized tensor primitive (placeholder for future implementation).
/// Quantized tensor primitive storing MLX's native quantized representation.
#[derive(Debug, Clone)]
pub struct MlxQuantizedTensorPrimitive {
/// The underlying tensor (stored as float for now).
pub tensor: MlxTensorPrimitive,
/// Quantization scheme.
pub scheme: QuantizationScheme,
}

/// Quantization scheme.
#[derive(Debug, Clone, Copy, Default)]
pub enum QuantizationScheme {
#[default]
None,
/// Quantized weight values (MLX's packed uint format).
pub quantized: Array,
/// Per-group scale factors.
pub scales: Array,
/// Per-group zero-point biases.
pub biases: Array,
/// Logical tensor shape (e.g. [in_features, out_features]).
pub shape: Vec<usize>,
/// MLX group size (e.g. 32 or 64).
pub group_size: i32,
/// Bit width (4 or 8).
pub bits: i32,
/// Burn quantization scheme (for round-tripping back to Burn format).
pub scheme: QuantScheme,
}

// SAFETY: Same as MlxTensorPrimitive
// SAFETY: Same as MlxTensorPrimitive — MLX uses internal synchronization.
unsafe impl Send for MlxQuantizedTensorPrimitive {}
unsafe impl Sync for MlxQuantizedTensorPrimitive {}

impl TensorMetadata for MlxQuantizedTensorPrimitive {
fn dtype(&self) -> burn_tensor::DType {
self.tensor.dtype()
fn dtype(&self) -> DType {
DType::QFloat(self.scheme)
}

fn shape(&self) -> burn_tensor::Shape {
burn_tensor::Shape::from(self.tensor.shape.clone())
burn_tensor::Shape::from(self.shape.clone())
}
}

impl burn_tensor::quantization::QTensorPrimitive for MlxQuantizedTensorPrimitive {
fn scheme(&self) -> &burn_tensor::quantization::QuantizationScheme {
// Return a reference to a static scheme
static SYMMETRIC: burn_tensor::quantization::QuantizationScheme =
burn_tensor::quantization::QuantizationScheme::PerTensorSymmetric(
burn_tensor::quantization::QuantizationType::QInt8,
);
&SYMMETRIC
fn scheme(&self) -> &QuantScheme {
&self.scheme
}
}

/// MLX Backend for Burn, generic over float precision.
///
/// The default float type is `f32`. Use `Mlx<half::f16>` (or the `MlxHalf` alias)
/// for half-precision inference, which halves memory bandwidth and leverages
/// Apple Silicon's native f16 support.
///
/// # Examples
///
/// ```ignore
/// use burn_mlx::{Mlx, MlxHalf};
///
/// // f32 backend (default, same as before)
/// type Backend32 = Mlx;
///
/// // f16 backend for faster inference
/// type Backend16 = MlxHalf;
/// ```
pub struct Mlx<F: FloatMlxElement = f32> {
_phantom: PhantomData<F>,
}

impl<F: FloatMlxElement> std::fmt::Debug for Mlx<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Mlx").finish()
}
}

/// MLX Backend for Burn.
#[derive(Debug, Default, Clone, Copy)]
pub struct Mlx;
impl<F: FloatMlxElement> Default for Mlx<F> {
fn default() -> Self {
Self { _phantom: PhantomData }
}
}

impl<F: FloatMlxElement> Clone for Mlx<F> {
fn clone(&self) -> Self {
*self
}
}

impl Backend for Mlx {
impl<F: FloatMlxElement> Copy for Mlx<F> {}

impl<F: FloatMlxElement> Backend for Mlx<F> {
type Device = MlxDevice;

type FloatTensorPrimitive = MlxTensorPrimitive;
type FloatElem = f32;
type FloatElem = F;

type IntTensorPrimitive = MlxTensorPrimitive;
type IntElem = i32;
Expand All @@ -120,23 +158,34 @@ impl Backend for Mlx {
type BoolElem = bool;

type QuantizedTensorPrimitive = MlxQuantizedTensorPrimitive;
type QuantizedEncoding = i8;

fn name() -> String {
fn name(_device: &Self::Device) -> String {
"mlx".to_string()
}

fn seed(seed: u64) {
fn seed(_device: &Self::Device, seed: u64) {
SEED.store(seed, Ordering::SeqCst);
// MLX uses its own seeding mechanism
mlx_rs::random::seed(seed);
let _ = mlx_rs::random::seed(seed);
}

fn supports_dtype(_device: &Self::Device, dtype: DType) -> bool {
matches!(
dtype,
DType::F32 | DType::F64 | DType::F16 | DType::BF16 | DType::I32 | DType::I64 | DType::Bool
)
}

fn sync(device: &Self::Device) {
// MLX is lazy-evaluated; sync forces evaluation
// This is a no-op in MLX as synchronization happens implicitly
// when reading tensor values
let _ = device;
fn sync(_device: &Self::Device) -> Result<(), ExecutionError> {
let stream = mlx_rs::Stream::default();
let status = unsafe { mlx_sys::mlx_synchronize(stream.as_ptr()) };
if status == 0 {
Ok(())
} else {
Err(ExecutionError::WithContext {
reason: "MLX stream synchronization failed".into(),
})
}
}
}

Expand Down
17 changes: 15 additions & 2 deletions src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,24 @@ impl fmt::Display for MlxDevice {
}
}

impl DeviceOps for MlxDevice {
fn id(&self) -> burn_tensor::backend::DeviceId {
impl burn_tensor::backend::Device for MlxDevice {
fn from_id(device_id: burn_tensor::backend::DeviceId) -> Self {
match device_id.type_id {
0 => MlxDevice::Cpu,
_ => MlxDevice::Gpu,
}
}

fn to_id(&self) -> burn_tensor::backend::DeviceId {
match self {
MlxDevice::Cpu => burn_tensor::backend::DeviceId::new(0, 0),
MlxDevice::Gpu => burn_tensor::backend::DeviceId::new(1, 0),
}
}

fn device_count(_type_id: u16) -> usize {
1
}
}

impl DeviceOps for MlxDevice {}
84 changes: 83 additions & 1 deletion src/element.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

use burn_tensor::{DType, Element};
use half::{bf16, f16};
use mlx_rs::Dtype;
use mlx_rs::{Array, Dtype};
use num_traits::{Float, FromPrimitive};

/// Trait for elements that can be used with MLX.
pub trait MlxElement: Element + Clone + Send + Sync + 'static {
Expand Down Expand Up @@ -130,6 +131,87 @@ impl MlxElement for bool {
}
}

/// Trait for float elements that can be used as the primary float type in the MLX backend.
///
/// This enables `Mlx<F>` to be generic over the float precision (f32, f16, bf16, f64).
pub trait FloatMlxElement: MlxElement + Float + FromPrimitive {
/// Create a scalar MLX array from this value.
fn scalar_array(value: Self) -> Array;

/// Create a scalar MLX array from an f64 constant.
fn f64_scalar_array(value: f64) -> Array {
Self::scalar_array(Self::from_f64(value).unwrap())
}

/// Create an MLX array from a slice of elements.
fn array_from_slice(data: &[Self], shape: &[i32]) -> Array;

/// Create a zeros array in this element's dtype.
fn zeros_array(shape: &[i32]) -> Array;

/// Create a ones array in this element's dtype.
fn ones_array(shape: &[i32]) -> Array;

/// Read an MLX array's data as a vector of this element type.
fn array_to_vec(array: &Array) -> Vec<Self>;

/// Cast an MLX array to this element's dtype.
fn cast_array(array: &Array) -> Array;
}

impl FloatMlxElement for f32 {
fn scalar_array(value: Self) -> Array { Array::from_f32(value) }
fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) }
fn zeros_array(shape: &[i32]) -> Array { Array::zeros::<f32>(shape).expect("zeros") }
fn ones_array(shape: &[i32]) -> Array { Array::ones::<f32>(shape).expect("ones") }
fn array_to_vec(array: &Array) -> Vec<Self> { array.as_slice::<f32>().to_vec() }
fn cast_array(array: &Array) -> Array { array.as_type::<f32>().expect("cast") }
}

impl FloatMlxElement for f16 {
fn scalar_array(value: Self) -> Array { Array::from_slice(&[value], &[1]) }
fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) }
fn zeros_array(shape: &[i32]) -> Array { Array::zeros::<f16>(shape).expect("zeros") }
fn ones_array(shape: &[i32]) -> Array { Array::ones::<f16>(shape).expect("ones") }
fn array_to_vec(array: &Array) -> Vec<Self> { array.as_slice::<f16>().to_vec() }
fn cast_array(array: &Array) -> Array { array.as_type::<f16>().expect("cast") }
}

impl FloatMlxElement for bf16 {
fn scalar_array(value: Self) -> Array { Array::from_slice(&[value], &[1]) }
fn array_from_slice(data: &[Self], shape: &[i32]) -> Array { Array::from_slice(data, shape) }
fn zeros_array(shape: &[i32]) -> Array { Array::zeros::<bf16>(shape).expect("zeros") }
fn ones_array(shape: &[i32]) -> Array { Array::ones::<bf16>(shape).expect("ones") }
fn array_to_vec(array: &Array) -> Vec<Self> { array.as_slice::<bf16>().to_vec() }
fn cast_array(array: &Array) -> Array { array.as_type::<bf16>().expect("cast") }
}

impl FloatMlxElement for f64 {
fn scalar_array(value: Self) -> Array {
// f64 doesn't implement FromSliceElement in mlx-rs, route through f32
let arr = Array::from_f32(value as f32);
arr.as_type::<f64>().expect("cast to f64")
}
fn array_from_slice(data: &[Self], shape: &[i32]) -> Array {
let f32_data: Vec<f32> = data.iter().map(|&v| v as f32).collect();
let arr = Array::from_slice(&f32_data, shape);
arr.as_type::<f64>().expect("cast to f64")
}
fn zeros_array(shape: &[i32]) -> Array {
let arr = Array::zeros::<f32>(shape).expect("zeros");
arr.as_type::<f64>().expect("cast to f64")
}
fn ones_array(shape: &[i32]) -> Array {
let arr = Array::ones::<f32>(shape).expect("ones");
arr.as_type::<f64>().expect("cast to f64")
}
fn array_to_vec(array: &Array) -> Vec<Self> {
let arr = array.as_type::<f32>().expect("cast to f32");
arr.as_slice::<f32>().iter().map(|&v| v as f64).collect()
}
fn cast_array(array: &Array) -> Array { array.as_type::<f64>().expect("cast") }
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading