Skip to content

Commit 0491221

Browse files
authored
Change amp autocast to torch.autocast(device, ...) (#2585)
1 parent 0743653 commit 0491221

File tree

46 files changed

+246
-198
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+246
-198
lines changed

models/image_recognition/pytorch/common/inference.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def inference(model, dataloader, datatype, args):
9999
break
100100
x = x.to(memory_format=torch.channels_last)
101101
if args.precision == "bf16":
102-
with torch.cpu.amp.autocast(), torch.no_grad():
102+
with torch.autocast("cpu", ), torch.no_grad():
103103
model = torch.jit.trace(model, x, strict=False)
104104
model = torch.jit.freeze(model)
105105
else:
@@ -117,7 +117,7 @@ def inference(model, dataloader, datatype, args):
117117
else:
118118
images = images.to(memory_format=torch.channels_last)
119119
if args.ipex and args.precision == "bf16" and not args.jit:
120-
with torch.cpu.amp.autocast():
120+
with torch.autocast("cpu", ):
121121
if i == warmup_iters:
122122
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof, record_function("model_inference"):
123123
output = model(images)
@@ -141,7 +141,7 @@ def inference(model, dataloader, datatype, args):
141141
if not args.ipex and not args.jit:
142142
images = images.to(datatype).to(memory_format=torch.channels_last)
143143
if args.ipex and args.precision == "bf16" and not args.jit:
144-
with torch.cpu.amp.autocast():
144+
with torch.autocast("cpu", ):
145145
if i == warmup_iters:
146146
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof, record_function("model_inference"):
147147
output = model(images)

models/image_recognition/pytorch/common/main.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -475,11 +475,11 @@ def main_worker(gpu, ngpus_per_node, args):
475475
x = torch.randn(args.batch_size, 3, 224, 224).contiguous(memory_format=torch.channels_last)
476476
if args.bf16:
477477
x = x.to(torch.bfloat16)
478-
with torch.cpu.amp.autocast(dtype=torch.bfloat16), torch.no_grad():
478+
with torch.autocast("cpu", dtype=torch.bfloat16), torch.no_grad():
479479
model = torch.jit.trace(model, x).eval()
480480
elif args.fp16:
481481
x = x.to(torch.half)
482-
with torch.cpu.amp.autocast(dtype=torch.half), torch.no_grad():
482+
with torch.autocast("cpu", dtype=torch.half), torch.no_grad():
483483
model = torch.jit.trace(model, x).eval()
484484
else:
485485
with torch.no_grad():
@@ -522,7 +522,7 @@ def main_worker(gpu, ngpus_per_node, args):
522522
print('[Info] Running torch.compile() with default backend')
523523
model = torch.compile(converted_model)
524524
elif args.bf16:
525-
with torch.no_grad(), torch.cpu.amp.autocast(dtype=torch.bfloat16):
525+
with torch.no_grad(), torch.autocast("cpu", dtype=torch.bfloat16):
526526
x = x.to(torch.bfloat16)
527527
if args.ipex:
528528
print('[Info] Running torch.compile() BFloat16 with IPEX backend')
@@ -531,7 +531,7 @@ def main_worker(gpu, ngpus_per_node, args):
531531
print('[Info] Running torch.compile() BFloat16 with default backend')
532532
model = torch.compile(model)
533533
elif args.fp16:
534-
with torch.no_grad(), torch.cpu.amp.autocast(dtype=torch.half):
534+
with torch.no_grad(), torch.autocast("cpu", dtype=torch.half):
535535
x = x.to(torch.half)
536536
if args.ipex:
537537
print('[Info] Running torch.compile() FPloat16 with IPEX backend')
@@ -547,7 +547,7 @@ def main_worker(gpu, ngpus_per_node, args):
547547
else:
548548
print('[Info] Running torch.compile() Float32 with default backend')
549549
model = torch.compile(model)
550-
with torch.no_grad(), torch.cpu.amp.autocast(enabled=args.bf16 or args.fp16, dtype=torch.half if args.fp16 else torch.bfloat16):
550+
with torch.no_grad(), torch.autocast("cpu", enabled=args.bf16 or args.fp16, dtype=torch.half if args.fp16 else torch.bfloat16):
551551
y = model(x)
552552
y = model(x)
553553
validate(val_loader, model, criterion, args)
@@ -572,7 +572,7 @@ def main_worker(gpu, ngpus_per_node, args):
572572
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.half, fuse_update_step=False)
573573

574574
if args.inductor:
575-
with torch.cpu.amp.autocast(enabled=args.bf16 or args.fp16, dtype=torch.half if args.fp16 else torch.bfloat16):
575+
with torch.autocast("cpu", enabled=args.bf16 or args.fp16, dtype=torch.half if args.fp16 else torch.bfloat16):
576576
if args.ipex:
577577
print('[Info] Running training steps torch.compile() with IPEX backend')
578578
model = torch.compile(model, backend="ipex")
@@ -647,11 +647,11 @@ def train(train_loader, val_loader, model, criterion, optimizer, lr_scheduler, a
647647
target = target.cuda(args.gpu, non_blocking=True)
648648

649649
if args.bf16:
650-
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
650+
with torch.autocast("cpu", dtype=torch.bfloat16):
651651
output = model(images)
652652
output = output.to(torch.float32)
653653
elif args.fp16:
654-
with torch.cpu.amp.autocast(dtype=torch.half):
654+
with torch.autocast("cpu", dtype=torch.half):
655655
output = model(images)
656656
output = output.to(torch.float32)
657657

@@ -727,10 +727,10 @@ def run_weights_sharing_model(m, tid, args):
727727
while num_images < steps:
728728
start_time = time.time()
729729
if not args.jit and args.bf16:
730-
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
730+
with torch.autocast("cpu", dtype=torch.bfloat16):
731731
y = m(x)
732732
elif not args.jit and args.fp16:
733-
with torch.cpu.amp.autocast(dtype=torch.half):
733+
with torch.autocast("cpu", dtype=torch.half):
734734
y = m(x)
735735
else:
736736
y = m(x)
@@ -813,10 +813,10 @@ def validate(val_loader, model, criterion, args):
813813
if i >= args.warmup_iterations:
814814
end = time.time()
815815
if not args.jit and args.bf16:
816-
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
816+
with torch.autocast("cpu", dtype=torch.bfloat16):
817817
output = model(images)
818818
elif not args.jit and args.fp16:
819-
with torch.cpu.amp.autocast(dtype=torch.half):
819+
with torch.autocast("cpu", dtype=torch.half):
820820
output = model(images)
821821
else:
822822
output = model(images)
@@ -852,10 +852,10 @@ def validate(val_loader, model, criterion, args):
852852
target = target.cuda(args.gpu, non_blocking=True)
853853

854854
if not args.jit and args.bf16:
855-
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
855+
with torch.autocast("cpu", dtype=torch.bfloat16):
856856
output = model(images)
857857
elif not args.jit and args.fp16:
858-
with torch.cpu.amp.autocast(dtype=torch.half):
858+
with torch.autocast("cpu", dtype=torch.half):
859859
output = model(images)
860860

861861
else:

models/image_recognition/pytorch/common/main_runtime_extension.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def main_worker(gpu, ngpus_per_node, args):
411411
x = torch.randn(batch_per_stream, 3, 224, 224).contiguous(memory_format=torch.channels_last)
412412
if args.bf16:
413413
x = x.to(torch.bfloat16)
414-
with torch.cpu.amp.autocast(), torch.no_grad():
414+
with torch.autocast("cpu", ), torch.no_grad():
415415
model = torch.jit.trace(model, x).eval()
416416
else:
417417
with torch.no_grad():
@@ -492,7 +492,7 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
492492
# compute output
493493

494494
if args.bf16:
495-
with torch.cpu.amp.autocast():
495+
with torch.autocast("cpu", ):
496496
output = model(images)
497497
output = output.to(torch.float32)
498498
else:
@@ -536,7 +536,7 @@ def run_weights_sharing_model(m, tid, args):
536536
while num_images < steps:
537537
start_time = time.time()
538538
if not args.jit and args.bf16:
539-
with torch.cpu.amp.autocast():
539+
with torch.autocast("cpu", ):
540540
y = m(x)
541541
else:
542542
y = m(x)
@@ -600,7 +600,7 @@ def validate(val_loader, model, criterion, args):
600600
if i >= args.warmup_iterations:
601601
end = time.time()
602602
if not args.jit and args.bf16:
603-
with torch.cpu.amp.autocast():
603+
with torch.autocast("cpu", ):
604604
output = model(images)
605605
else:
606606
output = model(images)
@@ -632,7 +632,7 @@ def validate(val_loader, model, criterion, args):
632632
if args.bf16:
633633
images = images.to(torch.bfloat16)
634634
if not args.jit and args.bf16:
635-
with torch.cpu.amp.autocast():
635+
with torch.autocast("cpu", ):
636636
output = model(images)
637637
else:
638638
output = model(images)

models/image_recognition/pytorch/common/train.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,11 @@ def main_worker(args):
328328

329329
# Forward pass
330330
if args.ipex and args.bf16:
331-
with torch.cpu.amp.autocast():
331+
with torch.autocast("cpu", ):
332332
output = model(images)
333333
output = output.to(torch.float32)
334334
elif args.ipex and args.fp16:
335-
with torch.cpu.amp.autocast(dtype=torch.half):
335+
with torch.autocast("cpu", dtype=torch.half):
336336
output = model(images)
337337
output = output.to(torch.float32)
338338
else:
@@ -472,12 +472,12 @@ def validate(val_loader, model, criterion, epoch, args):
472472

473473
if args.ipex and args.bf16:
474474
images = images.to(torch.bfloat16)
475-
with torch.cpu.amp.autocast():
475+
with torch.autocast("cpu", ):
476476
output = model(images)
477477
output = output.to(torch.float32)
478478
if args.ipex and args.fp16:
479479
images = images.to(torch.half)
480-
with torch.cpu.amp.autocast(dtype=torch.half):
480+
with torch.autocast("cpu", dtype=torch.half):
481481
output = model(images)
482482
output = output.to(torch.float32)
483483
else:

models_v2/pytorch/3d_unet/inference/gpu/predict.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def inference_config(model):
173173
io_utils.write_info('Using JIT trace')
174174
with torch.inference_mode():
175175
if args.xpu and args.ipex:
176-
with torch.xpu.amp.autocast(enabled=use_autocast, dtype=autocast_dtype, cache_enabled=False):
176+
with torch.autocast("xpu", enabled=use_autocast, dtype=autocast_dtype, cache_enabled=False):
177177
model = torch.jit.trace(model, trace_input)
178178
elif args.gpu or args.xpu:
179179
with torch.autocast(enabled=use_autocast, device_type=get_device_type(), dtype=autocast_dtype, cache_enabled=False):
@@ -226,7 +226,7 @@ def do_warmup(model, ds):
226226
outputs += [model(images)]
227227
else:
228228
if args.ipex:
229-
with torch.xpu.amp.autocast(enabled=use_autocast, dtype=autocast_dtype, cache_enabled=True):
229+
with torch.autocast("xpu", enabled=use_autocast, dtype=autocast_dtype, cache_enabled=True):
230230
for batch_repeat_index in range(min([args.batch_streaming, args.warm_up - len(outputs)])):
231231
outputs += [model(images)]
232232
else:
@@ -316,7 +316,7 @@ def do_perf_benchmarking(model, ds, gt_data):
316316
torch.xpu.synchronize(args.device)
317317
statistics_utils.accuracy(args, outputs[0], target, overall, whole, core, enhancing, gt_data)
318318
else:
319-
with torch.xpu.amp.autocast(enabled=use_autocast, dtype=autocast_dtype, cache_enabled=False):
319+
with torch.autocast("xpu", enabled=use_autocast, dtype=autocast_dtype, cache_enabled=False):
320320
start_time = time.time()
321321
# compute output
322322
for batch_repeat_index in range(args.batch_streaming):

models_v2/pytorch/LCM/inference/cpu/inference.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def run_weights_sharing_model(pipe, tid, args):
7171
# run model
7272
start = time.time()
7373
if args.precision == "bf16" or args.precision == "fp16" or args.precision == "int8-bf16":
74-
with torch.cpu.amp.autocast(dtype=args.dtype), torch.no_grad():
74+
with torch.autocast("cpu", dtype=args.dtype), torch.no_grad():
7575
output = pipe(args.prompt, generator=torch.manual_seed(args.seed)).images
7676
else:
7777
with torch.no_grad():
@@ -180,7 +180,7 @@ def main():
180180
if args.precision == "int8-fp32":
181181
pipe.unet = ipex.quantization.convert(pipe.unet)
182182
else:
183-
with torch.cpu.amp.autocast():
183+
with torch.autocast("cpu", ):
184184
pipe.unet = ipex.quantization.convert(pipe.unet)
185185
pipe.vae = ipex.optimize(pipe.vae.eval(), dtype=args.dtype, inplace=True)
186186
print("running int8 evalation step\n")
@@ -231,14 +231,14 @@ def main():
231231
print("JIT trace ...")
232232
# from utils_vis import make_dot, draw
233233
if args.precision == "bf16" or args.precision == "fp16":
234-
with torch.cpu.amp.autocast(dtype=args.dtype), torch.no_grad():
234+
with torch.autocast("cpu", dtype=args.dtype), torch.no_grad():
235235
pipe.traced_unet = torch.jit.trace(pipe.unet, input, strict=False)
236236
pipe.traced_unet = torch.jit.freeze(pipe.traced_unet)
237237
pipe.traced_unet(*input)
238238
pipe.traced_unet(*input)
239239
# print(pipe.traced_unet.graph_for(input))
240240
elif args.precision == "int8-bf16":
241-
with torch.cpu.amp.autocast(), torch.no_grad():
241+
with torch.autocast("cpu", ), torch.no_grad():
242242
pipe.traced_unet = torch.jit.trace(pipe.unet, input, strict=False)
243243
pipe.traced_unet = torch.jit.freeze(pipe.traced_unet)
244244
pipe.traced_unet(*input)
@@ -289,14 +289,14 @@ def main():
289289
pipe.text_encoder = torch.compile(pipe.text_encoder)
290290
pipe.vae.decode = torch.compile(pipe.vae.decode)
291291
elif args.precision == "bf16":
292-
with torch.cpu.amp.autocast(), torch.no_grad():
292+
with torch.autocast("cpu", ), torch.no_grad():
293293
pipe.unet = torch.compile(pipe.unet)
294294
pipe.unet(*input)
295295
pipe.unet(*input)
296296
pipe.text_encoder = torch.compile(pipe.text_encoder)
297297
pipe.vae.decode = torch.compile(pipe.vae.decode)
298298
elif args.precision == "fp16":
299-
with torch.cpu.amp.autocast(dtype=torch.half), torch.no_grad():
299+
with torch.autocast("cpu", dtype=torch.half), torch.no_grad():
300300
pipe.unet = torch.compile(pipe.unet)
301301
pipe.unet(*input)
302302
pipe.unet(*input)
@@ -465,7 +465,7 @@ def main():
465465
pipe(args.prompt)
466466
pipe.traced_unet = convert_pt2e(pipe.traced_unet)
467467
torch.ao.quantization.move_exported_model_to_eval(pipe.traced_unet)
468-
with torch.cpu.amp.autocast(), torch.no_grad():
468+
with torch.autocast("cpu", ), torch.no_grad():
469469
pipe.traced_unet = torch.compile(pipe.traced_unet)
470470
pipe.traced_unet(*input)
471471
pipe.traced_unet(*input)
@@ -493,7 +493,7 @@ def main():
493493
# run model
494494
start = time.time()
495495
if args.precision == "bf16" or args.precision == "fp16" or args.precision == "int8-bf16":
496-
with torch.cpu.amp.autocast(dtype=args.dtype), torch.no_grad():
496+
with torch.autocast("cpu", dtype=args.dtype), torch.no_grad():
497497
output = pipe(args.prompt, generator=torch.manual_seed(args.seed)).images
498498
else:
499499
with torch.no_grad():
@@ -517,7 +517,7 @@ def main():
517517
real_image = images[0]
518518
print("prompt: ", prompt)
519519
if args.precision == "bf16" or args.precision == "fp16" or args.precision == "int8-bf16":
520-
with torch.cpu.amp.autocast(dtype=args.dtype), torch.no_grad():
520+
with torch.autocast("cpu", dtype=args.dtype), torch.no_grad():
521521
output = pipe(prompt, generator=torch.manual_seed(args.seed), output_type="numpy").images
522522
else:
523523
with torch.no_grad():
@@ -546,7 +546,7 @@ def main():
546546
print("Running profiling ...")
547547
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU], record_shapes=True) as p:
548548
if args.precision == "bf16" or args.precision == "fp16" or args.precision == "int8-bf16":
549-
with torch.cpu.amp.autocast(dtype=args.dtype), torch.no_grad():
549+
with torch.autocast("cpu", dtype=args.dtype), torch.no_grad():
550550
pipe(args.prompt, generator=torch.manual_seed(args.seed)).images
551551
else:
552552
with torch.no_grad():

models_v2/pytorch/bert_large/inference/gpu/run_squad.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def load_jit_model(model, inputs, dtype, device, jit_trace_path, use_jit_cache):
151151
in_1 = torch.unsqueeze(inputs["input_ids"][0].clone(), 0)
152152
in_2 = torch.unsqueeze(inputs["token_type_ids"][0].clone(), 0)
153153
in_3 = torch.unsqueeze(inputs["attention_mask"][0].clone(), 0)
154-
with torch.xpu.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
154+
with torch.autocast("xpu", enabled=True, dtype=dtype, cache_enabled=False):
155155
jit_model = torch.jit.trace(model,
156156
(in_1.to(device),
157157
in_2.to(device),
@@ -396,11 +396,11 @@ def train(args, train_dataset, model, tokenizer):
396396

397397
# autocast context
398398
if args.device_choice == 'cuda':
399-
autocast_context = torch.cuda.amp.autocast(enabled=use_autocast, dtype=autocast_dtype)
399+
autocast_context = torch.autocast("cuda", enabled=use_autocast, dtype=autocast_dtype)
400400
elif args.device_choice == 'xpu':
401-
autocast_context = torch.xpu.amp.autocast(enabled=use_autocast, dtype=autocast_dtype)
401+
autocast_context = torch.autocast("xpu", enabled=use_autocast, dtype=autocast_dtype)
402402
else:
403-
autocast_context = torch.cpu.amp.autocast(enabled=use_autocast, dtype=autocast_dtype)
403+
autocast_context = torch.autocast("cpu", enabled=use_autocast, dtype=autocast_dtype)
404404

405405
import contextlib
406406
profile_context = contextlib.nullcontext()

0 commit comments

Comments
 (0)