Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions benchmarks/python/masked_scatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import os
import platform
import subprocess
import time
from copy import copy
Expand All @@ -17,16 +18,38 @@
if not os.path.isdir(RESULTS_DIR):
os.mkdir(RESULTS_DIR)

DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n")

TORCH_DEVICE = torch.device(
"mps"
if torch.backends.mps.is_available()
else ("cuda" if torch.cuda.is_available() else "cpu")
)


def get_device_name():
if TORCH_DEVICE.type == "cuda":
try:
out = subprocess.check_output(
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
stderr=subprocess.DEVNULL,
)
return out.decode("utf-8").splitlines()[0].strip()
except Exception:
return "CUDA_GPU"
if TORCH_DEVICE.type == "mps":
try:
out = subprocess.check_output(
["sysctl", "-n", "machdep.cpu.brand_string"],
stderr=subprocess.DEVNULL,
)
return out.decode("utf-8").strip()
except Exception:
return "Apple_Silicon"
return platform.processor() or platform.machine() or "CPU"


DEVICE_NAME = get_device_name()


N_WARMUP = 5
N_ITER_BENCH = 50
N_ITER_FUNC = 20
Expand Down
38 changes: 38 additions & 0 deletions mlx/backend/cuda/device/scatter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,42 @@ __global__ void scatter(
Op{}(out + out_idx, upd[upd_loc]);
}

template <typename T, bool SrcContiguous, typename IdxT>
__global__ void masked_scatter_assign(
const bool* mask,
const int32_t* scatter_offsets,
const T* src,
T* out,
IdxT size,
IdxT src_batch_size,
IdxT mask_batch_size,
const __grid_constant__ Shape src_shape,
const __grid_constant__ Strides src_strides,
int32_t src_ndim) {
IdxT index = cg::this_grid().thread_rank();
if (index >= size) {
return;
}

if (!mask[index]) {
return;
}

IdxT src_index = static_cast<IdxT>(scatter_offsets[index]);
if (src_index >= src_batch_size) {
// Match Metal backend behavior by skipping out-of-range source reads.
return;
}

IdxT batch_idx = index / mask_batch_size;
if constexpr (SrcContiguous) {
out[index] = src[batch_idx * src_batch_size + src_index];
} else {
IdxT src_elem = batch_idx * src_batch_size + src_index;
IdxT src_loc =
elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim);
out[index] = src[src_loc];
}
}

} // namespace mlx::core::cu
99 changes: 99 additions & 0 deletions mlx/backend/cuda/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/jit_module.h"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/scan.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
Expand Down Expand Up @@ -435,4 +436,102 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel, num_blocks, block_dims, {}, 0, args.args());
}

void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("MaskedScatter::eval_gpu");
assert(inputs.size() == 3);

const array& dst = inputs[0];
const array& mask = inputs[1];
const array& src = inputs[2];

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

const size_t total = mask.size();
const CopyType copy_type = (total == 1)
? CopyType::Scalar
: (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General);
copy_gpu(dst, out, copy_type, s);
if (total == 0) {
return;
}

array mask_flat = flatten_in_eval(mask, 1, -1, s);
if (mask_flat.data<void>() != mask.data<void>()) {
encoder.add_temporary(mask_flat);
}
if (!mask_flat.flags().row_contiguous) {
mask_flat = contiguous_copy_gpu(mask_flat, s);
encoder.add_temporary(mask_flat);
}

array scatter_offsets(mask_flat.shape(), int32, nullptr, {});
scatter_offsets.set_data(cu::malloc_async(scatter_offsets.nbytes(), encoder));
encoder.add_temporary(scatter_offsets);

scan_gpu_inplace(
mask_flat,
scatter_offsets,
Scan::Sum,
/* axis= */ 1,
/* reverse= */ false,
/* inclusive= */ false,
s);

const size_t batch_count = mask.shape(0);
const size_t mask_batch_size = mask_flat.size() / batch_count;
const size_t src_batch_size = src.size() / src.shape(0);
bool large = total > INT32_MAX || src.size() > INT32_MAX;

std::string module_name =
fmt::format("masked_scatter_assign_{}", dtype_to_string(out.dtype()));
cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() {
std::vector<std::string> kernel_names;
for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) {
for (int use_large = 0; use_large <= 1; ++use_large) {
kernel_names.push_back(
fmt::format(
"mlx::core::cu::masked_scatter_assign<{}, {}, {}>",
dtype_to_cuda_type(out.dtype()),
src_contiguous ? "true" : "false",
use_large ? "int64_t" : "int32_t"));
}
}
return std::make_tuple(false, jit_source_scatter, std::move(kernel_names));
});

cu::KernelArgs args;
args.append(mask_flat);
args.append(scatter_offsets);
args.append(src);
args.append(out);
if (large) {
args.append<int64_t>(mask_flat.size());
args.append<int64_t>(src_batch_size);
args.append<int64_t>(mask_batch_size);
} else {
args.append<int32_t>(mask_flat.size());
args.append<int32_t>(src_batch_size);
args.append<int32_t>(mask_batch_size);
}
args.append_ndim(src.shape());
args.append_ndim(src.strides());
args.append<int32_t>(src.ndim());

encoder.set_input_array(mask_flat);
encoder.set_input_array(scatter_offsets);
encoder.set_input_array(src);
encoder.set_output_array(out);

std::string kernel_name = fmt::format(
"mlx::core::cu::masked_scatter_assign<{}, {}, {}>",
dtype_to_cuda_type(out.dtype()),
src.flags().row_contiguous ? "true" : "false",
large ? "int64_t" : "int32_t");
auto kernel = mod.get_kernel(kernel_name);
auto [num_blocks, block_dims] = get_launch_args(mask_flat, large);
encoder.add_kernel_node_raw(
kernel, num_blocks, block_dims, {}, 0, args.args());
}

} // namespace mlx::core
1 change: 0 additions & 1 deletion mlx/backend/cuda/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
NO_GPU(MaskedScatter)

namespace distributed {
NO_GPU_MULTI(Send)
Expand Down
75 changes: 44 additions & 31 deletions mlx/backend/cuda/scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlx/backend/cuda/device/binary_ops.cuh"
#include "mlx/backend/cuda/kernel_utils.cuh"
#include "mlx/backend/cuda/reduce/reduce_ops.cuh"
#include "mlx/backend/cuda/scan.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/dtype_utils.h"
#include "mlx/primitives.h"
Expand Down Expand Up @@ -362,51 +363,38 @@ constexpr bool supports_scan_op() {
}
}

void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Scan::eval_gpu");
assert(inputs.size() == 1);
auto in = inputs[0];
auto& s = stream();
void scan_gpu_inplace(
array in,
array& out,
Scan::ReduceType reduce_type,
int axis,
bool reverse,
bool inclusive,
const Stream& s) {
auto& encoder = cu::get_command_encoder(s);

if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(),
in.strides(),
in.flags());
}
} else {
in = contiguous_copy_gpu(in, s);
out.copy_shared_buffer(in);
}

constexpr int N_READS = 4;
int32_t axis_size = in.shape(axis_);
bool contiguous = in.strides()[axis_] == 1;
int32_t axis_size = in.shape(axis);
bool contiguous = in.strides()[axis] == 1;

encoder.set_input_array(in);
encoder.set_output_array(out);

dispatch_all_types(in.dtype(), [&](auto type_tag) {
using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) {
dispatch_scan_ops(reduce_type, [&](auto scan_op_tag) {
using Op = MLX_GET_TYPE(scan_op_tag);
if constexpr (supports_scan_op<Op, T>()) {
using U = typename cu::ScanResult<Op, T>::type;
dispatch_bool(inclusive_, [&](auto inclusive) {
dispatch_bool(reverse_, [&](auto reverse) {
dispatch_bool(inclusive, [&](auto inclusive_tag) {
dispatch_bool(reverse, [&](auto reverse_tag) {
if (contiguous) {
auto kernel = cu::contiguous_scan<
T,
U,
Op,
N_READS,
inclusive.value,
reverse.value>;
inclusive_tag.value,
reverse_tag.value>;
int block_dim = cuda::ceil_div(axis_size, N_READS);
block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE;
block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE);
Expand All @@ -427,9 +415,9 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
N_READS,
BM,
BN,
inclusive.value,
reverse.value>;
int64_t stride = in.strides()[axis_];
inclusive_tag.value,
reverse_tag.value>;
int64_t stride = in.strides()[axis];
int64_t stride_blocks = cuda::ceil_div(stride, BN);
dim3 num_blocks = get_2d_grid_dims(
in.shape(), in.strides(), axis_size * stride);
Expand Down Expand Up @@ -463,4 +451,29 @@ void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
});
}

void Scan::eval_gpu(const std::vector<array>& inputs, array& out) {
nvtx3::scoped_range r("Scan::eval_gpu");
assert(inputs.size() == 1);
auto in = inputs[0];
auto& s = stream();
auto& encoder = cu::get_command_encoder(s);

if (in.flags().contiguous && in.strides()[axis_] != 0) {
if (in.is_donatable() && in.itemsize() == out.itemsize()) {
out.copy_shared_buffer(in);
} else {
out.set_data(
cu::malloc_async(in.data_size() * out.itemsize(), encoder),
in.data_size(),
in.strides(),
in.flags());
}
} else {
in = contiguous_copy_gpu(in, s);
out.copy_shared_buffer(in);
}

scan_gpu_inplace(in, out, reduce_type_, axis_, reverse_, inclusive_, s);
}

} // namespace mlx::core
20 changes: 20 additions & 0 deletions mlx/backend/cuda/scan.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright © 2025 Apple Inc.

#pragma once

#include "mlx/array.h"
#include "mlx/primitives.h"
#include "mlx/stream.h"

namespace mlx::core {

void scan_gpu_inplace(
array in,
array& out,
Scan::ReduceType reduce_type,
int axis,
bool reverse,
bool inclusive,
const Stream& s);

} // namespace mlx::core
4 changes: 0 additions & 4 deletions python/tests/cuda_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,4 @@
"TestQuantized.test_throw",
"TestQuantized.test_vjp_scales_biases",
"TestExportImport.test_export_quantized_model",
# Masked scatter
"TestOps.test_masked_scatter",
"TestVmap.test_vmap_masked_scatter",
"TestArray.test_setitem_with_boolean_mask",
}
5 changes: 0 additions & 5 deletions tests/autograd_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1357,11 +1357,6 @@ TEST_CASE("test grad dynamic slices") {
}

TEST_CASE("test masked_scatter autograd") {
if (cu::is_available()) {
INFO("Skipping masked_scatter cuda autograd tests");
return;
}

// Test jvp
{
auto self = array({10.f, 20.f, 30.f, 40.f}, {4});
Expand Down
5 changes: 0 additions & 5 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2437,11 +2437,6 @@ TEST_CASE("test scatter") {
}

TEST_CASE("test masked_scatter") {
if (cu::is_available()) {
INFO("Skipping masked_scatter cuda ops tests");
return;
}

// Wrong mask dtype
CHECK_THROWS(masked_scatter(array({1, 2}), array({1, 2}), array({1, 2})));

Expand Down