[feature] Support W8A8 PD-Mix Quantization (#4235)
In PD-separated deployment scenarios: * MoE layers use dynamic quantization exclusively. * For the Attention module, Prefill (P) nodes use **dynamic** quantization, while Decode (D) nodes use **static** quantization. In PD-mixed deployment scenarios: * **All components fall back to dynamic quantization**, as it is difficult to distinguish between Prefill and Decode tokens. ___ - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: SlightwindSec <slightwindsec@gmail.com> Signed-off-by: Slightwind <slightwindsec@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -12,6 +12,8 @@ from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
|
||||
AscendW8A8PDMixLinearMethod)
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W4A8_DYNAMIC": {
|
||||
@@ -30,6 +32,10 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"linear": AscendW8A8DynamicLinearMethod,
|
||||
"moe": AscendW8A8DynamicFusedMoEMethod,
|
||||
},
|
||||
"W8A8_MIX": {
|
||||
"linear": AscendW8A8PDMixLinearMethod,
|
||||
"moe": AscendW8A8PDMixFusedMoeMethod,
|
||||
},
|
||||
"C8": {
|
||||
"attention": AscendC8KVCacheMethod,
|
||||
},
|
||||
|
||||
@@ -87,6 +87,7 @@ class AscendW8A8LinearMethod:
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["bias"] = torch.zeros(output_size, dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
@@ -192,6 +193,7 @@ class AscendW8A8LinearMethod:
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype)
|
||||
if getattr(layer, "ascend_quant_method",
|
||||
"") == COMPRESSED_TENSORS_METHOD:
|
||||
deq_scale = layer.input_scale.data * layer.weight_scale.data
|
||||
|
||||
@@ -60,6 +60,7 @@ class AscendW8A8DynamicLinearMethod:
|
||||
params_dict["weight_offset"] = torch.empty(output_size,
|
||||
1,
|
||||
dtype=params_dtype)
|
||||
params_dict["bias"] = torch.zeros(output_size, dtype=torch.float32)
|
||||
return params_dict
|
||||
|
||||
def get_pergroup_param(self,
|
||||
@@ -110,6 +111,7 @@ class AscendW8A8DynamicLinearMethod:
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
layer.bias.data = layer.bias.data.to(layer.weight_scale.data.dtype)
|
||||
|
||||
|
||||
class AscendW8A8DynamicFusedMoEMethod:
|
||||
|
||||
70
vllm_ascend/quantization/w8a8_pdmix.py
Normal file
70
vllm_ascend/quantization/w8a8_pdmix.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
from .w8a8 import AscendW8A8LinearMethod
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class AscendW8A8PDMixLinearMethod(AscendW8A8DynamicLinearMethod):
|
||||
|
||||
def __init__(self):
|
||||
self.kv_transfer_config = get_current_vllm_config().kv_transfer_config
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def apply(layer, x, bias=None, tp_rank=0):
|
||||
if layer.is_kv_consumer:
|
||||
return AscendW8A8LinearMethod.apply(layer, x, bias, tp_rank)
|
||||
else:
|
||||
return AscendW8A8DynamicLinearMethod.apply(layer, x, bias, tp_rank)
|
||||
|
||||
@staticmethod
|
||||
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
return AscendW8A8LinearMethod.get_pertensor_param(params_dtype)
|
||||
|
||||
@staticmethod
|
||||
def get_perchannel_param(
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
return AscendW8A8LinearMethod.get_perchannel_param(
|
||||
output_size, params_dtype)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
AscendW8A8LinearMethod.process_weights_after_loading(
|
||||
cast(AscendW8A8LinearMethod, self), layer)
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.is_kv_consumer = self.kv_transfer_config is not None and self.kv_transfer_config.is_kv_consumer
|
||||
|
||||
|
||||
class AscendW8A8PDMixFusedMoeMethod(AscendW8A8DynamicFusedMoEMethod):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = AscendW8A8DynamicFusedMoEMethod.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_sizes,
|
||||
params_dtype)
|
||||
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_deq_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
|
||||
return param_dict
|
||||
Reference in New Issue
Block a user