[Model][2/N] Remove deepseek_mtp modeling. (#3561)
This PR is step 2 of deepseek model refactoring and removes deepseek_mtp. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
@@ -29,6 +30,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
@@ -557,6 +559,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.prefill_mask = None
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
||||
@@ -654,7 +657,17 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
if envs.VLLM_ASCEND_ENABLE_MLAPO:
|
||||
# Currently mlapo only supports W8A8 quantization in MLA scenario
|
||||
# TODO(whx): modify this limitation when mlapo supports floating point
|
||||
if self.fused_qkv_a_proj is None or not isinstance(
|
||||
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
|
||||
None), AscendW8A8LinearMethod):
|
||||
self.enable_mlapo = False
|
||||
logger.warning(
|
||||
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
||||
"Some layers in your model are not quantized with W8A8,"
|
||||
"thus mlapo is disabled for these layers.")
|
||||
if self.enable_mlapo:
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
@@ -1229,7 +1242,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
# MLA Preprocess
|
||||
forward_context = get_forward_context()
|
||||
if (envs.VLLM_ASCEND_ENABLE_MLAPO and
|
||||
if (self.enable_mlapo and
|
||||
(attn_metadata is None or not forward_context.with_prefill)):
|
||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata)
|
||||
|
||||
Reference in New Issue
Block a user