You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am experiencing an issue (specifically with DeepSpeed stage 3, not stages 1-2) where the tensors registered within sub-nn.Modules of my LightningModule's main lit_model.networknn.Module are not moved by register_buffer() to the correct device upon training the In particular, I am trying to register buffers as
Bug description
Using the DeepSpeed
configurationI am experiencing an issue (specifically with DeepSpeed stage 3, not stages 1-2) where the tensors registered within sub-
of myLightningModule
are not moved byregister_buffer()
to the correct device upon training
. In particular, I am trying to register buffers aswithin the various submodules of my
. When my optimizer tries to perform a step, I get the errorRuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu!
when trying to use these registered buffers e.g., by multiplying them by feature tensors loaded onto (in this case)
.What version are you seeing the problem on?
How to reproduce the bug
No response
Error messages and logs
Current environment
- GPU:
- available: True
- version: 11.8
- adam-atan2-pytorch: 0.0.10
- alphafold3-pytorch: 0.0.41
- alphafold3-pytorch-lightning-hydra: 0.1.111
- frame-averaging-pytorch: 0.0.19
- lightning: 2.4.0
- lightning-utilities: 0.11.6
- pytorch-lightning: 2.4.0
- rotary-embedding-torch: 0.6.1
- torch: 2.3.0+cu118
- torch-geometric: 2.5.3
- torchaudio: 2.3.0+cu118
- torchmetrics: 1.4.1
- torchtyping: 0.1.4
- torchvision: 0.18.0+cu118
- adam-atan2-pytorch: 0.0.10
- aiofiles: 23.2.1
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- alembic: 1.13.1
- alphafold3-pytorch: 0.0.41
- alphafold3-pytorch-lightning-hydra: 0.1.111
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.4.0
- appdirs: 1.4.4
- argcomplete: 3.3.0
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 23.2.0
- autopage: 0.5.2
- beartype: 0.18.5
- beautifulsoup4: 4.12.3
- biopandas: 0.5.1.dev0
- biopython: 1.83
- bioservices: 1.11.2
- cattrs: 23.2.3
- certifi: 2024.8.30
- cfgv: 3.4.0
- chardet: 5.2.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- cliff: 4.7.0
- cmaes: 0.10.0
- cmd2: 2.4.3
- colorama: 0.4.6
- colorlog: 6.8.2
- colt5-attention: 0.11.0
- comm: 0.2.2
- contourpy: 1.2.1
- cycler: 0.12.1
- debugpy: 1.8.1
- decorator: 5.1.1
- deepdiff: 7.0.1
- deepspeed: 0.15.0
- distlib: 0.3.8
- docker-pycreds: 0.4.0
- easydev: 0.13.2
- einops: 0.8.0
- einx: 0.2.2
- environs: 11.0.0
- exceptiongroup: 1.2.1
- executing: 2.0.1
- fastapi: 0.112.2
- ffmpy: 0.4.0
- filelock: 3.13.1
- fonttools: 4.52.4
- frame-averaging-pytorch: 0.0.19
- freetype-py: 2.3.0
- frozendict: 2.4.4
- frozenlist: 1.4.1
- fsspec: 2024.2.0
- gemmi: 0.6.6
- gevent: 24.2.1
- gitdb: 4.0.11
- gitpython: 3.1.43
- gradio: 4.43.0
- gradio-client: 1.3.0
- gradio-molecule3d: 0.0.5
- graphein: 1.7.6
- greenlet: 3.0.3
- grequests: 0.7.0
- h11: 0.14.0
- hjson: 3.1.0
- httpcore: 1.0.5
- httpx: 0.27.2
- huggingface-hub: 0.23.4
- hydra-colorlog: 1.2.0
- hydra-core: 1.3.2
- hydra-optuna-sweeper: 1.2.0
- identify: 2.5.36
- idna: 3.7
- importlib-resources: 6.4.4
- iniconfig: 2.0.0
- ipykernel: 6.29.4
- ipython: 8.24.0
- jaxtyping: 0.2.28
- jedi: 0.19.1
- jinja2: 3.1.3
- joblib: 1.4.2
- jupyter-client: 8.6.2
- jupyter-core: 5.7.2
- kiwisolver: 1.4.5
- lightning: 2.4.0
- lightning-utilities: 0.11.6
- line-profiler: 4.1.3
- local-attention: 1.9.1
- loguru: 0.7.2
- looseversion: 1.1.2
- lxml: 5.2.2
- mako: 1.3.5
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- marshmallow: 3.21.3
- matplotlib: 3.8.4
- matplotlib-inline: 0.1.7
- mdurl: 0.1.2
- mmtf-python: 1.1.3
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.5
- multipledispatch: 1.0.0
- munkres: 1.1.4
- nest-asyncio: 1.6.0
- networkx: 3.2.1
- ninja:
- nodeenv: 1.8.0
- numpy: 1.23.5
- nvidia-cublas-cu11:
- nvidia-cuda-cupti-cu11: 11.8.87
- nvidia-cuda-nvrtc-cu11: 11.8.89
- nvidia-cuda-runtime-cu11: 11.8.89
- nvidia-cudnn-cu11:
- nvidia-cufft-cu11:
- nvidia-curand-cu11:
- nvidia-cusolver-cu11:
- nvidia-cusparse-cu11:
- nvidia-ml-py: 12.560.30
- nvidia-nccl-cu11: 2.20.5
- nvidia-nvtx-cu11: 11.8.86
- omegaconf: 2.3.0
- optree: 0.11.0
- optuna: 2.10.1
- ordered-set: 4.1.0
- orjson: 3.10.7
- packaging: 24.0
- pandas: 1.5.3
- parso: 0.8.4
- pbr: 6.0.0
- pdbeccdutils: 0.8.5
- pexpect: 4.9.0
- pillow: 10.2.0
- pip: 24.0
- pipx: 1.5.0
- platformdirs: 4.2.2
- plotly: 5.22.0
- pluggy: 1.5.0
- polars: 1.3.0
- pre-commit: 3.7.1
- prettytable: 3.10.0
- prompt-toolkit: 3.0.45
- protobuf: 4.25.4
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-cpuinfo: 9.0.0
- pycairo: 1.26.0
- pydantic: 2.8.2
- pydantic-core: 2.20.1
- pydub: 0.25.1
- pygments: 2.18.0
- pyparsing: 3.1.2
- pyperclip: 1.8.2
- pytest: 8.2.1
- python-dateutil: 2.9.0
- python-dotenv: 1.0.1
- python-multipart: 0.0.9
- pytorch-lightning: 2.4.0
- pytz: 2024.1
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- rdkit: 2024.3.2
- reportlab: 4.1.0
- requests: 2.32.2
- requests-cache: 1.2.0
- retrying: 1.3.4
- rich: 13.7.1
- rich-click: 1.8.2
- rlpycairo: 0.2.0
- rootutils: 1.0.7
- rotary-embedding-torch: 0.6.1
- ruff: 0.6.4
- scikit-learn: 1.5.0
- scipy: 1.13.1
- seaborn: 0.13.2
- semantic-version: 2.10.0
- sentry-sdk: 2.12.0
- setproctitle: 1.3.3
- setuptools: 70.0.0
- sh: 2.0.7
- shellingham: 1.5.4
- shortuuid: 1.0.13
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.1
- soupsieve: 2.5
- sqlalchemy: 2.0.30
- stack-data: 0.6.3
- starlette: 0.38.4
- stevedore: 5.2.0
- suds-community: 1.1.2
- sympy: 1.12
- taylor-series-linear-attention: 0.1.12
- tenacity: 8.3.0
- threadpoolctl: 3.5.0
- timeout-decorator: 0.5.0
- tomli: 2.0.1
- tomlkit: 0.12.0
- torch: 2.3.0+cu118
- torch-geometric: 2.5.3
- torchaudio: 2.3.0+cu118
- torchmetrics: 1.4.1
- torchtyping: 0.1.4
- torchvision: 0.18.0+cu118
- tornado: 6.4
- tqdm: 4.66.4
- traitlets: 5.14.3
- triton: 2.3.0
- typeguard: 2.13.3
- typer: 0.12.5
- typing-extensions: 4.11.0
- tzdata: 2024.1
- unicodedata2: 15.1.0
- url-normalize: 1.4.3
- urllib3: 2.2.1
- userpath: 1.9.2
- uvicorn: 0.30.6
- virtualenv: 20.26.2
- wandb: 0.16.6
- wcwidth: 0.2.13
- websockets: 12.0
- wget: 3.2
- wheel: 0.43.0
- wrapt: 1.16.0
- xarray: 2024.3.0
- xmltodict: 0.13.0
- yarl: 1.9.4
- zope.event: 5.0
- zope.interface: 6.4.post2
- OS: Linux
- architecture:
- 64bit
- processor: x86_64
- python: 3.10.14
- release: 4.18.0-553.16.1.el8_10.x86_64
- version: Proposal for help #1 SMP Thu Aug 8 07:11:46 EDT 2024
More info
No response
The text was updated successfully, but these errors were encountered: