diff --git a/Ironwood/configs/hbm/hbm_multiple_devices.yaml b/Ironwood/configs/hbm/hbm_multiple_devices.yaml new file mode 100644 index 00000000..1d5ceae5 --- /dev/null +++ b/Ironwood/configs/hbm/hbm_multiple_devices.yaml @@ -0,0 +1,8 @@ +benchmarks: +- benchmark_name: "multiple_device_hbm_copy" + benchmark_sweep_params: + - {num_elements_range: {start: 1048576, end: 4294967296, multiplier: 2}, dtype: "bfloat16", num_runs: 1} + trace_dir: "../microbenchmarks/hbm" + csv_path: "../microbenchmarks/hbm" + xlml_metrics_dir: "../microbenchmarks/hbm" + xla_dump_dir: "../microbenchmarks/hbm/hlo_graphs" \ No newline at end of file diff --git a/Ironwood/configs/training/gemm_multiple_devices.yaml b/Ironwood/configs/training/gemm_multiple_devices.yaml new file mode 100644 index 00000000..07590ad1 --- /dev/null +++ b/Ironwood/configs/training/gemm_multiple_devices.yaml @@ -0,0 +1,15 @@ +benchmarks: +- benchmark_name: "gemm_multiple_devices" + trace_dir: "../microbenchmarks/gemm_multiple_devices_bf16" + csv_path: "../microbenchmarks/gemm_multiple_devices_bf16" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_devices_bf16" + xla_dump_dir: "../microbenchmarks/gemm_multiple_devices_bf16/hlo_graphs" + benchmark_sweep_params: + - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'bfloat16'} +- benchmark_name: "gemm_multiple_devices" + trace_dir: "../microbenchmarks/gemm_multiple_devices_fp8" + csv_path: "../microbenchmarks/gemm_multiple_devices_fp8" + xlml_metrics_dir: "../microbenchmarks/gemm_multiple_devices_fp8" + xla_dump_dir: "../microbenchmarks/gemm_multiple_devices_fp8/hlo_graphs" + benchmark_sweep_params: + - {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'} \ No newline at end of file diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index c8c27bbe..0b92b83e 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -540,3 +540,100 @@ def gemm_accum_calculate_metrics( total_flops_all_devices, PEAK_FLOPS_PER_DEVICE, ) + +def gemm_multiple_devices( + m: int, + k: int, + n: int, + dtype: jnp.dtype = jax.numpy.float8_e4m3fn, + num_runs: int = 1, + trace_dir: str = None, +) -> Dict[str, Any]: + """Benchmarks the OUT:BF16 = IN0 dtype x IN1:dtype. Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16.""" + + def f(x, y): + with jax.named_scope(MARKER): + acc = jax.numpy.einsum( + "ij,jk->ik", x, y, preferred_element_type=jnp.float32 + ) + return acc.astype(jnp.bfloat16) + SHARDING_STRATEGY_MULTI_DEVICES = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M + + mesh = create_mesh(SHARDING_STRATEGY_MULTI_DEVICES) + lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY_MULTI_DEVICES) + rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY_MULTI_DEVICES) + out_sharding = get_out_sharding(SHARDING_STRATEGY_MULTI_DEVICES) + + jit_sharded_f = jax.jit( + shard_map( + f, + mesh, + in_specs=(lhs_sharding.spec, rhs_sharding.spec), + out_specs=out_sharding, + check_rep=False, + ) + ) + + lhs_shape = (m, k) + rhs_shape = (k, n) + + lhs_dtype = dtype + rhs_dtype = dtype + + key = jax.random.key(SEED) + + def data_generator(): + """Creates new random data on host and puts it on device.""" + nonlocal key # Use and update the outer 'key' + key, key_lhs, key_rhs = jax.random.split(key, 3) + + # Create random data on host + lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype) + rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype) + + # Put on device (HBM) + lhs_device = jax.device_put(lhs_host, lhs_sharding) + rhs_device = jax.device_put(rhs_host, rhs_sharding) + + return (lhs_device, rhs_device) + + # Run the benchmark + + print("Running gemm_multiple_devices benchmark", num_runs) + dtype_str = "fp8" if dtype==jax.numpy.float8_e4m3fn else "bf16" + time_ms_list = multiple_iteration_timeit_from_trace( + jit_sharded_f, + data_generator, + matrix_dim=f"{dtype_str}_{m}x{n}x{k}", + tries=num_runs, + task="gemm_multiple_run", + trace_dir=trace_dir, + ) + return { + "time_ms_list": time_ms_list, + } + + +def gemm_multiple_devices_calculate_metrics( + m: int, + k: int, + n: int, + dtype: jnp.dtype, + time_ms_list: list[float], +) -> Dict[str, Any]: + # Calculate FLOPs + SHARDING_STRATEGY_MULTI_DEVICES = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M + total_flops = 2 * m * k * n # Total floating-point operations + total_flops, total_flops_all_devices = handle_based_on_sharding( + total_flops, SHARDING_STRATEGY_MULTI_DEVICES + ) + peak_flops = PEAK_FLOPS_PER_DEVICE if dtype==jax.numpy.float8_e4m3fn else PEAK_FLOPS_PER_DEVICE/2 + return unified_flops_metrics( + m, + n, + k, + time_ms_list, + total_flops, + total_flops_all_devices, + peak_flops, + ) diff --git a/Ironwood/src/benchmark_hbm.py b/Ironwood/src/benchmark_hbm.py index bb279f42..7b6801b0 100644 --- a/Ironwood/src/benchmark_hbm.py +++ b/Ironwood/src/benchmark_hbm.py @@ -6,11 +6,15 @@ from benchmark_utils import ( MetricsStatistics, multiple_iteration_timeit_from_trace, + ShardingStrategy, + create_mesh, ) from common import MARKER import jax import jax.numpy as jnp +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec as P SEED = 0 os.environ["LIBTPU_INIT_ARGS"] = ( @@ -102,3 +106,75 @@ def single_device_hbm_copy_calculate_metrics( metrics.update(statistics.serialize_statistics()) metrics = {key: value for key, value in metrics.items() if value is not None} return metadata, metrics + +def multiple_devices_hbm_copy( + num_elements: int, + dtype: jnp.dtype, + num_runs: int = 1, + trace_dir: str = None, +) -> Dict[str, Any]: + """Benchmarks HBM with copy(read and write) on multiple devices.""" + + SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING + def f(a): + with jax.named_scope(MARKER): + return a.copy() + + mesh = create_mesh(SHARDING_STRATEGY) + sharding = NamedSharding(mesh, P(None,)) + + a = jax.random.normal(jax.random.key(0), (num_elements,), out_sharding=sharding).astype(dtype) + print(a.shape) + print(a.dtype) + jitted_f = jax.jit(f) + # Run once + output = jitted_f(a) + jax.block_until_ready(output) + + # Run the benchmark + time_ms_list = multiple_iteration_timeit_from_trace( + compute_func=jitted_f, + data_generator=lambda: (a,), + matrix_dim=f"{num_elements}", + tries=num_runs, + task="copy", + trace_dir=trace_dir, + ) + return {"time_ms_list": time_ms_list} + +def multiple_devices_hbm_copy_calculate_metrics( + num_elements: int, dtype: jnp.dtype, time_ms_list: list +) -> Dict[str, Any]: + """Calculates the metrics for the multiple devices hbm copy benchmark.""" + # Build dictionary of all the parameters in the function + params = locals().items() + metadata = get_metrics_helper(params) + metrics = {} + + # Calculate throughput. + tensor_size_bytes = num_elements * dtype.dtype.itemsize + + tensor_size_gbytes = (tensor_size_bytes * 2) / 10**9 + time_statistics = MetricsStatistics( + metrics_list=time_ms_list, metrics_name="time_ms" + ) + time_s_list = [time_ms / 10**3 for time_ms in time_ms_list] + bw_gbyte_sec_list = [tensor_size_gbytes / time_s for time_s in time_s_list] + statistics = MetricsStatistics( + metrics_list=bw_gbyte_sec_list, metrics_name="bw_gbyte_sec" + ) + print( + f"Tensor size: {tensor_size_bytes / 1024**2} MB, time taken (median):" + f" {time_statistics.statistics['p50']:.4f} ms, bandwidth (median): {statistics.statistics['p50']:.3f} GB/s" + ) + print() + # Gather the metrics to report. + metadata.update( + { + "tensor_size_gbytes": tensor_size_gbytes, + } + ) + metrics.update(time_statistics.serialize_statistics()) + metrics.update(statistics.serialize_statistics()) + metrics = {key: value for key, value in metrics.items() if value is not None} + return metadata, metrics \ No newline at end of file diff --git a/Ironwood/src/run_benchmark.py b/Ironwood/src/run_benchmark.py index 2b703487..3b867fee 100644 --- a/Ironwood/src/run_benchmark.py +++ b/Ironwood/src/run_benchmark.py @@ -54,6 +54,7 @@ } HBM_BENCHMARK_MAP = { "single_device_hbm_copy": "benchmark_hbm.single_device_hbm_copy", + "multiple_device_hbm_copy": "benchmark_hbm.multiple_devices_hbm_copy", } COMPUTE_BENCHMARK_MAP = { "gemm_simple": "benchmark_gemm.gemm_simple", @@ -62,6 +63,7 @@ "gemm_throttling": "benchmark_gemm_throttling.gemm_throttling", "gemm": "benchmark_gemm.gemm", "gemm_accum": "benchmark_gemm.gemm_accum", + "gemm_multiple_devices": "benchmark_gemm.gemm_multiple_devices", "quantization": "benchmark_compute.quantization", "transpose_quantization": "benchmark_compute.transpose_quantization", "quantization_static_scaling": (