[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -26,10 +26,10 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||
@@ -242,8 +242,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
forward_context = get_forward_context()
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
mc2_mask = _EXTRA_CTX.mc2_mask
|
||||
if self.tp_size > 1:
|
||||
# Also slice mc2_mask
|
||||
split_mc2_mask = torch.tensor_split(mc2_mask, self.tp_size, dim=0)
|
||||
@@ -252,7 +251,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
if not self.replace_allreduce:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
target_pad_length = forward_context.padded_num_tokens
|
||||
target_pad_length = _EXTRA_CTX.padded_num_tokens
|
||||
pad_size = target_pad_length - self.num_tokens
|
||||
|
||||
# Pad if necessary (unless shared expert DP is enabled)
|
||||
@@ -367,8 +366,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
max_tokens_across_dp = _EXTRA_CTX.max_tokens_across_dp
|
||||
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
@@ -381,8 +379,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_pcp = forward_context.max_tokens_across_pcp
|
||||
max_tokens_across_pcp = _EXTRA_CTX.max_tokens_across_pcp
|
||||
|
||||
self.num_tokens_pcp = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_pcp - self.num_tokens_pcp
|
||||
|
||||
Reference in New Issue
Block a user