diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 033a6a7ffb..d47bc553b0 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -585,10 +585,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) +@pytest.mark.parametrize("single_param", all_boolean) @pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) @pytest.mark.parametrize("num_gemms", [4]) def test_sanity_grouped_linear( - dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split + dtype, + bs, + model, + fp8_recipe, + fp8_model_params, + use_bias, + single_param, + num_gemms, + empty_split, ): if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: pytest.skip("FP8 model parameters are not supported in debug mode.") @@ -598,6 +607,9 @@ def test_sanity_grouped_linear( bs = bs * 16 num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) + if single_param: + os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1" + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -617,7 +629,8 @@ def test_sanity_grouped_linear( # Verify that weights are stored in contiguous GroupedTensor storage. weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)] if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): - check_grouped_tensor_pointers(weights, fp8_recipe) + if single_param: + check_grouped_tensor_pointers(weights, fp8_recipe) inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True @@ -636,6 +649,9 @@ def test_sanity_grouped_linear( loss.backward() assert out.shape == (num_tokens, ffn_hidden_size) + if single_param: + del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b6596bc2e9..2f859e748b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -6,6 +6,7 @@ from typing import Union, Optional, Callable, Tuple, List from itertools import chain import warnings +import os import functools import torch @@ -793,7 +794,9 @@ def make_grouped_weights(self, defer_init=False) -> None: def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) - self.make_grouped_weights(defer_init=defer_init) + # Grouped tensor weights is an opt-in feature. + if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))): + self.make_grouped_weights(defer_init=defer_init) def set_tensor_parallel_attributes(self, defer_init=False) -> None: """Set attributes needed for TP"""