You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am experiencing an issue (specifically with DeepSpeed stage 3, not stages 1-2) where the tensors registered within sub-nn.Modules of my LightningModule's main lit_model.networknn.Module are not moved by register_buffer() to the correct device upon training the lit_module.network. In particular, I am trying to register buffers as
Bug description
Using the DeepSpeed
Strategy
configurationI am experiencing an issue (specifically with DeepSpeed stage 3, not stages 1-2) where the tensors registered within sub-
nn.Modules
of myLightningModule
's mainlit_model.network
nn.Module
are not moved byregister_buffer()
to the correct device upon training thelit_module.network
. In particular, I am trying to register buffers aswithin the various submodules of my
lit_module.network
. When my optimizer tries to perform a step, I get the errorRuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:6 and cpu!
when trying to use these registered buffers e.g., by multiplying them by feature tensors loaded onto (in this case)
cuda:6
.What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
- GPU:
- NVIDIA A100 80GB PCIe
- NVIDIA A100 80GB PCIe
- available: True
- version: 11.8
- adam-atan2-pytorch: 0.0.10
- alphafold3-pytorch: 0.0.41
- alphafold3-pytorch-lightning-hydra: 0.1.111
- frame-averaging-pytorch: 0.0.19
- lightning: 2.4.0
- lightning-utilities: 0.11.6
- pytorch-lightning: 2.4.0
- rotary-embedding-torch: 0.6.1
- torch: 2.3.0+cu118
- torch-geometric: 2.5.3
- torchaudio: 2.3.0+cu118
- torchmetrics: 1.4.1
- torchtyping: 0.1.4
- torchvision: 0.18.0+cu118
- adam-atan2-pytorch: 0.0.10
- aiofiles: 23.2.1
- aiohttp: 3.9.5
- aiosignal: 1.3.1
- alembic: 1.13.1
- alphafold3-pytorch: 0.0.41
- alphafold3-pytorch-lightning-hydra: 0.1.111
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.4.0
- appdirs: 1.4.4
- argcomplete: 3.3.0
- asttokens: 2.4.1
- async-timeout: 4.0.3
- attrs: 23.2.0
- autopage: 0.5.2
- beartype: 0.18.5
- beautifulsoup4: 4.12.3
- biopandas: 0.5.1.dev0
- biopython: 1.83
- bioservices: 1.11.2
- cattrs: 23.2.3
- certifi: 2024.8.30
- cfgv: 3.4.0
- chardet: 5.2.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- cliff: 4.7.0
- cmaes: 0.10.0
- cmd2: 2.4.3
- colorama: 0.4.6
- colorlog: 6.8.2
- colt5-attention: 0.11.0
- comm: 0.2.2
- contourpy: 1.2.1
- cycler: 0.12.1
- debugpy: 1.8.1
- decorator: 5.1.1
- deepdiff: 7.0.1
- deepspeed: 0.15.0
- distlib: 0.3.8
- docker-pycreds: 0.4.0
- easydev: 0.13.2
- einops: 0.8.0
- einx: 0.2.2
- environs: 11.0.0
- exceptiongroup: 1.2.1
- executing: 2.0.1
- fastapi: 0.112.2
- ffmpy: 0.4.0
- filelock: 3.13.1
- fonttools: 4.52.4
- frame-averaging-pytorch: 0.0.19
- freetype-py: 2.3.0
- frozendict: 2.4.4
- frozenlist: 1.4.1
- fsspec: 2024.2.0
- gemmi: 0.6.6
- gevent: 24.2.1
- gitdb: 4.0.11
- gitpython: 3.1.43
- gradio: 4.43.0
- gradio-client: 1.3.0
- gradio-molecule3d: 0.0.5
- graphein: 1.7.6
- greenlet: 3.0.3
- grequests: 0.7.0
- h11: 0.14.0
- hjson: 3.1.0
- httpcore: 1.0.5
- httpx: 0.27.2
- huggingface-hub: 0.23.4
- hydra-colorlog: 1.2.0
- hydra-core: 1.3.2
- hydra-optuna-sweeper: 1.2.0
- identify: 2.5.36
- idna: 3.7
- importlib-resources: 6.4.4
- iniconfig: 2.0.0
- ipykernel: 6.29.4
- ipython: 8.24.0
- jaxtyping: 0.2.28
- jedi: 0.19.1
- jinja2: 3.1.3
- joblib: 1.4.2
- jupyter-client: 8.6.2
- jupyter-core: 5.7.2
- kiwisolver: 1.4.5
- lightning: 2.4.0
- lightning-utilities: 0.11.6
- line-profiler: 4.1.3
- local-attention: 1.9.1
- loguru: 0.7.2
- looseversion: 1.1.2
- lxml: 5.2.2
- mako: 1.3.5
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- marshmallow: 3.21.3
- matplotlib: 3.8.4
- matplotlib-inline: 0.1.7
- mdurl: 0.1.2
- mmtf-python: 1.1.3
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.5
- multipledispatch: 1.0.0
- munkres: 1.1.4
- nest-asyncio: 1.6.0
- networkx: 3.2.1
- ninja: 1.11.1.1
- nodeenv: 1.8.0
- numpy: 1.23.5
- nvidia-cublas-cu11: 11.11.3.6
- nvidia-cuda-cupti-cu11: 11.8.87
- nvidia-cuda-nvrtc-cu11: 11.8.89
- nvidia-cuda-runtime-cu11: 11.8.89
- nvidia-cudnn-cu11: 8.7.0.84
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-curand-cu11: 10.3.0.86
- nvidia-cusolver-cu11: 11.4.1.48
- nvidia-cusparse-cu11: 11.7.5.86
- nvidia-ml-py: 12.560.30
- nvidia-nccl-cu11: 2.20.5
- nvidia-nvtx-cu11: 11.8.86
- omegaconf: 2.3.0
- optree: 0.11.0
- optuna: 2.10.1
- ordered-set: 4.1.0
- orjson: 3.10.7
- packaging: 24.0
- pandas: 1.5.3
- parso: 0.8.4
- pbr: 6.0.0
- pdbeccdutils: 0.8.5
- pexpect: 4.9.0
- pillow: 10.2.0
- pip: 24.0
- pipx: 1.5.0
- platformdirs: 4.2.2
- plotly: 5.22.0
- pluggy: 1.5.0
- polars: 1.3.0
- pre-commit: 3.7.1
- prettytable: 3.10.0
- prompt-toolkit: 3.0.45
- protobuf: 4.25.4
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-cpuinfo: 9.0.0
- pycairo: 1.26.0
- pydantic: 2.8.2
- pydantic-core: 2.20.1
- pydub: 0.25.1
- pygments: 2.18.0
- pyparsing: 3.1.2
- pyperclip: 1.8.2
- pytest: 8.2.1
- python-dateutil: 2.9.0
- python-dotenv: 1.0.1
- python-multipart: 0.0.9
- pytorch-lightning: 2.4.0
- pytz: 2024.1
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- rdkit: 2024.3.2
- reportlab: 4.1.0
- requests: 2.32.2
- requests-cache: 1.2.0
- retrying: 1.3.4
- rich: 13.7.1
- rich-click: 1.8.2
- rlpycairo: 0.2.0
- rootutils: 1.0.7
- rotary-embedding-torch: 0.6.1
- ruff: 0.6.4
- scikit-learn: 1.5.0
- scipy: 1.13.1
- seaborn: 0.13.2
- semantic-version: 2.10.0
- sentry-sdk: 2.12.0
- setproctitle: 1.3.3
- setuptools: 70.0.0
- sh: 2.0.7
- shellingham: 1.5.4
- shortuuid: 1.0.13
- six: 1.16.0
- smmap: 5.0.1
- sniffio: 1.3.1
- soupsieve: 2.5
- sqlalchemy: 2.0.30
- stack-data: 0.6.3
- starlette: 0.38.4
- stevedore: 5.2.0
- suds-community: 1.1.2
- sympy: 1.12
- taylor-series-linear-attention: 0.1.12
- tenacity: 8.3.0
- threadpoolctl: 3.5.0
- timeout-decorator: 0.5.0
- tomli: 2.0.1
- tomlkit: 0.12.0
- torch: 2.3.0+cu118
- torch-geometric: 2.5.3
- torchaudio: 2.3.0+cu118
- torchmetrics: 1.4.1
- torchtyping: 0.1.4
- torchvision: 0.18.0+cu118
- tornado: 6.4
- tqdm: 4.66.4
- traitlets: 5.14.3
- triton: 2.3.0
- typeguard: 2.13.3
- typer: 0.12.5
- typing-extensions: 4.11.0
- tzdata: 2024.1
- unicodedata2: 15.1.0
- url-normalize: 1.4.3
- urllib3: 2.2.1
- userpath: 1.9.2
- uvicorn: 0.30.6
- virtualenv: 20.26.2
- wandb: 0.16.6
- wcwidth: 0.2.13
- websockets: 12.0
- wget: 3.2
- wheel: 0.43.0
- wrapt: 1.16.0
- xarray: 2024.3.0
- xmltodict: 0.13.0
- yarl: 1.9.4
- zope.event: 5.0
- zope.interface: 6.4.post2
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.14
- release: 4.18.0-553.16.1.el8_10.x86_64
- version: Proposal for help #1 SMP Thu Aug 8 07:11:46 EDT 2024
More info
No response
The text was updated successfully, but these errors were encountered: