[Feature]Supports DSv3.1 PD separation and C8 quantization (#7222)

Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>

### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints: 
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
pichangping
2026-03-16 22:49:05 +08:00
committed by GitHub
parent a6f6e919e6
commit 3f39ac9c8d
15 changed files with 1112 additions and 161 deletions

View File

@@ -612,6 +612,7 @@ class AscendMlaCPImpl(AscendMLAImpl):
k_pe: torch.Tensor,
block_size: int,
attn_metadata: AscendMLAMetadata,
dequant_scale_q_nope=None,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None

View File

@@ -47,7 +47,7 @@ from vllm_ascend.ops.layer_shard_linear import (
register_all_layers_to_shard_weight_series,
)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
@@ -658,6 +658,7 @@ class DecodeMLAPreprocessResult(NamedTuple):
k_nope: torch.Tensor | None = None
k_pe: torch.Tensor | None = None
decode_q_wo_k_up: torch.Tensor | None = None
dequant_scale_q_nope: torch.Tensor | None = None
class PrefillMLAPreprocessResult(NamedTuple):
@@ -725,6 +726,12 @@ class AscendMLAImpl(MLAAttentionImpl):
self.is_kv_producer = (
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
)
self.layer_name = kwargs.get("layer_name")
quant_config = self.vllm_config.quant_config
self.fa_quant_layer = (
quant_config.enabling_fa_quant(self.vllm_config, self.layer_name) if quant_config is not None else False
)
self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype
self.layer_sharding_kwargs = []
for layer_name in get_ascend_config().layer_sharding or []:
if layer_name in kwargs:
@@ -775,6 +782,8 @@ class AscendMLAImpl(MLAAttentionImpl):
actual_seq_lengths,
attn_output,
softmax_lse,
dequant_scale_q_nope,
fak_descale_float,
) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model:
@@ -793,26 +802,35 @@ class AscendMLAImpl(MLAAttentionImpl):
seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list))
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu.npu_fused_infer_attention_score.out(
extra_args = {}
if dequant_scale_q_nope is not None:
extra_args = {
"query_quant_mode": 3,
"key_quant_mode": 0,
"value_quant_mode": 0,
"dequant_scale_query": dequant_scale_q_nope,
"dequant_scale_key": fak_descale_float,
"dequant_scale_value": fak_descale_float,
}
torch_npu.npu_fused_infer_attention_score_v2.out(
q_nope,
k_nope,
k_nope,
query_rope=q_pe,
key_rope=k_pe,
num_heads=num_heads,
num_query_heads=num_heads,
num_key_value_heads=num_kv_heads,
input_layout=input_layout,
atten_mask=attn_mask,
sparse_mode=sparse_mode,
scale=scale,
antiquant_mode=0,
antiquant_scale=None,
softmax_scale=scale,
block_table=block_table,
block_size=block_size,
actual_seq_lengths_kv=seq_lens_list,
actual_seq_lengths=actual_seq_lengths,
actual_seq_kvlen=seq_lens_list,
actual_seq_qlen=actual_seq_lengths,
workspace=graph_params.workspaces.get(num_tokens),
out=[attn_output, softmax_lse],
**extra_args,
)
torch.npu.graph_task_update_end(update_stream)
@@ -887,6 +905,8 @@ class AscendMLAImpl(MLAAttentionImpl):
)
if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype)
elif self.fa_quant_layer:
self._process_weights_for_fused_fa_quant()
else:
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
@@ -895,6 +915,32 @@ class AscendMLAImpl(MLAAttentionImpl):
if is_hidden_layer(layer):
post_process_after_loading_for_shard_weight_series(layer)
def _process_weights_for_fused_fa_quant(self):
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
wu_q = self.q_proj.weight.data
self.wu_q = wu_q
q_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr]
self.wd_q = q_a_proj_fa3
kv_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr]
self.wd_kv = kv_a_proj_fa3
self.dequant_scale_w_uq_qr = self.q_proj.weight_scale.data.view(1, -1).to(torch.float)
q_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr]
self.dequant_scale_w_dq = q_a_proj_deq_scl.view(1, -1).to(torch.float)
kv_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr]
self.dequant_scale_w_dkv_kr = kv_a_proj_deq_scl.view(1, -1).to(torch.float)
layer = self.vllm_config.compilation_config.static_forward_context[self.layer_name]
self.quant_kscale = layer.quant_kscale
self.fak_descale_float = layer.fak_descale_float
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
assert self.fused_qkv_a_proj is not None
assert self.q_a_layernorm is not None
@@ -1236,6 +1282,7 @@ class AscendMLAImpl(MLAAttentionImpl):
k_pe: torch.Tensor,
block_size: int,
attn_metadata: AscendMLAMetadata,
dequant_scale_q_nope=None,
) -> torch.Tensor:
decode_meta = attn_metadata.decode
assert decode_meta is not None
@@ -1243,7 +1290,15 @@ class AscendMLAImpl(MLAAttentionImpl):
# shape of knope/k_pe for npu graph mode should be:
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
actual_seq_lengths = None
if self.enable_kv_nz:
if self.fa_quant_layer:
nz_fmt_last_dim = 16
k_nope = k_nope.view(
-1, self.num_kv_heads, self.kv_lora_rank // (nz_fmt_last_dim * 2), block_size, nz_fmt_last_dim * 2
)
k_pe = k_pe.view(
-1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim
)
elif self.enable_kv_nz:
nz_fmt_last_dim = 16
k_nope = k_nope.view(
-1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim
@@ -1278,6 +1333,15 @@ class AscendMLAImpl(MLAAttentionImpl):
sparse_mode = 3
attn_mask = attn_metadata.decode.attn_mask # type:ignore
actual_seq_lengths = decode_meta.actual_seq_lengths_q
elif self.fa_quant_layer:
attn_mask = None
input_layout = "BSND_NBSD"
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous()
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1).contiguous()
dequant_scale_q_nope = dequant_scale_q_nope.view(num_tokens, 1, self.num_heads)
sparse_mode = 0
actual_seq_lengths = None
attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank)
else:
# The output layout is set to NBSD to eliminate the need for a
# transpose operation after attention.
@@ -1299,19 +1363,27 @@ class AscendMLAImpl(MLAAttentionImpl):
common_kwargs = {
"query_rope": q_pe,
"key_rope": k_pe,
"num_heads": self.num_heads,
"num_query_heads": self.num_heads,
"num_key_value_heads": self.num_kv_heads,
"input_layout": input_layout,
"atten_mask": attn_mask,
"sparse_mode": sparse_mode,
"scale": self.scale,
"antiquant_mode": 0,
"antiquant_scale": None,
"softmax_scale": self.scale,
"block_table": decode_meta.block_table,
"block_size": block_size,
"actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
"actual_seq_qlen": actual_seq_lengths,
"actual_seq_kvlen": decode_meta.seq_lens_list,
}
if self.fa_quant_layer:
extra_fa_args = {
"query_quant_mode": 3,
"key_quant_mode": 0,
"value_quant_mode": 0,
"dequant_scale_query": dequant_scale_q_nope,
"dequant_scale_key": self.fak_descale_float,
"dequant_scale_value": self.fak_descale_float,
}
common_kwargs.update(extra_fa_args)
if _EXTRA_CTX.is_draft_model:
graph_params = get_draft_graph_params()
else:
@@ -1325,8 +1397,33 @@ class AscendMLAImpl(MLAAttentionImpl):
graph_params.events[num_tokens].append(event)
workspace = graph_params.workspaces.get(num_tokens)
attn_output = torch.empty(attn_output_shape, dtype=q_pe.dtype, device=q_pe.device)
softmax_lse = torch.empty(num_tokens, dtype=q_pe.dtype, device=q_pe.device)
attn_params = (
weak_ref_tensors(q_nope),
weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_pe),
self.num_heads,
self.num_kv_heads,
input_layout,
weak_ref_tensors(attn_mask) if attn_mask is not None else None,
sparse_mode,
self.scale,
decode_meta.block_table,
block_size,
decode_meta.seq_lens_list,
actual_seq_lengths,
weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse),
)
if self.fa_quant_layer:
attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) # type: ignore
else:
attn_params = attn_params + (None, None) # type: ignore
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
workspace = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace(
q_nope, k_nope, k_nope, **common_kwargs
)
if _EXTRA_CTX.is_draft_model:
@@ -1334,38 +1431,16 @@ class AscendMLAImpl(MLAAttentionImpl):
else:
update_graph_params_workspaces(num_tokens, workspace)
attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device)
softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device)
graph_params.attn_params[num_tokens].append(
(
weak_ref_tensors(q_nope),
weak_ref_tensors(k_nope),
weak_ref_tensors(q_pe),
weak_ref_tensors(k_pe),
self.num_heads,
self.num_kv_heads,
input_layout,
weak_ref_tensors(attn_mask) if attn_mask is not None else None,
sparse_mode,
self.scale,
decode_meta.block_table,
block_size,
decode_meta.seq_lens_list,
actual_seq_lengths,
weak_ref_tensors(attn_output),
weak_ref_tensors(softmax_lse),
)
)
graph_params.attn_params[num_tokens].append(attn_params)
torch.npu.graph_task_group_begin(stream)
torch_npu.npu_fused_infer_attention_score.out(
torch_npu.npu_fused_infer_attention_score_v2.out(
q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse]
)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:
attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs)
attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(q_nope, k_nope, k_nope, **common_kwargs)
return self._v_up_proj(attn_output)
@@ -1381,55 +1456,81 @@ class AscendMLAImpl(MLAAttentionImpl):
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
decode_q_nope = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
decode_q_pe = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
dequant_scale_q_nope = None
if self.fa_quant_layer:
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope = torch_npu.npu_mla_prolog_v2(
quantized_x,
self.wd_q,
self.wu_q,
self.W_UK_T,
self.wd_kv,
self.gamma1,
self.gamma2,
sin,
cos,
attn_metadata.slot_mapping[:bsz].to(torch.int64),
decode_k_nope,
decode_k_pe,
dequant_scale_x=pertoken_scale.view(-1, 1),
dequant_scale_w_dq=self.dequant_scale_w_dq,
dequant_scale_w_uq_qr=self.dequant_scale_w_uq_qr,
dequant_scale_w_dkv_kr=self.dequant_scale_w_dkv_kr,
quant_scale_ckv=self.quant_kscale,
cache_mode="PA_NZ",
)
else:
decode_q_nope = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
decode_q_pe = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
cos,
sin,
self.W_UK_T,
decode_k_nope,
decode_k_pe,
attn_metadata.slot_mapping[:bsz],
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=decode_q_nope,
kv_cache_out0=decode_k_nope,
q_out1=decode_q_pe,
kv_cache_out1=decode_k_pe,
enable_inner_out=False,
inner_out=torch.tensor([], device=hidden_states.device),
)
decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
cos,
sin,
self.W_UK_T,
decode_k_nope,
decode_k_pe,
attn_metadata.slot_mapping[:bsz],
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=decode_q_nope,
kv_cache_out0=decode_k_nope,
q_out1=decode_q_pe,
kv_cache_out1=decode_k_pe,
enable_inner_out=False,
inner_out=torch.tensor([], device=hidden_states.device),
)
decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe)
decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
decode_preprocess_res = DecodeMLAPreprocessResult(
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope=dequant_scale_q_nope
)
return decode_preprocess_res, None
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata):
@@ -1576,7 +1677,7 @@ class AscendMLAImpl(MLAAttentionImpl):
o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device)
# MLA Preprocess
if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
if self.fa_quant_layer or (self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS):
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states.contiguous(), need_gather_q_kv
)
@@ -1596,6 +1697,7 @@ class AscendMLAImpl(MLAAttentionImpl):
decode_preprocess_res.k_pe,
kv_cache[0].shape[1],
attn_metadata,
decode_preprocess_res.dequant_scale_q_nope,
)
o_proj_input[:num_decode_tokens] = output_decode