[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with17c540a9931. refactor deepseek to the latest code arch as of17c540a9932. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -31,7 +31,7 @@ import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -75,7 +75,12 @@ from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
TorchairAscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable
|
||||
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable, vllm_version_is
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.attention import Attention
|
||||
else:
|
||||
from vllm.attention.layer import MLAAttention
|
||||
|
||||
|
||||
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
|
||||
@@ -561,30 +566,65 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
||||
# i.e.
|
||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else None,
|
||||
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
if vllm_version_is("0.11.0"):
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=False,
|
||||
indexer=None,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj
|
||||
if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
decoder_layer=decoder_layer,
|
||||
)
|
||||
else:
|
||||
self.mla_attn = MLAAttention(
|
||||
num_heads=self.num_local_heads,
|
||||
scale=self.scaling,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_sparse=False,
|
||||
indexer=None,
|
||||
# MLA Args
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj
|
||||
if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -791,35 +831,65 @@ class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention):
|
||||
prefix=f"{prefix}.indexer",
|
||||
)
|
||||
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=True,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
indexer=self.indexer,
|
||||
decoder_layer=decoder_layer,
|
||||
)
|
||||
if vllm_version_is("0.11.0"):
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=True,
|
||||
indexer=self.indexer,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj
|
||||
if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
decoder_layer=decoder_layer,
|
||||
)
|
||||
else:
|
||||
self.sfa_attn = MLAAttention(
|
||||
num_heads=self.num_local_heads,
|
||||
scale=self.scaling,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_sparse=True,
|
||||
indexer=self.indexer,
|
||||
# MLA Args
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj
|
||||
if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user