support deepseek quant & mix-parallel with graphmode (#585)

### What this PR does / why we need it?
1. support deepseek with w8a8 quant;
2. support deepseek with mix-parallel(multi-DP, EP+TP);
3. support deepseek with graphmode.
---------

Signed-off-by: wen-jie666 <wenjie39@huawei.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
zzzzwwjj
2025-04-23 16:23:25 +08:00
committed by GitHub
parent e74331a1ed
commit 5c6d05a59e
13 changed files with 520 additions and 221 deletions

View File

@@ -26,13 +26,13 @@
# """Inference-only DeepseekV2/DeepseekV3 model."""
import os
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as dist
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.attention import Attention, AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_dp_group, get_pp_group,
@@ -64,7 +64,6 @@ from vllm.model_executor.models.utils import (
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
class CustomDeepseekV2MoE(nn.Module):
@@ -133,7 +132,7 @@ class CustomDeepseekV2MoE(nn.Module):
vllm_config = get_current_vllm_config()
self.dp_size = get_dp_group().world_size
batch_size = vllm_config.scheduler_config.max_num_seqs
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
params_dtype = torch.get_default_dtype()
self.final_hidden_states = torch.zeros(
@@ -309,38 +308,36 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
if VLLM_ENABLE_GRAPH_MODE == "1":
self.forward = self.forward_torchair
else:
self.forward = self.forward_eager # type: ignore
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def forward_torchair(self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor = None,
attn_metadata=None):
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
kv_cache, attn_metadata)
def forward_eager(self, positions: torch.Tensor,
hidden_states: torch.Tensor):
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
if self.enable_graph_mode:
return self.mla_attn.impl.forward(self.mla_attn,
hidden_states_or_q_c,
hidden_states, None, kv_cache,
attn_metadata)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
@@ -408,6 +405,54 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
if hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
class CustomDeepseekV2Model(nn.Module):
@@ -459,7 +504,9 @@ class CustomDeepseekV2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
@@ -473,8 +520,13 @@ class CustomDeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
kv_caches[i -
self.start_layer] if kv_caches is not None else None,
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
@@ -514,6 +566,20 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[List[torch.Tensor]] = None,
attn_metadata: Optional[AttentionMetadata] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass