35
35
TorchTensorRTModule ,
36
36
)
37
37
from torch_tensorrt .dynamo .utils import (
38
+ CPU_DEVICE ,
38
39
check_module_output ,
40
+ delete_module ,
39
41
get_model_device ,
40
42
get_torch_inputs ,
41
43
set_log_level ,
@@ -312,9 +314,6 @@ def refit_module_weights(
312
314
get_decompositions (settings .enable_experimental_decompositions )
313
315
)
314
316
new_gm = new_weight_module .module ()
315
- # TODO: Memory control prototyping. Under discussion
316
- if settings .offload_module_to_cpu :
317
- new_weight_module .module ().to ("cpu" )
318
317
319
318
logger .debug ("Input graph: " + str (new_gm .graph ))
320
319
# Apply lowering on the graph module
@@ -393,7 +392,7 @@ def refit_module_weights(
393
392
394
393
# Iterate over all components that can be accelerated
395
394
# Generate the corresponding TRT Module for those
396
-
395
+ new_weight_module . module (). to ( CPU_DEVICE )
397
396
for name , new_submodule in new_partitioned_module .named_children ():
398
397
# Refit each submodule
399
398
# Extract engine from the submodule
@@ -496,11 +495,7 @@ def refit_module_weights(
496
495
settings = settings ,
497
496
weight_name_map = None ,
498
497
)
499
- # TODO: Memory control prototyping. Under discussion
500
- if settings .offload_module_to_cpu :
501
- del new_submodule
502
- gc .collect ()
503
- torch .cuda .empty_cache ()
498
+ delete_module (new_submodule )
504
499
505
500
# clear EXCLUDE_WEIGHTS flag
506
501
serialization_config = engine .create_serialization_config ()
@@ -523,20 +518,18 @@ def refit_module_weights(
523
518
gc .collect ()
524
519
torch .cuda .empty_cache ()
525
520
526
- # TODO: Memory control prototyping. Under discussion
527
- if settings .offload_module_to_cpu :
528
- del new_partitioned_module
529
- gc .collect ()
530
- torch .cuda .empty_cache ()
521
+ delete_module (new_partitioned_module )
531
522
532
523
if verify_output and arg_inputs is not None :
524
+ new_gm .to (torch .cuda .current_device ())
533
525
if check_module_output (
534
526
new_module = new_gm ,
535
527
refitted_module = compiled_module ,
536
528
arg_inputs = torch_inputs ,
537
529
kwarg_inputs = torch_kwarg_inputs ,
538
530
):
539
531
logger .info ("Refitting Succeed!" )
532
+ new_gm .to (CPU_DEVICE )
540
533
else :
541
534
if weight_name_map :
542
535
logger .warning (
@@ -552,6 +545,7 @@ def refit_module_weights(
552
545
in_place = in_place ,
553
546
)
554
547
logger .error ("Refitting Failed! The outputs do not match." )
548
+ new_gm .to (CPU_DEVICE )
555
549
else :
556
550
logger .info ("Refitting Completed! Output verification skipped." )
557
551
0 commit comments