Add DeepSeek V3.2 support (#3270)
### What this PR does / why we need it? This PR added the initial DeepSeek V3.2 support with [vLLM v0.11.0](https://github.com/vllm-project/vllm/tree/releases/v0.11.0) (not released yet). We will complete vLLM adaptation as soon as possible. This feature will be ready in recent 1-2 days. Related doc: https://github.com/vllm-project/vllm-ascend/pull/3223 . ### Does this PR introduce _any_ user-facing change? Yes! ### How was this patch tested? CI passed and Run deepseek doc soon. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wxsIcey <1790571317@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -67,7 +67,9 @@ from vllm.model_executor.models.utils import (
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.sfa import Indexer
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
@@ -435,6 +437,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
decoder_layer=None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
@@ -630,6 +633,225 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
output_shape=output_shape)
|
||||
|
||||
|
||||
class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
decoder_layer=None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % self.tp_size == 0
|
||||
self.num_local_heads = num_heads // self.tp_size
|
||||
self.layers = config.num_hidden_layers
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
return_bias=False,
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
if (config.n_routed_experts is not None
|
||||
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||
and self.debug_layer_idx % config.moe_layer_freq == 0
|
||||
and (ascend_config.multistream_overlap_shared_expert
|
||||
or self.enable_shared_expert_dp)):
|
||||
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
else:
|
||||
self.o_proj = TorchairDeepseekV2RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.dim: int = config.hidden_size # 7168
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
self.n_heads: int = 64 # 64
|
||||
self.head_dim: int = 128 # 128
|
||||
self.index_topk: int = 2048 # 2048
|
||||
self.indexer = Indexer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
dim=self.dim,
|
||||
n_heads=self.n_heads,
|
||||
head_dim=self.head_dim,
|
||||
index_topk=self.index_topk,
|
||||
prefix=f"{prefix}.indexer",
|
||||
)
|
||||
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sfa=True,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
indexer=self.indexer,
|
||||
decoder_layer=decoder_layer,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
if not self.torchair_graph_enabled:
|
||||
if forward_context.attn_metadata is not None and isinstance(
|
||||
forward_context.attn_metadata, dict):
|
||||
attn_metadata = next(
|
||||
iter(forward_context.attn_metadata.values()), None)
|
||||
else:
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if kv_cache is None:
|
||||
kv_cache = self.sfa_attn.kv_cache[
|
||||
forward_context.virtual_engine]
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
# if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# # Simulate all gather to calculate output shape
|
||||
# num_tokens = num_tokens * self.tp_size
|
||||
# need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
if self.enable_shared_expert_dp and (
|
||||
self.debug_layer_idx == self.first_k_dense_replace
|
||||
or self.debug_layer_idx == self.layers):
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv, output)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
@@ -654,9 +876,16 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
ascend_config = get_ascend_config()
|
||||
self.use_mla = False
|
||||
self.use_sfa = False
|
||||
# TODO: enable mla in vllm-ascend
|
||||
if model_config.use_mla:
|
||||
attn_cls = TorchairDeepseekV2MLAAttention
|
||||
if ascend_config.use_sfa:
|
||||
attn_cls = TorchairDeepseekV2SFAAttention
|
||||
self.use_sfa = True
|
||||
else:
|
||||
attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment]
|
||||
self.use_mla = True
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
@@ -675,6 +904,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
decoder_layer=self,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
@@ -715,21 +945,34 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
replace_allreduce: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if attn_metadata is not None and attn_metadata.num_decodes > 0:
|
||||
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
||||
if attn_metadata is not None:
|
||||
decoding_condition_met = (
|
||||
not attn_metadata.is_prefill if self.use_sfa else
|
||||
attn_metadata.num_decodes > 0 if self.use_mla else False)
|
||||
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
|
||||
else:
|
||||
mla_moe_communication = False
|
||||
if residual is None:
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if (envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention)
|
||||
and attn_metadata is not None
|
||||
and not forward_context.with_prefill):
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
previous_hidden_states, previous_residual = hidden_states, residual
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
# Dispose hidden_states and residual from the previous layer
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(previous_hidden_states)
|
||||
dispose_tensor(previous_residual)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
previous_hidden_states, previous_residual = hidden_states, residual
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
# Dispose hidden_states and residual from the previous layer
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(previous_hidden_states)
|
||||
dispose_tensor(previous_residual)
|
||||
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
||||
dim=0)
|
||||
|
||||
Reference in New Issue
Block a user