diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index a2e78d3842..15824d6c5c 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -323,9 +323,18 @@ Device::Device() { auto pool = new_scoped_memory_pool(); device_ = load_device(); default_library_ = load_default_library(device_); - arch_ = std::string(device_->architecture()->name()->utf8String()); - int ag_tens = arch_[arch_.size() - 3] - '0'; - int ag_ones = arch_[arch_.size() - 2] - '0'; + arch_ = env::metal_gpu_arch(); + if (arch_.empty()) { + arch_ = std::string(device_->architecture()->name()->utf8String()); + } + int ag_tens = 0; + int ag_ones = 0; + if (arch_.size() >= 3) { + ag_tens = arch_[arch_.size() - 3] - '0'; + ag_ones = arch_[arch_.size() - 2] - '0'; + ag_tens = (ag_tens < 10 && ag_tens >= 0) ? ag_tens : 0; + ag_ones = (ag_ones < 10 && ag_ones >= 0) ? ag_ones : 0; + } arch_gen_ = ag_tens * 10 + ag_ones; auto arch = arch_.back(); switch (arch) { diff --git a/mlx/backend/metal/device_info.cpp b/mlx/backend/metal/device_info.cpp index dd18fc6c6f..b8f5f0e752 100644 --- a/mlx/backend/metal/device_info.cpp +++ b/mlx/backend/metal/device_info.cpp @@ -21,9 +21,10 @@ device_info(int device_index) { auto init_device_info = []() -> std::unordered_map> { auto pool = metal::new_scoped_memory_pool(); - auto raw_device = metal::device(mlx::core::Device::gpu).mtl_device(); + auto& device = metal::device(mlx::core::Device::gpu); + auto raw_device = device.mtl_device(); auto name = std::string(raw_device->name()->utf8String()); - auto arch = std::string(raw_device->architecture()->name()->utf8String()); + auto arch = device.get_architecture(); size_t memsize = 0; size_t length = sizeof(memsize); diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index cb67c74f73..bd197937c5 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -82,10 +82,9 @@ inline array ensure_row_contiguous_matrix( } inline int get_qmv_batch_limit(int D, int O, metal::Device& d) { - auto arch = d.get_architecture(); - auto arch_size = arch.back(); - auto arch_gen = arch.substr(arch.size() - 3, 2); - if (arch_gen == "13" || arch_gen == "14") { + auto arch_size = d.get_architecture().back(); + auto arch_gen = d.get_architecture_gen(); + if (arch_gen == 13 || arch_gen == 14) { switch (arch_size) { case 'd': if (D <= 2048 && O <= 2048) { diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 1ae259b1ea..cf0e0f38db 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -258,6 +258,14 @@ int get_var(const char* name, int default_value) { } } +std::string get_var(const char* name, const char* default_value) { + if (const char* buff_str = std::getenv(name)) { + return buff_str; + } else { + return default_value; + } +} + } // namespace env template diff --git a/mlx/utils.h b/mlx/utils.h index bb2de466b5..62aa82b658 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -136,6 +136,7 @@ inline int next_power_of_2(int n) { namespace env { int get_var(const char* name, int default_value); +std::string get_var(const char* name, const char* default_value); inline int bfs_max_width() { static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20); @@ -169,6 +170,11 @@ inline int nccl_timeout(int default_value) { return nccl_timeout; } +inline const std::string& metal_gpu_arch() { + static std::string gpu_arch_ = get_var("MLX_METAL_GPU_ARCH", ""); + return gpu_arch_; +} + } // namespace env } // namespace mlx::core