Update burn dependency from 0.16 to 0.20#1
Open
mike-marcacci wants to merge 6 commits intoTuringWorks:mainfrom
Open
Update burn dependency from 0.16 to 0.20#1mike-marcacci wants to merge 6 commits intoTuringWorks:mainfrom
mike-marcacci wants to merge 6 commits intoTuringWorks:mainfrom
Conversation
Migrates burn-mlx across 4 major burn versions (0.17-0.20) with all required API changes: Device trait impl, Slice replacing Range<usize>, ExecutionError return types, QuantScheme, renamed scatter/select ops, ceil_mode on pooling, and new required methods (trig, cumulative, bitwise, unfold, cross, cast, matmul for ints). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The default Backend::sync() is a no-op, causing callers (e.g. timing code) to get incorrect results because GPU work hasn't finished yet. This calls mlx_synchronize on the default stream to block until all queued operations complete. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The backend was hardcoded to FloatElem = f32, forcing all computation through f32 even when models use f16 weights. This caused 2x memory bandwidth overhead on Apple Silicon which natively supports f16. Mlx<F: FloatMlxElement = f32> is now generic like Burn's Wgpu backend. Adds MlxHalf and MlxBf16 type aliases for convenient half-precision use. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace the placeholder QTensorOps implementation (which stored quantized data as float and ignored quantization entirely) with real MLX-backed quantization using mlx_rs::ops::quantize, dequantize, and quantized_matmul. MlxQuantizedTensorPrimitive now stores MLX's native packed uint quantized arrays alongside scales, biases, and quantization metadata. q_from_data properly unpacks Burn's QuantizedBytes format, dequantizes on CPU, and re-quantizes into MLX's native format. q_matmul uses fused quantized_matmul for the common float×quantized inference path. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…requant cycles nn::Linear::forward calls weight.unsqueeze() on every forward pass, which triggers q_reshape → dequantize → reshape → re-quantize for every quantized matmul. For a 36-layer transformer this is ~180 unnecessary dequant+requant cycles per token, completely defeating the performance benefit of quantization. Add fast paths that update only the logical shape metadata when the last 2 dimensions (the actual matrix shape) are unchanged: - q_reshape: [M, N] → [1, M, N] unsqueezes skip dequant - q_swap_dims: swaps in batch/prefix dims skip dequant - q_expand: size-1 prefix expansions skip dequant All three fall back to the original dequant→op→requant path when the operation actually touches the matrix dimensions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
MLX's dequantize and quantized_matmul ops may return f32 arrays regardless of the backend's configured float precision. When using Mlx<f16>, this causes dtype mismatches in downstream operations. Cast results via F::cast_array() to ensure consistency.
f85fec7 to
beb4a1b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Migrates burn-mlx across 4 major burn versions (0.17-0.20) with all required API changes: Device trait impl, Slice replacing Range, ExecutionError return types, QuantScheme, renamed scatter/select ops, ceil_mode on pooling, and new required methods (trig, cumulative, bitwise, unfold, cross, cast, matmul for ints).
Note that these changes were made entirely by Claude Code. I wanted to use burn-mlx to help me differentiate performance issues caused directly to my qwen3-burn implementation and inherent limitations of WGPU/CubeCL.
It ended up making a massive difference in my implementation, so I figured I'd open a PR.