[Model][3/N] Refactor sfa into mla and remove deepseek_v3_2.py (#3769)

This is the follow-up PR to PR #3189, which continues to refactor sfa
into mla and finally remove deepseek_v3_2.py. This is the last PR of
deepseek modeling refactoring. After this, all deepseek-related model
codes are removed from vllm_ascend.

FurtherMore, after this PR deepseek v3.2 can run chunk-prefill with
correct accuracy.

- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-30 17:06:38 +08:00
committed by GitHub
parent eff3e5fc6f
commit f6149f3894
10 changed files with 751 additions and 1935 deletions

View File

@@ -52,6 +52,35 @@ else:
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
class IndexerWrapper(nn.Module):
'''
A wrapper of Indexer for Deepseek v3.2.
This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py.
It wraps the original Indexer, inherits its module weights
(including wq_b, wk, weights_proj, k_norm)
while deletes the unused topk_indices_buffer and k_cache to save memory.
TODO: Will be removed once original Indexer supports different quantization methods.
'''
def __init__(self, vllm_indexer: nn.Module) -> None:
super().__init__()
self.n_head: int = vllm_indexer.n_head # 64
self.head_dim: int = vllm_indexer.head_dim # 128
self.topk_tokens: int = vllm_indexer.topk_tokens # 2048
self.q_lora_rank: int = vllm_indexer.q_lora_rank # 1536
self.wq_b = vllm_indexer.wq_b
self.wk = vllm_indexer.wk
self.weights_proj = vllm_indexer.weights_proj
self.k_norm = vllm_indexer.k_norm
self.softmax_scale = vllm_indexer.softmax_scale
vllm_indexer.topk_indices_buffer = None # delete topk_indices_buffer
vllm_indexer.k_cache = None # delete k_cache
def forward(self):
return
# TODO(whx): adapt v0.11.0 and DSA
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
@@ -86,6 +115,10 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers
if mla_modules.indexer is not None:
ascend_indexer = IndexerWrapper(mla_modules.indexer)
else:
ascend_indexer = None
if vllm_version_is("0.11.0"):
self.mla_attn = Attention(
@@ -97,6 +130,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
indexer=ascend_indexer,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
@@ -128,7 +163,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer,
indexer=ascend_indexer,
# extra args
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,