Skip to content

Commit

Permalink
[test] fix unit test breaks
Browse files Browse the repository at this point in the history
  • Loading branch information
xy12181 committed Feb 19, 2025
1 parent 2527fe4 commit 93022b9
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@

class CollectiveMatmulTest(absltest.TestCase):

def _create_device_mesh(self):
devices = jax.devices()
return parallel.create_device_mesh(
devices=devices,
shape=len(devices),
)

def test_all_gather_collective_matmul(self):
key1, key2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)
lhs = jax.random.normal(key1, shape=(1, 32), dtype=jnp.float32)
rhs = jax.random.normal(key2, shape=(32, 16), dtype=jnp.float32)
expect = lhs @ rhs

mesh = parallel.create_device_mesh(jax.devices(), (2, 4))
mesh = self._create_device_mesh()
axis_names = mesh.axis_names
rhs = jax.device_put(rhs, NamedSharding(mesh, P(None, axis_names)))
rhs = kernel.prepare_rhs_for_all_gather_collective_matmul(rhs, mesh)
Expand All @@ -61,7 +68,7 @@ def test_collective_matmul_reduce_scatter(self):
rhs = jax.random.uniform(key2, shape=(64, 64), dtype=jnp.float32)
expect = lhs @ rhs

mesh = parallel.create_device_mesh(jax.devices(), (2, 4))
mesh = self._create_device_mesh()
axis_names = mesh.axis_names
rhs = jax.device_put(rhs, NamedSharding(mesh, P(axis_names, None)))

Expand Down
11 changes: 7 additions & 4 deletions experimental/jax/tests/model/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@

class LlamaModelTest(absltest.TestCase):

@classmethod
def setUpClass(cls):
super().setUpClass()
def _create_device_mesh(self):
devices = jax.devices()
return parallel.create_device_mesh(
devices=devices,
shape=len(devices),
)

def test_llama(self):
# TODO: make it as an accuracy test.
mesh = self._create_device_mesh()
model_id = "meta-llama/Llama-2-7b-chat-hf"
mesh = parallel.create_device_mesh(jax.devices(), len(jax.devices()))
model_registry = ModelRegistry()

config, tokenizer = model_registry.load_model_config(
Expand Down
15 changes: 9 additions & 6 deletions experimental/jax/tests/nn/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@

class EmbeddingTest(absltest.TestCase):

def test_embedding(self):
mesh = parallel.create_device_mesh(
devices=jax.devices(),
shape=len(jax.devices()),
def _create_device_mesh(self):
devices = jax.devices()
return parallel.create_device_mesh(
devices=devices,
shape=len(devices),
)
vocal_size = 2048
emb_dim = 8192

def test_embedding(self):
mesh = self._create_device_mesh()
vocal_size, emb_dim = 2048, 8192
embedding_layer = nn.Embedding(
vocal_size,
emb_dim,
Expand Down
105 changes: 50 additions & 55 deletions experimental/jax/tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,80 +22,76 @@

class ModuleTest(absltest.TestCase):

def test_random_code_initialize(self):
w0, w1, w2, w3 = (
def setUp(self):
super().setUp()
self.w0, self.w1, self.w2, self.w3 = (
jnp.ones((1,)),
jnp.ones((2,)),
jnp.ones((3,)),
jnp.ones((4,)),
)
parent_module = Module()
parent_module.w0 = Parameter(w0)

h1_child_0_module = Module()
h1_child_0_module.w1 = Parameter(w1)

h1_child_1_module = Module()
h1_child_1_module.w2 = Parameter(w2)

h2_child_0_module = Module()
h2_child_0_module.w3 = Parameter(w3)

parent_module.child0 = h1_child_0_module
parent_module.child1 = h1_child_1_module
h1_child_0_module.child0 = h2_child_0_module

parent_module.init_weights()
self.parent_module = Module()
self.parent_module.w0 = Parameter(self.w0)
self.h1_child_0_module = Module()
self.h1_child_0_module.w1 = Parameter(self.w1)
self.h1_child_1_module = Module()
self.h1_child_1_module.w2 = Parameter(self.w2)
self.h2_child_0_module = Module()
self.h2_child_0_module.w3 = Parameter(self.w3)

self.parent_module.child0 = self.h1_child_0_module
self.parent_module.child1 = self.h1_child_1_module
self.h1_child_0_module.child0 = self.h2_child_0_module

def test_random_code_initialize(self):
self.parent_module.init_weights()
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
w0,
parent_module.w0.value,
self.w0,
self.parent_module.w0.value,
)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
w1,
h1_child_0_module.w1.value,
self.w1,
self.h1_child_0_module.w1.value,
)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
w2,
h1_child_1_module.w2.value,
self.w2,
self.h1_child_1_module.w2.value,
)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
w3,
h2_child_0_module.w3.value,
self.w3,
self.h2_child_0_module.w3.value,
)

def test_load_weights_dict(self):
w0, w1, w2, w3 = (
jnp.ones((1,)),
jnp.ones((2,)),
jnp.ones((3,)),
jnp.ones((4,)),
)
parent_module = Module()
parent_module.w0 = Parameter(w0)

h1_child_0_module = Module()
h1_child_0_module.w1 = Parameter(w1)

h1_child_1_module = Module()
h1_child_1_module.w2 = Parameter(w2)
parent_weight_dict = {
"w0": jnp.ones((1,)),
"child0": {
"w1": jnp.ones((2,)),
"child0": {
"w3": jnp.ones((4,)),
},
},
"child1": {
"w2": jnp.ones((3,)),
},
}

h2_child_0_module = Module()
h2_child_0_module.w3 = Parameter(w3)
self.parent_module.load_weights_dict(parent_weight_dict)

parent_module.child0 = h1_child_0_module
parent_module.child1 = h1_child_1_module
h1_child_0_module.child0 = h2_child_0_module
print(parent_module)
np.testing.assert_array_equal(self.parent_module.w0, self.w0)
np.testing.assert_array_equal(self.h1_child_0_module.w1, self.w1)
np.testing.assert_array_equal(self.h1_child_1_module.w2, self.w2)
np.testing.assert_array_equal(self.h2_child_0_module.w3, self.w3)

def test_load_weights_dict_error(self):
partial_parent_weight_dict = {
"w0": jnp.zeros((1,)),
"child0": {
Expand All @@ -105,21 +101,20 @@ def test_load_weights_dict(self):
},
},
}

child1_weight_dict = {
"w2": jnp.zeros((2,)),
"wrong_weight_not_load": jnp.zeros((2,)),
}

parent_module.load_weights_dict(partial_parent_weight_dict)
h1_child_1_module.load_weights_dict(child1_weight_dict)
self.parent_module.load_weights_dict(partial_parent_weight_dict)
self.h1_child_1_module.load_weights_dict(child1_weight_dict)

np.testing.assert_array_equal(parent_module.w0, 0)
np.testing.assert_array_equal(h1_child_0_module.w1, 0)
np.testing.assert_array_equal(h1_child_1_module.w2, 0)
np.testing.assert_array_equal(h2_child_0_module.w3, 0)
np.testing.assert_array_equal(self.parent_module.w0, 0)
np.testing.assert_array_equal(self.h1_child_0_module.w1, 0)
np.testing.assert_array_equal(self.h1_child_1_module.w2, 0)
np.testing.assert_array_equal(self.h2_child_0_module.w3, 0)

assert not h1_child_1_module.wrong_weight_not_load
assert not self.h1_child_1_module.wrong_weight_not_load


if __name__ == "__main__":
Expand Down
12 changes: 8 additions & 4 deletions experimental/jax/tests/nn/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@

class NormTest(absltest.TestCase):

def test_rmsnorm_per_device_forward(self):
mesh = parallel.create_device_mesh(
devices=jax.devices(),
shape=len(jax.devices()),
def _create_device_mesh(self):
devices = jax.devices()
return parallel.create_device_mesh(
devices=devices,
shape=len(devices),
)

def test_rmsnorm_per_device_forward(self):
mesh = self._create_device_mesh()
hidden_state_size = 128
eps = 1e-6
rmsnorm_layer = nn.RMSNorm(
Expand Down
7 changes: 4 additions & 3 deletions experimental/jax/tests/parallel/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
class CollectiveOperationsTest(absltest.TestCase):

def _build_mesh(self):
axis = "x"
devices = jax.devices()
device_mesh = jax.experimental.mesh_utils.create_device_mesh(
len(jax.devices()), jax.devices()
(len(devices),), devices
)
mesh = jax.sharding.Mesh(device_mesh, ("x"))
axis = "x"
mesh = jax.sharding.Mesh(device_mesh, (axis,))
return mesh, axis

def test_reduce_scatter(self):
Expand Down

0 comments on commit 93022b9

Please sign in to comment.