update
This commit is contained in:
44
vllm/renderers/embed_utils.py
Normal file
44
vllm/renderers/embed_utils.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
|
||||
def safe_load_prompt_embeds(
|
||||
model_config: "ModelConfig",
|
||||
embed: bytes,
|
||||
) -> torch.Tensor:
|
||||
if not model_config.enable_prompt_embeds:
|
||||
raise VLLMValidationError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`.",
|
||||
parameter="prompt_embeds",
|
||||
)
|
||||
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(
|
||||
BytesIO(pybase64.b64decode(embed, validate=True)),
|
||||
weights_only=True,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
tensor = tensor.to_dense()
|
||||
|
||||
if tensor.dim() > 2:
|
||||
tensor = tensor.squeeze(0)
|
||||
assert tensor.dim() == 2
|
||||
|
||||
return tensor
|
||||
Reference in New Issue
Block a user