Files
enginex-mlu590-vllm/vllm_mlu/model_executor/model_loader/dummy_loader.py
2026-04-24 09:58:03 +08:00

173 lines
6.7 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple
from tqdm import tqdm
from vllm.config import ModelConfig
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def initialize_dummy_weights_normal_dist(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
std: float = 0.5,
seed: int = 1234,
) -> None:
"""
Initialize the weights of a PyTorch model with values drawn from a normal distribution.
Floating point parameters are initialized with a normal distribution whose mean is randomly
sampled from [low, high] and standard deviation is fixed at 0.5. Integer parameters are
initialized with random integers in [floor(low), ceil(high)). The initialization is performed
in a batched and efficient way for both floating point and integer parameters.
Optimized version: Uses shared pinned memory based on the largest parameter block size
to minimize H2D transfers, sacrificing global uniqueness for performance.
Args:
model (torch.nn.Module): The model whose weights will be initialized.
low (float): Lower bound for sampling the mean of the normal distribution (for float params).
high (float): Upper bound for sampling the mean of the normal distribution (for float params).
std (float): Standard deviation for the normal distribution (for float params).
seed (int): Random seed for reproducibility.
"""
# Randomly sample the mean for the normal distribution from [low, high]
rng = np.random.RandomState(seed)
mean = float(rng.uniform(low, high, 1).item())
# Create a CPU generator for reproducibility
cpu_gen = torch.Generator(device="cpu")
cpu_gen.manual_seed(seed)
# Collect parameters: separate into floating point and integer types
float_params: List[Tuple[str, torch.Tensor]] = []
int_params: List[Tuple[str, torch.Tensor]] = []
for name, t in tqdm(model.state_dict().items(), desc="Gen dummy weights: Collect params"):
if not isinstance(t, torch.Tensor):
continue
if torch.is_floating_point(t):
float_params.append((name, t))
elif t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64):
int_params.append((name, t))
# -------- Floating point parameters: optimized shared memory initialization --------
if float_params:
# Find the largest parameter block size
max_float_elems = max(p.numel() for _, p in float_params)
# Create shared pinned memory buffer based on largest parameter
shared_float_buffer = torch.empty(max_float_elems, dtype=torch.float32, device="cpu", pin_memory=True)
shared_float_buffer.normal_(mean=mean, std=std, generator=cpu_gen)
# Copy shared buffer to device once
device_buffer = shared_float_buffer.to(next(iter(float_params))[1].device, non_blocking=True)
for _, p in tqdm(float_params, desc="Gen dummy weights: Init float params"):
n = p.numel()
# Extract from device buffer (may reuse same values for different parameters)
view = device_buffer[:n].view(p.shape)
# torch.normal_ does not support dtypes < fp16, so cast via fp16 if needed
if torch.finfo(p.dtype).bits < 16:
tmp = view.to(torch.float16)
tmp = tmp.to(p.dtype)
else:
tmp = view.to(p.dtype)
# Copy from device buffer to parameter (D2D copy, much faster)
p.data.copy_(tmp)
# -------- Integer parameters: optimized shared memory initialization --------
if int_params:
# Find the largest parameter block size
max_int_elems = max(p.numel() for _, p in int_params)
int_low = int(np.floor(low))
int_high = int(np.ceil(high))
if int_high == int_low:
int_high = int_low + 1 # Ensure at least one possible value
# Create shared pinned memory buffer based on largest parameter
shared_int_buffer = torch.randint(
low=int_low,
high=int_high,
size=(max_int_elems,),
dtype=torch.int64,
generator=cpu_gen,
device="cpu",
pin_memory=True
)
# Copy shared buffer to device once
device_int_buffer = shared_int_buffer.to(next(iter(int_params))[1].device, non_blocking=True)
for _, p in tqdm(int_params, desc="Gen dummy weights: Init int params"):
n = p.numel()
# Extract from device buffer (may reuse same values for different parameters)
view = device_int_buffer[:n].view(p.shape)
tmp = view.to(p.dtype)
# Copy from device buffer to parameter (D2D copy, much faster)
p.data.copy_(tmp)
SMOOTHQUANT_METHOD = "smoothquant"
MULTIMODAL_ARCH_KEYWORDS = {"VL", "Vision", "Multimodal"}
def vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
'''
=============================
Modify by vllm_mlu
=============================
@brief: use torch.normal_ instead of torch.uniform_ for distinguishable logits
std=0.5 is used for better distinguishable logits
'''
# === Default parameter setup (Original values as fallback) ===
low_val = -1e-3
high_val = 1e-3
std_val = 0.5
# === Model and Quantization Check Logic ===
quant_method = getattr(model_config, "quantization", None)
# Attempt to get the architectures list from model_config
archs = getattr(model_config, "architectures", []) or []
# Determine if the model is multimodal (based on architecture names)
is_multimodal = any(
keyword in arch
for arch in archs
for keyword in MULTIMODAL_ARCH_KEYWORDS
)
# === Apply SmoothQuant + Multimodal Parameters ===
if is_multimodal and quant_method == SMOOTHQUANT_METHOD:
# (smoothquant) + Multimodal specific values to mitigate NaN overflow
std_val = 1e-4
initialize_dummy_weights_normal_dist(
model,
low=low_val,
high=high_val,
std=std_val
)
# add a sync to make sure the weights are initialized
torch.mlu.synchronize()
'''
==================
End of MLU Hijack
==================
'''
MluHijackObject.apply_hijack(
DummyModelLoader,
DummyModelLoader.load_weights,
vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights
)