[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)

### What this PR does / why we need it?

A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:

Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;

### Does this PR introduce _any_ user-facing change?

No change at user-facing.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
57c22e57f9

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-07-28 14:06:20 +08:00
committed by GitHub
parent e3a2443c3a
commit ba3dfbd59e
22 changed files with 629 additions and 347 deletions

View File

@@ -21,6 +21,7 @@ def main():
tensor_parallel_size=2,
max_model_len=4096,
trust_remote_code=True,
enable_expert_parallel=True,
additional_config={
"torchair_graph_config": {
"enabled": False
@@ -28,7 +29,6 @@ def main():
"ascend_scheduler_config": {
"enabled": True
},
"expert_tensor_parallel_size": 1
})
# Generate texts from the prompts. The output is a list of RequestOutput

View File

@@ -1,3 +1,7 @@
rm -rf ./.torchair_cache/
rm -rf ./dynamo_*
rm -rf /root/ascend/log/debug/plog/*
export HCCL_IF_IP=2.0.0.0
export GLOO_SOCKET_IFNAME="enp189s0f0"
export TP_SOCKET_IFNAME="enp189s0f0"
@@ -6,25 +10,24 @@ export HCCL_SOCKET_IFNAME="enp189s0f0"
export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100
export VLLM_USE_V1=0
export ASCEND_RT_VISIBLE_DEVICES=0,1
export VLLM_DP_SIZE=2
export VLLM_DP_RANK=0
export VLLM_DP_MASTER_IP="2.0.0.0"
export VLLM_DP_MASTER_PORT=40001
export VLLM_DP_PROXY_IP="2.0.0.0"
export VLLM_DP_PROXY_PORT=30002
export VLLM_DP_MONITOR_PORT=30003
export VLLM_HTTP_PORT=20001
export VLLM_USE_V1=1
export ASCEND_LAUNCH_BLOCKING=0
vllm serve /data/weights/Qwen2.5-0.5B-Instruct \
--host 0.0.0.0 \
--port 20001 \
--tensor-parallel-size 1 \
--seed 1024 \
--port 20002 \
--served-model-name Qwen \
--max-model-len 2000 \
--max-num-batched-tokens 2000 \
--data-parallel-size 4 \
--data-parallel-size-local 4 \
--data-parallel-address 2.0.0.0 \
--data-parallel-rpc-port 13389 \
--tensor-parallel-size 4 \
--enable-expert-parallel \
--no-enable-prefix-caching \
--max-num-seqs 16 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--gpu-memory-utilization 0.9 \
--trust-remote-code \
--gpu-memory-utilization 0.9 \
--enforce-eager \
--additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}'

View File

@@ -114,7 +114,16 @@ def mock_distributed():
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group):
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
yield
@pytest.fixture
def mock_forward_context():
forward_context = Mock(in_profile_run=False, with_prefill=False)
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
return_value=forward_context):
yield
@@ -205,7 +214,8 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
quant_config=None)
def test_custom_deepseek_v2_moe(mock_distributed, base_config):
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = CustomDeepseekV2MoE(config=base_config,
quant_config=None,

View File

@@ -18,16 +18,18 @@ from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm_ascend.ascend_forward_context import get_fused_moe_state
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.utils import adapt_patch # noqa E402
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
adapt_patch(True)
def mock_ep_group(mocker):
def mock_ep_and_mc2_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.rank = 0
@@ -52,7 +54,8 @@ def mock_dist_env(mocker: MockerFixture):
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
@@ -73,7 +76,7 @@ def mock_dist_env(mocker: MockerFixture):
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=MagicMock(
attn_metadata=MagicMock(max_num_tokens_across_dp=10),
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
)), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
@@ -122,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture):
patch("torch_npu.npu_moe_finalize_routing", return_value=(
torch.randn(16, 2)
)):
yield
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
torch.randn(16, 2))), \
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
torch.randn(16, 2))):
yield
else:
yield
@pytest.fixture
@@ -237,11 +247,16 @@ class TestAscendFusedMoe:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once()
@@ -288,15 +303,20 @@ class TestAscendUnquantizedFusedMoEMethod:
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
2 test use_select_experts and fused_experts
3 test use select_gating_topk_softmax_experts and fused_experts
4 test use select_experts and fused_experts_with_all2all_buffer
"""
global_num_experts, ep_size, select_softmax = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
with patch(
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
select_softmax):
select_softmax), \
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
@@ -309,7 +329,7 @@ class TestAscendUnquantizedFusedMoEMethod:
top_k=2,
renormalize=True,
global_num_experts=global_num_experts,
is_prefill=False)
is_prefill=is_prefill)
if ep_size == 1:
assert result.shape == (16, 2)
@@ -327,8 +347,13 @@ class TestAscendUnquantizedFusedMoEMethod:
4 test use_select_experts and fused_experts
"""
ep_size, alltoall_buffer = others_param
is_prefill = False
forward_context = MagicMock(
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
alltoall_buffer):
alltoall_buffer), \
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
@@ -347,7 +372,7 @@ class TestAscendUnquantizedFusedMoEMethod:
renormalize=True,
global_num_experts=128,
expert_map=expert_map,
is_prefill=False)
is_prefill=is_prefill)
if ep_size == 16 or ep_size == 1:
assert result.shape == (16, 2)

View File

@@ -0,0 +1,117 @@
import math
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.platforms import current_platform
import vllm_ascend.envs as envs
class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
AllGatherEP = 3
NaiveMulticast = 4
# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
if with_prefill:
return FusedMoEState.NaiveMulticast
else:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill:
return FusedMoEState.All2All
else:
return FusedMoEState.MC2
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
num_actual_tokens: Optional[int] = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
"""
with set_forward_context(attn_metadata,
vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
forward_context = get_forward_context()
forward_context.with_prefill = with_prefill
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
is_deepseek_v3_r1 = hasattr(
vllm_config.model_config.hf_config, 'n_routed_experts'
) and vllm_config.model_config.hf_config.n_routed_experts == 256
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
is_deepseek_v3_r1)
forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False
if num_tokens is None and attn_metadata is not None:
if hasattr(attn_metadata, 'num_actual_tokens'):
# for v1 engine
num_tokens = attn_metadata.num_actual_tokens
else:
# for v0 engine
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
if num_actual_tokens is None:
num_actual_tokens = num_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
)
else:
max_tokens_across_dp = num_tokens
forward_context.max_tokens_across_dp = max_tokens_across_dp
if num_tokens is not None:
tp_world_size = get_tp_group().world_size
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
dtype=torch.bool,
device=current_platform.device_type)
mc2_mask[:num_actual_tokens] = True
forward_context.mc2_mask = mc2_mask
try:
yield
finally:
pass

View File

@@ -119,6 +119,7 @@ class AscendAttentionState(Enum):
@dataclass
class AscendMetadata:
# **************************** Basic Properties ****************************
attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run.
@@ -149,11 +150,6 @@ class AscendMetadata:
# (num_tokens,)
slot_mapping: torch.Tensor = None
# ************************* DP Related Properties **************************
with_prefill_across_dp: bool = False
# Maximum number of tokens across dp group
max_num_tokens_across_dp: int = 0
class AscendAttentionMetadataBuilder:
@@ -164,12 +160,7 @@ class AscendAttentionMetadataBuilder:
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
def build(self, num_reqs, num_actual_tokens, max_query_len):
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
@@ -196,18 +187,15 @@ class AscendAttentionMetadataBuilder:
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ)
attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp)
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state)
return attn_metadata

View File

@@ -127,8 +127,6 @@ class AscendTorchairMetadata:
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
# max value of number of tokens across dp group
max_num_tokens_across_dp: int = 0
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
@@ -139,7 +137,7 @@ class AscendTorchairMetadata:
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None
with_prefill_across_dp: bool = False
decode: Optional[AscendDecodeMetadata] = None
@@ -178,8 +176,9 @@ class AscendAttentionTorchairMetadataBuilder:
return graph_block_tables[:num_seqs, :max_blocks]
def build_dummy(self, num_reqs: int,
num_actual_tokens: int) -> AscendTorchairMetadata:
def build_torchair_graph_dummy(
self, num_reqs: int,
num_actual_tokens: int) -> AscendTorchairMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
@@ -214,7 +213,6 @@ class AscendAttentionTorchairMetadataBuilder:
seq_lens=seq_lens,
slot_mapping=slot_mapping,
attn_state=AscendAttentionState.DecodeOnly,
max_num_tokens_across_dp=num_reqs,
decode=decode_metadata)
return attn_metadata
@@ -222,9 +220,7 @@ class AscendAttentionTorchairMetadataBuilder:
num_reqs,
num_actual_tokens,
max_query_len,
graph_pad_size: int = -1,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
graph_pad_size: int = -1):
device = self.runner.device
@@ -263,7 +259,6 @@ class AscendAttentionTorchairMetadataBuilder:
pad_value = 1
padded_seq_lens = seq_lens.tolist() + [pad_value
] * graph_pad_size
max_num_tokens_across_dp = len(padded_seq_lens)
seq_lens = torch.from_numpy(
np.array(padded_seq_lens).astype(np.int32))
@@ -303,9 +298,7 @@ class AscendAttentionTorchairMetadataBuilder:
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp)
attn_state=attn_state)
return attn_metadata

View File

@@ -126,9 +126,6 @@ class AscendMLAMetadata:
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
max_num_tokens_across_dp: int = 0
with_prefill_across_dp: bool = False
query_lens: Optional[list[int]] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
@@ -302,8 +299,8 @@ class AscendMLAMetadataBuilder:
return graph_block_tables[:num_seqs, :max_blocks]
def build_dummy(self, num_reqs: int,
num_actual_tokens: int) -> AscendMLAMetadata:
def build_torchair_graph_dummy(
self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
@@ -353,8 +350,6 @@ class AscendMLAMetadataBuilder:
num_actual_tokens: int,
max_query_len: int,
graph_pad_size: int = -1,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False,
query_start_loc: torch.Tensor = None,
) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs
@@ -498,8 +493,6 @@ class AscendMLAMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp,
)

View File

@@ -0,0 +1,48 @@
from typing import Optional
import torch
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2
def model_parallel_initialized():
return (_MC2 is not None)
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if model_parallel_initialized():
return
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
global _MC2
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_MC2 = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="mc2")
def destroy_ascend_model_parallel():
global _MC2
if _MC2:
_MC2.destroy()
_MC2 = None

View File

@@ -174,20 +174,12 @@ class CustomDeepseekDBOMoE(nn.Module):
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
forward_context = get_forward_context()
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.
if attn_metadata is None:
# for profile run
is_prefill = True
enable_force_load_balance = True
else:
is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
enable_force_load_balance = forward_context.in_profile_run
is_prefill = forward_context.with_prefill
old_hidden_states = hidden_states.clone()

View File

@@ -377,20 +377,14 @@ class CustomDeepseekV2MoE(nn.Module):
attn_metadata: Optional[AttentionMetadata] = None,
replace_allreduce: bool = False) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
forward_context = get_forward_context()
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.
if attn_metadata is None:
# for profile run
is_prefill = True
enable_force_load_balance = True
else:
is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
enable_force_load_balance = forward_context.in_profile_run
is_prefill = forward_context.with_prefill
# If this node is kv_consumer, we force the moe always runs in decode path to make sure
# the behaviour aligned between dummy_run and normal model_execute.
if self.kv_consumer:
@@ -572,9 +566,10 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
enable_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None
and not attn_metadata.with_prefill_across_dp
and not forward_context.with_prefill
and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None:

View File

@@ -837,12 +837,8 @@ class PanguProMoEModel(nn.Module):
# if attn_meatadata is not passed, we try to get it from forward_context.
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks
# are same.
max_tokens_across_dp = hidden_states.shape[0]
else:
max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
max_tokens_across_dp = get_forward_context().max_tokens_across_dp
tp_size = get_tp_group().world_size
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.

View File

@@ -223,7 +223,6 @@ def model_input_split_v1_mla_attn(
attn_mask=attn_mask_pre,
prefill=prefill_pre,
decode=decode_pre,
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
)
attention_metadata_post = _metadata_cls(
num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
@@ -240,6 +239,5 @@ def model_input_split_v1_mla_attn(
attn_state=attn_state_post,
prefill=prefill_post,
decode=decode_post,
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
)
return [attention_metadata_pre, attention_metadata_post]

View File

@@ -40,12 +40,15 @@ from vllm.model_executor.layers.quantization.base_config import \
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_all_reduce_merge_state, get_fused_moe_state,
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
get_ascend_soc_version,
get_rm_router_logits_state, is_310p)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -127,9 +130,23 @@ def fused_experts_with_mc2(
moe_parallel_config: FusedMoEParallelConfig,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None
shared_experts: Optional[Any] = None,
is_torchair: bool = False,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
global_bs = 0
quant_mode = 0
ep_rank_id = moe_parallel_config.ep_rank
ep_world_size = moe_parallel_config.ep_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
@@ -137,32 +154,35 @@ def fused_experts_with_mc2(
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
"global_bs": 0,
}
rank = torch.distributed.get_rank()
quant_mode = 0
ep_rank_id = moe_parallel_config.ep_rank
ep_world_size = moe_parallel_config.ep_size
tp_world_size = moe_parallel_config.tp_size
tp_rank = rank % tp_world_size
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_world_size,
"tp_rank_id": tp_rank,
}
if need_extra_args:
stage1_kwargs.update({
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:5]
if shared_experts is not None:
@@ -205,7 +225,6 @@ def fused_experts_with_mc2(
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
@@ -218,15 +237,33 @@ def fused_experts_with_mc2(
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_world_size,
"tp_rank_id": tp_rank,
}
if enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": assist_info_for_combine,
})
if need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if shared_experts is None:
return hidden_states
@@ -981,17 +1018,14 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
super().__init__(moe=moe)
vllm_config = get_current_vllm_config()
self.ep_group = get_ep_group()
self.ep_size = self.moe.moe_parallel_config.ep_size
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
self.max_model_len = vllm_config.model_config.max_model_len
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = self.ep_group.device_group
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
@@ -1074,8 +1108,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_fused_moe_state(self.ep_size, is_prefill,
is_deepseek_v3_r1)
fused_moe_state = get_forward_context().fused_moe_state
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
@@ -1087,7 +1121,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts)
shared_experts=shared_experts,
mc2_mask=kwargs.get("mc2_mask", None))
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
@@ -1295,52 +1330,56 @@ class AscendFusedMoE(FusedMoE):
real_top_k = self.top_k
num_tokens, hidden_size = hidden_states.shape
is_deepseek_v3_r1 = self.global_num_experts == 256
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
is_prefill, is_deepseek_v3_r1)
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)
tp_size = get_tensor_model_parallel_world_size()
if (tp_size > 1 and fused_moe_state not in [
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
if num_tokens < tp_size:
if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, tp_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
mc2_mask = chunk_mc2_mask[tp_rank]
if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
if num_tokens < max_num_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
max_tokens_across_dp = forward_context.max_tokens_across_dp
if num_tokens < max_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
@@ -1379,20 +1418,24 @@ class AscendFusedMoE(FusedMoE):
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=shared_experts if self.torchair_graph_enabled
and self.enable_multistream_moe and not is_prefill else None,
mc2_mask=mc2_mask,
)
if shared_experts:
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if (tp_size > 1 and fused_moe_state not in [
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < tp_size:
if tp_size > 1:
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
else:
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
elif self.dp_size > 1:

View File

@@ -22,8 +22,7 @@ from typing import Any, Dict, List, Optional
from vllm.logger import logger
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
wrapper_rmsnorm_init)
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod)
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
@@ -81,9 +80,6 @@ class VLLMAscendQuantizer:
VLLMAscendQuantizer.apply_patch(
"vllm.model_executor.layers.layernorm.RMSNorm",
"forward_oot", [wrapper_rmsnorm_forward_oot])
VLLMAscendQuantizer.apply_patch(
"vllm_ascend.worker.model_runner.NPUModelRunnerBase",
"load_model", [wrapper_load_model])
break
VLLMAscendQuantizer.patched = True
logger.info("Using the vLLM Ascend Quantizer version now!")

View File

@@ -20,15 +20,17 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch_npu
from vllm.distributed import GroupCoordinator
from vllm.distributed.parallel_state import get_ep_group
from vllm.distributed import GroupCoordinator, get_ep_group
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
dispose_tensor, get_fused_moe_state)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version)
def apply_mlp(hidden_states: torch.Tensor,
@@ -118,10 +120,29 @@ def fused_experts_with_mc2(
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
is_torchair: bool = False,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert mc2_mask is not None
if log2phy is not None:
topk_ids = log2phy[topk_ids]
global_bs = 0
quant_mode = 2
ep_group = get_mc2_group()
ep_rank_id = ep_group.rank_in_group
ep_world_size = ep_group.world_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
@@ -133,47 +154,43 @@ def fused_experts_with_mc2(
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
"expert_scales": topk_weights.to(torch.float32),
"global_bs": 0,
}
rank = torch.distributed.get_rank()
quant_mode = 2
ep_group = get_ep_group().device_group
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
world_size = torch.distributed.get_world_size()
tp_size = world_size // all_to_all_group_size
tp_rank = rank % tp_size
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
}
if need_extra_args:
stage1_kwargs.update({
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
0:7]
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:5]
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(hidden_states, topk_weights)
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
npu_wait_tensor(shared_gate_up[0], expand_x)
shared_act = shared_experts.act_fn(shared_gate_up)
npu_wait_tensor(quantized_x_for_share, expand_x)
shared_act_out = shared_experts.act_fn(
(quantized_x_for_share, dynamic_scale_for_share))
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
# `expand_x` will be disposed in the `apply_mlp` function
down_out_list = apply_mlp(expand_x,
w1,
w1_scale,
@@ -186,13 +203,11 @@ def fused_experts_with_mc2(
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"expand_scales": expand_scales,
}
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
@@ -200,24 +215,43 @@ def fused_experts_with_mc2(
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
}
if enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": assist_info_for_combine,
})
if need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if shared_experts is None:
return hidden_states
else:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_act[0], down_out_list)
shared_output, _ = shared_experts.down_proj(shared_act)
npu_wait_tensor(shared_act, down_out_list)
shared_output, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
return hidden_states, shared_output
@@ -640,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod:
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = self.ep_group.device_group
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
@@ -755,8 +789,7 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(x.dtype)
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill, is_deepseek_v3_r1)
fused_moe_state = get_forward_context().fused_moe_state
if fused_moe_state == FusedMoEState.AllGatherEP:
return fused_experts_with_allgather(
hidden_states=x,
@@ -782,7 +815,9 @@ class AscendW8A8DynamicFusedMoEMethod:
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts)
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None))
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:

View File

@@ -52,13 +52,3 @@ class NPUTorchairWorker(NPUWorker):
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
return available_kv_cache_memory
def _get_max_num_tokens_and_with_prefill(self):
"""Override _get_max_num_tokens_and_with_prefill to update max_num_tokens."""
max_num_tokens, with_prefill = super(
)._get_max_num_tokens_and_with_prefill()
if not with_prefill:
max_num_tokens = self.model_runner.select_torchair_padded_batch_size(
max_num_tokens)
return max_num_tokens, with_prefill

View File

@@ -429,15 +429,6 @@ def npu_prefetch(input: torch.Tensor,
torch_npu.npu_prefetch(input, dependency, max_size)
# TODO(zzzzwwjj): move this into forward_context
class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
AllGatherEP = 3
NaiveMulticast = 4
# TODO(ttanzhiqiang): rm_router_logits
# dp>1 will trigger
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
@@ -468,26 +459,6 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
return False
# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
if with_prefill:
return FusedMoEState.NaiveMulticast
else:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill:
return FusedMoEState.All2All
else:
return FusedMoEState.MC2
def register_ascend_customop():
"""Register Ascend CustomOP
@@ -506,3 +477,30 @@ def register_ascend_customop():
# NOTE: Keep this at last to ensure all custom actions are registered
_ASCEND_CUSTOMOP_IS_REIGISTERED = True
# TODO(zzzzwwjj): It will be judged with _build_info afterwards.
class AscendSocVersion(Enum):
A2 = 0
A3 = 1
UNDEFINED = 2
_ascend_soc_version = None
def init_ascend_soc_version():
soc_version = torch_npu.npu.get_soc_version()
global _ascend_soc_version
if 220 <= soc_version <= 225:
_ascend_soc_version = AscendSocVersion.A2
elif 250 <= soc_version <= 255:
_ascend_soc_version = AscendSocVersion.A3
else:
_ascend_soc_version = AscendSocVersion.UNDEFINED
def get_ascend_soc_version():
global _ascend_soc_version
assert _ascend_soc_version is not None
return _ascend_soc_version

View File

@@ -7,13 +7,13 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -142,9 +142,9 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions.to(device)
self.hidden_states[:num_tokens] = target_hidden_states
attn_metadata.block_tables = block_table.to(device)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
@@ -239,9 +239,9 @@ class EagleProposer:
attn_metadata.attn_mask = attn_mask
attn_metadata.block_tables = block_table.to(device)
# Run the model.
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
with set_ascend_forward_context(attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
@@ -344,8 +344,9 @@ class EagleProposer:
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
with set_ascend_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],

View File

@@ -34,7 +34,6 @@ import torch
import torch._dynamo.cache_size
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ReduceOp
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
@@ -44,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.forward_context import get_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -77,6 +76,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
scatter_mm_placeholders)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
@@ -347,6 +347,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
# kv role
self.is_kv_producer = False
if vllm_config.kv_transfer_config is not None:
@@ -566,16 +569,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata()
def _get_forward_metadata_across_dp(
self, total_num_scheduled_tokens: int,
with_prefill: bool) -> tuple[int, bool]:
forward_metadata = torch.tensor(
[total_num_scheduled_tokens, with_prefill],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata,
op=ReduceOp.MAX,
group=get_dp_group().cpu_group)
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
self,
maybe_padded_num_tokens: int,
num_tokens: int,
with_prefill: bool,
enable_dbo: bool = False,
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
if self.dp_size == 1:
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
num_tokens_across_dp = [0] * self.dp_size * 2
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
forward_metadata = torch.tensor(num_tokens_across_dp +
[with_prefill, not enable_dbo],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
with_prefill = bool(forward_metadata[-2])
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
if with_prefill:
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
2]
maybe_padded_num_tokens = num_tokens
else:
num_tokens_across_dp = forward_metadata[:self.dp_size]
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
# `max_tokens_across_dp`, in other situation it is not necessary.
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
forward_metadata[-1])
def get_eagle_atten_dict(
self,
@@ -1052,21 +1083,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]
if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
total_num_scheduled_tokens, with_prefill)
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
# Add graph_pad_size here
maybe_padded_num_tokens = total_num_scheduled_tokens
if self.torchair_graph_enabled and not with_prefill:
if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
padded_batch_size = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
if self.torchair_graph_enabled and not with_prefill:
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
@@ -1134,8 +1160,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions = self.mrope_positions[:, :num_input_tokens]
if self.torchair_graph_enabled and not with_prefill:
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_num_tokens_across_dp]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -1151,9 +1177,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
})
# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=padded_num_tokens_across_dp,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
num_actual_tokens=total_num_scheduled_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
model_kwargs = {}
@@ -1165,7 +1195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model(
padded_batch_size)
padded_num_tokens_across_dp)
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
@@ -1643,7 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
with set_forward_context(None, self.vllm_config):
with set_ascend_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfer(scheduler_output))
@@ -1688,14 +1718,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run(
self,
num_tokens: int,
is_compile: bool = False,
with_prefill: bool = True,
skip_attn: bool = True,
with_prefill: bool = False,
is_torchair_compile: bool = False,
) -> torch.Tensor:
maybe_padded_num_tokens = num_tokens
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, num_tokens, with_prefill, False)
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
@@ -1706,6 +1748,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer:
with_prefill = True
# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
elif skip_attn:
attn_metadata = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
attn_metadata = None
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
model = self.model
@@ -1735,20 +1788,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
num_actual_tokens=0,
):
model_kwargs = {}
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
# Only mark static while compiling
if is_compile:
if is_torchair_compile:
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(
attn_metadata.decode.block_table)
torch._dynamo.mark_static(
attn_metadata.decode.input_positions)
torch._dynamo.mark_static(
get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches:
assert isinstance(
@@ -1761,13 +1821,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=None,
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
**model_kwargs,
)
else:
maybe_converting_weight_acl_format(self.model,
@@ -1787,9 +1848,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run(num_tokens)
return hidden_states
@contextmanager
def set_in_profile_run(self):
self.in_profile_run = True
try:
yield
finally:
self.in_profile_run = False
def profile_run(self) -> None:
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
with self.set_in_profile_run():
hidden_states = self._dummy_run(self.max_num_tokens,
with_prefill=True)
output = None
if get_pp_group().is_last_rank:
if self.is_pooling_model:
@@ -2159,10 +2230,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
is_compile=True,
with_prefill=False)
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
self._dummy_run(num_tokens, is_torchair_compile=True)
self._dummy_run(num_tokens, is_torchair_compile=True)
logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes))
@@ -2205,6 +2274,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.

View File

@@ -2,12 +2,12 @@ import torch
from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
@@ -117,7 +117,7 @@ class MtpProposer:
query_start_loc=cu_num_tokens,
)
with set_forward_context(attn_metadata, self.vllm_config):
with set_ascend_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=target_positions,

View File

@@ -40,9 +40,10 @@ from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (sleep_mode_enabled, try_register_lib,
vllm_version_is)
from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled,
try_register_lib, vllm_version_is)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if not vllm_version_is("0.10.0"):
@@ -134,6 +135,7 @@ class NPUWorker(WorkerBase):
NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
init_ascend_soc_version()
# Initialize the distributed environment.
self._init_worker_distributed_environment()
# Set random seed.
@@ -272,20 +274,8 @@ class NPUWorker(WorkerBase):
def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id)
def _get_max_num_tokens_and_with_prefill(self):
max_num_tokens = 1
with_prefill = False
if self.model_runner.dp_size > 1:
max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp(
max_num_tokens, with_prefill)
return max_num_tokens, with_prefill
def execute_dummy_batch(self) -> None:
max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill(
)
self.model_runner._dummy_run(max_num_tokens,
is_compile=False,
with_prefill=with_prefill)
self.model_runner._dummy_run(1)
def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment."""
@@ -295,6 +285,7 @@ class NPUWorker(WorkerBase):
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size)
init_ascend_model_parallel(self.parallel_config)
ensure_kv_transfer_initialized(self.vllm_config)
def _init_profiler(self):