From aecf6275b95f03c6681307f7c6e6532e997ee965 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Sat, 21 Feb 2026 02:05:25 +0800 Subject: [PATCH 1/5] feat(cuda): implement masked scatter kernel --- mlx/backend/cuda/CMakeLists.txt | 2 +- .../cuda/{indexing.cpp => indexing.cu} | 131 ++++++++++++++++++ mlx/backend/cuda/primitives.cpp | 1 - mlx/backend/cuda/scan.cu | 75 +++++----- mlx/backend/cuda/scan.h | 20 +++ 5 files changed, 196 insertions(+), 33 deletions(-) rename mlx/backend/cuda/{indexing.cpp => indexing.cu} (75%) create mode 100644 mlx/backend/cuda/scan.h diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 1b95116e0e..35fbd33e56 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -31,7 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cu similarity index 75% rename from mlx/backend/cuda/indexing.cpp rename to mlx/backend/cuda/indexing.cu index 424566d258..7091ed3ec6 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/scan.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -54,6 +55,53 @@ void append_indices_arg( } // namespace +namespace cu { + +template +__global__ void masked_assign( + const bool* mask, + const int32_t* scatter_offsets, + const T* src, + T* out, + IdxT total, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim, + IdxT src_batch_size, + IdxT mask_batch_size) { + IdxT block_id = static_cast(blockIdx.x) + + static_cast(gridDim.x) * + (static_cast(blockIdx.y) + + static_cast(gridDim.y) * static_cast(blockIdx.z)); + IdxT thread_id = block_id * blockDim.x + threadIdx.x; + IdxT stride = + static_cast(blockDim.x) * gridDim.x * gridDim.y * gridDim.z; + + for (IdxT idx = thread_id; idx < total; idx += stride) { + if (!mask[idx]) { + continue; + } + + IdxT src_index = static_cast(scatter_offsets[idx]); + if (src_index >= src_batch_size) { + // Match Metal backend behavior by skipping out-of-range source reads. + continue; + } + + IdxT batch_idx = idx / mask_batch_size; + if constexpr (src_contiguous) { + out[idx] = src[batch_idx * src_batch_size + src_index]; + } else { + IdxT src_elem = batch_idx * src_batch_size + src_index; + IdxT src_loc = + elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim); + out[idx] = src[src_loc]; + } + } +} + +} // namespace cu + void Gather::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Gather::eval_gpu"); assert(inputs.size() > 0); @@ -435,4 +483,87 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { kernel, num_blocks, block_dims, {}, 0, args.args()); } +void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("MaskedScatter::eval_gpu"); + assert(inputs.size() == 3); + + const array& dst = inputs[0]; + const array& mask = inputs[1]; + const array& src = inputs[2]; + + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + const size_t total = mask.size(); + const CopyType copy_type = (total == 1) + ? CopyType::Scalar + : (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy_gpu(dst, out, copy_type, s); + if (total == 0) { + return; + } + + array mask_flat = flatten_in_eval(mask, 1, -1, s); + if (mask_flat.data() != mask.data()) { + encoder.add_temporary(mask_flat); + } + if (!mask_flat.flags().row_contiguous) { + mask_flat = contiguous_copy_gpu(mask_flat, s); + encoder.add_temporary(mask_flat); + } + + array scatter_offsets(mask_flat.shape(), int32, nullptr, {}); + scatter_offsets.set_data(cu::malloc_async(scatter_offsets.nbytes(), encoder)); + encoder.add_temporary(scatter_offsets); + + scan_gpu_inplace( + mask_flat, + scatter_offsets, + Scan::Sum, + /* axis= */ 1, + /* reverse= */ false, + /* inclusive= */ false, + s); + + const size_t batch_count = mask.shape(0); + const size_t mask_batch_size = mask_flat.size() / batch_count; + const size_t src_batch_size = src.size() / src.shape(0); + + encoder.set_input_array(mask_flat); + encoder.set_input_array(scatter_offsets); + encoder.set_input_array(src); + encoder.set_output_array(out); + + dispatch_all_types(out.dtype(), [&](auto type_tag) { + using T = cuda_type_t; + dispatch_bool(src.flags().row_contiguous, [&](auto src_contiguous) { + dispatch_bool( + total > INT32_MAX || src.size() > INT32_MAX, [&](auto large) { + using IdxT = std::conditional_t; + auto [num_blocks, block_dims] = get_launch_args( + mask_flat.size(), + mask_flat.shape(), + mask_flat.strides(), + large()); + auto kernel = cu::masked_assign; + encoder.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(mask_flat), + gpu_ptr(scatter_offsets), + gpu_ptr(src), + gpu_ptr(out), + static_cast(mask_flat.size()), + const_param(src.shape()), + const_param(src.strides()), + static_cast(src.ndim()), + static_cast(src_batch_size), + static_cast(mask_batch_size)); + }); + }); + }); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/primitives.cpp b/mlx/backend/cuda/primitives.cpp index 0caac8de6e..cb9b69fbb0 100644 --- a/mlx/backend/cuda/primitives.cpp +++ b/mlx/backend/cuda/primitives.cpp @@ -36,7 +36,6 @@ NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU_MULTI(Eig) NO_GPU_MULTI(Eigh) -NO_GPU(MaskedScatter) namespace distributed { NO_GPU_MULTI(Send) diff --git a/mlx/backend/cuda/scan.cu b/mlx/backend/cuda/scan.cu index bd25084c1f..a07d056d89 100644 --- a/mlx/backend/cuda/scan.cu +++ b/mlx/backend/cuda/scan.cu @@ -4,6 +4,7 @@ #include "mlx/backend/cuda/device/binary_ops.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/reduce/reduce_ops.cuh" +#include "mlx/backend/cuda/scan.h" #include "mlx/backend/gpu/copy.h" #include "mlx/dtype_utils.h" #include "mlx/primitives.h" @@ -362,51 +363,38 @@ constexpr bool supports_scan_op() { } } -void Scan::eval_gpu(const std::vector& inputs, array& out) { - nvtx3::scoped_range r("Scan::eval_gpu"); - assert(inputs.size() == 1); - auto in = inputs[0]; - auto& s = stream(); +void scan_gpu_inplace( + array in, + array& out, + Scan::ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive, + const Stream& s) { auto& encoder = cu::get_command_encoder(s); - - if (in.flags().contiguous && in.strides()[axis_] != 0) { - if (in.is_donatable() && in.itemsize() == out.itemsize()) { - out.copy_shared_buffer(in); - } else { - out.set_data( - cu::malloc_async(in.data_size() * out.itemsize(), encoder), - in.data_size(), - in.strides(), - in.flags()); - } - } else { - in = contiguous_copy_gpu(in, s); - out.copy_shared_buffer(in); - } - constexpr int N_READS = 4; - int32_t axis_size = in.shape(axis_); - bool contiguous = in.strides()[axis_] == 1; + int32_t axis_size = in.shape(axis); + bool contiguous = in.strides()[axis] == 1; encoder.set_input_array(in); encoder.set_output_array(out); dispatch_all_types(in.dtype(), [&](auto type_tag) { using T = cuda_type_t; - dispatch_scan_ops(reduce_type_, [&](auto scan_op_tag) { + dispatch_scan_ops(reduce_type, [&](auto scan_op_tag) { using Op = MLX_GET_TYPE(scan_op_tag); if constexpr (supports_scan_op()) { using U = typename cu::ScanResult::type; - dispatch_bool(inclusive_, [&](auto inclusive) { - dispatch_bool(reverse_, [&](auto reverse) { + dispatch_bool(inclusive, [&](auto inclusive_tag) { + dispatch_bool(reverse, [&](auto reverse_tag) { if (contiguous) { auto kernel = cu::contiguous_scan< T, U, Op, N_READS, - inclusive.value, - reverse.value>; + inclusive_tag.value, + reverse_tag.value>; int block_dim = cuda::ceil_div(axis_size, N_READS); block_dim = cuda::ceil_div(block_dim, WARP_SIZE) * WARP_SIZE; block_dim = std::min(block_dim, WARP_SIZE * WARP_SIZE); @@ -427,9 +415,9 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { N_READS, BM, BN, - inclusive.value, - reverse.value>; - int64_t stride = in.strides()[axis_]; + inclusive_tag.value, + reverse_tag.value>; + int64_t stride = in.strides()[axis]; int64_t stride_blocks = cuda::ceil_div(stride, BN); dim3 num_blocks = get_2d_grid_dims( in.shape(), in.strides(), axis_size * stride); @@ -463,4 +451,29 @@ void Scan::eval_gpu(const std::vector& inputs, array& out) { }); } +void Scan::eval_gpu(const std::vector& inputs, array& out) { + nvtx3::scoped_range r("Scan::eval_gpu"); + assert(inputs.size() == 1); + auto in = inputs[0]; + auto& s = stream(); + auto& encoder = cu::get_command_encoder(s); + + if (in.flags().contiguous && in.strides()[axis_] != 0) { + if (in.is_donatable() && in.itemsize() == out.itemsize()) { + out.copy_shared_buffer(in); + } else { + out.set_data( + cu::malloc_async(in.data_size() * out.itemsize(), encoder), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + in = contiguous_copy_gpu(in, s); + out.copy_shared_buffer(in); + } + + scan_gpu_inplace(in, out, reduce_type_, axis_, reverse_, inclusive_, s); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/scan.h b/mlx/backend/cuda/scan.h new file mode 100644 index 0000000000..ea233edfb1 --- /dev/null +++ b/mlx/backend/cuda/scan.h @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/primitives.h" +#include "mlx/stream.h" + +namespace mlx::core { + +void scan_gpu_inplace( + array in, + array& out, + Scan::ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive, + const Stream& s); + +} // namespace mlx::core From 71cb5afc41fd6e0b016f360ea842b04b91ac70ab Mon Sep 17 00:00:00 2001 From: Lyxot Date: Sat, 21 Feb 2026 02:06:25 +0800 Subject: [PATCH 2/5] test(cuda): enable masked scatter test coverage --- python/tests/cuda_skip.py | 4 ---- tests/autograd_tests.cpp | 5 ----- tests/ops_tests.cpp | 5 ----- 3 files changed, 14 deletions(-) diff --git a/python/tests/cuda_skip.py b/python/tests/cuda_skip.py index 20793d5c91..fd2924a411 100644 --- a/python/tests/cuda_skip.py +++ b/python/tests/cuda_skip.py @@ -55,8 +55,4 @@ "TestQuantized.test_throw", "TestQuantized.test_vjp_scales_biases", "TestExportImport.test_export_quantized_model", - # Masked scatter - "TestOps.test_masked_scatter", - "TestVmap.test_vmap_masked_scatter", - "TestArray.test_setitem_with_boolean_mask", } diff --git a/tests/autograd_tests.cpp b/tests/autograd_tests.cpp index ff8d986bd9..25c871cdf9 100644 --- a/tests/autograd_tests.cpp +++ b/tests/autograd_tests.cpp @@ -1357,11 +1357,6 @@ TEST_CASE("test grad dynamic slices") { } TEST_CASE("test masked_scatter autograd") { - if (cu::is_available()) { - INFO("Skipping masked_scatter cuda autograd tests"); - return; - } - // Test jvp { auto self = array({10.f, 20.f, 30.f, 40.f}, {4}); diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index 62fd8c5923..23740b7004 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2437,11 +2437,6 @@ TEST_CASE("test scatter") { } TEST_CASE("test masked_scatter") { - if (cu::is_available()) { - INFO("Skipping masked_scatter cuda ops tests"); - return; - } - // Wrong mask dtype CHECK_THROWS(masked_scatter(array({1, 2}), array({1, 2}), array({1, 2}))); From 50c0da5a699a591e5707b32aefbae5749c0faf4a Mon Sep 17 00:00:00 2001 From: Lyxot Date: Wed, 25 Feb 2026 04:07:58 +0800 Subject: [PATCH 3/5] refactor(cuda): align masked scatter jit with scatter kernels --- mlx/backend/cuda/CMakeLists.txt | 2 +- mlx/backend/cuda/device/scatter.cuh | 38 ++++++ .../cuda/{indexing.cu => indexing.cpp} | 121 +++++++----------- 3 files changed, 83 insertions(+), 78 deletions(-) rename mlx/backend/cuda/{indexing.cu => indexing.cpp} (84%) diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index 35fbd33e56..1b95116e0e 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -31,7 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cu + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp diff --git a/mlx/backend/cuda/device/scatter.cuh b/mlx/backend/cuda/device/scatter.cuh index b2f6403505..9a124d5426 100644 --- a/mlx/backend/cuda/device/scatter.cuh +++ b/mlx/backend/cuda/device/scatter.cuh @@ -65,4 +65,42 @@ __global__ void scatter( Op{}(out + out_idx, upd[upd_loc]); } +template +__global__ void masked_scatter_assign( + const bool* mask, + const int32_t* scatter_offsets, + const T* src, + T* out, + IdxT size, + IdxT src_batch_size, + IdxT mask_batch_size, + const __grid_constant__ Shape src_shape, + const __grid_constant__ Strides src_strides, + int32_t src_ndim) { + IdxT index = cg::this_grid().thread_rank(); + if (index >= size) { + return; + } + + if (!mask[index]) { + return; + } + + IdxT src_index = static_cast(scatter_offsets[index]); + if (src_index >= src_batch_size) { + // Match Metal backend behavior by skipping out-of-range source reads. + return; + } + + IdxT batch_idx = index / mask_batch_size; + if constexpr (SrcContiguous) { + out[index] = src[batch_idx * src_batch_size + src_index]; + } else { + IdxT src_elem = batch_idx * src_batch_size + src_index; + IdxT src_loc = + elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim); + out[index] = src[src_loc]; + } +} + } // namespace mlx::core::cu diff --git a/mlx/backend/cuda/indexing.cu b/mlx/backend/cuda/indexing.cpp similarity index 84% rename from mlx/backend/cuda/indexing.cu rename to mlx/backend/cuda/indexing.cpp index 7091ed3ec6..a3b5332270 100644 --- a/mlx/backend/cuda/indexing.cu +++ b/mlx/backend/cuda/indexing.cpp @@ -55,53 +55,6 @@ void append_indices_arg( } // namespace -namespace cu { - -template -__global__ void masked_assign( - const bool* mask, - const int32_t* scatter_offsets, - const T* src, - T* out, - IdxT total, - const __grid_constant__ Shape src_shape, - const __grid_constant__ Strides src_strides, - int32_t src_ndim, - IdxT src_batch_size, - IdxT mask_batch_size) { - IdxT block_id = static_cast(blockIdx.x) + - static_cast(gridDim.x) * - (static_cast(blockIdx.y) + - static_cast(gridDim.y) * static_cast(blockIdx.z)); - IdxT thread_id = block_id * blockDim.x + threadIdx.x; - IdxT stride = - static_cast(blockDim.x) * gridDim.x * gridDim.y * gridDim.z; - - for (IdxT idx = thread_id; idx < total; idx += stride) { - if (!mask[idx]) { - continue; - } - - IdxT src_index = static_cast(scatter_offsets[idx]); - if (src_index >= src_batch_size) { - // Match Metal backend behavior by skipping out-of-range source reads. - continue; - } - - IdxT batch_idx = idx / mask_batch_size; - if constexpr (src_contiguous) { - out[idx] = src[batch_idx * src_batch_size + src_index]; - } else { - IdxT src_elem = batch_idx * src_batch_size + src_index; - IdxT src_loc = - elem_to_loc(src_elem, src_shape.data(), src_strides.data(), src_ndim); - out[idx] = src[src_loc]; - } - } -} - -} // namespace cu - void Gather::eval_gpu(const std::vector& inputs, array& out) { nvtx3::scoped_range r("Gather::eval_gpu"); assert(inputs.size() > 0); @@ -528,42 +481,56 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { const size_t batch_count = mask.shape(0); const size_t mask_batch_size = mask_flat.size() / batch_count; const size_t src_batch_size = src.size() / src.shape(0); + bool large = total > INT32_MAX || src.size() > INT32_MAX; + + std::string module_name = + fmt::format("masked_scatter_assign_{}", dtype_to_string(out.dtype())); + cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + std::vector kernel_names; + for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) { + for (int use_large = 0; use_large <= 1; ++use_large) { + kernel_names.push_back( + fmt::format( + "mlx::core::cu::masked_scatter_assign<{}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + src_contiguous ? "true" : "false", + use_large ? "int64_t" : "int32_t")); + } + } + return std::make_tuple(false, jit_source_scatter, std::move(kernel_names)); + }); + + cu::KernelArgs args; + args.append(mask_flat); + args.append(scatter_offsets); + args.append(src); + args.append(out); + if (large) { + args.append(mask_flat.size()); + args.append(src_batch_size); + args.append(mask_batch_size); + } else { + args.append(mask_flat.size()); + args.append(src_batch_size); + args.append(mask_batch_size); + } + args.append_ndim(src.shape()); + args.append_ndim(src.strides()); + args.append(src.ndim()); encoder.set_input_array(mask_flat); encoder.set_input_array(scatter_offsets); encoder.set_input_array(src); encoder.set_output_array(out); - dispatch_all_types(out.dtype(), [&](auto type_tag) { - using T = cuda_type_t; - dispatch_bool(src.flags().row_contiguous, [&](auto src_contiguous) { - dispatch_bool( - total > INT32_MAX || src.size() > INT32_MAX, [&](auto large) { - using IdxT = std::conditional_t; - auto [num_blocks, block_dims] = get_launch_args( - mask_flat.size(), - mask_flat.shape(), - mask_flat.strides(), - large()); - auto kernel = cu::masked_assign; - encoder.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(mask_flat), - gpu_ptr(scatter_offsets), - gpu_ptr(src), - gpu_ptr(out), - static_cast(mask_flat.size()), - const_param(src.shape()), - const_param(src.strides()), - static_cast(src.ndim()), - static_cast(src_batch_size), - static_cast(mask_batch_size)); - }); - }); - }); + std::string kernel_name = fmt::format( + "mlx::core::cu::masked_scatter_assign<{}, {}, {}>", + dtype_to_cuda_type(out.dtype()), + src.flags().row_contiguous ? "true" : "false", + large ? "int64_t" : "int32_t"); + auto kernel = mod.get_kernel(kernel_name); + auto [num_blocks, block_dims] = get_launch_args(mask_flat, large); + encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); } } // namespace mlx::core From f5693f7e891359e6845ed4d13c63a90c5a1eb394 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Fri, 27 Feb 2026 18:23:21 +0800 Subject: [PATCH 4/5] refactor(cuda): use add_kernel_node_raw for masked scatter launch --- mlx/backend/cuda/indexing.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index a3b5332270..2ccd420ae1 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -530,7 +530,8 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { large ? "int64_t" : "int32_t"); auto kernel = mod.get_kernel(kernel_name); auto [num_blocks, block_dims] = get_launch_args(mask_flat, large); - encoder.add_kernel_node(kernel, num_blocks, block_dims, 0, args.args()); + encoder.add_kernel_node_raw( + kernel, num_blocks, block_dims, {}, 0, args.args()); } } // namespace mlx::core From 19571bc4bf3785969d3878cdcc6c669e1788f455 Mon Sep 17 00:00:00 2001 From: Lyxot Date: Sat, 28 Feb 2026 15:52:33 +0800 Subject: [PATCH 5/5] test: update bench script --- benchmarks/python/masked_scatter.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/benchmarks/python/masked_scatter.py b/benchmarks/python/masked_scatter.py index 71857c5436..5a3310ad49 100644 --- a/benchmarks/python/masked_scatter.py +++ b/benchmarks/python/masked_scatter.py @@ -1,5 +1,6 @@ import math import os +import platform import subprocess import time from copy import copy @@ -17,9 +18,6 @@ if not os.path.isdir(RESULTS_DIR): os.mkdir(RESULTS_DIR) -DEVICE_NAME = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) -DEVICE_NAME = DEVICE_NAME.decode("utf-8").strip("\n") - TORCH_DEVICE = torch.device( "mps" if torch.backends.mps.is_available() @@ -27,6 +25,31 @@ ) +def get_device_name(): + if TORCH_DEVICE.type == "cuda": + try: + out = subprocess.check_output( + ["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"], + stderr=subprocess.DEVNULL, + ) + return out.decode("utf-8").splitlines()[0].strip() + except Exception: + return "CUDA_GPU" + if TORCH_DEVICE.type == "mps": + try: + out = subprocess.check_output( + ["sysctl", "-n", "machdep.cpu.brand_string"], + stderr=subprocess.DEVNULL, + ) + return out.decode("utf-8").strip() + except Exception: + return "Apple_Silicon" + return platform.processor() or platform.machine() or "CPU" + + +DEVICE_NAME = get_device_name() + + N_WARMUP = 5 N_ITER_BENCH = 50 N_ITER_FUNC = 20