# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.nn as nn from torch.nn.parameter import Parameter from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.linear import ( WEIGHT_LOADER_V2_SUPPORTED, ReplicatedLinear, UnquantizedLinearMethod, ColumnParallelLinear ) from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter, ) from vllm.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) from vllm.logger import init_logger logger = init_logger(__name__) def get_weights(self): """get_weights""" if hasattr(self, "kunlun_linear_weights"): return self.kunlun_linear_weights weights = torch.nn.Parameter(self.weight.to(torch.float32)) self.register_parameter("kunlun_linear_weights", weights) return self.kunlun_linear_weights def get_weights_half(self): """get_weights_half""" if hasattr(self, "kunlun_linear_weights_half"): return self.kunlun_linear_weights_half weights = torch.nn.Parameter(self.weight.to(torch.float16)) ReplicatedLinear.get_weights = get_weights ReplicatedLinear.get_weights_half = get_weights_half def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): weight = Parameter( torch.empty( sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype ), requires_grad=False, ) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) # rewrite create_weights and remove weight_loader_v2 to suport cuda graph UnquantizedLinearMethod.create_weights = create_weights WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") class QKVParallelLinear(ColumnParallelLinear): """ Base on v0.11.0 QKVParallelLinear, And add v_head size for swa (MIMO V2) """ def __init__( self, hidden_size: int, head_size: int, total_num_heads: int, total_num_kv_heads: int | None = None, bias: bool = True, skip_bias_add: bool = False, params_dtype: torch.dtype | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", *, return_bias: bool = True, disable_tp: bool = False, v_head_size: int | None = None, ): self.hidden_size = hidden_size self.head_size = head_size self.v_head_size = v_head_size if v_head_size is not None else head_size self.total_num_heads = total_num_heads if total_num_kv_heads is None: total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 input_size = self.hidden_size output_size = ( self.num_heads * self.head_size + self.num_kv_heads * self.head_size + self.num_kv_heads * self.v_head_size ) * tp_size self.output_sizes = [ self.num_heads * self.head_size * tp_size, # q_proj self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.v_head_size * tp_size, # v_proj ] super().__init__( input_size=input_size, output_size=output_size, bias=bias, gather_output=False, skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, return_bias=return_bias, disable_tp=disable_tp, ) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { "q": 0, "k": self.num_heads * self.head_size, "v": (self.num_heads + self.num_kv_heads) * self.head_size, "total": (self.num_heads + self.num_kv_heads) * self.head_size + self.num_kv_heads * self.v_head_size, } return shard_offset_mapping.get(loaded_shard_id) def _get_shard_size_mapping(self, loaded_shard_id: str): shard_size_mapping = { "q": self.num_heads * self.head_size, "k": self.num_kv_heads * self.head_size, "v": self.num_kv_heads * self.v_head_size, } return shard_size_mapping.get(loaded_shard_id) def _load_fused_module_from_checkpoint( self, param: BasevLLMParameter, loaded_weight: torch.Tensor ): """ Handle special case for models where QKV layers are already fused on disk. In this case, we have no shard id. This function determines the shard id by splitting these layers and then calls the weight loader using the shard id. An example of a model with these fused layers: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct """ shard_offsets = [ # (shard_id, shard_offset, shard_size) ("q", 0, self.total_num_heads * self.head_size), ( "k", self.total_num_heads * self.head_size, self.total_num_kv_heads * self.head_size, ), ( "v", (self.total_num_heads + self.total_num_kv_heads) * self.head_size, self.total_num_kv_heads * self.v_head_size, ), ] for shard_id, shard_offset, shard_size in shard_offsets: # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. if ( isinstance(param, (PackedColumnParameter, PackedvLLMParameter)) and param.packed_dim == param.output_dim ): shard_size, shard_offset = param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset ) loaded_weight_shard = loaded_weight.narrow( param.output_dim, shard_offset, shard_size ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) def weight_loader_v2( self, param: BasevLLMParameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None, ): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): param.load_qkv_weight( loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank ) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank) return # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return assert loaded_shard_id in ["q", "k", "v"] shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id) # Note(simon): This is needed for Qwen3's fp8 quantization. if isinstance(param, BlockQuantScaleParameter): assert self.quant_method is not None # Assume the weight block size has been set by quant method assert hasattr(self, "weight_block_size") weight_block_size = self.weight_block_size assert weight_block_size is not None block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = (shard_offset + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n param.load_qkv_weight( loaded_weight=loaded_weight, num_heads=self.num_kv_head_replicas, shard_id=loaded_shard_id, shard_offset=shard_offset, shard_size=shard_size, tp_rank=self.tp_rank, ) def weight_loader( self, param: Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None, ): # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) if is_gguf_weight_type: idx_map = {"q": 0, "k": 1, "v": 2} if loaded_shard_id is not None: param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: param.shard_weight_type = {k: loaded_weight.item() for k in idx_map} return if is_gguf_weight: output_dim = getattr(param, "output_dim", None) shard_size = loaded_weight.size(output_dim) // self.tp_size start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) return param_data = param.data output_dim = getattr(param, "output_dim", None) # Special case for per-tensor scales in fused case. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) if loaded_shard_id is None: # Loaded weight is already fused on disk (qkv). # (e.g., Phi-3's qkv_proj). if output_dim is None: if needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight, 0 ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return shard_offsets = [ # (shard_id, shard_offset, shard_size) ("q", 0, self.total_num_heads * self.head_size), ( "k", self.total_num_heads * self.head_size, self.total_num_kv_heads * self.head_size, ), ( "v", (self.total_num_heads + self.total_num_kv_heads) * self.head_size, self.total_num_kv_heads * self.v_head_size, ), ] use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.packed_factor shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset ) if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.total_num_heads * self.head_size), "k": ( self.total_num_heads * self.head_size, self.total_num_kv_heads * self.head_size, ), "v": ( (self.total_num_heads + self.total_num_kv_heads) * self.head_size, self.total_num_kv_heads * self.v_head_size, ), "total": ( (self.total_num_heads + self.total_num_kv_heads) * self.head_size + self.total_num_kv_heads * self.v_head_size, 0, ), } shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, shard_id ) loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size ) self.weight_loader(param, loaded_weight_shard, shard_id) return assert loaded_shard_id in ["q", "k", "v"] # If output dim is defined, use the default loading process. if output_dim is not None: if loaded_shard_id == "q": shard_offset = 0 shard_size = self.num_heads * self.head_size elif loaded_shard_id == "k": shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size elif loaded_shard_id == "v": shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.v_head_size # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: shard_size = shard_size // param.packed_factor shard_offset = shard_offset // param.packed_factor # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset ) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) # bitsandbytes loads the weights of the specific portion # no need to narrow is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), "k": ( self.num_heads * self.head_size, self.num_kv_heads * self.head_size, ), "v": ( (self.num_heads + self.num_kv_heads) * self.head_size, self.num_kv_heads * self.v_head_size, ), "total": ( (self.num_heads + self.num_kv_heads) * self.head_size + self.num_kv_heads * self.v_head_size, 0, ), } shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, loaded_shard_id ) param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": shard_rank = self.tp_rank else: shard_rank = self.tp_rank // self.num_kv_head_replicas start_idx = shard_rank * shard_size if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for per-tensor scales in fused case. elif needs_scalar_to_array: param_data, loaded_weight = adjust_scalar_to_fused_array( param_data, loaded_weight, loaded_shard_id ) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: logger.warning( "Loading a weight without `output_dim` attribute in " "QKVParallelLinear, assume the weight is the same " "for all partitions." ) assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight)