From 36e450eb0f4fb5e8a262384fb302a5b7dc83b582 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Tue, 5 Aug 2025 08:39:02 +0800 Subject: [PATCH] [Misc] Nit fix for disaggregated_prefill and ascend_forward_context (#2097) we recently added disaggregated_prefill and ascend_forward_context feature by https://github.com/vllm-project/vllm-ascend/commit/ba3dfbd59e43b9071895f483d12c034d8538ced0 and https://github.com/vllm-project/vllm-ascend/commit/df0ec55162339c08b0bfcb79a2dcfac89f8c6e33. This PR fix some nit introduced by them to make the code clear. 1. drop `current_platform` usage. It'll lead unknown circular import error in some case 2. update `set_ascend_forward_context` function to make the logic clear. for example, remove V0 support in this function. 3. Remove useless `self.local_rank_across_dp` in worker 4. Remove `soc_info.py` to use `get_ascend_soc_version` instead. - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/02f82fe4386b3e84eb0f06bfaf7744c5b4fdba4f Signed-off-by: wangxiyuan --- .../disaggregated_prefill_v1/gen_ranktable.py | 10 ++++--- tests/ut/ops/test_fused_ops.py | 6 ++--- vllm_ascend/ascend_forward_context.py | 26 +++++++------------ .../llmdatadist_c_mgr_connector.py | 6 ++--- vllm_ascend/soc_info.py | 14 ---------- vllm_ascend/utils.py | 3 ++- vllm_ascend/worker/worker_v1.py | 8 +++--- 7 files changed, 26 insertions(+), 47 deletions(-) delete mode 100644 vllm_ascend/soc_info.py diff --git a/examples/disaggregated_prefill_v1/gen_ranktable.py b/examples/disaggregated_prefill_v1/gen_ranktable.py index d170f3b..52db3ee 100644 --- a/examples/disaggregated_prefill_v1/gen_ranktable.py +++ b/examples/disaggregated_prefill_v1/gen_ranktable.py @@ -4,7 +4,7 @@ import os import torch.distributed as dist -from vllm_ascend.soc_info import NPUSocInfo +from vllm_ascend.utils import AscendSocVersion, init_ascend_soc_version, get_ascend_soc_version parser = argparse.ArgumentParser( description="Arguments of rank table generator", ) @@ -33,7 +33,9 @@ local_rank = os.environ.get("LOCAL_RANK") # This variable is set by torchrun, # and is different from WORLD_SIZE in gen_rank_table.sh. world_size = os.environ.get("WORLD_SIZE") -soc_info = NPUSocInfo() + +init_ascend_soc_version() +soc_info = get_ascend_soc_version() def get_cmd_stdout(cmd): @@ -59,7 +61,7 @@ if local_rank == "0": for card_id in range(num_cards): for chip_id in range(chips_per_card): device_id = card_id * chips_per_card + chip_id - if soc_info.is_a3: + if soc_info == AscendSocVersion.A3: device_ip = get_cmd_stdout( f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr" ).split(":")[1].strip() @@ -79,7 +81,7 @@ if local_rank == "0": "device_id": str(device_id), "device_ip": str(device_ip), } - if soc_info.is_a3: + if soc_info == AscendSocVersion.A3: device_info.update({ "super_pod_id": str(super_pod_id), "super_device_id": str(super_device_id) diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 2a4bd59..6c89f6f 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -21,7 +21,7 @@ import torch.nn as nn import torch_npu from pytest_mock import MockerFixture -from vllm_ascend.ascend_forward_context import get_fused_moe_state +from vllm_ascend.ascend_forward_context import _get_fused_moe_state from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 @@ -310,7 +310,7 @@ class TestAscendUnquantizedFusedMoEMethod: global_num_experts, ep_size = others_param is_prefill = False is_deepseek_v3_r1 = global_num_experts == 256 - forward_context = MagicMock(fused_moe_state=get_fused_moe_state( + forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( ep_size, is_prefill, is_deepseek_v3_r1)) with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context): @@ -346,7 +346,7 @@ class TestAscendUnquantizedFusedMoEMethod: ep_size, alltoall_buffer = others_param is_prefill = False forward_context = MagicMock( - fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True)) + fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True)) with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER", alltoall_buffer), \ patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 2d08079..c862534 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -7,9 +7,9 @@ import torch from vllm.config import VllmConfig from vllm.distributed import get_dp_group, get_ep_group, get_tp_group from vllm.forward_context import get_forward_context, set_forward_context -from vllm.platforms import current_platform import vllm_ascend.envs as envs +from vllm_ascend.platform import NPUPlatform class FusedMoEState(Enum): @@ -22,8 +22,8 @@ class FusedMoEState(Enum): # TODO(zzzzwwjj): add soc_version to choose branch -def get_fused_moe_state(ep_size: int, with_prefill: bool, - is_deepseek_v3_r1: bool): +def _get_fused_moe_state(ep_size: int, with_prefill: bool, + is_deepseek_v3_r1: bool): # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep # only supports deepseek v3/r1 if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 @@ -73,11 +73,9 @@ def set_ascend_forward_context( is_deepseek_v3_r1 = hasattr( vllm_config.model_config.hf_config, 'n_routed_experts' ) and vllm_config.model_config.hf_config.n_routed_experts == 256 - fused_moe_state = get_fused_moe_state(ep_size, with_prefill, - is_deepseek_v3_r1) - + fused_moe_state = _get_fused_moe_state(ep_size, with_prefill, + is_deepseek_v3_r1) forward_context.fused_moe_state = fused_moe_state - forward_context.in_profile_run = in_profile_run # NOTE: This cannot be set using set_forward_context @@ -85,15 +83,7 @@ def set_ascend_forward_context( forward_context.capturing = False if num_tokens is None and attn_metadata is not None: - if hasattr(attn_metadata, 'num_actual_tokens'): - # for v1 engine - num_tokens = attn_metadata.num_actual_tokens - else: - # for v0 engine - num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens - - if num_actual_tokens is None: - num_actual_tokens = num_tokens + num_tokens = attn_metadata.num_actual_tokens dp_world_size = get_dp_group().world_size if dp_world_size > 1 and forward_context.dp_metadata is not None: @@ -105,6 +95,8 @@ def set_ascend_forward_context( forward_context.max_tokens_across_dp = max_tokens_across_dp if num_tokens is not None: + if num_actual_tokens is None: + num_actual_tokens = num_tokens tp_world_size = get_tp_group().world_size # NOTE: token num which need to pad to when mc2 forward_context.padded_num_tokens = math.ceil( @@ -112,7 +104,7 @@ def set_ascend_forward_context( mc2_mask = torch.zeros(forward_context.padded_num_tokens, dtype=torch.bool, - device=current_platform.device_type) + device=NPUPlatform.device_type) mc2_mask[:num_actual_tokens] = True forward_context.mc2_mask = mc2_mask diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 84b2435..7631a09 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -28,7 +28,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import Request, RequestStatus from vllm_ascend import envs -from vllm_ascend.soc_info import NPUSocInfo +from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version TORCH_DTYPE_TO_NPU_DTYPE = { torch.half: llm_datadist.DataType.DT_FLOAT16, @@ -336,7 +336,7 @@ class LLMDataDistCMgrConnectorWorker(): self.local_agent_metadata.cluster_id) self.init_llm_datadist() self.finished_reqs: set[str] = set() - self.soc_info = NPUSocInfo() + self.soc_info = get_ascend_soc_version() # Set hccl deterministic for model execute os.environ["HCCL_DETERMINISTIC"] = "true" self.done_receiving_counts: defaultdict[str, @@ -681,7 +681,7 @@ class LLMDataDistCMgrConnectorWorker(): rank_table["server_list"].append( # type: ignore[attr-defined] decode_server_device_info) - if self.soc_info.is_a3: + if self.soc_info == AscendSocVersion.A3: # generate super_pod_list for rank table super_pod_list = [] prefill_super_pod_info = { diff --git a/vllm_ascend/soc_info.py b/vllm_ascend/soc_info.py deleted file mode 100644 index ac1317e..0000000 --- a/vllm_ascend/soc_info.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -import torch_npu - - -@dataclass -class NPUSocInfo: - is_a3: bool = False - - def __post_init__(self): - torch_npu.npu._lazy_init() - self.soc_version = torch_npu._C._npu_get_soc_version() - if self.soc_version in (250, 251, 252, 253, 254, 255): - self.is_a3 = True diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index a5e4984..ee620b4 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -479,7 +479,8 @@ def register_ascend_customop(): _ASCEND_CUSTOMOP_IS_REIGISTERED = True -# TODO(zzzzwwjj): It will be judged with _build_info afterwards. +# TODO(zzzzwwjj): Currently there is no clear SOC_VERSION policy for A2 and A3 in CANN. +# So we get the version dynamically. In the future, we should get the version info from _build_info like 310p does. class AscendSocVersion(Enum): A2 = 0 A3 = 1 diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 65d2f51..c13238a 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -71,8 +71,10 @@ class NPUWorker(WorkerBase): from vllm_ascend import ops ops.register_dummy_fusion_op() _register_atb_extensions() - # init ascend config + + # init ascend config and soc version init_ascend_config(vllm_config) + init_ascend_soc_version() super().__init__(vllm_config=vllm_config, local_rank=local_rank, @@ -81,9 +83,6 @@ class NPUWorker(WorkerBase): is_driver_worker=is_driver_worker) # Try to import mindie_turbo to accelerate vLLM inference. - local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local - world_size = self.vllm_config.parallel_config.world_size - self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank try_register_lib( "mindie_turbo", "MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo." @@ -137,7 +136,6 @@ class NPUWorker(WorkerBase): NPUPlatform.empty_cache() self.init_npu_memory = NPUPlatform.mem_get_info()[0] - init_ascend_soc_version() # Initialize the distributed environment. self._init_worker_distributed_environment() # Set random seed.