Skip to content

Commit a87e5aa

Browse files
authored
require safetensors (#510)
1 parent 13da4c1 commit a87e5aa

File tree

4 files changed

+43
-84
lines changed

4 files changed

+43
-84
lines changed

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

+22-51
Original file line numberDiff line numberDiff line change
@@ -275,24 +275,6 @@ async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str:
275275
return config_map[inference_framework]
276276

277277

278-
def _include_safetensors_bin_or_pt(model_files: List[str]) -> Optional[str]:
279-
"""
280-
This function is used to determine whether to include "*.safetensors", "*.bin", or "*.pt" files
281-
based on which file type is present most often in the checkpoint folder. The most
282-
frequently present file type is included.
283-
In case of ties, priority is given to "*.safetensors", then "*.bin", then "*.pt".
284-
"""
285-
num_safetensors = len([f for f in model_files if f.endswith(".safetensors")])
286-
num_bin = len([f for f in model_files if f.endswith(".bin")])
287-
num_pt = len([f for f in model_files if f.endswith(".pt")])
288-
maximum = max(num_safetensors, num_bin, num_pt)
289-
if num_safetensors == maximum:
290-
return "*.safetensors"
291-
if num_bin == maximum:
292-
return "*.bin"
293-
return "*.pt"
294-
295-
296278
def _model_endpoint_entity_to_get_llm_model_endpoint_response(
297279
model_endpoint: ModelEndpoint,
298280
) -> GetLLMModelEndpointV1Response:
@@ -354,6 +336,10 @@ def validate_checkpoint_path_uri(checkpoint_path: str) -> None:
354336
raise ObjectHasInvalidValueException(
355337
f"Only S3 paths are supported. Given checkpoint path: {checkpoint_path}."
356338
)
339+
if checkpoint_path.endswith(".tar"):
340+
raise ObjectHasInvalidValueException(
341+
f"Tar files are not supported. Given checkpoint path: {checkpoint_path}."
342+
)
357343

358344

359345
def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str:
@@ -370,6 +356,14 @@ def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]
370356
return checkpoint_path
371357

372358

359+
def validate_checkpoint_files(checkpoint_files: List[str]) -> None:
360+
"""Require safetensors in the checkpoint path."""
361+
model_files = [f for f in checkpoint_files if "model" in f]
362+
num_safetensors = len([f for f in model_files if f.endswith(".safetensors")])
363+
if num_safetensors == 0:
364+
raise ObjectHasInvalidValueException("No safetensors found in the checkpoint path.")
365+
366+
373367
class CreateLLMModelBundleV1UseCase:
374368
def __init__(
375369
self,
@@ -557,27 +551,14 @@ def load_model_weights_sub_commands(
557551
else:
558552
s5cmd = "./s5cmd"
559553

560-
base_path = checkpoint_path.split("/")[-1]
561-
if base_path.endswith(".tar"):
562-
# If the checkpoint file is a tar file, extract it into final_weights_folder
563-
subcommands.extend(
564-
[
565-
f"{s5cmd} cp {checkpoint_path} .",
566-
f"mkdir -p {final_weights_folder}",
567-
f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}",
568-
]
569-
)
570-
else:
571-
# Let's check whether to exclude "*.safetensors" or "*.bin" files
572-
checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path)
573-
model_files = [f for f in checkpoint_files if "model" in f]
574-
575-
include_str = _include_safetensors_bin_or_pt(model_files)
576-
file_selection_str = f"--include '*.model' --include '*.json' --include '{include_str}' --exclude 'optimizer*'"
577-
subcommands.append(
578-
f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
579-
)
554+
checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path)
555+
validate_checkpoint_files(checkpoint_files)
580556

557+
# filter to configs ('*.model' and '*.json') and weights ('*.safetensors')
558+
file_selection_str = "--include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*'"
559+
subcommands.append(
560+
f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
561+
)
581562
return subcommands
582563

583564
def load_model_files_sub_commands_trt_llm(
@@ -591,19 +572,9 @@ def load_model_files_sub_commands_trt_llm(
591572
See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt
592573
and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt
593574
"""
594-
subcommands = []
595-
596-
base_path = checkpoint_path.split("/")[-1]
597-
598-
if base_path.endswith(".tar"):
599-
raise ObjectHasInvalidValueException(
600-
"Checkpoint for TensorRT-LLM models must be a folder, not a tar file."
601-
)
602-
else:
603-
subcommands.append(
604-
f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./"
605-
)
606-
575+
subcommands = [
576+
f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./"
577+
]
607578
return subcommands
608579

609580
async def create_deepspeed_bundle(

model-engine/tests/unit/conftest.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -757,21 +757,25 @@ class FakeLLMArtifactGateway(LLMArtifactGateway):
757757
def __init__(self):
758758
self.existing_models = []
759759
self.s3_bucket = {
760-
"fake-checkpoint": ["fake.bin, fake2.bin", "fake3.safetensors"],
760+
"fake-checkpoint": ["model-fake.bin, model-fake2.bin", "model-fake.safetensors"],
761761
"llama-7b/tokenizer.json": ["llama-7b/tokenizer.json"],
762762
"llama-7b/tokenizer_config.json": ["llama-7b/tokenizer_config.json"],
763763
"llama-7b/special_tokens_map.json": ["llama-7b/special_tokens_map.json"],
764+
"llama-2-7b": ["model-fake.safetensors"],
765+
"mpt-7b": ["model-fake.safetensors"],
764766
}
765767
self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"}
766768

767769
def _add_model(self, owner: str, model_name: str):
768770
self.existing_models.append((owner, model_name))
769771

770772
def list_files(self, path: str, **kwargs) -> List[str]:
773+
path = path.lstrip("s3://")
771774
if path in self.s3_bucket:
772775
return self.s3_bucket[path]
773776

774777
def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]:
778+
path = path.lstrip("s3://")
775779
if path in self.s3_bucket:
776780
return self.s3_bucket[path]
777781

model-engine/tests/unit/domain/conftest.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def create_llm_model_endpoint_request_sync() -> CreateLLMModelEndpointV1Request:
196196
labels={"team": "infra", "product": "my_product"},
197197
aws_role="test_aws_role",
198198
results_s3_bucket="test_s3_bucket",
199+
checkpoint_path="s3://mpt-7b",
199200
)
200201

201202

@@ -222,7 +223,7 @@ def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request
222223
labels={"team": "infra", "product": "my_product"},
223224
aws_role="test_aws_role",
224225
results_s3_bucket="test_s3_bucket",
225-
checkpoint_path="s3://test-s3.tar",
226+
checkpoint_path="s3://llama-2-7b",
226227
)
227228

228229

@@ -249,14 +250,15 @@ def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Req
249250
labels={"team": "infra", "product": "my_product"},
250251
aws_role="test_aws_role",
251252
results_s3_bucket="test_s3_bucket",
253+
checkpoint_path="s3://mpt-7b",
252254
)
253255

254256

255257
@pytest.fixture
256258
def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request:
257259
return UpdateLLMModelEndpointV1Request(
258260
inference_framework_image_tag="latest",
259-
checkpoint_path="s3://test_checkpoint_path",
261+
checkpoint_path="s3://mpt-7b",
260262
memory="4G",
261263
min_workers=0,
262264
max_workers=1,
@@ -286,7 +288,7 @@ def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Reque
286288
labels={"team": "infra", "product": "my_product"},
287289
aws_role="test_aws_role",
288290
results_s3_bucket="test_s3_bucket",
289-
checkpoint_path="s3://test-s3.tar",
291+
checkpoint_path="s3://llama-2-7b",
290292
)
291293

292294

@@ -315,6 +317,7 @@ def create_llm_model_endpoint_text_generation_inference_request_streaming() -> (
315317
labels={"team": "infra", "product": "my_product"},
316318
aws_role="test_aws_role",
317319
results_s3_bucket="test_s3_bucket",
320+
checkpoint_path="s3://mpt-7b",
318321
)
319322

320323

model-engine/tests/unit/domain/test_llm_use_cases.py

+10-29
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@
4949
GpuType,
5050
ModelDownloadV1UseCase,
5151
UpdateLLMModelEndpointV1UseCase,
52-
_include_safetensors_bin_or_pt,
5352
infer_hardware_from_model_name,
5453
validate_and_update_completion_params,
54+
validate_checkpoint_files,
5555
)
5656
from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase
5757

@@ -141,7 +141,7 @@ async def test_create_model_endpoint_use_case_success(
141141
"inference_framework_image_tag": create_llm_model_endpoint_request_sync.inference_framework_image_tag,
142142
"num_shards": create_llm_model_endpoint_request_sync.num_shards,
143143
"quantize": None,
144-
"checkpoint_path": None,
144+
"checkpoint_path": create_llm_model_endpoint_request_sync.checkpoint_path,
145145
}
146146
}
147147

@@ -166,7 +166,7 @@ async def test_create_model_endpoint_use_case_success(
166166
"inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag,
167167
"num_shards": create_llm_model_endpoint_request_streaming.num_shards,
168168
"quantize": None,
169-
"checkpoint_path": None,
169+
"checkpoint_path": create_llm_model_endpoint_request_sync.checkpoint_path,
170170
}
171171
}
172172

@@ -295,7 +295,6 @@ async def test_create_model_bundle_inference_framework_image_tag_validation(
295295
request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy()
296296
request.inference_framework = inference_framework
297297
request.inference_framework_image_tag = inference_framework_image_tag
298-
request.checkpoint_path = "s3://test-s3.tar"
299298
user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True)
300299
if valid:
301300
await use_case.execute(user=user, request=request)
@@ -1755,34 +1754,16 @@ async def test_delete_public_inference_model_raises_not_authorized(
17551754

17561755

17571756
@pytest.mark.asyncio
1758-
async def test_include_safetensors_bin_or_pt_majority_safetensors():
1759-
fake_model_files = ["fake.bin", "fake2.safetensors", "model.json", "optimizer.pt"]
1760-
assert _include_safetensors_bin_or_pt(fake_model_files) == "*.safetensors"
1761-
1762-
1763-
@pytest.mark.asyncio
1764-
async def test_include_safetensors_bin_or_pt_majority_bin():
1765-
fake_model_files = [
1766-
"fake.bin",
1767-
"fake2.bin",
1768-
"fake3.safetensors",
1769-
"model.json",
1770-
"optimizer.pt",
1771-
"fake4.pt",
1772-
]
1773-
assert _include_safetensors_bin_or_pt(fake_model_files) == "*.bin"
1757+
async def test_validate_checkpoint_files_no_safetensors():
1758+
fake_model_files = ["model-fake.bin", "model.json", "optimizer.pt"]
1759+
with pytest.raises(ObjectHasInvalidValueException):
1760+
validate_checkpoint_files(fake_model_files)
17741761

17751762

17761763
@pytest.mark.asyncio
1777-
async def test_include_safetensors_bin_or_pt_majority_pt():
1778-
fake_model_files = [
1779-
"fake.bin",
1780-
"fake2.safetensors",
1781-
"model.json",
1782-
"optimizer.pt",
1783-
"fake3.pt",
1784-
]
1785-
assert _include_safetensors_bin_or_pt(fake_model_files) == "*.pt"
1764+
async def test_validate_checkpoint_files_safetensors_with_other_files():
1765+
fake_model_files = ["model-fake.bin", "model-fake2.safetensors", "model.json", "optimizer.pt"]
1766+
validate_checkpoint_files(fake_model_files) # No exception should be raised
17861767

17871768

17881769
def test_infer_hardware_from_model_name():

0 commit comments

Comments
 (0)