@@ -475,11 +475,11 @@ def main_worker(gpu, ngpus_per_node, args):
475
475
x = torch .randn (args .batch_size , 3 , 224 , 224 ).contiguous (memory_format = torch .channels_last )
476
476
if args .bf16 :
477
477
x = x .to (torch .bfloat16 )
478
- with torch .cpu . amp . autocast (dtype = torch .bfloat16 ), torch .no_grad ():
478
+ with torch .autocast ("cpu" , dtype = torch .bfloat16 ), torch .no_grad ():
479
479
model = torch .jit .trace (model , x ).eval ()
480
480
elif args .fp16 :
481
481
x = x .to (torch .half )
482
- with torch .cpu . amp . autocast (dtype = torch .half ), torch .no_grad ():
482
+ with torch .autocast ("cpu" , dtype = torch .half ), torch .no_grad ():
483
483
model = torch .jit .trace (model , x ).eval ()
484
484
else :
485
485
with torch .no_grad ():
@@ -522,7 +522,7 @@ def main_worker(gpu, ngpus_per_node, args):
522
522
print ('[Info] Running torch.compile() with default backend' )
523
523
model = torch .compile (converted_model )
524
524
elif args .bf16 :
525
- with torch .no_grad (), torch .cpu . amp . autocast (dtype = torch .bfloat16 ):
525
+ with torch .no_grad (), torch .autocast ("cpu" , dtype = torch .bfloat16 ):
526
526
x = x .to (torch .bfloat16 )
527
527
if args .ipex :
528
528
print ('[Info] Running torch.compile() BFloat16 with IPEX backend' )
@@ -531,7 +531,7 @@ def main_worker(gpu, ngpus_per_node, args):
531
531
print ('[Info] Running torch.compile() BFloat16 with default backend' )
532
532
model = torch .compile (model )
533
533
elif args .fp16 :
534
- with torch .no_grad (), torch .cpu . amp . autocast (dtype = torch .half ):
534
+ with torch .no_grad (), torch .autocast ("cpu" , dtype = torch .half ):
535
535
x = x .to (torch .half )
536
536
if args .ipex :
537
537
print ('[Info] Running torch.compile() FPloat16 with IPEX backend' )
@@ -547,7 +547,7 @@ def main_worker(gpu, ngpus_per_node, args):
547
547
else :
548
548
print ('[Info] Running torch.compile() Float32 with default backend' )
549
549
model = torch .compile (model )
550
- with torch .no_grad (), torch .cpu . amp . autocast (enabled = args .bf16 or args .fp16 , dtype = torch .half if args .fp16 else torch .bfloat16 ):
550
+ with torch .no_grad (), torch .autocast ("cpu" , enabled = args .bf16 or args .fp16 , dtype = torch .half if args .fp16 else torch .bfloat16 ):
551
551
y = model (x )
552
552
y = model (x )
553
553
validate (val_loader , model , criterion , args )
@@ -572,7 +572,7 @@ def main_worker(gpu, ngpus_per_node, args):
572
572
model , optimizer = ipex .optimize (model , optimizer = optimizer , dtype = torch .half , fuse_update_step = False )
573
573
574
574
if args .inductor :
575
- with torch .cpu . amp . autocast (enabled = args .bf16 or args .fp16 , dtype = torch .half if args .fp16 else torch .bfloat16 ):
575
+ with torch .autocast ("cpu" , enabled = args .bf16 or args .fp16 , dtype = torch .half if args .fp16 else torch .bfloat16 ):
576
576
if args .ipex :
577
577
print ('[Info] Running training steps torch.compile() with IPEX backend' )
578
578
model = torch .compile (model , backend = "ipex" )
@@ -647,11 +647,11 @@ def train(train_loader, val_loader, model, criterion, optimizer, lr_scheduler, a
647
647
target = target .cuda (args .gpu , non_blocking = True )
648
648
649
649
if args .bf16 :
650
- with torch .cpu . amp . autocast (dtype = torch .bfloat16 ):
650
+ with torch .autocast ("cpu" , dtype = torch .bfloat16 ):
651
651
output = model (images )
652
652
output = output .to (torch .float32 )
653
653
elif args .fp16 :
654
- with torch .cpu . amp . autocast (dtype = torch .half ):
654
+ with torch .autocast ("cpu" , dtype = torch .half ):
655
655
output = model (images )
656
656
output = output .to (torch .float32 )
657
657
@@ -727,10 +727,10 @@ def run_weights_sharing_model(m, tid, args):
727
727
while num_images < steps :
728
728
start_time = time .time ()
729
729
if not args .jit and args .bf16 :
730
- with torch .cpu . amp . autocast (dtype = torch .bfloat16 ):
730
+ with torch .autocast ("cpu" , dtype = torch .bfloat16 ):
731
731
y = m (x )
732
732
elif not args .jit and args .fp16 :
733
- with torch .cpu . amp . autocast (dtype = torch .half ):
733
+ with torch .autocast ("cpu" , dtype = torch .half ):
734
734
y = m (x )
735
735
else :
736
736
y = m (x )
@@ -813,10 +813,10 @@ def validate(val_loader, model, criterion, args):
813
813
if i >= args .warmup_iterations :
814
814
end = time .time ()
815
815
if not args .jit and args .bf16 :
816
- with torch .cpu . amp . autocast (dtype = torch .bfloat16 ):
816
+ with torch .autocast ("cpu" , dtype = torch .bfloat16 ):
817
817
output = model (images )
818
818
elif not args .jit and args .fp16 :
819
- with torch .cpu . amp . autocast (dtype = torch .half ):
819
+ with torch .autocast ("cpu" , dtype = torch .half ):
820
820
output = model (images )
821
821
else :
822
822
output = model (images )
@@ -852,10 +852,10 @@ def validate(val_loader, model, criterion, args):
852
852
target = target .cuda (args .gpu , non_blocking = True )
853
853
854
854
if not args .jit and args .bf16 :
855
- with torch .cpu . amp . autocast (dtype = torch .bfloat16 ):
855
+ with torch .autocast ("cpu" , dtype = torch .bfloat16 ):
856
856
output = model (images )
857
857
elif not args .jit and args .fp16 :
858
- with torch .cpu . amp . autocast (dtype = torch .half ):
858
+ with torch .autocast ("cpu" , dtype = torch .half ):
859
859
output = model (images )
860
860
861
861
else :
0 commit comments