[Model][3/N] Refactor sfa into mla and remove deepseek_v3_2.py (#3769)
This is the follow-up PR to PR #3189, which continues to refactor sfa
into mla and finally remove deepseek_v3_2.py. This is the last PR of
deepseek modeling refactoring. After this, all deepseek-related model
codes are removed from vllm_ascend.
FurtherMore, after this PR deepseek v3.2 can run chunk-prefill with
correct accuracy.
- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -52,6 +52,8 @@ if prefill_context_parallel_enable():
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
|
||||
|
||||
class AscendMLABackend(AttentionBackend):
|
||||
|
||||
@@ -808,16 +810,17 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
# Currently mlapo only supports W8A8 quantization in MLA scenario
|
||||
# TODO(whx): modify this limitation when mlapo supports floating point
|
||||
if self.fused_qkv_a_proj is None or not isinstance(
|
||||
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
|
||||
None), AscendW8A8LinearMethod):
|
||||
self.enable_mlapo = False
|
||||
logger.warning_once(
|
||||
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
||||
"Some layers in your model are not quantized with W8A8,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
if self.enable_mlapo:
|
||||
# Currently mlapo only supports W8A8 quantization in MLA scenario
|
||||
# TODO(whx): modify this limitation when mlapo supports floating point
|
||||
if self.fused_qkv_a_proj is None or not isinstance(
|
||||
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
|
||||
None), AscendW8A8LinearMethod):
|
||||
self.enable_mlapo = False
|
||||
logger.warning_once(
|
||||
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
||||
"Some layers in your model are not quantized with W8A8,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
if self.enable_mlapo:
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
@@ -1282,12 +1285,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
|
||||
attn_metadata, need_gather_q_kv):
|
||||
# MLA Preprocess:
|
||||
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||
# 3. If need_gather_q_kv, perform all_gather.
|
||||
# 4. Preprocess decode tokens, write kv cache and get:
|
||||
# 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
|
||||
# or
|
||||
# Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||
# 2. If need_gather_q_kv, perform all_gather.
|
||||
# 3. Preprocess decode tokens, write kv cache and get:
|
||||
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
|
||||
# 5. Preprocess prefill tokens, write kv cache and get:
|
||||
# 4. Preprocess prefill tokens, write kv cache and get:
|
||||
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
|
||||
Reference in New Issue
Block a user