diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 54c2614bd..97cbe46f9 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -16,4 +16,16 @@ OutlierAwareLinear, Params4bit, StableEmbedding, + Conv1d4bit, + Conv2d4bit, + Conv3d4bit, + Conv1dFP4, + Conv2dFP4, + Conv3dFP4, + Conv1dNF4, + Conv2dNF4, + Conv3dNF4, + Conv1d8bitLt, + Conv2d8bitLt, + Conv3d8bitLt, ) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 54506f41d..84294807a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1174,3 +1174,477 @@ def forward(self, x): w = self.quantize_weight(self.weight, self.outlier_dim) self.weight.data.copy_(w) self.is_quantized = True + +class Conv4bit(nn.Module): + """Base class for 4-bit quantized convolutional layers. + + Weights are stored in a compressed 4-bit format via :class:`Params4bit`. + During the forward pass the weights are dequantized on-the-fly and the + standard PyTorch ``F.conv*d`` functional API is used for computation. + + This is an abstract base class — use :class:`Conv1d4bit`, :class:`Conv2d4bit`, + or :class:`Conv3d4bit` (and their FP4 / NF4 convenience aliases) instead. + + The approach mirrors how :class:`Embedding4bit` handles non-linear layers: + dequantize the packed weights, reshape to the original kernel shape, and + delegate to the highly-optimised cuDNN convolution kernels. + """ + + def __init__(self, *args, **kwargs): + raise NotImplementedError("Conv4bit is an abstract base class. Use Conv1d4bit, Conv2d4bit, or Conv3d4bit.") + + def _save_to_state_dict(self, destination, prefix, keep_vars): + """Save weight and bias, then append quant_state components.""" + if getattr(self.weight, "quant_state", None) is not None and getattr( + self.weight.quant_state, "packing_format_for_cpu", False + ): + self.weight.data, self.weight.quant_state = _convert_weight_packed_for_cpu_inverse( + self.weight.data, self.weight.quant_state + ) + super()._save_to_state_dict(destination, prefix, keep_vars) + if getattr(self.weight, "quant_state", None) is not None: + for k, v in self.weight.quant_state.as_dict(packed=True).items(): + destination[prefix + "weight." + k] = v if keep_vars else v.detach() + + def _setup_4bit(self, compress_statistics, quant_type, quant_storage): + self.compute_dtype = None + self.compute_type_is_set = False + self.quant_state = None + self.quant_storage = quant_storage + self.support_avx512bf16_for_cpu = has_avx512bf16() + + # Remember the original weight shape so we can restore it after dequantization. + # kernel_size is always a tuple from nn.Conv*d; include it explicitly for clarity. + self._weight_shape = ( + self.out_channels, + self.in_channels // self.groups, + *self.kernel_size, + ) + + # Flatten kernel dimensions into a 2-D matrix for the block-wise quantiser. + weight_flat = self.weight.data.reshape(self.out_channels, -1) + self.weight = Params4bit( + weight_flat, + requires_grad=False, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=quant_storage, + module=self, + ) + + def set_compute_type(self, x): + if x.dtype in [torch.float32, torch.bfloat16]: + self.compute_dtype = x.dtype + elif x.dtype == torch.float16: + if self.compute_dtype in [None, torch.float32]: + logger.warning( + f"Input type into {self.__class__.__name__} is torch.float16, but " + "bnb_4bit_compute_dtype=torch.float32 (default). " + "This will lead to slow inference or training speed.", + ) + + def _get_dequantized_weight(self): + """Dequantize the packed 4-bit weights back to the original kernel shape.""" + fix_4bit_weight_quant_state_from_module(self) + quant_state = self.weight.quant_state + + if ( + not getattr(quant_state, "packing_format_for_cpu", False) + and self.weight.device.type == "cpu" + and self.support_avx512bf16_for_cpu + and not self.training + ): + self.weight.data, quant_state = _convert_weight_packed_for_cpu( + self.weight.data, quant_state + ) + + dequantized_weight = bnb.functional.dequantize_4bit(self.weight.data, quant_state) + return dequantized_weight.reshape(self._weight_shape) + + def _forward_impl(self, x: torch.Tensor, conv_fn): + """Shared forward logic for all Conv*d4bit subclasses.""" + if not self.compute_type_is_set: + self.set_compute_type(x) + self.compute_type_is_set = True + + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + w = self._get_dequantized_weight().to(x.dtype) + bias = None if self.bias is None else self.bias.to(x.dtype) + + out = conv_fn(x, w, bias, self.stride, self.padding, self.dilation, self.groups) + return out.to(inp_dtype) + + +class Conv1d4bit(Conv4bit, nn.Conv1d): + """4-bit quantized 1-D convolution. + + Drop-in replacement for :class:`torch.nn.Conv1d`. Weights are quantized to + 4 bits on ``.to(device)`` and dequantized during each forward pass. + + Example:: + + import bitsandbytes as bnb + + fp_conv = torch.nn.Conv1d(64, 128, 3, padding=1) + q_conv = bnb.nn.Conv1d4bit(64, 128, 3, padding=1) + q_conv.load_state_dict(fp_conv.state_dict(), strict=False) + q_conv = q_conv.to("cuda") # quantization happens here + """ + + _conv_dims = 1 + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=True, padding_mode="zeros", + compute_dtype=None, compress_statistics=True, quant_type="fp4", + quant_storage=torch.uint8, device=None, dtype=None, + ): + nn.Conv1d.__init__( + self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, device, dtype, + ) + self._setup_4bit(compress_statistics, quant_type, quant_storage) + if compute_dtype is not None: + self.compute_dtype = compute_dtype + self.compute_type_is_set = True + + def forward(self, x: torch.Tensor): + return self._forward_impl(x, F.conv1d) + + +class Conv2d4bit(Conv4bit, nn.Conv2d): + """4-bit quantized 2-D convolution. + + Drop-in replacement for :class:`torch.nn.Conv2d`. Weights are quantized to + 4 bits on ``.to(device)`` and dequantized during each forward pass. + + Example:: + + import bitsandbytes as bnb + + fp_conv = torch.nn.Conv2d(3, 64, 3, padding=1) + q_conv = bnb.nn.Conv2d4bit(3, 64, 3, padding=1, quant_type="nf4") + q_conv.load_state_dict(fp_conv.state_dict(), strict=False) + q_conv = q_conv.to("cuda") # quantization happens here + """ + + _conv_dims = 2 + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=True, padding_mode="zeros", + compute_dtype=None, compress_statistics=True, quant_type="fp4", + quant_storage=torch.uint8, device=None, dtype=None, + ): + nn.Conv2d.__init__( + self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, device, dtype, + ) + self._setup_4bit(compress_statistics, quant_type, quant_storage) + if compute_dtype is not None: + self.compute_dtype = compute_dtype + self.compute_type_is_set = True + + def forward(self, x: torch.Tensor): + return self._forward_impl(x, F.conv2d) + + +class Conv3d4bit(Conv4bit, nn.Conv3d): + """4-bit quantized 3-D convolution. + + Drop-in replacement for :class:`torch.nn.Conv3d`. Weights are quantized to + 4 bits on ``.to(device)`` and dequantized during each forward pass. + + Example:: + + import bitsandbytes as bnb + + q_conv = bnb.nn.Conv3d4bit(3, 64, 3, padding=1, quant_type="nf4") + q_conv = q_conv.to("cuda") # quantization happens here + """ + + _conv_dims = 3 + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=True, padding_mode="zeros", + compute_dtype=None, compress_statistics=True, quant_type="fp4", + quant_storage=torch.uint8, device=None, dtype=None, + ): + nn.Conv3d.__init__( + self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, device, dtype, + ) + self._setup_4bit(compress_statistics, quant_type, quant_storage) + if compute_dtype is not None: + self.compute_dtype = compute_dtype + self.compute_type_is_set = True + + def forward(self, x: torch.Tensor): + return self._forward_impl(x, F.conv3d) + + +class Conv1dFP4(Conv1d4bit): + """``Conv1d4bit`` pre-configured with ``quant_type='fp4'``.""" + + def __init__(self, *args, **kwargs): + kwargs["quant_type"] = "fp4" + super().__init__(*args, **kwargs) + + +class Conv1dNF4(Conv1d4bit): + """``Conv1d4bit`` pre-configured with ``quant_type='nf4'``.""" + + def __init__(self, *args, **kwargs): + kwargs["quant_type"] = "nf4" + super().__init__(*args, **kwargs) + + +class Conv2dFP4(Conv2d4bit): + """``Conv2d4bit`` pre-configured with ``quant_type='fp4'``.""" + + def __init__(self, *args, **kwargs): + kwargs["quant_type"] = "fp4" + super().__init__(*args, **kwargs) + + +class Conv2dNF4(Conv2d4bit): + """``Conv2d4bit`` pre-configured with ``quant_type='nf4'``.""" + + def __init__(self, *args, **kwargs): + kwargs["quant_type"] = "nf4" + super().__init__(*args, **kwargs) + + +class Conv3dFP4(Conv3d4bit): + """``Conv3d4bit`` pre-configured with ``quant_type='fp4'``.""" + + def __init__(self, *args, **kwargs): + kwargs["quant_type"] = "fp4" + super().__init__(*args, **kwargs) + + +class Conv3dNF4(Conv3d4bit): + """``Conv3d4bit`` pre-configured with ``quant_type='nf4'``.""" + + def __init__(self, *args, **kwargs): + kwargs["quant_type"] = "nf4" + super().__init__(*args, **kwargs) + + +class Conv8bitLt(nn.Module): + """Base class for 8-bit quantized convolutional layers. + + Weights are stored as ``Int8Params`` (row-wise absmax-quantized int8). + During the forward pass the weights are dequantized and the standard + PyTorch ``F.conv*d`` functional API is used for computation. + + This is an abstract base class — use :class:`Conv1d8bitLt`, + :class:`Conv2d8bitLt`, or :class:`Conv3d8bitLt` instead. + """ + + def __init__(self, *args, **kwargs): + raise NotImplementedError("Conv8bitLt is an abstract base class. Use Conv1d8bitLt, Conv2d8bitLt, or Conv3d8bitLt.") + + def _setup_8bit(self, has_fp16_weights): + self.state = bnb.MatmulLtState() + self.state.has_fp16_weights = has_fp16_weights + + # Remember original kernel shape for reshape after dequantization. + self._weight_shape = ( + self.out_channels, + self.in_channels // self.groups, + *self.kernel_size, + ) + + self.weight = Int8Params( + self.weight.data.reshape(self.out_channels, -1), + has_fp16_weights=has_fp16_weights, + requires_grad=has_fp16_weights, + ) + self._register_load_state_dict_pre_hook(maybe_rearrange_weight) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super()._save_to_state_dict(destination, prefix, keep_vars) + scb_name = "SCB" + param_from_weight = getattr(self.weight, scb_name) + param_from_state = getattr(self.state, scb_name) + key_name = prefix + f"{scb_name}" + format_name = prefix + "weight_format" + + if not self.state.has_fp16_weights: + if param_from_weight is not None: + destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() + destination[format_name] = torch.tensor(0, dtype=torch.uint8) + elif param_from_state is not None: + destination[key_name] = param_from_state if keep_vars else param_from_state.detach() + destination[format_name] = torch.tensor(0, dtype=torch.uint8) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs, + ): + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs, + ) + unexpected_copy = list(unexpected_keys) + for key in unexpected_copy: + input_name = key[len(prefix):] + if input_name == "SCB": + if self.weight.SCB is None: + raise RuntimeError( + "Loading a quantized checkpoint into non-quantized Conv8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()", + ) + input_param = state_dict[key] + self.weight.SCB.copy_(input_param) + if self.state.SCB is not None: + self.state.SCB = self.weight.SCB + unexpected_keys.remove(key) + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def to(self, *args, **kwargs): + result = super().to(*args, **kwargs) + device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs) + if device is not None: + if result.state.CB is not None: + result.state.CB = result.state.CB.to(device) + if result.state.SCB is not None: + result.state.SCB = result.state.SCB.to(device) + return result + + def _get_dequantized_weight(self, x_dtype): + """Dequantize Int8 weights back to the original kernel shape.""" + if self.state.has_fp16_weights: + return self.weight.data.reshape(self._weight_shape).to(x_dtype) + + if self.weight.CB is not None: + self.init_8bit_state() + + if self.state.CB is None: + w = self.weight.data + else: + # row-wise dequantization: w_fp = CB_int8 * (SCB / 127) + w = self.state.CB.to(x_dtype) * (self.state.SCB.to(x_dtype) / 127.0).unsqueeze(1) + + return w.reshape(self._weight_shape) + + def _forward_impl(self, x: torch.Tensor, conv_fn): + """Shared forward logic for all Conv*d8bitLt subclasses.""" + self.state.is_training = self.training + + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + w = self._get_dequantized_weight(x.dtype) + bias = None if self.bias is None else self.bias.to(x.dtype) + + out = conv_fn(x, w, bias, self.stride, self.padding, self.dilation, self.groups) + + if not self.state.has_fp16_weights and self.state.CB is not None: + self.weight.data = self.state.CB + + return out + + +class Conv1d8bitLt(Conv8bitLt, nn.Conv1d): + """8-bit quantized 1-D convolution (LLM.int8()-style). + + Drop-in replacement for :class:`torch.nn.Conv1d`. + + Example:: + + import bitsandbytes as bnb + + q_conv = bnb.nn.Conv1d8bitLt(64, 128, 3, padding=1, has_fp16_weights=False) + q_conv = q_conv.to("cuda") # quantization happens here + """ + + _conv_dims = 1 + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=True, padding_mode="zeros", + has_fp16_weights=True, device=None, dtype=None, + ): + nn.Conv1d.__init__( + self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, device, dtype, + ) + self._setup_8bit(has_fp16_weights) + + def forward(self, x: torch.Tensor): + return self._forward_impl(x, F.conv1d) + + +class Conv2d8bitLt(Conv8bitLt, nn.Conv2d): + """8-bit quantized 2-D convolution (LLM.int8()-style). + + Drop-in replacement for :class:`torch.nn.Conv2d`. + + Example:: + + import bitsandbytes as bnb + + q_conv = bnb.nn.Conv2d8bitLt(3, 64, 3, padding=1, has_fp16_weights=False) + q_conv = q_conv.to("cuda") # quantization happens here + """ + + _conv_dims = 2 + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=True, padding_mode="zeros", + has_fp16_weights=True, device=None, dtype=None, + ): + nn.Conv2d.__init__( + self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, device, dtype, + ) + self._setup_8bit(has_fp16_weights) + + def forward(self, x: torch.Tensor): + return self._forward_impl(x, F.conv2d) + + +class Conv3d8bitLt(Conv8bitLt, nn.Conv3d): + """8-bit quantized 3-D convolution (LLM.int8()-style). + + Drop-in replacement for :class:`torch.nn.Conv3d`. + + Example:: + + import bitsandbytes as bnb + + q_conv = bnb.nn.Conv3d8bitLt(3, 64, 3, padding=1, has_fp16_weights=False) + q_conv = q_conv.to("cuda") # quantization happens here + """ + + _conv_dims = 3 + + def __init__( + self, in_channels, out_channels, kernel_size, stride=1, padding=0, + dilation=1, groups=1, bias=True, padding_mode="zeros", + has_fp16_weights=True, device=None, dtype=None, + ): + nn.Conv3d.__init__( + self, in_channels, out_channels, kernel_size, + stride, padding, dilation, groups, bias, padding_mode, device, dtype, + ) + self._setup_8bit(has_fp16_weights) + + def forward(self, x: torch.Tensor): + return self._forward_impl(x, F.conv3d) + diff --git a/tests/test_conv.py b/tests/test_conv.py new file mode 100644 index 000000000..9108a481f --- /dev/null +++ b/tests/test_conv.py @@ -0,0 +1,451 @@ +"""Tests for quantized Conv layers (Conv*d4bit, Conv*d8bitLt). + +Follows the same patterns used in test_linear4bit.py and test_modules.py. +""" + +import pytest +import torch +import torch.nn as nn + +import bitsandbytes as bnb +from bitsandbytes.nn.modules import ( + Conv1d4bit, + Conv2d4bit, + Conv3d4bit, + Conv1dFP4, + Conv1dNF4, + Conv2dFP4, + Conv2dNF4, + Conv3dFP4, + Conv3dNF4, + Conv1d8bitLt, + Conv2d8bitLt, + Conv3d8bitLt, + Params4bit, + Int8Params, +) +from tests.helpers import ( + TRUE_FALSE, + get_available_devices, + id_formatter, +) + + +# --------------------------------------------------------------------------- +# Fixtures and helpers +# --------------------------------------------------------------------------- + +CONV4BIT_CLASSES_1D = [Conv1d4bit, Conv1dFP4, Conv1dNF4] +CONV4BIT_CLASSES_2D = [Conv2d4bit, Conv2dFP4, Conv2dNF4] +CONV4BIT_CLASSES_3D = [Conv3d4bit, Conv3dFP4, Conv3dNF4] +CONV8BIT_CLASSES = [Conv1d8bitLt, Conv2d8bitLt, Conv3d8bitLt] + + +# --------------------------------------------------------------------------- +# 1. Construction / parameter type tests (no device movement needed) +# --------------------------------------------------------------------------- + + +class TestConv4bitConstruction: + """Verify that Conv*d4bit layers can be constructed and have the right param types.""" + + def test_conv1d4bit_creates_params4bit_weight(self): + layer = Conv1d4bit(16, 32, 3, padding=1) + assert isinstance(layer.weight, Params4bit) + assert layer.weight.quant_type == "fp4" + assert layer._weight_shape == (32, 16, 3) + + def test_conv2d4bit_creates_params4bit_weight(self): + layer = Conv2d4bit(3, 64, 3, padding=1) + assert isinstance(layer.weight, Params4bit) + assert layer._weight_shape == (64, 3, 3, 3) + + def test_conv3d4bit_creates_params4bit_weight(self): + layer = Conv3d4bit(3, 16, (3, 3, 3), padding=1) + assert isinstance(layer.weight, Params4bit) + assert layer._weight_shape == (16, 3, 3, 3, 3) + + @pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) + def test_conv2d4bit_quant_type(self, quant_type): + layer = Conv2d4bit(8, 16, 3, quant_type=quant_type) + assert layer.weight.quant_type == quant_type + + def test_conv1d_fp4_alias(self): + layer = Conv1dFP4(16, 32, 3) + assert isinstance(layer, Conv1d4bit) + assert layer.weight.quant_type == "fp4" + + def test_conv1d_nf4_alias(self): + layer = Conv1dNF4(16, 32, 3) + assert isinstance(layer, Conv1d4bit) + assert layer.weight.quant_type == "nf4" + + def test_conv2d_fp4_alias(self): + layer = Conv2dFP4(3, 16, 3) + assert isinstance(layer, Conv2d4bit) + assert layer.weight.quant_type == "fp4" + + def test_conv2d_nf4_alias(self): + layer = Conv2dNF4(3, 16, 3) + assert isinstance(layer, Conv2d4bit) + assert layer.weight.quant_type == "nf4" + + def test_conv3d_fp4_alias(self): + layer = Conv3dFP4(3, 16, 3) + assert isinstance(layer, Conv3d4bit) + assert layer.weight.quant_type == "fp4" + + def test_conv3d_nf4_alias(self): + layer = Conv3dNF4(3, 16, 3) + assert isinstance(layer, Conv3d4bit) + assert layer.weight.quant_type == "nf4" + + @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) + def test_conv2d4bit_bias(self, bias): + layer = Conv2d4bit(3, 16, 3, bias=bias) + if bias: + assert layer.bias is not None + assert layer.bias.shape == (16,) + else: + assert layer.bias is None + + def test_conv2d4bit_groups(self): + layer = Conv2d4bit(16, 32, 3, groups=4, padding=1) + assert layer.groups == 4 + # weight shape should be (out_ch, in_ch/groups, kH, kW) + assert layer._weight_shape == (32, 4, 3, 3) + + def test_conv2d4bit_asymmetric_kernel(self): + layer = Conv2d4bit(3, 16, (5, 3), padding=(2, 1)) + assert layer._weight_shape == (16, 3, 5, 3) + + def test_conv2d4bit_no_bias_weight_shape(self): + layer = Conv2d4bit(3, 16, 3, bias=False) + # Flatten should be (out_channels, in_channels * kH * kW) + assert layer.weight.data.shape[0] <= 16 * 3 * 3 * 3 # packed, so could be smaller + + def test_compute_dtype_is_stored(self): + layer = Conv2d4bit(3, 16, 3, compute_dtype=torch.bfloat16) + assert layer.compute_dtype == torch.bfloat16 + assert layer.compute_type_is_set is True + + def test_compute_dtype_default_is_none(self): + layer = Conv2d4bit(3, 16, 3) + assert layer.compute_dtype is None + assert layer.compute_type_is_set is False + + +class TestConv8bitLtConstruction: + """Verify that Conv*d8bitLt layers can be constructed and have the right param types.""" + + def test_conv1d8bit_creates_int8params_weight(self): + layer = Conv1d8bitLt(16, 32, 3, padding=1) + assert isinstance(layer.weight, Int8Params) + assert layer._weight_shape == (32, 16, 3) + + def test_conv2d8bit_creates_int8params_weight(self): + layer = Conv2d8bitLt(3, 64, 3, padding=1) + assert isinstance(layer.weight, Int8Params) + assert layer._weight_shape == (64, 3, 3, 3) + + def test_conv3d8bit_creates_int8params_weight(self): + layer = Conv3d8bitLt(3, 16, (3, 3, 3), padding=1) + assert isinstance(layer.weight, Int8Params) + assert layer._weight_shape == (16, 3, 3, 3, 3) + + @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("fp16w")) + def test_conv2d8bit_has_fp16_weights(self, has_fp16_weights): + layer = Conv2d8bitLt(3, 16, 3, has_fp16_weights=has_fp16_weights) + assert layer.state.has_fp16_weights == has_fp16_weights + assert layer.weight.has_fp16_weights == has_fp16_weights + + @pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) + def test_conv2d8bit_bias(self, bias): + layer = Conv2d8bitLt(3, 16, 3, bias=bias) + if bias: + assert layer.bias is not None + else: + assert layer.bias is None + + def test_conv2d8bit_groups(self): + layer = Conv2d8bitLt(16, 32, 3, groups=4, padding=1) + assert layer.groups == 4 + assert layer._weight_shape == (32, 4, 3, 3) + + +# --------------------------------------------------------------------------- +# 2. isinstance / subclass checks +# --------------------------------------------------------------------------- + + +class TestInheritance: + """Check MRO and isinstance relationships.""" + + def test_conv1d4bit_is_conv1d(self): + layer = Conv1d4bit(16, 32, 3) + assert isinstance(layer, nn.Conv1d) + assert isinstance(layer, nn.Module) + + def test_conv2d4bit_is_conv2d(self): + layer = Conv2d4bit(3, 16, 3) + assert isinstance(layer, nn.Conv2d) + + def test_conv3d4bit_is_conv3d(self): + layer = Conv3d4bit(3, 16, 3) + assert isinstance(layer, nn.Conv3d) + + def test_conv1d8bit_is_conv1d(self): + layer = Conv1d8bitLt(16, 32, 3) + assert isinstance(layer, nn.Conv1d) + + def test_conv2d8bit_is_conv2d(self): + layer = Conv2d8bitLt(3, 16, 3) + assert isinstance(layer, nn.Conv2d) + + def test_conv3d8bit_is_conv3d(self): + layer = Conv3d8bitLt(3, 16, 3) + assert isinstance(layer, nn.Conv3d) + + +# --------------------------------------------------------------------------- +# 3. Weight shape bookkeeping +# --------------------------------------------------------------------------- + + +class TestWeightShape: + """Verify _weight_shape is correctly computed for various configurations.""" + + @pytest.mark.parametrize( + "in_ch, out_ch, ks, groups, expected_shape", + [ + (16, 32, 3, 1, (32, 16, 3)), + (16, 32, 5, 1, (32, 16, 5)), + (16, 32, 3, 4, (32, 4, 3)), + (16, 32, 1, 1, (32, 16, 1)), + ], + ) + def test_conv1d4bit_weight_shape(self, in_ch, out_ch, ks, groups, expected_shape): + layer = Conv1d4bit(in_ch, out_ch, ks, groups=groups) + assert layer._weight_shape == expected_shape + + @pytest.mark.parametrize( + "in_ch, out_ch, ks, groups, expected_shape", + [ + (3, 64, 3, 1, (64, 3, 3, 3)), + (16, 32, (5, 3), 1, (32, 16, 5, 3)), + (16, 32, 3, 4, (32, 4, 3, 3)), + (64, 64, 1, 1, (64, 64, 1, 1)), + ], + ) + def test_conv2d4bit_weight_shape(self, in_ch, out_ch, ks, groups, expected_shape): + layer = Conv2d4bit(in_ch, out_ch, ks, groups=groups) + assert layer._weight_shape == expected_shape + + @pytest.mark.parametrize( + "in_ch, out_ch, ks, groups, expected_shape", + [ + (3, 16, 3, 1, (16, 3, 3, 3, 3)), + (16, 32, (3, 5, 7), 1, (32, 16, 3, 5, 7)), + (16, 32, 3, 4, (32, 4, 3, 3, 3)), + ], + ) + def test_conv3d4bit_weight_shape(self, in_ch, out_ch, ks, groups, expected_shape): + layer = Conv3d4bit(in_ch, out_ch, ks, groups=groups) + assert layer._weight_shape == expected_shape + + +# --------------------------------------------------------------------------- +# 4. Conv attributes preserved (stride, padding, dilation) +# --------------------------------------------------------------------------- + + +class TestConvAttributes: + """Ensure standard Conv attributes survive quantisation wrapping.""" + + def test_conv1d4bit_attributes(self): + layer = Conv1d4bit(16, 32, 5, stride=2, padding=2, dilation=1) + assert layer.stride == (2,) + assert layer.padding == (2,) + assert layer.dilation == (1,) + assert layer.kernel_size == (5,) + assert layer.in_channels == 16 + assert layer.out_channels == 32 + + def test_conv2d4bit_attributes(self): + layer = Conv2d4bit(3, 64, (3, 5), stride=(1, 2), padding=(1, 2), dilation=(1, 1)) + assert layer.stride == (1, 2) + assert layer.padding == (1, 2) + assert layer.kernel_size == (3, 5) + + def test_conv2d8bit_attributes(self): + layer = Conv2d8bitLt(3, 64, 3, stride=2, padding=1) + assert layer.stride == (2, 2) + assert layer.padding == (1, 1) + assert layer.kernel_size == (3, 3) + + +# --------------------------------------------------------------------------- +# 5. Forward pass tests (require CUDA / accelerator) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) +class TestConv4bitForward: + """Test forward pass output shapes and approximate numerical correctness.""" + + def test_conv1d4bit_forward_shape(self, device, quant_type, bias): + layer = Conv1d4bit(16, 32, 3, padding=1, bias=bias, quant_type=quant_type) + layer = layer.to(device) + x = torch.randn(2, 16, 20, device=device, dtype=torch.float32) + out = layer(x) + assert out.shape == (2, 32, 20) + assert out.dtype == x.dtype + + def test_conv2d4bit_forward_shape(self, device, quant_type, bias): + layer = Conv2d4bit(3, 64, 3, padding=1, bias=bias, quant_type=quant_type) + layer = layer.to(device) + x = torch.randn(2, 3, 8, 8, device=device, dtype=torch.float32) + out = layer(x) + assert out.shape == (2, 64, 8, 8) + + def test_conv2d4bit_forward_groups(self, device, quant_type, bias): + layer = Conv2d4bit(16, 32, 3, padding=1, groups=4, bias=bias, quant_type=quant_type) + layer = layer.to(device) + x = torch.randn(2, 16, 8, 8, device=device, dtype=torch.float32) + out = layer(x) + assert out.shape == (2, 32, 8, 8) + + def test_conv2d4bit_forward_asymmetric_kernel(self, device, quant_type, bias): + layer = Conv2d4bit(3, 16, (5, 3), padding=(2, 1), bias=bias, quant_type=quant_type) + layer = layer.to(device) + x = torch.randn(1, 3, 12, 12, device=device, dtype=torch.float32) + out = layer(x) + assert out.shape == (1, 16, 12, 12) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +class TestConv4bitNumericalParity: + """Compare quantized Conv output with full-precision Conv (rough parity).""" + + def test_conv2d4bit_approximate_parity(self, device, quant_type): + """Output of quantized conv should be within ~20% of fp conv for random weights.""" + torch.manual_seed(42) + fp_conv = nn.Conv2d(16, 32, 3, padding=1, bias=True) + + q_conv = Conv2d4bit(16, 32, 3, padding=1, bias=True, quant_type=quant_type) + # Copy the weights before quantisation + q_conv.weight = Params4bit( + data=fp_conv.weight.data.reshape(32, -1).clone(), + quant_type=quant_type, + requires_grad=False, + ) + q_conv.bias = nn.Parameter(fp_conv.bias.data.clone()) + + fp_conv = fp_conv.to(device) + q_conv = q_conv.to(device) + + x = torch.randn(4, 16, 8, 8, device=device, dtype=torch.float32) + fp_out = fp_conv(x) + q_out = q_conv(x) + + assert fp_out.shape == q_out.shape + # 4-bit quantisation introduces error; we check the values are in the same ballpark + rel_err = (fp_out - q_out).abs().mean() / fp_out.abs().mean() + assert rel_err < 0.5, f"Relative error too large: {rel_err:.4f}" + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("bias", TRUE_FALSE, ids=id_formatter("bias")) +class TestConv8bitLtForward: + """Test forward pass for 8-bit conv layers.""" + + def test_conv1d8bit_forward_fp16_weights(self, device, bias): + layer = Conv1d8bitLt(16, 32, 3, padding=1, bias=bias, has_fp16_weights=True) + layer = layer.to(device) + x = torch.randn(2, 16, 20, device=device, dtype=torch.float32) + out = layer(x) + assert out.shape == (2, 32, 20) + + def test_conv2d8bit_forward_fp16_weights(self, device, bias): + layer = Conv2d8bitLt(3, 64, 3, padding=1, bias=bias, has_fp16_weights=True) + layer = layer.to(device) + x = torch.randn(2, 3, 8, 8, device=device, dtype=torch.float32) + out = layer(x) + assert out.shape == (2, 64, 8, 8) + + def test_conv2d8bit_forward_int8_weights(self, device, bias): + layer = Conv2d8bitLt(16, 32, 3, padding=1, bias=bias, has_fp16_weights=False) + layer = layer.to(device) + x = torch.randn(2, 16, 8, 8, device=device, dtype=torch.float16) + out = layer(x) + assert out.shape == (2, 32, 8, 8) + + def test_conv2d8bit_forward_groups(self, device, bias): + layer = Conv2d8bitLt(16, 32, 3, padding=1, groups=4, bias=bias, has_fp16_weights=True) + layer = layer.to(device) + x = torch.randn(2, 16, 8, 8, device=device, dtype=torch.float32) + out = layer(x) + assert out.shape == (2, 32, 8, 8) + + +# --------------------------------------------------------------------------- +# 6. Module export checks +# --------------------------------------------------------------------------- + + +class TestModuleExports: + """Verify all new Conv classes are accessible via bnb.nn.*.""" + + @pytest.mark.parametrize( + "name", + [ + "Conv1d4bit", + "Conv2d4bit", + "Conv3d4bit", + "Conv1dFP4", + "Conv2dFP4", + "Conv3dFP4", + "Conv1dNF4", + "Conv2dNF4", + "Conv3dNF4", + "Conv1d8bitLt", + "Conv2d8bitLt", + "Conv3d8bitLt", + ], + ) + def test_accessible_via_bnb_nn(self, name): + assert hasattr(bnb.nn, name), f"bnb.nn.{name} not found" + cls = getattr(bnb.nn, name) + assert isinstance(cls, type), f"bnb.nn.{name} is not a class" + + +# --------------------------------------------------------------------------- +# 7. Eval / Training mode +# --------------------------------------------------------------------------- + + +class TestTrainEvalMode: + """Ensure train/eval modes propagate correctly.""" + + def test_conv2d4bit_eval_mode(self): + layer = Conv2d4bit(3, 16, 3) + layer.eval() + assert not layer.training + + def test_conv2d4bit_train_mode(self): + layer = Conv2d4bit(3, 16, 3) + layer.train() + assert layer.training + + def test_conv2d8bit_eval_mode(self): + layer = Conv2d8bitLt(3, 16, 3) + layer.eval() + assert not layer.training + + def test_conv2d8bit_train_mode(self): + layer = Conv2d8bitLt(3, 16, 3) + layer.train() + assert layer.training