Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 89 additions & 3 deletions iron/operators/gemv/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,28 +166,114 @@ def core_body(A_L3L1_fifo, B_L3L1_fifo, C_L1L3_fifo, matvec):
for col in range(cols)
]

# --- Batch-coalesce (default): one BD per column over all batches. ---
# Replaces the per-batch unroll with a single iterated BD; the stock A_taps/C_taps
# above remain the fallback (and the num_batches==1 path). Access-equivalent to the
# unroll (covered by test_gemv_batched).
# AIE shim BD limits (mlir-aie AIEXDialect.cpp verifyStridesWraps): the two wrap
# dims have a size cap (1023) while one dim is size-uncapped; TAP sizes are
# outermost-first and the verifier reverses them, so [1, num_batches, run_hi,
# run_lo] places num_batches in the uncapped dim and the contiguous run in the two
# wrap dims. Access-equivalent to the unroll (covered by test_gemv_batched).
# The shim also enforces a 4-byte address granularity on every transfer size and
# stride (NOT skipped, even for linear transfers); for bf16 (2 bytes) that means
# the innermost size (run_lo) and the batch stride must be even, so split_run only
# yields an even run_lo and the predicate requires even strides.
MAX_WRAP = 1023
MAX_STRIDE = (1 << 20) - 1 # conservative element-stride bound for the wrap dims
GRAN_ELEMS = 2 # 4-byte shim granularity / 2-byte bf16 element

def split_run(run, lim=MAX_WRAP, gran=GRAN_ELEMS):
"""Factor a contiguous run into (hi, lo), both <= lim and lo a multiple of gran
(the address-granularity-aligned inner size), lo maximal. None if no such
split exists (caller then falls back to the per-batch path)."""
lo_start = (lim // gran) * gran
for lo in range(lo_start, 0, -gran):
if run % lo == 0 and (run // lo) <= lim:
return (run // lo, lo)
return None

A_run, A_bstride = (M // cols) * K, M * K
C_run, C_bstride = (M // cols), M
A_split, C_split = split_run(A_run), split_run(C_run)
coalesce = (
num_batches > 1
and A_bstride <= MAX_STRIDE
and C_bstride <= MAX_STRIDE
and A_bstride % GRAN_ELEMS == 0
and C_bstride % GRAN_ELEMS == 0
and A_split is not None
and C_split is not None
)

def coalesced_tap(L3_ty, col_off, split, bstride):
run_hi, run_lo = split
return TensorAccessPattern(
tensor_dims=L3_ty.__args__[0],
offset=col_off,
sizes=[1, num_batches, run_hi, run_lo],
strides=[0, bstride, run_lo, 1],
)

if coalesce:
# Backpressure replaces the per-batch drain wait, so the A/C ObjectFifos must
# be deep enough (>=2) for the producer not to overrun the consumer.
assert all(f.depth >= 2 for f in A_L3L1_fifos) and all(
f.depth >= 2 for f in C_L1L3_fifos
), "coalesced GEMV needs A/C ObjectFifo depth>=2 (replaces the per-batch wait)"
A_taps_coalesced = [
coalesced_tap(L3_A_ty, col * (M // cols) * K, A_split, A_bstride)
for col in range(cols)
]
C_taps_coalesced = [
coalesced_tap(L3_C_ty, col * (M // cols), C_split, C_bstride)
for col in range(cols)
]

rt = Runtime()
with rt.sequence(L3_A_ty, L3_B_ty, L3_C_ty) as (A, B, C):
rt.start(*workers)
tg_b = rt.task_group()
for col in range(cols):
# Simple linear transfer of B, includes all batches in sequence
rt.fill(B_L3L1_fifos[col].prod(), B, B_tap, task_group=tg_b)
for batch in range(num_batches):
if coalesce:
# One iterated BD per column over all batches; per-batch `wait` dropped
# in favor of ObjectFifo backpressure (asserted above).
tg_ac = rt.task_group()
for col in range(cols):
rt.fill(
A_L3L1_fifos[col].prod(), A, A_taps[col][batch], task_group=tg_ac
A_L3L1_fifos[col].prod(), A, A_taps_coalesced[col], task_group=tg_ac
)
for col in range(cols):
rt.drain(
C_L1L3_fifos[col].cons(),
C,
C_taps[col][batch],
C_taps_coalesced[col],
task_group=tg_ac,
wait=True,
)
rt.finish_task_group(tg_ac)
else:
# Fallback (also the num_batches==1 path): stock per-batch unroll.
for batch in range(num_batches):
tg_ac = rt.task_group()
for col in range(cols):
rt.fill(
A_L3L1_fifos[col].prod(),
A,
A_taps[col][batch],
task_group=tg_ac,
)
for col in range(cols):
rt.drain(
C_L1L3_fifos[col].cons(),
C,
C_taps[col][batch],
task_group=tg_ac,
wait=True,
)
rt.finish_task_group(tg_ac)
rt.finish_task_group(tg_b)

return Program(dev, rt).resolve_program(SequentialPlacer())
24 changes: 24 additions & 0 deletions iron/operators/gemv/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,27 @@ def generate_golden_reference(
"B": B,
"C": C,
}


def generate_golden_reference_batched(M=128, K=128, num_batches=2, seed=42):
"""
Generate golden reference data for a batched GEMV (num_batches independent
matrix-vector products stacked contiguously, matching the GEMV op layout).

Parameters:
M: Number of rows of each matrix A
K: Number of columns of each matrix A (equals vector B length)
num_batches: Number of independent GEMVs
seed: Random seed

Returns:
dict: Contains 'A' (matrices), 'B' (vectors), 'C' (output vectors)
"""
torch.manual_seed(seed)
val_range = 4
A = torch.randn(num_batches, M, K, dtype=torch.bfloat16) * val_range
B = torch.randn(num_batches, K, dtype=torch.bfloat16) * val_range
C = torch.empty(num_batches, M, dtype=torch.bfloat16)
for b in range(num_batches):
C[b] = A[b] @ B[b]
return {"A": A, "B": B, "C": C}
53 changes: 52 additions & 1 deletion iron/operators/gemv/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import aie.utils as aie_utils

from iron.operators.gemv.op import GEMV
from iron.operators.gemv.reference import generate_golden_reference
from iron.operators.gemv.reference import (
generate_golden_reference,
generate_golden_reference_batched,
)
from iron.common.test_utils import run_test


Expand Down Expand Up @@ -69,3 +72,51 @@ def test_gemv(M, K, num_aie_columns, tile_size_input, tile_size_output, aie_cont
print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n")

assert not errors, f"Test failed with errors: {errors}"


def get_batched_params():
max_cols = aie_utils.get_current_device().cols
# (M, K, cols, tsi, tso, num_batches): exercise the coalesced path + fallback.
plist = [
(256, 128, 1, 1, 256, 4), # tiny, coalesced
(256, 128, 8, 1, 32, 100), # large num_batches -> the size-uncapped dim
(448, 64, 8, 1, 56, 192), # multi-dim run split + large num_batches together
(64, 1536, 1, 1, 64, 8), # large K
(1026, 64, 1, 1, 2, 2), # run needs an even (granularity-aligned) split
(1024, 1024, 1, 1, 64, 2), # batch stride > 2**20 -> falls back to per-batch
(512, 64, 8, 4, 64, 32), # attn-style: tile_size_input>1, num_batches=heads
]
out = []
for p in plist:
if p[2] > max_cols:
continue
out.append(pytest.param(*p))
return out


@pytest.mark.parametrize(
"M,K,num_aie_columns,tile_size_input,tile_size_output,num_batches",
get_batched_params(),
)
def test_gemv_batched(
M, K, num_aie_columns, tile_size_input, tile_size_output, num_batches, aie_context
):
golden = generate_golden_reference_batched(M=M, K=K, num_batches=num_batches)
operator = GEMV(
M=M,
K=K,
num_aie_columns=num_aie_columns,
tile_size_input=tile_size_input,
tile_size_output=tile_size_output,
num_batches=num_batches,
context=aie_context,
)
input_buffers = {
"matrix": golden["A"].flatten(),
"vector": golden["B"].flatten(),
}
output_buffers = {"output": golden["C"].flatten()}
errors, *_ = run_test(
operator, input_buffers, output_buffers, rel_tol=0.04, abs_tol=1e-3
)
assert not errors, f"batched GEMV failed: {errors}"