Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LearningRateMonitor broken on MPS backend with Apple silicon #20250

Open
MalteEbner opened this issue Sep 6, 2024 · 0 comments
Open

LearningRateMonitor broken on MPS backend with Apple silicon #20250

MalteEbner opened this issue Sep 6, 2024 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x

Comments

@MalteEbner
Copy link

MalteEbner commented Sep 6, 2024

Bug description

When the optimizer contains any data of type float64, then adding a LearningRateMonitor causes a Value Error on MPS backends with apple silicon. See the self-contained and minimal example in "How to reproduce the bug" below.

The error is:

  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/lr_monitor.py", line 219, in <dictcomp>
    name: torch.tensor(value, device=trainer.strategy.root_device) for name, value in latest_stat.items()
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

When removing the LearningRateMonitor, the code runs through, thus the optimiser itself is fine.

Note that the quick fix to remove the lr=np.float64(0.01) works only for the minimal example. In my case, the optimiser is imported from an external module and has more parameters, making it much harder to change.

I tried out 4 fixes in the pytorch lightning source code, all of them fix the problem, but might have side-effects or not work on other devices or in other configurations:

Replace torch.tensor(value, device=trainer.strategy.root_device) in this line to one of:

  • torch.tensor(value, device="cpu")
  • torch.tensor(value, device=value.device)
  • torch.tensor(value, device=trainer.strategy.root_device, dtype=torch.float32)
  • value

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import numpy as np
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor

class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(2, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = nn.functional.mse_loss(self(x), y)
        return loss

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=np.float64(0.01))
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [scheduler]

# Data
x = torch.randn(100, 2)
y = torch.randn(100, 1)
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=2)

# Training
model = SimpleModel()
lr_monitor = LearningRateMonitor(logging_interval='step')
trainer = pl.Trainer(max_epochs=10, callbacks=[lr_monitor])
trainer.fit(model, dataloader)

Error messages and logs

Epoch 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎    | 49/50 [00:00<00:00, 188.15it/s, v_num=13]Traceback (most recent call last):
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/test_lr_monitor.py", line 37, in <module>
    trainer.fit(model, dataloader)
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 233, in advance
    call._call_callback_hooks(trainer, "on_train_batch_start", batch, batch_idx)
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 218, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/lr_monitor.py", line 173, in on_train_batch_start
    latest_stat = self._extract_stats(trainer, interval)
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/lr_monitor.py", line 216, in _extract_stats
    trainer.callback_metrics.update({
  File "/Users/malteebnerlightly/Documents/GitHub/lightly-train/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/lr_monitor.py", line 217, in <dictcomp>
    name: torch.tensor(value, device=trainer.strategy.root_device) for name, value in latest_stat.items()
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
Epoch 0:  98%|█████████▊| 49/50 [00:00<00:00, 129.92it/s, v_num=13]  

Environment

Machine is a MacBook Pro with M1-Pro CPU

Current environment
  • CUDA:
    - GPU: None
    - available: False
    - version: None
  • Lightning:
    - lightning-utilities: 0.11.7
    - pytorch-lightning: 2.4.0
    - torch: 2.4.1
    - torchmetrics: 1.4.1
    - torchvision: 0.19.1
  • Packages:
    - absl-py: 2.1.0
    - aenum: 3.1.15
    - aiohappyeyeballs: 2.4.0
    - aiohttp: 3.10.5
    - aiosignal: 1.3.1
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - async-timeout: 4.0.3
    - attrs: 24.2.0
    - autocommand: 2.2.2
    - backports.tarfile: 1.2.0
    - certifi: 2024.7.4
    - charset-normalizer: 3.3.2
    - exceptiongroup: 1.2.2
    - filelock: 3.15.4
    - frozenlist: 1.4.1
    - fsspec: 2024.9.0
    - grpcio: 1.65.5
    - huggingface-hub: 0.24.6
    - hydra-core: 1.3.2
    - idna: 3.8
    - importlib-metadata: 8.0.0
    - importlib-resources: 6.4.0
    - inflect: 7.3.1
    - iniconfig: 2.0.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jinja2: 3.1.4
    - licenseheaders: 0.8.8
    - lightning-utilities: 0.11.7
    - markdown: 3.7
    - markupsafe: 2.1.5
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.0.5
    - mypy: 1.11.1
    - mypy-extensions: 1.0.0
    - networkx: 3.3
    - numpy: 2.1.1
    - omegaconf: 2.3.0
    - packaging: 24.1
    - pillow: 10.4.0
    - platformdirs: 4.2.2
    - pluggy: 1.5.0
    - protobuf: 5.27.3
    - psutil: 6.0.0
    - pydantic: 1.10.18
    - pydantic-core: 2.20.1
    - pydeprecate: 0.3.2
    - pytest: 8.3.2
    - pytest-mock: 3.14.0
    - python-dateutil: 2.9.0.post0
    - pytorch-lightning: 2.4.0
    - pyyaml: 6.0.2
    - regex: 2024.7.24
    - requests: 2.32.3
    - ruff: 0.6.1
    - safetensors: 0.4.4
    - setuptools: 74.1.2
    - six: 1.16.0
    - sympy: 1.13.2
    - tensorboard: 2.17.1
    - tensorboard-data-server: 0.7.2
    - timm: 1.0.8
    - tomli: 2.0.1
    - torch: 2.4.1
    - torchmetrics: 1.4.1
    - torchvision: 0.19.1
    - tqdm: 4.66.5
    - typeguard: 4.3.0
    - types-tqdm: 4.66.0.20240417
    - typing-extensions: 4.12.2
    - urllib3: 2.2.2
    - werkzeug: 3.0.3
    - wheel: 0.43.0
    - yarl: 1.9.11
    - zipp: 3.19.2
  • System:
    - OS: Darwin
    - architecture:
    - 64bit
    -
    - processor: arm
    - python: 3.10.8
    - release: 23.6.0
    - version: Darwin Kernel Version 23.6.0: Mon Jul 29 21:14:30 PDT 2024; root:xnu-10063.141.2~1/RELEASE_ARM64_T6000

More info

No response

@MalteEbner MalteEbner added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.4.x
Projects
None yet
Development

No branches or pull requests

1 participant