[Model][1/N] Delete deepseek v2/v3 modeling codes. (#3189)

This PR deletes model codes of deepseek_v2 and deepseek_v3 to reuse the
model file from vLLM.

vLLM Ascend now uses custom ops register way instead of model file
hard-coding.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-20 15:31:34 +08:00
committed by GitHub
parent 918ded9155
commit f8b52fe950
13 changed files with 143 additions and 843 deletions

View File

@@ -536,12 +536,13 @@ class AscendMLAImpl(MLAAttentionImpl):
self.qk_head_dim = kwargs['qk_head_dim']
self.v_head_dim = kwargs['v_head_dim']
self.rotary_emb = kwargs['rotary_emb']
self.q_proj = kwargs['q_proj']
self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None)
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
'q_b_proj']
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_proj = kwargs.get('q_a_proj', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
@@ -648,36 +649,46 @@ class AscendMLAImpl(MLAAttentionImpl):
self._process_weights_for_fused_mlapo(act_dtype)
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., self.q_lora_rank:].contiguous()
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
..., :self.q_lora_rank].contiguous()
kv_a_proj_wt = kv_a_proj_wt.contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data), dim=-1)
kv_a_proj_wt = kv_a_proj_wt.contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
self.q_lora_rank:].contiguous()
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
q_lora_rank].contiguous(
)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat(
(kv_a_proj_deq_scl, self.q_a_proj.deq_scale), dim=-1).contiguous()
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
dim=-1).contiguous()
kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
self.q_lora_rank:].contiguous()
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
q_lora_rank].contiguous(
)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat(
(kv_a_proj_qt_bias, self.q_a_proj.quant_bias),
dim=-1).contiguous()
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
dim=-1).contiguous()
wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads,
@@ -704,22 +715,22 @@ class AscendMLAImpl(MLAAttentionImpl):
self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
device = self.q_a_proj.weight.device
device = self.q_proj.weight.device
self.gamma0 = torch.ones(
[self.q_a_proj.weight.shape[-1]],
[self.fused_qkv_a_proj.weight.shape[-1]],
dtype=act_dtype,
device=device,
)
self.beta0 = torch.zeros(
[self.q_a_proj.weight.shape[-1]],
[self.fused_qkv_a_proj.weight.shape[-1]],
dtype=act_dtype,
device=device,
)
self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data
self.quant_scale0 = self.q_a_proj.input_scale.data
self.quant_offset0 = self.q_a_proj.input_offset.data
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
@@ -1122,21 +1133,26 @@ class AscendMLAImpl(MLAAttentionImpl):
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if self.q_a_proj is not None:
maybe_npu_prefetch(inputs=self.q_a_proj.weight,
if self.fused_qkv_a_proj is not None:
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dependency=hidden_states,
enabled=self.enable_prefetch)
ckq = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(ckq)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
else:
q_c = hidden_states
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
# Process for Flash Comm V1
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c, need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split, need_gather_q_kv)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
@@ -1264,14 +1280,18 @@ class AscendMLAImpl(MLAAttentionImpl):
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input)[0]
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(o_proj_input)[0]
output[...] = self.o_proj(o_proj_input,
is_prefill=prefill_preprocess_res
is not None)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input