[Feature] support deepseek v3/r1/v3.2 (#78)
* [Feature] support deepseek v3/r1/v3.2 * fix gpt_oss * update readme * update readme --------- Co-authored-by: hanhaowen <hanhaowen@baidu.com>
This commit is contained in:
@@ -1,244 +1,169 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from typing import Any, Literal, Optional, cast, Callable, Optional
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod
|
||||
|
||||
from compressed_tensors.config import (CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure)
|
||||
from compressed_tensors.quantization import (ActivationOrdering,
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
# TODO: import position will be changed after 0.9.0
|
||||
# vllm.model_executor.layers.fused_moe.fused_moe --> vllm.model_executor.layers.fused_moe
|
||||
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
|
||||
"""modify scale -> abs max"""
|
||||
layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
layer.w13_weight_scale.data * 127, requires_grad=False
|
||||
)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(
|
||||
layer.w2_weight_scale.data * 127, requires_grad=False
|
||||
)
|
||||
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
import re
|
||||
import xtorch_ops
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
klx_process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
hidden_states = x
|
||||
global_num_experts, up_gate_size, _ = layer.w13_weight.shape
|
||||
M, N = hidden_states.shape
|
||||
hidden_dim = layer.w2_weight.shape[1]
|
||||
normed_score = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device)
|
||||
topk_ids = torch.empty(M,
|
||||
top_k,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
num_blocks = 12
|
||||
block_statistic = torch.zeros(
|
||||
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
|
||||
)
|
||||
|
||||
from safetensors.torch import load_file as safe_load_file
|
||||
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def get_moe_method(quant_config, layer) -> "CompressedTensorsMoEMethod":
|
||||
tsm = getattr(quant_config, "target_scheme_map", None) or {}
|
||||
linear_cfg = None
|
||||
for k in ("Linear", "FusedMoE", "MoE", "Moe", "Experts"):
|
||||
if k in tsm and isinstance(tsm[k], dict):
|
||||
linear_cfg = tsm[k]; break
|
||||
if not linear_cfg:
|
||||
# print("target_scheme_map missing; fallback to INT8(W8A8) method")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
wq = linear_cfg.get("weights"); aq = linear_cfg.get("input_activations")
|
||||
if not wq or not aq:
|
||||
# print("incomplete scheme; fallback to INT8(W8A8)")
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
# 其它分流按需;默认回落:
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
||||
|
||||
# copied from vllm 0.9.0
|
||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# 直接创建默认的量化配置字典,避免 QuantizationArgs 的验证问题
|
||||
# print("Creating default INT8 quantization config for MoE")
|
||||
|
||||
# 创建默认的权重量化配置字典
|
||||
self.weight_quant = type('WeightQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'channel',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': False,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# 创建默认的输入激活量化配置字典
|
||||
self.input_quant = type('InputQuant', (), {
|
||||
'type': 'int',
|
||||
'num_bits': 8,
|
||||
'strategy': 'token',
|
||||
'group_size': 128,
|
||||
'symmetric': True,
|
||||
'dynamic': True,
|
||||
'actorder': 'none',
|
||||
'observer': None,
|
||||
'observer_kwargs': {},
|
||||
'block_structure': None
|
||||
})()
|
||||
|
||||
# 修改比较方式,直接比较字符串
|
||||
per_channel = (
|
||||
self.weight_quant.strategy == "channel"
|
||||
and self.input_quant.strategy == "token")
|
||||
if not per_channel:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found "
|
||||
f"{self.weight_quant}, {self.input_quant}")
|
||||
|
||||
self.static_input_scales = not self.input_quant.dynamic
|
||||
if self.static_input_scales:
|
||||
raise ValueError(
|
||||
"For INT8 Fused MoE layers, we require channelwise, "
|
||||
"dynamic per token quantization. Found static input scales.")
|
||||
|
||||
def create_weights1(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# 权重先用浮点占位,便于从 ckpt 加载原始权重
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype), # 通常是 torch.bfloat16
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# 通道 scale:float32 + 二维 [E, out](与 fused_moe/UT 对齐)
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# 输入 scale 动态计算即可
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8), # 直接使用 int8
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# 缩放因子
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.empty(num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
# 输入 scale 动态计算
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
@torch.no_grad()
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
return
|
||||
#原始权重转 float32 做统计更稳健
|
||||
w13_f = layer.w13_weight.float()
|
||||
w2_f = layer.w2_weight.float()
|
||||
|
||||
# 每列(abs_max) -> per-column scale(out 维在 dim=1,列在 dim=-1)
|
||||
qmax = 127.0
|
||||
w13_abs_max = torch.amax(torch.abs(w13_f), dim=-1) # [E, 2N]
|
||||
w2_abs_max = torch.amax(torch.abs(w2_f), dim=-1) # [E, H]
|
||||
|
||||
w13_scale_2d = torch.clamp(w13_abs_max, min=1e-6) / qmax # [E, 2N], float32
|
||||
w2_scale_2d = torch.clamp(w2_abs_max, min=1e-6) / qmax # [E, H], float32
|
||||
|
||||
# 量化:用 3D scale 广播,存回 2D scale
|
||||
w13_scale_3d = w13_scale_2d.unsqueeze(-1) # [E, 2N, 1]
|
||||
w2_scale_3d = w2_scale_2d.unsqueeze(-1) # [E, H, 1]
|
||||
|
||||
w13_q = torch.round(w13_f / w13_scale_3d).clamp_(-128, 127).to(torch.int8)
|
||||
w2_q = torch.round(w2_f / w2_scale_3d ).clamp_(-128, 127).to(torch.int8)
|
||||
|
||||
# 可选:若你的 fused/kernel 期望 scale 预乘 127(与某些 UT 后端一致),打开下面两行:
|
||||
w13_scale_2d = w13_scale_2d * 127.0
|
||||
w2_scale_2d = w2_scale_2d * 127.0
|
||||
|
||||
# 回写参数:权重 int8;scale 用 float32 + 2D
|
||||
replace_parameter(layer, 'w13_weight', torch.nn.Parameter(w13_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight', torch.nn.Parameter(w2_q, requires_grad=False))
|
||||
replace_parameter(layer, 'w13_weight_scale',
|
||||
torch.nn.Parameter(w13_scale_2d.contiguous(), requires_grad=False))
|
||||
replace_parameter(layer, 'w2_weight_scale',
|
||||
torch.nn.Parameter(w2_scale_2d.contiguous(), requires_grad=False))
|
||||
|
||||
# 简要检查
|
||||
print(f"w13: {w13_q.shape}, w13_s: {w13_scale_2d.shape}, w2: {w2_q.shape}, w2_s: {w2_scale_2d.shape}")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False, # 添加这个参数
|
||||
expert_load_view: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
logical_replica_count: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
linear_weights: Optional[torch.Tensor] = None, # 添加这个参数
|
||||
) -> torch.Tensor:
|
||||
|
||||
output = torch.empty_like(x)
|
||||
torch.ops._C.moe_ffn_per_token_block(
|
||||
x=x,
|
||||
inter_weight=layer.w13_weight,
|
||||
inter_scale=layer.w13_weight_scale,
|
||||
outer_weight=layer.w2_weight,
|
||||
outer_scale=layer.w2_weight_scale,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
linear_weights=linear_weights,
|
||||
expert_map=expert_map,
|
||||
activation=activation,
|
||||
output=output,
|
||||
use_expert_parallel=expert_map is not None,
|
||||
ep_size=expert_map.size(0) if expert_map is not None else 1,
|
||||
ep_rank=0,
|
||||
router_logits = router_logits.float()
|
||||
if scoring_func == "softmax":
|
||||
torch.ops._C.moe_softmax_topk_norm(
|
||||
x=router_logits,
|
||||
normed_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=None,
|
||||
stable=True)
|
||||
elif scoring_func == "sigmoid":
|
||||
torch.ops._C.moe_sigmoid_group_topk_norm(
|
||||
x=router_logits,
|
||||
norm_score=normed_score,
|
||||
topk_index=topk_ids,
|
||||
block_static=block_statistic,
|
||||
bias=e_score_correction_bias,
|
||||
n_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
scale=routed_scaling_factor,
|
||||
)
|
||||
return output
|
||||
|
||||
print("[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsMoEMethod \
|
||||
--> vllm_xpu.model_executor.layers.quantization.compressed_tensors_moe.py:CompressedTensorsMoEMethod")
|
||||
moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float
|
||||
expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E]
|
||||
sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1]
|
||||
sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
torch.ops._C.gen_block_statistic(topk_ids,block_statistic)
|
||||
|
||||
torch.ops._C.moe_pre_sorted(
|
||||
x=hidden_states,
|
||||
topk_index=topk_ids,
|
||||
block_statistic=block_statistic,
|
||||
moe_expand=moe_expand,
|
||||
moe_index=sorted_tokens_idx,
|
||||
expert_m=expert_m,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod)
|
||||
|
||||
y = torch.empty(M,top_k,
|
||||
layer.w13_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
moe_expand = moe_expand.view(M * top_k, hidden_dim)
|
||||
|
||||
x_shape = moe_expand.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
|
||||
torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=x_q,
|
||||
x_perchannel_max=x_scale,
|
||||
weight=layer.w13_weight,
|
||||
w_perchannel_max=layer.w13_weight_scale,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=y,
|
||||
topk_ids=topk_ids,
|
||||
# sort_mode=False,
|
||||
act=None)
|
||||
|
||||
d = y.shape[-1] // 2
|
||||
output_shape = (y.shape[:-1] + (d, ))
|
||||
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
|
||||
torch.ops._C.silu_and_mul(out1, y)
|
||||
|
||||
out = torch.empty(M,top_k,
|
||||
layer.w2_weight.shape[1],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
out1 = out1.reshape(-1, out1.shape[-1])
|
||||
x_shape = out1.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device)
|
||||
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
|
||||
|
||||
torch.ops._C.moe_fc(
|
||||
x=x_q,
|
||||
x_perchannel_max=x_scale,
|
||||
weight=layer.w2_weight,
|
||||
w_perchannel_max=layer.w2_weight_scale,
|
||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||
sorted_tokens_idx=sorted_tokens_idx,
|
||||
moe_topk=top_k,
|
||||
y=out,
|
||||
topk_ids=topk_ids,
|
||||
# sort_mode=False,
|
||||
act=None)
|
||||
|
||||
dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device)
|
||||
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
|
||||
|
||||
torch.ops._C.moe_post(
|
||||
x=out,
|
||||
moe_index=sorted_tokens_idx,
|
||||
normed_scale=normed_score,
|
||||
dequant_scale=dequant_scale,
|
||||
y=output
|
||||
)
|
||||
return output
|
||||
|
||||
CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading
|
||||
CompressedTensorsW8A8Int8MoEMethod.apply = apply
|
||||
0
vllm_kunlun/ops/quantization/kernels/__init__.py
Normal file
0
vllm_kunlun/ops/quantization/kernels/__init__.py
Normal file
122
vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py
Normal file
122
vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise)
|
||||
|
||||
def can_implement_kunlun(
|
||||
cls, c: ScaledMMLinearLayerConfig=None) -> tuple[bool, Optional[str]]:
|
||||
return True, None
|
||||
|
||||
def klx_process_weights_after_loading(layer: torch.nn.Module) -> None:
|
||||
"""modify scale -> abs max"""
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data * 127, requires_grad=False)
|
||||
|
||||
def process_weights_after_loading_kunlun(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False))
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, self.w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale,
|
||||
layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer, self.w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False))
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False))
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, self.i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, self.i_s_name,
|
||||
torch.nn.Parameter(scale, requires_grad=False))
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(layer, self.i_zp_name,
|
||||
torch.nn.Parameter(azp, requires_grad=False))
|
||||
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.config.input_symmetric:
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
|
||||
setattr(layer, self.azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False))
|
||||
else:
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
klx_process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights_kunlun(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x_q, x_scale, out = None, None, None
|
||||
w_t_shape = layer.weight.T.shape
|
||||
if isinstance(x, tuple):
|
||||
x_q, x_scale = x
|
||||
out = torch.empty((x_q.shape[0], w_t_shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
device=x_q.device)
|
||||
else:
|
||||
x_shape = x.shape
|
||||
x_q = torch.empty(x_shape, dtype=torch.int8, device=x.device)
|
||||
x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=x.device)
|
||||
out = torch.empty((x_shape[0], w_t_shape[0]),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
torch.ops._C.quant2d(x, x_q, x_scale, force_sdnn=True)
|
||||
torch.ops._C.gemm_I8_I8_bf16_nt(x_q, x_scale, layer.weight.T.data, layer.weight_scale.data, out)
|
||||
return out
|
||||
|
||||
CutlassScaledMMLinearKernel.apply_weights = apply_weights_kunlun
|
||||
CutlassScaledMMLinearKernel.can_implement = can_implement_kunlun
|
||||
CutlassScaledMMLinearKernel.process_weights_after_loading = process_weights_after_loading_kunlun
|
||||
Reference in New Issue
Block a user