[Model][3/N] Refactor sfa into mla and remove deepseek_v3_2.py (#3769)
This is the follow-up PR to PR #3189, which continues to refactor sfa
into mla and finally remove deepseek_v3_2.py. This is the last PR of
deepseek modeling refactoring. After this, all deepseek-related model
codes are removed from vllm_ascend.
FurtherMore, after this PR deepseek v3.2 can run chunk-prefill with
correct accuracy.
- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -52,6 +52,35 @@ else:
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
|
||||
|
||||
|
||||
class IndexerWrapper(nn.Module):
|
||||
'''
|
||||
A wrapper of Indexer for Deepseek v3.2.
|
||||
This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py.
|
||||
It wraps the original Indexer, inherits its module weights
|
||||
(including wq_b, wk, weights_proj, k_norm)
|
||||
while deletes the unused topk_indices_buffer and k_cache to save memory.
|
||||
TODO: Will be removed once original Indexer supports different quantization methods.
|
||||
'''
|
||||
|
||||
def __init__(self, vllm_indexer: nn.Module) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.n_head: int = vllm_indexer.n_head # 64
|
||||
self.head_dim: int = vllm_indexer.head_dim # 128
|
||||
self.topk_tokens: int = vllm_indexer.topk_tokens # 2048
|
||||
self.q_lora_rank: int = vllm_indexer.q_lora_rank # 1536
|
||||
self.wq_b = vllm_indexer.wq_b
|
||||
self.wk = vllm_indexer.wk
|
||||
self.weights_proj = vllm_indexer.weights_proj
|
||||
self.k_norm = vllm_indexer.k_norm
|
||||
self.softmax_scale = vllm_indexer.softmax_scale
|
||||
vllm_indexer.topk_indices_buffer = None # delete topk_indices_buffer
|
||||
vllm_indexer.k_cache = None # delete k_cache
|
||||
|
||||
def forward(self):
|
||||
return
|
||||
|
||||
|
||||
# TODO(whx): adapt v0.11.0 and DSA
|
||||
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
|
||||
@@ -86,6 +115,10 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
self.first_k_dense_replace = hf_config.first_k_dense_replace
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.layers = hf_config.num_hidden_layers
|
||||
if mla_modules.indexer is not None:
|
||||
ascend_indexer = IndexerWrapper(mla_modules.indexer)
|
||||
else:
|
||||
ascend_indexer = None
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
self.mla_attn = Attention(
|
||||
@@ -97,6 +130,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
indexer=ascend_indexer,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
@@ -128,7 +163,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
indexer=mla_modules.indexer,
|
||||
indexer=ascend_indexer,
|
||||
# extra args
|
||||
rotary_emb=mla_modules.rotary_emb,
|
||||
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
|
||||
|
||||
Reference in New Issue
Block a user