[Misc] Nit fix for disaggregated_prefill and ascend_forward_context (#2097)
we recently added disaggregated_prefill and ascend_forward_context feature byba3dfbd59eanddf0ec55162. This PR fix some nit introduced by them to make the code clear. 1. drop `current_platform` usage. It'll lead unknown circular import error in some case 2. update `set_ascend_forward_context` function to make the logic clear. for example, remove V0 support in this function. 3. Remove useless `self.local_rank_across_dp` in worker 4. Remove `soc_info.py` to use `get_ascend_soc_version` instead. - vLLM version: v0.10.0 - vLLM main:02f82fe438Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_ascend.soc_info import NPUSocInfo
|
||||
from vllm_ascend.utils import AscendSocVersion, init_ascend_soc_version, get_ascend_soc_version
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Arguments of rank table generator", )
|
||||
@@ -33,7 +33,9 @@ local_rank = os.environ.get("LOCAL_RANK")
|
||||
# This variable is set by torchrun,
|
||||
# and is different from WORLD_SIZE in gen_rank_table.sh.
|
||||
world_size = os.environ.get("WORLD_SIZE")
|
||||
soc_info = NPUSocInfo()
|
||||
|
||||
init_ascend_soc_version()
|
||||
soc_info = get_ascend_soc_version()
|
||||
|
||||
|
||||
def get_cmd_stdout(cmd):
|
||||
@@ -59,7 +61,7 @@ if local_rank == "0":
|
||||
for card_id in range(num_cards):
|
||||
for chip_id in range(chips_per_card):
|
||||
device_id = card_id * chips_per_card + chip_id
|
||||
if soc_info.is_a3:
|
||||
if soc_info == AscendSocVersion.A3:
|
||||
device_ip = get_cmd_stdout(
|
||||
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
|
||||
).split(":")[1].strip()
|
||||
@@ -79,7 +81,7 @@ if local_rank == "0":
|
||||
"device_id": str(device_id),
|
||||
"device_ip": str(device_ip),
|
||||
}
|
||||
if soc_info.is_a3:
|
||||
if soc_info == AscendSocVersion.A3:
|
||||
device_info.update({
|
||||
"super_pod_id": str(super_pod_id),
|
||||
"super_device_id": str(super_device_id)
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from vllm_ascend.ascend_forward_context import get_fused_moe_state
|
||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||
@@ -310,7 +310,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
@@ -346,7 +346,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
ep_size, alltoall_buffer = others_param
|
||||
is_prefill = False
|
||||
forward_context = MagicMock(
|
||||
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
|
||||
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
|
||||
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
|
||||
alltoall_buffer), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
|
||||
@@ -7,9 +7,9 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
@@ -22,8 +22,8 @@ class FusedMoEState(Enum):
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
is_deepseek_v3_r1: bool):
|
||||
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
@@ -73,11 +73,9 @@ def set_ascend_forward_context(
|
||||
is_deepseek_v3_r1 = hasattr(
|
||||
vllm_config.model_config.hf_config, 'n_routed_experts'
|
||||
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
||||
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
|
||||
is_deepseek_v3_r1)
|
||||
|
||||
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
|
||||
is_deepseek_v3_r1)
|
||||
forward_context.fused_moe_state = fused_moe_state
|
||||
|
||||
forward_context.in_profile_run = in_profile_run
|
||||
|
||||
# NOTE: This cannot be set using set_forward_context
|
||||
@@ -85,15 +83,7 @@ def set_ascend_forward_context(
|
||||
forward_context.capturing = False
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
if hasattr(attn_metadata, 'num_actual_tokens'):
|
||||
# for v1 engine
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
else:
|
||||
# for v0 engine
|
||||
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
||||
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
@@ -105,6 +95,8 @@ def set_ascend_forward_context(
|
||||
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
||||
|
||||
if num_tokens is not None:
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
tp_world_size = get_tp_group().world_size
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
@@ -112,7 +104,7 @@ def set_ascend_forward_context(
|
||||
|
||||
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
|
||||
dtype=torch.bool,
|
||||
device=current_platform.device_type)
|
||||
device=NPUPlatform.device_type)
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.soc_info import NPUSocInfo
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
@@ -336,7 +336,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
self.local_agent_metadata.cluster_id)
|
||||
self.init_llm_datadist()
|
||||
self.finished_reqs: set[str] = set()
|
||||
self.soc_info = NPUSocInfo()
|
||||
self.soc_info = get_ascend_soc_version()
|
||||
# Set hccl deterministic for model execute
|
||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||
self.done_receiving_counts: defaultdict[str,
|
||||
@@ -681,7 +681,7 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||
decode_server_device_info)
|
||||
|
||||
if self.soc_info.is_a3:
|
||||
if self.soc_info == AscendSocVersion.A3:
|
||||
# generate super_pod_list for rank table
|
||||
super_pod_list = []
|
||||
prefill_super_pod_info = {
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch_npu
|
||||
|
||||
|
||||
@dataclass
|
||||
class NPUSocInfo:
|
||||
is_a3: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
torch_npu.npu._lazy_init()
|
||||
self.soc_version = torch_npu._C._npu_get_soc_version()
|
||||
if self.soc_version in (250, 251, 252, 253, 254, 255):
|
||||
self.is_a3 = True
|
||||
@@ -479,7 +479,8 @@ def register_ascend_customop():
|
||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): It will be judged with _build_info afterwards.
|
||||
# TODO(zzzzwwjj): Currently there is no clear SOC_VERSION policy for A2 and A3 in CANN.
|
||||
# So we get the version dynamically. In the future, we should get the version info from _build_info like 310p does.
|
||||
class AscendSocVersion(Enum):
|
||||
A2 = 0
|
||||
A3 = 1
|
||||
|
||||
@@ -71,8 +71,10 @@ class NPUWorker(WorkerBase):
|
||||
from vllm_ascend import ops
|
||||
ops.register_dummy_fusion_op()
|
||||
_register_atb_extensions()
|
||||
# init ascend config
|
||||
|
||||
# init ascend config and soc version
|
||||
init_ascend_config(vllm_config)
|
||||
init_ascend_soc_version()
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
@@ -81,9 +83,6 @@ class NPUWorker(WorkerBase):
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
# Try to import mindie_turbo to accelerate vLLM inference.
|
||||
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
|
||||
world_size = self.vllm_config.parallel_config.world_size
|
||||
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
|
||||
try_register_lib(
|
||||
"mindie_turbo",
|
||||
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
|
||||
@@ -137,7 +136,6 @@ class NPUWorker(WorkerBase):
|
||||
NPUPlatform.empty_cache()
|
||||
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||
|
||||
init_ascend_soc_version()
|
||||
# Initialize the distributed environment.
|
||||
self._init_worker_distributed_environment()
|
||||
# Set random seed.
|
||||
|
||||
Reference in New Issue
Block a user