173 lines
6.7 KiB
Python
173 lines
6.7 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import numpy as np
|
||
|
|
from typing import List, Tuple
|
||
|
|
from tqdm import tqdm
|
||
|
|
|
||
|
|
from vllm.config import ModelConfig
|
||
|
|
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||
|
|
|
||
|
|
|
||
|
|
def initialize_dummy_weights_normal_dist(
|
||
|
|
model: torch.nn.Module,
|
||
|
|
low: float = -1e-3,
|
||
|
|
high: float = 1e-3,
|
||
|
|
std: float = 0.5,
|
||
|
|
seed: int = 1234,
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Initialize the weights of a PyTorch model with values drawn from a normal distribution.
|
||
|
|
Floating point parameters are initialized with a normal distribution whose mean is randomly
|
||
|
|
sampled from [low, high] and standard deviation is fixed at 0.5. Integer parameters are
|
||
|
|
initialized with random integers in [floor(low), ceil(high)). The initialization is performed
|
||
|
|
in a batched and efficient way for both floating point and integer parameters.
|
||
|
|
|
||
|
|
Optimized version: Uses shared pinned memory based on the largest parameter block size
|
||
|
|
to minimize H2D transfers, sacrificing global uniqueness for performance.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
model (torch.nn.Module): The model whose weights will be initialized.
|
||
|
|
low (float): Lower bound for sampling the mean of the normal distribution (for float params).
|
||
|
|
high (float): Upper bound for sampling the mean of the normal distribution (for float params).
|
||
|
|
std (float): Standard deviation for the normal distribution (for float params).
|
||
|
|
seed (int): Random seed for reproducibility.
|
||
|
|
"""
|
||
|
|
# Randomly sample the mean for the normal distribution from [low, high]
|
||
|
|
rng = np.random.RandomState(seed)
|
||
|
|
mean = float(rng.uniform(low, high, 1).item())
|
||
|
|
|
||
|
|
# Create a CPU generator for reproducibility
|
||
|
|
cpu_gen = torch.Generator(device="cpu")
|
||
|
|
cpu_gen.manual_seed(seed)
|
||
|
|
|
||
|
|
# Collect parameters: separate into floating point and integer types
|
||
|
|
float_params: List[Tuple[str, torch.Tensor]] = []
|
||
|
|
int_params: List[Tuple[str, torch.Tensor]] = []
|
||
|
|
|
||
|
|
for name, t in tqdm(model.state_dict().items(), desc="Gen dummy weights: Collect params"):
|
||
|
|
if not isinstance(t, torch.Tensor):
|
||
|
|
continue
|
||
|
|
if torch.is_floating_point(t):
|
||
|
|
float_params.append((name, t))
|
||
|
|
elif t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64):
|
||
|
|
int_params.append((name, t))
|
||
|
|
|
||
|
|
# -------- Floating point parameters: optimized shared memory initialization --------
|
||
|
|
if float_params:
|
||
|
|
# Find the largest parameter block size
|
||
|
|
max_float_elems = max(p.numel() for _, p in float_params)
|
||
|
|
|
||
|
|
# Create shared pinned memory buffer based on largest parameter
|
||
|
|
shared_float_buffer = torch.empty(max_float_elems, dtype=torch.float32, device="cpu", pin_memory=True)
|
||
|
|
shared_float_buffer.normal_(mean=mean, std=std, generator=cpu_gen)
|
||
|
|
|
||
|
|
# Copy shared buffer to device once
|
||
|
|
device_buffer = shared_float_buffer.to(next(iter(float_params))[1].device, non_blocking=True)
|
||
|
|
|
||
|
|
for _, p in tqdm(float_params, desc="Gen dummy weights: Init float params"):
|
||
|
|
n = p.numel()
|
||
|
|
# Extract from device buffer (may reuse same values for different parameters)
|
||
|
|
view = device_buffer[:n].view(p.shape)
|
||
|
|
|
||
|
|
# torch.normal_ does not support dtypes < fp16, so cast via fp16 if needed
|
||
|
|
if torch.finfo(p.dtype).bits < 16:
|
||
|
|
tmp = view.to(torch.float16)
|
||
|
|
tmp = tmp.to(p.dtype)
|
||
|
|
else:
|
||
|
|
tmp = view.to(p.dtype)
|
||
|
|
|
||
|
|
# Copy from device buffer to parameter (D2D copy, much faster)
|
||
|
|
p.data.copy_(tmp)
|
||
|
|
|
||
|
|
# -------- Integer parameters: optimized shared memory initialization --------
|
||
|
|
if int_params:
|
||
|
|
# Find the largest parameter block size
|
||
|
|
max_int_elems = max(p.numel() for _, p in int_params)
|
||
|
|
|
||
|
|
int_low = int(np.floor(low))
|
||
|
|
int_high = int(np.ceil(high))
|
||
|
|
if int_high == int_low:
|
||
|
|
int_high = int_low + 1 # Ensure at least one possible value
|
||
|
|
|
||
|
|
# Create shared pinned memory buffer based on largest parameter
|
||
|
|
shared_int_buffer = torch.randint(
|
||
|
|
low=int_low,
|
||
|
|
high=int_high,
|
||
|
|
size=(max_int_elems,),
|
||
|
|
dtype=torch.int64,
|
||
|
|
generator=cpu_gen,
|
||
|
|
device="cpu",
|
||
|
|
pin_memory=True
|
||
|
|
)
|
||
|
|
|
||
|
|
# Copy shared buffer to device once
|
||
|
|
device_int_buffer = shared_int_buffer.to(next(iter(int_params))[1].device, non_blocking=True)
|
||
|
|
|
||
|
|
for _, p in tqdm(int_params, desc="Gen dummy weights: Init int params"):
|
||
|
|
n = p.numel()
|
||
|
|
# Extract from device buffer (may reuse same values for different parameters)
|
||
|
|
view = device_int_buffer[:n].view(p.shape)
|
||
|
|
tmp = view.to(p.dtype)
|
||
|
|
# Copy from device buffer to parameter (D2D copy, much faster)
|
||
|
|
p.data.copy_(tmp)
|
||
|
|
|
||
|
|
|
||
|
|
SMOOTHQUANT_METHOD = "smoothquant"
|
||
|
|
MULTIMODAL_ARCH_KEYWORDS = {"VL", "Vision", "Multimodal"}
|
||
|
|
def vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights(self, model: nn.Module,
|
||
|
|
model_config: ModelConfig) -> None:
|
||
|
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||
|
|
# random values to the weights.
|
||
|
|
'''
|
||
|
|
=============================
|
||
|
|
Modify by vllm_mlu
|
||
|
|
=============================
|
||
|
|
@brief: use torch.normal_ instead of torch.uniform_ for distinguishable logits
|
||
|
|
std=0.5 is used for better distinguishable logits
|
||
|
|
'''
|
||
|
|
|
||
|
|
# === Default parameter setup (Original values as fallback) ===
|
||
|
|
low_val = -1e-3
|
||
|
|
high_val = 1e-3
|
||
|
|
std_val = 0.5
|
||
|
|
|
||
|
|
# === Model and Quantization Check Logic ===
|
||
|
|
quant_method = getattr(model_config, "quantization", None)
|
||
|
|
|
||
|
|
# Attempt to get the architectures list from model_config
|
||
|
|
archs = getattr(model_config, "architectures", []) or []
|
||
|
|
|
||
|
|
# Determine if the model is multimodal (based on architecture names)
|
||
|
|
is_multimodal = any(
|
||
|
|
keyword in arch
|
||
|
|
for arch in archs
|
||
|
|
for keyword in MULTIMODAL_ARCH_KEYWORDS
|
||
|
|
)
|
||
|
|
|
||
|
|
# === Apply SmoothQuant + Multimodal Parameters ===
|
||
|
|
if is_multimodal and quant_method == SMOOTHQUANT_METHOD:
|
||
|
|
# (smoothquant) + Multimodal specific values to mitigate NaN overflow
|
||
|
|
std_val = 1e-4
|
||
|
|
|
||
|
|
initialize_dummy_weights_normal_dist(
|
||
|
|
model,
|
||
|
|
low=low_val,
|
||
|
|
high=high_val,
|
||
|
|
std=std_val
|
||
|
|
)
|
||
|
|
# add a sync to make sure the weights are initialized
|
||
|
|
torch.mlu.synchronize()
|
||
|
|
'''
|
||
|
|
==================
|
||
|
|
End of MLU Hijack
|
||
|
|
==================
|
||
|
|
'''
|
||
|
|
|
||
|
|
MluHijackObject.apply_hijack(
|
||
|
|
DummyModelLoader,
|
||
|
|
DummyModelLoader.load_weights,
|
||
|
|
vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights
|
||
|
|
)
|