Skip to content

Commit

Permalink
Update to MLX 0.20.0
Browse files Browse the repository at this point in the history
  • Loading branch information
zcbenz committed Nov 18, 2024
1 parent 6d8b629 commit 6a4182a
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 8 deletions.
8 changes: 6 additions & 2 deletions src/indexing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,12 @@ std::pair<bool, mx::array> SliceUpdate(
// Pre process tuple.
mx::array up = ToArray(std::move(vals), a->dtype());

// Remove leading singletons dimensions from the update.
std::vector<int> up_shape = GetUpShape(up);
// Remove extra leading singletons dimensions from the update.
int s = 0;
while (s < up.ndim() && up.shape(s) == 1 && (up.ndim() - s) > a->ndim()) {
s++;
}
std::vector<int> up_shape(up.shape().begin() + s, up.shape().end());
up = mx::reshape(std::move(up), up_shape.empty() ? std::vector<int>{1}
: std::move(up_shape));

Expand Down
18 changes: 14 additions & 4 deletions tests/fast.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ describe('fast', () => {
const rx = rmsNorm(x, weight, eps);
const rxFast = mx.fast.rmsNorm(x, weight, eps);
assert.isBelow(mx.abs(mx.subtract(rx, rxFast)).max().item() as number, 1e-6);

assert.throws(() => {
const x = mx.random.uniform(0, 1, [1, 5]);
mx.fast,rmsNorm(x, mx.ones([4]), 1e-5);
});
});

it('rmsNormGrad', () => {
Expand Down Expand Up @@ -213,15 +218,16 @@ describe('fast', () => {
assert.isBelow((mx.subtract(gw1, gw2).abs().max().item() as number) / (gw1.abs().mean().item() as number), 1e-5);
});

it('layerNorm', () => {
it('layerNorm', function() {
// This test is unreliable in CPU.
if (!mx.metal.isAvailable())
this.retries(4);

const tolerances = [
{dtype: mx.float32, eps: 1e-5},
{dtype: mx.float16, eps: 5e-3},
{dtype: mx.bfloat16, eps: 5e-2},
];
if (process.env.CI == 'true') {
tolerances[0].eps = 1e-4;
}

const dtypes = [mx.float32, mx.float16, mx.bfloat16];
const epss = [1e-3, 1e-5];
Expand Down Expand Up @@ -346,6 +352,10 @@ describe('fast', () => {
it('layerNormGrad', function() {
this.timeout(10 * 1000); // slow in QEMU

// This test is unreliable in CPU.
if (!mx.metal.isAvailable())
this.retries(4);

let D = 32;
const eps = 1e-5;
const f1 = (x, w, b, y) => mx.multiply(layerNorm(x, w, b, eps), y).sum();
Expand Down
48 changes: 47 additions & 1 deletion tests/vmap.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,53 @@ describe('vmap', () => {
}
});

// FIXME(zcbenz): mx.vmap currently must have all args being mx.array.
xit('vmapGather', () => {
const gather = (a: mx.array, idx: any) => {
return a.index(idx);
};

let a = mx.array([[1, 2], [3, 4]]);
let idx = mx.array(0);
let out = mx.vmap(gather, [0, null])(a, idx);
assert.deepEqual(out.tolist(), [1, 3]);
out = mx.vmap(gather, [1, null])(a, idx);
assert.deepEqual(out.tolist(), [1, 2]);

idx = mx.array([0, 1]);
out = mx.vmap(gather, [0, 0])(a, idx);
assert.deepEqual(out.tolist(), [1, 4]);

a = mx.ones([2, 3, 4]);
idx = mx.zeros(4, mx.int32);
out = mx.vmap(gather, [2, 0])(a, idx);
assert.deepEqual(out.shape, [4, 3]);

let f = mx.vmap(gather, [0, 0]);
out = f(mx.ones([2, 3, 4]), mx.zeros(2, mx.int32));
assert.deepEqual(out.shape, [2, 4]);

const gather2 = (a: mx.array, idxa: any, idxb: any) => {
return a.index(idxa, idxb);
};

a = mx.ones([2, 3, 4]);
let idxa = mx.zeros([2, 3], mx.int32);
let idxb = mx.zeros(3, mx.int32);
out = mx.vmap(gather2, [0, 0, null])(a, idxa, idxb);
assert.deepEqual(out.shape, [2, 3]);

idxa = mx.zeros([3, 1, 2], mx.int32);
idxb = mx.zeros([2, 3, 1, 2], mx.int32);
out = mx.vmap(gather2, [0, null, 0])(a, idxa, idxb);
assert.deepEqual(out.shape, [2, 3, 1, 2]);

idxa = mx.zeros([3, 1, 2], mx.int32);
idxb = mx.zeros([3, 1, 2, 2], mx.int32);
out = mx.vmap(gather2, [0, null, 3])(a, idxa, idxb);
assert.deepEqual(out.shape, [2, 3, 1, 2]);
});

it('vmapScatter', () => {
const scatter = (a: mx.array) => {
a.indexPut_(0, mx.array(0.0));
Expand Down Expand Up @@ -409,7 +456,6 @@ describe('vmap', () => {
// out = mx.vmap(constFunc, [0, null])(a, b);
// assertArrayAllTrue(mx.arrayEqual(out, mx.full([2], 2)));
// out = mx.vmap(constFunc, [null, 0])(a, b);
// FIXME(zcbenz): mx.vmap currently must have all args being mx.array.
// assertArrayAllTrue(mx.arrayEqual(out, mx.full([4], 2)));
out = mx.vmap(constFunc, [1, 1])(a, b);
assertArrayAllTrue(mx.arrayEqual(out, mx.full([3], 2)));
Expand Down

0 comments on commit 6a4182a

Please sign in to comment.