Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print local sizes and strides of each nn.Linear #4067

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Mar 12, 2025

$ mpirun -np 2 --output-filename /tmp/test_deepseek_v3 pytest tests/python/test_deepseek_v3.py -k backward -v -s --only-mpi --runxfail
$ cat /tmp/test_deepseek_v3/1/rank.0/stdout 

self_attn.q_a_proj: inp = ((torch.Size([1, 2048, 7168]), (14680064, 7168, 1)),), weight = (torch.Size([1536, 7168]), (7168, 1)), out = (torch.Size([1, 2048, 1536]), (3145728, 1536, 1))
self_attn.q_b_proj: inp = ((torch.Size([1, 2048, 1536]), (3145728, 1536, 1)),), weight = (torch.Size([12288, 1536]), (1536, 1)), out = (torch.Size([1, 2048, 12288]), (25165824, 12288, 1))
self_attn.kv_a_proj_with_mqa: inp = ((torch.Size([1, 2048, 7168]), (14680064, 7168, 1)),), weight = (torch.Size([576, 7168]), (7168, 1)), out = (torch.Size([1, 2048, 576]), (1179648, 576, 1))
self_attn.kv_b_proj: inp = ((torch.Size([1, 2048, 512]), (1048576, 512, 1)),), weight = (torch.Size([16384, 512]), (512, 1)), out = (torch.Size([1, 2048, 16384]), (33554432, 16384, 1))
self_attn.o_proj: inp = ((torch.Size([1, 2048, 8192]), (16777216, 8192, 1)),), weight = (torch.Size([7168, 8192]), (8192, 1)), out = (torch.Size([1, 2048, 7168]), (14680064, 7168, 1))
mlp.experts.1.gate_proj: inp = ((torch.Size([872, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([872, 1024]), (1024, 1))
mlp.experts.1.up_proj: inp = ((torch.Size([872, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([872, 1024]), (1024, 1))
mlp.experts.1.down_proj: inp = ((torch.Size([872, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([872, 7168]), (7168, 1))
mlp.experts.3.gate_proj: inp = ((torch.Size([874, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([874, 1024]), (1024, 1))
mlp.experts.3.up_proj: inp = ((torch.Size([874, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([874, 1024]), (1024, 1))
mlp.experts.3.down_proj: inp = ((torch.Size([874, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([874, 7168]), (7168, 1))
mlp.experts.5.gate_proj: inp = ((torch.Size([876, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([876, 1024]), (1024, 1))
mlp.experts.5.up_proj: inp = ((torch.Size([876, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([876, 1024]), (1024, 1))
mlp.experts.5.down_proj: inp = ((torch.Size([876, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([876, 7168]), (7168, 1))
mlp.experts.7.gate_proj: inp = ((torch.Size([903, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([903, 1024]), (1024, 1))
mlp.experts.7.up_proj: inp = ((torch.Size([903, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([903, 1024]), (1024, 1))
mlp.experts.7.down_proj: inp = ((torch.Size([903, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([903, 7168]), (7168, 1))
mlp.experts.9.gate_proj: inp = ((torch.Size([906, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([906, 1024]), (1024, 1))
mlp.experts.9.up_proj: inp = ((torch.Size([906, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([906, 1024]), (1024, 1))
mlp.experts.9.down_proj: inp = ((torch.Size([906, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([906, 7168]), (7168, 1))
mlp.experts.11.gate_proj: inp = ((torch.Size([883, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([883, 1024]), (1024, 1))
mlp.experts.11.up_proj: inp = ((torch.Size([883, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([883, 1024]), (1024, 1))
mlp.experts.11.down_proj: inp = ((torch.Size([883, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([883, 7168]), (7168, 1))
mlp.experts.13.gate_proj: inp = ((torch.Size([887, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([887, 1024]), (1024, 1))
mlp.experts.13.up_proj: inp = ((torch.Size([887, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([887, 1024]), (1024, 1))
mlp.experts.13.down_proj: inp = ((torch.Size([887, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([887, 7168]), (7168, 1))
mlp.experts.15.gate_proj: inp = ((torch.Size([799, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([799, 1024]), (1024, 1))
mlp.experts.15.up_proj: inp = ((torch.Size([799, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([799, 1024]), (1024, 1))
mlp.experts.15.down_proj: inp = ((torch.Size([799, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([799, 7168]), (7168, 1))
mlp.experts.17.gate_proj: inp = ((torch.Size([888, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([888, 1024]), (1024, 1))
mlp.experts.17.up_proj: inp = ((torch.Size([888, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([888, 1024]), (1024, 1))
mlp.experts.17.down_proj: inp = ((torch.Size([888, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([888, 7168]), (7168, 1))
mlp.experts.19.gate_proj: inp = ((torch.Size([827, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([827, 1024]), (1024, 1))
mlp.experts.19.up_proj: inp = ((torch.Size([827, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([827, 1024]), (1024, 1))
mlp.experts.19.down_proj: inp = ((torch.Size([827, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([827, 7168]), (7168, 1))
mlp.experts.21.gate_proj: inp = ((torch.Size([803, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([803, 1024]), (1024, 1))
mlp.experts.21.up_proj: inp = ((torch.Size([803, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([803, 1024]), (1024, 1))
mlp.experts.21.down_proj: inp = ((torch.Size([803, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([803, 7168]), (7168, 1))
mlp.experts.22.gate_proj: inp = ((torch.Size([2048, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([2048, 1024]), (1024, 1))
mlp.experts.22.up_proj: inp = ((torch.Size([2048, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([2048, 1024]), (1024, 1))
mlp.experts.22.down_proj: inp = ((torch.Size([2048, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([2048, 7168]), (7168, 1))
mlp.experts.23.gate_proj: inp = ((torch.Size([480, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([480, 1024]), (1024, 1))
mlp.experts.23.up_proj: inp = ((torch.Size([480, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([480, 1024]), (1024, 1))
mlp.experts.23.down_proj: inp = ((torch.Size([480, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([480, 7168]), (7168, 1))
mlp.experts.25.gate_proj: inp = ((torch.Size([174, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([174, 1024]), (1024, 1))
mlp.experts.25.up_proj: inp = ((torch.Size([174, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([174, 1024]), (1024, 1))
mlp.experts.25.down_proj: inp = ((torch.Size([174, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([174, 7168]), (7168, 1))
mlp.experts.27.gate_proj: inp = ((torch.Size([52, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([52, 1024]), (1024, 1))
mlp.experts.27.up_proj: inp = ((torch.Size([52, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([52, 1024]), (1024, 1))
mlp.experts.27.down_proj: inp = ((torch.Size([52, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([52, 7168]), (7168, 1))
mlp.experts.29.gate_proj: inp = ((torch.Size([14, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([14, 1024]), (1024, 1))
mlp.experts.29.up_proj: inp = ((torch.Size([14, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([14, 1024]), (1024, 1))
mlp.experts.29.down_proj: inp = ((torch.Size([14, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([14, 7168]), (7168, 1))
mlp.experts.31.gate_proj: inp = ((torch.Size([2, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([2, 1024]), (1024, 1))
mlp.experts.31.up_proj: inp = ((torch.Size([2, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([2, 1024]), (1024, 1))
mlp.experts.31.down_proj: inp = ((torch.Size([2, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([2, 7168]), (7168, 1))
mlp.experts.34.gate_proj: inp = ((torch.Size([2048, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([2048, 1024]), (1024, 1))
mlp.experts.34.up_proj: inp = ((torch.Size([2048, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([2048, 1024]), (1024, 1))
mlp.experts.34.down_proj: inp = ((torch.Size([2048, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([2048, 7168]), (7168, 1))
mlp.experts.40.gate_proj: inp = ((torch.Size([2048, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([2048, 1024]), (1024, 1))
mlp.experts.40.up_proj: inp = ((torch.Size([2048, 7168]), (7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([2048, 1024]), (1024, 1))
mlp.experts.40.down_proj: inp = ((torch.Size([2048, 1024]), (1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([2048, 7168]), (7168, 1))
mlp.shared_experts.gate_proj: inp = ((torch.Size([1, 2048, 7168]), (14680064, 7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([1, 2048, 1024]), (2097152, 1024, 1))
mlp.shared_experts.up_proj: inp = ((torch.Size([1, 2048, 7168]), (14680064, 7168, 1)),), weight = (torch.Size([1024, 7168]), (7168, 1)), out = (torch.Size([1, 2048, 1024]), (2097152, 1024, 1))
mlp.shared_experts.down_proj: inp = ((torch.Size([1, 2048, 1024]), (2097152, 1024, 1)),), weight = (torch.Size([7168, 1024]), (1024, 1)), out = (torch.Size([1, 2048, 7168]), (14680064, 7168, 1))

It requires a patch to run and therefore xfails by default. However,
it's still useful to run backprop with a patch to collect shapes and
layouts.
Copy link

github-actions bot commented Mar 12, 2025

Review updated until commit d886082

Description

  • Adds a new test for transformer layer parallelization

  • Prints local sizes and strides of each nn.Linear module

  • Includes a setup for process group and CUDA profiler

  • Parametrizes tests for forward, backward, and inference


Changes walkthrough 📝

Relevant files
Tests
test_deepseek_v3.py
Add transformer layer parallelization test and logging     

tests/python/test_deepseek_v3.py

  • Added new test file for transformer layer parallelization
  • Implemented setup for process group and CUDA profiler
  • Parametrized tests for different compute types (forward, backward,
    inference)
  • Registered forward hooks to print sizes and strides of nn.Linear
    modules
  • +225/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Impact

    The PR introduces print statements that could impact performance, especially in a distributed environment. Consider using logging with a higher level (e.g., DEBUG) or conditional printing.

        f"{name}: inp = {inp_sizes}, weight = {weight_sizes}, out = {out_sizes}"
    )
    Code Duplication

    The get_size_and_strides function is defined inside the print_size_hook function, which can lead to code duplication if used elsewhere. Consider moving it to a higher scope or a utility module.

    def hook(module, inp, out):
        def get_size_and_strides(x):
            if isinstance(x, DTensor):
                return get_size_and_strides(x.to_local())
    
            if isinstance(x, torch.Tensor):
                return x.size(), x.stride()
    
            return x
    Error Handling

    The print_size_hook function does not handle potential errors, such as when x is neither a DTensor nor a torch.Tensor. Consider adding error handling or assertions.

    def hook(module, inp, out):
        def get_size_and_strides(x):
            if isinstance(x, DTensor):
                return get_size_and_strides(x.to_local())
    
            if isinstance(x, torch.Tensor):
                return x.size(), x.stride()
    
            return x

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant