Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yongx/test #191

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions code_style.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Clean up Python codes using Pylint & Pyink
# Googlers: please run `sudo apt install pipx; pipx install pylint --force; pipx install pyink==23.10.0` in advance

set -e # Exit immediately if any command fails

FOLDERS_TO_FORMAT=("jetstream" "experimental")
LINE_LENGTH=$(grep -E "^max-line-length=" pylintrc | cut -d '=' -f 2)

# Check for --check flag
CHECK_ONLY_PYINK_FLAGS=""
if [[ "$1" == "--check" ]]; then
CHECK_ONLY_PYINK_FLAGS="--check --diff --color"
fi

for folder in "${FOLDERS_TO_FORMAT[@]}"
do
pyink "$folder" ${CHECK_ONLY_PYINK_FLAGS} --pyink-indentation=2 --line-length=${LINE_LENGTH}
done

for folder in "${FOLDERS_TO_FORMAT[@]}"
do
# pylint doesn't change files, only reports errors.
pylint "./$folder"
done

echo "Successfully clean up all codes."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is existing make format sufficient? In case you are unaware of that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove the code style script.


Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,20 @@

class CollectiveMatmulTest(absltest.TestCase):

def create_device_mesh(self):
devices = jax.devices()
return parallel.create_device_mesh(
devices=devices,
shape=len(devices),
)

def test_all_gather_collective_matmul(self):
key1, key2 = jax.random.PRNGKey(0), jax.random.PRNGKey(1)
lhs = jax.random.normal(key1, shape=(1, 32), dtype=jnp.float32)
rhs = jax.random.normal(key2, shape=(32, 16), dtype=jnp.float32)
expect = lhs @ rhs

mesh = parallel.create_device_mesh(jax.devices(), (2, 4))
mesh = self.create_device_mesh()
axis_names = mesh.axis_names
rhs = jax.device_put(rhs, NamedSharding(mesh, P(None, axis_names)))
rhs = kernel.prepare_rhs_for_all_gather_collective_matmul(rhs, mesh)
Expand All @@ -61,7 +68,7 @@ def test_collective_matmul_reduce_scatter(self):
rhs = jax.random.uniform(key2, shape=(64, 64), dtype=jnp.float32)
expect = lhs @ rhs

mesh = parallel.create_device_mesh(jax.devices(), (2, 4))
mesh = self.create_device_mesh()
axis_names = mesh.axis_names
rhs = jax.device_put(rhs, NamedSharding(mesh, P(axis_names, None)))

Expand Down
11 changes: 7 additions & 4 deletions experimental/jax/tests/model/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@

class LlamaModelTest(absltest.TestCase):

@classmethod
def setUpClass(cls):
super().setUpClass()
def create_device_mesh(self):
devices = jax.devices()
return parallel.create_device_mesh(
devices=devices,
shape=len(devices),
)

def test_llama(self):
# TODO: make it as an accuracy test.
mesh = self.create_device_mesh()
model_id = "meta-llama/Llama-2-7b-chat-hf"
mesh = parallel.create_device_mesh(jax.devices(), len(jax.devices()))
model_registry = ModelRegistry()

config, tokenizer = model_registry.load_model_config(
Expand Down
15 changes: 9 additions & 6 deletions experimental/jax/tests/nn/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@

class EmbeddingTest(absltest.TestCase):

def test_embedding(self):
mesh = parallel.create_device_mesh(
devices=jax.devices(),
shape=len(jax.devices()),
def create_device_mesh(self):
devices = jax.devices()
return parallel.create_device_mesh(
devices=devices,
shape=len(devices),
)
vocal_size = 2048
emb_dim = 8192

def test_embedding(self):
mesh = self.create_device_mesh()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self._create_device_mesh. Guide line is _ for private method

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

vocal_size, emb_dim = 2048, 8192
embedding_layer = nn.Embedding(
vocal_size,
emb_dim,
Expand Down
103 changes: 50 additions & 53 deletions experimental/jax/tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,80 +22,78 @@

class ModuleTest(absltest.TestCase):

def test_random_code_initialize(self):
w0, w1, w2, w3 = (
def setUp(self):
super().setUp()
self.w0, self.w1, self.w2, self.w3 = (
jnp.ones((1,)),
jnp.ones((2,)),
jnp.ones((3,)),
jnp.ones((4,)),
)
parent_module = Module()
parent_module.w0 = Parameter(w0)

h1_child_0_module = Module()
h1_child_0_module.w1 = Parameter(w1)

h1_child_1_module = Module()
h1_child_1_module.w2 = Parameter(w2)
self.parent_module = Module()
self.parent_module.w0 = Parameter(self.w0)
self.h1_child_0_module = Module()
self.h1_child_0_module.w1 = Parameter(self.w1)
self.h1_child_1_module = Module()
self.h1_child_1_module.w2 = Parameter(self.w2)
self.h2_child_0_module = Module()
self.h2_child_0_module.w3 = Parameter(self.w3)

h2_child_0_module = Module()
h2_child_0_module.w3 = Parameter(w3)
self.parent_module.child0 = self.h1_child_0_module
self.parent_module.child1 = self.h1_child_1_module
self.h1_child_0_module.child0 = self.h2_child_0_module

parent_module.child0 = h1_child_0_module
parent_module.child1 = h1_child_1_module
h1_child_0_module.child0 = h2_child_0_module

parent_module.init_weights()
print(self.parent_module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def test_random_code_initialize(self):
self.parent_module.init_weights()
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
w0,
parent_module.w0.value,
self.w0,
self.parent_module.w0.value,
)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
w1,
h1_child_0_module.w1.value,
self.w1,
self.h1_child_0_module.w1.value,
)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
w2,
h1_child_1_module.w2.value,
self.w2,
self.h1_child_1_module.w2.value,
)
np.testing.assert_raises(
AssertionError,
np.testing.assert_array_equal,
w3,
h2_child_0_module.w3.value,
self.w3,
self.h2_child_0_module.w3.value,
)

def test_load_weights_dict(self):
w0, w1, w2, w3 = (
jnp.ones((1,)),
jnp.ones((2,)),
jnp.ones((3,)),
jnp.ones((4,)),
)
parent_module = Module()
parent_module.w0 = Parameter(w0)

h1_child_0_module = Module()
h1_child_0_module.w1 = Parameter(w1)

h1_child_1_module = Module()
h1_child_1_module.w2 = Parameter(w2)
parent_weight_dict = {
"w0": jnp.ones((1,)),
"child0": {
"w1": jnp.ones((2,)),
"child0": {
"w3": jnp.ones((4,)),
},
},
"child1": {
"w2": jnp.ones((3,)),
},
}

h2_child_0_module = Module()
h2_child_0_module.w3 = Parameter(w3)
self.parent_module.load_weights_dict(parent_weight_dict)

parent_module.child0 = h1_child_0_module
parent_module.child1 = h1_child_1_module
h1_child_0_module.child0 = h2_child_0_module
print(parent_module)
np.testing.assert_array_equal(self.parent_module.w0, self.w0)
np.testing.assert_array_equal(self.h1_child_0_module.w1, self.w1)
np.testing.assert_array_equal(self.h1_child_1_module.w2, self.w2)
np.testing.assert_array_equal(self.h2_child_0_module.w3, self.w3)

def test_load_weights_dict_error(self):
partial_parent_weight_dict = {
"w0": jnp.zeros((1,)),
"child0": {
Expand All @@ -105,21 +103,20 @@ def test_load_weights_dict(self):
},
},
}

child1_weight_dict = {
"w2": jnp.zeros((2,)),
"wrong_weight_not_load": jnp.zeros((2,)),
}

parent_module.load_weights_dict(partial_parent_weight_dict)
h1_child_1_module.load_weights_dict(child1_weight_dict)
self.parent_module.load_weights_dict(partial_parent_weight_dict)
self.h1_child_1_module.load_weights_dict(child1_weight_dict)

np.testing.assert_array_equal(parent_module.w0, 0)
np.testing.assert_array_equal(h1_child_0_module.w1, 0)
np.testing.assert_array_equal(h1_child_1_module.w2, 0)
np.testing.assert_array_equal(h2_child_0_module.w3, 0)
np.testing.assert_array_equal(self.parent_module.w0, 0)
np.testing.assert_array_equal(self.h1_child_0_module.w1, 0)
np.testing.assert_array_equal(self.h1_child_1_module.w2, 0)
np.testing.assert_array_equal(self.h2_child_0_module.w3, 0)

assert not h1_child_1_module.wrong_weight_not_load
assert not self.h1_child_1_module.wrong_weight_not_load


if __name__ == "__main__":
Expand Down
12 changes: 8 additions & 4 deletions experimental/jax/tests/nn/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@

class NormTest(absltest.TestCase):

def test_rmsnorm_per_device_forward(self):
mesh = parallel.create_device_mesh(
devices=jax.devices(),
shape=len(jax.devices()),
def create_device_mesh(self):
devices = jax.devices()
return parallel.create_device_mesh(
devices=devices,
shape=len(devices),
)

def test_rmsnorm_per_device_forward(self):
mesh = self.create_device_mesh()
hidden_state_size = 128
eps = 1e-6
rmsnorm_layer = nn.RMSNorm(
Expand Down
7 changes: 4 additions & 3 deletions experimental/jax/tests/parallel/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
class CollectiveOperationsTest(absltest.TestCase):

def _build_mesh(self):
axis = "x"
devices = jax.devices()
device_mesh = jax.experimental.mesh_utils.create_device_mesh(
len(jax.devices()), jax.devices()
(len(devices),), devices
)
mesh = jax.sharding.Mesh(device_mesh, ("x"))
axis = "x"
mesh = jax.sharding.Mesh(device_mesh, (axis,))
return mesh, axis

def test_reduce_scatter(self):
Expand Down
Loading