# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project import torch import torch.nn as nn import numpy as np from typing import List, Tuple from tqdm import tqdm from vllm.config import ModelConfig from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader from vllm_mlu.mlu_hijack_utils import MluHijackObject def initialize_dummy_weights_normal_dist( model: torch.nn.Module, low: float = -1e-3, high: float = 1e-3, std: float = 0.5, seed: int = 1234, ) -> None: """ Initialize the weights of a PyTorch model with values drawn from a normal distribution. Floating point parameters are initialized with a normal distribution whose mean is randomly sampled from [low, high] and standard deviation is fixed at 0.5. Integer parameters are initialized with random integers in [floor(low), ceil(high)). The initialization is performed in a batched and efficient way for both floating point and integer parameters. Optimized version: Uses shared pinned memory based on the largest parameter block size to minimize H2D transfers, sacrificing global uniqueness for performance. Args: model (torch.nn.Module): The model whose weights will be initialized. low (float): Lower bound for sampling the mean of the normal distribution (for float params). high (float): Upper bound for sampling the mean of the normal distribution (for float params). std (float): Standard deviation for the normal distribution (for float params). seed (int): Random seed for reproducibility. """ # Randomly sample the mean for the normal distribution from [low, high] rng = np.random.RandomState(seed) mean = float(rng.uniform(low, high, 1).item()) # Create a CPU generator for reproducibility cpu_gen = torch.Generator(device="cpu") cpu_gen.manual_seed(seed) # Collect parameters: separate into floating point and integer types float_params: List[Tuple[str, torch.Tensor]] = [] int_params: List[Tuple[str, torch.Tensor]] = [] for name, t in tqdm(model.state_dict().items(), desc="Gen dummy weights: Collect params"): if not isinstance(t, torch.Tensor): continue if torch.is_floating_point(t): float_params.append((name, t)) elif t.dtype in (torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64): int_params.append((name, t)) # -------- Floating point parameters: optimized shared memory initialization -------- if float_params: # Find the largest parameter block size max_float_elems = max(p.numel() for _, p in float_params) # Create shared pinned memory buffer based on largest parameter shared_float_buffer = torch.empty(max_float_elems, dtype=torch.float32, device="cpu", pin_memory=True) shared_float_buffer.normal_(mean=mean, std=std, generator=cpu_gen) # Copy shared buffer to device once device_buffer = shared_float_buffer.to(next(iter(float_params))[1].device, non_blocking=True) for _, p in tqdm(float_params, desc="Gen dummy weights: Init float params"): n = p.numel() # Extract from device buffer (may reuse same values for different parameters) view = device_buffer[:n].view(p.shape) # torch.normal_ does not support dtypes < fp16, so cast via fp16 if needed if torch.finfo(p.dtype).bits < 16: tmp = view.to(torch.float16) tmp = tmp.to(p.dtype) else: tmp = view.to(p.dtype) # Copy from device buffer to parameter (D2D copy, much faster) p.data.copy_(tmp) # -------- Integer parameters: optimized shared memory initialization -------- if int_params: # Find the largest parameter block size max_int_elems = max(p.numel() for _, p in int_params) int_low = int(np.floor(low)) int_high = int(np.ceil(high)) if int_high == int_low: int_high = int_low + 1 # Ensure at least one possible value # Create shared pinned memory buffer based on largest parameter shared_int_buffer = torch.randint( low=int_low, high=int_high, size=(max_int_elems,), dtype=torch.int64, generator=cpu_gen, device="cpu", pin_memory=True ) # Copy shared buffer to device once device_int_buffer = shared_int_buffer.to(next(iter(int_params))[1].device, non_blocking=True) for _, p in tqdm(int_params, desc="Gen dummy weights: Init int params"): n = p.numel() # Extract from device buffer (may reuse same values for different parameters) view = device_int_buffer[:n].view(p.shape) tmp = view.to(p.dtype) # Copy from device buffer to parameter (D2D copy, much faster) p.data.copy_(tmp) SMOOTHQUANT_METHOD = "smoothquant" MULTIMODAL_ARCH_KEYWORDS = {"VL", "Vision", "Multimodal"} def vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. ''' ============================= Modify by vllm_mlu ============================= @brief: use torch.normal_ instead of torch.uniform_ for distinguishable logits std=0.5 is used for better distinguishable logits ''' # === Default parameter setup (Original values as fallback) === low_val = -1e-3 high_val = 1e-3 std_val = 0.5 # === Model and Quantization Check Logic === quant_method = getattr(model_config, "quantization", None) # Attempt to get the architectures list from model_config archs = getattr(model_config, "architectures", []) or [] # Determine if the model is multimodal (based on architecture names) is_multimodal = any( keyword in arch for arch in archs for keyword in MULTIMODAL_ARCH_KEYWORDS ) # === Apply SmoothQuant + Multimodal Parameters === if is_multimodal and quant_method == SMOOTHQUANT_METHOD: # (smoothquant) + Multimodal specific values to mitigate NaN overflow std_val = 1e-4 initialize_dummy_weights_normal_dist( model, low=low_val, high=high_val, std=std_val ) # add a sync to make sure the weights are initialized torch.mlu.synchronize() ''' ================== End of MLU Hijack ================== ''' MluHijackObject.apply_hijack( DummyModelLoader, DummyModelLoader.load_weights, vllm__model_executor__model_loader__dummy_loader__DummyModelLoader__load_weights )