[Feature]Supports DSv3.1 PD separation and C8 quantization (#7222)
Co-authored-by: kunpengW-code <1289706727@qq.com>
Co-authored-by: linsheng1 <1950916997@qq.com>
### What this PR does / why we need it?
Currently, chunked prefill is forcibly enabled. DeepSeek V3.1 W8A8C8
supports only the PD separation scenario. C8 refers to quantizing the KV
cache to int8, which aims to reduce the GPU memory usage of the KV cache
and improve the inference throughput.
Constraints:
1. Only the PD separation mode can be used and
MooncakeLayerwiseConnector can be used to run the model.
2. Currently, only the activation value supports dynamic quantization,
and the KV cache supports static quantization. C8 quantization with MTP
is not supported. You can use ModelSlim for quantization. The
quantization procedure is as follows:
pip install transformers==4.48.2
git clone https://gitcode.com/Ascend/msmodelslim.git
cd msmodelslim
bash install.sh
cd example/DeepSeek/
python3 quant_deepseek_w8a8.py --model_path <path/weight> --save_path
<path/quant_weight>
--anti_dataset../common/deepseek_anti_prompt_50_v3_1.json
--calib_dataset../common/deepseek_calib_prompt_50_v3_1.json --rot
--trust_remote_code True --fa_quant --dynamic --anti_method m6
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pichangping <1337510399@qq.com>
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Co-authored-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -47,7 +47,7 @@ from vllm_ascend.ops.layer_shard_linear import (
|
||||
register_all_layers_to_shard_weight_series,
|
||||
)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.quantization.methods import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.methods.w8a8_static import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors
|
||||
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
|
||||
|
||||
@@ -658,6 +658,7 @@ class DecodeMLAPreprocessResult(NamedTuple):
|
||||
k_nope: torch.Tensor | None = None
|
||||
k_pe: torch.Tensor | None = None
|
||||
decode_q_wo_k_up: torch.Tensor | None = None
|
||||
dequant_scale_q_nope: torch.Tensor | None = None
|
||||
|
||||
|
||||
class PrefillMLAPreprocessResult(NamedTuple):
|
||||
@@ -725,6 +726,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.is_kv_producer = (
|
||||
self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||
)
|
||||
self.layer_name = kwargs.get("layer_name")
|
||||
quant_config = self.vllm_config.quant_config
|
||||
self.fa_quant_layer = (
|
||||
quant_config.enabling_fa_quant(self.vllm_config, self.layer_name) if quant_config is not None else False
|
||||
)
|
||||
self.dtype = torch.int8 if self.fa_quant_layer else self.vllm_config.model_config.dtype
|
||||
self.layer_sharding_kwargs = []
|
||||
for layer_name in get_ascend_config().layer_sharding or []:
|
||||
if layer_name in kwargs:
|
||||
@@ -775,6 +782,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
actual_seq_lengths,
|
||||
attn_output,
|
||||
softmax_lse,
|
||||
dequant_scale_q_nope,
|
||||
fak_descale_float,
|
||||
) = param
|
||||
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
|
||||
if speculative_config and speculative_config.method == "mtp" and not _EXTRA_CTX.is_draft_model:
|
||||
@@ -793,26 +802,35 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list))
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
extra_args = {}
|
||||
if dequant_scale_q_nope is not None:
|
||||
extra_args = {
|
||||
"query_quant_mode": 3,
|
||||
"key_quant_mode": 0,
|
||||
"value_quant_mode": 0,
|
||||
"dequant_scale_query": dequant_scale_q_nope,
|
||||
"dequant_scale_key": fak_descale_float,
|
||||
"dequant_scale_value": fak_descale_float,
|
||||
}
|
||||
torch_npu.npu_fused_infer_attention_score_v2.out(
|
||||
q_nope,
|
||||
k_nope,
|
||||
k_nope,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_pe,
|
||||
num_heads=num_heads,
|
||||
num_query_heads=num_heads,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
input_layout=input_layout,
|
||||
atten_mask=attn_mask,
|
||||
sparse_mode=sparse_mode,
|
||||
scale=scale,
|
||||
antiquant_mode=0,
|
||||
antiquant_scale=None,
|
||||
softmax_scale=scale,
|
||||
block_table=block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=seq_lens_list,
|
||||
actual_seq_lengths=actual_seq_lengths,
|
||||
actual_seq_kvlen=seq_lens_list,
|
||||
actual_seq_qlen=actual_seq_lengths,
|
||||
workspace=graph_params.workspaces.get(num_tokens),
|
||||
out=[attn_output, softmax_lse],
|
||||
**extra_args,
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
@@ -887,6 +905,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
)
|
||||
if self.enable_mlapo:
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
elif self.fa_quant_layer:
|
||||
self._process_weights_for_fused_fa_quant()
|
||||
else:
|
||||
# if mlapo, W_UK_T can't trans nz
|
||||
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
|
||||
@@ -895,6 +915,32 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
if is_hidden_layer(layer):
|
||||
post_process_after_loading_for_shard_weight_series(layer)
|
||||
|
||||
def _process_weights_for_fused_fa_quant(self):
|
||||
self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr]
|
||||
self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr]
|
||||
|
||||
wu_q = self.q_proj.weight.data
|
||||
|
||||
self.wu_q = wu_q
|
||||
|
||||
q_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr]
|
||||
|
||||
self.wd_q = q_a_proj_fa3
|
||||
|
||||
kv_a_proj_fa3 = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr]
|
||||
|
||||
self.wd_kv = kv_a_proj_fa3
|
||||
|
||||
self.dequant_scale_w_uq_qr = self.q_proj.weight_scale.data.view(1, -1).to(torch.float)
|
||||
q_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr]
|
||||
self.dequant_scale_w_dq = q_a_proj_deq_scl.view(1, -1).to(torch.float)
|
||||
kv_a_proj_deq_scl = self.fused_qkv_a_proj.weight_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr]
|
||||
self.dequant_scale_w_dkv_kr = kv_a_proj_deq_scl.view(1, -1).to(torch.float)
|
||||
|
||||
layer = self.vllm_config.compilation_config.static_forward_context[self.layer_name]
|
||||
self.quant_kscale = layer.quant_kscale
|
||||
self.fak_descale_float = layer.fak_descale_float
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
assert self.fused_qkv_a_proj is not None
|
||||
assert self.q_a_layernorm is not None
|
||||
@@ -1236,6 +1282,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
k_pe: torch.Tensor,
|
||||
block_size: int,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
dequant_scale_q_nope=None,
|
||||
) -> torch.Tensor:
|
||||
decode_meta = attn_metadata.decode
|
||||
assert decode_meta is not None
|
||||
@@ -1243,7 +1290,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
actual_seq_lengths = None
|
||||
if self.enable_kv_nz:
|
||||
if self.fa_quant_layer:
|
||||
nz_fmt_last_dim = 16
|
||||
k_nope = k_nope.view(
|
||||
-1, self.num_kv_heads, self.kv_lora_rank // (nz_fmt_last_dim * 2), block_size, nz_fmt_last_dim * 2
|
||||
)
|
||||
k_pe = k_pe.view(
|
||||
-1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim
|
||||
)
|
||||
elif self.enable_kv_nz:
|
||||
nz_fmt_last_dim = 16
|
||||
k_nope = k_nope.view(
|
||||
-1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim
|
||||
@@ -1278,6 +1333,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
sparse_mode = 3
|
||||
attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
elif self.fa_quant_layer:
|
||||
attn_mask = None
|
||||
input_layout = "BSND_NBSD"
|
||||
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous()
|
||||
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1).contiguous()
|
||||
dequant_scale_q_nope = dequant_scale_q_nope.view(num_tokens, 1, self.num_heads)
|
||||
sparse_mode = 0
|
||||
actual_seq_lengths = None
|
||||
attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank)
|
||||
else:
|
||||
# The output layout is set to NBSD to eliminate the need for a
|
||||
# transpose operation after attention.
|
||||
@@ -1299,19 +1363,27 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
common_kwargs = {
|
||||
"query_rope": q_pe,
|
||||
"key_rope": k_pe,
|
||||
"num_heads": self.num_heads,
|
||||
"num_query_heads": self.num_heads,
|
||||
"num_key_value_heads": self.num_kv_heads,
|
||||
"input_layout": input_layout,
|
||||
"atten_mask": attn_mask,
|
||||
"sparse_mode": sparse_mode,
|
||||
"scale": self.scale,
|
||||
"antiquant_mode": 0,
|
||||
"antiquant_scale": None,
|
||||
"softmax_scale": self.scale,
|
||||
"block_table": decode_meta.block_table,
|
||||
"block_size": block_size,
|
||||
"actual_seq_lengths": actual_seq_lengths,
|
||||
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
|
||||
"actual_seq_qlen": actual_seq_lengths,
|
||||
"actual_seq_kvlen": decode_meta.seq_lens_list,
|
||||
}
|
||||
if self.fa_quant_layer:
|
||||
extra_fa_args = {
|
||||
"query_quant_mode": 3,
|
||||
"key_quant_mode": 0,
|
||||
"value_quant_mode": 0,
|
||||
"dequant_scale_query": dequant_scale_q_nope,
|
||||
"dequant_scale_key": self.fak_descale_float,
|
||||
"dequant_scale_value": self.fak_descale_float,
|
||||
}
|
||||
common_kwargs.update(extra_fa_args)
|
||||
if _EXTRA_CTX.is_draft_model:
|
||||
graph_params = get_draft_graph_params()
|
||||
else:
|
||||
@@ -1325,8 +1397,33 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
attn_output = torch.empty(attn_output_shape, dtype=q_pe.dtype, device=q_pe.device)
|
||||
softmax_lse = torch.empty(num_tokens, dtype=q_pe.dtype, device=q_pe.device)
|
||||
attn_params = (
|
||||
weak_ref_tensors(q_nope),
|
||||
weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(q_pe),
|
||||
weak_ref_tensors(k_pe),
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
input_layout,
|
||||
weak_ref_tensors(attn_mask) if attn_mask is not None else None,
|
||||
sparse_mode,
|
||||
self.scale,
|
||||
decode_meta.block_table,
|
||||
block_size,
|
||||
decode_meta.seq_lens_list,
|
||||
actual_seq_lengths,
|
||||
weak_ref_tensors(attn_output),
|
||||
weak_ref_tensors(softmax_lse),
|
||||
)
|
||||
if self.fa_quant_layer:
|
||||
attn_params = attn_params + (dequant_scale_q_nope, self.fak_descale_float) # type: ignore
|
||||
else:
|
||||
attn_params = attn_params + (None, None) # type: ignore
|
||||
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_v2_get_max_workspace(
|
||||
q_nope, k_nope, k_nope, **common_kwargs
|
||||
)
|
||||
if _EXTRA_CTX.is_draft_model:
|
||||
@@ -1334,38 +1431,16 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
else:
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device)
|
||||
softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device)
|
||||
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(
|
||||
weak_ref_tensors(q_nope),
|
||||
weak_ref_tensors(k_nope),
|
||||
weak_ref_tensors(q_pe),
|
||||
weak_ref_tensors(k_pe),
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
input_layout,
|
||||
weak_ref_tensors(attn_mask) if attn_mask is not None else None,
|
||||
sparse_mode,
|
||||
self.scale,
|
||||
decode_meta.block_table,
|
||||
block_size,
|
||||
decode_meta.seq_lens_list,
|
||||
actual_seq_lengths,
|
||||
weak_ref_tensors(attn_output),
|
||||
weak_ref_tensors(softmax_lse),
|
||||
)
|
||||
)
|
||||
graph_params.attn_params[num_tokens].append(attn_params)
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
torch_npu.npu_fused_infer_attention_score_v2.out(
|
||||
q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse]
|
||||
)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs)
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score_v2(q_nope, k_nope, k_nope, **common_kwargs)
|
||||
|
||||
return self._v_up_proj(attn_output)
|
||||
|
||||
@@ -1381,55 +1456,81 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
|
||||
|
||||
decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
|
||||
decode_q_nope = torch.empty(
|
||||
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
decode_q_pe = torch.empty(
|
||||
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
dequant_scale_q_nope = None
|
||||
if self.fa_quant_layer:
|
||||
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope = torch_npu.npu_mla_prolog_v2(
|
||||
quantized_x,
|
||||
self.wd_q,
|
||||
self.wu_q,
|
||||
self.W_UK_T,
|
||||
self.wd_kv,
|
||||
self.gamma1,
|
||||
self.gamma2,
|
||||
sin,
|
||||
cos,
|
||||
attn_metadata.slot_mapping[:bsz].to(torch.int64),
|
||||
decode_k_nope,
|
||||
decode_k_pe,
|
||||
dequant_scale_x=pertoken_scale.view(-1, 1),
|
||||
dequant_scale_w_dq=self.dequant_scale_w_dq,
|
||||
dequant_scale_w_uq_qr=self.dequant_scale_w_uq_qr,
|
||||
dequant_scale_w_dkv_kr=self.dequant_scale_w_dkv_kr,
|
||||
quant_scale_ckv=self.quant_kscale,
|
||||
cache_mode="PA_NZ",
|
||||
)
|
||||
else:
|
||||
decode_q_nope = torch.empty(
|
||||
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
decode_q_pe = torch.empty(
|
||||
(hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
torch.ops._C_ascend.mla_preprocess(
|
||||
hidden_states,
|
||||
self.wd_qkv,
|
||||
self.deq_scale_qkv,
|
||||
self.gamma1,
|
||||
self.beta1,
|
||||
self.wu_q,
|
||||
self.qb_deq_scl,
|
||||
self.gamma2,
|
||||
cos,
|
||||
sin,
|
||||
self.W_UK_T,
|
||||
decode_k_nope,
|
||||
decode_k_pe,
|
||||
attn_metadata.slot_mapping[:bsz],
|
||||
quant_scale0=self.quant_scale0,
|
||||
quant_offset0=self.quant_offset0,
|
||||
bias0=self.quant_bias_qkv,
|
||||
quant_scale1=self.quant_scale1,
|
||||
quant_offset1=self.quant_offset1,
|
||||
bias1=self.qb_qt_bias,
|
||||
ctkv_scale=self.ctkv_scale,
|
||||
q_nope_scale=self.q_nope_scale,
|
||||
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
q_out0=decode_q_nope,
|
||||
kv_cache_out0=decode_k_nope,
|
||||
q_out1=decode_q_pe,
|
||||
kv_cache_out1=decode_k_pe,
|
||||
enable_inner_out=False,
|
||||
inner_out=torch.tensor([], device=hidden_states.device),
|
||||
)
|
||||
decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank)
|
||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||
torch.ops._C_ascend.mla_preprocess(
|
||||
hidden_states,
|
||||
self.wd_qkv,
|
||||
self.deq_scale_qkv,
|
||||
self.gamma1,
|
||||
self.beta1,
|
||||
self.wu_q,
|
||||
self.qb_deq_scl,
|
||||
self.gamma2,
|
||||
cos,
|
||||
sin,
|
||||
self.W_UK_T,
|
||||
decode_k_nope,
|
||||
decode_k_pe,
|
||||
attn_metadata.slot_mapping[:bsz],
|
||||
quant_scale0=self.quant_scale0,
|
||||
quant_offset0=self.quant_offset0,
|
||||
bias0=self.quant_bias_qkv,
|
||||
quant_scale1=self.quant_scale1,
|
||||
quant_offset1=self.quant_offset1,
|
||||
bias1=self.qb_qt_bias,
|
||||
ctkv_scale=self.ctkv_scale,
|
||||
q_nope_scale=self.q_nope_scale,
|
||||
cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
q_out0=decode_q_nope,
|
||||
kv_cache_out0=decode_k_nope,
|
||||
q_out1=decode_q_pe,
|
||||
kv_cache_out1=decode_k_pe,
|
||||
enable_inner_out=False,
|
||||
inner_out=torch.tensor([], device=hidden_states.device),
|
||||
)
|
||||
decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank)
|
||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||
|
||||
decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe)
|
||||
|
||||
decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
||||
decode_preprocess_res = DecodeMLAPreprocessResult(
|
||||
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe, dequant_scale_q_nope=dequant_scale_q_nope
|
||||
)
|
||||
return decode_preprocess_res, None
|
||||
|
||||
def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata):
|
||||
@@ -1576,7 +1677,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
# MLA Preprocess
|
||||
if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS:
|
||||
if self.fa_quant_layer or (self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS):
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states.contiguous(), need_gather_q_kv
|
||||
)
|
||||
@@ -1596,6 +1697,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
decode_preprocess_res.k_pe,
|
||||
kv_cache[0].shape[1],
|
||||
attn_metadata,
|
||||
decode_preprocess_res.dequant_scale_q_nope,
|
||||
)
|
||||
|
||||
o_proj_input[:num_decode_tokens] = output_decode
|
||||
|
||||
Reference in New Issue
Block a user