diff --git a/examples/workload/example.py b/examples/workload/example.py new file mode 100644 index 0000000..3137dad --- /dev/null +++ b/examples/workload/example.py @@ -0,0 +1,194 @@ +# RUN: %PYTHON %s | FileCheck %s +# CHECK: func.func @payload +# CHECK: PASSED +# CHECK: Throughput: +""" +Workload example: Element-wise sum of two (M, N) float32 arrays on CPU. +""" + +import numpy as np +from mlir import ir +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor +from mlir.dialects import func, linalg, bufferization +from mlir.dialects import transform +from mlir.execution_engine import ExecutionEngine +from contextlib import contextmanager +from functools import cached_property +import ctypes +from typing import Optional +from lighthouse.utils.mlir import ( + apply_registered_pass, + canonicalize, + match, +) +from lighthouse.workload import ( + Workload, + execute, + benchmark, +) + + +class ElementwiseSum(Workload): + """ + Computes element-wise sum of (M, N) float32 arrays on CPU. + + We can construct the input arrays and compute the reference solution in + Python with Numpy. + + We use @cached_property to store the inputs and reference solution in the + object so that they are only computed once. + """ + + def __init__(self, M: int, N: int): + self.M = M + self.N = N + self.dtype = np.float32 + + @cached_property + def _input_arrays(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + print(" * Generating input arrays...") + np.random.seed(2) + A = np.random.rand(self.M, self.N).astype(self.dtype) + B = np.random.rand(self.M, self.N).astype(self.dtype) + C = np.zeros((self.M, self.N), dtype=self.dtype) + return [A, B, C] + + @cached_property + def _reference_solution(self) -> np.ndarray: + print(" * Computing reference solution...") + A, B, _ = self._input_arrays + return A + B + + def _get_input_arrays(self) -> list[ctypes.Structure]: + return [get_ranked_memref_descriptor(a) for a in self._input_arrays] + + @contextmanager + def allocate_inputs(self, execution_engine: ExecutionEngine): + try: + yield self._get_input_arrays() + finally: + # cached numpy arrays are deallocated automatically + pass + + def check_correctness( + self, execution_engine: ExecutionEngine, verbose: int = 0 + ) -> bool: + C = self._input_arrays[2] + C_ref = self._reference_solution + if verbose > 1: + print("Reference solution:") + print(C_ref) + print("Computed solution:") + print(C) + success = np.allclose(C, C_ref) + if verbose: + if success: + print("PASSED") + else: + print("FAILED Result mismatch!") + return success + + def shared_libs(self) -> list[str]: + return [] + + def get_complexity(self) -> tuple[int, int, int]: + nbytes = np.dtype(self.dtype).itemsize + flop_count = self.M * self.N # one addition per element + memory_reads = 2 * self.M * self.N * nbytes # read A and B + memory_writes = self.M * self.N * nbytes # write C + return (flop_count, memory_reads, memory_writes) + + def payload_module(self) -> ir.Module: + mod = ir.Module.create() + + with ir.InsertionPoint(mod.body): + float32_t = ir.F32Type.get() + shape = (self.M, self.N) + tensor_t = ir.RankedTensorType.get(shape, float32_t) + memref_t = ir.MemRefType.get(shape, float32_t) + fargs = [memref_t, memref_t, memref_t] + + @func.func(*fargs, name=self.payload_function_name) + def payload(A, B, C): + a_tensor = bufferization.to_tensor(tensor_t, A, restrict=True) + b_tensor = bufferization.to_tensor(tensor_t, B, restrict=True) + c_tensor = bufferization.to_tensor( + tensor_t, C, restrict=True, writable=True + ) + add = linalg.add(a_tensor, b_tensor, outs=[c_tensor]) + bufferization.materialize_in_destination( + None, add, C, restrict=True, writable=True + ) + + payload.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + return mod + + def schedule_module( + self, stop_at_stage: Optional[str] = None, parameters: Optional[dict] = None + ) -> ir.Module: + schedule_module = ir.Module.create() + schedule_module.operation.attributes["transform.with_named_sequence"] = ( + ir.UnitAttr.get() + ) + with ir.InsertionPoint(schedule_module.body): + named_sequence = transform.named_sequence( + "__transform_main", + [transform.AnyOpType.get()], + [], + arg_attrs=[{"transform.readonly": ir.UnitAttr.get()}], + ) + with ir.InsertionPoint(named_sequence.body): + anytype = transform.AnyOpType.get() + func = match(named_sequence.bodyTarget, ops={"func.func"}) + mod = transform.get_parent_op( + anytype, + func, + op_name="builtin.module", + deduplicate=True, + ) + mod = apply_registered_pass(mod, "one-shot-bufferize") + mod = apply_registered_pass(mod, "convert-linalg-to-loops") + transform.apply_cse(mod) + canonicalize(mod) + + if stop_at_stage == "bufferized": + transform.YieldOp() + return schedule_module + + mod = apply_registered_pass(mod, "convert-scf-to-cf") + mod = apply_registered_pass(mod, "finalize-memref-to-llvm") + mod = apply_registered_pass(mod, "convert-cf-to-llvm") + mod = apply_registered_pass(mod, "convert-arith-to-llvm") + mod = apply_registered_pass(mod, "convert-func-to-llvm") + mod = apply_registered_pass(mod, "reconcile-unrealized-casts") + transform.YieldOp() + + return schedule_module + + +if __name__ == "__main__": + with ir.Context(), ir.Location.unknown(): + wload = ElementwiseSum(400, 400) + + print(" Dump kernel ".center(60, "-")) + wload.lower_payload(dump_payload="bufferized", dump_schedule=True) + + print(" Execute 1 ".center(60, "-")) + execute(wload, verbose=2) + + print(" Execute 2 ".center(60, "-")) + execute(wload, verbose=1) + + print(" Benchmark ".center(60, "-")) + times = benchmark(wload) + times *= 1e6 # convert to microseconds + # compute statistics + mean = np.mean(times) + min = np.min(times) + max = np.max(times) + std = np.std(times) + print(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}") + flop_count = wload.get_complexity()[0] + gflops = flop_count / (mean * 1e-6) / 1e9 + print(f"Throughput: {gflops:.2f} GFLOPS") diff --git a/examples/workload/example_mlir.py b/examples/workload/example_mlir.py new file mode 100644 index 0000000..7d3211a --- /dev/null +++ b/examples/workload/example_mlir.py @@ -0,0 +1,222 @@ +# RUN: %PYTHON %s | FileCheck %s +# CHECK: func.func @payload +# CHECK: PASSED +# CHECK: Throughput: +""" +Workload example: Element-wise sum of two (M, N) float32 arrays on CPU. + +In this example, allocation and deallocation of input arrays is done in MLIR. +""" + +import numpy as np +from mlir import ir +from mlir.runtime.np_to_memref import ( + ranked_memref_to_numpy, + make_nd_memref_descriptor, + as_ctype, +) +from mlir.dialects import func, linalg, arith, memref +from mlir.execution_engine import ExecutionEngine +import ctypes +from contextlib import contextmanager +from lighthouse.utils import ( + get_packed_arg, + memrefs_to_packed_args, + memref_to_ctype, +) +from example import ElementwiseSum +from lighthouse.workload import ( + execute, + benchmark, +) + + +def emit_host_alloc(suffix: str, element_type: ir.Type, rank: int = 2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + index_t = ir.IndexType.get() + i32_t = ir.IntegerType.get_signless(32) + inputs = rank * (i32_t,) + + @func.func(*inputs, name="host_alloc_" + suffix) + def alloc_func(*shape): + dims = [arith.index_cast(index_t, a) for a in shape] + alloc = memref.alloc(memref_dyn_t, dims, []) + return alloc + + alloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + +def emit_host_dealloc(suffix: str, element_type: ir.Type, rank: int = 2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + + @func.func(memref_dyn_t, name="host_dealloc_" + suffix) + def dealloc_func(buffer): + memref.dealloc(buffer) + + dealloc_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + +def emit_fill_constant(suffix: str, value: float, element_type: ir.Type, rank: int = 2): + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + + @func.func(memref_dyn_t, name="host_fill_constant_" + suffix) + def init_func(buffer): + const = arith.constant(element_type, value) + linalg.fill(const, outs=[buffer]) + + init_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + +def emit_fill_random( + suffix: str, + element_type: ir.Type, + min: float = 0.0, + max: float = 1.0, + seed: int = 2, +): + rank = 2 + dyn = ir.ShapedType.get_dynamic_size() + memref_dyn_t = ir.MemRefType.get(rank * (dyn,), element_type) + i32_t = ir.IntegerType.get_signless(32) + f64_t = ir.F64Type.get() + + @func.func(memref_dyn_t, name="host_fill_random_" + suffix) + def init_func(buffer): + min_cst = arith.constant(f64_t, min) + max_cst = arith.constant(f64_t, max) + seed_cst = arith.constant(i32_t, seed) + linalg.fill_rng_2d(min_cst, max_cst, seed_cst, outs=[buffer]) + + init_func.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + +class ElementwiseSumMLIRAlloc(ElementwiseSum): + """ + Computes element-wise sum of (M, N) float32 arrays on CPU. + + Extends ElementwiseSum by allocating input arrays in MLIR. + """ + + def __init__(self, M: int, N: int): + super().__init__(M, N) + # keep track of allocated memrefs + self.memrefs = {} + + def _allocate_array( + self, name: str, execution_engine: ExecutionEngine + ) -> ctypes.Structure: + if name in self.memrefs: + return self.memrefs[name] + alloc_func = execution_engine.lookup("host_alloc_f32") + # construct a memref descriptor for the result memref + shape = (self.M, self.N) + mref = make_nd_memref_descriptor(len(shape), as_ctype(self.dtype))() + ptr_mref = memref_to_ctype(mref) + ptr_dims = [ctypes.pointer(ctypes.c_int32(d)) for d in shape] + alloc_func(get_packed_arg([ptr_mref, *ptr_dims])) + self.memrefs[name] = mref + return mref + + def _deallocate_all(self, execution_engine: ExecutionEngine): + for mref in self.memrefs.values(): + dealloc_func = execution_engine.lookup("host_dealloc_f32") + dealloc_func(memrefs_to_packed_args([mref])) + self.memrefs = {} + + def get_input_arrays( + self, execution_engine: ExecutionEngine + ) -> list[ctypes.Structure]: + A = self._allocate_array("A", execution_engine) + B = self._allocate_array("B", execution_engine) + C = self._allocate_array("C", execution_engine) + + # initialize with MLIR + fill_zero_func = execution_engine.lookup("host_fill_constant_zero_f32") + fill_random_func = execution_engine.lookup("host_fill_random_f32") + fill_zero_func(memrefs_to_packed_args([C])) + fill_random_func(memrefs_to_packed_args([A])) + fill_random_func(memrefs_to_packed_args([B])) + + return [A, B, C] + + @contextmanager + def allocate_inputs(self, execution_engine: ExecutionEngine): + try: + yield self.get_input_arrays(execution_engine) + finally: + self._deallocate_all(execution_engine) + + def check_correctness( + self, execution_engine: ExecutionEngine, verbose: int = 0 + ) -> bool: + # compute reference solution with numpy + A = ranked_memref_to_numpy([self.memrefs["A"]]) + B = ranked_memref_to_numpy([self.memrefs["B"]]) + C = ranked_memref_to_numpy([self.memrefs["C"]]) + C_ref = A + B + if verbose > 1: + print("Reference solution:") + print(C_ref) + print("Computed solution:") + print(C) + success = np.allclose(C, C_ref) + + # Alternatively we could have done the verification in MLIR by emitting + # a check function. + # Here we just call the payload function again. + # self._allocate_array("C_ref", execution_engine) + # func = execution_engine.lookup("payload") + # func(memrefs_to_packed_args([ + # self.memrefs["A"], + # self.memrefs["B"], + # self.memrefs["C_ref"], + # ])) + # Check correctness with numpy. + # C = ranked_memref_to_numpy([self.memrefs["C"]]) + # C_ref = ranked_memref_to_numpy([self.memrefs["C_ref"]]) + # success = np.allclose(C, C_ref) + + if verbose: + if success: + print("PASSED") + else: + print("FAILED Result mismatch!") + return success + + def payload_module(self): + mod = super().payload_module() + # extend the payload module with de/alloc/fill functions + with ir.InsertionPoint(mod.body): + float32_t = ir.F32Type.get() + emit_host_alloc("f32", float32_t) + emit_host_dealloc("f32", float32_t) + emit_fill_constant("zero_f32", 0.0, float32_t) + emit_fill_random("f32", float32_t, min=-1.0, max=1.0) + return mod + + +if __name__ == "__main__": + with ir.Context(), ir.Location.unknown(): + wload = ElementwiseSumMLIRAlloc(400, 400) + + print(" Dump kernel ".center(60, "-")) + wload.lower_payload(dump_payload="bufferized", dump_schedule=False) + + print(" Execute ".center(60, "-")) + execute(wload, verbose=2) + + print(" Benchmark ".center(60, "-")) + times = benchmark(wload) + times *= 1e6 # convert to microseconds + # compute statistics + mean = np.mean(times) + min = np.min(times) + max = np.max(times) + std = np.std(times) + print(f"Timings (us): mean={mean:.2f}+/-{std:.2f} min={min:.2f} max={max:.2f}") + flop_count = wload.get_complexity()[0] + gflops = flop_count / (mean * 1e-6) / 1e9 + print(f"Throughput: {gflops:.2f} GFLOPS") diff --git a/lighthouse/utils/mlir.py b/lighthouse/utils/mlir.py new file mode 100644 index 0000000..e900b6a --- /dev/null +++ b/lighthouse/utils/mlir.py @@ -0,0 +1,34 @@ +""" +MLIR utility functions. +""" + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import structured +import os + + +def apply_registered_pass(*args, **kwargs): + return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs) + + +def match(*args, **kwargs): + return structured.structured_match(transform.AnyOpType.get(), *args, **kwargs) + + +def canonicalize(op): + with ir.InsertionPoint(transform.apply_patterns(op).patterns): + transform.apply_patterns_canonicalization() + + +def get_mlir_library_path(): + """Return MLIR shared library path.""" + pkg_path = ir.__file__ + if "python_packages" in pkg_path: + # looks like a local mlir install + path = os.path.join(pkg_path.split("python_packages")[0], "lib") + else: + # maybe installed in python path + path = os.path.join(os.path.split(pkg_path)[0], "_mlir_libs") + assert os.path.isdir(path) + return path diff --git a/lighthouse/workload/__init__.py b/lighthouse/workload/__init__.py new file mode 100644 index 0000000..4738604 --- /dev/null +++ b/lighthouse/workload/__init__.py @@ -0,0 +1,4 @@ +from .workload import Workload +from .runner import execute, benchmark + +__all__ = ["Workload", "benchmark", "execute"] diff --git a/lighthouse/workload/runner.py b/lighthouse/workload/runner.py new file mode 100644 index 0000000..ae5e07e --- /dev/null +++ b/lighthouse/workload/runner.py @@ -0,0 +1,165 @@ +""" +Utility functions for running workloads. +""" + +import numpy as np +import os +from mlir import ir +from mlir.dialects import func, arith, scf, memref +from mlir.execution_engine import ExecutionEngine +from mlir.runtime.np_to_memref import get_ranked_memref_descriptor +from lighthouse.utils.mlir import get_mlir_library_path +from lighthouse.utils import memrefs_to_packed_args +from lighthouse.workload import Workload +from typing import Optional + + +def get_engine( + payload_module: ir.Module, shared_libs: list[str] = None, opt_level: int = 3 +) -> ExecutionEngine: + lib_dir = get_mlir_library_path() + libs = [] + for so_file in shared_libs or []: + so_path = os.path.join(lib_dir, so_file) + if not os.path.isfile(so_path): + raise ValueError(f"Could not find shared library {so_path}") + libs.append(so_path) + execution_engine = ExecutionEngine( + payload_module, opt_level=opt_level, shared_libs=libs + ) + execution_engine.initialize() + return execution_engine + + +def execute( + workload: Workload, + check_correctness: bool = True, + schedule_parameters: Optional[dict] = None, + verbose: int = 0, +): + # lower payload with schedule + payload_module = workload.lower_payload(schedule_parameters=schedule_parameters) + # get execution engine + engine = get_engine(payload_module, shared_libs=workload.shared_libs()) + + with workload.allocate_inputs(execution_engine=engine) as inputs: + # prepare function arguments + packed_args = memrefs_to_packed_args(inputs) + + # handle to payload function + payload_func = engine.lookup(workload.payload_function_name) + + # call function + payload_func(packed_args) + + if check_correctness: + success = workload.check_correctness( + execution_engine=engine, verbose=verbose + ) + if not success: + raise ValueError("Benchmark verification failed.") + + +def emit_benchmark_function( + payload_module: ir.Module, + payload_function_name: str, + nruns: int, + nwarmup: int, +): + """ + Emit a benchmark function that calls payload function and times it. + + Every function call is timed separately. Returns the times (seconds) in a + memref. + """ + # find original payload function + payload_func = None + for op in payload_module.operation.regions[0].blocks[0]: + if isinstance(op, func.FuncOp) and op.name.value == payload_function_name: + payload_func = op + break + assert payload_func is not None, "Could not find payload function" + payload_arguments = payload_func.type.inputs + + # emit benchmark function that calls payload and times it + with ir.InsertionPoint(payload_module.body): + # define rtclock function + f64_t = ir.F64Type.get() + func.FuncOp("rtclock", ((), (f64_t,)), visibility="private") + # emit benchmark function + time_memref_t = ir.MemRefType.get((nruns,), f64_t) + args = payload_arguments + [time_memref_t] + + @func.func(*args) + def benchmark(*args): + index_t = ir.IndexType.get() + zero = arith.constant(index_t, 0) + one = arith.constant(index_t, 1) + nwarmup_cst = arith.constant(index_t, nwarmup) + for i in scf.for_(zero, nwarmup_cst, one): + # FIXME(upstream): func.call is broken for this use case? + func.CallOp(payload_func, list(args[: len(payload_arguments)])) + scf.yield_(()) + nruns_cst = arith.constant(index_t, nruns) + for i in scf.for_(zero, nruns_cst, one): + tic = func.call((f64_t,), "rtclock", ()) + func.CallOp(payload_func, list(args[: len(payload_arguments)])) + toc = func.call((f64_t,), "rtclock", ()) + time = arith.subf(toc, tic) + memref.store(time, args[-1], [i]) + scf.yield_(()) + + benchmark.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + + +def benchmark( + workload: Workload, + nruns: int = 100, + nwarmup: int = 10, + schedule_parameters: Optional[dict] = None, + check_correctness: bool = True, + verbose: int = 0, +) -> np.ndarray: + # get original payload module + payload_module = workload.payload_module() + + # add benchmark function with timing + emit_benchmark_function( + payload_module, workload.payload_function_name, nruns, nwarmup + ) + + # lower + schedule_module = workload.schedule_module(parameters=schedule_parameters) + schedule_module.body.operations[0].apply(payload_module) + + # get execution engine, rtclock requires mlir_c_runner + libs = workload.shared_libs() + c_runner_lib = "libmlir_c_runner_utils.so" + if c_runner_lib not in libs: + libs.append(c_runner_lib) + engine = get_engine(payload_module, shared_libs=libs) + + with workload.allocate_inputs(execution_engine=engine) as inputs: + if check_correctness: + # call payload once to verify correctness + # prepare function arguments + packed_args = memrefs_to_packed_args(inputs) + + payload_func = engine.lookup(workload.payload_function_name) + payload_func(packed_args) + success = workload.check_correctness( + execution_engine=engine, verbose=verbose + ) + if not success: + raise ValueError("Benchmark verification failed.") + + # allocate buffer for timings and prepare arguments + time_array = np.zeros((nruns,), dtype=np.float64) + time_memref = get_ranked_memref_descriptor(time_array) + packed_args_with_time = memrefs_to_packed_args(inputs + [time_memref]) + + # call benchmark function + benchmark_func = engine.lookup("benchmark") + benchmark_func(packed_args_with_time) + + return time_array diff --git a/lighthouse/workload/workload.py b/lighthouse/workload/workload.py new file mode 100644 index 0000000..cc2c4f5 --- /dev/null +++ b/lighthouse/workload/workload.py @@ -0,0 +1,108 @@ +""" +Abstract base class for workloads. + +Defines the expected interface for generic workload execution methods. +""" + +from mlir import ir +from mlir.execution_engine import ExecutionEngine +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Optional + + +class Workload(ABC): + """ + Abstract base class for workloads. + + A workload is defined by a fixed payload function and problem size. + Different realizations of the workload can be obtained by altering the + lowering schedule parameters. + + The MLIR payload function should take input arrays as memrefs and return + nothing. + """ + + payload_function_name: str = "payload" + + @abstractmethod + def shared_libs(self) -> list[str]: + """Return a list of shared libraries required byt the execution engine.""" + pass + + @abstractmethod + def payload_module(self) -> ir.Module: + """Generate the MLIR module containing the payload function.""" + pass + + @abstractmethod + def schedule_module( + self, + stop_at_stage: Optional[str] = None, + parameters: Optional[dict] = None, + ) -> ir.Module: + """ + Generate the MLIR module containing the transform schedule. + + The `stop_at_stage` argument can be used to interrupt lowering at + a desired IR level for debugging purposes. + """ + pass + + def lower_payload( + self, + dump_payload: Optional[str] = None, + dump_schedule: bool = False, + schedule_parameters: Optional[dict] = None, + ) -> ir.Module: + """ + Apply transform schedule to the payload module. + + Optionally dumps the payload IR at the desired level and/or the + transform schedule to stdout. + + Returns the lowered payload module. + """ + payload_module = self.payload_module() + schedule_module = self.schedule_module( + stop_at_stage=dump_payload, parameters=schedule_parameters + ) + if not dump_payload or dump_payload != "initial": + # apply schedule on payload module + named_seq = schedule_module.body.operations[0] + named_seq.apply(payload_module) + if dump_payload: + print(payload_module) + if dump_schedule: + print(schedule_module) + return payload_module + + @abstractmethod + @contextmanager + def allocate_inputs(self, execution_engine: ExecutionEngine): + """ + Context manager that allocates and returns payload input buffers. + + Returns the payload input buffers as memrefs that can be directly + passed to the compiled payload function. + + On exit, frees any manually allocated memory (if any). + """ + pass + + @abstractmethod + def check_correctness( + self, execution_engine: ExecutionEngine, verbose: int = 0 + ) -> bool: + """Verify the correctness of the computation.""" + pass + + @abstractmethod + def get_complexity(self) -> tuple[int, int, int]: + """ + Return the computational complexity of the workload. + + Returns a tuple (flop_count, memory_reads, memory_writes). Memory + reads/writes are in bytes. + """ + pass