Files
2026-04-02 04:55:00 +00:00

465 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import itertools
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
# yapf: disable
from vllm.model_executor.parameter import (BasevLLMParameter,
BlockQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter,
RowvLLMParameter)
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
WEIGHT_LOADER_V2_SUPPORTED,
LinearBase,
RowParallelLinear)
def ReplicatedLinear__init__(self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super(ReplicatedLinear,self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k 如设置为2 即 quant_block_k 是 64
self.scale_k_slice = 1
self.scale_n = 1
self.scale_n_slice = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[1]
import math
if input_size % quant_config.weight_block_size[1]:
gcd_value = math.gcd(input_size % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
self.scale_k_slice = input_size // gcd_value
if output_size % quant_config.weight_block_size[0]:
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.scale_n_slice = output_size // gcd_value
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
scale_k = self.scale_k,
scale_n = self.scale_n,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def ReplicatedLinear_weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])[:, :self.scale_k_slice]
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
assert param.size() == loaded_weight.size(), f'{param.size()}, {loaded_weight.size()}'
param.data.copy_(loaded_weight)
def refine_block(block_size:list[int],
weight_size:list[int],
dim:int=0,
pingpong_size:int = 2.5*1024*1024, #bytes
core_number:int = 4,
data_type:int = 2, #bfloat16
max_iter_number:int = 2):
'''
对于不均匀分core 需要每个core <= 2.5M 才能保证可以pingpong,
core间相差数量为 block_size[dim] * weight_size[1-dim]
缩小block_size可以减小core间差距使得更平均一些直到大core数据量不超
如果均匀分core已经超了或者没有超就没必要调整
'''
if dim < 0:
dim = 2 + dim
pingpong_size = pingpong_size / data_type # number of data
block_size_refine = block_size[dim]
all_block_number = weight_size[dim] // block_size_refine
if all_block_number % core_number == 0:
#均分,这种情况不管有没有超,都无需调整
return block_size_refine
block_number_tiny = all_block_number // core_number
block_number_big = all_block_number // core_number + 1
if block_number_tiny * block_size_refine * weight_size[1-dim] >= pingpong_size or \
block_number_big * block_size_refine * weight_size[1-dim] <= pingpong_size :
# 小的已经超了,无法再调整了
# 大的没有超,无需调整
return block_size_refine
all_block_number_tmp = all_block_number
block_size_refine_tmp = block_size_refine
for iter_index in range(max_iter_number):
all_block_number_tmp = all_block_number_tmp * 2
block_size_refine_tmp = block_size_refine_tmp // 2
if all_block_number_tmp % core_number == 0:
block_number_tiny = all_block_number // core_number
if block_number_tiny * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
return block_size_refine_tmp
else:
#均分还是超了,无需调整
return block_size_refine
else:
block_number_big = all_block_number_tmp // core_number + 1
if block_number_big * block_size_refine_tmp * weight_size[1-dim] <= pingpong_size:
return block_size_refine_tmp
return block_size_refine
def ColumnParallelLinear__init__(self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,):
# Divide the weight matrix along the last dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super(ColumnParallelLinear,self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
self.scale_n = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[0]
import math
if hasattr(self, "output_sizes"):
# 对于Merge类型的ColumnParallelLinear来说需要根据每个Part Linear的shape去计算最小公约数
output_size_no_merge = self.output_partition_sizes
block_values = [o % quant_config.weight_block_size[0] for o in output_size_no_merge]
is_gcd_recompute = sum(block_values)
if is_gcd_recompute:
import math
block_values.append(quant_config.weight_block_size[0])
gcd_value = math.gcd(*block_values)
# Notice:
# 这儿对于非对齐的Part-Weight 可能需要验证一下流程
# 对于DeepSeek来说仅存在于MLP&MOE中的MergeColumnLinear都是Shape一致的PartWeight
# 对于QWen3来说会存在QKVColumnLinear是Shape不一致的PartWeight但是由于QWen3当下的切分方案对于gcd_value无感无需重计算所以暂时不会进来
if hasattr(self, "output_sizes") and len(output_size_no_merge) == 2 and output_size_no_merge[0] == output_size_no_merge[1]:
#only refine mlp w13
gcd_value = refine_block([gcd_value, quant_config.weight_block_size[1]], [output_size_no_merge[0], input_size])
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
else:
# 对于非Merge的ColumnParallelLinear来说, 仅仅根据当下shape去计算最小公约数
output_size_no_merge = self.output_size_per_partition
is_gcd_recompute = output_size_no_merge % quant_config.weight_block_size[0]
if is_gcd_recompute:
gcd_value = math.gcd(output_size_no_merge % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n =self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
scale_n = self.scale_n,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def ColumnParallelLinear_weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
param.load_column_parallel_weight(loaded_weight=loaded_weight)
class MergedColumnParallelLinear(ColumnParallelLinear):
def weight_loader_v2(self,
param: BasevLLMParameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[-1]])
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,n,k] -> [n*scale_n, k]
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]:
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
if loaded_shard_id is None:
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0] // self.scale_n, weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
def RowParallelLinear__init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super(RowParallelLinear, self).__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
self.scale_k = 1 # quant_block_k 128 需要除以 scale_k 如设置为2 即 quant_block_k 是 64
self.scale_n = 1
self.scale_n_slice = 1
if quant_config is not None and hasattr(quant_config, "weight_block_size") and quant_config.weight_block_size is not None:
gcd_value = quant_config.weight_block_size[1]
import math
if self.input_size_per_partition % quant_config.weight_block_size[1]:
gcd_value = math.gcd(self.input_size_per_partition % quant_config.weight_block_size[1], quant_config.weight_block_size[1])
self.scale_k =self.scale_k * quant_config.weight_block_size[1] // gcd_value
if output_size % quant_config.weight_block_size[0]:
gcd_value = math.gcd(output_size % quant_config.weight_block_size[0], quant_config.weight_block_size[0])
self.scale_n = self.scale_n * quant_config.weight_block_size[0] // gcd_value
self.scale_n_slice = output_size // gcd_value
# N = 576, block = 128, n方向scale 扩充需要知道两个信息: 1.拷贝多少份 scale_n; 2. slice 有效的 scale_n_slice
# scale = [s0,s1,s2,s3,s4] 拷贝scale_n=2份
# scale = [s0,s0,s1,s1,s2,s2,s3,s3,s4,s4]slice scale_n_slice=9份 =>[s0,s0,s1,s1,s2,s2,s3,s3,s4]
if self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
gcd_value = quant_config.group_size
import math
if self.input_size_per_partition % quant_config.group_size:
gcd_value = math.gcd(self.input_size_per_partition % quant_config.group_size, quant_config.group_size)
self.quant_method.scale_k = self.quant_method.scale_k * quant_config.group_size // gcd_value
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
scale_k = self.scale_k,
scale_n = self.scale_n,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def RowParallelLinear_weight_loader_v2_vacc(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
if self.quant_method.__class__.__name__ in ['Fp8LinearMethod', 'Fp8MoEMethod'] and torch.finfo(loaded_weight.dtype).bits > 8:
if self.scale_k > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_k, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,2,0).reshape([loaded_weight.shape[1], -1])
#[1,n,k] -> [scale_k,n,k] -> [n,k,scale_k] -> [n, k*scale_k]
if self.scale_n > 1 and len(loaded_weight.shape) == 2:
loaded_weight = loaded_weight.unsqueeze(0) #[1,n,k]
loaded_weight = loaded_weight.expand(self.scale_n, loaded_weight.shape[1], loaded_weight.shape[2]).permute(1,0,2).reshape([-1, loaded_weight.shape[2]])[:self.scale_n_slice]
#[1,n,k] -> [scale_n,n,k] -> [n,scale_n,k] -> [n*scale_n,k]
elif self.quant_method.__class__.__name__ in ['GPTQLinearMethod']:
# broadcast scale TODO: broadcast zero
if self.quant_method.scale_k > 1 and len(loaded_weight.shape) == 2 and loaded_weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
loaded_weight = loaded_weight.unsqueeze(1) #[k,1,n]
loaded_weight = loaded_weight.expand(loaded_weight.shape[0], self.quant_method.scale_k, loaded_weight.shape[2]).reshape([-1, loaded_weight.shape[2]])
#[k,1,n] -> [k,scale_k,n]] -> [k*scale_k, n]
param.load_row_parallel_weight(loaded_weight=loaded_weight)
class UnquantizedLinearMethod():
"""Linear method without quantization."""
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if bias is not None:
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler
parallel_embedding_output = None
if memory_recycler is not None:
if memory_recycler.EMBEDDING_OUT_BUFFER.size(0) == x.size(0):
parallel_embedding_output = memory_recycler.EMBEDDING_OUT_BUFFER
return torch.mm(x.view(-1, x.shape[-1]), layer.weight.transpose(1,0), out=parallel_embedding_output).view(*(x.shape[:-1]), layer.weight.shape[0])