Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Registered buffers not moved to correct device when using DeepSpeed Stage 3 #20258

Open
amorehead opened this issue Sep 6, 2024 · 2 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@amorehead
Copy link
Contributor

Bug description

Using the DeepSpeed Strategy configuration

_target_: lightning.pytorch.strategies.DeepSpeedStrategy
zero_optimization: true
stage: 3
allgather_bucket_size: 2e8
reduce_bucket_size: 2e8
offload_optimizer: false
offload_parameters: false
partition_activations: false
cpu_checkpointing: false
contiguous_gradients: false
overlap_comm: false

I am experiencing an issue (specifically with DeepSpeed stage 3, not stages 1-2) where the tensors registered within sub-nn.Modules of my LightningModule's main lit_model.network nn.Module are not moved by register_buffer() to the correct device upon training the lit_module.network. In particular, I am trying to register buffers as

distance_bins_tensor = tensor([0.0, 1.0, 2.0, 3.0])
self.register_buffer("distance_bins", distance_bins_tensor)

within the various submodules of my lit_module.network. When my optimizer tries to perform a step, I get the error

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu!

when trying to use these registered buffers e.g., by multiplying them by feature tensors loaded onto (in this case) cuda:6.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100 80GB PCIe
    - NVIDIA A100 80GB PCIe
    - available: True
    - version: 11.8
  • Lightning:
    - adam-atan2-pytorch: 0.0.10
    - alphafold3-pytorch: 0.0.41
    - alphafold3-pytorch-lightning-hydra: 0.1.111
    - frame-averaging-pytorch: 0.0.19
    - lightning: 2.4.0
    - lightning-utilities: 0.11.6
    - pytorch-lightning: 2.4.0
    - rotary-embedding-torch: 0.6.1
    - torch: 2.3.0+cu118
    - torch-geometric: 2.5.3
    - torchaudio: 2.3.0+cu118
    - torchmetrics: 1.4.1
    - torchtyping: 0.1.4
    - torchvision: 0.18.0+cu118
  • Packages:
    - adam-atan2-pytorch: 0.0.10
    - aiofiles: 23.2.1
    - aiohttp: 3.9.5
    - aiosignal: 1.3.1
    - alembic: 1.13.1
    - alphafold3-pytorch: 0.0.41
    - alphafold3-pytorch-lightning-hydra: 0.1.111
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.4.0
    - appdirs: 1.4.4
    - argcomplete: 3.3.0
    - asttokens: 2.4.1
    - async-timeout: 4.0.3
    - attrs: 23.2.0
    - autopage: 0.5.2
    - beartype: 0.18.5
    - beautifulsoup4: 4.12.3
    - biopandas: 0.5.1.dev0
    - biopython: 1.83
    - bioservices: 1.11.2
    - cattrs: 23.2.3
    - certifi: 2024.8.30
    - cfgv: 3.4.0
    - chardet: 5.2.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - cliff: 4.7.0
    - cmaes: 0.10.0
    - cmd2: 2.4.3
    - colorama: 0.4.6
    - colorlog: 6.8.2
    - colt5-attention: 0.11.0
    - comm: 0.2.2
    - contourpy: 1.2.1
    - cycler: 0.12.1
    - debugpy: 1.8.1
    - decorator: 5.1.1
    - deepdiff: 7.0.1
    - deepspeed: 0.15.0
    - distlib: 0.3.8
    - docker-pycreds: 0.4.0
    - easydev: 0.13.2
    - einops: 0.8.0
    - einx: 0.2.2
    - environs: 11.0.0
    - exceptiongroup: 1.2.1
    - executing: 2.0.1
    - fastapi: 0.112.2
    - ffmpy: 0.4.0
    - filelock: 3.13.1
    - fonttools: 4.52.4
    - frame-averaging-pytorch: 0.0.19
    - freetype-py: 2.3.0
    - frozendict: 2.4.4
    - frozenlist: 1.4.1
    - fsspec: 2024.2.0
    - gemmi: 0.6.6
    - gevent: 24.2.1
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - gradio: 4.43.0
    - gradio-client: 1.3.0
    - gradio-molecule3d: 0.0.5
    - graphein: 1.7.6
    - greenlet: 3.0.3
    - grequests: 0.7.0
    - h11: 0.14.0
    - hjson: 3.1.0
    - httpcore: 1.0.5
    - httpx: 0.27.2
    - huggingface-hub: 0.23.4
    - hydra-colorlog: 1.2.0
    - hydra-core: 1.3.2
    - hydra-optuna-sweeper: 1.2.0
    - identify: 2.5.36
    - idna: 3.7
    - importlib-resources: 6.4.4
    - iniconfig: 2.0.0
    - ipykernel: 6.29.4
    - ipython: 8.24.0
    - jaxtyping: 0.2.28
    - jedi: 0.19.1
    - jinja2: 3.1.3
    - joblib: 1.4.2
    - jupyter-client: 8.6.2
    - jupyter-core: 5.7.2
    - kiwisolver: 1.4.5
    - lightning: 2.4.0
    - lightning-utilities: 0.11.6
    - line-profiler: 4.1.3
    - local-attention: 1.9.1
    - loguru: 0.7.2
    - looseversion: 1.1.2
    - lxml: 5.2.2
    - mako: 1.3.5
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - marshmallow: 3.21.3
    - matplotlib: 3.8.4
    - matplotlib-inline: 0.1.7
    - mdurl: 0.1.2
    - mmtf-python: 1.1.3
    - mpmath: 1.3.0
    - msgpack: 1.0.8
    - multidict: 6.0.5
    - multipledispatch: 1.0.0
    - munkres: 1.1.4
    - nest-asyncio: 1.6.0
    - networkx: 3.2.1
    - ninja: 1.11.1.1
    - nodeenv: 1.8.0
    - numpy: 1.23.5
    - nvidia-cublas-cu11: 11.11.3.6
    - nvidia-cuda-cupti-cu11: 11.8.87
    - nvidia-cuda-nvrtc-cu11: 11.8.89
    - nvidia-cuda-runtime-cu11: 11.8.89
    - nvidia-cudnn-cu11: 8.7.0.84
    - nvidia-cufft-cu11: 10.9.0.58
    - nvidia-curand-cu11: 10.3.0.86
    - nvidia-cusolver-cu11: 11.4.1.48
    - nvidia-cusparse-cu11: 11.7.5.86
    - nvidia-ml-py: 12.560.30
    - nvidia-nccl-cu11: 2.20.5
    - nvidia-nvtx-cu11: 11.8.86
    - omegaconf: 2.3.0
    - optree: 0.11.0
    - optuna: 2.10.1
    - ordered-set: 4.1.0
    - orjson: 3.10.7
    - packaging: 24.0
    - pandas: 1.5.3
    - parso: 0.8.4
    - pbr: 6.0.0
    - pdbeccdutils: 0.8.5
    - pexpect: 4.9.0
    - pillow: 10.2.0
    - pip: 24.0
    - pipx: 1.5.0
    - platformdirs: 4.2.2
    - plotly: 5.22.0
    - pluggy: 1.5.0
    - polars: 1.3.0
    - pre-commit: 3.7.1
    - prettytable: 3.10.0
    - prompt-toolkit: 3.0.45
    - protobuf: 4.25.4
    - psutil: 5.9.8
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.2
    - py-cpuinfo: 9.0.0
    - pycairo: 1.26.0
    - pydantic: 2.8.2
    - pydantic-core: 2.20.1
    - pydub: 0.25.1
    - pygments: 2.18.0
    - pyparsing: 3.1.2
    - pyperclip: 1.8.2
    - pytest: 8.2.1
    - python-dateutil: 2.9.0
    - python-dotenv: 1.0.1
    - python-multipart: 0.0.9
    - pytorch-lightning: 2.4.0
    - pytz: 2024.1
    - pyyaml: 6.0.1
    - pyzmq: 26.0.3
    - rdkit: 2024.3.2
    - reportlab: 4.1.0
    - requests: 2.32.2
    - requests-cache: 1.2.0
    - retrying: 1.3.4
    - rich: 13.7.1
    - rich-click: 1.8.2
    - rlpycairo: 0.2.0
    - rootutils: 1.0.7
    - rotary-embedding-torch: 0.6.1
    - ruff: 0.6.4
    - scikit-learn: 1.5.0
    - scipy: 1.13.1
    - seaborn: 0.13.2
    - semantic-version: 2.10.0
    - sentry-sdk: 2.12.0
    - setproctitle: 1.3.3
    - setuptools: 70.0.0
    - sh: 2.0.7
    - shellingham: 1.5.4
    - shortuuid: 1.0.13
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - soupsieve: 2.5
    - sqlalchemy: 2.0.30
    - stack-data: 0.6.3
    - starlette: 0.38.4
    - stevedore: 5.2.0
    - suds-community: 1.1.2
    - sympy: 1.12
    - taylor-series-linear-attention: 0.1.12
    - tenacity: 8.3.0
    - threadpoolctl: 3.5.0
    - timeout-decorator: 0.5.0
    - tomli: 2.0.1
    - tomlkit: 0.12.0
    - torch: 2.3.0+cu118
    - torch-geometric: 2.5.3
    - torchaudio: 2.3.0+cu118
    - torchmetrics: 1.4.1
    - torchtyping: 0.1.4
    - torchvision: 0.18.0+cu118
    - tornado: 6.4
    - tqdm: 4.66.4
    - traitlets: 5.14.3
    - triton: 2.3.0
    - typeguard: 2.13.3
    - typer: 0.12.5
    - typing-extensions: 4.11.0
    - tzdata: 2024.1
    - unicodedata2: 15.1.0
    - url-normalize: 1.4.3
    - urllib3: 2.2.1
    - userpath: 1.9.2
    - uvicorn: 0.30.6
    - virtualenv: 20.26.2
    - wandb: 0.16.6
    - wcwidth: 0.2.13
    - websockets: 12.0
    - wget: 3.2
    - wheel: 0.43.0
    - wrapt: 1.16.0
    - xarray: 2024.3.0
    - xmltodict: 0.13.0
    - yarl: 1.9.4
    - zope.event: 5.0
    - zope.interface: 6.4.post2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.14
    - release: 4.18.0-553.16.1.el8_10.x86_64
    - version: Proposal for help #1 SMP Thu Aug 8 07:11:46 EDT 2024

More info

No response

@amorehead amorehead added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 6, 2024
@cstsunfu
Copy link

I have a similar issue. Any updates on this?

@amorehead
Copy link
Contributor Author

Not yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

2 participants