diff --git a/classify/train.py b/classify/train.py index d454c7187339..6233033e4449 100644 --- a/classify/train.py +++ b/classify/train.py @@ -27,7 +27,6 @@ import torch.hub as hub import torch.optim.lr_scheduler as lr_scheduler import torchvision -from torch.cuda import amp from tqdm import tqdm FILE = Path(__file__).resolve() @@ -48,6 +47,7 @@ check_git_info, check_git_status, check_requirements, + check_version, colorstr, download, increment_path, @@ -198,7 +198,13 @@ def lf(x): t0 = time.time() criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function best_fitness = 0.0 - scaler = amp.GradScaler(enabled=cuda) + + scaler = None + if check_version(torch.__version__, "2.4.0"): + scaler = torch.amp.GradScaler("cuda", enabled=cuda) + else: + scaler = torch.cuda.amp.GradScaler(enabled=cuda) + val = test_dir.stem # 'val' or 'test' LOGGER.info( f"Image sizes {imgsz} train, {imgsz} test\n" @@ -218,8 +224,14 @@ def lf(x): for i, (images, labels) in pbar: # progress bar images, labels = images.to(device, non_blocking=True), labels.to(device) + amp_autocast = None + if check_version(torch.__version__, "2.4.0"): + amp_autocast = torch.amp.autocast("cuda", enabled=device.type != "cpu") + else: + amp_autocast = torch.cuda.amp.autocast(enabled=device.type != "cpu") + # Forward - with amp.autocast(enabled=cuda): # stability issues when enabled + with amp_autocast: # stability issues when enabled loss = criterion(model(images), labels) # Backward diff --git a/classify/val.py b/classify/val.py index 72bd0e14e2c5..b350d2b005c3 100644 --- a/classify/val.py +++ b/classify/val.py @@ -42,6 +42,7 @@ Profile, check_img_size, check_requirements, + check_version, colorstr, increment_path, print_args, @@ -108,7 +109,14 @@ def run( action = "validating" if dataloader.dataset.root.stem == "val" else "testing" desc = f"{pbar.desc[:-36]}{action:>36}" if pbar else f"{action}" bar = tqdm(dataloader, desc, n, not training, bar_format=TQDM_BAR_FORMAT, position=0) - with torch.cuda.amp.autocast(enabled=device.type != "cpu"): + + amp_autocast = None + if check_version(torch.__version__, "2.4.0"): + amp_autocast = torch.amp.autocast("cuda", enabled=device.type != "cpu") + else: + amp_autocast = torch.cuda.amp.autocast(enabled=device.type != "cpu") + + with amp_autocast: for images, labels in bar: with dt[0]: images, labels = images.to(device, non_blocking=True), labels.to(device) diff --git a/models/common.py b/models/common.py index ea893db4b66f..1e3b00934765 100644 --- a/models/common.py +++ b/models/common.py @@ -20,7 +20,6 @@ import torch import torch.nn as nn from PIL import Image -from torch.cuda import amp # Import 'ultralytics' package or install if missing try: @@ -864,7 +863,12 @@ def forward(self, ims, size=640, augment=False, profile=False): p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param autocast = self.amp and (p.device.type != "cpu") # Automatic Mixed Precision (AMP) inference if isinstance(ims, torch.Tensor): # torch - with amp.autocast(autocast): + amp_autocast = None + if check_version(torch.__version__, "2.4.0"): + amp_autocast = torch.amp.autocast("cuda", enabled=autocast) + else: + amp_autocast = torch.cuda.amp.autocast(enabled=autocast) + with amp_autocast: return self.model(ims.to(p.device).type_as(p), augment=augment) # inference # Pre-process @@ -891,7 +895,13 @@ def forward(self, ims, size=640, augment=False, profile=False): x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32 - with amp.autocast(autocast): + amp_autocast = None + if check_version(torch.__version__, "2.4.0"): + amp_autocast = torch.amp.autocast("cuda", enabled=autocast) + else: + amp_autocast = torch.cuda.amp.autocast(enabled=autocast) + + with amp_autocast: # Inference with dt[1]: y = self.model(x, augment=augment) # forward diff --git a/segment/train.py b/segment/train.py index 815c97ce1d48..1a2e034e50ba 100644 --- a/segment/train.py +++ b/segment/train.py @@ -58,6 +58,7 @@ check_img_size, check_requirements, check_suffix, + check_version, check_yaml, colorstr, get_latest_run, @@ -320,7 +321,13 @@ def lf(x): maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move - scaler = torch.cuda.amp.GradScaler(enabled=amp) + + scaler = None + if check_version(torch.__version__, "2.4.0"): + scaler = torch.amp.GradScaler("cuda", enabled=amp) + else: + scaler = torch.cuda.amp.GradScaler(enabled=amp) + stopper, stop = EarlyStopping(patience=opt.patience), False compute_loss = ComputeLoss(model, overlap=overlap) # init loss class # callbacks.run('on_train_start') @@ -379,8 +386,14 @@ def lf(x): ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) + amp_autocast = None + if check_version(torch.__version__, "2.4.0"): + amp_autocast = torch.amp.autocast("cuda", enabled=amp) + else: + amp_autocast = torch.cuda.amp.autocast(enabled=amp) + # Forward - with torch.cuda.amp.autocast(amp): + with amp_autocast: pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float()) if RANK != -1: diff --git a/train.py b/train.py index 1401ccb969b4..ff93b687fb65 100644 --- a/train.py +++ b/train.py @@ -63,6 +63,7 @@ check_img_size, check_requirements, check_suffix, + check_version, check_yaml, colorstr, get_latest_run, @@ -352,7 +353,13 @@ def lf(x): maps = np.zeros(nc) # mAP per class results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move - scaler = torch.cuda.amp.GradScaler(enabled=amp) + + scaler = None + if check_version(torch.__version__, "2.4.0"): + scaler = torch.amp.GradScaler("cuda", enabled=amp) + else: + scaler = torch.cuda.amp.GradScaler(enabled=amp) + stopper, stop = EarlyStopping(patience=opt.patience), False compute_loss = ComputeLoss(model) # init loss class callbacks.run("on_train_start") @@ -408,8 +415,14 @@ def lf(x): ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False) + amp_autocast = None + if check_version(torch.__version__, "2.4.0"): + amp_autocast = torch.amp.autocast("cuda", enabled=amp) + else: + amp_autocast = torch.cuda.amp.autocast(amp) + # Forward - with torch.cuda.amp.autocast(amp): + with amp_autocast: pred = model(imgs) # forward loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size if RANK != -1: diff --git a/utils/autobatch.py b/utils/autobatch.py index 9d5ea0a94296..8f67497c7bcb 100644 --- a/utils/autobatch.py +++ b/utils/autobatch.py @@ -6,12 +6,15 @@ import numpy as np import torch -from utils.general import LOGGER, colorstr +from utils.general import LOGGER, check_version, colorstr from utils.torch_utils import profile def check_train_batch_size(model, imgsz=640, amp=True): """Checks and computes optimal training batch size for YOLOv5 model, given image size and AMP setting.""" + if check_version(torch.__version__, "2.4.0"): + with torch.amp.autocast("cuda", enabled=amp): + return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size with torch.cuda.amp.autocast(amp): return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size