Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("use_bias", all_boolean)
@pytest.mark.parametrize("single_param", all_boolean)
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
@pytest.mark.parametrize("num_gemms", [4])
def test_sanity_grouped_linear(
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
dtype,
bs,
model,
fp8_recipe,
fp8_model_params,
use_bias,
single_param,
num_gemms,
empty_split,
):
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
pytest.skip("FP8 model parameters are not supported in debug mode.")
Expand All @@ -598,6 +607,9 @@ def test_sanity_grouped_linear(
bs = bs * 16
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)

if single_param:
os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"

if fp8_recipe is not None:
if not is_fp8_supported(config):
pytest.skip("Model config does not support FP8")
Expand All @@ -617,7 +629,8 @@ def test_sanity_grouped_linear(
# Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
check_grouped_tensor_pointers(weights, fp8_recipe)
if single_param:
check_grouped_tensor_pointers(weights, fp8_recipe)

inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
Expand All @@ -636,6 +649,9 @@ def test_sanity_grouped_linear(
loss.backward()
assert out.shape == (num_tokens, ffn_hidden_size)

if single_param:
del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
Expand Down
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import warnings
import os

import functools
import torch
Expand Down Expand Up @@ -793,7 +794,9 @@ def make_grouped_weights(self, defer_init=False) -> None:

def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
self.make_grouped_weights(defer_init=defer_init)
# Grouped tensor weights is an opt-in feature.
if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parsing will raise ValueError if the envvar is set to non-numeric values like "true", "false", or empty string. Consider more robust parsing if users might set this directly:

Suggested change
if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))):
if os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0").lower() in ("1", "true", "yes"):

self.make_grouped_weights(defer_init=defer_init)

def set_tensor_parallel_attributes(self, defer_init=False) -> None:
"""Set attributes needed for TP"""
Expand Down
Loading