[Misc] Nit fix for disaggregated_prefill and ascend_forward_context (#2097)
we recently added disaggregated_prefill and ascend_forward_context feature byba3dfbd59eanddf0ec55162. This PR fix some nit introduced by them to make the code clear. 1. drop `current_platform` usage. It'll lead unknown circular import error in some case 2. update `set_ascend_forward_context` function to make the logic clear. for example, remove V0 support in this function. 3. Remove useless `self.local_rank_across_dp` in worker 4. Remove `soc_info.py` to use `get_ascend_soc_version` instead. - vLLM version: v0.10.0 - vLLM main:02f82fe438Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -7,9 +7,9 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
@@ -22,8 +22,8 @@ class FusedMoEState(Enum):
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
is_deepseek_v3_r1: bool):
|
||||
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
@@ -73,11 +73,9 @@ def set_ascend_forward_context(
|
||||
is_deepseek_v3_r1 = hasattr(
|
||||
vllm_config.model_config.hf_config, 'n_routed_experts'
|
||||
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
||||
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
|
||||
is_deepseek_v3_r1)
|
||||
|
||||
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
|
||||
is_deepseek_v3_r1)
|
||||
forward_context.fused_moe_state = fused_moe_state
|
||||
|
||||
forward_context.in_profile_run = in_profile_run
|
||||
|
||||
# NOTE: This cannot be set using set_forward_context
|
||||
@@ -85,15 +83,7 @@ def set_ascend_forward_context(
|
||||
forward_context.capturing = False
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
if hasattr(attn_metadata, 'num_actual_tokens'):
|
||||
# for v1 engine
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
else:
|
||||
# for v0 engine
|
||||
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
||||
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
@@ -105,6 +95,8 @@ def set_ascend_forward_context(
|
||||
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
||||
|
||||
if num_tokens is not None:
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
tp_world_size = get_tp_group().world_size
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
@@ -112,7 +104,7 @@ def set_ascend_forward_context(
|
||||
|
||||
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
|
||||
dtype=torch.bool,
|
||||
device=current_platform.device_type)
|
||||
device=NPUPlatform.device_type)
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
|
||||
|
||||
Reference in New Issue
Block a user