[Model] Support DeepSeek-V4
This commit is contained in:
3
vllm_mlu/model_executor/model_loader/__init__.py
Normal file
3
vllm_mlu/model_executor/model_loader/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
173
vllm_mlu/model_executor/model_loader/dummy_loader.py
Normal file
173
vllm_mlu/model_executor/model_loader/dummy_loader.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from typing import List, Tuple
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def initialize_dummy_weights_normal_dist(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
std: float = 0.5,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the weights of a PyTorch model with values drawn from a normal distribution.
|
||||
Floating point parameters are initialized with a normal distribution whose mean is randomly
|
||||
sampled from [low, high] and standard deviation is fixed at 0.5. Integer parameters are
|
||||
initialized with random integers in [floor(low), ceil(high)). The initialization is performed
|
||||
in a batched and efficient way for both floating point and integer parameters.
|
||||
|
||||
Optimized version: Uses shared pinned memory based on the largest parameter block size
|
||||
to minimize H2D transfers, sacrificing global uniqueness for performance.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model whose weights will be initialized.
|
||||
low (float): Lower bound for sampling the mean of the normal distribution (for float params).
|
||||
high (float): Upper bound for sampling the mean of the normal distribution (for float params).
|
||||
std (float): Standard deviation for the normal distribution (for float params).
|
||||
seed (int): Random seed for reproducibility.
|
||||
"""
|
||||
# Randomly sample the mean for the normal distribution from [low, high]
|
||||
rng = np.random.RandomState(seed)
|
||||
mean = float(rng.uniform(low, high, 1).item())
|
||||
|
||||
# Create a CPU generator for reproducibility
|
||||
cpu_gen = torch.Generator(device="cpu")
|
||||
cpu_gen.manual_seed(seed)
|
||||
|
||||
# Collect parameters: separate into floating point and integer types
|
||||
float_params: List[Tuple[str, torch.Tensor]] = []
|
||||
int_params: List[Tuple[str, torch.Tensor]] = []
|
||||
|
||||
for name, t in tqdm(model.state_dict().items(), desc="Gen dummy weights: Collect params"):
|
||||
if not isinstance(t, torch.Tensor):
|
||||
continue
|
||||
if torch.is_floating_point(t):
|
||||
float_params.append((name, t))
|
||||
elif t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64):
|
||||
int_params.append((name, t))
|
||||
|
||||
# -------- Floating point parameters: optimized shared memory initialization --------
|
||||
if float_params:
|
||||
# Find the largest parameter block size
|
||||
max_float_elems = max(p.numel() for _, p in float_params)
|
||||
|
||||
# Create shared pinned memory buffer based on largest parameter
|
||||
shared_float_buffer = torch.empty(max_float_elems, dtype=torch.float32, device="cpu", pin_memory=True)
|
||||
shared_float_buffer.normal_(mean=mean, std=std, generator=cpu_gen)
|
||||
|
||||
# Copy shared buffer to device once
|
||||
device_buffer = shared_float_buffer.to(next(iter(float_params))[1].device, non_blocking=True)
|
||||
|
||||
for _, p in tqdm(float_params, desc="Gen dummy weights: Init float params"):
|
||||
n = p.numel()
|
||||
# Extract from device buffer (may reuse same values for different parameters)
|
||||
view = device_buffer[:n].view(p.shape)
|
||||
|
||||
# torch.normal_ does not support dtypes < fp16, so cast via fp16 if needed
|
||||
if torch.finfo(p.dtype).bits < 16:
|
||||
tmp = view.to(torch.float16)
|
||||
tmp = tmp.to(p.dtype)
|
||||
else:
|
||||
tmp = view.to(p.dtype)
|
||||
|
||||
# Copy from device buffer to parameter (D2D copy, much faster)
|
||||
p.data.copy_(tmp)
|
||||
|
||||
# -------- Integer parameters: optimized shared memory initialization --------
|
||||
if int_params:
|
||||
# Find the largest parameter block size
|
||||
max_int_elems = max(p.numel() for _, p in int_params)
|
||||
|
||||
int_low = int(np.floor(low))
|
||||
int_high = int(np.ceil(high))
|
||||
if int_high == int_low:
|
||||
int_high = int_low + 1 # Ensure at least one possible value
|
||||
|
||||
# Create shared pinned memory buffer based on largest parameter
|
||||
shared_int_buffer = torch.randint(
|
||||
low=int_low,
|
||||
high=int_high,
|
||||
size=(max_int_elems,),
|
||||
dtype=torch.int64,
|
||||
generator=cpu_gen,
|
||||
device="cpu",
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
# Copy shared buffer to device once
|
||||
device_int_buffer = shared_int_buffer.to(next(iter(int_params))[1].device, non_blocking=True)
|
||||
|
||||
for _, p in tqdm(int_params, desc="Gen dummy weights: Init int params"):
|
||||
n = p.numel()
|
||||
# Extract from device buffer (may reuse same values for different parameters)
|
||||
view = device_int_buffer[:n].view(p.shape)
|
||||
tmp = view.to(p.dtype)
|
||||
# Copy from device buffer to parameter (D2D copy, much faster)
|
||||
p.data.copy_(tmp)
|
||||
|
||||
|
||||
SMOOTHQUANT_METHOD = "smoothquant"
|
||||
MULTIMODAL_ARCH_KEYWORDS = {"VL", "Vision", "Multimodal"}
|
||||
def vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use torch.normal_ instead of torch.uniform_ for distinguishable logits
|
||||
std=0.5 is used for better distinguishable logits
|
||||
'''
|
||||
|
||||
# === Default parameter setup (Original values as fallback) ===
|
||||
low_val = -1e-3
|
||||
high_val = 1e-3
|
||||
std_val = 0.5
|
||||
|
||||
# === Model and Quantization Check Logic ===
|
||||
quant_method = getattr(model_config, "quantization", None)
|
||||
|
||||
# Attempt to get the architectures list from model_config
|
||||
archs = getattr(model_config, "architectures", []) or []
|
||||
|
||||
# Determine if the model is multimodal (based on architecture names)
|
||||
is_multimodal = any(
|
||||
keyword in arch
|
||||
for arch in archs
|
||||
for keyword in MULTIMODAL_ARCH_KEYWORDS
|
||||
)
|
||||
|
||||
# === Apply SmoothQuant + Multimodal Parameters ===
|
||||
if is_multimodal and quant_method == SMOOTHQUANT_METHOD:
|
||||
# (smoothquant) + Multimodal specific values to mitigate NaN overflow
|
||||
std_val = 1e-4
|
||||
|
||||
initialize_dummy_weights_normal_dist(
|
||||
model,
|
||||
low=low_val,
|
||||
high=high_val,
|
||||
std=std_val
|
||||
)
|
||||
# add a sync to make sure the weights are initialized
|
||||
torch.mlu.synchronize()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
DummyModelLoader,
|
||||
DummyModelLoader.load_weights,
|
||||
vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights
|
||||
)
|
||||
137
vllm_mlu/model_executor/model_loader/tensorizer.py
Normal file
137
vllm_mlu/model_executor/model_loader/tensorizer.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import time
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, TensorDeserializer, TensorizerArgs,
|
||||
_check_tensors_on_meta_device, _resize_lora_embeddings,
|
||||
is_valid_deserialization_uri)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
from tensorizer.stream_io import open_stream
|
||||
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||
no_init_or_tensor)
|
||||
|
||||
except ImportError:
|
||||
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||||
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def deserialize_tensorizer_model(model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
|
||||
raise ValueError(
|
||||
f"{tensorizer_config.tensorizer_uri} is not a valid "
|
||||
f"tensorizer URI. Please check that the URI is correct. "
|
||||
f"It must either point to a local existing file, or have a "
|
||||
f"S3, HTTP or HTTPS scheme.")
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use mlu device
|
||||
'''
|
||||
device = ''
|
||||
if current_platform.is_out_of_tree():
|
||||
device = f'mlu:{torch.mlu.current_device()}'
|
||||
elif current_platform.is_xpu():
|
||||
device = f'xpu:{torch.xpu.current_device()}'
|
||||
else:
|
||||
device = f'cuda:{torch.cuda.current_device()}'
|
||||
with open_stream(
|
||||
tensorizer_config.tensorizer_uri,
|
||||
mode="rb",
|
||||
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
|
||||
stream,
|
||||
dtype=tensorizer_config.dtype,
|
||||
device=device,
|
||||
**tensorizer_args.deserialization_kwargs) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.perf_counter()
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
deserializer.close()
|
||||
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
|
||||
end - start, per_second)
|
||||
logger.info("Memory usage before: %s", before_mem)
|
||||
logger.info("Memory usage after: %s", after_mem)
|
||||
|
||||
_check_tensors_on_meta_device(model)
|
||||
_resize_lora_embeddings(model)
|
||||
del model.vllm_tensorized_marker
|
||||
|
||||
def serialize_extra_artifacts(
|
||||
tensorizer_args: TensorizerArgs,
|
||||
served_model_name: Union[str, list[str], None]) -> None:
|
||||
if not isinstance(served_model_name, str):
|
||||
raise ValueError(
|
||||
f"served_model_name must be a str for serialize_extra_artifacts, "
|
||||
f"not {type(served_model_name)}.")
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use local file
|
||||
'''
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
local_model_path = Path(served_model_name)
|
||||
if not local_model_path.exists() or not local_model_path.is_dir():
|
||||
raise ValueError(
|
||||
f"served_model_name must be a valid local directory in offline mode, "
|
||||
f"but got: {served_model_name}"
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: copy local file
|
||||
'''
|
||||
logger.info("Copying local model from %s to temporary directory %s",
|
||||
local_model_path, tmpdir)
|
||||
shutil.copytree(local_model_path, tmpdir, dirs_exist_ok=True)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
for artifact in os.scandir(tmpdir):
|
||||
if not artifact.is_file():
|
||||
continue
|
||||
with open(artifact.path, "rb") as f, open_stream(
|
||||
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs) as stream:
|
||||
logger.info("Writing artifact %s", artifact.name)
|
||||
stream.write(f.read())
|
||||
|
||||
35
vllm_mlu/model_executor/model_loader/tensorizer_loader.py
Normal file
35
vllm_mlu/model_executor/model_loader/tensorizer_loader.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.model_loader.tensorizer import is_vllm_tensorized
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
|
||||
from vllm_mlu.model_executor.model_loader.tensorizer import deserialize_tensorizer_model
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_config: ModelConfig
|
||||
) -> None:
|
||||
"""Load serialized model weights with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/others/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
deserialize_tensorizer_model(model, tensorizer_config)
|
||||
else:
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
TensorizerLoader,
|
||||
TensorizerLoader.load_weights,
|
||||
vllm__model_executor__model_loader__tensorizer_loader__TensorizerLoader__load_weights
|
||||
)
|
||||
Reference in New Issue
Block a user