From 1e69d921783639f1fc20b5e3bbb232953b5db261 Mon Sep 17 00:00:00 2001 From: "Yu-Hsuan (Amy) Lin" Date: Wed, 4 Feb 2026 10:40:27 +0000 Subject: [PATCH] Correct fp4 tensor size calculation The new utility will use jnp.finfo and jnp.iinfo to determine the accurate bit width of any dtype, ensuring correct bandwidth metrics for current and future sub-byte types (like int4 or float4). --- Ironwood/src/benchmark_collectives.py | 6 ++++-- Ironwood/src/benchmark_hbm.py | 3 ++- Ironwood/src/benchmark_send_recv.py | 7 ++++--- Ironwood/src/benchmark_utils.py | 12 ++++++++++++ 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/Ironwood/src/benchmark_collectives.py b/Ironwood/src/benchmark_collectives.py index 69b4d21..0889038 100644 --- a/Ironwood/src/benchmark_collectives.py +++ b/Ironwood/src/benchmark_collectives.py @@ -11,6 +11,7 @@ from benchmark_utils import MetricsStatistics from benchmark_utils import multiple_iteration_timeit_from_trace from benchmark_utils import ShardingStrategy +from benchmark_utils import get_real_dtype_bytes from common import MARKER import jax from jax import core @@ -72,7 +73,7 @@ def get_metrics_helper( for key, value in params if value is not None and key not in exclude_keys } - metadata["dtype"] = metadata["dtype"].dtype.itemsize + metadata["dtype"] = get_real_dtype_bytes(metadata["dtype"].dtype) return metadata @@ -98,7 +99,8 @@ def unified_ici_collectives_metrics( hlo_first_replica_group = [] input_num_elements = matrix_shape[0] * matrix_shape[1] * matrix_shape[2] - dtype_bytes = dtype.dtype.itemsize + dtype_name = dtype.dtype.name + dtype_bytes = get_real_dtype_bytes(dtype.dtype) if xla_output: xla_output_json = json.loads(xla_output) hlo_input_shape = xla_output_json.get("hlo_input_shape") diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index bb279f4..67f0429 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -6,6 +6,7 @@ from benchmark_utils import ( MetricsStatistics, multiple_iteration_timeit_from_trace, + get_real_dtype_bytes, ) from common import MARKER import jax @@ -76,7 +77,7 @@ def single_device_hbm_copy_calculate_metrics( metrics = {} # Calculate throughput. - tensor_size_bytes = num_elements * dtype.dtype.itemsize + tensor_size_bytes = num_elements * get_real_dtype_bytes(dtype.dtype) tensor_size_gbytes = (tensor_size_bytes * 2) / 10**9 time_statistics = MetricsStatistics( diff --git a/Ironwood/src/benchmark_send_recv.py b/Ironwood/src/benchmark_send_recv.py index 9095000..c7dd5db 100644 --- a/Ironwood/src/benchmark_send_recv.py +++ b/Ironwood/src/benchmark_send_recv.py @@ -8,6 +8,7 @@ import jax.sharding from benchmark_utils import ( get_trace, + get_real_dtype_bytes, ) from common import MARKER import tempfile @@ -68,7 +69,7 @@ def get_metrics_helper( for key, value in params if value is not None and key not in exclude_keys } - metadata['dtype'] = metadata['dtype'].dtype.itemsize + metadata['dtype'] = get_real_dtype_bytes(metadata['dtype'].dtype) return metadata @@ -84,7 +85,7 @@ def send_recv_benchmark( device_count = jax.local_device_count() devices = mesh_utils.create_device_mesh((device_count,)) mesh = jax.sharding.Mesh(devices, 'x') - item_size = jnp.dtype(dtype).itemsize + item_size = get_real_dtype_bytes(jnp.dtype(dtype)) tensor_size_bytes = num_elements * item_size last_dim = tensor_size_bytes // (1 * 8 * item_size) @@ -161,7 +162,7 @@ def send_recv_benchmark_calculate_metrics( metadata = get_metrics_helper(params) metrics = {} - tensor_size_bytes = num_elements * jnp.dtype(dtype).itemsize + tensor_size_bytes = num_elements * get_real_dtype_bytes(jnp.dtype(dtype)) tensor_size_gbytes = tensor_size_bytes / 10**9 metrics['runtime_ms (ms)'] = runtime_ms diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index e28f39e..ccd4f4c 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -28,6 +28,18 @@ import jax.extend from tensorflow.tsl.profiler.protobuf import xplane_pb2 + +def get_real_dtype_bytes(dtype) -> float: + """Returns the real byte size of a dtype, handling sub-byte types.""" + try: + return jnp.finfo(dtype).bits / 8 + except Exception: + try: + return jnp.iinfo(dtype).bits / 8 + except Exception: + return dtype.itemsize + + # The dictionary to map a JAX (collective) function to its main HLO. TARGET_TASK_NAME_COLLECTIVES_MAP = { "all_to_all_ici_op": r"all-to-all.[0-9]+",