diff --git a/.gitmodules b/.gitmodules index 40a4a0f..2e0fdd9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "deps/mlx"] path = deps/mlx - url = https://github.com/ml-explore/mlx + url = https://github.com/robert-johansson/mlx [submodule "deps/kizunapi"] path = deps/kizunapi url = https://github.com/photoionization/kizunapi diff --git a/deps/mlx b/deps/mlx index b529515..3b30ffc 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit b529515eb158edd0919746ce4e545fe0879d6437 +Subproject commit 3b30ffc5e8c16d116b57458b0d91e75a29f98bf5 diff --git a/node_mlx.node.d.ts b/node_mlx.node.d.ts index 2000c8a..fcaf69a 100644 --- a/node_mlx.node.d.ts +++ b/node_mlx.node.d.ts @@ -240,6 +240,10 @@ declare module '*node_mlx.node' { function notEqual(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array; function erf(array: ScalarOrArray, s?: StreamOrDevice): array; function erfinv(array: ScalarOrArray, s?: StreamOrDevice): array; + function lgamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function digamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI0e(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI1e(array: ScalarOrArray, s?: StreamOrDevice): array; function exp(array: ScalarOrArray, s?: StreamOrDevice): array; function expm1(array: ScalarOrArray, s?: StreamOrDevice): array; function expandDims(array: ScalarOrArray, dims: number | number[], s?: StreamOrDevice): array; diff --git a/src/fast.cc b/src/fast.cc index 6cfe1e0..24aa539 100644 --- a/src/fast.cc +++ b/src/fast.cc @@ -42,16 +42,16 @@ mx::array ScaledDotProductAttention( throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, mask_str, {}, s); + queries, keys, values, scale, mask_str, {}, {}, s); } else { auto mask_arr = std::get(mask); return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {mask_arr}, s); + queries, keys, values, scale, "", {mask_arr}, {}, s); } } else { return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {}, s); + queries, keys, values, scale, "", {}, {}, s); } } diff --git a/src/fft.cc b/src/fft.cc index 808889a..51cccca 100644 --- a/src/fft.cc +++ b/src/fft.cc @@ -32,7 +32,7 @@ std::function FFTNOpWrapper(const char* name, mx::array(*func1)(const mx::array&, - const std::vector&, + const mx::Shape&, const std::vector&, mx::StreamOrDevice), mx::array(*func2)(const mx::array&, @@ -45,16 +45,17 @@ FFTNOpWrapper(const char* name, std::optional> axes, mx::StreamOrDevice s) { if (n && axes) { - return mx::fft::fftn(a, std::move(*n), std::move(*axes), s); + mx::Shape shape_n(n->begin(), n->end()); + return func1(a, shape_n, std::move(*axes), s); } else if (axes) { - return mx::fft::fftn(a, std::move(*axes), s); + return func2(a, std::move(*axes), s); } else if (n) { std::ostringstream msg; msg << "[" << name << "] " << "`axes` should not be `None` if `s` is not `None`."; throw std::invalid_argument(msg.str()); } else { - return mx::fft::fftn(a, s); + return func3(a, s); } }; } @@ -66,7 +67,7 @@ std::function FFT2OpWrapper(const char* name, mx::array(*func1)(const mx::array&, - const std::vector&, + const mx::Shape&, const std::vector&, mx::StreamOrDevice), mx::array(*func2)(const mx::array&, diff --git a/src/indexing.cc b/src/indexing.cc index e5d7286..b61dc12 100644 --- a/src/indexing.cc +++ b/src/indexing.cc @@ -563,7 +563,7 @@ ScatterResult ScatterArgsNDimentional(const mx::array* a, a->shape().begin() + non_none_indices, a->shape().end()); up = mx::reshape(std::move(up), std::move(up_reshape)); - mx::Shape axes(arr_indices.size(), 0); + std::vector axes(arr_indices.size(), 0); std::iota(axes.begin(), axes.end(), 0); return {std::move(arr_indices), std::move(up), std::move(axes)}; } diff --git a/src/metal.cc b/src/metal.cc index 91f79ae..777f1b4 100644 --- a/src/metal.cc +++ b/src/metal.cc @@ -1,4 +1,14 @@ #include "src/bindings.h" +#include "mlx/backend/gpu/device_info.h" + +namespace metal_ops { + +const std::unordered_map>& +DeviceInfo() { + return mx::gpu::device_info(0); +} + +} // namespace metal_ops void InitMetal(napi_env env, napi_value exports) { napi_value metal = ki::CreateObject(env); @@ -8,5 +18,5 @@ void InitMetal(napi_env env, napi_value exports) { "isAvailable", &mx::metal::is_available, "startCapture", &mx::metal::start_capture, "stopCapture", &mx::metal::stop_capture, - "deviceInfo", &mx::metal::device_info); + "deviceInfo", &metal_ops::DeviceInfo); } diff --git a/src/ops.cc b/src/ops.cc index aad9d05..7707dd4 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -191,7 +191,7 @@ mx::array Full(std::variant shape, ScalarOrArray vals, std::optional dtype, mx::StreamOrDevice s) { - return mx::full(PutIntoVector(std::move(shape)), + return mx::full(PutIntoShape(std::move(shape)), ToArray(std::move(vals), std::move(dtype)), s); } @@ -199,13 +199,13 @@ mx::array Full(std::variant shape, mx::array Zeros(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::zeros(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::zeros(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Ones(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::ones(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::ones(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Eye(int n, @@ -303,8 +303,9 @@ std::vector Split(const mx::array& a, if (auto i = std::get_if(&indices); i) { return mx::split(a, *i, axis.value_or(0), s); } else { - return mx::split(a, std::move(std::get>(indices)), - axis.value_or(0), s); + auto& v = std::get>(indices); + mx::Shape shape_indices(v.begin(), v.end()); + return mx::split(a, std::move(shape_indices), axis.value_or(0), s); } } @@ -544,7 +545,7 @@ mx::array ConvTranspose1d( mx::StreamOrDevice s) { return mx::conv_transpose1d(input, weight, stride.value_or(1), padding.value_or(0), dilation.value_or(1), - groups.value_or(1), s); + /*output_padding=*/0, groups.value_or(1), s); } mx::array ConvTranspose2d( @@ -574,7 +575,7 @@ mx::array ConvTranspose2d( dilation_pair = std::move(*p); return mx::conv_transpose2d(input, weight, stride_pair, padding_pair, - dilation_pair, groups.value_or(1), s); + dilation_pair, {0, 0}, groups.value_or(1), s); } mx::array ConvTranspose3d( @@ -604,7 +605,7 @@ mx::array ConvTranspose3d( dilation_tuple = std::move(*p); return mx::conv_transpose3d(input, weight, stride_tuple, padding_tuple, - dilation_tuple, groups.value_or(1), s); + dilation_tuple, {0, 0, 0}, groups.value_or(1), s); } mx::array ConvGeneral( @@ -789,6 +790,10 @@ void InitOps(napi_env env, napi_value exports) { "expm1", &mx::expm1, "erf", &mx::erf, "erfinv", &mx::erfinv, + "lgamma", &mx::lgamma, + "digamma", &mx::digamma, + "besselI0e", &mx::bessel_i0e, + "besselI1e", &mx::bessel_i1e, "sin", &mx::sin, "cos", &mx::cos, "tan", &mx::tan, @@ -811,7 +816,9 @@ void InitOps(napi_env env, napi_value exports) { "stopGradient", &mx::stop_gradient, "sigmoid", &mx::sigmoid, "power", BinOpWrapper(&mx::power), - "arange", &ops::ARange, + "arange", &ops::ARange); + + ki::Set(env, exports, "linspace", &ops::Linspace, "kron", &mx::kron, "take", &ops::Take, @@ -848,12 +855,16 @@ void InitOps(napi_env env, napi_value exports) { "min", DimOpWrapper(&mx::min), "max", DimOpWrapper(&mx::max), "logcumsumexp", CumOpWrapper(&mx::logcumsumexp), - "logsumexp", DimOpWrapper(&mx::logsumexp), + "logsumexp", DimOpWrapper(&mx::logsumexp)); + + ki::Set(env, exports, "mean", DimOpWrapper(&mx::mean), "variance", &ops::Var, "std", &ops::Std, "split", &ops::Split, - "argmin", &ops::ArgMin, + "argmin", &ops::ArgMin); + + ki::Set(env, exports, "argmax", &ops::ArgMax, "sort", &ops::Sort, "argsort", &ops::ArgSort, @@ -864,7 +875,9 @@ void InitOps(napi_env env, napi_value exports) { "blockMaskedMM", &mx::block_masked_mm, "gatherMM", &mx::gather_mm, "gatherQMM", &mx::gather_qmm, - "softmax", &ops::Softmax, + "softmax", &ops::Softmax); + + ki::Set(env, exports, "concatenate", &ops::Concatenate, "concat", &ops::Concatenate, "stack", &ops::Stack, @@ -876,7 +889,9 @@ void InitOps(napi_env env, napi_value exports) { "cumsum", CumOpWrapper(&mx::cumsum), "cumprod", CumOpWrapper(&mx::cumprod), "cummax", CumOpWrapper(&mx::cummax), - "cummin", CumOpWrapper(&mx::cummin), + "cummin", CumOpWrapper(&mx::cummin)); + + ki::Set(env, exports, "conj", &mx::conjugate, "conjugate", &mx::conjugate, "convolve", &ops::Convolve, @@ -912,7 +927,9 @@ void InitOps(napi_env env, napi_value exports) { "bitwiseXor", BinOpWrapper(&mx::bitwise_xor), "leftShift", BinOpWrapper(&mx::left_shift), "rightShift", BinOpWrapper(&mx::right_shift), - "view", &mx::view, + "view", &mx::view); + + ki::Set(env, exports, "hadamardTransform", &mx::hadamard_transform, "einsumPath", &mx::einsum_path, "einsum", &mx::einsum, diff --git a/src/utils.cc b/src/utils.cc index b827c53..f309c48 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,7 +1,7 @@ #include "src/array.h" #include "src/utils.h" -mx::Shape PutIntoVector(std::variant shape) { +mx::Shape PutIntoShape(std::variant shape) { if (auto i = std::get_if(&shape); i) return {*i}; return std::move(std::get(shape)); diff --git a/src/utils.h b/src/utils.h index 8e9fd41..3cf2fb2 100644 --- a/src/utils.h +++ b/src/utils.h @@ -8,12 +8,45 @@ namespace mx = mlx::core; +// Teach kizunapi how to serialize/deserialize SmallVector (used for Shape +// and other types in MLX >= 0.26). Mirrors the std::vector specialization. +namespace ki { + +template +struct Type> { + static constexpr const char* name = "Array"; + static napi_status ToNode(napi_env env, + const mlx::core::SmallVector& vec, + napi_value* result) { + napi_status s = napi_create_array_with_length(env, vec.size(), result); + if (s != napi_ok) return s; + for (size_t i = 0; i < vec.size(); ++i) { + napi_value el; + s = ConvertToNode(env, vec[i], &el); + if (s != napi_ok) return s; + s = napi_set_element(env, *result, i, el); + if (s != napi_ok) return s; + } + return napi_ok; + } + static std::optional> FromNode( + napi_env env, napi_value value) { + // Read as std::vector then convert to SmallVector. + auto vec = Type>::FromNode(env, value); + if (!vec) return std::nullopt; + return mlx::core::SmallVector(vec->begin(), vec->end()); + } +}; + +} // namespace ki + using OptionalAxes = std::variant>; using ScalarOrArray = std::variant; -// Read args into a vector of types. -template -bool ReadArgs(ki::Arguments* args, std::vector* results) { +// Read args into a container of types (vector or SmallVector). +template +bool ReadArgs(ki::Arguments* args, Container* results) { + using T = typename Container::value_type; while (args->RemainingsLength() > 0) { std::optional a = args->GetNext(); if (!a) { @@ -45,8 +78,15 @@ void DefineToString(napi_env env, napi_value prototype) { symbol, ki::MemberFunction(&ToString)); } +// If input is one int, put it into a Shape, otherwise just return the Shape. +mx::Shape PutIntoShape(std::variant shape); + // If input is one int, put it into a vector, otherwise just return the vector. -std::vector PutIntoVector(std::variant> shape); +inline std::vector PutIntoVector(std::variant> v) { + if (auto i = std::get_if(&v); i) + return {*i}; + return std::move(std::get>(v)); +} // Get axis arg from js value. std::vector GetReduceAxes(OptionalAxes value, int dims);