diff --git a/docs/src/python/distributed.rst b/docs/src/python/distributed.rst index 8b48d727e0..464e723ecb 100644 --- a/docs/src/python/distributed.rst +++ b/docs/src/python/distributed.rst @@ -17,6 +17,9 @@ made available. init all_sum all_gather + all_to_all + moe_dispatch_exchange + moe_combine_exchange send recv recv_like diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index b9544bae51..ec87a75192 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -45,6 +45,7 @@ Layers MaxPool2d MaxPool3d Mish + MixtureOfExperts MultiHeadAttention PReLU QuantizedAllToShardedLinear diff --git a/mlx/backend/cpu/distributed.cpp b/mlx/backend/cpu/distributed.cpp index 22dc4b4cc8..490f8ff86c 100644 --- a/mlx/backend/cpu/distributed.cpp +++ b/mlx/backend/cpu/distributed.cpp @@ -1,11 +1,13 @@ // Copyright © 2024 Apple Inc. #include +#include #include "mlx/allocator.h" #include "mlx/backend/cpu/copy.h" #include "mlx/backend/cpu/encoder.h" #include "mlx/distributed/primitives.h" +#include "mlx/types/half_types.h" namespace mlx::core::distributed { @@ -100,4 +102,902 @@ void ReduceScatter::eval_cpu( std::vector& outputs) { throw std::runtime_error("[ReduceScatter] Not implemented yet."); } + +void AllToAll::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 1); + assert(outputs.size() == 1); + auto [in, copied] = ensure_row_contiguous(inputs[0], stream()); + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + distributed::detail::all_to_all(group(), in, outputs[0], stream()); + if (copied) { + auto& enc = cpu::get_command_encoder(stream()); + enc.add_temporary(in); + } +} + +// Helper: make a row-contiguous array that shares the buffer of `parent`, +// starting at byte offset `byte_offset`, with shape `shape`. +// The caller must ensure `parent` stays alive while the returned array is used. +static array make_subview( + const array& parent, + const Shape& shape, + size_t byte_offset, + Dtype dtype) { + // Compute strides for row-contiguous layout + Strides strides(shape.size()); + if (!shape.empty()) { + strides.back() = 1; + for (int i = static_cast(shape.size()) - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + } + size_t num_elems = 1; + for (auto s : shape) + num_elems *= s; + array::Flags flags; + flags.contiguous = true; + flags.row_contiguous = true; + flags.col_contiguous = (shape.size() <= 1 || num_elems <= 1); + + // byte_offset in terms of elements + int64_t elem_offset = static_cast(byte_offset / size_of(dtype)); + array view(shape, dtype, nullptr, {}); + view.copy_shared_buffer(parent, strides, flags, num_elems, elem_offset); + return view; +} + +void MoeDispatchExchange::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 2); + assert(outputs.size() == 2); + + auto [tokens_in, tok_copied] = ensure_row_contiguous(inputs[0], stream()); + auto [indices_in, idx_copied] = ensure_row_contiguous(inputs[1], stream()); + + // Allocate outputs before dispatch so callers can depend on their pointers + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + outputs[1].set_data(allocator::malloc(outputs[1].nbytes())); + + int N = tokens_in.shape(0); + int D = tokens_in.shape(1); + int top_k = indices_in.shape(1); + int world_size = group().size(); + int num_experts = num_experts_; + int capacity = capacity_; + Group grp = group(); + size_t elem_size = tokens_in.itemsize(); + Dtype dtype = tokens_in.dtype(); + + // Capture raw pointers; arrays kept alive via add_temporary below + const void* tok_raw = tokens_in.data(); + const int32_t* idx_raw = indices_in.data(); + void* out0_raw = outputs[0].data(); + int32_t* out1_raw = outputs[1].data(); + + auto& enc = cpu::get_command_encoder(stream()); + + // All data access is inside enc.dispatch so it runs after Metal GPU + // command buffers for upstream ops have been committed and synchronized. + enc.dispatch([tok_raw, + idx_raw, + out0_raw, + out1_raw, + N, + D, + top_k, + world_size, + num_experts, + capacity, + elem_size, + dtype, + grp]() mutable { + int experts_per_device = num_experts / world_size; + int cap_total = world_size * capacity; // total capacity slots per expert + + // Initialize route_indices to -1 + int32_t* route_ptr = out1_raw; + std::fill(route_ptr, route_ptr + N * top_k, int32_t(-1)); + + // Zero-initialize the output dispatched buffer + // Output shape: [experts_per_device, cap_total, D] + size_t out_nbytes = (size_t)experts_per_device * cap_total * D * elem_size; + std::memset(out0_raw, 0, out_nbytes); + + const auto* tok_bytes = static_cast(tok_raw); + auto* out_bytes = static_cast(out0_raw); + std::vector expert_counts(num_experts, 0); + + // world_size == 1: local-only path (no send/recv) + // New route_idx layout: flat_idx = local_expert * cap_total + dest_rank * + // capacity + pos For world_size=1: dest_rank=0, cap_total=capacity + // => flat_idx = local_expert * capacity + pos (same formula as old for + // ws=1) + if (world_size == 1) { + for (int k = 0; k < top_k; k++) { + for (int n = 0; n < N; n++) { + int eid = idx_raw[n * top_k + k]; + if (eid < 0 || eid >= num_experts) + continue; + int pos = expert_counts[eid]++; + if (pos < capacity) { + // dest_rank=0, local_expert=eid, cap_total=capacity + int flat_idx = eid * cap_total + 0 * capacity + pos; + route_ptr[n * top_k + k] = flat_idx; + std::memcpy( + out_bytes + flat_idx * D * elem_size, + tok_bytes + n * D * elem_size, + D * elem_size); + } + // else: route stays -1 + } + } + return; + } + + // world_size == 2: v3 variable exchange protocol + if (world_size == 2) { + int my_rank = grp.rank(); + int peer = 1 - my_rank; + + // Packet row layout: [meta32(4B) | payload(D*elem_size) | pad] + // meta32 = (local_expert << 16) | (pos & 0xFFFF) + size_t raw_row = 4 + D * elem_size; + int row_stride = + static_cast((raw_row + 15) & ~size_t(15)); // align to 16 + + int max_send = N * top_k; + int recv_cap = experts_per_device * + capacity; // peer can fill at most capacity per expert + + // Allocate packet buffers + size_t send_pkt_bytes = + static_cast(std::max(max_send, 1)) * row_stride; + size_t recv_pkt_bytes = + static_cast(std::max(recv_cap, 1)) * row_stride; + + array send_pkt({static_cast(send_pkt_bytes)}, uint8, nullptr, {}); + send_pkt.set_data(allocator::malloc(send_pkt_bytes)); + auto* send_pkt_ptr = send_pkt.data(); + + array recv_pkt({static_cast(recv_pkt_bytes)}, uint8, nullptr, {}); + recv_pkt.set_data(allocator::malloc(recv_pkt_bytes)); + + // Count exchange arrays + array count_send({1}, int32, nullptr, {}); + count_send.set_data(allocator::malloc(sizeof(int32_t))); + array count_recv({1}, int32, nullptr, {}); + count_recv.set_data(allocator::malloc(sizeof(int32_t))); + + int send_count = 0; + + // Dispatch: k-outer, n-inner deterministic loop + for (int k = 0; k < top_k; k++) { + for (int n = 0; n < N; n++) { + int eid = idx_raw[n * top_k + k]; + if (eid < 0 || eid >= num_experts) + continue; + int dest_rank = eid / experts_per_device; + int local_expert = eid % experts_per_device; + int pos = expert_counts[eid]++; + if (pos >= capacity) + continue; + + int flat_idx = local_expert * cap_total + dest_rank * capacity + pos; + route_ptr[n * top_k + k] = flat_idx; + + if (dest_rank == my_rank) { + // LOCAL: directly scatter into output + std::memcpy( + out_bytes + flat_idx * D * elem_size, + tok_bytes + n * D * elem_size, + D * elem_size); + } else { + // REMOTE: pack into send packet + uint8_t* row = + send_pkt_ptr + static_cast(send_count) * row_stride; + uint32_t meta = (static_cast(local_expert) << 16) | + (static_cast(pos) & 0xFFFF); + std::memcpy(row, &meta, 4); + std::memcpy(row + 4, tok_bytes + n * D * elem_size, D * elem_size); + send_count++; + } + } + } + + // Exchange packets + auto* raw = grp.raw_group().get(); + int peer_count = raw->blocking_exchange_v( + send_pkt, + send_count, + recv_pkt, + recv_cap, + row_stride, + peer, + detail::ExchangeTag::MoeDispatchCount, + detail::ExchangeTag::MoeDispatchPayload, + count_send, + count_recv); + + // Scatter received remote tokens into output + auto* recv_pkt_ptr = recv_pkt.data(); + for (int i = 0; i < peer_count; i++) { + const uint8_t* row = recv_pkt_ptr + static_cast(i) * row_stride; + uint32_t meta; + std::memcpy(&meta, row, 4); + int local_expert = static_cast(meta >> 16); + int slot_pos = static_cast(meta & 0xFFFF); + if (local_expert < 0 || local_expert >= experts_per_device || + slot_pos < 0 || slot_pos >= capacity) { + throw std::runtime_error( + "[MoeDispatchExchange] received out-of-bounds metadata: " + "local_expert=" + + std::to_string(local_expert) + + " slot_pos=" + std::to_string(slot_pos)); + } + int recv_flat_idx = + local_expert * cap_total + peer * capacity + slot_pos; + std::memcpy( + out_bytes + recv_flat_idx * D * elem_size, row + 4, D * elem_size); + } + return; + } + + // world_size > 2: fallback to existing fixed all_to_all + { + int slots_per_device = experts_per_device * capacity; + int total_slots = world_size * slots_per_device; + size_t send_nbytes = (size_t)total_slots * D * elem_size; + + // Allocate send buffer: [total_slots, D] (layout: [W, E, C, D]) + array send_arr(Shape{total_slots, D}, dtype, nullptr, {}); + send_arr.set_data(allocator::malloc(send_nbytes)); + std::memset(send_arr.data(), 0, send_nbytes); + + auto* send_bytes = static_cast(send_arr.data()); + + // Dispatch: k-outer, n-inner for deterministic slot assignment + // Use NEW route_idx layout: flat_idx = local_expert * cap_total + + // dest_rank * capacity + pos The send buffer uses old layout [W, E, C, D] + // for all_to_all compatibility + for (int k = 0; k < top_k; k++) { + for (int n = 0; n < N; n++) { + int eid = idx_raw[n * top_k + k]; + if (eid < 0 || eid >= num_experts) + continue; + int pos = expert_counts[eid]++; + if (pos < capacity) { + int dest_rank = eid / experts_per_device; + int local_expert = eid % experts_per_device; + // Old send buffer layout for all_to_all: [W, E, C, D] + int send_flat = + dest_rank * slots_per_device + local_expert * capacity + pos; + // New route_idx layout + int new_flat_idx = + local_expert * cap_total + dest_rank * capacity + pos; + route_ptr[n * top_k + k] = new_flat_idx; + std::memcpy( + send_bytes + send_flat * D * elem_size, + tok_bytes + n * D * elem_size, + D * elem_size); + } + // else: route stays -1 + } + } + + // Allocate recv buffer + array recv_arr(Shape{total_slots, D}, dtype, nullptr, {}); + recv_arr.set_data(allocator::malloc(send_nbytes)); + + // All-to-all exchange using blocking API + grp.raw_group()->blocking_all_to_all(send_arr, recv_arr); + + // recv_arr layout: [world_size, experts_per_device, capacity, D] + // output layout: [experts_per_device, world_size * capacity, D] + // out[e, w*capacity+c, d] = recv[w, e, c, d] + const auto* recv_bytes = + static_cast(recv_arr.data()); + + for (int w = 0; w < world_size; w++) { + for (int e = 0; e < experts_per_device; e++) { + for (int c = 0; c < capacity; c++) { + int recv_row = w * slots_per_device + e * capacity + c; + int out_row = e * cap_total + w * capacity + c; + std::memcpy( + out_bytes + out_row * D * elem_size, + recv_bytes + recv_row * D * elem_size, + D * elem_size); + } + } + } + // send_arr and recv_arr go out of scope here; their allocator memory + // is freed via the array destructor. + } + }); + + // Keep input arrays alive until the dispatched lambda has executed + enc.add_temporary(tokens_in); + enc.add_temporary(indices_in); +} + +void MoeCombineExchange::eval_cpu( + const std::vector& inputs, + std::vector& outputs) { + assert(inputs.size() == 4); + assert(outputs.size() == 1); + + // inputs: expert_outputs [E_local, cap_total, D], + // route_indices [N, top_k] int32, + // weights [N, top_k] float32, + // original_tokens [N, D] + auto [expert_out, eo_copied] = ensure_row_contiguous(inputs[0], stream()); + auto [route_idx, ri_copied] = ensure_row_contiguous(inputs[1], stream()); + auto [weights_in, w_copied] = ensure_row_contiguous(inputs[2], stream()); + auto [orig_tok, ot_copied] = ensure_row_contiguous(inputs[3], stream()); + + int experts_per_device = expert_out.shape(0); + int cap_total = expert_out.shape(1); + int D = expert_out.shape(2); + int N = orig_tok.shape(0); + int top_k = route_idx.shape(1); + int world_size = group().size(); + int capacity = capacity_; + Group grp = group(); + size_t elem_size = expert_out.itemsize(); + Dtype dtype = expert_out.dtype(); + + // Allocate output before dispatch + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + + // Capture raw pointers; arrays kept alive via add_temporary below + const void* eo_raw = expert_out.data(); + const int32_t* ri_raw = route_idx.data(); + const float* w_raw = weights_in.data(); + const void* orig_raw = orig_tok.data(); + void* out0_raw = outputs[0].data(); + + auto& enc = cpu::get_command_encoder(stream()); + + // All data access is inside enc.dispatch so it runs after Metal GPU + // command buffers for upstream ops have been committed and synchronized. + enc.dispatch([eo_raw, + ri_raw, + w_raw, + orig_raw, + out0_raw, + experts_per_device, + cap_total, + D, + N, + top_k, + world_size, + capacity, + elem_size, + dtype, + grp]() mutable { + // world_size == 1: local-only path (no send/recv) + // route_idx flat_idx = local_expert * cap_total + 0 * capacity + pos + // = local_expert * capacity + pos (cap_total == + // capacity for ws=1) + // expert_outputs is indexed directly by flat_idx + if (world_size == 1) { + switch (dtype) { + case float32: { + const auto* eo_f = static_cast(eo_raw); + auto* out_f = static_cast(out0_raw); + const auto* orig_f = static_cast(orig_raw); + for (int n = 0; n < N; n++) { + float* dst = out_f + n * D; + std::fill(dst, dst + D, 0.0f); + bool has_valid = false; + for (int k = 0; k < top_k; k++) { + int flat_idx = ri_raw[n * top_k + k]; + if (flat_idx >= 0) { + has_valid = true; + float w = w_raw[n * top_k + k]; + const float* src = eo_f + flat_idx * D; + for (int d = 0; d < D; d++) + dst[d] += w * src[d]; + } + } + if (!has_valid) { + std::memcpy(dst, orig_f + n * D, D * sizeof(float)); + } + } + break; + } + case float16: { + const auto* eo_h = static_cast(eo_raw); + auto* out_h = static_cast(out0_raw); + const auto* orig_h = static_cast(orig_raw); + std::vector accum(D); + for (int n = 0; n < N; n++) { + std::fill(accum.begin(), accum.end(), 0.0f); + bool has_valid = false; + for (int k = 0; k < top_k; k++) { + int flat_idx = ri_raw[n * top_k + k]; + if (flat_idx >= 0) { + has_valid = true; + float w = w_raw[n * top_k + k]; + const float16_t* src = eo_h + flat_idx * D; + for (int d = 0; d < D; d++) { + accum[d] += w * static_cast(src[d]); + } + } + } + float16_t* dst = out_h + n * D; + if (has_valid) { + for (int d = 0; d < D; d++) + dst[d] = float16_t(accum[d]); + } else { + std::memcpy(dst, orig_h + n * D, D * sizeof(float16_t)); + } + } + break; + } + case bfloat16: { + const auto* eo_h = static_cast(eo_raw); + auto* out_h = static_cast(out0_raw); + const auto* orig_h = static_cast(orig_raw); + std::vector accum(D); + for (int n = 0; n < N; n++) { + std::fill(accum.begin(), accum.end(), 0.0f); + bool has_valid = false; + for (int k = 0; k < top_k; k++) { + int flat_idx = ri_raw[n * top_k + k]; + if (flat_idx >= 0) { + has_valid = true; + float w = w_raw[n * top_k + k]; + const bfloat16_t* src = eo_h + flat_idx * D; + for (int d = 0; d < D; d++) { + accum[d] += w * static_cast(src[d]); + } + } + } + bfloat16_t* dst = out_h + n * D; + if (has_valid) { + for (int d = 0; d < D; d++) + dst[d] = bfloat16_t(accum[d]); + } else { + std::memcpy(dst, orig_h + n * D, D * sizeof(bfloat16_t)); + } + } + break; + } + default: + throw std::runtime_error( + "[MoeCombineExchange] Unsupported dtype. Use float32, float16, or bfloat16."); + } + return; + } + + // world_size == 2: v3 combine protocol + if (world_size == 2) { + int my_rank = grp.rank(); + int peer = 1 - my_rank; + + const auto* eo_bytes = static_cast(eo_raw); + + // Response row layout: [token_slot32(4B) | payload(D*elem_size) | pad] + size_t raw_resp_row = 4 + D * elem_size; + int resp_stride = static_cast((raw_resp_row + 15) & ~size_t(15)); + + // Request row layout: [token_slot32(4B) | local_expert16(2B) | pos16(2B)] + int req_stride = 8; + + int max_local_routes = N * top_k; // max requests WE send + int max_peer_routes = + experts_per_device * capacity; // max requests PEER can send + + // Allocate request buffers + size_t req_send_bytes = + static_cast(std::max(max_local_routes, 1)) * req_stride; + size_t req_recv_bytes = + static_cast(std::max(max_peer_routes, 1)) * req_stride; + array req_send({static_cast(req_send_bytes)}, uint8, nullptr, {}); + req_send.set_data(allocator::malloc(req_send_bytes)); + array req_recv({static_cast(req_recv_bytes)}, uint8, nullptr, {}); + req_recv.set_data(allocator::malloc(req_recv_bytes)); + + // Allocate response buffers — responses bounded by received requests / + // sent requests + size_t resp_send_bytes = + static_cast(std::max(max_peer_routes, 1)) * resp_stride; + size_t resp_recv_bytes = + static_cast(std::max(max_local_routes, 1)) * resp_stride; + array resp_send({static_cast(resp_send_bytes)}, uint8, nullptr, {}); + resp_send.set_data(allocator::malloc(resp_send_bytes)); + array resp_recv({static_cast(resp_recv_bytes)}, uint8, nullptr, {}); + resp_recv.set_data(allocator::malloc(resp_recv_bytes)); + + // Count exchange arrays (reused for both request and response exchanges) + array count_send({1}, int32, nullptr, {}); + count_send.set_data(allocator::malloc(sizeof(int32_t))); + array count_recv({1}, int32, nullptr, {}); + count_recv.set_data(allocator::malloc(sizeof(int32_t))); + + // Lambda for weighted accumulate (handles all dtypes) + auto weighted_add = [&](void* dst_raw, + const void* src_raw, + float w, + int D) { + switch (dtype) { + case float32: { + auto* dst = static_cast(dst_raw); + const auto* src = static_cast(src_raw); + for (int d = 0; d < D; d++) + dst[d] += w * src[d]; + break; + } + case float16: { + auto* dst = static_cast(dst_raw); // accumulate in float32 + const auto* src = static_cast(src_raw); + for (int d = 0; d < D; d++) + dst[d] += w * static_cast(src[d]); + break; + } + case bfloat16: { + auto* dst = static_cast(dst_raw); // accumulate in float32 + const auto* src = static_cast(src_raw); + for (int d = 0; d < D; d++) + dst[d] += w * static_cast(src[d]); + break; + } + default: + throw std::runtime_error("[MoeCombineExchange] Unsupported dtype"); + } + }; + + // Accumulation buffer (always float32 for precision) + std::vector accum(static_cast(N) * D, 0.0f); + std::vector has_valid(N, false); + + int req_send_count = 0; + auto* req_send_ptr = req_send.data(); + + // Step 1: Process all routes, accumulate local, pack remote requests + for (int k = 0; k < top_k; k++) { + for (int n = 0; n < N; n++) { + int flat_idx = ri_raw[n * top_k + k]; + if (flat_idx < 0) + continue; + + int remainder = flat_idx % cap_total; + int dest_rank = remainder / capacity; + float w = w_raw[n * top_k + k]; + + if (dest_rank == my_rank) { + // LOCAL: accumulate directly + has_valid[n] = true; + weighted_add( + accum.data() + static_cast(n) * D, + eo_bytes + static_cast(flat_idx) * D * elem_size, + w, + D); + } else { + // REMOTE: pack request + int local_expert_idx = flat_idx / cap_total; + int pos = remainder % capacity; + uint32_t token_slot = static_cast(n * top_k + k); + uint16_t le16 = static_cast(local_expert_idx); + uint16_t pos16 = static_cast(pos); + + uint8_t* row = + req_send_ptr + static_cast(req_send_count) * req_stride; + std::memcpy(row, &token_slot, 4); + std::memcpy(row + 4, &le16, 2); + std::memcpy(row + 6, &pos16, 2); + req_send_count++; + has_valid[n] = true; + } + } + } + + auto* raw = grp.raw_group().get(); + + // Step 2: Exchange requests + int peer_req_count = raw->blocking_exchange_v( + req_send, + req_send_count, + req_recv, + max_peer_routes, + req_stride, + peer, + detail::ExchangeTag::MoeCombineReqCount, + detail::ExchangeTag::MoeCombineReqPayload, + count_send, + count_recv); + + // Step 3: Build responses from received requests + auto* req_recv_ptr = req_recv.data(); + auto* resp_send_ptr = resp_send.data(); + + for (int i = 0; i < peer_req_count; i++) { + const uint8_t* req_row = + req_recv_ptr + static_cast(i) * req_stride; + uint32_t token_slot; + uint16_t le16, pos16; + std::memcpy(&token_slot, req_row, 4); + std::memcpy(&le16, req_row + 4, 2); + std::memcpy(&pos16, req_row + 6, 2); + + int local_expert = static_cast(le16); + int slot_pos = static_cast(pos16); + + if (local_expert < 0 || local_expert >= experts_per_device || + slot_pos < 0 || slot_pos >= capacity) { + throw std::runtime_error( + "[MoeCombineExchange] out-of-bounds request: local_expert=" + + std::to_string(local_expert) + + " pos=" + std::to_string(slot_pos)); + } + + // Lookup: expert_outputs at peer's slot + int eo_flat = local_expert * cap_total + peer * capacity + slot_pos; + + // Pack response: [token_slot | payload] + uint8_t* resp_row = + resp_send_ptr + static_cast(i) * resp_stride; + std::memcpy(resp_row, &token_slot, 4); + std::memcpy( + resp_row + 4, + eo_bytes + static_cast(eo_flat) * D * elem_size, + D * elem_size); + } + + // Step 4: Exchange responses + int peer_res_count = raw->blocking_exchange_v( + resp_send, + peer_req_count, + resp_recv, + max_local_routes, + resp_stride, + peer, + detail::ExchangeTag::MoeCombineResCount, + detail::ExchangeTag::MoeCombineResPayload, + count_send, + count_recv); + + // Step 5: Process responses — accumulate into output + auto* resp_recv_ptr = resp_recv.data(); + for (int i = 0; i < peer_res_count; i++) { + const uint8_t* resp_row = + resp_recv_ptr + static_cast(i) * resp_stride; + uint32_t token_slot; + std::memcpy(&token_slot, resp_row, 4); + + if (token_slot >= static_cast(N * top_k)) { + throw std::runtime_error( + "[MoeCombineExchange] invalid token_slot=" + + std::to_string(token_slot)); + } + + int n = static_cast(token_slot) / top_k; + int k = static_cast(token_slot) % top_k; + float w = w_raw[n * top_k + k]; + + weighted_add( + accum.data() + static_cast(n) * D, resp_row + 4, w, D); + } + + // Step 6: Write output + switch (dtype) { + case float32: { + auto* out_f = static_cast(out0_raw); + const auto* orig_f = static_cast(orig_raw); + for (int n = 0; n < N; n++) { + if (has_valid[n]) { + std::memcpy( + out_f + n * D, + accum.data() + static_cast(n) * D, + D * sizeof(float)); + } else { + std::memcpy(out_f + n * D, orig_f + n * D, D * sizeof(float)); + } + } + break; + } + case float16: { + auto* out_h = static_cast(out0_raw); + const auto* orig_h = static_cast(orig_raw); + for (int n = 0; n < N; n++) { + if (has_valid[n]) { + for (int d = 0; d < D; d++) { + out_h[n * D + d] = + float16_t(accum[static_cast(n) * D + d]); + } + } else { + std::memcpy(out_h + n * D, orig_h + n * D, D * sizeof(float16_t)); + } + } + break; + } + case bfloat16: { + auto* out_h = static_cast(out0_raw); + const auto* orig_h = static_cast(orig_raw); + for (int n = 0; n < N; n++) { + if (has_valid[n]) { + for (int d = 0; d < D; d++) { + out_h[n * D + d] = + bfloat16_t(accum[static_cast(n) * D + d]); + } + } else { + std::memcpy( + out_h + n * D, orig_h + n * D, D * sizeof(bfloat16_t)); + } + } + break; + } + default: + throw std::runtime_error("[MoeCombineExchange] Unsupported dtype"); + } + return; + } + + // world_size > 2: fallback to existing all_to_all + { + int total_slots = experts_per_device * cap_total; // E * W * C + size_t total_bytes = (size_t)total_slots * D * elem_size; + int slots_per_device = experts_per_device * capacity; + + // Reverse transpose: [E, W*C, D] -> [W, E, C, D] + // send_arr[w, e, c, d] = expert_out[e, w*capacity+c, d] + array send_arr(Shape{total_slots, D}, dtype, nullptr, {}); + send_arr.set_data(allocator::malloc(total_bytes)); + + auto* send_bytes = static_cast(send_arr.data()); + const auto* eo_bytes = static_cast(eo_raw); + + for (int e = 0; e < experts_per_device; e++) { + for (int w = 0; w < world_size; w++) { + for (int c = 0; c < capacity; c++) { + int eo_row = e * cap_total + w * capacity + c; + int send_row = w * slots_per_device + e * capacity + c; + std::memcpy( + send_bytes + send_row * D * elem_size, + eo_bytes + eo_row * D * elem_size, + D * elem_size); + } + } + } + + array recv_arr(Shape{total_slots, D}, dtype, nullptr, {}); + recv_arr.set_data(allocator::malloc(total_bytes)); + + grp.raw_group()->blocking_all_to_all(send_arr, recv_arr); + + // recv_arr: [total_slots, D] flat, layout: [W, E, C, D] + // route_idx uses NEW layout: flat_idx = local_expert * cap_total + + // dest_rank * cap + pos Map new route_idx -> old recv_row: + // local_expert = flat_idx / cap_total + // dest_rank = (flat_idx % cap_total) / capacity + // pos = (flat_idx % cap_total) % capacity + // recv_row = dest_rank * slots_per_device + local_expert * capacity + // + pos + + switch (dtype) { + case float32: { + const auto* recv_f = static_cast(recv_arr.data()); + auto* out_f = static_cast(out0_raw); + const auto* orig_f = static_cast(orig_raw); + + for (int n = 0; n < N; n++) { + float* dst = out_f + n * D; + std::fill(dst, dst + D, 0.0f); + bool has_valid = false; + for (int k = 0; k < top_k; k++) { + int flat_idx = ri_raw[n * top_k + k]; + if (flat_idx >= 0) { + has_valid = true; + float w = w_raw[n * top_k + k]; + int local_expert = flat_idx / cap_total; + int remainder = flat_idx % cap_total; + int dest_rank = remainder / capacity; + int pos = remainder % capacity; + int recv_row = dest_rank * slots_per_device + + local_expert * capacity + pos; + const float* src = recv_f + recv_row * D; + for (int d = 0; d < D; d++) + dst[d] += w * src[d]; + } + } + if (!has_valid) { + std::memcpy(dst, orig_f + n * D, D * sizeof(float)); + } + } + break; + } + case float16: { + const auto* recv_h = + static_cast(recv_arr.data()); + auto* out_h = static_cast(out0_raw); + const auto* orig_h = static_cast(orig_raw); + + std::vector accum(D); + for (int n = 0; n < N; n++) { + std::fill(accum.begin(), accum.end(), 0.0f); + bool has_valid = false; + for (int k = 0; k < top_k; k++) { + int flat_idx = ri_raw[n * top_k + k]; + if (flat_idx >= 0) { + has_valid = true; + float w = w_raw[n * top_k + k]; + int local_expert = flat_idx / cap_total; + int remainder = flat_idx % cap_total; + int dest_rank = remainder / capacity; + int pos = remainder % capacity; + int recv_row = dest_rank * slots_per_device + + local_expert * capacity + pos; + const float16_t* src = recv_h + recv_row * D; + for (int d = 0; d < D; d++) { + accum[d] += w * static_cast(src[d]); + } + } + } + float16_t* dst = out_h + n * D; + if (has_valid) { + for (int d = 0; d < D; d++) { + dst[d] = float16_t(accum[d]); + } + } else { + std::memcpy(dst, orig_h + n * D, D * sizeof(float16_t)); + } + } + break; + } + case bfloat16: { + const auto* recv_h = + static_cast(recv_arr.data()); + auto* out_h = static_cast(out0_raw); + const auto* orig_h = static_cast(orig_raw); + + std::vector accum(D); + for (int n = 0; n < N; n++) { + std::fill(accum.begin(), accum.end(), 0.0f); + bool has_valid = false; + for (int k = 0; k < top_k; k++) { + int flat_idx = ri_raw[n * top_k + k]; + if (flat_idx >= 0) { + has_valid = true; + float w = w_raw[n * top_k + k]; + int local_expert = flat_idx / cap_total; + int remainder = flat_idx % cap_total; + int dest_rank = remainder / capacity; + int pos = remainder % capacity; + int recv_row = dest_rank * slots_per_device + + local_expert * capacity + pos; + const bfloat16_t* src = recv_h + recv_row * D; + for (int d = 0; d < D; d++) { + accum[d] += w * static_cast(src[d]); + } + } + } + bfloat16_t* dst = out_h + n * D; + if (has_valid) { + for (int d = 0; d < D; d++) { + dst[d] = bfloat16_t(accum[d]); + } + } else { + std::memcpy(dst, orig_h + n * D, D * sizeof(bfloat16_t)); + } + } + break; + } + default: + throw std::runtime_error( + "[MoeCombineExchange] Unsupported dtype. Use float32, float16, or bfloat16."); + } + // send_arr and recv_arr go out of scope here; their allocator memory + // is freed via the array destructor. + } + }); + + // Keep input arrays alive until the dispatched lambda has executed + enc.add_temporary(expert_out); + enc.add_temporary(route_idx); + enc.add_temporary(weights_in); + enc.add_temporary(orig_tok); +} } // namespace mlx::core::distributed diff --git a/mlx/backend/cuda/distributed.cu b/mlx/backend/cuda/distributed.cu index ac79875789..1be238cdee 100644 --- a/mlx/backend/cuda/distributed.cu +++ b/mlx/backend/cuda/distributed.cu @@ -118,4 +118,22 @@ void ReduceScatter::eval_gpu( throw std::runtime_error("Only sum scatter is supported. "); } } + +void AllToAll::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error("[AllToAll::eval_gpu] has no CUDA implementation."); +} + +void MoeDispatchExchange::eval_gpu( + const std::vector&, + std::vector&) { + throw std::runtime_error( + "[MoeDispatchExchange::eval_gpu] has no CUDA implementation."); +} + +void MoeCombineExchange::eval_gpu( + const std::vector&, + std::vector&) { + throw std::runtime_error( + "[MoeCombineExchange::eval_gpu] has no CUDA implementation."); +} } // namespace mlx::core::distributed diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 4074e7b1e9..e7c84704f0 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -34,6 +34,7 @@ make_jit_source(indexing/gather_front kernels/indexing/indexing.h) make_jit_source(indexing/gather_axis) make_jit_source(indexing/scatter_axis) make_jit_source(hadamard) +make_jit_source(moe) if(MLX_METAL_JIT) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jit_kernels.cpp) @@ -109,6 +110,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/moe.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp diff --git a/mlx/backend/metal/distributed.cpp b/mlx/backend/metal/distributed.cpp index 217ee3c946..9b172104ab 100644 --- a/mlx/backend/metal/distributed.cpp +++ b/mlx/backend/metal/distributed.cpp @@ -1,19 +1,29 @@ // Copyright © 2024 Apple Inc. +#include #include +#include +#include +#include #include "mlx/allocator.h" #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" +#include "mlx/backend/gpu/eval.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" #include "mlx/distributed/ops.h" #include "mlx/distributed/primitives.h" #include "mlx/fence.h" #include "mlx/scheduler.h" +#include "mlx/types/half_types.h" namespace mlx::core::distributed { +// Forward declare from moe.cpp +MTL::ComputePipelineState* +get_moe_kernel(metal::Device& d, const std::string& base_name, Dtype dtype); + void AllReduce::eval_gpu(const std::vector&, std::vector&) { throw std::runtime_error("[AllReduce::eval_gpu] has no GPU implementation."); } @@ -35,4 +45,893 @@ void ReduceScatter::eval_gpu(const std::vector&, std::vector&) { "[ReduceScatter::eval_gpu] has no GPU implementation."); } +void AllToAll::eval_gpu(const std::vector&, std::vector&) { + throw std::runtime_error("[AllToAll::eval_gpu] has no GPU implementation."); +} + +// --------------------------------------------------------------------------- +// MoeDispatchExchange::eval_gpu +// --------------------------------------------------------------------------- +// +// Architecture: +// CPU = O(N*top_k) routing only (route build, meta decode, exchange) +// GPU = O(N*D) data movement (scatter, gather via Metal kernels) +// +// Flows: +// ws==1: synchronize -> CPU route -> GPU zero-fill + dispatch_local +// ws==2: synchronize -> CPU route -> GPU zero-fill + dispatch_local + +// packet_gather -> mid-sync -> CPU exchange_v -> CPU decode -> +// GPU packet_scatter +// ws>2: CPU fallback (create CPU-stream primitive) +// --------------------------------------------------------------------------- + +void MoeDispatchExchange::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + try { + // Fault injection (test-only) + if (std::getenv("MLX_MOE_EP_FORCE_METAL_ERROR")) { + throw std::runtime_error("[MoE EP] Forced Metal error for testing"); + } + + assert(inputs.size() == 2); + assert(outputs.size() == 2); + + auto& s = stream(); + auto& d = metal::device(s.device); + int world_size = group().size(); + + // ws > 2: fallback to CPU eval + if (world_size > 2) { + MoeDispatchExchange cpu_fb( + to_stream(std::monostate{}, Device::cpu), + group(), + num_experts_, + capacity_, + deterministic_, + MoeBackend::Cpu); + cpu_fb.eval_cpu(inputs, outputs); + return; + } + + // Ensure inputs are evaluated on GPU + auto& tokens_in = inputs[0]; + auto& indices_in = inputs[1]; + + int N = tokens_in.shape(0); + int D_val = tokens_in.shape(1); + int top_k = indices_in.shape(1); + int num_experts = num_experts_; + int capacity = capacity_; + int experts_per_device = num_experts / std::max(world_size, 1); + int cap_total = world_size * capacity; + size_t elem_size = tokens_in.itemsize(); + Dtype dtype = tokens_in.dtype(); + + // Allocate output arrays + // outputs[0]: dispatched [E_local, cap_total, D] + // outputs[1]: route_indices [N, top_k] int32 + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + outputs[1].set_data(allocator::malloc(outputs[1].nbytes())); + + // Step 1: gpu::synchronize to flush upstream GPU ops so CPU can read + // expert_indices via UMA + gpu::synchronize(s); + + // CPU: read expert_indices via UMA + const int32_t* idx_ptr = indices_in.data(); + int32_t* route_ptr = outputs[1].data(); + + // Initialize route_indices to -1 (CPU, O(N*top_k)) + std::fill(route_ptr, route_ptr + N * top_k, int32_t(-1)); + + // Zero-fill dispatched output via CPU memset (UMA, after synchronize) + size_t out_nbytes = + static_cast(experts_per_device) * cap_total * D_val * elem_size; + std::memset(outputs[0].data(), 0, out_nbytes); + + // Expert count tracking for routing + std::vector expert_counts(num_experts, 0); + + // ========================================================================= + // ws == 1: local-only path + // ========================================================================= + if (world_size == 1) { + // Build slot_map and nk_indices for valid assignments + std::vector slot_map_vec; + std::vector nk_indices_vec; + slot_map_vec.reserve(N * top_k); + nk_indices_vec.reserve(N * top_k); + + // k-outer, n-inner for deterministic slot assignment + for (int k = 0; k < top_k; k++) { + for (int n = 0; n < N; n++) { + int eid = idx_ptr[n * top_k + k]; + if (eid < 0 || eid >= num_experts) + continue; + int pos = expert_counts[eid]++; + if (pos >= capacity) + continue; + + int flat_idx = eid * cap_total + pos; + route_ptr[n * top_k + k] = flat_idx; + + slot_map_vec.push_back(flat_idx); + nk_indices_vec.push_back(n * top_k + k); + } + } + + int valid_count = static_cast(slot_map_vec.size()); + if (valid_count == 0) { + // No valid assignments — output is already zero-filled + return; + } + + // Build GPU metadata buffers (UMA zero-copy) + array slot_map_buf({valid_count}, int32, nullptr, {}); + slot_map_buf.set_data(allocator::malloc(valid_count * sizeof(int32_t))); + std::memcpy( + slot_map_buf.data(), + slot_map_vec.data(), + valid_count * sizeof(int32_t)); + + array nk_indices_buf({valid_count}, int32, nullptr, {}); + nk_indices_buf.set_data(allocator::malloc(valid_count * sizeof(int32_t))); + std::memcpy( + nk_indices_buf.data(), + nk_indices_vec.data(), + valid_count * sizeof(int32_t)); + + // Launch moe_dispatch_local kernel + auto kernel = get_moe_kernel(d, "moe_dispatch_local", dtype); + auto& enc = d.get_command_encoder(s.index); + enc.set_compute_pipeline_state(kernel); + enc.set_input_array(tokens_in, 0); + enc.set_output_array(outputs[0], 1); + enc.set_input_array(slot_map_buf, 2); + enc.set_input_array(nk_indices_buf, 3); + enc.set_bytes(D_val, 4); + enc.set_bytes(top_k, 5); + + int tx = std::min(D_val, 256); + MTL::Size grid_dims = MTL::Size(D_val, valid_count, 1); + MTL::Size group_dims = MTL::Size(tx, 1, 1); + enc.dispatch_threads(grid_dims, group_dims); + + // Keep temporaries alive until command buffer commits + d.add_temporary(slot_map_buf, s.index); + d.add_temporary(nk_indices_buf, s.index); + return; + } + + // ========================================================================= + // ws == 2: v3 variable exchange protocol with Metal kernels + // ========================================================================= + { + int my_rank = group().rank(); + int peer = 1 - my_rank; + Group grp = group(); + + // Packet row layout: [header(16B) | payload(D*elem_size) | pad] + size_t raw_row = 16 + static_cast(D_val) * elem_size; + int row_stride = static_cast((raw_row + 15) & ~size_t(15)); + + int max_send = N * top_k; + int recv_cap = experts_per_device * capacity; + + // Build routing metadata on CPU + // local_slot_map[i] = flat_idx for i-th valid local assignment + // local_nk[i] = n*top_k+k for i-th valid local assignment + // remote_tok_idx[j] = n (token index) for j-th remote assignment + // remote_headers[j] = meta32 = (local_expert << 16) | (pos & 0xFFFF) + std::vector local_slot_map; + std::vector local_nk; + std::vector remote_tok_idx; + std::vector remote_headers; + + local_slot_map.reserve(N * top_k); + local_nk.reserve(N * top_k); + remote_tok_idx.reserve(N * top_k); + remote_headers.reserve(N * top_k); + + // k-outer, n-inner for deterministic slot assignment + for (int k = 0; k < top_k; k++) { + for (int n = 0; n < N; n++) { + int eid = idx_ptr[n * top_k + k]; + if (eid < 0 || eid >= num_experts) + continue; + int dest_rank = eid / experts_per_device; + int local_expert = eid % experts_per_device; + int pos = expert_counts[eid]++; + if (pos >= capacity) + continue; + + int flat_idx = local_expert * cap_total + dest_rank * capacity + pos; + route_ptr[n * top_k + k] = flat_idx; + + if (dest_rank == my_rank) { + // LOCAL + local_slot_map.push_back(flat_idx); + local_nk.push_back(n * top_k + k); + } else { + // REMOTE — will be packed into packet via GPU kernel + remote_tok_idx.push_back(n); + uint32_t meta = (static_cast(local_expert) << 16) | + (static_cast(pos) & 0xFFFF); + remote_headers.push_back(meta); + } + } + } + + int local_count = static_cast(local_slot_map.size()); + int send_count = static_cast(remote_tok_idx.size()); + + // --- GPU Phase 1: local scatter via moe_dispatch_local --- + if (local_count > 0) { + array slot_map_buf({local_count}, int32, nullptr, {}); + slot_map_buf.set_data(allocator::malloc(local_count * sizeof(int32_t))); + std::memcpy( + slot_map_buf.data(), + local_slot_map.data(), + local_count * sizeof(int32_t)); + + array nk_buf({local_count}, int32, nullptr, {}); + nk_buf.set_data(allocator::malloc(local_count * sizeof(int32_t))); + std::memcpy( + nk_buf.data(), + local_nk.data(), + local_count * sizeof(int32_t)); + + auto kernel = get_moe_kernel(d, "moe_dispatch_local", dtype); + auto& enc = d.get_command_encoder(s.index); + enc.set_compute_pipeline_state(kernel); + enc.set_input_array(tokens_in, 0); + enc.set_output_array(outputs[0], 1); + enc.set_input_array(slot_map_buf, 2); + enc.set_input_array(nk_buf, 3); + enc.set_bytes(D_val, 4); + enc.set_bytes(top_k, 5); + + int tx = std::min(D_val, 256); + MTL::Size grid_dims = MTL::Size(D_val, local_count, 1); + MTL::Size group_dims = MTL::Size(tx, 1, 1); + enc.dispatch_threads(grid_dims, group_dims); + + d.add_temporary(slot_map_buf, s.index); + d.add_temporary(nk_buf, s.index); + } + + // --- GPU Phase 2: pack remote tokens into packets via moe_packet_gather + // --- Allocate packet buffers + size_t send_pkt_bytes = + static_cast(std::max(send_count, 1)) * row_stride; + size_t recv_pkt_bytes = + static_cast(std::max(recv_cap, 1)) * row_stride; + + array send_pkt({static_cast(send_pkt_bytes)}, uint8, nullptr, {}); + send_pkt.set_data(allocator::malloc(send_pkt_bytes)); + + array recv_pkt({static_cast(recv_pkt_bytes)}, uint8, nullptr, {}); + recv_pkt.set_data(allocator::malloc(recv_pkt_bytes)); + + if (send_count > 0) { + // Build src_idx (token row indices) and headers buffers + array src_idx_buf({send_count}, int32, nullptr, {}); + src_idx_buf.set_data(allocator::malloc(send_count * sizeof(int32_t))); + std::memcpy( + src_idx_buf.data(), + remote_tok_idx.data(), + send_count * sizeof(int32_t)); + + array headers_buf({send_count}, uint32, nullptr, {}); + headers_buf.set_data(allocator::malloc(send_count * sizeof(uint32_t))); + std::memcpy( + headers_buf.data(), + remote_headers.data(), + send_count * sizeof(uint32_t)); + + auto pkt_kernel = get_moe_kernel(d, "moe_packet_gather", dtype); + auto& enc2 = d.get_command_encoder(s.index); + enc2.set_compute_pipeline_state(pkt_kernel); + enc2.set_input_array(tokens_in, 0); + enc2.set_output_array(send_pkt, 1); + enc2.set_input_array(src_idx_buf, 2); + enc2.set_input_array(headers_buf, 3); + enc2.set_bytes(D_val, 4); + enc2.set_bytes(send_count, 5); + enc2.set_bytes(row_stride, 6); + + int tx = std::min(D_val, 256); + MTL::Size grid_dims = MTL::Size(D_val, send_count, 1); + MTL::Size group_dims = MTL::Size(tx, 1, 1); + enc2.dispatch_threads(grid_dims, group_dims); + + d.add_temporary(src_idx_buf, s.index); + d.add_temporary(headers_buf, s.index); + } + + // --- Mid-sync: flush GPU work so packet data is ready for RDMA --- + gpu::synchronize(s); + + // --- CPU: v3 exchange --- + array count_send({1}, int32, nullptr, {}); + count_send.set_data(allocator::malloc(sizeof(int32_t))); + array count_recv({1}, int32, nullptr, {}); + count_recv.set_data(allocator::malloc(sizeof(int32_t))); + + auto* raw = grp.raw_group().get(); + int peer_count = raw->blocking_exchange_v( + send_pkt, + send_count, + recv_pkt, + recv_cap, + row_stride, + peer, + detail::ExchangeTag::MoeDispatchCount, + detail::ExchangeTag::MoeDispatchPayload, + count_send, + count_recv); + + // --- CPU: decode recv meta -> build recv_flat_idx --- + if (peer_count > 0) { + std::vector recv_flat_idx_vec(peer_count); + auto* recv_pkt_ptr = recv_pkt.data(); + + for (int i = 0; i < peer_count; i++) { + const uint8_t* row = + recv_pkt_ptr + static_cast(i) * row_stride; + uint32_t meta; + std::memcpy(&meta, row, 4); + int local_expert = static_cast(meta >> 16); + int slot_pos = static_cast(meta & 0xFFFF); + + if (local_expert < 0 || local_expert >= experts_per_device || + slot_pos < 0 || slot_pos >= capacity) { + throw std::runtime_error( + "[MoeDispatchExchange::eval_gpu] received out-of-bounds " + "metadata: local_expert=" + + std::to_string(local_expert) + + " slot_pos=" + std::to_string(slot_pos)); + } + recv_flat_idx_vec[i] = + local_expert * cap_total + peer * capacity + slot_pos; + } + + // Build GPU buffer for flat indices + array flat_idx_buf({peer_count}, int32, nullptr, {}); + flat_idx_buf.set_data(allocator::malloc(peer_count * sizeof(int32_t))); + std::memcpy( + flat_idx_buf.data(), + recv_flat_idx_vec.data(), + peer_count * sizeof(int32_t)); + + // --- GPU Phase 3: scatter recv packets into dispatched --- + auto scatter_kernel = get_moe_kernel(d, "moe_packet_scatter", dtype); + auto& enc3 = d.get_command_encoder(s.index); + enc3.set_compute_pipeline_state(scatter_kernel); + enc3.set_input_array(recv_pkt, 0); + enc3.set_output_array(outputs[0], 1); + enc3.set_input_array(flat_idx_buf, 2); + enc3.set_bytes(D_val, 3); + enc3.set_bytes(peer_count, 4); + enc3.set_bytes(row_stride, 5); + + int tx = std::min(D_val, 256); + MTL::Size grid_dims = MTL::Size(D_val, peer_count, 1); + MTL::Size group_dims = MTL::Size(tx, 1, 1); + enc3.dispatch_threads(grid_dims, group_dims); + + d.add_temporary(flat_idx_buf, s.index); + d.add_temporary(recv_pkt, s.index); + } + // send_pkt, count_send, count_recv will be freed when going out of scope + // (they are no longer referenced by any GPU command after synchronize) + } + + } catch (const std::exception& e) { + const char* fb_env = std::getenv("MLX_MOE_EP_FALLBACK_ON_ERROR"); + bool fallback_enabled = !fb_env || std::string(fb_env) != "0"; + + if (!fallback_enabled) { + throw; // rethrow for debug/CI + } + + // Log warning (once) + static std::once_flag warn_flag; + std::call_once(warn_flag, [&]() { + std::cerr << "[MoE EP] Metal eval_gpu failed: " << e.what() + << ". Falling back to CPU.\n"; + }); + + // Flush any partial GPU state before CPU fallback + try { + gpu::synchronize(stream()); + } catch (...) { + } + + // CPU fallback: create new CPU-stream primitive + MoeDispatchExchange cpu_prim( + to_stream(std::monostate{}, Device::cpu), + group(), + num_experts_, + capacity_, + deterministic_, + MoeBackend::Cpu); + cpu_prim.eval_cpu(inputs, outputs); + } +} + +// --------------------------------------------------------------------------- +// MoeCombineExchange::eval_gpu +// --------------------------------------------------------------------------- +// +// Flows: +// ws==1: synchronize -> CPU route -> GPU combine_weighted_sum +// ws==2: synchronize -> CPU route -> CPU exchange_v (requests) -> +// CPU decode -> GPU packet_gather (responses) -> mid-sync -> +// CPU exchange_v (responses) -> CPU decode -> +// GPU build unified_src -> GPU packet_scatter -> +// GPU combine_weighted_sum +// ws>2: CPU fallback +// --------------------------------------------------------------------------- + +void MoeCombineExchange::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + try { + // Fault injection (test-only) + if (std::getenv("MLX_MOE_EP_FORCE_METAL_ERROR")) { + throw std::runtime_error("[MoE EP] Forced Metal error for testing"); + } + + assert(inputs.size() == 4); + assert(outputs.size() == 1); + + auto& s = stream(); + auto& d = metal::device(s.device); + int world_size = group().size(); + + // ws > 2: fallback to CPU eval + if (world_size > 2) { + MoeCombineExchange cpu_fb( + to_stream(std::monostate{}, Device::cpu), + group(), + num_experts_, + capacity_, + deterministic_, + MoeBackend::Cpu); + cpu_fb.eval_cpu(inputs, outputs); + return; + } + + // inputs: expert_outputs [E_local, cap_total, D], + // route_indices [N, top_k] int32, + // weights [N, top_k] float32, + // original_tokens [N, D] + auto& expert_out = inputs[0]; + auto& route_idx = inputs[1]; + auto& weights_in = inputs[2]; + auto& orig_tok = inputs[3]; + + int experts_per_device = expert_out.shape(0); + int cap_total = expert_out.shape(1); + int D_val = expert_out.shape(2); + int N = orig_tok.shape(0); + int top_k = route_idx.shape(1); + int capacity = capacity_; + size_t elem_size = expert_out.itemsize(); + Dtype dtype = expert_out.dtype(); + + // Allocate output + outputs[0].set_data(allocator::malloc(outputs[0].nbytes())); + + // Step 1: gpu::synchronize to flush upstream so CPU can read route_indices + gpu::synchronize(s); + + const int32_t* ri_ptr = route_idx.data(); + const float* w_ptr = weights_in.data(); + + // ========================================================================= + // ws == 1: local-only path + // ========================================================================= + if (world_size == 1) { + // Build src_idx for the weighted_sum kernel: + // src_idx[n*top_k+k] = flat_idx into expert_out, or -1 + // For ws=1, expert_out is directly indexed by route_idx flat values + // since data_src == expert_out + + // src_idx is just the route_indices — already in the right format + // We can pass route_idx directly as src_idx + + auto kernel = get_moe_kernel(d, "moe_combine_weighted_sum", dtype); + auto& enc = d.get_command_encoder(s.index); + enc.set_compute_pipeline_state(kernel); + enc.set_input_array(expert_out, 0); // data_src + enc.set_output_array(outputs[0], 1); // output + enc.set_input_array(orig_tok, 2); // original + enc.set_input_array(weights_in, 3); // weights + enc.set_input_array(route_idx, 4); // src_idx + enc.set_bytes(D_val, 5); + enc.set_bytes(N, 6); + enc.set_bytes(top_k, 7); + + int tx = std::min(D_val, 256); + MTL::Size grid_dims = MTL::Size(D_val, N, 1); + MTL::Size group_dims = MTL::Size(tx, 1, 1); + enc.dispatch_threads(grid_dims, group_dims); + + return; + } + + // ========================================================================= + // ws == 2: v3 combine protocol with Metal kernels + // ========================================================================= + { + int my_rank = group().rank(); + int peer = 1 - my_rank; + Group grp = group(); + + // Response row layout: [header(16B) | payload(D*elem_size) | pad] + size_t raw_resp_row = 16 + static_cast(D_val) * elem_size; + int resp_stride = static_cast((raw_resp_row + 15) & ~size_t(15)); + + // Request row layout: [token_slot32(4B) | local_expert16(2B) | pos16(2B)] + int req_stride = 8; + + int max_local_routes = N * top_k; + int max_peer_routes = experts_per_device * capacity; + + // --------------- Step 2: CPU route analysis --------------- + // Separate local vs remote routes + + // For local routes: we need src_idx entries pointing into expert_out + // For remote routes: we need to send requests to peer + + // src_idx[n*top_k+k] will eventually be an index into unified_src + // For ws==2, unified_src = expert_out (possibly copied) + scattered + // responses + + // We track: + // - local_entries: (nk_idx, flat_idx_in_expert_out) for local + // accumulation + // - remote_entries: (nk_idx, local_expert, pos) for remote requests + struct LocalEntry { + int nk_idx; + int flat_idx; + }; + struct RemoteEntry { + int nk_idx; + int local_expert; + int pos; + }; + + std::vector local_entries; + std::vector remote_entries; + local_entries.reserve(N * top_k); + remote_entries.reserve(N * top_k); + + for (int k = 0; k < top_k; k++) { + for (int n = 0; n < N; n++) { + int flat_idx = ri_ptr[n * top_k + k]; + if (flat_idx < 0) + continue; + + int remainder = flat_idx % cap_total; + int dest_rank = remainder / capacity; + + if (dest_rank == my_rank) { + local_entries.push_back({n * top_k + k, flat_idx}); + } else { + int local_expert_idx = flat_idx / cap_total; + int pos = remainder % capacity; + remote_entries.push_back({n * top_k + k, local_expert_idx, pos}); + } + } + } + + int req_send_count = static_cast(remote_entries.size()); + + // --------------- Step 3: CPU exchange requests --------------- + // Allocate request buffers + size_t req_send_bytes = + static_cast(std::max(req_send_count, 1)) * req_stride; + size_t req_recv_bytes = + static_cast(std::max(max_peer_routes, 1)) * req_stride; + + array req_send({static_cast(req_send_bytes)}, uint8, nullptr, {}); + req_send.set_data(allocator::malloc(req_send_bytes)); + array req_recv({static_cast(req_recv_bytes)}, uint8, nullptr, {}); + req_recv.set_data(allocator::malloc(req_recv_bytes)); + + // Pack requests + auto* req_send_ptr = req_send.data(); + for (int i = 0; i < req_send_count; i++) { + auto& re = remote_entries[i]; + uint32_t token_slot = static_cast(re.nk_idx); + uint16_t le16 = static_cast(re.local_expert); + uint16_t pos16 = static_cast(re.pos); + + uint8_t* row = req_send_ptr + static_cast(i) * req_stride; + std::memcpy(row, &token_slot, 4); + std::memcpy(row + 4, &le16, 2); + std::memcpy(row + 6, &pos16, 2); + } + + // Count exchange arrays + array count_send({1}, int32, nullptr, {}); + count_send.set_data(allocator::malloc(sizeof(int32_t))); + array count_recv({1}, int32, nullptr, {}); + count_recv.set_data(allocator::malloc(sizeof(int32_t))); + + auto* raw = grp.raw_group().get(); + + int peer_req_count = raw->blocking_exchange_v( + req_send, + req_send_count, + req_recv, + max_peer_routes, + req_stride, + peer, + detail::ExchangeTag::MoeCombineReqCount, + detail::ExchangeTag::MoeCombineReqPayload, + count_send, + count_recv); + + // --------------- Step 4: decode peer requests, build response metadata + // --- For each received request, we need to: + // - look up the flat_idx in expert_out + // - record the token_slot for the response header + std::vector eo_flat_idx_vec(peer_req_count); + std::vector resp_headers_vec(peer_req_count); + + auto* req_recv_ptr = req_recv.data(); + for (int i = 0; i < peer_req_count; i++) { + const uint8_t* req_row = + req_recv_ptr + static_cast(i) * req_stride; + uint32_t token_slot; + uint16_t le16, pos16; + std::memcpy(&token_slot, req_row, 4); + std::memcpy(&le16, req_row + 4, 2); + std::memcpy(&pos16, req_row + 6, 2); + + int local_expert = static_cast(le16); + int slot_pos = static_cast(pos16); + + if (local_expert < 0 || local_expert >= experts_per_device || + slot_pos < 0 || slot_pos >= capacity) { + throw std::runtime_error( + "[MoeCombineExchange::eval_gpu] out-of-bounds request: " + "local_expert=" + + std::to_string(local_expert) + + " pos=" + std::to_string(slot_pos)); + } + + // expert_out flat index: expert_out[local_expert, + // peer*capacity+slot_pos] + int eo_flat = local_expert * cap_total + peer * capacity + slot_pos; + eo_flat_idx_vec[i] = eo_flat; + resp_headers_vec[i] = token_slot; + } + + // --------------- Step 5: GPU packet_gather (pack responses) + // --------------- + size_t resp_send_bytes = + static_cast(std::max(peer_req_count, 1)) * resp_stride; + size_t resp_recv_bytes = + static_cast(std::max(req_send_count, 1)) * resp_stride; + + array resp_send({static_cast(resp_send_bytes)}, uint8, nullptr, {}); + resp_send.set_data(allocator::malloc(resp_send_bytes)); + array resp_recv({static_cast(resp_recv_bytes)}, uint8, nullptr, {}); + resp_recv.set_data(allocator::malloc(resp_recv_bytes)); + + if (peer_req_count > 0) { + // Build GPU metadata buffers + array eo_flat_idx_buf({peer_req_count}, int32, nullptr, {}); + eo_flat_idx_buf.set_data( + allocator::malloc(peer_req_count * sizeof(int32_t))); + std::memcpy( + eo_flat_idx_buf.data(), + eo_flat_idx_vec.data(), + peer_req_count * sizeof(int32_t)); + + array resp_hdr_buf({peer_req_count}, uint32, nullptr, {}); + resp_hdr_buf.set_data( + allocator::malloc(peer_req_count * sizeof(uint32_t))); + std::memcpy( + resp_hdr_buf.data(), + resp_headers_vec.data(), + peer_req_count * sizeof(uint32_t)); + + auto pkt_kernel = get_moe_kernel(d, "moe_packet_gather", dtype); + auto& enc = d.get_command_encoder(s.index); + enc.set_compute_pipeline_state(pkt_kernel); + enc.set_input_array(expert_out, 0); // source + enc.set_output_array(resp_send, 1); // packet + enc.set_input_array(eo_flat_idx_buf, 2); // src_idx + enc.set_input_array(resp_hdr_buf, 3); // headers + enc.set_bytes(D_val, 4); + enc.set_bytes(peer_req_count, 5); + enc.set_bytes(resp_stride, 6); + + int tx = std::min(D_val, 256); + MTL::Size grid_dims = MTL::Size(D_val, peer_req_count, 1); + MTL::Size group_dims = MTL::Size(tx, 1, 1); + enc.dispatch_threads(grid_dims, group_dims); + + d.add_temporary(eo_flat_idx_buf, s.index); + d.add_temporary(resp_hdr_buf, s.index); + } + + // --------------- Mid-sync: flush GPU packet_gather --------------- + gpu::synchronize(s); + + // --------------- Step 6: CPU exchange responses --------------- + int peer_res_count = raw->blocking_exchange_v( + resp_send, + peer_req_count, + resp_recv, + req_send_count, + resp_stride, + peer, + detail::ExchangeTag::MoeCombineResCount, + detail::ExchangeTag::MoeCombineResPayload, + count_send, + count_recv); + + // --------------- Step 7: decode recv responses -> build scatter indices + // --- resp_recv contains [token_slot32(in 16B header) | payload] We need + // to: + // 1. Build unified_src workspace that contains all data rows + // (expert_out rows for local + received response rows for remote) + // 2. Build src_idx[N*top_k] mapping (n,k) -> row in unified_src + // 3. Run moe_combine_weighted_sum + + // unified_src layout: + // Rows 0..E_local*cap_total-1 = expert_out (for local lookups) + // Rows E_local*cap_total.. = received response payloads + int eo_total_rows = experts_per_device * cap_total; + int unified_total_rows = eo_total_rows + peer_res_count; + + // Allocate unified_src + size_t unified_nbytes = + static_cast(unified_total_rows) * D_val * elem_size; + array unified_src({unified_total_rows, D_val}, dtype, nullptr, {}); + unified_src.set_data(allocator::malloc(unified_nbytes)); + + // Copy expert_out into unified_src base region + // After synchronize, CPU can safely memcpy from expert_out (UMA) + std::memcpy( + unified_src.data(), + expert_out.data(), + static_cast(eo_total_rows) * D_val * elem_size); + + // Build src_idx on CPU + // Initialize all to -1 + std::vector src_idx_vec(N * top_k, -1); + + // Local entries: src_idx -> flat_idx in expert_out = row in unified_src + for (auto& le : local_entries) { + src_idx_vec[le.nk_idx] = le.flat_idx; + } + + // Decode responses and scatter payloads into unified_src + if (peer_res_count > 0) { + auto* resp_recv_ptr = resp_recv.data(); + + // Build a map from token_slot -> response index for the scatter + std::vector resp_flat_idx_vec(peer_res_count); + + for (int i = 0; i < peer_res_count; i++) { + const uint8_t* resp_row = + resp_recv_ptr + static_cast(i) * resp_stride; + uint32_t token_slot; + std::memcpy(&token_slot, resp_row, 4); + + if (token_slot >= static_cast(N * top_k)) { + throw std::runtime_error( + "[MoeCombineExchange::eval_gpu] invalid token_slot=" + + std::to_string(token_slot)); + } + + // This response goes to row eo_total_rows + i in unified_src + int unified_row = eo_total_rows + i; + src_idx_vec[static_cast(token_slot)] = unified_row; + + // Target row in unified_src for packet_scatter + resp_flat_idx_vec[i] = unified_row; + } + + // GPU: scatter received response payloads into unified_src + array resp_flat_buf({peer_res_count}, int32, nullptr, {}); + resp_flat_buf.set_data( + allocator::malloc(peer_res_count * sizeof(int32_t))); + std::memcpy( + resp_flat_buf.data(), + resp_flat_idx_vec.data(), + peer_res_count * sizeof(int32_t)); + + auto scatter_kernel = get_moe_kernel(d, "moe_packet_scatter", dtype); + auto& enc_scatter = d.get_command_encoder(s.index); + enc_scatter.set_compute_pipeline_state(scatter_kernel); + enc_scatter.set_input_array(resp_recv, 0); + enc_scatter.set_output_array(unified_src, 1); + enc_scatter.set_input_array(resp_flat_buf, 2); + enc_scatter.set_bytes(D_val, 3); + enc_scatter.set_bytes(peer_res_count, 4); + enc_scatter.set_bytes(resp_stride, 5); + + int tx = std::min(D_val, 256); + MTL::Size grid_dims = MTL::Size(D_val, peer_res_count, 1); + MTL::Size group_dims = MTL::Size(tx, 1, 1); + enc_scatter.dispatch_threads(grid_dims, group_dims); + + d.add_temporary(resp_flat_buf, s.index); + d.add_temporary(resp_recv, s.index); + } + + // --------------- Step 8: GPU combine_weighted_sum --------------- + // Build src_idx GPU buffer + array src_idx_buf({N * top_k}, int32, nullptr, {}); + src_idx_buf.set_data(allocator::malloc(N * top_k * sizeof(int32_t))); + std::memcpy( + src_idx_buf.data(), + src_idx_vec.data(), + N * top_k * sizeof(int32_t)); + + auto ws_kernel = get_moe_kernel(d, "moe_combine_weighted_sum", dtype); + auto& enc_ws = d.get_command_encoder(s.index); + enc_ws.set_compute_pipeline_state(ws_kernel); + enc_ws.set_input_array(unified_src, 0); // data_src + enc_ws.set_output_array(outputs[0], 1); // output + enc_ws.set_input_array(orig_tok, 2); // original + enc_ws.set_input_array(weights_in, 3); // weights + enc_ws.set_input_array(src_idx_buf, 4); // src_idx + enc_ws.set_bytes(D_val, 5); + enc_ws.set_bytes(N, 6); + enc_ws.set_bytes(top_k, 7); + + int tx = std::min(D_val, 256); + MTL::Size grid_dims = MTL::Size(D_val, N, 1); + MTL::Size group_dims = MTL::Size(tx, 1, 1); + enc_ws.dispatch_threads(grid_dims, group_dims); + + // Keep temporaries alive + d.add_temporary(unified_src, s.index); + d.add_temporary(src_idx_buf, s.index); + } + + } catch (const std::exception& e) { + const char* fb_env = std::getenv("MLX_MOE_EP_FALLBACK_ON_ERROR"); + bool fallback_enabled = !fb_env || std::string(fb_env) != "0"; + + if (!fallback_enabled) { + throw; // rethrow for debug/CI + } + + // Log warning (once) + static std::once_flag combine_warn_flag; + std::call_once(combine_warn_flag, [&]() { + std::cerr << "[MoE EP] Metal eval_gpu failed: " << e.what() + << ". Falling back to CPU.\n"; + }); + + // Flush any partial GPU state before CPU fallback + try { + gpu::synchronize(stream()); + } catch (...) { + } + + // CPU fallback: create new CPU-stream primitive + MoeCombineExchange cpu_prim( + to_stream(std::monostate{}, Device::cpu), + group(), + num_experts_, + capacity_, + deterministic_, + MoeBackend::Cpu); + cpu_prim.eval_cpu(inputs, outputs); + } +} + } // namespace mlx::core::distributed diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index bd58a691a9..d567f92b31 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -41,6 +41,11 @@ void eval(array& arr) { debug_set_primitive_buffer_label(command_buffer, arr.primitive()); arr.primitive().eval_gpu(arr.inputs(), outputs); + + // Re-capture command buffer in case eval_gpu performed a mid-sync + // (e.g., distributed primitives that call gpu::synchronize internally). + // If no mid-sync occurred, this returns the same buffer. + command_buffer = d.get_command_buffer(s.index); } std::unordered_set> buffers; for (auto& in : arr.inputs()) { diff --git a/mlx/backend/metal/jit/includes.h b/mlx/backend/metal/jit/includes.h index a6ef0f14af..9e639ca4c7 100644 --- a/mlx/backend/metal/jit/includes.h +++ b/mlx/backend/metal/jit/includes.h @@ -23,6 +23,7 @@ const char* gather_axis(); const char* gather_front(); const char* hadamard(); const char* logsumexp(); +const char* moe(); const char* quantized_utils(); const char* quantized(); const char* fp_quantized(); diff --git a/mlx/backend/metal/kernels/moe.h b/mlx/backend/metal/kernels/moe.h new file mode 100644 index 0000000000..969f240c85 --- /dev/null +++ b/mlx/backend/metal/kernels/moe.h @@ -0,0 +1,203 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +// MoE Expert Parallelism Metal Kernels +// +// These kernels accelerate the scatter/gather operations in +// MoeDispatchExchange and MoeCombineExchange primitives. + +// Kernel 1: moe_dispatch_local +// Scatters LOCAL tokens into the dispatched buffer using a precomputed +// slot_map. slot_map[nk] = flat_idx into dispatched[E_local, cap_total, D], or +// -1 to skip. +// +// Grid: (D_ceil, valid_count, 1) where D_ceil = ceil(D / ELEM_PER_THREAD) +// Group: (min(D_ceil, 256), 1, 1) +// +// Note: valid_count = number of (n,k) pairs with slot_map >= 0 +// nk_indices[i] = original n*top_k+k index for the i-th valid entry +template +[[kernel]] void moe_dispatch_local( + const device T* tokens [[buffer(0)]], // [N, D] + device T* dispatched [[buffer(1)]], // [E_local * cap_total * D] flat + const device int* slot_map [[buffer(2)]], // [valid_count] flat_idx values + const device int* nk_indices + [[buffer(3)]], // [valid_count] n*top_k+k originals + constant int& D [[buffer(4)]], + constant int& top_k [[buffer(5)]], + uint2 gid [[thread_position_in_grid]]) { + int d = gid.x; + int i = gid.y; + if (d >= D) + return; + + int flat_idx = slot_map[i]; + int nk = nk_indices[i]; + int n = nk / top_k; + + dispatched[static_cast(flat_idx) * D + d] = + tokens[static_cast(n) * D + d]; +} + +// Kernel 2: moe_dispatch_scatter_remote +// Scatters received remote tokens (from RDMA exchange) into dispatched buffer. +// recv_meta[i] = (local_expert, pos) packed as meta32 +// recv_payload is contiguous [cnt, D] after meta extraction on CPU. +// +// Grid: (D_ceil, cnt, 1) +// Group: (min(D_ceil, 256), 1, 1) +template +[[kernel]] void moe_dispatch_scatter_remote( + const device T* recv_payload [[buffer(0)]], // [cnt, D] + device T* dispatched [[buffer(1)]], // [E_local * cap_total * D] flat + const device int* recv_flat_idx + [[buffer(2)]], // [cnt] precomputed flat indices + constant int& D [[buffer(3)]], + constant int& cnt [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + int d = gid.x; + int i = gid.y; + if (d >= D || i >= cnt) + return; + + int flat_idx = recv_flat_idx[i]; + dispatched[static_cast(flat_idx) * D + d] = + recv_payload[static_cast(i) * D + d]; +} + +// Kernel 3: moe_combine_gather_remote +// Gathers expert outputs for peer's requested tokens into a send buffer. +// For each request i, lookup expert_out at flat index and copy to send_results. +// +// Grid: (D_ceil, cnt, 1) +// Group: (min(D_ceil, 256), 1, 1) +template +[[kernel]] void moe_combine_gather_remote( + const device T* expert_out [[buffer(0)]], // [E_local * cap_total * D] flat + device T* send_results [[buffer(1)]], // [cnt, D] + const device int* eo_flat_idx + [[buffer(2)]], // [cnt] precomputed flat indices + constant int& D [[buffer(3)]], + constant int& cnt [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + int d = gid.x; + int i = gid.y; + if (d >= D || i >= cnt) + return; + + int flat_idx = eo_flat_idx[i]; + send_results[static_cast(i) * D + d] = + expert_out[static_cast(flat_idx) * D + d]; +} + +// Kernel 4: moe_combine_weighted_sum +// Performs weighted accumulation of expert outputs per token. +// For each token n, sums weights[n,k] * src[k_data_idx, d] for k=0..top_k-1. +// Uses float32 accumulation for precision regardless of input dtype. +// +// data_src contains interleaved local and remote results indexed by src_idx. +// src_idx[n * top_k + k] = index into data_src for that (n,k) pair, or -1 to +// skip. +// +// Grid: (D_ceil, N, 1) +// Group: (min(D_ceil, 256), 1, 1) +template +[[kernel]] void moe_combine_weighted_sum( + const device T* data_src [[buffer(0)]], // [total_entries, D] + device T* output [[buffer(1)]], // [N, D] + const device T* original [[buffer(2)]], // [N, D] fallback + const device float* weights [[buffer(3)]], // [N, top_k] + const device int* src_idx [[buffer(4)]], // [N * top_k] + constant int& D [[buffer(5)]], + constant int& N [[buffer(6)]], + constant int& top_k [[buffer(7)]], + uint2 gid [[thread_position_in_grid]]) { + int d = gid.x; + int n = gid.y; + if (d >= D || n >= N) + return; + + float accum = 0.0f; + bool has_valid = false; + + for (int k = 0; k < top_k; k++) { + int idx = src_idx[n * top_k + k]; + if (idx >= 0) { + has_valid = true; + float w = weights[n * top_k + k]; + accum += w * static_cast(data_src[static_cast(idx) * D + d]); + } + } + + if (has_valid) { + output[static_cast(n) * D + d] = static_cast(accum); + } else { + output[static_cast(n) * D + d] = + original[static_cast(n) * D + d]; + } +} + +// Kernel 5: moe_packet_gather +// Gathers rows from a source buffer into packet format with 16B headers. +// Each packet row = [header(16B) | payload(D*sizeof(T)) | pad] aligned to +// row_stride. Used for dispatch remote pack and combine response pack. +// +// Grid: (D, cnt, 1) +// Group: (min(D, 256), 1, 1) +template +[[kernel]] void moe_packet_gather( + const device T* source [[buffer(0)]], // flat source [rows, D] + device uint8_t* packet [[buffer(1)]], // [cnt, row_stride] + const device int* src_idx [[buffer(2)]], // [cnt] source row indices + const device uint32_t* headers [[buffer(3)]], // [cnt] header values + constant int& D [[buffer(4)]], + constant int& cnt [[buffer(5)]], + constant int& row_stride [[buffer(6)]], + uint2 gid [[thread_position_in_grid]]) { + int d = gid.x; + int i = gid.y; + if (d >= D || i >= cnt) + return; + + long pkt_base = (long)i * row_stride; + + // Write header into 16B-aligned region (first thread per row only) + if (d == 0) { + *reinterpret_cast(packet + pkt_base) = headers[i]; + } + + // Write payload at offset 16 (aligned for vectorized access) + int row = src_idx[i]; + device T* payload = reinterpret_cast(packet + pkt_base + 16); + payload[d] = source[(long)row * D + d]; +} + +// Kernel 6: moe_packet_scatter +// Scatters payload from packet format into a target buffer. +// Each packet row = [header(16B) | payload(D*sizeof(T)) | pad]. +// flat_idx provides the destination row index in the target buffer. +// +// Grid: (D, cnt, 1) +// Group: (min(D, 256), 1, 1) +template +[[kernel]] void moe_packet_scatter( + const device uint8_t* packet [[buffer(0)]], // [cnt, row_stride] + device T* target [[buffer(1)]], // flat target buffer + const device int* flat_idx [[buffer(2)]], // [cnt] target row indices + constant int& D [[buffer(3)]], + constant int& cnt [[buffer(4)]], + constant int& row_stride [[buffer(5)]], + uint2 gid [[thread_position_in_grid]]) { + int d = gid.x; + int i = gid.y; + if (d >= D || i >= cnt) + return; + + long pkt_base = (long)i * row_stride; + // Read payload at offset 16 (aligned) + const device T* payload = + reinterpret_cast(packet + pkt_base + 16); + int out_idx = flat_idx[i]; + target[(long)out_idx * D + d] = payload[d]; +} diff --git a/mlx/backend/metal/moe.cpp b/mlx/backend/metal/moe.cpp new file mode 100644 index 0000000000..855fea9eaf --- /dev/null +++ b/mlx/backend/metal/moe.cpp @@ -0,0 +1,129 @@ +// Copyright © 2026 Apple Inc. + +// MoE Expert Parallelism Metal kernel launch helpers. +// +// Provides get_moe_kernel() which JIT-compiles and caches the six MoE +// Metal kernels declared in kernels/moe.h: +// - moe_dispatch_local +// - moe_dispatch_scatter_remote +// - moe_combine_gather_remote +// - moe_combine_weighted_sum +// - moe_packet_gather +// - moe_packet_scatter +// +// The actual eval_gpu dispatch logic lives in distributed.cpp. + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/jit/includes.h" +#include "mlx/backend/metal/utils.h" + +namespace mlx::core::distributed { + +namespace { + +std::string moe_type_string(Dtype dtype) { + switch (dtype) { + case float32: + return "float"; + case float16: + return "half"; + case bfloat16: + return "bfloat16_t"; + default: + throw std::runtime_error( + "[moe] Unsupported dtype for Metal kernel. " + "Expected float32, float16, or bfloat16."); + } +} + +std::string moe_type_suffix(Dtype dtype) { + switch (dtype) { + case float32: + return "float32"; + case float16: + return "float16"; + case bfloat16: + return "bfloat16"; + default: + throw std::runtime_error("[moe] Unsupported dtype."); + } +} + +} // namespace + +MTL::ComputePipelineState* +get_moe_kernel(metal::Device& d, const std::string& base_name, Dtype dtype) { + auto type_str = moe_type_string(dtype); + auto suffix = moe_type_suffix(dtype); + auto kernel_name = base_name + "_" + suffix; + + auto lib = d.get_library(kernel_name, [&]() { + std::string source = metal::utils(); + source += metal::moe(); + source += "\ntemplate [[host_name(\"" + kernel_name + "\")]] "; + source += "[[kernel]] void " + base_name + "<" + type_str + ">("; + + // Explicit template instantiation with named parameters so that + // Metal [[buffer(N)]] / [[thread_position_in_grid]] attributes + // bind to parameters (not types). + if (base_name == "moe_dispatch_local") { + source += "const device " + type_str + "* tokens [[buffer(0)]], "; + source += "device " + type_str + "* dispatched [[buffer(1)]], "; + source += "const device int* slot_map [[buffer(2)]], "; + source += "const device int* nk_indices [[buffer(3)]], "; + source += "constant int& D [[buffer(4)]], "; + source += "constant int& top_k [[buffer(5)]], "; + source += "uint2 gid [[thread_position_in_grid]]);\n"; + } else if (base_name == "moe_dispatch_scatter_remote") { + source += "const device " + type_str + "* recv_payload [[buffer(0)]], "; + source += "device " + type_str + "* dispatched [[buffer(1)]], "; + source += "const device int* recv_flat_idx [[buffer(2)]], "; + source += "constant int& D [[buffer(3)]], "; + source += "constant int& cnt [[buffer(4)]], "; + source += "uint2 gid [[thread_position_in_grid]]);\n"; + } else if (base_name == "moe_combine_gather_remote") { + source += "const device " + type_str + "* expert_out [[buffer(0)]], "; + source += "device " + type_str + "* send_results [[buffer(1)]], "; + source += "const device int* eo_flat_idx [[buffer(2)]], "; + source += "constant int& D [[buffer(3)]], "; + source += "constant int& cnt [[buffer(4)]], "; + source += "uint2 gid [[thread_position_in_grid]]);\n"; + } else if (base_name == "moe_combine_weighted_sum") { + source += "const device " + type_str + "* data_src [[buffer(0)]], "; + source += "device " + type_str + "* output [[buffer(1)]], "; + source += "const device " + type_str + "* original [[buffer(2)]], "; + source += "const device float* weights [[buffer(3)]], "; + source += "const device int* src_idx [[buffer(4)]], "; + source += "constant int& D [[buffer(5)]], "; + source += "constant int& N [[buffer(6)]], "; + source += "constant int& top_k [[buffer(7)]], "; + source += "uint2 gid [[thread_position_in_grid]]);\n"; + } else if (base_name == "moe_packet_gather") { + source += "const device " + type_str + "* source [[buffer(0)]], "; + source += "device uint8_t* packet [[buffer(1)]], "; + source += "const device int* src_idx [[buffer(2)]], "; + source += "const device uint32_t* headers [[buffer(3)]], "; + source += "constant int& D [[buffer(4)]], "; + source += "constant int& cnt [[buffer(5)]], "; + source += "constant int& row_stride [[buffer(6)]], "; + source += "uint2 gid [[thread_position_in_grid]]);\n"; + } else if (base_name == "moe_packet_scatter") { + source += "const device uint8_t* packet [[buffer(0)]], "; + source += "device " + type_str + "* target [[buffer(1)]], "; + source += "const device int* flat_idx [[buffer(2)]], "; + source += "constant int& D [[buffer(3)]], "; + source += "constant int& cnt [[buffer(4)]], "; + source += "constant int& row_stride [[buffer(5)]], "; + source += "uint2 gid [[thread_position_in_grid]]);\n"; + } else { + throw std::runtime_error( + "[get_moe_kernel] Unknown kernel base name: " + base_name); + } + + return source; + }); + + return d.get_kernel(kernel_name, lib); +} + +} // namespace mlx::core::distributed diff --git a/mlx/backend/no_gpu/primitives.cpp b/mlx/backend/no_gpu/primitives.cpp index 4819ed2724..d7031b8785 100644 --- a/mlx/backend/no_gpu/primitives.cpp +++ b/mlx/backend/no_gpu/primitives.cpp @@ -180,6 +180,9 @@ NO_GPU_MULTI(AllGather) NO_GPU_MULTI(Send) NO_GPU_MULTI(Recv) NO_GPU_MULTI(ReduceScatter) +NO_GPU_MULTI(AllToAll) +NO_GPU_MULTI(MoeDispatchExchange) +NO_GPU_MULTI(MoeCombineExchange) } // namespace distributed } // namespace mlx::core diff --git a/mlx/distributed/CMakeLists.txt b/mlx/distributed/CMakeLists.txt index 807fd0fbdb..618979c162 100644 --- a/mlx/distributed/CMakeLists.txt +++ b/mlx/distributed/CMakeLists.txt @@ -2,7 +2,8 @@ target_sources( mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/distributed_impl.cpp) if(MLX_BUILD_CPU AND NOT WIN32) target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp) diff --git a/mlx/distributed/distributed.cpp b/mlx/distributed/distributed.cpp index 3cde6a263b..60e1119648 100644 --- a/mlx/distributed/distributed.cpp +++ b/mlx/distributed/distributed.cpp @@ -50,6 +50,10 @@ void sum_scatter( group.raw_group()->sum_scatter(input, output, stream); } +void all_to_all(Group group, const array& input, array& output, Stream stream) { + group.raw_group()->all_to_all(input, output, stream); +} + class EmptyGroup : public GroupImpl { public: Stream communication_stream(StreamOrDevice s) override { @@ -98,6 +102,10 @@ class EmptyGroup : public GroupImpl { throw std::runtime_error( "Communication not implemented in an empty distributed group."); } + void all_to_all(const array&, array&, Stream) override { + throw std::runtime_error( + "Communication not implemented in an empty distributed group."); + } }; } // namespace detail diff --git a/mlx/distributed/distributed_impl.cpp b/mlx/distributed/distributed_impl.cpp new file mode 100644 index 0000000000..d827994227 --- /dev/null +++ b/mlx/distributed/distributed_impl.cpp @@ -0,0 +1,99 @@ +// Copyright © 2024 Apple Inc. + +#include +#include +#include + +#include "mlx/distributed/distributed_impl.h" + +namespace mlx::core::distributed::detail { + +int GroupImpl::blocking_exchange_v( + const array& send_rows_buf, + int send_rows, + array& recv_rows_buf, + int recv_cap_rows, + int row_stride_bytes, + int peer, + ExchangeTag count_tag, + ExchangeTag payload_tag, + array& count_send, + array& count_recv) { + if (send_rows < 0 || recv_cap_rows < 0 || row_stride_bytes <= 0) { + throw std::invalid_argument( + "[blocking_exchange_v] invalid args: send_rows=" + + std::to_string(send_rows) + + " recv_cap_rows=" + std::to_string(recv_cap_rows) + + " row_stride_bytes=" + std::to_string(row_stride_bytes)); + } + + // Overflow check: send_rows * row_stride_bytes + if (send_rows > 0 && + static_cast(row_stride_bytes) > + SIZE_MAX / static_cast(send_rows)) { + throw std::overflow_error( + "[blocking_exchange_v] send_rows * row_stride_bytes overflow"); + } + size_t send_payload_bytes = static_cast(send_rows) * row_stride_bytes; + if (send_payload_bytes > send_rows_buf.nbytes()) { + throw std::out_of_range( + "[blocking_exchange_v] send payload exceeds send buffer"); + } + + // Phase 1: Exchange counts + count_send.data()[0] = static_cast(send_rows); + count_recv.data()[0] = 0; + + blocking_sendrecv( + count_send, + sizeof(int32_t), + count_recv, + sizeof(int32_t), + peer, + count_tag); + + int peer_count = count_recv.data()[0]; + + if (peer_count < 0) { + throw std::runtime_error( + "[blocking_exchange_v] negative peer_count=" + + std::to_string(peer_count)); + } + if (peer_count > recv_cap_rows) { + throw std::out_of_range( + "[blocking_exchange_v] peer_count=" + std::to_string(peer_count) + + " exceeds recv_cap_rows=" + std::to_string(recv_cap_rows)); + } + + // Overflow check: peer_count * row_stride_bytes + if (peer_count > 0 && + static_cast(row_stride_bytes) > + SIZE_MAX / static_cast(peer_count)) { + throw std::overflow_error( + "[blocking_exchange_v] peer_count * row_stride_bytes overflow"); + } + size_t recv_payload_bytes = + static_cast(peer_count) * row_stride_bytes; + if (recv_payload_bytes > recv_rows_buf.nbytes()) { + throw std::out_of_range( + "[blocking_exchange_v] peer payload exceeds recv buffer"); + } + + // Early return if no payload + if (send_rows == 0 && peer_count == 0) { + return 0; + } + + // Phase 2: Exchange payload + blocking_sendrecv( + send_rows_buf, + send_payload_bytes, + recv_rows_buf, + recv_payload_bytes, + peer, + payload_tag); + + return peer_count; +} + +} // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/distributed_impl.h b/mlx/distributed/distributed_impl.h index d889587abc..6863ff75cd 100644 --- a/mlx/distributed/distributed_impl.h +++ b/mlx/distributed/distributed_impl.h @@ -2,10 +2,23 @@ #pragma once +#include +#include +#include + #include "mlx/distributed/distributed.h" namespace mlx::core::distributed::detail { +enum class ExchangeTag : uint16_t { + MoeDispatchCount = 100, + MoeDispatchPayload = 101, + MoeCombineReqCount = 110, + MoeCombineReqPayload = 111, + MoeCombineResCount = 120, + MoeCombineResPayload = 121, +}; + /** * Abstract base class of a distributed group implementation. */ @@ -30,6 +43,53 @@ class GroupImpl { virtual void all_min(const array& input, array& output, Stream stream) = 0; virtual void sum_scatter(const array& input, array& output, Stream stream) = 0; + virtual void all_to_all(const array& input, array& output, Stream stream) = 0; + + // Blocking (synchronous) communication — runs directly on the calling + // thread without going through the encoder/stream machinery. + virtual void blocking_send(const array& input, int dst) { + throw std::runtime_error( + "[GroupImpl] blocking_send not supported by this backend"); + } + virtual void blocking_recv(array& output, int src) { + throw std::runtime_error( + "[GroupImpl] blocking_recv not supported by this backend"); + } + virtual void blocking_all_to_all(const array& input, array& output) { + throw std::runtime_error( + "[GroupImpl] blocking_all_to_all not supported by this backend"); + } + + // Tagged bidirectional blocking sendrecv — sends send_nbytes from + // send_buf and receives recv_nbytes into recv_buf in a single call. + // Supports asymmetric sizes. If send_nbytes==0 or recv_nbytes==0 the + // corresponding direction is skipped. + virtual void blocking_sendrecv( + const array& send_buf, + size_t send_nbytes, + array& recv_buf, + size_t recv_nbytes, + int peer, + ExchangeTag tag) { + throw std::runtime_error( + "[GroupImpl] blocking_sendrecv not supported by this backend"); + } + + // Non-virtual concrete helper: variable-length row exchange with a peer. + // Performs two blocking_sendrecv calls: (1) count exchange, (2) payload. + // count_send/count_recv must be pre-allocated int32 arrays of size >= 1. + // Returns: number of rows received from peer (peer_count). + int blocking_exchange_v( + const array& send_rows_buf, + int send_rows, + array& recv_rows_buf, + int recv_cap_rows, + int row_stride_bytes, + int peer, + ExchangeTag count_tag, + ExchangeTag payload_tag, + array& count_send, + array& count_recv); }; /* Define the MLX stream that the communication should happen in. */ @@ -56,4 +116,7 @@ void all_min(Group group, const array& input, array& output, Stream stream); /** Reduce scatter with average operation */ void sum_scatter(Group group, const array& input, array& output, Stream stream); +/** All-to-all exchange */ +void all_to_all(Group group, const array& input, array& output, Stream stream); + } // namespace mlx::core::distributed::detail diff --git a/mlx/distributed/jaccl/mesh.cpp b/mlx/distributed/jaccl/mesh.cpp index c8df4e6745..bf566f4194 100644 --- a/mlx/distributed/jaccl/mesh.cpp +++ b/mlx/distributed/jaccl/mesh.cpp @@ -328,6 +328,363 @@ void MeshGroup::recv(array& out, int src, Stream stream) { }); } +void MeshGroup::blocking_send(const array& input, int dst) { + auto data = input.data(); + int64_t n_bytes = input.nbytes(); + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE; + auto [sz, N] = buffer_size_from_message(n_bytes); + int in_flight = 0; + int64_t read_offset = 0; + int buff = 0; + while (read_offset < n_bytes && buff < PIPELINE) { + std::copy( + data + read_offset, + data + std::min(read_offset + N, n_bytes), + send_buffer(sz, buff).begin()); + send_to(sz, dst, buff); + buff++; + read_offset += N; + in_flight++; + } + auto deadline_send = + std::chrono::steady_clock::now() + std::chrono::seconds(5); + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = connections_[dst].poll(WC_NUM, wc); + if (n < 0) { + throw std::runtime_error( + "[jaccl] blocking_send: poll() returned " + std::to_string(n)); + } + if (n == 0 && std::chrono::steady_clock::now() > deadline_send) { + throw std::runtime_error( + "[jaccl] blocking_send: timeout waiting for CQ completion (in_flight=" + + std::to_string(in_flight) + ")"); + } + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int b = (wc[i].wr_id >> 8) & 0xff; + if (wc[i].status != IBV_WC_SUCCESS) { + throw std::runtime_error( + "[jaccl] blocking_send: WC error status=" + + std::to_string(wc[i].status) + + " wr_id=" + std::to_string(wc[i].wr_id)); + } + if (work_type != SEND_WR) { + throw std::runtime_error( + "[jaccl] blocking_send: unexpected work_type=" + + std::to_string(work_type)); + } + in_flight--; + if (read_offset < n_bytes) { + std::copy( + data + read_offset, + data + std::min(read_offset + N, n_bytes), + send_buffer(sz, b).begin()); + send_to(sz, dst, b); + read_offset += N; + in_flight++; + } + } + } +} + +void MeshGroup::blocking_recv(array& out, int src) { + auto data = out.data(); + int64_t n_bytes = out.nbytes(); + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE; + auto [sz, N] = buffer_size_from_message(n_bytes); + int in_flight = 0; + int64_t write_offset = 0; + int buff = 0; + while (N * buff < n_bytes && buff < PIPELINE) { + recv_from(sz, src, buff); + in_flight++; + buff++; + } + auto deadline_recv = + std::chrono::steady_clock::now() + std::chrono::seconds(5); + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = connections_[src].poll(WC_NUM, wc); + if (n < 0) { + throw std::runtime_error( + "[jaccl] blocking_recv: poll() returned " + std::to_string(n)); + } + if (n == 0 && std::chrono::steady_clock::now() > deadline_recv) { + throw std::runtime_error( + "[jaccl] blocking_recv: timeout waiting for CQ completion (in_flight=" + + std::to_string(in_flight) + ")"); + } + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int b = (wc[i].wr_id >> 8) & 0xff; + if (wc[i].status != IBV_WC_SUCCESS) { + throw std::runtime_error( + "[jaccl] blocking_recv: WC error status=" + + std::to_string(wc[i].status) + + " wr_id=" + std::to_string(wc[i].wr_id)); + } + if (work_type != RECV_WR) { + throw std::runtime_error( + "[jaccl] blocking_recv: unexpected work_type=" + + std::to_string(work_type)); + } + in_flight--; + std::copy( + recv_buffer(sz, b, src).begin(), + recv_buffer(sz, b, src).begin() + + std::min(n_bytes - write_offset, static_cast(N)), + data + write_offset); + write_offset += N; + if (write_offset + (PIPELINE - 1) * N < n_bytes) { + recv_from(sz, src, b); + in_flight++; + } + } + } +} + +void MeshGroup::blocking_all_to_all(const array& input, array& output) { + if (size_ != 2) { + throw std::runtime_error( + "[jaccl] blocking_all_to_all currently supports size == 2, got " + + std::to_string(size_) + "."); + } + auto in_ptr = input.data(); + auto out_ptr = output.data(); + if (in_ptr == out_ptr) { + throw std::runtime_error( + "[jaccl] in-place blocking_all_to_all is not supported."); + } + int64_t n_bytes = static_cast(input.nbytes()); + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * 2; + int peer = 1 - rank_; + int64_t per_peer_bytes = n_bytes / size_; + std::memcpy( + out_ptr + rank_ * per_peer_bytes, + in_ptr + rank_ * per_peer_bytes, + per_peer_bytes); + if (per_peer_bytes == 0) + return; + char* send_src = const_cast(in_ptr) + peer * per_peer_bytes; + char* recv_dst = out_ptr + peer * per_peer_bytes; + auto [sz, N] = buffer_size_from_message(per_peer_bytes); + int in_flight = 0; + int64_t read_offset = 0; + int64_t write_offset = 0; + int buff = 0; + while (read_offset < per_peer_bytes && buff < PIPELINE) { + recv_from(sz, peer, buff); + in_flight++; + std::copy( + send_src + read_offset, + send_src + + std::min(read_offset + static_cast(N), per_peer_bytes), + send_buffer(sz, buff).begin()); + send_to(sz, peer, buff); + in_flight++; + read_offset += N; + buff++; + } + auto deadline_a2a = + std::chrono::steady_clock::now() + std::chrono::seconds(5); + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = connections_[peer].poll(WC_NUM, wc); + if (n < 0) { + throw std::runtime_error( + "[jaccl] blocking_all_to_all: poll() returned " + std::to_string(n)); + } + if (n == 0 && std::chrono::steady_clock::now() > deadline_a2a) { + throw std::runtime_error( + "[jaccl] blocking_all_to_all: timeout waiting for CQ completion (in_flight=" + + std::to_string(in_flight) + ")"); + } + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int b = (wc[i].wr_id >> 8) & 0xff; + if (wc[i].status != IBV_WC_SUCCESS) { + throw std::runtime_error( + "[jaccl] blocking_all_to_all: WC error status=" + + std::to_string(wc[i].status) + + " wr_id=" + std::to_string(wc[i].wr_id)); + } + in_flight--; + if (work_type == SEND_WR) { + if (read_offset < per_peer_bytes) { + std::copy( + send_src + read_offset, + send_src + + std::min( + read_offset + static_cast(N), per_peer_bytes), + send_buffer(sz, b).begin()); + send_to(sz, peer, b); + in_flight++; + read_offset += N; + } + } else if (work_type == RECV_WR) { + std::copy( + recv_buffer(sz, b, peer).begin(), + recv_buffer(sz, b, peer).begin() + + std::min( + static_cast(N), per_peer_bytes - write_offset), + recv_dst + write_offset); + write_offset += N; + if (write_offset + (PIPELINE - 1) * N < per_peer_bytes) { + recv_from(sz, peer, b); + in_flight++; + } + } + } + } +} + +void MeshGroup::blocking_sendrecv( + const array& send_buf, + size_t send_nbytes, + array& recv_buf, + size_t recv_nbytes, + int peer, + detail::ExchangeTag tag) { + // Skip no-op + if (send_nbytes == 0 && recv_nbytes == 0) + return; + + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * 2; + + // We use the existing send_buffer/recv_buffer infrastructure. + // For the tagged sendrecv, we encode tag in the upper 16 bits of wr_id: + // wr_id = (tag << 48) | (work_type << 16) | (buff << 8) | rank + // But since existing code uses wr_id as: work_type(16-23) | buff(8-15) | + // rank(0-7) and poll() only looks at bits 0-23, we can safely put tag in bits + // 48-63. However, ibv_wc.wr_id is uint64_t, so this is safe. + // + // Actually, to keep it simple and avoid any potential issues with existing + // code, we'll just use the standard wr_id format (SEND_WR/RECV_WR, buff, + // rank) and NOT encode the tag in wr_id. The tag is implicit since + // blocking_sendrecv uses its own poll loop and won't see other CQ entries + // from other operations (we're in a blocking context inside enc.dispatch + // lambda). + + auto send_ptr = static_cast(send_buf.data()); + auto recv_ptr = static_cast(recv_buf.data()); + + // Determine buffer size for RDMA frames + size_t max_bytes = std::max(send_nbytes, recv_nbytes); + if (max_bytes == 0) + max_bytes = 1; // avoid 0-size + auto [sz, N] = buffer_size_from_message(static_cast(max_bytes)); + + int in_flight = 0; + int64_t send_offset = 0; + int64_t recv_write_offset = 0; + int64_t send_total = static_cast(send_nbytes); + int64_t recv_total = static_cast(recv_nbytes); + + // Prefill: post recvs first, then sends + int buff = 0; + int recv_buffs_posted = 0; + int send_buffs_posted = 0; + + // Post recv buffers first (if we have data to receive) + if (recv_total > 0) { + buff = 0; + while (static_cast(N) * buff < recv_total && buff < PIPELINE) { + recv_from(sz, peer, buff); + in_flight++; + recv_buffs_posted++; + buff++; + } + } + + // Post sends + if (send_total > 0) { + buff = 0; + while (send_offset < send_total && buff < PIPELINE) { + auto chunk = std::min(send_total - send_offset, static_cast(N)); + std::copy( + send_ptr + send_offset, + send_ptr + send_offset + chunk, + send_buffer(sz, buff).begin()); + send_to(sz, peer, buff); + in_flight++; + send_offset += N; + send_buffs_posted++; + buff++; + } + } + + // Poll loop + auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5); + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = connections_[peer].poll(WC_NUM, wc); + if (n < 0) { + throw std::runtime_error( + "[jaccl] blocking_sendrecv: poll() returned " + std::to_string(n)); + } + if (n == 0 && std::chrono::steady_clock::now() > deadline) { + throw std::runtime_error( + "[jaccl] blocking_sendrecv: timeout (in_flight=" + + std::to_string(in_flight) + + " tag=" + std::to_string(static_cast(tag)) + ")"); + } + if (n > 0) { + deadline = std::chrono::steady_clock::now() + std::chrono::seconds(5); + } + + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int b = (wc[i].wr_id >> 8) & 0xff; + + if (wc[i].status != IBV_WC_SUCCESS) { + throw std::runtime_error( + "[jaccl] blocking_sendrecv: WC error status=" + + std::to_string(wc[i].status) + + " wr_id=" + std::to_string(wc[i].wr_id) + + " tag=" + std::to_string(static_cast(tag))); + } + + in_flight--; + + if (work_type == SEND_WR) { + // Send completed — post next chunk if available + if (send_offset < send_total) { + auto chunk = + std::min(send_total - send_offset, static_cast(N)); + std::copy( + send_ptr + send_offset, + send_ptr + send_offset + chunk, + send_buffer(sz, b).begin()); + send_to(sz, peer, b); + in_flight++; + send_offset += N; + } + } else if (work_type == RECV_WR) { + // Recv completed — copy data and post next recv if needed + auto chunk = + std::min(static_cast(N), recv_total - recv_write_offset); + if (chunk > 0) { + std::copy( + recv_buffer(sz, b, peer).begin(), + recv_buffer(sz, b, peer).begin() + chunk, + recv_ptr + recv_write_offset); + recv_write_offset += N; + } + if (recv_write_offset + (PIPELINE - 1) * static_cast(N) < + recv_total) { + recv_from(sz, peer, b); + in_flight++; + } + } + } + } +} + template void MeshGroup::all_reduce( const array& input, @@ -448,4 +805,106 @@ void MeshGroup::all_reduce( }); } +void MeshGroup::all_to_all(const array& input, array& output, Stream stream) { + if (size_ != 2) { + throw std::runtime_error( + "[jaccl] all_to_all currently supports size == 2, got " + + std::to_string(size_) + "."); + } + auto in_ptr = input.data(); + auto out_ptr = output.data(); + if (in_ptr == out_ptr) { + throw std::runtime_error( + "[jaccl] in-place all_to_all is not supported (input/output alias)."); + } + int64_t n_bytes = static_cast(input.nbytes()); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + encoder.dispatch([in_ptr, out_ptr, n_bytes, this]() { + constexpr int PIPELINE = 2; + constexpr int WC_NUM = PIPELINE * 2; + + int peer = 1 - rank_; + int64_t per_peer_bytes = n_bytes / size_; + + // Local chunk: input[rank] -> output[rank] + std::memcpy( + out_ptr + rank_ * per_peer_bytes, + in_ptr + rank_ * per_peer_bytes, + per_peer_bytes); + + if (per_peer_bytes == 0) + return; + + char* send_src = const_cast(in_ptr) + peer * per_peer_bytes; + char* recv_dst = out_ptr + peer * per_peer_bytes; + + auto [sz, N] = buffer_size_from_message(per_peer_bytes); + + int in_flight = 0; + int64_t read_offset = 0; + int64_t write_offset = 0; + + // Prefill: recv-first (deadlock prevention) + int buff = 0; + while (read_offset < per_peer_bytes && buff < PIPELINE) { + recv_from(sz, peer, buff); + in_flight++; + + std::copy( + send_src + read_offset, + send_src + + std::min(read_offset + static_cast(N), per_peer_bytes), + send_buffer(sz, buff).begin()); + send_to(sz, peer, buff); + in_flight++; + + read_offset += N; + buff++; + } + + // Single poll loop + while (in_flight > 0) { + ibv_wc wc[WC_NUM]; + int n = connections_[peer].poll(WC_NUM, wc); + + for (int i = 0; i < n; i++) { + int work_type = wc[i].wr_id >> 16; + int b = (wc[i].wr_id >> 8) & 0xff; + + in_flight--; + + if (work_type == SEND_WR) { + if (read_offset < per_peer_bytes) { + std::copy( + send_src + read_offset, + send_src + + std::min( + read_offset + static_cast(N), per_peer_bytes), + send_buffer(sz, b).begin()); + send_to(sz, peer, b); + in_flight++; + read_offset += N; + } + } else if (work_type == RECV_WR) { + std::copy( + recv_buffer(sz, b, peer).begin(), + recv_buffer(sz, b, peer).begin() + + std::min( + static_cast(N), per_peer_bytes - write_offset), + recv_dst + write_offset); + write_offset += N; + + if (write_offset + (PIPELINE - 1) * N < per_peer_bytes) { + recv_from(sz, peer, b); + in_flight++; + } + } + } + } + }); +} + } // namespace mlx::core::distributed::jaccl diff --git a/mlx/distributed/jaccl/mesh.h b/mlx/distributed/jaccl/mesh.h index 6f779e9ccb..69a661943e 100644 --- a/mlx/distributed/jaccl/mesh.h +++ b/mlx/distributed/jaccl/mesh.h @@ -41,9 +41,21 @@ class MeshGroup : public GroupImpl { void all_max(const array& input, array& output, Stream stream) override; void all_min(const array& input, array& output, Stream stream) override; void all_gather(const array& input, array& output, Stream stream) override; + void all_to_all(const array& input, array& output, Stream stream) override; void send(const array& input, int dst, Stream stream) override; void recv(array& out, int src, Stream stream) override; + void blocking_send(const array& input, int dst) override; + void blocking_recv(array& out, int src) override; + void blocking_all_to_all(const array& input, array& output) override; + void blocking_sendrecv( + const array& send_buf, + size_t send_nbytes, + array& recv_buf, + size_t recv_nbytes, + int peer, + detail::ExchangeTag tag) override; + void sum_scatter(const array& input, array& output, Stream stream) override { throw std::runtime_error("[jaccl] sum_scatter not supported."); } diff --git a/mlx/distributed/jaccl/ring.h b/mlx/distributed/jaccl/ring.h index a59ceb3dd8..e7dfbbf5ab 100644 --- a/mlx/distributed/jaccl/ring.h +++ b/mlx/distributed/jaccl/ring.h @@ -52,6 +52,10 @@ class RingGroup : public GroupImpl { throw std::runtime_error("[jaccl] sum_scatter not supported."); } + void all_to_all(const array& input, array& output, Stream stream) override { + throw std::runtime_error("[jaccl] all_to_all not supported."); + } + std::shared_ptr split(int color, int key = -1) override { throw std::runtime_error("[jaccl] Group split not supported."); } diff --git a/mlx/distributed/mpi/mpi.cpp b/mlx/distributed/mpi/mpi.cpp index 3b176e6e67..82a980a8f9 100644 --- a/mlx/distributed/mpi/mpi.cpp +++ b/mlx/distributed/mpi/mpi.cpp @@ -1,6 +1,7 @@ // Copyright © 2024 Apple Inc. #include +#include #include #include @@ -131,6 +132,8 @@ struct MPIWrapper { LOAD_SYMBOL(MPI_Allgather, all_gather); LOAD_SYMBOL(MPI_Send, send); LOAD_SYMBOL(MPI_Recv, recv); + LOAD_SYMBOL(MPI_Alltoall, all_to_all); + LOAD_SYMBOL(MPI_Sendrecv, sendrecv); LOAD_SYMBOL(MPI_Type_contiguous, mpi_type_contiguous); LOAD_SYMBOL(MPI_Type_commit, mpi_type_commit); LOAD_SYMBOL(MPI_Op_create, mpi_op_create); @@ -156,6 +159,7 @@ struct MPIWrapper { LOAD_SYMBOL(ompi_mpi_float, mpi_float_); LOAD_SYMBOL(ompi_mpi_double, mpi_double_); LOAD_SYMBOL(ompi_mpi_c_complex, mpi_complex_); + LOAD_SYMBOL(ompi_mpi_byte, mpi_byte_); } bool is_available() { @@ -237,6 +241,10 @@ struct MPIWrapper { } } + MPI_Datatype mpi_byte() { + return mpi_byte_; + } + MPI_Op op_sum(const array& arr) { switch (arr.dtype()) { case float16: @@ -294,6 +302,27 @@ struct MPIWrapper { int (*comm_free)(MPI_Comm*); int (*send)(const void*, int, MPI_Datatype, int, int, MPI_Comm); int (*recv)(void*, int, MPI_Datatype, int, int, MPI_Comm, MPI_Status*); + int (*all_to_all)( + const void*, + int, + MPI_Datatype, + void*, + int, + MPI_Datatype, + MPI_Comm); + int (*sendrecv)( + const void*, + int, + MPI_Datatype, + int, + int, + void*, + int, + MPI_Datatype, + int, + int, + MPI_Comm, + MPI_Status*); // Objects MPI_Comm comm_world_; @@ -326,6 +355,7 @@ struct MPIWrapper { MPI_Datatype mpi_complex_; MPI_Datatype mpi_float16_; MPI_Datatype mpi_bfloat16_; + MPI_Datatype mpi_byte_; private: bool initialized_; @@ -476,6 +506,97 @@ class MPIGroup : public GroupImpl { throw std::runtime_error("[mpi] sum_scatter not yet implemented."); } + void all_to_all(const array& input, array& output, Stream stream) override { + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(input); + encoder.set_output_array(output); + int count = input.size() / size(); + encoder.dispatch( + mpi().all_to_all, + input.data(), + count, + mpi().datatype(input), + output.data(), + count, + mpi().datatype(output), + comm_); + } + + void blocking_send(const array& input, int dst) override { + mpi().send( + input.data(), input.size(), mpi().datatype(input), dst, 0, comm_); + } + + void blocking_recv(array& out, int src) override { + MPI_Status status; + mpi().recv( + out.data(), + out.size(), + mpi().datatype(out), + src, + MPI_ANY_TAG, + comm_, + &status); + } + + void blocking_all_to_all(const array& input, array& output) override { + int count = input.size() / size(); + mpi().all_to_all( + input.data(), + count, + mpi().datatype(input), + output.data(), + count, + mpi().datatype(output), + comm_); + } + + void blocking_sendrecv( + const array& send_buf, + size_t send_nbytes, + array& recv_buf, + size_t recv_nbytes, + int peer, + detail::ExchangeTag tag) override { + if (send_nbytes == 0 && recv_nbytes == 0) + return; + + int mpi_tag = static_cast(tag); + auto* sptr = static_cast(send_buf.data()); + auto* rptr = static_cast(recv_buf.data()); + + // INT_MAX asymmetric chunk loop for large messages + size_t so = 0, ro = 0; + while (so < send_nbytes || ro < recv_nbytes) { + int sc = (so < send_nbytes) + ? static_cast( + std::min(static_cast(INT_MAX), send_nbytes - so)) + : 0; + int rc = (ro < recv_nbytes) + ? static_cast( + std::min(static_cast(INT_MAX), recv_nbytes - ro)) + : 0; + + MPI_Status st; + mpi().sendrecv( + sc > 0 ? sptr + so : sptr, + sc, + mpi().mpi_byte(), + peer, + mpi_tag, + rc > 0 ? rptr + ro : rptr, + rc, + mpi().mpi_byte(), + peer, + mpi_tag, + comm_, + &st); + + so += sc; + ro += rc; + } + } + private: MPI_Comm comm_; bool global_; diff --git a/mlx/distributed/nccl/nccl.cpp b/mlx/distributed/nccl/nccl.cpp index d8244bf94f..5b3ed53c9a 100644 --- a/mlx/distributed/nccl/nccl.cpp +++ b/mlx/distributed/nccl/nccl.cpp @@ -356,6 +356,10 @@ class NCCLGroup : public GroupImpl { }); } + void all_to_all(const array&, array&, Stream) override { + throw std::runtime_error("[nccl] all_to_all not yet implemented."); + } + template void all_reduce_impl( const array& input, diff --git a/mlx/distributed/ops.cpp b/mlx/distributed/ops.cpp index 1762f0e6bc..3338c04697 100644 --- a/mlx/distributed/ops.cpp +++ b/mlx/distributed/ops.cpp @@ -1,5 +1,8 @@ // Copyright © 2024 Apple Inc. +#include +#include +#include #include #include "mlx/backend/cuda/cuda.h" @@ -20,6 +23,42 @@ Group to_group(std::optional group) { } } +// Auto mode: resolve to CPU (Phase 5 policy removed for PR1 simplification) +MoeBackend resolve_auto_backend( + int /* N */, + int /* top_k */, + int /* D */, + int /* elem_size */) { + return MoeBackend::Cpu; +} + +MoeBackend resolve_backend_str(const std::string& backend) { + if (backend == "auto") + return MoeBackend::Auto; + if (backend == "cpu") + return MoeBackend::Cpu; + if (backend == "metal") + return MoeBackend::Metal; + throw std::invalid_argument( + "[moe] invalid backend '" + backend + "', expected auto/cpu/metal"); +} + +// GPU stream selection with graceful fallback to CPU +Stream resolve_moe_stream(MoeBackend& backend, StreamOrDevice s) { + if (backend == MoeBackend::Metal) { + try { + return to_stream(s, Device::gpu); + } catch (...) { + static std::once_flag warn_once; + std::call_once(warn_once, []() { + std::cerr << "[MoE EP] GPU stream unavailable. Falling back to CPU.\n"; + }); + backend = MoeBackend::Cpu; + } + } + return to_stream(s, Device::cpu); +} + } // namespace array all_sum( @@ -183,4 +222,233 @@ array sum_scatter( std::make_shared(stream, group, ReduceScatter::Sum), {x}); } + +array all_to_all( + const array& x, + std::optional group_ /* = std::nullopt */, + StreamOrDevice s /* = {} */) { + auto group = to_group(group_); + if (group.size() == 1) { + return x; + } + if (x.ndim() < 1) { + throw std::invalid_argument("[all_to_all] Input must be at least 1-D."); + } + if (x.shape(0) % group.size() != 0) { + std::ostringstream msg; + msg << "[all_to_all] Invalid shape=" << x.shape() << " for a group of size " + << group.size() + << ". The first dimension (axis 0) must be divisible by the group size."; + throw std::invalid_argument(msg.str()); + } + auto stream = detail::communication_stream(group, s); + return array( + x.shape(), x.dtype(), std::make_shared(stream, group), {x}); +} + +std::pair moe_dispatch_exchange( + const array& tokens, + const array& expert_indices, + int num_experts, + int capacity, + std::optional group_, + bool deterministic, + const std::string& backend, + StreamOrDevice s) { + auto group = to_group(group_); + + // Validate inputs + if (tokens.ndim() != 2) { + throw std::invalid_argument( + "[moe_dispatch_exchange] tokens must be 2-D [N, D]."); + } + if (expert_indices.ndim() != 2) { + throw std::invalid_argument( + "[moe_dispatch_exchange] expert_indices must be 2-D [N, top_k]."); + } + if (tokens.shape(0) != expert_indices.shape(0)) { + throw std::invalid_argument( + "[moe_dispatch_exchange] tokens and expert_indices must have same N."); + } + if (expert_indices.dtype() != int32) { + throw std::invalid_argument( + "[moe_dispatch_exchange] expert_indices must have dtype int32."); + } + if (num_experts % group.size() != 0) { + throw std::invalid_argument( + "[moe_dispatch_exchange] num_experts must be divisible by group size."); + } + if (capacity <= 0) { + throw std::invalid_argument( + "[moe_dispatch_exchange] capacity must be positive."); + } + + int world_size = group.size(); + int experts_per_device = num_experts / world_size; + + if (experts_per_device > 65535 || capacity > 65535) { + throw std::invalid_argument( + "[moe_dispatch_exchange] meta32 overflow: " + "experts_per_device=" + + std::to_string(experts_per_device) + " capacity=" + + std::to_string(capacity) + " — both must be <= 65535 for v3 protocol"); + } + + int cap_total = world_size * capacity; + int D = tokens.shape(1); + int N = tokens.shape(0); + int top_k = expert_indices.shape(1); + + // Output shapes: + // dispatched: [experts_per_device, cap_total, D] + // route_indices: [N, top_k] int32 + auto dispatched_shape = Shape{experts_per_device, cap_total, D}; + auto route_indices_shape = Shape{N, top_k}; + + auto moe_backend = resolve_backend_str(backend); + + // Resolve Auto → CPU + if (moe_backend == MoeBackend::Auto) { + moe_backend = MoeBackend::Cpu; + } + + // ws > 2: Metal not yet optimized, fall back to CPU + if (moe_backend == MoeBackend::Metal && world_size > 2) { + static std::once_flag warned; + std::call_once(warned, []() { + std::cerr + << "[MoE EP] Metal backend not yet optimized for world_size > 2, " + << "falling back to CPU path." << std::endl; + }); + moe_backend = MoeBackend::Cpu; + } + + auto stream = resolve_moe_stream(moe_backend, s); + + auto outputs = array::make_arrays( + {std::move(dispatched_shape), std::move(route_indices_shape)}, + {tokens.dtype(), int32}, + std::make_shared( + stream, group, num_experts, capacity, deterministic, moe_backend), + {tokens, expert_indices}); + + return {outputs[0], outputs[1]}; +} + +array moe_combine_exchange( + const array& expert_outputs, + const array& route_indices, + const array& weights, + const array& original_tokens, + int num_experts, + int capacity, + std::optional group_, + bool deterministic, + const std::string& backend, + StreamOrDevice s) { + auto group = to_group(group_); + + if (expert_outputs.ndim() != 3) { + throw std::invalid_argument( + "[moe_combine_exchange] expert_outputs must be 3-D [E_local, cap_total, D]."); + } + if (route_indices.ndim() != 2) { + throw std::invalid_argument( + "[moe_combine_exchange] route_indices must be 2-D [N, top_k]."); + } + if (weights.ndim() != 2) { + throw std::invalid_argument( + "[moe_combine_exchange] weights must be 2-D [N, top_k]."); + } + if (original_tokens.ndim() != 2) { + throw std::invalid_argument( + "[moe_combine_exchange] original_tokens must be 2-D [N, D]."); + } + if (route_indices.dtype() != int32) { + throw std::invalid_argument( + "[moe_combine_exchange] route_indices must have dtype int32."); + } + if (weights.dtype() != float32) { + throw std::invalid_argument( + "[moe_combine_exchange] weights must have dtype float32."); + } + + // Shape compatibility checks + if (route_indices.shape(0) != weights.shape(0) || + route_indices.shape(0) != original_tokens.shape(0)) { + std::ostringstream msg; + msg << "[moe_combine_exchange] N dimension mismatch: " + << "route_indices.shape(0)=" << route_indices.shape(0) + << " weights.shape(0)=" << weights.shape(0) + << " original_tokens.shape(0)=" << original_tokens.shape(0); + throw std::invalid_argument(msg.str()); + } + if (route_indices.shape(1) != weights.shape(1)) { + std::ostringstream msg; + msg << "[moe_combine_exchange] top_k dimension mismatch: " + << "route_indices.shape(1)=" << route_indices.shape(1) + << " weights.shape(1)=" << weights.shape(1); + throw std::invalid_argument(msg.str()); + } + if (original_tokens.shape(1) != expert_outputs.shape(2)) { + std::ostringstream msg; + msg << "[moe_combine_exchange] hidden dim D mismatch: " + << "original_tokens.shape(1)=" << original_tokens.shape(1) + << " expert_outputs.shape(2)=" << expert_outputs.shape(2); + throw std::invalid_argument(msg.str()); + } + if (original_tokens.dtype() != expert_outputs.dtype()) { + std::ostringstream msg; + msg << "[moe_combine_exchange] dtype mismatch: " + << "original_tokens.dtype=" << original_tokens.dtype() + << " expert_outputs.dtype=" << expert_outputs.dtype(); + throw std::invalid_argument(msg.str()); + } + + int world_size = group.size(); + if (expert_outputs.shape(1) != world_size * capacity) { + std::ostringstream msg; + msg << "[moe_combine_exchange] expert_outputs.shape(1)=" + << expert_outputs.shape(1) + << " must equal world_size * capacity = " << world_size << " * " + << capacity << " = " << (world_size * capacity) << "."; + throw std::invalid_argument(msg.str()); + } + if (capacity <= 0) { + throw std::invalid_argument( + "[moe_combine_exchange] capacity must be positive."); + } + + int N = original_tokens.shape(0); + int D = original_tokens.shape(1); + auto combined_shape = Shape{N, D}; + + auto moe_backend = resolve_backend_str(backend); + + // Resolve Auto → CPU + if (moe_backend == MoeBackend::Auto) { + moe_backend = MoeBackend::Cpu; + } + + // ws > 2: Metal not yet optimized, fall back to CPU + if (moe_backend == MoeBackend::Metal && world_size > 2) { + static std::once_flag warned; + std::call_once(warned, []() { + std::cerr + << "[MoE EP] Metal backend not yet optimized for world_size > 2, " + << "falling back to CPU path." << std::endl; + }); + moe_backend = MoeBackend::Cpu; + } + + auto stream = resolve_moe_stream(moe_backend, s); + + return array( + std::move(combined_shape), + expert_outputs.dtype(), + std::make_shared( + stream, group, num_experts, capacity, deterministic, moe_backend), + {expert_outputs, route_indices, weights, original_tokens}); +} + } // namespace mlx::core::distributed diff --git a/mlx/distributed/ops.h b/mlx/distributed/ops.h index e223c5bea2..0ab43675dd 100644 --- a/mlx/distributed/ops.h +++ b/mlx/distributed/ops.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "mlx/api.h" #include "mlx/distributed/distributed.h" @@ -54,4 +55,31 @@ MLX_API array sum_scatter( std::optional group = std::nullopt, StreamOrDevice s = {}); +MLX_API array all_to_all( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +MLX_API std::pair moe_dispatch_exchange( + const array& tokens, + const array& expert_indices, + int num_experts, + int capacity, + std::optional group = std::nullopt, + bool deterministic = true, + const std::string& backend = "cpu", + StreamOrDevice s = {}); + +MLX_API array moe_combine_exchange( + const array& expert_outputs, + const array& route_indices, + const array& weights, + const array& original_tokens, + int num_experts, + int capacity, + std::optional group = std::nullopt, + bool deterministic = true, + const std::string& backend = "cpu", + StreamOrDevice s = {}); + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.cpp b/mlx/distributed/primitives.cpp index 5e8d5327a1..fda9191fb7 100644 --- a/mlx/distributed/primitives.cpp +++ b/mlx/distributed/primitives.cpp @@ -92,4 +92,45 @@ std::pair, std::vector> Send::vmap( return {{send(inputs[0], dst_, group(), stream())}, axes}; } +std::pair, std::vector> AllToAll::vmap( + const std::vector& inputs, + const std::vector& axes) { + return {{all_to_all(inputs[0], group(), stream())}, axes}; +} + +std::vector AllToAll::jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector&) { + return {all_to_all(tangents[0], group(), stream())}; +} + +std::vector AllToAll::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector&, + const std::vector&) { + return {all_to_all(cotangents[0], group(), stream())}; +} + +std::vector MoeDispatchExchange::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + throw std::runtime_error( + "[MoeDispatchExchange] VJP not implemented yet. " + "Use ep_impl='python' for training."); +} + +std::vector MoeCombineExchange::vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) { + throw std::runtime_error( + "[MoeCombineExchange] VJP not implemented yet. " + "Use ep_impl='python' for training."); +} + } // namespace mlx::core::distributed diff --git a/mlx/distributed/primitives.h b/mlx/distributed/primitives.h index 18a0d65f5f..0d8b24d8c0 100644 --- a/mlx/distributed/primitives.h +++ b/mlx/distributed/primitives.h @@ -8,6 +8,8 @@ namespace mlx::core::distributed { +enum class MoeBackend { Auto, Cpu, Metal }; + class DistPrimitive : public Primitive { public: DistPrimitive(Stream stream, Group group) @@ -153,4 +155,125 @@ class ReduceScatter : public DistPrimitive { private: ReduceType reduce_type_; }; + +class AllToAll : public DistPrimitive { + public: + AllToAll(Stream stream, Group group) : DistPrimitive(stream, group) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(AllToAll); +}; + +class MoeDispatchExchange : public DistPrimitive { + public: + MoeDispatchExchange( + Stream stream, + Group group, + int num_experts, + int capacity, + bool deterministic, + MoeBackend backend = MoeBackend::Cpu) + : DistPrimitive(stream, group), + num_experts_(num_experts), + capacity_(capacity), + deterministic_(deterministic), + backend_(backend) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(MoeDispatchExchange); + + int num_experts() const { + return num_experts_; + } + int capacity() const { + return capacity_; + } + bool deterministic() const { + return deterministic_; + } + MoeBackend backend() const { + return backend_; + } + + private: + int num_experts_; + int capacity_; + bool deterministic_; + MoeBackend backend_; +}; + +class MoeCombineExchange : public DistPrimitive { + public: + MoeCombineExchange( + Stream stream, + Group group, + int num_experts, + int capacity, + bool deterministic, + MoeBackend backend = MoeBackend::Cpu) + : DistPrimitive(stream, group), + num_experts_(num_experts), + capacity_(capacity), + deterministic_(deterministic), + backend_(backend) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(MoeCombineExchange); + + int num_experts() const { + return num_experts_; + } + int capacity() const { + return capacity_; + } + bool deterministic() const { + return deterministic_; + } + MoeBackend backend() const { + return backend_; + } + + private: + int num_experts_; + int capacity_; + bool deterministic_; + MoeBackend backend_; +}; } // namespace mlx::core::distributed diff --git a/mlx/distributed/ring/ring.cpp b/mlx/distributed/ring/ring.cpp index ea40042844..78cb69d937 100644 --- a/mlx/distributed/ring/ring.cpp +++ b/mlx/distributed/ring/ring.cpp @@ -579,6 +579,10 @@ class RingGroup : public GroupImpl { throw std::runtime_error("[ring] sum_scatter not supported."); } + void all_to_all(const array&, array&, Stream) override { + throw std::runtime_error("[ring] all_to_all not supported."); + } + private: template void all_reduce( diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index c2fba58347..e8f28d3b92 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -71,6 +71,13 @@ from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear +from mlx.nn.layers.moe import ( + Expert, + MixtureOfExperts, + TopKRouter, + expert_combine, + expert_dispatch, +) from mlx.nn.layers.normalization import ( BatchNorm, GroupNorm, diff --git a/python/mlx/nn/layers/moe.py b/python/mlx/nn/layers/moe.py new file mode 100644 index 0000000000..cd8ec7ad52 --- /dev/null +++ b/python/mlx/nn/layers/moe.py @@ -0,0 +1,544 @@ +# Copyright © 2026 Apple Inc. + +import math +from typing import NamedTuple, Optional + +import mlx.core as mx +from mlx.nn.layers.activations import silu +from mlx.nn.layers.base import Module +from mlx.nn.layers.linear import Linear + + +class DispatchMeta(NamedTuple): + """Metadata for expert dispatch/combine round-trip.""" + + expert_indices: mx.array # [N, top_k] expert assignments + weights: mx.array # [N, top_k] routing weights + positions: mx.array # [N, top_k] slot positions in dispatch buffer + overflow_mask: mx.array # [N, 1] True if token overflowed capacity + num_experts: int + capacity: int + world_size: int + + +class TopKRouter(Module): + """Top-K expert router with load balancing auxiliary loss. + + Routes each token to the top-k experts based on a learned gate. + + Args: + hidden_dim: Input hidden dimension. + num_experts: Total number of experts. + top_k: Number of experts per token. Default: ``2``. + capacity_factor: Capacity scaling factor. Default: ``1.25``. + aux_loss_coeff: Coefficient for load balancing loss. Default: ``0.01``. + """ + + def __init__( + self, + hidden_dim: int, + num_experts: int, + top_k: int = 2, + capacity_factor: float = 1.25, + aux_loss_coeff: float = 0.01, + ): + super().__init__() + if top_k <= 0: + raise ValueError(f"top_k must be positive, got {top_k}") + if top_k > num_experts: + raise ValueError( + f"top_k ({top_k}) must not exceed num_experts ({num_experts})" + ) + self.gate = Linear(hidden_dim, num_experts, bias=False) + self.num_experts = num_experts + self.top_k = top_k + self.capacity_factor = capacity_factor + self.aux_loss_coeff = aux_loss_coeff + + def __call__(self, x: mx.array): + """Route tokens to experts. + + Args: + x: Input tensor of shape ``[N, hidden_dim]``. + + Returns: + Tuple of (weights, indices, aux_loss): + - weights: ``[N, top_k]`` normalized routing weights + - indices: ``[N, top_k]`` expert indices (integers in [0, num_experts)) + - aux_loss: scalar load balancing loss + """ + # x: [N, D] -> logits: [N, num_experts] + logits = self.gate(x) + probs = mx.softmax(logits, axis=-1) + + # Get top-k experts per token + # mx.argpartition gives indices of top-k elements (unordered within top-k) + # We negate probs to get the largest values + neg_probs = -probs + top_k_indices = mx.argpartition(neg_probs, kth=self.top_k - 1, axis=-1)[ + :, : self.top_k + ] + + # Stop gradient on discrete routing decisions + expert_indices = mx.stop_gradient(top_k_indices) + + # Gather the weights for selected experts + # Use take_along_axis to gather probs at the top-k indices + weights = mx.take_along_axis(probs, expert_indices, axis=-1) + + # Normalize weights so they sum to 1 per token + weights = weights / mx.maximum(weights.sum(axis=-1, keepdims=True), 1e-6) + + # Compute auxiliary load balancing loss + aux_loss = self._load_balance_loss(probs, expert_indices) + + return weights, expert_indices, aux_loss + + def _load_balance_loss(self, probs: mx.array, expert_indices: mx.array) -> mx.array: + """GShard load balancing loss: num_experts * sum(f_e * P_e). + + Args: + probs: [N, num_experts] routing probabilities. + expert_indices: [N, top_k] selected expert indices. + + Returns: + Scalar auxiliary loss. + """ + num_tokens = probs.shape[0] + if num_tokens == 0: + return mx.zeros((), dtype=probs.dtype) + + # f_e: fraction of tokens routed to each expert + # Create one-hot and sum across top_k selections + one_hot = mx.zeros_like(probs) + for k in range(self.top_k): + indices_k = expert_indices[:, k] # [N] + rows = mx.arange(num_tokens) + one_hot = one_hot.at[rows, indices_k].add(1.0) + f_e = mx.mean(one_hot, axis=0) / self.top_k # [num_experts] + + # P_e: mean routing probability per expert + P_e = mx.mean(probs, axis=0) # [num_experts] + + # GShard loss + loss = self.aux_loss_coeff * self.num_experts * mx.sum(f_e * P_e) + return loss + + +def _compute_capacity( + num_tokens: int, + top_k: int, + capacity_factor: float, + num_experts: int, +) -> int: + """Compute expert buffer capacity.""" + if num_experts <= 0: + raise ValueError(f"num_experts must be positive, got {num_experts}") + return max(1, math.ceil(num_tokens * top_k * capacity_factor / num_experts)) + + +class Expert(Module): + """Single expert network with SwiGLU activation. + + Args: + hidden_dim: Input/output hidden dimension. + expert_dim: Expert intermediate dimension. + """ + + def __init__(self, hidden_dim: int, expert_dim: int): + super().__init__() + self.w_gate = Linear(hidden_dim, expert_dim, bias=False) + self.w_up = Linear(hidden_dim, expert_dim, bias=False) + self.w_down = Linear(expert_dim, hidden_dim, bias=False) + + def __call__(self, x: mx.array) -> mx.array: + """Forward pass with SwiGLU activation. + + Args: + x: Input tensor of shape ``[..., hidden_dim]``. + + Returns: + Output tensor of shape ``[..., hidden_dim]``. + """ + return self.w_down(silu(self.w_gate(x)) * self.w_up(x)) + + +def expert_dispatch( + tokens: mx.array, + expert_indices: mx.array, + weights: mx.array, + num_experts: int, + capacity_factor: float, + group: Optional["mx.distributed.Group"] = None, +) -> tuple: + """Dispatch tokens to experts across devices. + + Args: + tokens: [N, D] input tokens. + expert_indices: [N, top_k] expert assignments. + weights: [N, top_k] routing weights. + num_experts: Total number of experts across all devices. + capacity_factor: Capacity scaling factor. + group: Distributed group. If None, uses local-only dispatch. + + Returns: + Tuple of (dispatched, meta): + - dispatched: [experts_per_device, capacity, D] expert inputs for this device + - meta: DispatchMeta for use with expert_combine + """ + num_tokens, hidden_dim = tokens.shape + top_k = expert_indices.shape[1] + + world_size = group.size() if group is not None else 1 + if world_size > 1 and num_experts % world_size != 0: + raise ValueError( + f"num_experts ({num_experts}) must be divisible by " + f"world_size ({world_size})" + ) + experts_per_device = num_experts // world_size + capacity = _compute_capacity(num_tokens, top_k, capacity_factor, num_experts) + + # In distributed mode, synchronize capacity across ranks so all ranks + # use the same buffer size for all_to_all. Different ranks may have + # different token counts (e.g. uneven last batch). + if world_size > 1 and group is not None: + cap_arr = mx.array(capacity, dtype=mx.int32) + cap_arr = mx.distributed.all_max(cap_arr, group=group) + mx.eval(cap_arr) + capacity = cap_arr.item() + + total_slots = world_size * experts_per_device * capacity + slots_per_device = experts_per_device * capacity + dispatch_flat = mx.zeros((total_slots, hidden_dim), dtype=tokens.dtype) + + expert_range = mx.arange(num_experts) + overflow_mask = mx.zeros((num_tokens, 1), dtype=mx.bool_) + expert_counts = mx.zeros((num_experts,), dtype=mx.int32) + pos_columns = [] + + zero_idx = mx.array(0, dtype=mx.int32) + neg_one = mx.array(-1, dtype=mx.int32) + zero_tokens = mx.zeros_like(tokens) + + for k in range(top_k): + indices_k = expert_indices[:, k] + + one_hot = indices_k.reshape(-1, 1) == expert_range.reshape(1, -1) + one_hot_int = one_hot.astype(mx.int32) + + # Per-expert cumulative position, offset by counts from prior k columns + cum = mx.cumsum(one_hot_int, axis=0) - 1 + expert_counts.reshape(1, -1) + + pos = mx.take_along_axis(cum, indices_k.reshape(-1, 1), axis=1).squeeze(1) + + valid = pos < capacity + overflow_mask = overflow_mask | (~valid).reshape(-1, 1) + + pos_columns.append(mx.where(valid, pos, neg_one)) + + d = indices_k // experts_per_device + le = indices_k % experts_per_device + flat_idx = d * slots_per_device + le * capacity + pos + flat_idx = mx.where(valid, flat_idx, zero_idx).astype(mx.int32) + scatter_vals = mx.where(valid.reshape(-1, 1), tokens, zero_tokens) + dispatch_flat = dispatch_flat.at[flat_idx].add(scatter_vals) + + expert_counts = expert_counts + one_hot_int.sum(axis=0) + + positions = mx.stack(pos_columns, axis=1) + + dispatch_buffer = dispatch_flat.reshape( + world_size, experts_per_device, capacity, hidden_dim + ) + + meta = DispatchMeta( + expert_indices=expert_indices, + weights=weights, + positions=positions, + overflow_mask=overflow_mask, + num_experts=num_experts, + capacity=capacity, + world_size=world_size, + ) + + # All-to-all exchange if distributed + if world_size > 1 and group is not None: + # Materialize scatter graph before distributed exchange + mx.eval(dispatch_buffer) + flat = dispatch_buffer.reshape(world_size, -1) + exchanged = mx.distributed.all_to_all(flat, group=group) + dispatched = exchanged.reshape( + world_size, experts_per_device, capacity, hidden_dim + ) + # Each device processes experts_per_device experts, + # data from all devices combined + # Reshape: [world_size, experts_per_device, capacity, D] -> [experts_per_device, world_size * capacity, D] + dispatched = mx.transpose(dispatched, axes=(1, 0, 2, 3)).reshape( + experts_per_device, world_size * capacity, hidden_dim + ) + else: + # Local only: [1, experts_per_device, capacity, D] -> [experts_per_device, capacity, D] + dispatched = dispatch_buffer.squeeze(0) + + return dispatched, meta + + +def expert_combine( + expert_outputs: mx.array, + meta: DispatchMeta, + original_tokens: mx.array, + group: Optional["mx.distributed.Group"] = None, +) -> mx.array: + """Combine expert outputs back to token order. + + Args: + expert_outputs: [experts_per_device, capacity_total, D] expert output tokens. + meta: DispatchMeta from expert_dispatch. + original_tokens: [N, D] original input tokens for residual. + group: Distributed group. + + Returns: + [N, D] combined output tokens. + """ + world_size = meta.world_size + experts_per_device = meta.num_experts // world_size + capacity = meta.capacity + hidden_dim = original_tokens.shape[-1] + num_tokens = original_tokens.shape[0] + + if world_size > 1 and group is not None: + # Reshape back for all_to_all: [experts_per_device, world_size * capacity, D] + # -> [world_size, experts_per_device, capacity, D] + reshaped = expert_outputs.reshape( + experts_per_device, world_size, capacity, hidden_dim + ) + reshaped = mx.transpose(reshaped, axes=(1, 0, 2, 3)) + flat = reshaped.reshape(world_size, -1) + exchanged = mx.distributed.all_to_all(flat, group=group) + result_buffer = exchanged.reshape( + world_size, experts_per_device, capacity, hidden_dim + ) + else: + result_buffer = expert_outputs.reshape( + 1, experts_per_device, capacity, hidden_dim + ) + + result_flat = result_buffer.reshape(-1, hidden_dim) + combined = mx.zeros_like(original_tokens) + top_k = meta.expert_indices.shape[1] + slots_per_device = experts_per_device * capacity + zero_idx = mx.array(0, dtype=mx.int32) + + for k in range(top_k): + indices_k = meta.expert_indices[:, k] + positions_k = meta.positions[:, k] + weights_k = meta.weights[:, k] + device_idx = indices_k // experts_per_device + local_expert = indices_k % experts_per_device + + flat_idx = device_idx * slots_per_device + local_expert * capacity + positions_k + valid = positions_k >= 0 + flat_idx = mx.where(valid, flat_idx, zero_idx).astype(mx.int32) + + gathered = result_flat[flat_idx] + safe_gathered = mx.where( + valid.reshape(-1, 1), gathered, mx.zeros_like(gathered) + ) + combined = combined + weights_k.reshape(-1, 1) * safe_gathered + + has_valid_route = (meta.positions >= 0).any(axis=1, keepdims=True) + combined = mx.where(has_valid_route, combined, original_tokens) + + return combined + + +class MixtureOfExperts(Module): + """Mixture of Experts layer with Expert Parallelism support. + + Args: + hidden_dim: Input/output hidden dimension. + expert_dim: Expert intermediate dimension. + num_experts: Total number of experts. + top_k: Number of experts per token. Default: ``2``. + capacity_factor: Capacity scaling factor. Default: ``1.25``. + aux_loss_coeff: Load balance loss coefficient. Default: ``0.01``. + ep_impl: Expert parallelism implementation to use. One of ``"auto"``, + ``"python"``, or ``"cpp"``. ``"auto"`` uses the Python vectorized + path (safe for training; C++ VJP not yet available). ``"cpp"`` + uses the fused C++ primitive (inference-only, no gradient support). + ``"python"`` always uses the Python vectorized path. Default: + ``"auto"``. + ep_backend: Backend for the C++ MoE exchange. One of ``"auto"``, + ``"cpu"``, or ``"metal"``. ``"auto"`` selects based on workload + size. Only used when ``ep_impl`` is ``"cpp"`` or ``"auto"`` with + C++ path active. Default: ``"auto"``. + + Limitations: + - Inference only: VJP/backward not implemented for C++ fused + dispatch/combine primitives. + - ws=2 optimized: world_size > 2 automatically downgrades Metal to + CPU backend and uses the fixed all_to_all CPU path (one-time + warning emitted). + - Expert parallelism is opt-in via ``ep_impl="cpp"`` parameter. + """ + + def __init__( + self, + hidden_dim: int, + expert_dim: int, + num_experts: int, + top_k: int = 2, + capacity_factor: float = 1.25, + aux_loss_coeff: float = 0.01, + ep_impl: str = "auto", + ep_backend: str = "auto", + ): + super().__init__() + + if num_experts <= 0: + raise ValueError(f"num_experts must be positive, got {num_experts}") + if top_k <= 0: + raise ValueError(f"top_k must be positive, got {top_k}") + if top_k > num_experts: + raise ValueError( + f"top_k ({top_k}) must not exceed num_experts ({num_experts})" + ) + + # Determine distributed context + self._world_size = 1 + self._group = None + try: + group = mx.distributed.init(strict=False) + if group.size() > 1: + # Probe all_to_all support; some backends (ring, NCCL) + # do not implement it and would crash on every forward pass. + try: + test = mx.distributed.all_to_all( + mx.zeros((group.size(),)), group=group + ) + mx.eval(test) + self._world_size = group.size() + self._group = group + except RuntimeError: + # Backend doesn't support all_to_all, fall back to local-only + pass + except Exception: + pass + + if num_experts % self._world_size != 0: + raise ValueError( + f"num_experts ({num_experts}) must be divisible by " + f"world_size ({self._world_size})" + ) + + self.hidden_dim = hidden_dim + self.expert_dim = expert_dim + self.num_experts = num_experts + self.top_k = top_k + self.capacity_factor = capacity_factor + self.ep_impl = ep_impl + self.ep_backend = ep_backend + + # Router + self.router = TopKRouter( + hidden_dim, num_experts, top_k, capacity_factor, aux_loss_coeff + ) + + # Local experts for this device + experts_per_device = num_experts // self._world_size + self.experts = [ + Expert(hidden_dim, expert_dim) for _ in range(experts_per_device) + ] + + def __call__(self, x: mx.array): + """Forward pass. + + Args: + x: Input tensor of shape ``[N, hidden_dim]``. + + Returns: + Tuple of (output, aux_loss): + - output: ``[N, hidden_dim]`` combined expert outputs + - aux_loss: scalar load balancing loss + """ + # Route + weights, expert_indices, aux_loss = self.router(x) + + # Determine implementation to use + # auto: use Python path (safe for training; C++ VJP not yet implemented) + # cpp: use C++ primitive (inference-only, no grad support) + # python: always use Python vectorized path + use_cpp = ( + self.ep_impl == "cpp" + and self._group is not None + and hasattr(mx.distributed, "moe_dispatch_exchange") + ) + + if use_cpp: + # C++ fused primitive path (inference-only) + capacity = _compute_capacity( + x.shape[0], self.router.top_k, self.capacity_factor, self.num_experts + ) + # Synchronize capacity across ranks + if self._world_size > 1: + cap_arr = mx.array(capacity, dtype=mx.int32) + cap_arr = mx.distributed.all_max(cap_arr, group=self._group) + mx.eval(cap_arr) + capacity = cap_arr.item() + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + x, + expert_indices.astype(mx.int32), + num_experts=self.num_experts, + capacity=capacity, + group=self._group, + backend=self.ep_backend, + ) + route_idx = mx.stop_gradient(route_idx) + expert_out = self._run_local_experts(dispatched) + weights_f32 = weights.astype(mx.float32) + output = mx.distributed.moe_combine_exchange( + expert_out, + route_idx, + weights_f32, + x, + num_experts=self.num_experts, + capacity=capacity, + group=self._group, + backend=self.ep_backend, + ) + else: + # Python vectorized path (supports grad) + dispatched, meta = expert_dispatch( + x, + expert_indices, + weights, + self.num_experts, + self.capacity_factor, + group=self._group, + ) + expert_out = self._run_local_experts(dispatched) + output = expert_combine( + expert_out, + meta, + x, + group=self._group, + ) + + return output, aux_loss + + def _run_local_experts(self, dispatched: mx.array) -> mx.array: + """Run local experts on dispatched tokens. + + Args: + dispatched: [experts_per_device, capacity_total, D] dispatched inputs. + + Returns: + [experts_per_device, capacity_total, D] expert outputs. + """ + outputs = [] + for i, expert in enumerate(self.experts): + expert_input = dispatched[i] # [capacity_total, D] + expert_output = expert(expert_input) # [capacity_total, D] + outputs.append(expert_output) + return mx.stack(outputs, axis=0) diff --git a/python/src/distributed.cpp b/python/src/distributed.cpp index 9f4a7cb59e..c5588fd339 100644 --- a/python/src/distributed.cpp +++ b/python/src/distributed.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -349,4 +350,170 @@ void init_distributed(nb::module_& parent_module) { Returns: array: The output array with shape ``[x.shape[0] // group.size(), *x.shape[1:]]``. )pbdoc"); + + m.def( + "all_to_all", + [](const ScalarOrArray& x, + std::optional group, + mx::StreamOrDevice s) { + return mx::distributed::all_to_all(to_array(x), group, s); + }, + "x"_a, + nb::kw_only(), + "group"_a = nb::none(), + "stream"_a = nb::none(), + nb::sig( + "def all_to_all(x: array, *, group: Optional[Group] = None, " + "stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + All-to-all exchange of data between processes. + + Each process splits its input along the first axis into ``group.size()`` + chunks and sends chunk *i* to process *i*. All processes receive one chunk + from every other process and concatenate them in rank order. The output + has the same shape as the input. + + ``x.shape[0]`` must be divisible by the group size. + + Args: + x (array): Input array. + group (Group): The group of processes that will participate in the + exchange. If set to ``None`` the global group is used. Default: + ``None``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The result of the all-to-all exchange. + )pbdoc"); + + m.def( + "moe_dispatch_exchange", + [](const ScalarOrArray& tokens, + const ScalarOrArray& expert_indices, + int num_experts, + int capacity, + std::optional group, + bool deterministic, + const std::string& backend, + mx::StreamOrDevice s) { + auto [dispatched, route_idx] = mx::distributed::moe_dispatch_exchange( + to_array(tokens), + to_array(expert_indices), + num_experts, + capacity, + group, + deterministic, + backend, + s); + return nb::make_tuple(dispatched, route_idx); + }, + "tokens"_a, + "expert_indices"_a, + nb::kw_only(), + "num_experts"_a, + "capacity"_a, + "group"_a = nb::none(), + "deterministic"_a = true, + "backend"_a = "cpu", + "stream"_a = nb::none(), + nb::sig( + "def moe_dispatch_exchange(tokens: array, expert_indices: array, " + "*, num_experts: int, capacity: int, group: Optional[Group] = None, " + "deterministic: bool = True, backend: str = \"cpu\", " + "stream: Union[None, Stream, Device] = None) -> tuple[array, array]"), + R"pbdoc( + Fused MoE dispatch and all-to-all exchange. + + Scatters tokens into the dispatch buffer and performs all-to-all + exchange in a single fused primitive. + + Args: + tokens (array): Input tokens of shape ``[N, D]``. + expert_indices (array): Expert assignments of shape ``[N, top_k]`` (int32). + num_experts (int): Total number of experts across all devices. + capacity (int): Per-expert capacity (max tokens per expert). + group (Group, optional): Distributed group. Default: global group. + deterministic (bool, optional): Use token-order based slot assignment. + Default: ``True``. + backend (str, optional): Compute backend for the fused primitive. + ``"auto"`` selects CPU for small workloads (N <= 64) and Metal + for large workloads (N >= 320). ``"cpu"`` forces the CPU path. + ``"metal"`` forces the Metal GPU path. Default: ``"cpu"``. + stream (Stream, optional): Stream or device. Default: ``None``. + + Returns: + tuple[array, array]: ``(dispatched, route_indices)`` where + - ``dispatched``: ``[experts_per_device, world_size * capacity, D]`` + - ``route_indices``: ``[N, top_k]`` int32, -1 means overflow + )pbdoc"); + + m.def( + "moe_combine_exchange", + [](const ScalarOrArray& expert_outputs, + const ScalarOrArray& route_indices, + const ScalarOrArray& weights, + const ScalarOrArray& original_tokens, + int num_experts, + int capacity, + std::optional group, + bool deterministic, + const std::string& backend, + mx::StreamOrDevice s) { + return mx::distributed::moe_combine_exchange( + to_array(expert_outputs), + to_array(route_indices), + to_array(weights), + to_array(original_tokens), + num_experts, + capacity, + group, + deterministic, + backend, + s); + }, + "expert_outputs"_a, + "route_indices"_a, + "weights"_a, + "original_tokens"_a, + nb::kw_only(), + "num_experts"_a, + "capacity"_a, + "group"_a = nb::none(), + "deterministic"_a = true, + "backend"_a = "cpu", + "stream"_a = nb::none(), + nb::sig( + "def moe_combine_exchange(expert_outputs: array, route_indices: array, " + "weights: array, original_tokens: array, " + "*, num_experts: int, capacity: int, group: Optional[Group] = None, " + "deterministic: bool = True, backend: str = \"cpu\", " + "stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Fused MoE all-to-all exchange and combine. + + Performs all-to-all exchange and gathers expert outputs back to token + order with weighted summation in a single fused primitive. + + Args: + expert_outputs (array): Expert outputs of shape + ``[experts_per_device, world_size * capacity, D]``. + route_indices (array): Route indices from ``moe_dispatch_exchange``, + shape ``[N, top_k]`` int32. + weights (array): Routing weights, shape ``[N, top_k]``. + original_tokens (array): Original input tokens ``[N, D]`` used as + residual fallback for fully-overflowed tokens. + num_experts (int): Total number of experts across all devices. + capacity (int): Per-expert capacity. + group (Group, optional): Distributed group. Default: global group. + deterministic (bool, optional): Default: ``True``. + backend (str, optional): Compute backend for the fused primitive. + ``"auto"`` selects CPU for small workloads (N <= 64) and Metal + for large workloads (N >= 320). ``"cpu"`` forces the CPU path. + ``"metal"`` forces the Metal GPU path. Default: ``"cpu"``. + stream (Stream, optional): Stream or device. Default: ``None``. + + Returns: + array: Combined tokens of shape ``[N, D]``. + )pbdoc"); } diff --git a/python/tests/jaccl_test_distributed.py b/python/tests/jaccl_test_distributed.py new file mode 100644 index 0000000000..1f4512f077 --- /dev/null +++ b/python/tests/jaccl_test_distributed.py @@ -0,0 +1,158 @@ +# Copyright © 2026 Apple Inc. + +import mlx.core as mx +import mlx_distributed_tests +import mlx_tests + + +class TestJACCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase): + @classmethod + def setUpClass(cls): + _ = mx.distributed.init(strict=True, backend="jaccl") + cls.atol = 1e-6 + cls.rtol = 1e-4 + + def test_groups(self): + world = mx.distributed.init() + self.assertEqual(world.size(), 2) + self.assertTrue(0 <= world.rank() < 2) + + world2 = mx.distributed.init() + self.assertEqual(world.size(), world2.size()) + self.assertEqual(world.rank(), world2.rank()) + + # ------------------------------------------------------------------ + # MoE Expert Parallelism C++ primitive tests (2-rank JACCL) + # ------------------------------------------------------------------ + + def test_moe_ep_cpp_roundtrip_2rank(self): + """Dispatch + combine roundtrip numerical correctness over 2 ranks.""" + world = mx.distributed.init() + if world.size() != 2: + self.skipTest("requires 2 ranks") + if not hasattr(mx.distributed, "moe_dispatch_exchange"): + self.skipTest("moe_dispatch_exchange not available") + + N = 16 + D = 64 + num_experts = 4 + top_k = 2 + capacity = 8 + experts_per_device = num_experts // world.size() # 2 + + for dtype in [mx.float32, mx.float16, mx.bfloat16]: + mx.random.seed(42 + world.rank()) + tokens = mx.random.normal((N, D)).astype(dtype) + + # Random expert indices in [0, num_experts) + expert_indices = mx.random.randint(0, num_experts, shape=(N, top_k)).astype( + mx.int32 + ) + + # Dispatch + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=num_experts, + capacity=capacity, + group=world, + ) + mx.eval(dispatched, route_idx) + + # Shape: [experts_per_device, world_size * capacity, D] + cap_total = world.size() * capacity + self.assertEqual( + dispatched.shape, + (experts_per_device, cap_total, D), + f"dtype={dtype}: dispatched shape mismatch", + ) + self.assertEqual(dispatched.dtype, dtype) + self.assertEqual(route_idx.shape, (N, top_k)) + self.assertEqual(route_idx.dtype, mx.int32) + + # Identity expert: expert_outputs = dispatched + weights = mx.ones((N, top_k), dtype=mx.float32) / top_k + original_tokens = mx.zeros((N, D), dtype=dtype) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + original_tokens, + num_experts=num_experts, + capacity=capacity, + group=world, + ) + mx.eval(combined) + + self.assertEqual(combined.shape, (N, D), f"dtype={dtype}: combined shape") + self.assertEqual(combined.dtype, dtype, f"dtype={dtype}: combined dtype") + self.assertTrue( + mx.all(mx.isfinite(combined)).item(), + f"dtype={dtype}: combined contains non-finite values", + ) + + def test_moe_ep_cpp_asymmetric_traffic_2rank(self): + """Asymmetric traffic completes without deadlock.""" + world = mx.distributed.init() + if world.size() != 2: + self.skipTest("requires 2 ranks") + if not hasattr(mx.distributed, "moe_dispatch_exchange"): + self.skipTest("moe_dispatch_exchange not available") + + N = 32 + D = 64 + num_experts = 4 + capacity = 16 + top_k = 2 + experts_per_device = num_experts // world.size() # 2 + + mx.random.seed(100 + world.rank()) + tokens = mx.random.normal((N, D)) + + # Rank 0: all tokens → experts 2,3 (owned by rank 1) + # Rank 1: all tokens → experts 0,1 (owned by rank 0) + if world.rank() == 0: + expert_indices = mx.random.randint(2, 4, shape=(N, top_k)).astype(mx.int32) + else: + expert_indices = mx.random.randint(0, 2, shape=(N, top_k)).astype(mx.int32) + + # Dispatch + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=num_experts, + capacity=capacity, + group=world, + ) + mx.eval(dispatched, route_idx) + + cap_total = world.size() * capacity + self.assertEqual( + dispatched.shape, + (experts_per_device, cap_total, D), + ) + + # Combine + weights = mx.ones((N, top_k)) / top_k + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=num_experts, + capacity=capacity, + group=world, + ) + mx.eval(combined) + + self.assertIsNotNone(combined) + self.assertEqual(combined.shape, (N, D)) + self.assertTrue( + mx.all(mx.isfinite(combined)).item(), + "combined contains non-finite values after asymmetric exchange", + ) + + +if __name__ == "__main__": + mlx_tests.MLXTestRunner() diff --git a/python/tests/mlx_distributed_tests.py b/python/tests/mlx_distributed_tests.py index 644da793be..6eed108344 100644 --- a/python/tests/mlx_distributed_tests.py +++ b/python/tests/mlx_distributed_tests.py @@ -1,9 +1,16 @@ # Copyright © 2025 Apple Inc. +import unittest + import mlx.core as mx import mlx.nn as nn import mlx_tests from mlx.nn.layers.distributed import shard_inplace, shard_linear +from mlx.nn.layers.moe import ( + MixtureOfExperts, + expert_combine, + expert_dispatch, +) from mlx.nn.utils import average_gradients @@ -317,6 +324,140 @@ def sharding(path, weight): y2 = smod(x) self.assertTrue(mx.allclose(y1, y2, atol=1e-6, rtol=1e-4)) + def _skip_if_all_to_all_unsupported(self, group): + try: + test = mx.distributed.all_to_all(mx.zeros((group.size(),)), group=group) + mx.eval(test) + except RuntimeError: + self.skipTest("all_to_all not supported on this backend") + + def test_all_to_all(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + rank = group.rank() + + # Test multiple dtypes + dtypes = [mx.float32, mx.float16, mx.bfloat16, mx.int32] + + for dt in dtypes: + # Create a [world_size * 4, 8] tensor with rank-specific values + rows = world_size * 4 + cols = 8 + x = (mx.ones((rows, cols), dtype=dt) * (rank * 100)) + mx.broadcast_to( + mx.arange(rows).reshape(-1, 1), (rows, cols) + ).astype(dt) + + y = mx.distributed.all_to_all(x, group=group) + mx.eval(y) + + # Output shape should equal input shape + self.assertEqual(y.shape, x.shape) + + if world_size == 1: + # For single process: all_to_all is identity + self.assertTrue(mx.array_equal(y, x).item()) + else: + # For multi-process: verify the all-to-all permutation + # Each rank's output chunk i should come from rank i's input chunk rank + chunk_size = rows // world_size + for src_rank in range(world_size): + out_chunk = y[src_rank * chunk_size : (src_rank + 1) * chunk_size] + # This chunk should be what src_rank sent to us (our rank-th chunk of src_rank's input) + expected_vals = ( + mx.ones((chunk_size, cols), dtype=dt) * (src_rank * 100) + ) + mx.broadcast_to( + mx.arange(rank * chunk_size, (rank + 1) * chunk_size).reshape( + -1, 1 + ), + (chunk_size, cols), + ).astype( + dt + ) + self.assertTrue(mx.array_equal(out_chunk, expected_vals).item()) + + def test_all_to_all_sizes(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + + # Test various input sizes + sizes = [ + (world_size,), # minimal 1D + (world_size * 256, 64), # medium 2D + (world_size * 2, 3, 4, 5), # multi-dimensional + ] + + for sh in sizes: + x = mx.ones(sh, dtype=mx.float32) + y = mx.distributed.all_to_all(x, group=group) + mx.eval(y) + + self.assertEqual(y.shape, x.shape) + if world_size == 1: + self.assertTrue(mx.array_equal(y, x).item()) + + def test_all_to_all_non_contiguous(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + + # Create a non-contiguous input via transpose then slice + base = mx.random.normal((8, world_size * 4)) + x_non_contig = base.T # shape (world_size * 4, 8), non-contiguous + + # Create contiguous copy + x_contig = mx.array(x_non_contig) + + y1 = mx.distributed.all_to_all(x_non_contig, group=group) + y2 = mx.distributed.all_to_all(x_contig, group=group) + mx.eval(y1, y2) + + self.assertTrue(mx.allclose(y1, y2).item()) + + def test_all_to_all_vjp(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + + x = mx.random.normal((world_size * 4, 8)) + mx.eval(x) + + # Test mx.grad + grad_fn = mx.grad(lambda x: mx.distributed.all_to_all(x, group=group).sum()) + g = grad_fn(x) + mx.eval(g) + + if world_size == 1: + # For single process: gradient of identity + sum is all ones + self.assertTrue(mx.allclose(g, mx.ones_like(g)).item()) + + # Test mx.value_and_grad + val_grad_fn = mx.value_and_grad( + lambda x: mx.distributed.all_to_all(x, group=group).sum() + ) + val, g2 = val_grad_fn(x) + mx.eval(val, g2) + + self.assertEqual(g2.shape, x.shape) + + @unittest.skipIf(mx.distributed.init().size() == 1, "requires world_size > 1") + def test_all_to_all_shape_validation(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + world_size = group.size() + + # Test that scalar input raises an exception + scalar = mx.array(1.0) + with self.assertRaises(Exception): + mx.eval(mx.distributed.all_to_all(scalar, group=group)) + + # Test that x.shape[0] % world_size != 0 raises (only meaningful for world_size > 1) + if world_size > 1: + bad = mx.ones((world_size * 4 + 1, 8)) + with self.assertRaises(Exception): + mx.eval(mx.distributed.all_to_all(bad, group=group)) + def test_all_gather(self): world = mx.distributed.init() dtypes = [ @@ -333,3 +474,249 @@ def test_all_gather(self): y = mx.distributed.all_gather(x) self.assertEqual(y.shape, (world.size() * 2, 2, 4)) self.assertTrue(mx.all(y == 1)) + + def test_moe_ep_forward(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + if group.size() != 2: + self.skipTest("MoE EP tests require exactly 2 devices") + + world_size = group.size() + rank = group.rank() + + mx.random.seed(42) + hidden_dim = 16 + expert_dim = 32 + num_experts = 4 + top_k = 2 + num_tokens = 8 + + moe = MixtureOfExperts( + hidden_dim=hidden_dim, + expert_dim=expert_dim, + num_experts=num_experts, + top_k=top_k, + capacity_factor=2.0, + ) + mx.eval(moe.parameters()) + + x = mx.random.normal((num_tokens, hidden_dim)) + rank * 1000 + output, aux_loss = moe(x) + mx.eval(output, aux_loss) + + # Shape check + self.assertEqual(output.shape, (num_tokens, hidden_dim)) + # Finiteness check + self.assertTrue(mx.all(mx.isfinite(output)).item()) + self.assertTrue(mx.isfinite(aux_loss).item()) + # EP enabled check + self.assertEqual(moe._world_size, world_size) + # Local expert count + self.assertEqual(len(moe.experts), num_experts // world_size) + + def test_moe_ep_uneven_batch(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + if group.size() != 2: + self.skipTest("MoE EP tests require exactly 2 devices") + + rank = group.rank() + + mx.random.seed(42) + hidden_dim = 16 + expert_dim = 32 + num_experts = 4 + top_k = 2 + + moe = MixtureOfExperts( + hidden_dim=hidden_dim, + expert_dim=expert_dim, + num_experts=num_experts, + top_k=top_k, + capacity_factor=2.0, + ) + mx.eval(moe.parameters()) + + if rank == 0: + x = mx.random.normal((4, hidden_dim)) + rank * 1000 + else: + x = mx.zeros((0, hidden_dim)) + + output, aux_loss = moe(x) + mx.eval(output, aux_loss) + + if rank == 0: + self.assertEqual(output.shape, (4, hidden_dim)) + self.assertTrue(mx.all(mx.isfinite(output)).item()) + else: + # Empty rank checks + self.assertEqual(output.shape, (0, hidden_dim)) + self.assertTrue(mx.all(mx.isfinite(output)).item()) + self.assertEqual(aux_loss.item(), 0.0) + + def test_moe_ep_dispatch_combine_roundtrip(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + if group.size() != 2: + self.skipTest("MoE EP tests require exactly 2 devices") + + rank = group.rank() + hidden_dim = 8 + num_tokens = 4 + num_experts = 4 + top_k = 1 + capacity_factor = 2.0 + + # Rank-distinct input + x = (mx.arange(num_tokens).reshape(-1, 1) + 1 + rank * 1000).astype(mx.float32) + x = mx.broadcast_to(x, (num_tokens, hidden_dim)) + x = mx.array(x) # make contiguous + + # Build routing: half local, half remote + expert_indices = mx.zeros((num_tokens, top_k), dtype=mx.int32) + for i in range(num_tokens): + if i < num_tokens // 2: + # Local expert + expert_indices = expert_indices.at[i, 0].add(rank * 2 + i % 2) + else: + # Remote expert + expert_indices = expert_indices.at[i, 0].add((1 - rank) * 2 + i % 2) + + weights = mx.ones((num_tokens, top_k), dtype=mx.float32) + + dispatched, meta = expert_dispatch( + x, + expert_indices, + weights, + num_experts=num_experts, + capacity_factor=capacity_factor, + group=group, + ) + mx.eval(dispatched) + + # Identity: just pass through + combined = expert_combine(dispatched, meta, x, group=group) + mx.eval(combined) + + # Check valid-route tokens round-trip correctly + has_valid = (meta.positions >= 0).any(axis=1) + mx.eval(has_valid) + for i in range(num_tokens): + if has_valid[i].item(): + self.assertTrue( + mx.allclose(combined[i], x[i], atol=1e-5, rtol=1e-4).item(), + f"Token {i} on rank {rank} did not round-trip correctly", + ) + + def test_moe_ep_partial_overflow(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + if group.size() != 2: + self.skipTest("MoE EP tests require exactly 2 devices") + + rank = group.rank() + hidden_dim = 8 + num_tokens = 4 + num_experts = 4 + top_k = 2 + capacity_factor = 0.5 # Force overflow + + # Rank-distinct input + x = mx.ones((num_tokens, hidden_dim), dtype=mx.float32) * (rank + 1.0) * 100 + + # Route all tokens to experts 0 and 1 -> overflow with low capacity + expert_indices = mx.zeros((num_tokens, top_k), dtype=mx.int32) + expert_indices = expert_indices.at[:, 1].add( + 1 + ) # k=0 -> expert 0, k=1 -> expert 1 + weights = mx.ones((num_tokens, top_k), dtype=mx.float32) * 0.5 + + dispatched, meta = expert_dispatch( + x, + expert_indices, + weights, + num_experts=num_experts, + capacity_factor=capacity_factor, + group=group, + ) + mx.eval(dispatched) + + combined = expert_combine(dispatched, meta, x, group=group) + mx.eval(combined) + + has_valid_route = (meta.positions >= 0).any(axis=1) + mx.eval(has_valid_route) + + for i in range(num_tokens): + if has_valid_route[i].item(): + # Valid route tokens should be finite + self.assertTrue( + mx.all(mx.isfinite(combined[i])).item(), + f"Token {i} on rank {rank} has non-finite values", + ) + else: + # Overflow tokens should fall back to residual (original input) + self.assertTrue( + mx.array_equal(combined[i], x[i]).item(), + f"Overflow token {i} on rank {rank} did not use residual", + ) + + def test_moe_ep_gradient(self): + group = mx.distributed.init() + self._skip_if_all_to_all_unsupported(group) + if group.size() != 2: + self.skipTest("MoE EP tests require exactly 2 devices") + + rank = group.rank() + + mx.random.seed(42) + hidden_dim = 16 + expert_dim = 32 + num_experts = 4 + top_k = 2 + num_tokens = 8 + + moe = MixtureOfExperts( + hidden_dim=hidden_dim, + expert_dim=expert_dim, + num_experts=num_experts, + top_k=top_k, + capacity_factor=2.0, + aux_loss_coeff=0.0, # Exclude aux_loss to test all_to_all VJP path + ) + mx.eval(moe.parameters()) + + x = mx.random.normal((num_tokens, hidden_dim)) + rank * 1000 + + def loss_fn(model, x): + output, _aux = model(x) + return output.sum() + + loss_and_grad = nn.value_and_grad(moe, loss_fn) + loss, grads = loss_and_grad(moe, x) + mx.eval(loss, grads) + + # Loss should be finite + self.assertTrue(mx.isfinite(loss).item()) + + # Router gate grad: must be finite + gate_grad = grads["router"]["gate"]["weight"] + self.assertTrue(mx.all(mx.isfinite(gate_grad)).item()) + gate_has_grad = mx.any(gate_grad != 0).item() + + # Check local expert grads: must be finite, at least one non-zero + any_expert_has_grad = False + for i in range(len(moe.experts)): + w_gate_grad = grads["experts"][i]["w_gate"]["weight"] + self.assertTrue( + mx.all(mx.isfinite(w_gate_grad)).item(), + f"Expert {i} w_gate grad has non-finite values on rank {rank}", + ) + if mx.any(w_gate_grad != 0).item(): + any_expert_has_grad = True + + # At least gate or one expert must receive non-zero gradients + self.assertTrue( + gate_has_grad or any_expert_has_grad, + f"Neither gate nor any expert received non-zero gradients on rank {rank}", + ) diff --git a/python/tests/test_moe.py b/python/tests/test_moe.py new file mode 100644 index 0000000000..566a3b82b4 --- /dev/null +++ b/python/tests/test_moe.py @@ -0,0 +1,1221 @@ +# Copyright © 2026 Apple Inc. + +import math +import unittest + +import mlx.core as mx +import mlx.nn as nn +import mlx_tests + +# Imports will be validated once the MoE module is implemented. +# If the module path changes, update accordingly. +from mlx.nn.layers.moe import ( + DispatchMeta, + Expert, + MixtureOfExperts, + TopKRouter, + _compute_capacity, + expert_combine, + expert_dispatch, +) + + +class TestTopKRouter(mlx_tests.MLXTestCase): + def test_output_shapes(self): + """Router should produce correct output shapes.""" + hidden_dim, num_experts, top_k = 64, 8, 2 + router = TopKRouter(hidden_dim, num_experts, top_k) + x = mx.random.normal((16, hidden_dim)) + weights, indices, aux_loss = router(x) + mx.eval(weights, indices, aux_loss) + + self.assertEqual(weights.shape, (16, top_k)) + self.assertEqual(indices.shape, (16, top_k)) + self.assertEqual(aux_loss.shape, ()) + + def test_index_range(self): + """Expert indices should be in [0, num_experts).""" + hidden_dim, num_experts, top_k = 64, 8, 2 + router = TopKRouter(hidden_dim, num_experts, top_k) + x = mx.random.normal((32, hidden_dim)) + _, indices, _ = router(x) + mx.eval(indices) + + self.assertTrue(mx.all(indices >= 0).item()) + self.assertTrue(mx.all(indices < num_experts).item()) + + def test_weights_sum(self): + """Routing weights should approximately sum to 1 per token.""" + hidden_dim, num_experts, top_k = 64, 8, 2 + router = TopKRouter(hidden_dim, num_experts, top_k) + x = mx.random.normal((16, hidden_dim)) + weights, _, _ = router(x) + mx.eval(weights) + + weight_sums = weights.sum(axis=-1) + self.assertTrue( + mx.allclose(weight_sums, mx.ones_like(weight_sums), atol=1e-5).item() + ) + + def test_gradient_flow(self): + """Gradients should flow through the router gate.""" + hidden_dim, num_experts, top_k = 32, 4, 2 + router = TopKRouter(hidden_dim, num_experts, top_k) + x = mx.random.normal((8, hidden_dim)) + + def loss_fn(model, x): + weights, _, aux_loss = model(x) + return weights.sum() + aux_loss + + loss, grads = nn.value_and_grad(router, loss_fn)(router, x) + mx.eval(loss, grads) + + # Gate weight should have non-zero gradient + self.assertTrue(mx.any(grads["gate"]["weight"] != 0).item()) + + def test_aux_loss_positive(self): + """Auxiliary loss should be non-negative.""" + hidden_dim, num_experts, top_k = 64, 8, 2 + router = TopKRouter(hidden_dim, num_experts, top_k) + x = mx.random.normal((16, hidden_dim)) + _, _, aux_loss = router(x) + mx.eval(aux_loss) + + self.assertTrue(aux_loss.item() >= 0) + + def test_different_top_k_values(self): + """Router should work with different top_k values.""" + hidden_dim, num_experts = 64, 8 + x = mx.random.normal((16, hidden_dim)) + + for top_k in [1, 2, 4]: + router = TopKRouter(hidden_dim, num_experts, top_k) + weights, indices, aux_loss = router(x) + mx.eval(weights, indices, aux_loss) + + self.assertEqual(weights.shape, (16, top_k)) + self.assertEqual(indices.shape, (16, top_k)) + + def test_single_token(self): + """Router should handle a single-token input.""" + hidden_dim, num_experts, top_k = 64, 8, 2 + router = TopKRouter(hidden_dim, num_experts, top_k) + x = mx.random.normal((1, hidden_dim)) + weights, indices, aux_loss = router(x) + mx.eval(weights, indices, aux_loss) + + self.assertEqual(weights.shape, (1, top_k)) + self.assertEqual(indices.shape, (1, top_k)) + + def test_top_k_validation(self): + """Should raise error for invalid top_k.""" + with self.assertRaises(ValueError): + TopKRouter(64, 8, top_k=0) + with self.assertRaises(ValueError): + TopKRouter(64, 8, top_k=9) + + def test_empty_batch(self): + """Router should handle zero-token input without NaN.""" + router = TopKRouter(64, 8, top_k=2) + x = mx.zeros((0, 64)) + weights, indices, aux_loss = router(x) + mx.eval(weights, indices, aux_loss) + self.assertEqual(weights.shape, (0, 2)) + self.assertEqual(indices.shape, (0, 2)) + self.assertTrue(mx.isfinite(aux_loss).item()) + self.assertEqual(aux_loss.item(), 0.0) + + +class TestExpert(mlx_tests.MLXTestCase): + def test_output_shape(self): + """Expert should preserve input/output dimensions.""" + hidden_dim, expert_dim = 64, 128 + expert = Expert(hidden_dim, expert_dim) + x = mx.random.normal((8, hidden_dim)) + out = expert(x) + mx.eval(out) + + self.assertEqual(out.shape, (8, hidden_dim)) + + def test_gradient_flow(self): + """Gradients should flow through expert.""" + hidden_dim, expert_dim = 32, 64 + expert = Expert(hidden_dim, expert_dim) + x = mx.random.normal((4, hidden_dim)) + + def loss_fn(model, x): + return model(x).sum() + + loss, grads = nn.value_and_grad(expert, loss_fn)(expert, x) + mx.eval(loss, grads) + + self.assertTrue(mx.any(grads["w_gate"]["weight"] != 0).item()) + self.assertTrue(mx.any(grads["w_up"]["weight"] != 0).item()) + self.assertTrue(mx.any(grads["w_down"]["weight"] != 0).item()) + + def test_single_token(self): + """Expert should handle single-token input.""" + hidden_dim, expert_dim = 64, 128 + expert = Expert(hidden_dim, expert_dim) + x = mx.random.normal((1, hidden_dim)) + out = expert(x) + mx.eval(out) + + self.assertEqual(out.shape, (1, hidden_dim)) + + def test_empty_input(self): + """Expert should handle zero-token input.""" + hidden_dim, expert_dim = 64, 128 + expert = Expert(hidden_dim, expert_dim) + x = mx.zeros((0, hidden_dim)) + out = expert(x) + mx.eval(out) + + self.assertEqual(out.shape, (0, hidden_dim)) + + +class TestComputeCapacity(mlx_tests.MLXTestCase): + def test_basic(self): + """Test capacity computation.""" + # 16 tokens, top_k=2, factor=1.25, 8 experts + cap = _compute_capacity(16, 2, 1.25, 8) + expected = max(1, math.ceil(16 * 2 * 1.25 / 8)) # ceil(5.0) = 5 + self.assertEqual(cap, expected) + + def test_minimum_one(self): + """Capacity should be at least 1.""" + cap = _compute_capacity(0, 2, 1.0, 8) + self.assertEqual(cap, 1) + + def test_exact_division(self): + """Test when division is exact.""" + # 8 tokens, top_k=1, factor=1.0, 4 experts -> ceil(2.0) = 2 + cap = _compute_capacity(8, 1, 1.0, 4) + self.assertEqual(cap, 2) + + def test_large_capacity_factor(self): + """Larger capacity factor yields larger capacity.""" + cap_low = _compute_capacity(16, 2, 1.0, 8) + cap_high = _compute_capacity(16, 2, 2.0, 8) + self.assertGreaterEqual(cap_high, cap_low) + + +class TestMixtureOfExperts(mlx_tests.MLXTestCase): + def test_forward_shape(self): + """MoE forward should produce correct output shape.""" + hidden_dim, expert_dim, num_experts = 64, 128, 4 + moe = MixtureOfExperts(hidden_dim, expert_dim, num_experts, top_k=2) + x = mx.random.normal((8, hidden_dim)) + output, aux_loss = moe(x) + mx.eval(output, aux_loss) + + self.assertEqual(output.shape, (8, hidden_dim)) + self.assertEqual(aux_loss.shape, ()) + + def test_backward(self): + """MoE should support backward pass.""" + hidden_dim, expert_dim, num_experts = 32, 64, 4 + moe = MixtureOfExperts(hidden_dim, expert_dim, num_experts, top_k=2) + x = mx.random.normal((4, hidden_dim)) + + def loss_fn(model, x): + output, aux_loss = model(x) + return output.sum() + aux_loss + + loss, grads = nn.value_and_grad(moe, loss_fn)(moe, x) + mx.eval(loss, grads) + + # At least the router should have gradients + self.assertIsNotNone(grads["router"]["gate"]["weight"]) + + def test_parameter_count(self): + """Verify parameter structure.""" + hidden_dim, expert_dim, num_experts = 64, 128, 4 + moe = MixtureOfExperts(hidden_dim, expert_dim, num_experts, top_k=2) + + params = moe.parameters() + # Should have router and experts + self.assertIn("router", params) + self.assertIn("experts", params) + # Single process: all experts are local + self.assertEqual(len(params["experts"]), num_experts) + + def test_validation_error(self): + """Should raise error for invalid num_experts or top_k.""" + with self.assertRaises(Exception): + MixtureOfExperts(64, 128, 0) + with self.assertRaises(ValueError): + MixtureOfExperts(64, 128, 4, top_k=0) + with self.assertRaises(ValueError): + MixtureOfExperts(64, 128, 4, top_k=5) + + def test_different_top_k(self): + """MoE should work with different top_k values.""" + hidden_dim, expert_dim, num_experts = 64, 128, 4 + x = mx.random.normal((8, hidden_dim)) + + for top_k in [1, 2]: + moe = MixtureOfExperts(hidden_dim, expert_dim, num_experts, top_k=top_k) + output, aux_loss = moe(x) + mx.eval(output, aux_loss) + + self.assertEqual(output.shape, (8, hidden_dim)) + + def test_deterministic_with_seed(self): + """Same seed should produce same results.""" + hidden_dim, expert_dim, num_experts = 32, 64, 4 + x = mx.random.normal((4, hidden_dim)) + + moe1 = MixtureOfExperts(hidden_dim, expert_dim, num_experts, top_k=2) + moe2 = MixtureOfExperts(hidden_dim, expert_dim, num_experts, top_k=2) + moe2.update(moe1.parameters()) + + out1, loss1 = moe1(x) + out2, loss2 = moe2(x) + mx.eval(out1, out2, loss1, loss2) + + self.assertTrue(mx.allclose(out1, out2).item()) + self.assertTrue(mx.allclose(loss1, loss2).item()) + + def test_large_batch(self): + """MoE should handle larger batch sizes.""" + hidden_dim, expert_dim, num_experts = 64, 128, 8 + moe = MixtureOfExperts(hidden_dim, expert_dim, num_experts, top_k=2) + x = mx.random.normal((128, hidden_dim)) + output, aux_loss = moe(x) + mx.eval(output, aux_loss) + + self.assertEqual(output.shape, (128, hidden_dim)) + + def test_partial_overflow_preserves_valid_routes(self): + """Tokens with at least one valid route should not be replaced by residual.""" + hidden_dim = 4 + num_experts = 2 + capacity = 1 + top_k = 2 + + # token0: expert 0 valid (pos=0), expert 1 overflow (pos=-1) + # token1: both routes overflow (pos=-1, -1) + positions = mx.array([[0, -1], [-1, -1]], dtype=mx.int32) + expert_indices = mx.array([[0, 1], [0, 1]], dtype=mx.int32) + weights = mx.array([[0.6, 0.4], [0.5, 0.5]]) + overflow_mask = mx.array([[True], [True]]) + + meta = DispatchMeta( + expert_indices=expert_indices, + weights=weights, + positions=positions, + overflow_mask=overflow_mask, + num_experts=num_experts, + capacity=capacity, + world_size=1, + ) + + # expert_outputs: [num_experts, capacity, hidden_dim] + expert_outputs = mx.ones((num_experts, capacity, hidden_dim)) * 10.0 + original_tokens = mx.zeros((2, hidden_dim)) + + combined = expert_combine(expert_outputs, meta, original_tokens) + mx.eval(combined) + + # Verify bug reproduction condition: token0 has overflow_mask=True + # but should still use expert output because it has a valid route. + self.assertTrue(meta.overflow_mask[0].item()) + has_valid = (meta.positions[0] >= 0).any().item() + self.assertTrue(has_valid) + + # token0: weight=0.6 * expert_output=10.0 → expected 6.0 per dim + expected_token0 = mx.full((hidden_dim,), 0.6 * 10.0) + self.assertTrue(mx.allclose(combined[0], expected_token0).item()) + # token1: all overflow → should be original (zeros) + self.assertTrue(mx.array_equal(combined[1], original_tokens[1]).item()) + + def test_ep_backend_parameter(self): + """Test ep_backend parameter is stored and used.""" + moe = MixtureOfExperts( + hidden_dim=32, + expert_dim=64, + num_experts=4, + top_k=2, + ep_impl="auto", + ep_backend="auto", + ) + self.assertEqual(moe.ep_backend, "auto") + + moe2 = MixtureOfExperts( + hidden_dim=32, + expert_dim=64, + num_experts=4, + top_k=2, + ep_impl="auto", + ep_backend="cpu", + ) + self.assertEqual(moe2.ep_backend, "cpu") + + +class TestVectorizedDispatchCombine(mlx_tests.MLXTestCase): + def test_dispatch_combine_duplicate_expert_across_k(self): + """Same expert selected by both top_k slots should not collide positions.""" + N, D = 4, 8 + num_experts = 4 + capacity_factor = 2.0 # generous capacity + + tokens = mx.random.normal((N, D)) + # Force token 0 and token 1 to route to the same expert (expert 0) for both k=0 and k=1 + expert_indices = mx.array( + [ + [0, 0], # token 0: expert 0 twice + [0, 0], # token 1: expert 0 twice + [1, 2], # token 2: different experts + [3, 1], # token 3: different experts + ], + dtype=mx.int32, + ) + weights = mx.array( + [ + [0.6, 0.4], + [0.5, 0.5], + [0.7, 0.3], + [0.8, 0.2], + ] + ) + + dispatched, meta = expert_dispatch( + tokens, + expert_indices, + weights, + num_experts=num_experts, + capacity_factor=capacity_factor, + ) + mx.eval(dispatched, *meta) + + # Positions for token 0 and token 1 should be different across k + # (expert_counts accumulation ensures no collision) + pos_token0 = meta.positions[0] # [top_k] + pos_token1 = meta.positions[1] # [top_k] + + # All positions should be >= 0 (no overflow with generous capacity) + self.assertTrue( + mx.all(meta.positions >= 0).item(), + f"Expected all valid positions, got {meta.positions}", + ) + + # For tokens routed to same expert: k=0 and k=1 positions must differ + self.assertNotEqual( + pos_token0[0].item(), + pos_token0[1].item(), + "Same expert positions should differ across k", + ) + + # Round-trip test: dispatch then combine with identity expert + expert_outputs = dispatched # identity + combined = expert_combine(expert_outputs, meta, tokens) + mx.eval(combined) + # Combined should not contain NaN + self.assertTrue(mx.all(mx.isfinite(combined)).item()) + + def test_dispatch_combine_overflow_boundary(self): + """Capacity boundary: first 2 tokens fit, last 2 overflow.""" + N, D = 4, 8 + num_experts = 2 + + tokens = mx.ones((N, D)) # all-ones for easy verification + # All tokens go to expert 0 for k=0, expert 1 for k=1 + expert_indices = mx.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + ], + dtype=mx.int32, + ) + weights = mx.array( + [ + [0.6, 0.4], + [0.6, 0.4], + [0.6, 0.4], + [0.6, 0.4], + ] + ) + + # capacity = max(1, ceil(4 * 2 * capacity_factor / 2)) + # With capacity_factor = 0.5: ceil(4 * 2 * 0.5 / 2) = ceil(2.0) = 2 + dispatched, meta = expert_dispatch( + tokens, + expert_indices, + weights, + num_experts=num_experts, + capacity_factor=0.5, + ) + mx.eval(dispatched, *meta) + + capacity = meta.capacity + self.assertEqual(capacity, 2) + + # For k=0 (expert 0): tokens 0,1 should have positions 0,1; tokens 2,3 overflow + positions_k0 = meta.positions[:, 0] + mx.eval(positions_k0) + self.assertEqual(positions_k0[0].item(), 0) + self.assertEqual(positions_k0[1].item(), 1) + self.assertEqual(positions_k0[2].item(), -1) # overflow + self.assertEqual(positions_k0[3].item(), -1) # overflow + + # Overflow mask should be True for tokens 2 and 3 + self.assertTrue(meta.overflow_mask[2].item()) + self.assertTrue(meta.overflow_mask[3].item()) + + def test_dispatch_combine_empty_batch(self): + """N=0 input should produce correct shapes without errors.""" + D = 8 + num_experts = 4 + + tokens = mx.zeros((0, D)) + expert_indices = mx.zeros((0, 2), dtype=mx.int32) + weights = mx.zeros((0, 2)) + + dispatched, meta = expert_dispatch( + tokens, + expert_indices, + weights, + num_experts=num_experts, + capacity_factor=1.25, + ) + mx.eval(dispatched, *meta) + + # Shape checks + self.assertEqual(meta.positions.shape, (0, 2)) + self.assertEqual(meta.overflow_mask.shape, (0, 1)) + self.assertEqual(dispatched.shape[0], num_experts) # experts_per_device + self.assertEqual(dispatched.shape[-1], D) + + # Round-trip with combine + expert_outputs = dispatched + combined = expert_combine(expert_outputs, meta, tokens) + mx.eval(combined) + self.assertEqual(combined.shape, (0, D)) + + def test_combine_all_invalid_residual(self): + """All routes invalid → combined should equal original_tokens.""" + N, D = 4, 8 + num_experts = 2 + capacity = 2 + + original_tokens = mx.random.normal((N, D)) + expert_outputs = mx.random.normal((num_experts, capacity, D)) + + # Manually construct meta with all-invalid positions + positions = mx.full((N, 2), -1, dtype=mx.int32) + expert_indices = mx.array([[0, 1]] * N, dtype=mx.int32) + weights = mx.array([[0.5, 0.5]] * N) + overflow_mask = mx.ones((N, 1), dtype=mx.bool_) + + meta = DispatchMeta( + expert_indices=expert_indices, + weights=weights, + positions=positions, + overflow_mask=overflow_mask, + num_experts=num_experts, + capacity=capacity, + world_size=1, + ) + + combined = expert_combine(expert_outputs, meta, original_tokens) + mx.eval(combined) + + self.assertTrue(mx.allclose(combined, original_tokens).item()) + + +class TestCppMoeExchange(unittest.TestCase): + """Tests for C++ moe_dispatch_exchange / moe_combine_exchange primitives.""" + + def setUp(self): + # Skip if C++ primitive not available + if not hasattr(mx.distributed, "moe_dispatch_exchange"): + self.skipTest("moe_dispatch_exchange not available") + # Detect actual world_size: try each backend explicitly so that + # mlx.launch-initialized backends (JACCL/MPI) are detected correctly. + self._world_size = 1 + for backend in ("jaccl", "mpi", "nccl"): + try: + g = mx.distributed.init(strict=True, backend=backend) + if g.size() > 1: + self._world_size = g.size() + break + except Exception: + pass + if self._world_size == 1: + try: + self._world_size = mx.distributed.init().size() + except Exception: + self._world_size = 1 + + def _python_dispatch_combine_ref( + self, tokens, expert_indices, weights, num_experts, capacity + ): + """Reference Python implementation for comparison.""" + N, D = tokens.shape + top_k = expert_indices.shape[1] + experts_per_device = num_experts # local only (world_size=1) + + dispatch_flat = mx.zeros((num_experts * capacity, D), dtype=tokens.dtype) + route_indices = mx.full((N, top_k), -1, dtype=mx.int32) + + expert_counts = [0] * num_experts + route_list = [[-1] * top_k for _ in range(N)] + for k in range(top_k): + for n in range(N): + eid = expert_indices[n, k].item() + pos = expert_counts[eid] + if pos < capacity: + flat_idx = eid * capacity + pos + route_list[n][k] = flat_idx + expert_counts[eid] += 1 + + route_np = mx.array(route_list, dtype=mx.int32) + # Build dispatch flat + disp = mx.zeros((num_experts * capacity, D), dtype=tokens.dtype) + for n in range(N): + for k in range(top_k): + flat_idx = route_list[n][k] + if flat_idx >= 0: + disp = disp.at[flat_idx].add(tokens[n]) + dispatched = disp.reshape(num_experts, capacity, D) + + # Combine + combined = mx.zeros((N, D), dtype=tokens.dtype) + result_flat = disp + for n in range(N): + has_valid = False + for k in range(top_k): + flat_idx = route_list[n][k] + if flat_idx >= 0: + has_valid = True + w = weights[n, k].item() + combined = combined.at[n].add(w * result_flat[flat_idx]) + if not has_valid: + combined = combined.at[n].add(tokens[n]) + + return dispatched, route_np, combined + + def test_dispatch_local_basic(self): + """Local dispatch matches reference for simple case.""" + if self._world_size > 1: + self.skipTest("local-only test") + mx.random.seed(42) + N, D, E, top_k = 8, 16, 4, 2 + capacity = 4 + + tokens = mx.random.normal((N, D)) + # Assign each token to experts deterministically + expert_indices = mx.array( + [[i % E, (i + 1) % E] for i in range(N)], dtype=mx.int32 + ) + weights = mx.ones((N, top_k)) / top_k + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + ) + mx.eval(dispatched, route_idx) + + # Shape check + self.assertEqual(dispatched.shape, (E, capacity, D)) + self.assertEqual(route_idx.shape, (N, top_k)) + self.assertEqual(route_idx.dtype, mx.int32) + + def test_dispatch_combine_roundtrip(self): + """Dispatch -> identity expert -> combine = input for non-overflow case.""" + mx.random.seed(7) + N, D, E, top_k = 4, 8, 4, 2 + capacity = 4 # large enough for no overflow + + tokens = mx.random.normal((N, D)) + # Each token goes to a unique expert pair + expert_indices = mx.array([[0, 1], [2, 3], [0, 2], [1, 3]], dtype=mx.int32) + weights = mx.ones((N, top_k)) / top_k # uniform + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + ) + mx.eval(dispatched, route_idx) + + # Identity expert: expert_outputs = dispatched + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + ) + mx.eval(combined) + + # Should reconstruct original tokens + self.assertEqual(combined.shape, (N, D)) + self.assertTrue(mx.allclose(combined, tokens, atol=1e-5).item()) + + def test_overflow_residual_fallback(self): + """Tokens with all-overflow routes get original_tokens as residual.""" + if self._world_size > 1: + self.skipTest("local-only test") + N, D, E, top_k = 4, 8, 1, 2 + capacity = 1 # only 1 slot for the single expert + + tokens = mx.random.normal((N, D)) + # All tokens go to expert 0, but capacity=1 -> most overflow + expert_indices = mx.zeros((N, top_k), dtype=mx.int32) + weights = mx.ones((N, top_k)) / top_k + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + ) + mx.eval(dispatched, route_idx) + + # Most route_indices should be -1 (overflow) + n_overflow = (route_idx == -1).sum().item() + self.assertGreater(n_overflow, 0) + + # Identity expert + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + ) + mx.eval(combined) + + # Token 0 (first to be dispatched) has valid route, rest are overflow + # Fully overflowed tokens should get original_tokens + route_idx_np = route_idx.tolist() + for n in range(N): + all_invalid = all(route_idx_np[n][k] < 0 for k in range(top_k)) + if all_invalid: + self.assertTrue( + mx.allclose(combined[n], tokens[n], atol=1e-5).item(), + f"Token {n} should be residual fallback", + ) + + def test_empty_batch(self): + """N=0 (empty batch) should produce empty outputs.""" + if self._world_size > 1: + self.skipTest("local-only test") + E, D, top_k = 4, 16, 2 + capacity = 4 + + tokens = mx.zeros((0, D)) + expert_indices = mx.zeros((0, top_k), dtype=mx.int32) + weights = mx.zeros((0, top_k)) + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + ) + mx.eval(dispatched, route_idx) + + self.assertEqual(dispatched.shape, (E, capacity, D)) + self.assertEqual(route_idx.shape, (0, top_k)) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + ) + mx.eval(combined) + self.assertEqual(combined.shape, (0, D)) + + def test_dispatch_deterministic(self): + """Same input always produces same route_indices (deterministic=True).""" + N, D, E, top_k = 16, 8, 4, 2 + capacity = 6 + + tokens = mx.random.normal((N, D)) + expert_indices = mx.array( + [[i % E, (i + 2) % E] for i in range(N)], dtype=mx.int32 + ) + + _, route1 = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + ) + _, route2 = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + ) + mx.eval(route1, route2) + self.assertTrue((route1 == route2).all().item()) + + def test_dtype_float16(self): + """float16 tokens are correctly dispatched and combined.""" + N, D, E, top_k = 8, 16, 4, 2 + capacity = 4 + + tokens = mx.random.normal((N, D)).astype(mx.float16) + expert_indices = mx.array( + [[i % E, (i + 1) % E] for i in range(N)], dtype=mx.int32 + ) + weights = mx.ones((N, top_k), dtype=mx.float32) / top_k + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + ) + mx.eval(dispatched, route_idx) + self.assertEqual(dispatched.dtype, mx.float16) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + ) + mx.eval(combined) + self.assertEqual(combined.dtype, mx.float16) + self.assertEqual(combined.shape, (N, D)) + + def test_cpp_vs_python_consistency(self): + """C++ primitive matches Python expert_dispatch/combine for local mode.""" + if self._world_size > 1: + self.skipTest("local-only test") + mx.random.seed(123) + N, D, E, top_k = 12, 8, 4, 2 + capacity_factor = 1.5 + + tokens = mx.random.normal((N, D)) + expert_indices = mx.array( + [[i % E, (i + 1) % E] for i in range(N)], dtype=mx.int32 + ) + weights = mx.random.uniform(shape=(N, top_k)) + weights = weights / weights.sum(axis=-1, keepdims=True) + + # Python path + dispatched_py, meta = expert_dispatch( + tokens, expert_indices, weights, E, capacity_factor, group=None + ) + capacity = meta.capacity + expert_out_py = dispatched_py # identity expert + + # C++ path + dispatched_cpp, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices.astype(mx.int32), + num_experts=E, + capacity=capacity, + ) + mx.eval(dispatched_cpp, route_idx) + + # Both dispatch should have same shape + self.assertEqual( + dispatched_cpp.shape, + dispatched_py.shape, + f"Shape mismatch: cpp={dispatched_cpp.shape} py={dispatched_py.shape}", + ) + + # C++ combine + combined_cpp = mx.distributed.moe_combine_exchange( + dispatched_cpp, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + ) + # Python combine + combined_py = expert_combine(expert_out_py, meta, tokens, group=None) + + mx.eval(combined_cpp, combined_py) + + # Results should be close (same deterministic routing) + self.assertTrue( + mx.allclose(combined_cpp, combined_py, atol=1e-5).item(), + f"C++ and Python combine results differ.\n" + f"Max diff: {mx.abs(combined_cpp - combined_py).max().item()}", + ) + + def test_backend_auto(self): + """Test that backend='auto' works (resolves to cpu in current build).""" + N, D, top_k = 8, 16, 2 + num_experts = self._world_size * 2 + capacity = 4 + tokens = mx.random.normal((N, D)) + expert_indices = mx.random.randint(0, num_experts, shape=(N, top_k)).astype( + mx.int32 + ) + + # backend="auto" should work without error + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=num_experts, + capacity=capacity, + backend="auto", + ) + mx.eval(dispatched, route_idx) + + # Verify shapes are correct + world_size = self._world_size + experts_per_device = num_experts // world_size + self.assertEqual( + dispatched.shape, (experts_per_device, world_size * capacity, D) + ) + self.assertEqual(route_idx.shape, (N, top_k)) + + # Run combine with auto backend too + expert_out = dispatched * 2.0 # simulate expert computation + weights = mx.ones((N, top_k), dtype=mx.float32) / top_k + combined = mx.distributed.moe_combine_exchange( + expert_out, + route_idx, + weights, + tokens, + num_experts=num_experts, + capacity=capacity, + backend="auto", + ) + mx.eval(combined) + self.assertEqual(combined.shape, (N, D)) + + def test_metal_dispatch_combine_roundtrip(self): + """Metal backend: dispatch -> identity -> combine = input for non-overflow.""" + mx.random.seed(42) + N, D, E, top_k = 16, 32, 4, 2 + capacity = 8 # large enough for no overflow + + tokens = mx.random.normal((N, D)) + expert_indices = mx.array( + [[i % E, (i + 1) % E] for i in range(N)], dtype=mx.int32 + ) + weights = mx.ones((N, top_k), dtype=mx.float32) / top_k + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(dispatched, route_idx) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(combined) + + self.assertEqual(combined.shape, (N, D)) + self.assertTrue( + mx.allclose(combined, tokens, atol=1e-4).item(), + f"Metal roundtrip failed. Max diff: {mx.abs(combined - tokens).max().item()}", + ) + + def test_metal_vs_cpu_consistency(self): + """Metal backend produces same results as CPU backend.""" + mx.random.seed(77) + N, D, E, top_k = 32, 64, 4, 2 + capacity = 12 + + tokens = mx.random.normal((N, D)) + expert_indices = mx.array( + [[i % E, (i + 1) % E] for i in range(N)], dtype=mx.int32 + ) + weights = mx.random.uniform(shape=(N, top_k)) + weights = weights / weights.sum(axis=-1, keepdims=True) + + # CPU path + disp_cpu, ri_cpu = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="cpu", + ) + mx.eval(disp_cpu, ri_cpu) + + comb_cpu = mx.distributed.moe_combine_exchange( + disp_cpu, + ri_cpu, + weights, + tokens, + num_experts=E, + capacity=capacity, + backend="cpu", + ) + mx.eval(comb_cpu) + + # Metal path + disp_metal, ri_metal = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(disp_metal, ri_metal) + + comb_metal = mx.distributed.moe_combine_exchange( + disp_metal, + ri_metal, + weights, + tokens, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(comb_metal) + + # Route indices must match exactly + self.assertTrue( + mx.array_equal(ri_cpu, ri_metal).item(), + "Route indices differ between CPU and Metal", + ) + # Dispatched must match + self.assertTrue( + mx.allclose(disp_cpu, disp_metal, atol=1e-5).item(), + f"Dispatch differs. Max diff: {mx.abs(disp_cpu - disp_metal).max().item()}", + ) + # Combined must match + self.assertTrue( + mx.allclose(comb_cpu, comb_metal, atol=1e-4).item(), + f"Combine differs. Max diff: {mx.abs(comb_cpu - comb_metal).max().item()}", + ) + + def test_metal_overflow_residual(self): + """Metal backend: all-overflow tokens get original_tokens as residual.""" + if self._world_size > 1: + self.skipTest("local-only test") + N, D, E, top_k = 4, 16, 1, 2 + capacity = 1 + + tokens = mx.random.normal((N, D)) + expert_indices = mx.zeros((N, top_k), dtype=mx.int32) + weights = mx.ones((N, top_k), dtype=mx.float32) / top_k + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(dispatched, route_idx) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(combined) + + route_idx_list = route_idx.tolist() + for n in range(N): + all_invalid = all(route_idx_list[n][k] < 0 for k in range(top_k)) + if all_invalid: + self.assertTrue( + mx.allclose(combined[n], tokens[n], atol=1e-5).item(), + f"Token {n} should be residual fallback (Metal)", + ) + + def test_metal_empty_batch(self): + """Metal backend: N=0 should produce correct shapes.""" + if self._world_size > 1: + self.skipTest("local-only test") + E, D, top_k = 4, 16, 2 + capacity = 4 + + tokens = mx.zeros((0, D)) + expert_indices = mx.zeros((0, top_k), dtype=mx.int32) + weights = mx.zeros((0, top_k), dtype=mx.float32) + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(dispatched, route_idx) + + self.assertEqual(dispatched.shape, (E, capacity, D)) + self.assertEqual(route_idx.shape, (0, top_k)) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(combined) + self.assertEqual(combined.shape, (0, D)) + + def test_metal_dtype_float16(self): + """Metal backend: float16 tokens dispatch and combine correctly.""" + N, D, E, top_k = 8, 32, 4, 2 + capacity = 4 + + tokens = mx.random.normal((N, D)).astype(mx.float16) + expert_indices = mx.array( + [[i % E, (i + 1) % E] for i in range(N)], dtype=mx.int32 + ) + weights = mx.ones((N, top_k), dtype=mx.float32) / top_k + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(dispatched, route_idx) + self.assertEqual(dispatched.dtype, mx.float16) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(combined) + self.assertEqual(combined.dtype, mx.float16) + self.assertEqual(combined.shape, (N, D)) + + def test_metal_dtype_bfloat16(self): + """Metal backend: bfloat16 tokens dispatch and combine correctly.""" + N, D, E, top_k = 8, 32, 4, 2 + capacity = 4 + + tokens = mx.random.normal((N, D)).astype(mx.bfloat16) + expert_indices = mx.array( + [[i % E, (i + 1) % E] for i in range(N)], dtype=mx.int32 + ) + weights = mx.ones((N, top_k), dtype=mx.float32) / top_k + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(dispatched, route_idx) + self.assertEqual(dispatched.dtype, mx.bfloat16) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(combined) + self.assertEqual(combined.dtype, mx.bfloat16) + self.assertEqual(combined.shape, (N, D)) + + def test_metal_large_batch(self): + """Metal backend: large batch N=256 with realistic parameters.""" + N, D, E, top_k = 256, 128, 8, 2 + capacity = 80 + + tokens = mx.random.normal((N, D)).astype(mx.float16) + expert_indices = mx.random.randint(0, E, shape=(N, top_k)).astype(mx.int32) + weights = mx.random.uniform(shape=(N, top_k)) + weights = weights / weights.sum(axis=-1, keepdims=True) + + dispatched, route_idx = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(dispatched, route_idx) + + combined = mx.distributed.moe_combine_exchange( + dispatched, + route_idx, + weights, + tokens, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(combined) + self.assertEqual(combined.shape, (N, D)) + self.assertTrue( + mx.all(mx.isfinite(combined)).item(), "NaN/Inf in combined output" + ) + + def test_metal_all_local_experts(self): + """Metal backend: all tokens routed to local experts (ws=1, no remote).""" + if self._world_size > 1: + self.skipTest("local-only test") + N, D, E, top_k = 16, 32, 4, 2 + capacity = 8 + + tokens = mx.random.normal((N, D)) + expert_indices = mx.array( + [[i % E, (i + 1) % E] for i in range(N)], dtype=mx.int32 + ) + weights = mx.ones((N, top_k), dtype=mx.float32) / top_k + + disp_cpu, ri_cpu = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="cpu", + ) + disp_metal, ri_metal = mx.distributed.moe_dispatch_exchange( + tokens, + expert_indices, + num_experts=E, + capacity=capacity, + backend="metal", + ) + mx.eval(disp_cpu, ri_cpu, disp_metal, ri_metal) + + self.assertTrue(mx.array_equal(ri_cpu, ri_metal).item()) + self.assertTrue(mx.allclose(disp_cpu, disp_metal, atol=1e-5).item()) + + +if __name__ == "__main__": + unittest.main()