feat: support compile torchair graph while warming up (#839)
### What this PR does / why we need it? feat: support compile torchair graph while warming up Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -36,9 +36,10 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed import (get_dp_group, get_pp_group,
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -211,8 +212,12 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
# TODO: need a better flag to indicate whether in profile run or not.
|
||||
@@ -547,7 +552,11 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(self.mlp, CustomDeepseekV2MoE):
|
||||
hidden_states = self.mlp(hidden_states, attn_metadata)
|
||||
else:
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if isinstance(
|
||||
self.mlp,
|
||||
|
||||
Reference in New Issue
Block a user