Files
xc-llm-kunlun/vllm_kunlun/ops/linear.py
Xinyu Dong b3c30a3cb9 [Feature] Support XiaoMi MIMO Flash V2 (#62)
* [Feature] Support MIMO Flash V2
2025-12-31 10:16:33 +08:00

436 lines
17 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED,
ReplicatedLinear,
UnquantizedLinearMethod,
ColumnParallelLinear
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import (
BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter,
)
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_weights(self):
"""get_weights"""
if hasattr(self, "kunlun_linear_weights"):
return self.kunlun_linear_weights
weights = torch.nn.Parameter(self.weight.to(torch.float32))
self.register_parameter("kunlun_linear_weights", weights)
return self.kunlun_linear_weights
def get_weights_half(self):
"""get_weights_half"""
if hasattr(self, "kunlun_linear_weights_half"):
return self.kunlun_linear_weights_half
weights = torch.nn.Parameter(self.weight.to(torch.float16))
ReplicatedLinear.get_weights = get_weights
ReplicatedLinear.get_weights_half = get_weights_half
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(
torch.empty(
sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype
),
requires_grad=False,
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
# rewrite create_weights and remove weight_loader_v2 to suport cuda graph
UnquantizedLinearMethod.create_weights = create_weights
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
class QKVParallelLinear(ColumnParallelLinear):
"""
Base on v0.11.0 QKVParallelLinear, And add v_head size for swa (MIMO V2)
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: int | None = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
v_head_size: int | None = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
self.v_head_size = v_head_size if v_head_size is not None else head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (
self.num_heads * self.head_size
+ self.num_kv_heads * self.head_size
+ self.num_kv_heads * self.v_head_size
) * tp_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.v_head_size * tp_size, # v_proj
]
super().__init__(
input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=False,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp,
)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.v_head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def _load_fused_module_from_checkpoint(
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
):
"""
Handle special case for models where QKV layers are already
fused on disk. In this case, we have no shard id. This function
determines the shard id by splitting these layers and then calls
the weight loader using the shard id.
An example of a model with these fused layers:
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
"""
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if (
isinstance(param, (PackedColumnParameter, PackedvLLMParameter))
and param.packed_dim == param.output_dim
):
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset
)
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(
self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: str | None = None,
):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id in ["q", "k", "v"]
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
# Note(simon): This is needed for Qwen3's fp8 quantization.
if isinstance(param, BlockQuantScaleParameter):
assert self.quant_method is not None
# Assume the weight block size has been set by quant method
assert hasattr(self, "weight_block_size")
weight_block_size = self.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (shard_offset + block_n - 1) // block_n
shard_size = (shard_size + block_n - 1) // block_n
param.load_qkv_weight(
loaded_weight=loaded_weight,
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
)
def weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: str | None = None,
):
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
idx_map = {"q": 0, "k": 1, "v": 2}
if loaded_shard_id is not None:
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
else:
param.shard_weight_type = {k: loaded_weight.item() for k in idx_map}
return
if is_gguf_weight:
output_dim = getattr(param, "output_dim", None)
shard_size = loaded_weight.size(output_dim) // self.tp_size
start_idx = self.tp_rank * shard_size
if loaded_shard_id is not None:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
return
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, 0
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
]
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset
)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.total_num_heads * self.head_size),
"k": (
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
"v": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size,
self.total_num_kv_heads * self.v_head_size,
),
"total": (
(self.total_num_heads + self.total_num_kv_heads)
* self.head_size
+ self.total_num_kv_heads * self.v_head_size,
0,
),
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, shard_id
)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
assert loaded_shard_id in ["q", "k", "v"]
# If output dim is defined, use the default loading process.
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.v_head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.packed_factor
shard_offset = shard_offset // param.packed_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset
)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (
self.num_heads * self.head_size,
self.num_kv_heads * self.head_size,
),
"v": (
(self.num_heads + self.num_kv_heads) * self.head_size,
self.num_kv_heads * self.v_head_size,
),
"total": (
(self.num_heads + self.num_kv_heads) * self.head_size
+ self.num_kv_heads * self.v_head_size,
0,
),
}
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id
)
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
if loaded_shard_id == "q":
shard_rank = self.tp_rank
else:
shard_rank = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for per-tensor scales in fused case.
elif needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
param_data, loaded_weight, loaded_shard_id
)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)