Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deps/mlx
Submodule mlx updated 681 files
6 changes: 3 additions & 3 deletions src/fast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ mx::array ScaledDotProductAttention(
throw std::invalid_argument(msg.str());
}
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, mask_str, {}, s);
queries, keys, values, scale, mask_str, {}, {}, s);
} else {
auto mask_arr = std::get<mx::array>(mask);
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {mask_arr}, s);
queries, keys, values, scale, "", {mask_arr}, {}, s);
}

} else {
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {}, s);
queries, keys, values, scale, "", {}, {}, s);
}
}

Expand Down
11 changes: 6 additions & 5 deletions src/fft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::function<mx::array(const mx::array& a,
mx::StreamOrDevice s)>
FFTNOpWrapper(const char* name,
mx::array(*func1)(const mx::array&,
const std::vector<int>&,
const mx::Shape&,
const std::vector<int>&,
mx::StreamOrDevice),
mx::array(*func2)(const mx::array&,
Expand All @@ -45,16 +45,17 @@ FFTNOpWrapper(const char* name,
std::optional<std::vector<int>> axes,
mx::StreamOrDevice s) {
if (n && axes) {
return mx::fft::fftn(a, std::move(*n), std::move(*axes), s);
mx::Shape shape_n(n->begin(), n->end());
return func1(a, shape_n, std::move(*axes), s);
} else if (axes) {
return mx::fft::fftn(a, std::move(*axes), s);
return func2(a, std::move(*axes), s);
} else if (n) {
std::ostringstream msg;
msg << "[" << name << "] "
<< "`axes` should not be `None` if `s` is not `None`.";
throw std::invalid_argument(msg.str());
} else {
return mx::fft::fftn(a, s);
return func3(a, s);
}
};
}
Expand All @@ -66,7 +67,7 @@ std::function<mx::array(const mx::array& a,
mx::StreamOrDevice s)>
FFT2OpWrapper(const char* name,
mx::array(*func1)(const mx::array&,
const std::vector<int>&,
const mx::Shape&,
const std::vector<int>&,
mx::StreamOrDevice),
mx::array(*func2)(const mx::array&,
Expand Down
2 changes: 1 addition & 1 deletion src/indexing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ ScatterResult ScatterArgsNDimentional(const mx::array* a,
a->shape().begin() + non_none_indices, a->shape().end());
up = mx::reshape(std::move(up), std::move(up_reshape));

mx::Shape axes(arr_indices.size(), 0);
std::vector<int> axes(arr_indices.size(), 0);
std::iota(axes.begin(), axes.end(), 0);
return {std::move(arr_indices), std::move(up), std::move(axes)};
}
Expand Down
12 changes: 11 additions & 1 deletion src/metal.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
#include "src/bindings.h"
#include "mlx/backend/gpu/device_info.h"

namespace metal_ops {

const std::unordered_map<std::string, std::variant<std::string, size_t>>&
DeviceInfo() {
return mx::gpu::device_info(0);
}

} // namespace metal_ops

void InitMetal(napi_env env, napi_value exports) {
napi_value metal = ki::CreateObject(env);
Expand All @@ -8,5 +18,5 @@ void InitMetal(napi_env env, napi_value exports) {
"isAvailable", &mx::metal::is_available,
"startCapture", &mx::metal::start_capture,
"stopCapture", &mx::metal::stop_capture,
"deviceInfo", &mx::metal::device_info);
"deviceInfo", &metal_ops::DeviceInfo);
}
41 changes: 27 additions & 14 deletions src/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,21 @@ mx::array Full(std::variant<int, mx::Shape> shape,
ScalarOrArray vals,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
return mx::full(PutIntoVector(std::move(shape)),
return mx::full(PutIntoShape(std::move(shape)),
ToArray(std::move(vals), std::move(dtype)),
s);
}

mx::array Zeros(std::variant<int, mx::Shape> shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
return mx::zeros(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s);
return mx::zeros(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s);
}

mx::array Ones(std::variant<int, mx::Shape> shape,
std::optional<mx::Dtype> dtype,
mx::StreamOrDevice s) {
return mx::ones(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s);
return mx::ones(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s);
}

mx::array Eye(int n,
Expand Down Expand Up @@ -303,8 +303,9 @@ std::vector<mx::array> Split(const mx::array& a,
if (auto i = std::get_if<int>(&indices); i) {
return mx::split(a, *i, axis.value_or(0), s);
} else {
return mx::split(a, std::move(std::get<std::vector<int>>(indices)),
axis.value_or(0), s);
auto& v = std::get<std::vector<int>>(indices);
mx::Shape shape_indices(v.begin(), v.end());
return mx::split(a, std::move(shape_indices), axis.value_or(0), s);
}
}

Expand Down Expand Up @@ -544,7 +545,7 @@ mx::array ConvTranspose1d(
mx::StreamOrDevice s) {
return mx::conv_transpose1d(input, weight, stride.value_or(1),
padding.value_or(0), dilation.value_or(1),
groups.value_or(1), s);
/*output_padding=*/0, groups.value_or(1), s);
}

mx::array ConvTranspose2d(
Expand Down Expand Up @@ -574,7 +575,7 @@ mx::array ConvTranspose2d(
dilation_pair = std::move(*p);

return mx::conv_transpose2d(input, weight, stride_pair, padding_pair,
dilation_pair, groups.value_or(1), s);
dilation_pair, {0, 0}, groups.value_or(1), s);
}

mx::array ConvTranspose3d(
Expand Down Expand Up @@ -604,7 +605,7 @@ mx::array ConvTranspose3d(
dilation_tuple = std::move(*p);

return mx::conv_transpose3d(input, weight, stride_tuple, padding_tuple,
dilation_tuple, groups.value_or(1), s);
dilation_tuple, {0, 0, 0}, groups.value_or(1), s);
}

mx::array ConvGeneral(
Expand Down Expand Up @@ -811,7 +812,9 @@ void InitOps(napi_env env, napi_value exports) {
"stopGradient", &mx::stop_gradient,
"sigmoid", &mx::sigmoid,
"power", BinOpWrapper(&mx::power),
"arange", &ops::ARange,
"arange", &ops::ARange);

ki::Set(env, exports,
"linspace", &ops::Linspace,
"kron", &mx::kron,
"take", &ops::Take,
Expand Down Expand Up @@ -848,12 +851,16 @@ void InitOps(napi_env env, napi_value exports) {
"min", DimOpWrapper(&mx::min),
"max", DimOpWrapper(&mx::max),
"logcumsumexp", CumOpWrapper(&mx::logcumsumexp),
"logsumexp", DimOpWrapper(&mx::logsumexp),
"logsumexp", DimOpWrapper(&mx::logsumexp));

ki::Set(env, exports,
"mean", DimOpWrapper(&mx::mean),
"variance", &ops::Var,
"std", &ops::Std,
"split", &ops::Split,
"argmin", &ops::ArgMin,
"argmin", &ops::ArgMin);

ki::Set(env, exports,
"argmax", &ops::ArgMax,
"sort", &ops::Sort,
"argsort", &ops::ArgSort,
Expand All @@ -864,7 +871,9 @@ void InitOps(napi_env env, napi_value exports) {
"blockMaskedMM", &mx::block_masked_mm,
"gatherMM", &mx::gather_mm,
"gatherQMM", &mx::gather_qmm,
"softmax", &ops::Softmax,
"softmax", &ops::Softmax);

ki::Set(env, exports,
"concatenate", &ops::Concatenate,
"concat", &ops::Concatenate,
"stack", &ops::Stack,
Expand All @@ -876,7 +885,9 @@ void InitOps(napi_env env, napi_value exports) {
"cumsum", CumOpWrapper(&mx::cumsum),
"cumprod", CumOpWrapper(&mx::cumprod),
"cummax", CumOpWrapper(&mx::cummax),
"cummin", CumOpWrapper(&mx::cummin),
"cummin", CumOpWrapper(&mx::cummin));

ki::Set(env, exports,
"conj", &mx::conjugate,
"conjugate", &mx::conjugate,
"convolve", &ops::Convolve,
Expand Down Expand Up @@ -912,7 +923,9 @@ void InitOps(napi_env env, napi_value exports) {
"bitwiseXor", BinOpWrapper(&mx::bitwise_xor),
"leftShift", BinOpWrapper(&mx::left_shift),
"rightShift", BinOpWrapper(&mx::right_shift),
"view", &mx::view,
"view", &mx::view);

ki::Set(env, exports,
"hadamardTransform", &mx::hadamard_transform,
"einsumPath", &mx::einsum_path,
"einsum", &mx::einsum,
Expand Down
2 changes: 1 addition & 1 deletion src/utils.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "src/array.h"
#include "src/utils.h"

mx::Shape PutIntoVector(std::variant<int, mx::Shape> shape) {
mx::Shape PutIntoShape(std::variant<int, mx::Shape> shape) {
if (auto i = std::get_if<int>(&shape); i)
return {*i};
return std::move(std::get<mx::Shape>(shape));
Expand Down
48 changes: 44 additions & 4 deletions src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,45 @@

namespace mx = mlx::core;

// Teach kizunapi how to serialize/deserialize SmallVector<T> (used for Shape
// and other types in MLX >= 0.26). Mirrors the std::vector<T> specialization.
namespace ki {

template<typename T, unsigned N, typename A>
struct Type<mlx::core::SmallVector<T, N, A>> {
static constexpr const char* name = "Array";
static napi_status ToNode(napi_env env,
const mlx::core::SmallVector<T, N, A>& vec,
napi_value* result) {
napi_status s = napi_create_array_with_length(env, vec.size(), result);
if (s != napi_ok) return s;
for (size_t i = 0; i < vec.size(); ++i) {
napi_value el;
s = ConvertToNode(env, vec[i], &el);
if (s != napi_ok) return s;
s = napi_set_element(env, *result, i, el);
if (s != napi_ok) return s;
}
return napi_ok;
}
static std::optional<mlx::core::SmallVector<T, N, A>> FromNode(
napi_env env, napi_value value) {
// Read as std::vector then convert to SmallVector.
auto vec = Type<std::vector<T>>::FromNode(env, value);
if (!vec) return std::nullopt;
return mlx::core::SmallVector<T, N, A>(vec->begin(), vec->end());
}
};

} // namespace ki

using OptionalAxes = std::variant<std::monostate, int, std::vector<int>>;
using ScalarOrArray = std::variant<bool, float, mx::array>;

// Read args into a vector of types.
template<typename T>
bool ReadArgs(ki::Arguments* args, std::vector<T>* results) {
// Read args into a container of types (vector or SmallVector).
template<typename Container>
bool ReadArgs(ki::Arguments* args, Container* results) {
using T = typename Container::value_type;
while (args->RemainingsLength() > 0) {
std::optional<T> a = args->GetNext<T>();
if (!a) {
Expand Down Expand Up @@ -45,8 +78,15 @@ void DefineToString(napi_env env, napi_value prototype) {
symbol, ki::MemberFunction(&ToString<T>));
}

// If input is one int, put it into a Shape, otherwise just return the Shape.
mx::Shape PutIntoShape(std::variant<int, mx::Shape> shape);

// If input is one int, put it into a vector, otherwise just return the vector.
std::vector<int> PutIntoVector(std::variant<int, std::vector<int>> shape);
inline std::vector<int> PutIntoVector(std::variant<int, std::vector<int>> v) {
if (auto i = std::get_if<int>(&v); i)
return {*i};
return std::move(std::get<std::vector<int>>(v));
}

// Get axis arg from js value.
std::vector<int> GetReduceAxes(OptionalAxes value, int dims);
Expand Down