244 lines
11 KiB
Python
244 lines
11 KiB
Python
import torch
|
||
from typing import Any, Literal, Optional, cast, Callable, Optional
|
||
|
||
from compressed_tensors.config import (CompressionFormat,
|
||
SparsityCompressionConfig,
|
||
SparsityStructure)
|
||
from compressed_tensors.quantization import (ActivationOrdering,
|
||
QuantizationStrategy)
|
||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||
FusedMoeWeightScaleSupported)
|
||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||
# TODO: import position will be changed after 0.9.0
|
||
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
|
||
|
||
from vllm.model_executor.utils import set_weight_attrs
|
||
import re
|
||
import xtorch_ops
|
||
|
||
|
||
from safetensors.torch import load_file as safe_load_file
|
||
|
||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||
|
||
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
|
||
tsm = getattr(quant_config, "target_scheme_map", None) or {}
|
||
linear_cfg = None
|
||
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
|
||
if k in tsm and isinstance(tsm[k], dict):
|
||
linear_cfg = tsm[k]; break
|
||
if not linear_cfg:
|
||
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
|
||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
|
||
if not wq or not aq:
|
||
# print("incomplete scheme; fallback to INT8(W8A8)")
|
||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||
# 其它分流按需;默认回落:
|
||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||
|
||
# copied from vllm 0.9.0
|
||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||
|
||
def __init__(
|
||
self,
|
||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||
):
|
||
self.quant_config = quant_config
|
||
|
||
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
|
||
# print("Creating default INT8 quantization config for MoE")
|
||
|
||
# 创建默认的权重量化配置字典
|
||
self.weight_quant = type('WeightQuant', (), {
|
||
'type': 'int',
|
||
'num_bits': 8,
|
||
'strategy': 'channel',
|
||
'group_size': 128,
|
||
'symmetric': True,
|
||
'dynamic': False,
|
||
'actorder': 'none',
|
||
'observer': None,
|
||
'observer_kwargs': {},
|
||
'block_structure': None
|
||
})()
|
||
|
||
# 创建默认的输入激活量化配置字典
|
||
self.input_quant = type('InputQuant', (), {
|
||
'type': 'int',
|
||
'num_bits': 8,
|
||
'strategy': 'token',
|
||
'group_size': 128,
|
||
'symmetric': True,
|
||
'dynamic': True,
|
||
'actorder': 'none',
|
||
'observer': None,
|
||
'observer_kwargs': {},
|
||
'block_structure': None
|
||
})()
|
||
|
||
# 修改比较方式,直接比较字符串
|
||
per_channel = (
|
||
self.weight_quant.strategy == "channel"
|
||
and self.input_quant.strategy == "token")
|
||
if not per_channel:
|
||
raise ValueError(
|
||
"For INT8 Fused MoE layers, we require channelwise, "
|
||
"dynamic per token quantization. Found "
|
||
f"{self.weight_quant}, {self.input_quant}")
|
||
|
||
self.static_input_scales = not self.input_quant.dynamic
|
||
if self.static_input_scales:
|
||
raise ValueError(
|
||
"For INT8 Fused MoE layers, we require channelwise, "
|
||
"dynamic per token quantization. Found static input scales.")
|
||
|
||
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||
# 权重先用浮点占位,便于从 ckpt 加载原始权重
|
||
w13_weight = torch.nn.Parameter(torch.empty(
|
||
num_experts,
|
||
2 * intermediate_size_per_partition,
|
||
hidden_size,
|
||
dtype=params_dtype), # 通常是 torch.bfloat16
|
||
requires_grad=False)
|
||
layer.register_parameter("w13_weight", w13_weight)
|
||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||
|
||
w2_weight = torch.nn.Parameter(torch.empty(
|
||
num_experts,
|
||
hidden_size,
|
||
intermediate_size_per_partition,
|
||
dtype=params_dtype),
|
||
requires_grad=False)
|
||
layer.register_parameter("w2_weight", w2_weight)
|
||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||
|
||
# 通道 scale:float32 + 二维 [E, out](与 fused_moe/UT 对齐)
|
||
w13_weight_scale = torch.nn.Parameter(
|
||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||
requires_grad=False)
|
||
w2_weight_scale = torch.nn.Parameter(
|
||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||
requires_grad=False)
|
||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||
|
||
# 输入 scale 动态计算即可
|
||
layer.w13_input_scale = None
|
||
layer.w2_input_scale = None
|
||
|
||
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||
w13_weight = torch.nn.Parameter(torch.empty(
|
||
num_experts,
|
||
2 * intermediate_size_per_partition,
|
||
hidden_size,
|
||
dtype=torch.int8), # 直接使用 int8
|
||
requires_grad=False)
|
||
layer.register_parameter("w13_weight", w13_weight)
|
||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||
|
||
w2_weight = torch.nn.Parameter(torch.empty(
|
||
num_experts,
|
||
hidden_size,
|
||
intermediate_size_per_partition,
|
||
dtype=torch.int8), # 直接使用 int8
|
||
requires_grad=False)
|
||
layer.register_parameter("w2_weight", w2_weight)
|
||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||
|
||
# 缩放因子
|
||
w13_weight_scale = torch.nn.Parameter(
|
||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||
requires_grad=False)
|
||
w2_weight_scale = torch.nn.Parameter(
|
||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||
requires_grad=False)
|
||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||
|
||
# 输入 scale 动态计算
|
||
layer.w13_input_scale = None
|
||
layer.w2_input_scale = None
|
||
|
||
@torch.no_grad()
|
||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||
return
|
||
#原始权重转 float32 做统计更稳健
|
||
w13_f = layer.w13_weight.float()
|
||
w2_f = layer.w2_weight.float()
|
||
|
||
# 每列(abs_max) -> per-column scale(out 维在 dim=1,列在 dim=-1)
|
||
qmax = 127.0
|
||
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
|
||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||
|
||
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
|
||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||
|
||
# 量化:用 3D scale 广播,存回 2D scale
|
||
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
|
||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||
|
||
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
|
||
|
||
# 可选:若你的 fused/kernel 期望 scale 预乘 127(与某些 UT 后端一致),打开下面两行:
|
||
w13_scale_2d = w13_scale_2d * 127.0
|
||
w2_scale_2d = w2_scale_2d * 127.0
|
||
|
||
# 回写参数:权重 int8;scale 用 float32 + 2D
|
||
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
|
||
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
|
||
replace_parameter(layer, 'w13_weight_scale',
|
||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
|
||
replace_parameter(layer, 'w2_weight_scale',
|
||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
|
||
|
||
# 简要检查
|
||
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
|
||
|
||
def apply(
|
||
self,
|
||
layer: torch.nn.Module,
|
||
x: torch.Tensor,
|
||
router_logits: torch.Tensor,
|
||
top_k: int,
|
||
renormalize: bool,
|
||
use_grouped_topk: bool = False,
|
||
topk_group: Optional[int] = None,
|
||
num_expert_group: Optional[int] = None,
|
||
global_num_experts: int = -1,
|
||
expert_map: Optional[torch.Tensor] = None,
|
||
custom_routing_function: Optional[Callable] = None,
|
||
scoring_func: str = "softmax",
|
||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||
apply_router_weight_on_input: bool = False,
|
||
activation: str = "silu",
|
||
enable_eplb: bool = False, # 添加这个参数
|
||
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
|
||
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
|
||
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
|
||
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
|
||
) -> torch.Tensor:
|
||
|
||
output = torch.empty_like(x)
|
||
torch.ops._C.moe_ffn_per_token_block(
|
||
x=x,
|
||
inter_weight=layer.w13_weight,
|
||
inter_scale=layer.w13_weight_scale,
|
||
outer_weight=layer.w2_weight,
|
||
outer_scale=layer.w2_weight_scale,
|
||
top_k=top_k,
|
||
global_num_experts=global_num_experts,
|
||
linear_weights=linear_weights,
|
||
expert_map=expert_map,
|
||
activation=activation,
|
||
output=output,
|
||
use_expert_parallel=expert_map is not None,
|
||
ep_size=expert_map.size(0) if expert_map is not None else 1,
|
||
ep_rank=0,
|
||
)
|
||
return output
|
||
|
||
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod") |