diff --git a/iron/operators/gemv/design.py b/iron/operators/gemv/design.py index 5031ed33..99ad7734 100644 --- a/iron/operators/gemv/design.py +++ b/iron/operators/gemv/design.py @@ -166,6 +166,90 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): for col in range(cols) ] + # --- Batch-coalesce (default): one BD per column over all batches. --- + # Replaces the per-batch unroll with a single iterated BD; the stock A_taps/C_taps + # above remain the fallback (and the num_batches==1 path). Access-equivalent to the + # unroll (covered by test_gemv_batched). + # + # This is NOT a single linear transfer. Within one batch the run is contiguous + # (A_run = (M//cols)*K elements), but the batch stride is the full matrix + # (A_bstride = M*K), so for cols>1 each column gathers its own slice out of every + # batch with a gap in between. Only cols==1 degenerates to bstride==run. So the + # batch dim is a genuine size-uncapped iteration dim and the TAP is required. + # + # The contiguous run is then split into two wrap dims [run_hi, run_lo] ONLY to fit + # the AIE shim's 10-bit (1023) wrap-size cap. TAP sizes are outermost-first and the + # verifier reverses them, so [1, num_batches, run_hi, run_lo] puts num_batches in the + # size-uncapped dim and the contiguous run in the two capped wrap dims. The shim also + # enforces a 4-byte address granularity on every size and stride (not skipped, even + # for linear transfers); for bf16 (2 bytes) that means run_lo and the batch stride + # must be even, so split_run only yields an even run_lo and the predicate requires + # even strides. + # + # FUTURE: this manual split is only needed on the current mlir_aie pin. Once IRON's + # pin moves past Xilinx/mlir-aie #3036 (LinearizeContiguousBDTransfer for the + # iteration dim, on top of #2924 which canonicalizes a contiguous run to linear form + # and bypasses the 1023 cap via the hardware buffer-length register), split_run / + # MAX_WRAP / GRAN_ELEMS can be dropped and the run supplied as one inner dim + # [num_batches, A_run]. The pin is currently frozen at the last pre-#3016 release. + # FIXME: pull these shim BD bounds from the MLIR-AIE target model rather than + # hard-coding them; they live in verifyStridesWraps in + # https://github.com/Xilinx/mlir-aie/blob/main/lib/Dialect/AIEX/IR/AIEXDialect.cpp + MAX_WRAP = 1023 + MAX_STRIDE = (1 << 20) - 1 # conservative element-stride bound for the wrap dims + GRAN_ELEMS = 2 # 4-byte shim granularity / 2-byte bf16 element + + def split_run(run, lim=MAX_WRAP, gran=GRAN_ELEMS): + """Factor a contiguous run into (hi, lo), both <= lim and lo a multiple of gran + (the address-granularity-aligned inner size), lo maximal. None if no such + split exists (caller then falls back to the per-batch path).""" + lo_start = (lim // gran) * gran + for lo in range(lo_start, 0, -gran): + if run % lo == 0 and (run // lo) <= lim: + return (run // lo, lo) + return None + + A_run, A_bstride = (M // cols) * K, M * K + C_run, C_bstride = (M // cols), M + A_split, C_split = split_run(A_run), split_run(C_run) + coalesce = ( + num_batches > 1 + and A_bstride <= MAX_STRIDE + and C_bstride <= MAX_STRIDE + and A_bstride % GRAN_ELEMS == 0 + and C_bstride % GRAN_ELEMS == 0 + and A_split is not None + and C_split is not None + ) + + def coalesced_tap(L3_ty, col_off, split, bstride): + run_hi, run_lo = split + return TensorAccessPattern( + tensor_dims=L3_ty.__args__[0], + offset=col_off, + sizes=[1, num_batches, run_hi, run_lo], + strides=[0, bstride, run_lo, 1], + ) + + if coalesce: + # Dropping the per-batch drain wait lets the single iterated fill BD run ahead of + # the core. ObjectFifo lock backpressure keeps that safe: a producer that gets + # ahead BLOCKS on the buffer lock (worst case a stall, never a corrupting + # overrun). depth>=2 only buys OVERLAP of fill with compute, so it is a + # performance guard here, not a correctness requirement (depth==1 is correct but + # fully serial). + assert all(f.depth >= 2 for f in A_L3L1_fifos) and all( + f.depth >= 2 for f in C_L1L3_fifos + ), "coalesced GEMV wants A/C ObjectFifo depth>=2 for fill/compute overlap" + A_taps_coalesced = [ + coalesced_tap(L3_A_ty, col * (M // cols) * K, A_split, A_bstride) + for col in range(cols) + ] + C_taps_coalesced = [ + coalesced_tap(L3_C_ty, col * (M // cols), C_split, C_bstride) + for col in range(cols) + ] + rt = Runtime() with rt.sequence(L3_A_ty, L3_B_ty, L3_C_ty) as (A, B, C): rt.start(*workers) @@ -173,17 +257,22 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec): for col in range(cols): # Simple linear transfer of B, includes all batches in sequence rt.fill(B_L3L1_fifos[col].prod(), B, B_tap, task_group=tg_b) - for batch in range(num_batches): + # Coalesced: one iterated BD per column covers all batches (num_waits==1, a + # single drain wait for the whole column). Fallback (incl. num_batches==1): the + # stock per-batch unroll (num_waits==num_batches, one wait per batch). The fills + # and drains are otherwise identical; only the TAP and the wait count differ. + num_waits = 1 if coalesce else num_batches + for w in range(num_waits): tg_ac = rt.task_group() for col in range(cols): - rt.fill( - A_L3L1_fifos[col].prod(), A, A_taps[col][batch], task_group=tg_ac - ) + a_tap = A_taps_coalesced[col] if coalesce else A_taps[col][w] + rt.fill(A_L3L1_fifos[col].prod(), A, a_tap, task_group=tg_ac) for col in range(cols): + c_tap = C_taps_coalesced[col] if coalesce else C_taps[col][w] rt.drain( C_L1L3_fifos[col].cons(), C, - C_taps[col][batch], + c_tap, task_group=tg_ac, wait=True, ) diff --git a/iron/operators/gemv/reference.py b/iron/operators/gemv/reference.py index da2eee40..d1ae694f 100644 --- a/iron/operators/gemv/reference.py +++ b/iron/operators/gemv/reference.py @@ -35,3 +35,27 @@ def generate_golden_reference( "B": B, "C": C, } + + +def generate_golden_reference_batched(M=128, K=128, num_batches=2, seed=42): + """ + Generate golden reference data for a batched GEMV (num_batches independent + matrix-vector products stacked contiguously, matching the GEMV op layout). + + Parameters: + M: Number of rows of each matrix A + K: Number of columns of each matrix A (equals vector B length) + num_batches: Number of independent GEMVs + seed: Random seed + + Returns: + dict: Contains 'A' (matrices), 'B' (vectors), 'C' (output vectors) + """ + torch.manual_seed(seed) + val_range = 4 + A = torch.randn(num_batches, M, K, dtype=torch.bfloat16) * val_range + B = torch.randn(num_batches, K, dtype=torch.bfloat16) * val_range + C = torch.empty(num_batches, M, dtype=torch.bfloat16) + for b in range(num_batches): + C[b] = A[b] @ B[b] + return {"A": A, "B": B, "C": C} diff --git a/iron/operators/gemv/test.py b/iron/operators/gemv/test.py index 2abb42c0..e2981c8c 100755 --- a/iron/operators/gemv/test.py +++ b/iron/operators/gemv/test.py @@ -6,7 +6,10 @@ import aie.utils as aie_utils from iron.operators.gemv.op import GEMV -from iron.operators.gemv.reference import generate_golden_reference +from iron.operators.gemv.reference import ( + generate_golden_reference, + generate_golden_reference_batched, +) from iron.common.test_utils import run_test @@ -69,3 +72,62 @@ def test_gemv(M, K, num_aie_columns, tile_size_input, tile_size_output, aie_cont print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") assert not errors, f"Test failed with errors: {errors}" + + +def get_batched_params(): + max_cols = aie_utils.get_current_device().cols + # (M, K, cols, tsi, tso, num_batches): exercise the coalesced path + fallback. + plist = [ + (256, 128, 1, 1, 256, 4), # tiny, coalesced + (256, 128, 8, 1, 32, 100), # large num_batches -> the size-uncapped dim + (448, 64, 8, 1, 56, 192), # multi-dim run split + large num_batches together + (64, 1536, 1, 1, 64, 8), # large K + (1026, 64, 1, 1, 2, 2), # run needs an even (granularity-aligned) split + (1024, 1024, 1, 1, 64, 2), # batch stride > 2**20 -> falls back to per-batch + (512, 64, 8, 4, 64, 32), # attn-style: tile_size_input>1, num_batches=heads + ] + out = [] + for p in plist: + if p[2] > max_cols: + continue + out.append(pytest.param(*p)) + return out + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", + Throughput=r"Throughput: (?P[\d\.e\+-]+) GFLOP/s", +) +@pytest.mark.parametrize( + "M,K,num_aie_columns,tile_size_input,tile_size_output,num_batches", + get_batched_params(), +) +def test_gemv_batched( + M, K, num_aie_columns, tile_size_input, tile_size_output, num_batches, aie_context +): + golden = generate_golden_reference_batched(M=M, K=K, num_batches=num_batches) + operator = GEMV( + M=M, + K=K, + num_aie_columns=num_aie_columns, + tile_size_input=tile_size_input, + tile_size_output=tile_size_output, + num_batches=num_batches, + context=aie_context, + ) + input_buffers = { + "matrix": golden["A"].flatten(), + "vector": golden["B"].flatten(), + } + output_buffers = {"output": golden["C"].flatten()} + errors, latency_us, bandwidth_gbps = run_test( + operator, input_buffers, output_buffers, rel_tol=0.04, abs_tol=1e-3 + ) + + print(f"\nLatency: {latency_us:.1f} us") + gflops = (2.0 * M * K * num_batches) / (latency_us * 1e-6) / 1e9 + print(f"Throughput: {gflops:.6e} GFLOP/s") + print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + + assert not errors, f"batched GEMV failed: {errors}"