Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Pylint Style Check

on:
pull_request:
paths:
- '**.py'

jobs:
pylint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.13"
cache: 'pip'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi

- name: Run Pylint
run: |
# Use the repository's .pylintrc rules on all python files
pylint $(git ls-files '*.py')
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml

- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint
language: system
types: [python]
require_serial: true
# Optional: You can list specific files to exclude here if needed
# exclude: ^tests/
9 changes: 7 additions & 2 deletions Ironwood/src/benchmark_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def create_mesh(ici_size: int, mesh_shape: str) -> Mesh:


def get_sharding_axis(dim_str: str, mesh: Mesh) -> tuple[str, ...]:
"""Computes sharding axis names from dimension string like '1x4' and mesh."""
"""Computes sharding axis names from dimension string and mesh."""
# Example of a dimension string is '1x4'
dim_tuple = dim_str.split("x")
dim_tuple = tuple(int(dim) for dim in dim_tuple)
sharding_axis = tuple(
Expand Down Expand Up @@ -203,6 +204,7 @@ def psum_benchmark(
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
# pylint: disable=unused-argument
"""Benchmarks the psum collective operation.

Args:
Expand Down Expand Up @@ -354,6 +356,7 @@ def psum_scatter_benchmark(
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
# pylint: disable=unused-argument
"""Benchmarks the psum_scatter collective operation.

Args:
Expand All @@ -376,7 +379,7 @@ def psum_scatter_benchmark(
"--xla_sc_disable_megacore_partitioning=true",
"--xla_tpu_disable_sparse_core_collective_offload_remover=true",
"--xla_tpu_enable_reduce_scatter_offload_tracing=true",
"--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true",
"--xla_tpu_enable_sparse_core_collective_offload_nd_reduce_scatter=true", # pylint: disable=line-too-long
"--xla_tpu_enable_sparse_core_collective_offload_reduce_scatter=true",
"--xla_tpu_enable_sparse_core_reduce_scatter_v2=true",
"--xla_tpu_use_tc_device_shape_on_sc=true",
Expand Down Expand Up @@ -470,6 +473,7 @@ def all_gather_benchmark(
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
# pylint: disable=unused-argument
"""Benchmarks the all_gather collective operation.

Args:
Expand Down Expand Up @@ -586,6 +590,7 @@ def all_to_all_benchmark(
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
# pylint: disable=unused-argument
"""Benchmarks the all_to_all collective operation.

Args:
Expand Down
31 changes: 19 additions & 12 deletions Ironwood/src/benchmark_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,11 @@ def swiglu_fwd(

def f(x):
with jax.named_scope(MARKER):
A, B = jnp.split(x, 2, axis=-1)
A_fp32 = A.astype(jnp.float32)
B_fp32 = B.astype(jnp.float32)
Y_fp32 = jax.nn.silu(A_fp32) * B_fp32
return Y_fp32.astype(jnp.bfloat16)
a, b = jnp.split(x, 2, axis=-1)
a_fp32 = a.astype(jnp.float32)
b_fp32 = b.astype(jnp.float32)
y_fp32 = jax.nn.silu(a_fp32) * b_fp32
return y_fp32.astype(jnp.bfloat16)

mesh = create_mesh(SHARDING_STRATEGY)
x_sharding = get_rowwise_named_shading(mesh, SHARDING_STRATEGY)
Expand Down Expand Up @@ -379,16 +379,17 @@ def swiglu_bwd(
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
# pylint: disable=invalid-name
"""
Inverse of swiglu_fwd
"""

def f_fwd(x):
A, B = jnp.split(x, 2, axis=-1)
A_fp32 = A.astype(jnp.float32)
B_fp32 = B.astype(jnp.float32)
Y_fp32 = jax.nn.silu(A_fp32) * B_fp32
return Y_fp32.astype(jnp.bfloat16)
a, b = jnp.split(x, 2, axis=-1)
a_fp32 = a.astype(jnp.float32)
b_fp32 = b.astype(jnp.float32)
y_fp32 = jax.nn.silu(a_fp32) * b_fp32
return y_fp32.astype(jnp.bfloat16)

def f(x: jax.Array, dy: jax.Array) -> jax.Array:
"""
Expand All @@ -397,7 +398,10 @@ def f(x: jax.Array, dy: jax.Array) -> jax.Array:
"""
# Get the VJP "pullback" function
# We ignore the forward result (_y)
_y, pullback_fn = jax.vjp(f_fwd, x)
# pylint: disable=unused-variable,invalid-name
_y, pullback_fn = jax.vjp(
f_fwd, x
)
with jax.named_scope(MARKER):
# Call the pullback function with the upstream gradient
# This IS the backward pass.
Expand Down Expand Up @@ -555,7 +559,10 @@ def f(x: jax.Array, dy: jax.Array) -> jax.Array:
"""
# Get the VJP "pullback" function
# We ignore the forward result (_y)
_y, pullback_fn = jax.vjp(f_fwd, x)
# pylint: disable=unused-variable,invalid-name
_y, pullback_fn = jax.vjp(
f_fwd, x
)
with jax.named_scope(MARKER):
# Call the pullback function with the upstream gradient
# This IS the backward pass.
Expand Down
15 changes: 8 additions & 7 deletions Ironwood/src/benchmark_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def gemm_multiple_run(
) -> Dict[str, Any]:
"""Benchmarks the OUT<M, N>:BF16 = IN0<M, K> dtype x IN1<N, K>:dtype."""

"""Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16."""
# Accumulation is FP32. Current supported dtype: float8_e4m3fn,
# bfloat16.

def f(x, y):
with jax.named_scope(MARKER):
Expand Down Expand Up @@ -170,8 +171,7 @@ def gemm_simple(
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks the OUT<M, N>:BF16 = IN0<M, K>:FP8 x IN1<N, K>:FP8."""

"""Accumulation is FP32."""
# Accumulation is FP32.

def f(x, y):
with jax.named_scope(MARKER):
Expand Down Expand Up @@ -266,8 +266,7 @@ def gemm_simple_with_dtype(
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks the OUT<M, N>:BF16 = IN0<M, K>:FP8 x IN1<N, K>:FP8."""

"""Accumulation is FP32."""
# Accumulation is FP32.

# Convert string dtypes to jnp dtypes
lhs_dtype = str_to_dtype(in_dtype_str)
Expand Down Expand Up @@ -368,7 +367,8 @@ def gemm_simple_with_dtype_calculate_metrics(
def gemm(
m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None
) -> Dict[str, Any]:
"""OUT<M, N>:BF16 = matmul(IN0<M, K>:FP8, IN1<N, K>:FP8) * outer_product(SF0<M, 1>:FP32 * SF1<1, N>:FP32)."""
"""OUT<M, N>:BF16 = matmul(IN0<M, K>:FP8, IN1<N, K>:FP8) *
outer_product(SF0<M, 1>:FP32 * SF1<1, N>:FP32)."""

def f(x, y, scale_m, scale_n):
with jax.named_scope(MARKER):
Expand Down Expand Up @@ -473,7 +473,8 @@ def gemm_accum(
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
"""OUT<M, N>:FP32 += matmul(IN0<M, K>:FP8, IN1<N, K>:FP8) * outer_product(SF0<M, 1>:FP32 * SF1<1, N>:FP32)."""
"""OUT<M, N>:FP32 += matmul(IN0<M, K>:FP8, IN1<N, K>:FP8) *
outer_product(SF0<M, 1>:FP32 * SF1<1, N>:FP32)."""

def f(out_buffer, x, y, scale_m, scale_n):
with jax.named_scope(MARKER):
Expand Down
6 changes: 2 additions & 4 deletions Ironwood/src/benchmark_gemm_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ def gemm_fp8_b128_fp32(
m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None
) -> Dict[str, Any]:
"""FP8 GEMM as DeepSeek-stype quantization, block size: 1x128."""

"""Use dynamic scaling factors."""
# Use dynamic scaling factors.

def f(x, y):
with jax.named_scope(MARKER):
Expand Down Expand Up @@ -387,8 +386,7 @@ def gemm_fp8_b128_fp32_static_scaling(
m: int, k: int, n: int, num_runs: int = 1, trace_dir: str = None
) -> Dict[str, Any]:
"""FP8 GEMM as DeepSeek-stype quantization, block size: 1x128."""

"""Use static scaling factors."""
# Use static scaling factors.

def f(x, y):
with jax.named_scope(MARKER):
Expand Down
4 changes: 2 additions & 2 deletions Ironwood/src/benchmark_hbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def single_device_hbm_copy_calculate_metrics(
)
print(
f"Tensor size: {tensor_size_bytes / 1024**2} MB, "
f"time taken (median): {time_statistics.statistics['p50']:.4f} ms, "
f"bandwidth (median): {statistics.statistics['p50']:.3f} GB/s"
f"time taken (median): {time_statistics.statistics["p50"]:.4f} ms, "
f"bandwidth (median): {statistics.statistics["p50"]:.3f} GB/s"
)
print()
# Gather the metrics to report.
Expand Down
9 changes: 6 additions & 3 deletions Ironwood/src/benchmark_host_device.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Benchmarks Host-to-Device and Device-to-Host transfer performance (Simple Baseline)."""
"""
Benchmarks Host-to-Device and Device-to-Host transfer performance
(Simple Baseline).
"""

import time
import os
Expand Down Expand Up @@ -123,8 +126,8 @@ def add_metric(name, ms_list):
]
stats_bw = MetricsStatistics(bw_list, f"{name}_bw (GiB/s)")
print(
f" {name}_bw (GiB/s) median: {stats_bw.statistics['p50']}, "
f"P95: {stats_bw.statistics['p95']}",
f"{name}_bw (GiB/s) median: {stats_bw.statistics["p50"]}, "
f"P95: {stats_bw.statistics["p95"]}",
flush=True,
)
metrics.update(stats_bw.serialize_statistics())
Expand Down
86 changes: 0 additions & 86 deletions Ironwood/src/benchmark_inference_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,89 +347,3 @@ def sigmoid_calculate_metrics(
dtype=dtype.dtype.name,
)


# def get_output_named_shading(mesh, strategy: ShardingStrategy):
# match strategy:
# case ShardingStrategy.NO_SHARDING:
# return NamedSharding(mesh, P(None))
# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M:
# return NamedSharding(mesh, P("device"))
# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M:
# return NamedSharding(mesh, P("device"))
# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N:
# assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N:
# assert False, f"ShardingStrategy is wrong for this ops: {strategy}"

# def get_out_sharding(strategy: ShardingStrategy):
# match strategy:
# case ShardingStrategy.NO_SHARDING:
# return P(None)
# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M:
# return P("device")
# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_M:
# return P("device")
# case ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_N:
# assert False, f"ShardingStrategy is wrong for this ops: {strategy}"
# case ShardingStrategy.SHARDING_ON_SINGLE_CHIP_WITH_N:
# assert False, f"ShardingStrategy is wrong for this ops: {strategy}"

# def add(m: int, dtype: jnp.dtype, num_runs: int = 1, trace_dir: str = None,
# ) -> Dict[str, Any]:
# """
# Z = X + Y
# """
# def f(x, y):
# with jax.named_scope(MARKER):
# return x + y

# mesh = create_mesh(SHARDING_STRATEGY)
# x_sharding = get_output_named_shading(mesh, SHARDING_STRATEGY)
# y_sharding = get_output_named_shading(mesh, SHARDING_STRATEGY)
# out_sharding = get_out_sharding(SHARDING_STRATEGY)
# jit_sharded_f = jax.jit(
# shard_map(
# f,
# mesh,
# in_specs=(x_sharding.spec, y_sharding.spec),
# out_specs=out_sharding,
# check_rep=False,
# )
# )
# x_shape = (m)
# y_shape = (m)
# x_dtype = dtype
# y_dtype = dtype

# key = jax.random.key(SEED)

# def data_generator():
# """Creates new random data on host and puts it on device."""
# nonlocal key # Use and update the outer 'key'
# key, k1, k2 = jax.random.split(key, 3)

# x_host = jax.random.normal(k1, x_shape).astype(x_dtype)
# y_host = jax.random.normal(k2, y_shape).astype(y_dtype)

# x_device = jax.device_put(x_host, x_sharding)
# y_device = jax.device_put(y_host, y_sharding)

# return (x_device, y_device)

# time_ms_list = iteration_timeit(
# jit_sharded_f,
# data_generator,
# matrix_dim=f"{m}",
# tries=num_runs,
# task="add",
# trace_dir=trace_dir,
# )
# return {"time_ms_list": time_ms_list}

# def add_calculate_metrics(
# m: int, dtype: jnp.dtype, time_ms_list: list[float]
# ) -> Dict[str, Any]:
# scale = 2 if dtype == jnp.bfloat16 else 1
# total_bytes = scale * 3 * m
# total_bytes, total_bytes_all_devices = handle_based_on_sharding(total_bytes, SHARDING_STRATEGY)
# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name)
16 changes: 10 additions & 6 deletions Ironwood/src/benchmark_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def send_recv_benchmark(
dtype: jnp.dtype,
trace_dir: str,
):
"""Runs p2p communication, sending tensor_size_bytes from source to target device."""
# pylint: disable=unused-argument
"""
Runs p2p communication, sending tensor_size_bytes from source to target
device.
"""
device_count = jax.local_device_count()
devices = mesh_utils.create_device_mesh((device_count,))
mesh = jax.sharding.Mesh(devices, "x")
Expand Down Expand Up @@ -120,14 +124,14 @@ def p2p_send(source_id, target_id):
target_recv_sizes,
no_recvs,
)
input = jax.random.normal(
random_input = jax.random.normal(
jax.random.key(0), (1, 8, last_dim), dtype=dtype
)
output = jnp.zeros((1, 8, last_dim), dtype=dtype)

with jax.named_scope(MARKER):
ra2a = jax.lax.ragged_all_to_all(
operand=input,
operand=random_input,
output=output,
input_offsets=input_offsets,
send_sizes=final_send_sizes,
Expand Down Expand Up @@ -158,10 +162,10 @@ def p2p_send(source_id, target_id):


def send_recv_benchmark_calculate_metrics(
source_id: int,
target_id: int,
source_id: int, # pylint: disable=unused-argument
target_id: int, # pylint: disable=unused-argument
num_elements: int,
n_repeats: int,
n_repeats: int, # pylint: disable=unused-argument
dtype: jnp.dtype,
runtime_ms: float,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand Down
Loading
Loading