[Model][1/N] Delete deepseek v2/v3 modeling codes. (#3189)
This PR deletes model codes of deepseek_v2 and deepseek_v3 to reuse the model file from vLLM. vLLM Ascend now uses custom ops register way instead of model file hard-coding. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -536,12 +536,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.qk_head_dim = kwargs['qk_head_dim']
|
||||
self.v_head_dim = kwargs['v_head_dim']
|
||||
self.rotary_emb = kwargs['rotary_emb']
|
||||
self.q_proj = kwargs['q_proj']
|
||||
self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None)
|
||||
self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[
|
||||
'q_b_proj']
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_proj = kwargs.get('q_a_proj', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -648,36 +649,46 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self._process_weights_for_fused_mlapo(act_dtype)
|
||||
|
||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||
kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., self.q_lora_rank:].contiguous()
|
||||
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., :self.q_lora_rank].contiguous()
|
||||
kv_a_proj_wt = kv_a_proj_wt.contiguous()
|
||||
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data), dim=-1)
|
||||
kv_a_proj_wt = kv_a_proj_wt.contiguous()
|
||||
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
|
||||
wd_qkv = wd_qkv.t().contiguous()
|
||||
wd_qkv = transdata(wd_qkv,
|
||||
block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
|
||||
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale
|
||||
kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[
|
||||
self.q_lora_rank:].contiguous()
|
||||
q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self.
|
||||
q_lora_rank].contiguous(
|
||||
)
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
|
||||
self.qk_rope_head_dim)
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||
self.deq_scale_qkv = torch.cat(
|
||||
(kv_a_proj_deq_scl, self.q_a_proj.deq_scale), dim=-1).contiguous()
|
||||
self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl),
|
||||
dim=-1).contiguous()
|
||||
|
||||
kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias
|
||||
kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[
|
||||
self.q_lora_rank:].contiguous()
|
||||
q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self.
|
||||
q_lora_rank].contiguous(
|
||||
)
|
||||
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
|
||||
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
|
||||
self.qk_rope_head_dim)
|
||||
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
|
||||
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
|
||||
self.quant_bias_qkv = torch.cat(
|
||||
(kv_a_proj_qt_bias, self.q_a_proj.quant_bias),
|
||||
dim=-1).contiguous()
|
||||
self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias),
|
||||
dim=-1).contiguous()
|
||||
|
||||
wu_q = self.q_proj.weight.data
|
||||
wu_q = wu_q.t().reshape(self.num_heads,
|
||||
@@ -704,22 +715,22 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.qb_qt_bias = qb_qt_bias.reshape(
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
|
||||
|
||||
device = self.q_a_proj.weight.device
|
||||
device = self.q_proj.weight.device
|
||||
self.gamma0 = torch.ones(
|
||||
[self.q_a_proj.weight.shape[-1]],
|
||||
[self.fused_qkv_a_proj.weight.shape[-1]],
|
||||
dtype=act_dtype,
|
||||
device=device,
|
||||
)
|
||||
self.beta0 = torch.zeros(
|
||||
[self.q_a_proj.weight.shape[-1]],
|
||||
[self.fused_qkv_a_proj.weight.shape[-1]],
|
||||
dtype=act_dtype,
|
||||
device=device,
|
||||
)
|
||||
self.gamma1 = self.q_a_layernorm.weight.data
|
||||
self.beta1 = self.q_a_layernorm.bias.data
|
||||
self.gamma2 = self.kv_a_layernorm.weight.data
|
||||
self.quant_scale0 = self.q_a_proj.input_scale.data
|
||||
self.quant_offset0 = self.q_a_proj.input_offset.data
|
||||
self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data
|
||||
self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data
|
||||
self.quant_scale1 = self.q_proj.input_scale.data
|
||||
self.quant_offset1 = self.q_proj.input_offset.data
|
||||
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
|
||||
@@ -1122,21 +1133,26 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
if self.q_a_proj is not None:
|
||||
maybe_npu_prefetch(inputs=self.q_a_proj.weight,
|
||||
if self.fused_qkv_a_proj is not None:
|
||||
maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
|
||||
dependency=hidden_states,
|
||||
enabled=self.enable_prefetch)
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
q_c = self.q_a_layernorm(ckq)
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_no_split = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
else:
|
||||
q_c = hidden_states
|
||||
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
|
||||
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# Process for Flash Comm V1
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c, need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split, need_gather_q_kv)
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
if has_prefill:
|
||||
@@ -1264,14 +1280,18 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
output[...] = self.o_proj(o_proj_input)[0]
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
output[...] = self.o_proj(o_proj_input)[0]
|
||||
output[...] = self.o_proj(o_proj_input,
|
||||
is_prefill=prefill_preprocess_res
|
||||
is not None)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
del o_proj_input
|
||||
|
||||
|
||||
Reference in New Issue
Block a user