[1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it? This is the step 1 of refactoring code to adapt with vllm main, and this pr aligned with17c540a9931. refactor deepseek to the latest code arch as of17c540a9932. bunches of fixes due to vllm changes - Fix `AscendScheduler` `__post_init__`, caused by https://github.com/vllm-project/vllm/pull/25075 - Fix `AscendScheduler` init got an unexpected arg `block_size`, caused by https://github.com/vllm-project/vllm/pull/26296 - Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by https://github.com/vllm-project/vllm/pull/23485 - Fix `MLAAttention` import,caused by https://github.com/vllm-project/vllm/pull/25103 - Fix `SharedFusedMoE` import, caused by https://github.com/vllm-project/vllm/pull/26145 - Fix `LazyLoader` improt, caused by https://github.com/vllm-project/vllm/pull/27022 - Fix `vllm.utils.swap_dict_values` improt, caused by https://github.com/vllm-project/vllm/pull/26990 - Fix `Backend` enum import, caused by https://github.com/vllm-project/vllm/pull/25893 - Fix `CompilationLevel` renaming to `CompilationMode` issue introduced by https://github.com/vllm-project/vllm/pull/26355 - Fix fused_moe ops, caused by https://github.com/vllm-project/vllm/pull/24097 - Fix bert model because of `inputs_embeds`, caused by https://github.com/vllm-project/vllm/pull/25922 - Fix MRope because of `get_input_positions_tensor` to `get_mrope_input_positions`, caused by https://github.com/vllm-project/vllm/pull/24172 - Fix `splitting_ops` changes introduced by https://github.com/vllm-project/vllm/pull/25845 - Fix multi-modality changes introduced by https://github.com/vllm-project/vllm/issues/16229 - Fix lora bias dropping issue introduced by https://github.com/vllm-project/vllm/pull/25807 - Fix structured ouput break introduced by https://github.com/vllm-project/vllm/issues/26737 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? CI passed with existing test. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Icey <1790571317@qq.com> Co-authored-by: Icey <1790571317@qq.com>
This commit is contained in:
@@ -42,6 +42,14 @@ else:
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.attention import Attention
|
||||
from vllm.model_executor.layers.mla import \
|
||||
MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper
|
||||
else:
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
|
||||
|
||||
|
||||
# TODO(whx): adapt v0.11.0 and DSA
|
||||
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
@@ -107,22 +115,20 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
|
||||
)
|
||||
else:
|
||||
self.mla_attn = MLAAttention(
|
||||
num_heads=self.num_heads,
|
||||
num_heads=num_heads,
|
||||
scale=scale,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
kv_b_proj=mla_modules.kv_b_proj,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
kv_b_proj=mla_modules.kv_b_proj,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
indexer=mla_modules.indexer,
|
||||
# extra args
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
rotary_emb=mla_modules.rotary_emb,
|
||||
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
|
||||
q_b_proj=mla_modules.q_b_proj,
|
||||
|
||||
@@ -24,18 +24,29 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||
from vllm.model_executor.layers.mla import MLAModules
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.11.0"):
|
||||
from vllm.attention import Attention
|
||||
from vllm.model_executor.layers.mla import \
|
||||
MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper
|
||||
else:
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAModules:
|
||||
q_a_proj: Optional[torch.nn.Module]
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
kv_a_proj_with_mqa: torch.nn.Module
|
||||
@@ -44,73 +55,103 @@ class AscendSFAModules:
|
||||
o_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
indexer: torch.nn.Module
|
||||
is_sparse: bool
|
||||
fused_qkv_a_proj: Optional[torch.nn.Module]
|
||||
q_b_proj: Optional[torch.nn.Module]
|
||||
topk_indices_buffer: Optional[torch.Tensor]
|
||||
|
||||
|
||||
class AscendSparseFlashAttention(MultiHeadLatentAttention):
|
||||
class AscendSparseFlashAttention(MultiHeadLatentAttentionWrapper):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
enable_shared_expert_dp: bool,
|
||||
debug_layer_idx: int,
|
||||
first_k_dense_replace: int,
|
||||
tp_size: int,
|
||||
sfa_modules: AscendSFAModules,
|
||||
num_local_heads: int,
|
||||
scaling: float,
|
||||
layers: int,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
mla_modules: MLAModules,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
self.debug_layer_idx = debug_layer_idx
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.tp_size = tp_size
|
||||
self.num_local_heads = num_local_heads
|
||||
self.layers = layers
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.prefix = prefix
|
||||
self.scaling = scale
|
||||
self.indexer = mla_modules.indexer
|
||||
self.is_sparse = mla_modules.is_sparse
|
||||
hf_config = get_current_vllm_config().model_config.hf_config
|
||||
self.enable_shared_expert_dp = get_ascend_config(
|
||||
).enable_shared_expert_dp
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
self.first_k_dense_replace = hf_config.first_k_dense_replace
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.layers = hf_config.num_hidden_layers
|
||||
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=True,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=sfa_modules.rotary_emb,
|
||||
q_a_proj=sfa_modules.q_a_proj,
|
||||
q_a_layernorm=sfa_modules.q_a_layernorm,
|
||||
q_proj=sfa_modules.q_proj,
|
||||
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=sfa_modules.kv_a_layernorm,
|
||||
kv_b_proj=sfa_modules.kv_b_proj,
|
||||
o_proj=sfa_modules.o_proj,
|
||||
indexer=sfa_modules.indexer)
|
||||
if vllm_version_is("0.11.0"):
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=num_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=True,
|
||||
indexer=self.indexer,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
rotary_emb=mla_modules.rotary_emb,
|
||||
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
|
||||
q_b_proj=mla_modules.q_b_proj,
|
||||
q_a_layernorm=mla_modules.q_a_layernorm,
|
||||
q_proj=mla_modules.q_proj,
|
||||
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=mla_modules.kv_a_layernorm,
|
||||
kv_b_proj=mla_modules.kv_b_proj,
|
||||
o_proj=mla_modules.o_proj,
|
||||
)
|
||||
else:
|
||||
self.sfa_attn = MLAAttention(
|
||||
num_heads=num_heads,
|
||||
scale=scale,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
kv_b_proj=mla_modules.kv_b_proj,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
indexer=mla_modules.indexer,
|
||||
# extra args
|
||||
rotary_emb=mla_modules.rotary_emb,
|
||||
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
|
||||
q_b_proj=mla_modules.q_b_proj,
|
||||
q_a_layernorm=mla_modules.q_a_layernorm,
|
||||
q_proj=mla_modules.q_proj,
|
||||
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=mla_modules.kv_a_layernorm,
|
||||
o_proj=mla_modules.o_proj,
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
|
||||
Reference in New Issue
Block a user