[Misc] Nit fix for disaggregated_prefill and ascend_forward_context (#2097)
we recently added disaggregated_prefill and ascend_forward_context feature byba3dfbd59eanddf0ec55162. This PR fix some nit introduced by them to make the code clear. 1. drop `current_platform` usage. It'll lead unknown circular import error in some case 2. update `set_ascend_forward_context` function to make the logic clear. for example, remove V0 support in this function. 3. Remove useless `self.local_rank_across_dp` in worker 4. Remove `soc_info.py` to use `get_ascend_soc_version` instead. - vLLM version: v0.10.0 - vLLM main:02f82fe438Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
|||||||
|
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from vllm_ascend.soc_info import NPUSocInfo
|
from vllm_ascend.utils import AscendSocVersion, init_ascend_soc_version, get_ascend_soc_version
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Arguments of rank table generator", )
|
description="Arguments of rank table generator", )
|
||||||
@@ -33,7 +33,9 @@ local_rank = os.environ.get("LOCAL_RANK")
|
|||||||
# This variable is set by torchrun,
|
# This variable is set by torchrun,
|
||||||
# and is different from WORLD_SIZE in gen_rank_table.sh.
|
# and is different from WORLD_SIZE in gen_rank_table.sh.
|
||||||
world_size = os.environ.get("WORLD_SIZE")
|
world_size = os.environ.get("WORLD_SIZE")
|
||||||
soc_info = NPUSocInfo()
|
|
||||||
|
init_ascend_soc_version()
|
||||||
|
soc_info = get_ascend_soc_version()
|
||||||
|
|
||||||
|
|
||||||
def get_cmd_stdout(cmd):
|
def get_cmd_stdout(cmd):
|
||||||
@@ -59,7 +61,7 @@ if local_rank == "0":
|
|||||||
for card_id in range(num_cards):
|
for card_id in range(num_cards):
|
||||||
for chip_id in range(chips_per_card):
|
for chip_id in range(chips_per_card):
|
||||||
device_id = card_id * chips_per_card + chip_id
|
device_id = card_id * chips_per_card + chip_id
|
||||||
if soc_info.is_a3:
|
if soc_info == AscendSocVersion.A3:
|
||||||
device_ip = get_cmd_stdout(
|
device_ip = get_cmd_stdout(
|
||||||
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
|
f"{hccn_tool_path} -i {device_id} -vnic -g | grep ipaddr"
|
||||||
).split(":")[1].strip()
|
).split(":")[1].strip()
|
||||||
@@ -79,7 +81,7 @@ if local_rank == "0":
|
|||||||
"device_id": str(device_id),
|
"device_id": str(device_id),
|
||||||
"device_ip": str(device_ip),
|
"device_ip": str(device_ip),
|
||||||
}
|
}
|
||||||
if soc_info.is_a3:
|
if soc_info == AscendSocVersion.A3:
|
||||||
device_info.update({
|
device_info.update({
|
||||||
"super_pod_id": str(super_pod_id),
|
"super_pod_id": str(super_pod_id),
|
||||||
"super_device_id": str(super_device_id)
|
"super_device_id": str(super_device_id)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import get_fused_moe_state
|
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||||
AscendUnquantizedFusedMoEMethod)
|
AscendUnquantizedFusedMoEMethod)
|
||||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||||
@@ -310,7 +310,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
|||||||
global_num_experts, ep_size = others_param
|
global_num_experts, ep_size = others_param
|
||||||
is_prefill = False
|
is_prefill = False
|
||||||
is_deepseek_v3_r1 = global_num_experts == 256
|
is_deepseek_v3_r1 = global_num_experts == 256
|
||||||
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
|
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||||
return_value=forward_context):
|
return_value=forward_context):
|
||||||
@@ -346,7 +346,7 @@ class TestAscendUnquantizedFusedMoEMethod:
|
|||||||
ep_size, alltoall_buffer = others_param
|
ep_size, alltoall_buffer = others_param
|
||||||
is_prefill = False
|
is_prefill = False
|
||||||
forward_context = MagicMock(
|
forward_context = MagicMock(
|
||||||
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
|
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
|
||||||
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
|
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
|
||||||
alltoall_buffer), \
|
alltoall_buffer), \
|
||||||
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ import torch
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
import vllm_ascend.envs as envs
|
import vllm_ascend.envs as envs
|
||||||
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEState(Enum):
|
class FusedMoEState(Enum):
|
||||||
@@ -22,8 +22,8 @@ class FusedMoEState(Enum):
|
|||||||
|
|
||||||
|
|
||||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||||
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
def _get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||||
is_deepseek_v3_r1: bool):
|
is_deepseek_v3_r1: bool):
|
||||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||||
# only supports deepseek v3/r1
|
# only supports deepseek v3/r1
|
||||||
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||||
@@ -73,11 +73,9 @@ def set_ascend_forward_context(
|
|||||||
is_deepseek_v3_r1 = hasattr(
|
is_deepseek_v3_r1 = hasattr(
|
||||||
vllm_config.model_config.hf_config, 'n_routed_experts'
|
vllm_config.model_config.hf_config, 'n_routed_experts'
|
||||||
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
) and vllm_config.model_config.hf_config.n_routed_experts == 256
|
||||||
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
|
fused_moe_state = _get_fused_moe_state(ep_size, with_prefill,
|
||||||
is_deepseek_v3_r1)
|
is_deepseek_v3_r1)
|
||||||
|
|
||||||
forward_context.fused_moe_state = fused_moe_state
|
forward_context.fused_moe_state = fused_moe_state
|
||||||
|
|
||||||
forward_context.in_profile_run = in_profile_run
|
forward_context.in_profile_run = in_profile_run
|
||||||
|
|
||||||
# NOTE: This cannot be set using set_forward_context
|
# NOTE: This cannot be set using set_forward_context
|
||||||
@@ -85,15 +83,7 @@ def set_ascend_forward_context(
|
|||||||
forward_context.capturing = False
|
forward_context.capturing = False
|
||||||
|
|
||||||
if num_tokens is None and attn_metadata is not None:
|
if num_tokens is None and attn_metadata is not None:
|
||||||
if hasattr(attn_metadata, 'num_actual_tokens'):
|
num_tokens = attn_metadata.num_actual_tokens
|
||||||
# for v1 engine
|
|
||||||
num_tokens = attn_metadata.num_actual_tokens
|
|
||||||
else:
|
|
||||||
# for v0 engine
|
|
||||||
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
|
||||||
|
|
||||||
if num_actual_tokens is None:
|
|
||||||
num_actual_tokens = num_tokens
|
|
||||||
|
|
||||||
dp_world_size = get_dp_group().world_size
|
dp_world_size = get_dp_group().world_size
|
||||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||||
@@ -105,6 +95,8 @@ def set_ascend_forward_context(
|
|||||||
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
forward_context.max_tokens_across_dp = max_tokens_across_dp
|
||||||
|
|
||||||
if num_tokens is not None:
|
if num_tokens is not None:
|
||||||
|
if num_actual_tokens is None:
|
||||||
|
num_actual_tokens = num_tokens
|
||||||
tp_world_size = get_tp_group().world_size
|
tp_world_size = get_tp_group().world_size
|
||||||
# NOTE: token num which need to pad to when mc2
|
# NOTE: token num which need to pad to when mc2
|
||||||
forward_context.padded_num_tokens = math.ceil(
|
forward_context.padded_num_tokens = math.ceil(
|
||||||
@@ -112,7 +104,7 @@ def set_ascend_forward_context(
|
|||||||
|
|
||||||
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
|
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=current_platform.device_type)
|
device=NPUPlatform.device_type)
|
||||||
mc2_mask[:num_actual_tokens] = True
|
mc2_mask[:num_actual_tokens] = True
|
||||||
forward_context.mc2_mask = mc2_mask
|
forward_context.mc2_mask = mc2_mask
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.soc_info import NPUSocInfo
|
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||||
|
|
||||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||||
@@ -336,7 +336,7 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
self.local_agent_metadata.cluster_id)
|
self.local_agent_metadata.cluster_id)
|
||||||
self.init_llm_datadist()
|
self.init_llm_datadist()
|
||||||
self.finished_reqs: set[str] = set()
|
self.finished_reqs: set[str] = set()
|
||||||
self.soc_info = NPUSocInfo()
|
self.soc_info = get_ascend_soc_version()
|
||||||
# Set hccl deterministic for model execute
|
# Set hccl deterministic for model execute
|
||||||
os.environ["HCCL_DETERMINISTIC"] = "true"
|
os.environ["HCCL_DETERMINISTIC"] = "true"
|
||||||
self.done_receiving_counts: defaultdict[str,
|
self.done_receiving_counts: defaultdict[str,
|
||||||
@@ -681,7 +681,7 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
rank_table["server_list"].append( # type: ignore[attr-defined]
|
rank_table["server_list"].append( # type: ignore[attr-defined]
|
||||||
decode_server_device_info)
|
decode_server_device_info)
|
||||||
|
|
||||||
if self.soc_info.is_a3:
|
if self.soc_info == AscendSocVersion.A3:
|
||||||
# generate super_pod_list for rank table
|
# generate super_pod_list for rank table
|
||||||
super_pod_list = []
|
super_pod_list = []
|
||||||
prefill_super_pod_info = {
|
prefill_super_pod_info = {
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import torch_npu
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NPUSocInfo:
|
|
||||||
is_a3: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
torch_npu.npu._lazy_init()
|
|
||||||
self.soc_version = torch_npu._C._npu_get_soc_version()
|
|
||||||
if self.soc_version in (250, 251, 252, 253, 254, 255):
|
|
||||||
self.is_a3 = True
|
|
||||||
@@ -479,7 +479,8 @@ def register_ascend_customop():
|
|||||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
|
||||||
|
|
||||||
|
|
||||||
# TODO(zzzzwwjj): It will be judged with _build_info afterwards.
|
# TODO(zzzzwwjj): Currently there is no clear SOC_VERSION policy for A2 and A3 in CANN.
|
||||||
|
# So we get the version dynamically. In the future, we should get the version info from _build_info like 310p does.
|
||||||
class AscendSocVersion(Enum):
|
class AscendSocVersion(Enum):
|
||||||
A2 = 0
|
A2 = 0
|
||||||
A3 = 1
|
A3 = 1
|
||||||
|
|||||||
@@ -71,8 +71,10 @@ class NPUWorker(WorkerBase):
|
|||||||
from vllm_ascend import ops
|
from vllm_ascend import ops
|
||||||
ops.register_dummy_fusion_op()
|
ops.register_dummy_fusion_op()
|
||||||
_register_atb_extensions()
|
_register_atb_extensions()
|
||||||
# init ascend config
|
|
||||||
|
# init ascend config and soc version
|
||||||
init_ascend_config(vllm_config)
|
init_ascend_config(vllm_config)
|
||||||
|
init_ascend_soc_version()
|
||||||
|
|
||||||
super().__init__(vllm_config=vllm_config,
|
super().__init__(vllm_config=vllm_config,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
@@ -81,9 +83,6 @@ class NPUWorker(WorkerBase):
|
|||||||
is_driver_worker=is_driver_worker)
|
is_driver_worker=is_driver_worker)
|
||||||
|
|
||||||
# Try to import mindie_turbo to accelerate vLLM inference.
|
# Try to import mindie_turbo to accelerate vLLM inference.
|
||||||
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
|
|
||||||
world_size = self.vllm_config.parallel_config.world_size
|
|
||||||
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
|
|
||||||
try_register_lib(
|
try_register_lib(
|
||||||
"mindie_turbo",
|
"mindie_turbo",
|
||||||
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
|
"MindIE Turbo is installed. vLLM inference will be accelerated with MindIE Turbo."
|
||||||
@@ -137,7 +136,6 @@ class NPUWorker(WorkerBase):
|
|||||||
NPUPlatform.empty_cache()
|
NPUPlatform.empty_cache()
|
||||||
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||||
|
|
||||||
init_ascend_soc_version()
|
|
||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
self._init_worker_distributed_environment()
|
self._init_worker_distributed_environment()
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
|
|||||||
Reference in New Issue
Block a user