Skip to content

Commit 2c43eba

Browse files
authored
Updated test tolerances for H100 (linkedin#55)
## Summary - Updated test tolerances to pass on H100 ## Testing Done - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ``` $ make checkstyle flake8 . --exclude=.venv; flake8_status=$?; \ isort .; isort_status=$?; \ black .; black_status=$?; \ if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \ exit 1; \ fi Skipped 2 files All done! ✨ 🍰 ✨ 51 files left unchanged. $ make test pytest --disable-warnings test/ --ignore=test/convergence ================================================================================ test session starts ================================================================================ platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0 rootdir: /home/jobuser/Liger-Kernel plugins: lipy-config-base-32.0.27, lipy-fabric-35.2.19, lipy-test-8.0.66, datadir-1.3.1 collected 114 items test/transformers/test_cross_entropy.py .......................................................... [ 50%] test/transformers/test_fused_linear_cross_entropy.py ...... [ 56%] test/transformers/test_geglu.py ........ [ 63%] test/transformers/test_rms_norm.py ................ [ 77%] test/transformers/test_rope.py ............ [ 87%] test/transformers/test_swiglu.py ........ [ 94%] test/transformers/test_trainer_integration.py ... [ 97%] test/transformers/test_transformers_monkey_patch.py . [ 98%] test/triton/test_triton_monkey_patch.py .. [100%] =============================================================================== 114 passed in 35.25s ================================================================================ (.venv) jobuser [ ~/Liger-Kernel ]$ make test-convergence HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence ================================================================================ test session starts ================================================================================ platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0 rootdir: /home/jobuser/Liger-Kernel plugins: lipy-config-base-32.0.27, lipy-fabric-35.2.19, lipy-test-8.0.66, datadir-1.3.1 collected 6 items test/convergence/test_mini_models.py .... [ 66%] test/convergence/test_mini_models_no_logits.py .. [100%] ================================================================================ 6 passed in 23.04s ================================================================================= ```
1 parent b418557 commit 2c43eba

File tree

5 files changed

+16
-17
lines changed

5 files changed

+16
-17
lines changed

.flake8

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ max-line-length = 120
44
exclude =
55
.git,
66
__pycache__,
7-
benchmark_internal/others
7+
benchmark_internal/others,
8+
.venv
89
# E203: https://github.com/psf/black/issues/315
910
extend-ignore=E501,B006,E731,A002,E203

examples/medusa/train.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@
3737

3838
@dataclass
3939
class ModelArguments:
40-
model_name_or_path: Optional[str] = field(
41-
default="meta-llama/Meta-Llama-3-8B"
42-
)
40+
model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B")
4341

4442

4543
@dataclass

setup.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
package_dir={"": "src"},
1414
packages=find_namespace_packages(where="src"),
1515
classifiers=[
16-
'Development Status :: 4 - Beta',
17-
'Intended Audience :: Developers',
18-
'Intended Audience :: Science/Research',
19-
'Intended Audience :: Education',
20-
'License :: OSI Approved :: BSD License',
21-
'Programming Language :: Python :: 3',
22-
'Programming Language :: Python :: 3.8',
23-
'Programming Language :: Python :: 3.9',
24-
'Programming Language :: Python :: 3.10',
25-
'Topic :: Software Development :: Libraries',
26-
'Topic :: Scientific/Engineering :: Artificial Intelligence',
16+
"Development Status :: 4 - Beta",
17+
"Intended Audience :: Developers",
18+
"Intended Audience :: Science/Research",
19+
"Intended Audience :: Education",
20+
"License :: OSI Approved :: BSD License",
21+
"Programming Language :: Python :: 3",
22+
"Programming Language :: Python :: 3.8",
23+
"Programming Language :: Python :: 3.9",
24+
"Programming Language :: Python :: 3.10",
25+
"Topic :: Software Development :: Libraries",
26+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
2727
],
2828
keywords="triton,kernels,LLM training,deep learning,Hugging Face,PyTorch,GPU optimization",
2929
include_package_data=True,

test/transformers/test_cross_entropy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _test_correctness_not_last_layer_once(
101101
[
102102
(0.1, torch.bfloat16, 1e-8, 5e-2),
103103
(1.0, torch.bfloat16, 1e-8, 5e-2),
104-
(10.0, torch.bfloat16, 1e-8, 5e-2),
104+
(10.0, torch.bfloat16, 1e-7, 5e-2),
105105
(0.1, torch.float32, 1e-8, 1e-6),
106106
(1.0, torch.float32, 1e-8, 1e-6),
107107
(10.0, torch.float32, 1e-8, 1e-6),

test/transformers/test_rms_norm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def forward(self, hidden_states):
4141
@pytest.mark.parametrize(
4242
"dtype, atol, rtol",
4343
[
44-
(torch.float32, 1e-4, 1e-7),
44+
(torch.float32, 1e-4, 1e-6),
4545
(torch.bfloat16, 5.0, 1e-5),
4646
],
4747
)

0 commit comments

Comments
 (0)