Skip to content

Update burn dependency from 0.16 to 0.20#1

Open
mike-marcacci wants to merge 6 commits intoTuringWorks:mainfrom
eidola-ai:burn-0-20
Open

Update burn dependency from 0.16 to 0.20#1
mike-marcacci wants to merge 6 commits intoTuringWorks:mainfrom
eidola-ai:burn-0-20

Conversation

@mike-marcacci
Copy link

@mike-marcacci mike-marcacci commented Feb 12, 2026

Migrates burn-mlx across 4 major burn versions (0.17-0.20) with all required API changes: Device trait impl, Slice replacing Range, ExecutionError return types, QuantScheme, renamed scatter/select ops, ceil_mode on pooling, and new required methods (trig, cumulative, bitwise, unfold, cross, cast, matmul for ints).


Note that these changes were made entirely by Claude Code. I wanted to use burn-mlx to help me differentiate performance issues caused directly to my qwen3-burn implementation and inherent limitations of WGPU/CubeCL.

It ended up making a massive difference in my implementation, so I figured I'd open a PR.

mike-marcacci and others added 6 commits February 12, 2026 12:25
Migrates burn-mlx across 4 major burn versions (0.17-0.20) with all
required API changes: Device trait impl, Slice replacing Range<usize>,
ExecutionError return types, QuantScheme, renamed scatter/select ops,
ceil_mode on pooling, and new required methods (trig, cumulative,
bitwise, unfold, cross, cast, matmul for ints).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The default Backend::sync() is a no-op, causing callers (e.g. timing
code) to get incorrect results because GPU work hasn't finished yet.
This calls mlx_synchronize on the default stream to block until all
queued operations complete.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The backend was hardcoded to FloatElem = f32, forcing all computation
through f32 even when models use f16 weights. This caused 2x memory
bandwidth overhead on Apple Silicon which natively supports f16.

Mlx<F: FloatMlxElement = f32> is now generic like Burn's Wgpu backend.
Adds MlxHalf and MlxBf16 type aliases for convenient half-precision use.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace the placeholder QTensorOps implementation (which stored quantized
data as float and ignored quantization entirely) with real MLX-backed
quantization using mlx_rs::ops::quantize, dequantize, and quantized_matmul.

MlxQuantizedTensorPrimitive now stores MLX's native packed uint quantized
arrays alongside scales, biases, and quantization metadata. q_from_data
properly unpacks Burn's QuantizedBytes format, dequantizes on CPU, and
re-quantizes into MLX's native format. q_matmul uses fused quantized_matmul
for the common float×quantized inference path.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…requant cycles

nn::Linear::forward calls weight.unsqueeze() on every forward pass, which
triggers q_reshape → dequantize → reshape → re-quantize for every quantized
matmul. For a 36-layer transformer this is ~180 unnecessary dequant+requant
cycles per token, completely defeating the performance benefit of quantization.

Add fast paths that update only the logical shape metadata when the last 2
dimensions (the actual matrix shape) are unchanged:
- q_reshape: [M, N] → [1, M, N] unsqueezes skip dequant
- q_swap_dims: swaps in batch/prefix dims skip dequant
- q_expand: size-1 prefix expansions skip dequant

All three fall back to the original dequant→op→requant path when the
operation actually touches the matrix dimensions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
MLX's dequantize and quantized_matmul ops may return f32 arrays
regardless of the backend's configured float precision. When using
Mlx<f16>, this causes dtype mismatches in downstream operations.
Cast results via F::cast_array() to ensure consistency.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant