support deepseek quant & mix-parallel with graphmode (#585)
### What this PR does / why we need it? 1. support deepseek with w8a8 quant; 2. support deepseek with mix-parallel(multi-DP, EP+TP); 3. support deepseek with graphmode. --------- Signed-off-by: wen-jie666 <wenjie39@huawei.com> Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com> Signed-off-by: libaokui <libaokui@huawei.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wen-jie666 <wenjie39@huawei.com>
This commit is contained in:
@@ -26,13 +26,13 @@
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_dp_group, get_pp_group,
|
||||
@@ -64,7 +64,6 @@ from vllm.model_executor.models.utils import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.utils import VLLM_ENABLE_GRAPH_MODE
|
||||
|
||||
|
||||
class CustomDeepseekV2MoE(nn.Module):
|
||||
@@ -133,7 +132,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.dp_size = get_dp_group().world_size
|
||||
batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", 0)) == 1
|
||||
self.enable_mc2 = int(os.environ.get("VLLM_ENABLE_MC2", '0')) == 1
|
||||
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.final_hidden_states = torch.zeros(
|
||||
@@ -309,38 +308,36 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
if VLLM_ENABLE_GRAPH_MODE == "1":
|
||||
self.forward = self.forward_torchair
|
||||
else:
|
||||
self.forward = self.forward_eager # type: ignore
|
||||
self.enable_graph_mode = False
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
def forward_torchair(self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor = None,
|
||||
attn_metadata=None):
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
return self.mla_attn(hidden_states_or_q_c, hidden_states, None,
|
||||
kv_cache, attn_metadata)
|
||||
|
||||
def forward_eager(self, positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor):
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
if self.enable_graph_mode:
|
||||
return self.mla_attn.impl.forward(self.mla_attn,
|
||||
hidden_states_or_q_c,
|
||||
hidden_states, None, kv_cache,
|
||||
attn_metadata)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
@@ -408,6 +405,54 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# We scale both hidden_states and residual before
|
||||
# rmsnorm, and rmsnorm result would not affect by scale.
|
||||
hidden_states *= 1. / self.routed_scaling_factor
|
||||
if self.layer_idx == 0:
|
||||
# The residual is shared by all layers, we only scale it on
|
||||
# first layer.
|
||||
residual *= 1. / self.routed_scaling_factor
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(self.mlp,
|
||||
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
|
||||
# Fix FP16 overflow
|
||||
# Scaling the DeepseekV2MLP output, it is the input of
|
||||
# input_layernorm of next decoder layer.
|
||||
# The scaling of DeepseekV2MOE output would be done in the forward
|
||||
# of DeepseekV2MOE
|
||||
hidden_states *= 1. / self.routed_scaling_factor
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class CustomDeepseekV2Model(nn.Module):
|
||||
|
||||
@@ -459,7 +504,9 @@ class CustomDeepseekV2Model(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
@@ -473,8 +520,13 @@ class CustomDeepseekV2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, residual,
|
||||
kv_caches[i -
|
||||
self.start_layer] if kv_caches is not None else None,
|
||||
attn_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
@@ -514,6 +566,20 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user