[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with17c540a9931. refactor deepseek to the latest code arch as of17c540a9932. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -839,6 +839,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data.clone()),
|
||||
dim=-1)
|
||||
|
||||
wd_qkv = wd_qkv.t().contiguous()
|
||||
wd_qkv = transdata(wd_qkv,
|
||||
block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
@@ -951,6 +952,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||
|
||||
hidden_states = self.decoder_layer.input_layernorm(hidden_states)
|
||||
|
||||
decode_kq = self.q_a_proj(hidden_states) # q down
|
||||
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
|
||||
|
||||
@@ -982,7 +984,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
return output.fill_(0)
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
@@ -993,10 +995,12 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
|
||||
hidden_states_prefill = hidden_states
|
||||
prefill_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
|
||||
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
|
||||
prefill_kv_no_split = self.kv_a_proj_with_mqa(
|
||||
hidden_states_prefill) # c_kv
|
||||
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
prefill_kv_no_split = get_tp_group().all_gather(
|
||||
prefill_kv_no_split,
|
||||
@@ -1110,6 +1114,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
else:
|
||||
q_len = 1
|
||||
hidden_states_decode = hidden_states
|
||||
|
||||
decode_kq = self.q_a_proj(hidden_states_decode) # q down
|
||||
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
|
||||
decode_kv_no_split = self.kv_a_proj_with_mqa(
|
||||
|
||||
Reference in New Issue
Block a user