[Misc] Nit fix for disaggregated_prefill and ascend_forward_context (#2097)

we recently added disaggregated_prefill and ascend_forward_context
feature by
ba3dfbd59e
and
df0ec55162.
This PR fix some nit introduced by them to make the code clear.
1. drop `current_platform` usage. It'll lead unknown circular import
error in some case
2. update `set_ascend_forward_context` function to make the logic clear.
for example, remove V0 support in this function.
3. Remove useless `self.local_rank_across_dp` in worker
4. Remove `soc_info.py` to use `get_ascend_soc_version` instead.
 

- vLLM version: v0.10.0
- vLLM main:
02f82fe438

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-08-05 08:39:02 +08:00
committed by GitHub
parent ad366bf908
commit 36e450eb0f
7 changed files with 26 additions and 47 deletions

View File

@@ -4,7 +4,7 @@ import os
import torch.distributed as dist
from vllm_ascend.soc_info import NPUSocInfo
from vllm_ascend.utils import AscendSocVersion, init_ascend_soc_version, get_ascend_soc_version
parser = argparse.ArgumentParser(
description="Arguments of rank table generator", )
@@ -33,7 +33,9 @@ local_rank = os.environ.get("LOCAL_RANK")
# This variable is set by torchrun,
# and is different from WORLD_SIZE in gen_rank_table.sh.
world_size = os.environ.get("WORLD_SIZE")
soc_info = NPUSocInfo()
init_ascend_soc_version()
soc_info = get_ascend_soc_version()
def get_cmd_stdout(cmd):
@@ -59,7 +61,7 @@ if local_rank == "0":
for card_id in range(num_cards):
for chip_id in range(chips_per_card):
device_id = card_id * chips_per_card + chip_id
if soc_info.is_a3:
if soc_info == AscendSocVersion.A3:
device_ip = get_cmd_stdout(
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
).split(":")[1].strip()
@@ -79,7 +81,7 @@ if local_rank == "0":
"device_id": str(device_id),
"device_ip": str(device_ip),
}
if soc_info.is_a3:
if soc_info == AscendSocVersion.A3:
device_info.update({
"super_pod_id": str(super_pod_id),
"super_device_id": str(super_device_id)

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm_ascend.ascend_forward_context import get_fused_moe_state
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
@@ -310,7 +310,7 @@ class TestAscendUnquantizedFusedMoEMethod:
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
@@ -346,7 +346,7 @@ class TestAscendUnquantizedFusedMoEMethod:
ep_size, alltoall_buffer = others_param
is_prefill = False
forward_context = MagicMock(
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
alltoall_buffer), \
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \

View File

@@ -7,9 +7,9 @@ import torch
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.platforms import current_platform
import vllm_ascend.envs as envs
from vllm_ascend.platform import NPUPlatform
class FusedMoEState(Enum):
@@ -22,8 +22,8 @@ class FusedMoEState(Enum):
# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
@@ -73,11 +73,9 @@ def set_ascend_forward_context(
is_deepseek_v3_r1 = hasattr(
vllm_config.model_config.hf_config, 'n_routed_experts'
) and vllm_config.model_config.hf_config.n_routed_experts == 256
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
is_deepseek_v3_r1)
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
is_deepseek_v3_r1)
forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run
# NOTE: This cannot be set using set_forward_context
@@ -85,15 +83,7 @@ def set_ascend_forward_context(
forward_context.capturing = False
if num_tokens is None and attn_metadata is not None:
if hasattr(attn_metadata, 'num_actual_tokens'):
# for v1 engine
num_tokens = attn_metadata.num_actual_tokens
else:
# for v0 engine
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
if num_actual_tokens is None:
num_actual_tokens = num_tokens
num_tokens = attn_metadata.num_actual_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
@@ -105,6 +95,8 @@ def set_ascend_forward_context(
forward_context.max_tokens_across_dp = max_tokens_across_dp
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
tp_world_size = get_tp_group().world_size
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
@@ -112,7 +104,7 @@ def set_ascend_forward_context(
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
dtype=torch.bool,
device=current_platform.device_type)
device=NPUPlatform.device_type)
mc2_mask[:num_actual_tokens] = True
forward_context.mc2_mask = mc2_mask

View File

@@ -28,7 +28,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request, RequestStatus
from vllm_ascend import envs
from vllm_ascend.soc_info import NPUSocInfo
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
@@ -336,7 +336,7 @@ class LLMDataDistCMgrConnectorWorker():
self.local_agent_metadata.cluster_id)
self.init_llm_datadist()
self.finished_reqs: set[str] = set()
self.soc_info = NPUSocInfo()
self.soc_info = get_ascend_soc_version()
# Set hccl deterministic for model execute
os.environ["HCCL_DETERMINISTIC"] = "true"
self.done_receiving_counts: defaultdict[str,
@@ -681,7 +681,7 @@ class LLMDataDistCMgrConnectorWorker():
rank_table["server_list"].append( # type: ignore[attr-defined]
decode_server_device_info)
if self.soc_info.is_a3:
if self.soc_info == AscendSocVersion.A3:
# generate super_pod_list for rank table
super_pod_list = []
prefill_super_pod_info = {

View File

@@ -1,14 +0,0 @@
from dataclasses import dataclass
import torch_npu
@dataclass
class NPUSocInfo:
is_a3: bool = False
def __post_init__(self):
torch_npu.npu._lazy_init()
self.soc_version = torch_npu._C._npu_get_soc_version()
if self.soc_version in (250, 251, 252, 253, 254, 255):
self.is_a3 = True

View File

@@ -479,7 +479,8 @@ def register_ascend_customop():
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
# TODO(zzzzwwjj): It will be judged with _build_info afterwards.
# TODO(zzzzwwjj): Currently there is no clear SOC_VERSION policy for A2 and A3 in CANN.
# So we get the version dynamically. In the future, we should get the version info from _build_info like 310p does.
class AscendSocVersion(Enum):
A2 = 0
A3 = 1

View File

@@ -71,8 +71,10 @@ class NPUWorker(WorkerBase):
from vllm_ascend import ops
ops.register_dummy_fusion_op()
_register_atb_extensions()
# init ascend config
# init ascend config and soc version
init_ascend_config(vllm_config)
init_ascend_soc_version()
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,
@@ -81,9 +83,6 @@ class NPUWorker(WorkerBase):
is_driver_worker=is_driver_worker)
# Try to import mindie_turbo to accelerate vLLM inference.
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
world_size = self.vllm_config.parallel_config.world_size
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
try_register_lib(
"mindie_turbo",
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
@@ -137,7 +136,6 @@ class NPUWorker(WorkerBase):
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
init_ascend_soc_version()
# Initialize the distributed environment.
self._init_worker_distributed_environment()
# Set random seed.