@@ -275,24 +275,6 @@ async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str:
275
275
return config_map [inference_framework ]
276
276
277
277
278
- def _include_safetensors_bin_or_pt (model_files : List [str ]) -> Optional [str ]:
279
- """
280
- This function is used to determine whether to include "*.safetensors", "*.bin", or "*.pt" files
281
- based on which file type is present most often in the checkpoint folder. The most
282
- frequently present file type is included.
283
- In case of ties, priority is given to "*.safetensors", then "*.bin", then "*.pt".
284
- """
285
- num_safetensors = len ([f for f in model_files if f .endswith (".safetensors" )])
286
- num_bin = len ([f for f in model_files if f .endswith (".bin" )])
287
- num_pt = len ([f for f in model_files if f .endswith (".pt" )])
288
- maximum = max (num_safetensors , num_bin , num_pt )
289
- if num_safetensors == maximum :
290
- return "*.safetensors"
291
- if num_bin == maximum :
292
- return "*.bin"
293
- return "*.pt"
294
-
295
-
296
278
def _model_endpoint_entity_to_get_llm_model_endpoint_response (
297
279
model_endpoint : ModelEndpoint ,
298
280
) -> GetLLMModelEndpointV1Response :
@@ -354,6 +336,10 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None:
354
336
raise ObjectHasInvalidValueException (
355
337
f"Only S3 paths are supported. Given checkpoint path: { checkpoint_path } ."
356
338
)
339
+ if checkpoint_path .endswith (".tar" ):
340
+ raise ObjectHasInvalidValueException (
341
+ f"Tar files are not supported. Given checkpoint path: { checkpoint_path } ."
342
+ )
357
343
358
344
359
345
def get_checkpoint_path (model_name : str , checkpoint_path_override : Optional [str ]) -> str :
@@ -370,6 +356,14 @@ def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]
370
356
return checkpoint_path
371
357
372
358
359
+ def validate_checkpoint_files (checkpoint_files : List [str ]) -> None :
360
+ """Require safetensors in the checkpoint path."""
361
+ model_files = [f for f in checkpoint_files if "model" in f ]
362
+ num_safetensors = len ([f for f in model_files if f .endswith (".safetensors" )])
363
+ if num_safetensors == 0 :
364
+ raise ObjectHasInvalidValueException ("No safetensors found in the checkpoint path." )
365
+
366
+
373
367
class CreateLLMModelBundleV1UseCase :
374
368
def __init__ (
375
369
self ,
@@ -557,27 +551,14 @@ def load_model_weights_sub_commands(
557
551
else :
558
552
s5cmd = "./s5cmd"
559
553
560
- base_path = checkpoint_path .split ("/" )[- 1 ]
561
- if base_path .endswith (".tar" ):
562
- # If the checkpoint file is a tar file, extract it into final_weights_folder
563
- subcommands .extend (
564
- [
565
- f"{ s5cmd } cp { checkpoint_path } ." ,
566
- f"mkdir -p { final_weights_folder } " ,
567
- f"tar --no-same-owner -xf { base_path } -C { final_weights_folder } " ,
568
- ]
569
- )
570
- else :
571
- # Let's check whether to exclude "*.safetensors" or "*.bin" files
572
- checkpoint_files = self .llm_artifact_gateway .list_files (checkpoint_path )
573
- model_files = [f for f in checkpoint_files if "model" in f ]
574
-
575
- include_str = _include_safetensors_bin_or_pt (model_files )
576
- file_selection_str = f"--include '*.model' --include '*.json' --include '{ include_str } ' --exclude 'optimizer*'"
577
- subcommands .append (
578
- f"{ s5cmd } --numworkers 512 cp --concurrency 10 { file_selection_str } { os .path .join (checkpoint_path , '*' )} { final_weights_folder } "
579
- )
554
+ checkpoint_files = self .llm_artifact_gateway .list_files (checkpoint_path )
555
+ validate_checkpoint_files (checkpoint_files )
580
556
557
+ # filter to configs ('*.model' and '*.json') and weights ('*.safetensors')
558
+ file_selection_str = "--include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*'"
559
+ subcommands .append (
560
+ f"{ s5cmd } --numworkers 512 cp --concurrency 10 { file_selection_str } { os .path .join (checkpoint_path , '*' )} { final_weights_folder } "
561
+ )
581
562
return subcommands
582
563
583
564
def load_model_files_sub_commands_trt_llm (
@@ -591,19 +572,9 @@ def load_model_files_sub_commands_trt_llm(
591
572
See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt
592
573
and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt
593
574
"""
594
- subcommands = []
595
-
596
- base_path = checkpoint_path .split ("/" )[- 1 ]
597
-
598
- if base_path .endswith (".tar" ):
599
- raise ObjectHasInvalidValueException (
600
- "Checkpoint for TensorRT-LLM models must be a folder, not a tar file."
601
- )
602
- else :
603
- subcommands .append (
604
- f"./s5cmd --numworkers 512 cp --concurrency 50 { os .path .join (checkpoint_path , '*' )} ./"
605
- )
606
-
575
+ subcommands = [
576
+ f"./s5cmd --numworkers 512 cp --concurrency 50 { os .path .join (checkpoint_path , '*' )} ./"
577
+ ]
607
578
return subcommands
608
579
609
580
async def create_deepspeed_bundle (
0 commit comments