Skip to content

Commit a3d691c

Browse files
Gulin7vfdev-5
andauthored
Fix deprecated statement (#3307)
* fix-deprecated-warning Replaced torch.cuda.amp.autocast with torch.amp.autocast("cuda",...). * autopep8 fix * Update torch version to 1.12.0 * Address PR comments * Revert unwanted changes * Fix regex * Revert change in CycleGAN_with_torch_cuda_amp * Fix regex in test_create_supervised * Update ignite/engine/__init__.py * Update tests/ignite/engine/test_create_supervised.py --------- Co-authored-by: Gulin7 <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent b636374 commit a3d691c

File tree

9 files changed

+48
-41
lines changed

9 files changed

+48
-41
lines changed

examples/cifar10/main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import torch.nn as nn
88
import torch.optim as optim
99
import utils
10-
from torch.cuda.amp import autocast, GradScaler
10+
from torch.amp import autocast
11+
from torch.cuda.amp import GradScaler
1112

1213
import ignite
1314
import ignite.distributed as idist
@@ -299,7 +300,7 @@ def train_step(engine, batch):
299300

300301
model.train()
301302

302-
with autocast(enabled=with_amp):
303+
with autocast("cuda", enabled=with_amp):
303304
y_pred = model(x)
304305
loss = criterion(y_pred, y)
305306

@@ -355,7 +356,7 @@ def evaluate_step(engine: Engine, batch):
355356
x = x.to(device, non_blocking=True)
356357
y = y.to(device, non_blocking=True)
357358

358-
with autocast(enabled=with_amp):
359+
with autocast("cuda", enabled=with_amp):
359360
output = model(x)
360361
return output, y
361362

examples/cifar100_amp_benchmark/benchmark_torch_cuda_amp.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import fire
22
import torch
3-
from torch.cuda.amp import autocast, GradScaler
3+
from torch.amp import autocast
4+
from torch.cuda.amp import GradScaler
45
from torch.nn import CrossEntropyLoss
56
from torch.optim import SGD
67
from torchvision.models import wide_resnet50_2
@@ -34,7 +35,7 @@ def train_step(engine, batch):
3435
optimizer.zero_grad()
3536

3637
# Runs the forward pass with autocasting.
37-
with autocast():
38+
with autocast("cuda"):
3839
y_pred = model(x)
3940
loss = criterion(y_pred, y)
4041

examples/cifar10_qat/main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import torch.nn as nn
77
import torch.optim as optim
88
import utils
9-
from torch.cuda.amp import autocast, GradScaler
9+
from torch.amp import autocast
10+
from torch.cuda.amp import GradScaler
1011

1112
import ignite
1213
import ignite.distributed as idist
@@ -283,7 +284,7 @@ def train_step(engine, batch):
283284

284285
model.train()
285286

286-
with autocast(enabled=with_amp):
287+
with autocast("cuda", enabled=with_amp):
287288
y_pred = model(x)
288289
loss = criterion(y_pred, y)
289290

examples/notebooks/CycleGAN_with_torch_cuda_amp.ipynb

+3-2
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@
887887
"id": "JE8dLeEfIl_Z"
888888
},
889889
"source": [
890-
"We will use [`torch.cuda.amp.autocast`](https://pytorch.org/docs/master/amp.html#torch.cuda.amp.autocast) and [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/master/amp.html#torch.cuda.amp.GradScaler) to perform automatic mixed precision training. Our code follows a [typical mixed precision training example](https://pytorch.org/docs/master/notes/amp_examples.html#typical-mixed-precision-training)."
890+
"We will use [`torch.amp.autocast`](https://pytorch.org/docs/master/amp.html#torch.amp.autocast) and [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/master/amp.html#torch.cuda.amp.GradScaler) to perform automatic mixed precision training. Our code follows a [typical mixed precision training example](https://pytorch.org/docs/master/notes/amp_examples.html#typical-mixed-precision-training)."
891891
]
892892
},
893893
{
@@ -896,7 +896,8 @@
896896
"id": "vrJls4p-FRcA"
897897
},
898898
"source": [
899-
"from torch.cuda.amp import autocast, GradScaler\n",
899+
"from torch.cuda.amp import GradScaler\n",
900+
"from torch.amp import autocast\n",
900901
"\n",
901902
"from ignite.utils import convert_tensor\n",
902903
"import torch.nn.functional as F\n",

examples/references/classification/imagenet/main.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import torch
77

88
try:
9-
from torch.cuda.amp import autocast, GradScaler
9+
from torch.amp import autocast
10+
from torch.cuda.amp import GradScaler
1011
except ImportError:
11-
raise RuntimeError("Please, use recent PyTorch version, e.g. >=1.6.0")
12+
raise RuntimeError("Please, use recent PyTorch version, e.g. >=1.12.0")
1213

1314
import dataflow as data
1415
import utils
@@ -144,7 +145,7 @@ def create_trainer(model, optimizer, criterion, train_sampler, config, logger, w
144145
def training_step(engine, batch):
145146
model.train()
146147
x, y = prepare_batch(batch, device=device, non_blocking=True)
147-
with autocast(enabled=with_amp):
148+
with autocast("cuda", enabled=with_amp):
148149
y_pred = model(x)
149150
y_pred = model_output_transform(y_pred)
150151
loss = criterion(y_pred, y) / accumulation_steps
@@ -235,7 +236,7 @@ def create_evaluator(model, metrics, config, with_clearml, tag="val"):
235236
@torch.no_grad()
236237
def evaluate_step(engine, batch):
237238
model.eval()
238-
with autocast(enabled=with_amp):
239+
with autocast("cuda", enabled=with_amp):
239240
x, y = prepare_batch(batch, device=config.device, non_blocking=True)
240241
y_pred = model(x)
241242
y_pred = model_output_transform(y_pred)

examples/references/segmentation/pascal_voc2012/main.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import torch
77

88
try:
9-
from torch.cuda.amp import autocast, GradScaler
9+
from torch.amp import autocast
10+
from torch.cuda.amp import GradScaler
1011
except ImportError:
11-
raise RuntimeError("Please, use recent PyTorch version, e.g. >=1.6.0")
12+
raise RuntimeError("Please, use recent PyTorch version, e.g. >=1.12.0")
1213

1314
import dataflow as data
1415
import utils
@@ -191,7 +192,7 @@ def create_trainer(model, optimizer, criterion, train_sampler, config, logger, w
191192
def forward_pass(batch):
192193
model.train()
193194
x, y = prepare_batch(batch, device=device, non_blocking=True)
194-
with autocast(enabled=with_amp):
195+
with autocast("cuda", enabled=with_amp):
195196
y_pred = model(x)
196197
y_pred = model_output_transform(y_pred)
197198
loss = criterion(y_pred, y) / accumulation_steps
@@ -272,7 +273,7 @@ def create_evaluator(model, metrics, config, with_clearml, tag="val"):
272273
@torch.no_grad()
273274
def evaluate_step(engine, batch):
274275
model.eval()
275-
with autocast(enabled=with_amp):
276+
with autocast("cuda", enabled=with_amp):
276277
x, y = prepare_batch(batch, device=config.device, non_blocking=True)
277278
y_pred = model(x)
278279
y_pred = model_output_transform(y_pred)

examples/transformers/main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import torch.nn as nn
88
import torch.optim as optim
99
import utils
10-
from torch.cuda.amp import autocast, GradScaler
10+
from torch.amp import autocast
11+
from torch.cuda.amp import GradScaler
1112

1213
import ignite
1314
import ignite.distributed as idist
@@ -309,7 +310,7 @@ def train_step(engine, batch):
309310

310311
model.train()
311312

312-
with autocast(enabled=with_amp):
313+
with autocast("cuda", enabled=with_amp):
313314
y_pred = model(input_batch)
314315
loss = criterion(y_pred, labels)
315316

@@ -373,7 +374,7 @@ def evaluate_step(engine, batch):
373374
input_batch = {k: v.to(device, non_blocking=True, dtype=torch.long) for k, v in batch[0].items()}
374375
labels = labels.to(device, non_blocking=True, dtype=torch.float)
375376

376-
with autocast(enabled=with_amp):
377+
with autocast("cuda", enabled=with_amp):
377378
output = model(input_batch)
378379
return output, labels
379380

ignite/engine/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ def supervised_training_step_amp(
185185
"""
186186

187187
try:
188-
from torch.cuda.amp import autocast
188+
from torch.amp import autocast
189189
except ImportError:
190-
raise ImportError("Please install torch>=1.6.0 to use amp_mode='amp'.")
190+
raise ImportError("Please install torch>=1.12.0 to use amp_mode='amp'.")
191191

192192
if gradient_accumulation_steps <= 0:
193193
raise ValueError(
@@ -200,7 +200,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
200200
optimizer.zero_grad()
201201
model.train()
202202
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
203-
with autocast(enabled=True):
203+
with autocast("cuda", enabled=True):
204204
output = model_fn(model, x)
205205
y_pred = model_transform(output)
206206
loss = loss_fn(y_pred, y)
@@ -726,15 +726,15 @@ def supervised_evaluation_step_amp(
726726
Added `model_fn` to customize model's application on the sample
727727
"""
728728
try:
729-
from torch.cuda.amp import autocast
729+
from torch.amp import autocast
730730
except ImportError:
731-
raise ImportError("Please install torch>=1.6.0 to use amp_mode='amp'.")
731+
raise ImportError("Please install torch>=1.12.0 to use amp_mode='amp'.")
732732

733733
def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
734734
model.eval()
735735
with torch.no_grad():
736736
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
737-
with autocast(enabled=True):
737+
with autocast("cuda", enabled=True):
738738
output = model_fn(model, x)
739739
y_pred = model_transform(output)
740740
return output_transform(x, y, y_pred)

tests/ignite/engine/test_create_supervised.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def _():
168168
trainer.run(data)
169169

170170

171-
@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
171+
@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
172172
def test_create_supervised_training_scalar_assignment():
173173
with mock.patch("ignite.engine._check_arg") as check_arg_mock:
174174
check_arg_mock.return_value = None, torch.cuda.amp.GradScaler(enabled=False)
@@ -447,21 +447,21 @@ def test_create_supervised_trainer_apex_error():
447447
def mock_torch_cuda_amp_module():
448448
with patch.dict(
449449
"sys.modules",
450-
{"torch.cuda.amp": None, "torch.cuda.amp.grad_scaler": None, "torch.cuda.amp.autocast_mode": None},
450+
{"torch.amp": None, "torch.cuda.amp": None, "torch.amp.autocast_mode": None},
451451
):
452452
yield torch
453453

454454

455455
def test_create_supervised_trainer_amp_error(mock_torch_cuda_amp_module):
456-
with pytest.raises(ImportError, match="Please install torch>=1.6.0 to use amp_mode='amp'."):
456+
with pytest.raises(ImportError, match="Please install torch>=1.12.0 to use amp_mode='amp'."):
457457
_test_create_supervised_trainer_wrong_accumulation(trainer_device="cpu", amp_mode="amp")
458-
with pytest.raises(ImportError, match="Please install torch>=1.6.0 to use amp_mode='amp'."):
458+
with pytest.raises(ImportError, match="Please install torch>=1.12.0 to use amp_mode='amp'."):
459459
_test_create_supervised_trainer(amp_mode="amp")
460460
with pytest.raises(ImportError, match="Please install torch>=1.6.0 to use scaler argument."):
461461
_test_create_supervised_trainer(amp_mode="amp", scaler=True)
462462

463463

464-
@pytest.mark.skipif(Version(torch.__version__) < Version("1.5.0"), reason="Skip if < 1.5.0")
464+
@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
465465
def test_create_supervised_trainer_scaler_not_amp():
466466
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
467467

@@ -501,7 +501,7 @@ def test_create_supervised_trainer_on_mps():
501501
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
502502

503503

504-
@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
504+
@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
505505
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
506506
def test_create_supervised_trainer_on_cuda_amp():
507507
model_device = trainer_device = "cuda"
@@ -517,7 +517,7 @@ def test_create_supervised_trainer_on_cuda_amp():
517517
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device, amp_mode="amp")
518518

519519

520-
@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
520+
@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
521521
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
522522
def test_create_supervised_trainer_on_cuda_amp_scaler():
523523
model_device = trainer_device = "cuda"
@@ -630,8 +630,8 @@ def test_create_supervised_evaluator():
630630
_test_mocked_supervised_evaluator()
631631

632632
# older versions didn't have the autocast method so we skip the test for older builds
633-
if Version(torch.__version__) >= Version("1.6.0"):
634-
with mock.patch("torch.cuda.amp.autocast") as mock_torch_cuda_amp_module:
633+
if Version(torch.__version__) >= Version("1.12.0"):
634+
with mock.patch("torch.amp.autocast") as mock_torch_cuda_amp_module:
635635
_test_create_evaluation_step_amp(mock_torch_cuda_amp_module)
636636

637637

@@ -640,8 +640,8 @@ def test_create_supervised_evaluator_on_cpu():
640640
_test_mocked_supervised_evaluator(evaluator_device="cpu")
641641

642642
# older versions didn't have the autocast method so we skip the test for older builds
643-
if Version(torch.__version__) >= Version("1.6.0"):
644-
with mock.patch("torch.cuda.amp.autocast") as mock_torch_cuda_amp_module:
643+
if Version(torch.__version__) >= Version("1.12.0"):
644+
with mock.patch("torch.amp.autocast") as mock_torch_cuda_amp_module:
645645
_test_create_evaluation_step(mock_torch_cuda_amp_module, evaluator_device="cpu")
646646
_test_create_evaluation_step_amp(mock_torch_cuda_amp_module, evaluator_device="cpu")
647647

@@ -651,8 +651,8 @@ def test_create_supervised_evaluator_traced_on_cpu():
651651
_test_mocked_supervised_evaluator(evaluator_device="cpu", trace=True)
652652

653653
# older versions didn't have the autocast method so we skip the test for older builds
654-
if Version(torch.__version__) >= Version("1.6.0"):
655-
with mock.patch("torch.cuda.amp.autocast") as mock_torch_cuda_amp_module:
654+
if Version(torch.__version__) >= Version("1.12.0"):
655+
with mock.patch("torch.amp.autocast") as mock_torch_cuda_amp_module:
656656
_test_create_evaluation_step(mock_torch_cuda_amp_module, evaluator_device="cpu", trace=True)
657657

658658

@@ -682,7 +682,7 @@ def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
682682
_test_mocked_supervised_evaluator(evaluator_device="mps")
683683

684684

685-
@pytest.mark.skipif(Version(torch.__version__) < Version("1.6.0"), reason="Skip if < 1.6.0")
685+
@pytest.mark.skipif(Version(torch.__version__) < Version("1.12.0"), reason="Skip if < 1.12.0")
686686
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
687687
def test_create_supervised_evaluator_on_cuda_amp():
688688
model_device = evaluator_device = "cuda"
@@ -691,7 +691,7 @@ def test_create_supervised_evaluator_on_cuda_amp():
691691

692692

693693
def test_create_supervised_evaluator_amp_error(mock_torch_cuda_amp_module):
694-
with pytest.raises(ImportError, match="Please install torch>=1.6.0 to use amp_mode='amp'."):
694+
with pytest.raises(ImportError, match="Please install torch>=1.12.0 to use amp_mode='amp'."):
695695
_test_create_supervised_evaluator(amp_mode="amp")
696696

697697

0 commit comments

Comments
 (0)