feat: support compile torchair graph while warming up (#839)

### What this PR does / why we need it?
feat: support compile torchair graph while warming up

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-05-31 06:03:03 +08:00
committed by GitHub
parent d9fb027068
commit 507ae627ca
7 changed files with 242 additions and 234 deletions

View File

@@ -36,9 +36,10 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed import (get_dp_group, get_pp_group,
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -211,8 +212,12 @@ class CustomDeepseekV2MoE(nn.Module):
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.
@@ -547,7 +552,11 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp, CustomDeepseekV2MoE):
hidden_states = self.mlp(hidden_states, attn_metadata)
else:
hidden_states = self.mlp(hidden_states)
if isinstance(
self.mlp,