Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Ironwood/configs/hbm/hbm_multiple_devices.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
benchmarks:
- benchmark_name: "multiple_device_hbm_copy"
benchmark_sweep_params:
- {num_elements_range: {start: 1048576, end: 4294967296, multiplier: 2}, dtype: "bfloat16", num_runs: 1}
trace_dir: "../microbenchmarks/hbm"
csv_path: "../microbenchmarks/hbm"
xlml_metrics_dir: "../microbenchmarks/hbm"
xla_dump_dir: "../microbenchmarks/hbm/hlo_graphs"
15 changes: 15 additions & 0 deletions Ironwood/configs/training/gemm_multiple_devices.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
benchmarks:
- benchmark_name: "gemm_multiple_devices"
trace_dir: "../microbenchmarks/gemm_multiple_devices_bf16"
csv_path: "../microbenchmarks/gemm_multiple_devices_bf16"
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_devices_bf16"
xla_dump_dir: "../microbenchmarks/gemm_multiple_devices_bf16/hlo_graphs"
benchmark_sweep_params:
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'bfloat16'}
- benchmark_name: "gemm_multiple_devices"
trace_dir: "../microbenchmarks/gemm_multiple_devices_fp8"
csv_path: "../microbenchmarks/gemm_multiple_devices_fp8"
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_devices_fp8"
xla_dump_dir: "../microbenchmarks/gemm_multiple_devices_fp8/hlo_graphs"
benchmark_sweep_params:
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'}
97 changes: 97 additions & 0 deletions Ironwood/src/benchmark_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,100 @@ def gemm_accum_calculate_metrics(
total_flops_all_devices,
PEAK_FLOPS_PER_DEVICE,
)

def gemm_multiple_devices(
m: int,
k: int,
n: int,
dtype: jnp.dtype = jax.numpy.float8_e4m3fn,
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks the OUT<M, N>:BF16 = IN0<M, K> dtype x IN1<N, K>:dtype. Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16."""

def f(x, y):
with jax.named_scope(MARKER):
acc = jax.numpy.einsum(
"ij,jk->ik", x, y, preferred_element_type=jnp.float32
)
return acc.astype(jnp.bfloat16)
SHARDING_STRATEGY_MULTI_DEVICES = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M

mesh = create_mesh(SHARDING_STRATEGY_MULTI_DEVICES)
lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY_MULTI_DEVICES)
rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY_MULTI_DEVICES)
out_sharding = get_out_sharding(SHARDING_STRATEGY_MULTI_DEVICES)

jit_sharded_f = jax.jit(
shard_map(
f,
mesh,
in_specs=(lhs_sharding.spec, rhs_sharding.spec),
out_specs=out_sharding,
check_rep=False,
)
)

lhs_shape = (m, k)
rhs_shape = (k, n)

lhs_dtype = dtype
rhs_dtype = dtype

key = jax.random.key(SEED)

def data_generator():
"""Creates new random data on host and puts it on device."""
nonlocal key # Use and update the outer 'key'
key, key_lhs, key_rhs = jax.random.split(key, 3)

# Create random data on host
lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype)
rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype)

# Put on device (HBM)
lhs_device = jax.device_put(lhs_host, lhs_sharding)
rhs_device = jax.device_put(rhs_host, rhs_sharding)

return (lhs_device, rhs_device)

# Run the benchmark

print("Running gemm_multiple_devices benchmark", num_runs)
dtype_str = "fp8" if dtype==jax.numpy.float8_e4m3fn else "bf16"
time_ms_list = multiple_iteration_timeit_from_trace(
jit_sharded_f,
data_generator,
matrix_dim=f"{dtype_str}_{m}x{n}x{k}",
tries=num_runs,
task="gemm_multiple_run",
trace_dir=trace_dir,
)
return {
"time_ms_list": time_ms_list,
}


def gemm_multiple_devices_calculate_metrics(
m: int,
k: int,
n: int,
dtype: jnp.dtype,
time_ms_list: list[float],
) -> Dict[str, Any]:
# Calculate FLOPs
SHARDING_STRATEGY_MULTI_DEVICES = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M
total_flops = 2 * m * k * n # Total floating-point operations
total_flops, total_flops_all_devices = handle_based_on_sharding(
total_flops, SHARDING_STRATEGY_MULTI_DEVICES
)
peak_flops = PEAK_FLOPS_PER_DEVICE if dtype==jax.numpy.float8_e4m3fn else PEAK_FLOPS_PER_DEVICE/2
return unified_flops_metrics(
m,
n,
k,
time_ms_list,
total_flops,
total_flops_all_devices,
peak_flops,
)
76 changes: 76 additions & 0 deletions Ironwood/src/benchmark_hbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
from benchmark_utils import (
MetricsStatistics,
multiple_iteration_timeit_from_trace,
ShardingStrategy,
create_mesh,
)
from common import MARKER
import jax
import jax.numpy as jnp

from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P

SEED = 0
os.environ["LIBTPU_INIT_ARGS"] = (
Expand Down Expand Up @@ -102,3 +106,75 @@ def single_device_hbm_copy_calculate_metrics(
metrics.update(statistics.serialize_statistics())
metrics = {key: value for key, value in metrics.items() if value is not None}
return metadata, metrics

def multiple_devices_hbm_copy(
num_elements: int,
dtype: jnp.dtype,
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks HBM with copy(read and write) on multiple devices."""

SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING
def f(a):
with jax.named_scope(MARKER):
return a.copy()

mesh = create_mesh(SHARDING_STRATEGY)
sharding = NamedSharding(mesh, P(None,))

a = jax.random.normal(jax.random.key(0), (num_elements,), out_sharding=sharding).astype(dtype)
print(a.shape)
print(a.dtype)
jitted_f = jax.jit(f)
# Run once
output = jitted_f(a)
jax.block_until_ready(output)

# Run the benchmark
time_ms_list = multiple_iteration_timeit_from_trace(
compute_func=jitted_f,
data_generator=lambda: (a,),
matrix_dim=f"{num_elements}",
tries=num_runs,
task="copy",
trace_dir=trace_dir,
)
return {"time_ms_list": time_ms_list}

def multiple_devices_hbm_copy_calculate_metrics(
num_elements: int, dtype: jnp.dtype, time_ms_list: list
) -> Dict[str, Any]:
"""Calculates the metrics for the multiple devices hbm copy benchmark."""
# Build dictionary of all the parameters in the function
params = locals().items()
metadata = get_metrics_helper(params)
metrics = {}

# Calculate throughput.
tensor_size_bytes = num_elements * dtype.dtype.itemsize

tensor_size_gbytes = (tensor_size_bytes * 2) / 10**9
time_statistics = MetricsStatistics(
metrics_list=time_ms_list, metrics_name="time_ms"
)
time_s_list = [time_ms / 10**3 for time_ms in time_ms_list]
bw_gbyte_sec_list = [tensor_size_gbytes / time_s for time_s in time_s_list]
statistics = MetricsStatistics(
metrics_list=bw_gbyte_sec_list, metrics_name="bw_gbyte_sec"
)
print(
f"Tensor size: {tensor_size_bytes / 1024**2} MB, time taken (median):"
f" {time_statistics.statistics['p50']:.4f} ms, bandwidth (median): {statistics.statistics['p50']:.3f} GB/s"
)
print()
# Gather the metrics to report.
metadata.update(
{
"tensor_size_gbytes": tensor_size_gbytes,
}
)
metrics.update(time_statistics.serialize_statistics())
metrics.update(statistics.serialize_statistics())
metrics = {key: value for key, value in metrics.items() if value is not None}
return metadata, metrics
2 changes: 2 additions & 0 deletions Ironwood/src/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
}
HBM_BENCHMARK_MAP = {
"single_device_hbm_copy": "benchmark_hbm.single_device_hbm_copy",
"multiple_device_hbm_copy": "benchmark_hbm.multiple_devices_hbm_copy",
}
COMPUTE_BENCHMARK_MAP = {
"gemm_simple": "benchmark_gemm.gemm_simple",
Expand All @@ -62,6 +63,7 @@
"gemm_throttling": "benchmark_gemm_throttling.gemm_throttling",
"gemm": "benchmark_gemm.gemm",
"gemm_accum": "benchmark_gemm.gemm_accum",
"gemm_multiple_devices": "benchmark_gemm.gemm_multiple_devices",
"quantization": "benchmark_compute.quantization",
"transpose_quantization": "benchmark_compute.transpose_quantization",
"quantization_static_scaling": (
Expand Down