138 lines
4.8 KiB
Python
138 lines
4.8 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
import time
|
||
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||
|
|
|
||
|
|
from vllm.model_executor.model_loader.tensorizer import (
|
||
|
|
TensorizerConfig, TensorDeserializer, TensorizerArgs,
|
||
|
|
_check_tensors_on_meta_device, _resize_lora_embeddings,
|
||
|
|
is_valid_deserialization_uri)
|
||
|
|
from vllm.platforms import current_platform
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
|
||
|
|
try:
|
||
|
|
from tensorizer.stream_io import open_stream
|
||
|
|
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||
|
|
no_init_or_tensor)
|
||
|
|
|
||
|
|
except ImportError:
|
||
|
|
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||
|
|
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||
|
|
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||
|
|
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def deserialize_tensorizer_model(model: nn.Module,
|
||
|
|
tensorizer_config: TensorizerConfig) -> None:
|
||
|
|
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||
|
|
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
|
||
|
|
raise ValueError(
|
||
|
|
f"{tensorizer_config.tensorizer_uri} is not a valid "
|
||
|
|
f"tensorizer URI. Please check that the URI is correct. "
|
||
|
|
f"It must either point to a local existing file, or have a "
|
||
|
|
f"S3, HTTP or HTTPS scheme.")
|
||
|
|
before_mem = get_mem_usage()
|
||
|
|
start = time.perf_counter()
|
||
|
|
'''
|
||
|
|
=============================
|
||
|
|
Modify by vllm_mlu
|
||
|
|
=============================
|
||
|
|
@brief: use mlu device
|
||
|
|
'''
|
||
|
|
device = ''
|
||
|
|
if current_platform.is_out_of_tree():
|
||
|
|
device = f'mlu:{torch.mlu.current_device()}'
|
||
|
|
elif current_platform.is_xpu():
|
||
|
|
device = f'xpu:{torch.xpu.current_device()}'
|
||
|
|
else:
|
||
|
|
device = f'cuda:{torch.cuda.current_device()}'
|
||
|
|
with open_stream(
|
||
|
|
tensorizer_config.tensorizer_uri,
|
||
|
|
mode="rb",
|
||
|
|
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
|
||
|
|
stream,
|
||
|
|
dtype=tensorizer_config.dtype,
|
||
|
|
device=device,
|
||
|
|
**tensorizer_args.deserialization_kwargs) as deserializer:
|
||
|
|
deserializer.load_into_module(model)
|
||
|
|
end = time.perf_counter()
|
||
|
|
'''
|
||
|
|
==================
|
||
|
|
End of MLU Hijack
|
||
|
|
==================
|
||
|
|
'''
|
||
|
|
|
||
|
|
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||
|
|
duration = end - start
|
||
|
|
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||
|
|
after_mem = get_mem_usage()
|
||
|
|
deserializer.close()
|
||
|
|
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
|
||
|
|
end - start, per_second)
|
||
|
|
logger.info("Memory usage before: %s", before_mem)
|
||
|
|
logger.info("Memory usage after: %s", after_mem)
|
||
|
|
|
||
|
|
_check_tensors_on_meta_device(model)
|
||
|
|
_resize_lora_embeddings(model)
|
||
|
|
del model.vllm_tensorized_marker
|
||
|
|
|
||
|
|
def serialize_extra_artifacts(
|
||
|
|
tensorizer_args: TensorizerArgs,
|
||
|
|
served_model_name: Union[str, list[str], None]) -> None:
|
||
|
|
if not isinstance(served_model_name, str):
|
||
|
|
raise ValueError(
|
||
|
|
f"served_model_name must be a str for serialize_extra_artifacts, "
|
||
|
|
f"not {type(served_model_name)}.")
|
||
|
|
|
||
|
|
'''
|
||
|
|
=============================
|
||
|
|
Modify by vllm_mlu
|
||
|
|
=============================
|
||
|
|
@brief: use local file
|
||
|
|
'''
|
||
|
|
import shutil
|
||
|
|
from pathlib import Path
|
||
|
|
local_model_path = Path(served_model_name)
|
||
|
|
if not local_model_path.exists() or not local_model_path.is_dir():
|
||
|
|
raise ValueError(
|
||
|
|
f"served_model_name must be a valid local directory in offline mode, "
|
||
|
|
f"but got: {served_model_name}"
|
||
|
|
)
|
||
|
|
'''
|
||
|
|
==================
|
||
|
|
End of MLU Hijack
|
||
|
|
==================
|
||
|
|
'''
|
||
|
|
|
||
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||
|
|
'''
|
||
|
|
=============================
|
||
|
|
Modify by vllm_mlu
|
||
|
|
=============================
|
||
|
|
@brief: copy local file
|
||
|
|
'''
|
||
|
|
logger.info("Copying local model from %s to temporary directory %s",
|
||
|
|
local_model_path, tmpdir)
|
||
|
|
shutil.copytree(local_model_path, tmpdir, dirs_exist_ok=True)
|
||
|
|
'''
|
||
|
|
==================
|
||
|
|
End of MLU Hijack
|
||
|
|
==================
|
||
|
|
'''
|
||
|
|
|
||
|
|
for artifact in os.scandir(tmpdir):
|
||
|
|
if not artifact.is_file():
|
||
|
|
continue
|
||
|
|
with open(artifact.path, "rb") as f, open_stream(
|
||
|
|
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
|
||
|
|
mode="wb+",
|
||
|
|
**tensorizer_args.stream_kwargs) as stream:
|
||
|
|
logger.info("Writing artifact %s", artifact.name)
|
||
|
|
stream.write(f.read())
|
||
|
|
|