[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)

### What this PR does / why we need it?

A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:

Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;

### Does this PR introduce _any_ user-facing change?

No change at user-facing.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
57c22e57f9

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-07-28 14:06:20 +08:00
committed by GitHub
parent e3a2443c3a
commit ba3dfbd59e
22 changed files with 629 additions and 347 deletions

View File

@@ -21,6 +21,7 @@ def main():
tensor_parallel_size=2, tensor_parallel_size=2,
max_model_len=4096, max_model_len=4096,
trust_remote_code=True, trust_remote_code=True,
enable_expert_parallel=True,
additional_config={ additional_config={
"torchair_graph_config": { "torchair_graph_config": {
"enabled": False "enabled": False
@@ -28,7 +29,6 @@ def main():
"ascend_scheduler_config": { "ascend_scheduler_config": {
"enabled": True "enabled": True
}, },
"expert_tensor_parallel_size": 1
}) })
# Generate texts from the prompts. The output is a list of RequestOutput # Generate texts from the prompts. The output is a list of RequestOutput

View File

@@ -1,3 +1,7 @@
rm -rf ./.torchair_cache/
rm -rf ./dynamo_*
rm -rf /root/ascend/log/debug/plog/*
export HCCL_IF_IP=2.0.0.0 export HCCL_IF_IP=2.0.0.0
export GLOO_SOCKET_IFNAME="enp189s0f0" export GLOO_SOCKET_IFNAME="enp189s0f0"
export TP_SOCKET_IFNAME="enp189s0f0" export TP_SOCKET_IFNAME="enp189s0f0"
@@ -6,25 +10,24 @@ export HCCL_SOCKET_IFNAME="enp189s0f0"
export OMP_PROC_BIND=false export OMP_PROC_BIND=false
export OMP_NUM_THREADS=100 export OMP_NUM_THREADS=100
export VLLM_USE_V1=0 export VLLM_USE_V1=1
export ASCEND_LAUNCH_BLOCKING=0
export ASCEND_RT_VISIBLE_DEVICES=0,1
export VLLM_DP_SIZE=2
export VLLM_DP_RANK=0
export VLLM_DP_MASTER_IP="2.0.0.0"
export VLLM_DP_MASTER_PORT=40001
export VLLM_DP_PROXY_IP="2.0.0.0"
export VLLM_DP_PROXY_PORT=30002
export VLLM_DP_MONITOR_PORT=30003
export VLLM_HTTP_PORT=20001
vllm serve /data/weights/Qwen2.5-0.5B-Instruct \ vllm serve /data/weights/Qwen2.5-0.5B-Instruct \
--host 0.0.0.0 \ --host 0.0.0.0 \
--port 20001 \ --port 20002 \
--tensor-parallel-size 1 \
--seed 1024 \
--served-model-name Qwen \ --served-model-name Qwen \
--max-model-len 2000 \ --data-parallel-size 4 \
--max-num-batched-tokens 2000 \ --data-parallel-size-local 4 \
--data-parallel-address 2.0.0.0 \
--data-parallel-rpc-port 13389 \
--tensor-parallel-size 4 \
--enable-expert-parallel \
--no-enable-prefix-caching \
--max-num-seqs 16 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--gpu-memory-utilization 0.9 \
--trust-remote-code \ --trust-remote-code \
--gpu-memory-utilization 0.9 \ --enforce-eager \
--additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}'

View File

@@ -114,7 +114,16 @@ def mock_distributed():
return_value=Mock(is_first_rank=False, is_last_rank=False)), \ return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group): _PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
yield
@pytest.fixture
def mock_forward_context():
forward_context = Mock(in_profile_run=False, with_prefill=False)
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
return_value=forward_context):
yield yield
@@ -205,7 +214,8 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
quant_config=None) quant_config=None)
def test_custom_deepseek_v2_moe(mock_distributed, base_config): def test_custom_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1 base_config.n_shared_experts = 1
moe = CustomDeepseekV2MoE(config=base_config, moe = CustomDeepseekV2MoE(config=base_config,
quant_config=None, quant_config=None,

View File

@@ -18,16 +18,18 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from vllm_ascend.ascend_forward_context import get_fused_moe_state
from vllm_ascend.ops.fused_moe import (AscendFusedMoE, from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod) AscendUnquantizedFusedMoEMethod)
from vllm_ascend.utils import adapt_patch # noqa E402 from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
adapt_patch(True) adapt_patch(True)
def mock_ep_group(mocker): def mock_ep_and_mc2_group(mocker):
mock_group = mocker.MagicMock() mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0 mock_group.rank_in_group = 0
mock_group.rank = 0 mock_group.rank = 0
@@ -52,7 +54,8 @@ def mock_dist_env(mocker: MockerFixture):
with patch('torch.distributed.get_rank', return_value=0), \ with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \ patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
@@ -73,7 +76,7 @@ def mock_dist_env(mocker: MockerFixture):
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context', patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=MagicMock( return_value=MagicMock(
attn_metadata=MagicMock(max_num_tokens_across_dp=10), max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]) dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
)), \ )), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
@@ -122,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture):
patch("torch_npu.npu_moe_finalize_routing", return_value=( patch("torch_npu.npu_moe_finalize_routing", return_value=(
torch.randn(16, 2) torch.randn(16, 2)
)): )):
yield if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
torch.randn(16, 2))), \
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
torch.randn(16, 2))):
yield
else:
yield
@pytest.fixture @pytest.fixture
@@ -237,11 +247,16 @@ class TestAscendFusedMoe:
moe.moe_parallel_config.ep_size = 1 moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens) moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
output = moe.forward(inputs, forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
router_logits, dtype=torch.bool),
is_prefill=is_prefill, padded_num_tokens=num_tokens)
top_k=top_k, with patch("vllm_ascend.ops.fused_moe.get_forward_context",
shared_experts=shared_experts) return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once() moe.quant_method.apply.assert_called_once()
@@ -288,15 +303,20 @@ class TestAscendUnquantizedFusedMoEMethod:
def test_apply_without_expert_map(self, moe_method, mock_dist_env, def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param): mock_moe_env, others_param):
""" """
1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all 1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
2 test use_select_experts and fused_experts 2 test use_select_experts and fused_experts
3 test use select_gating_topk_softmax_experts and fused_experts 3 test use select_gating_topk_softmax_experts and fused_experts
4 test use select_experts and fused_experts_with_all2all_buffer 4 test use select_experts and fused_experts_with_all2all_buffer
""" """
global_num_experts, ep_size, select_softmax = others_param global_num_experts, ep_size, select_softmax = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
with patch( with patch(
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS", "vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
select_softmax): select_softmax), \
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context):
moe_method.ep_size = ep_size moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2) x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8) router_logits = torch.randn(8, 8)
@@ -309,7 +329,7 @@ class TestAscendUnquantizedFusedMoEMethod:
top_k=2, top_k=2,
renormalize=True, renormalize=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
is_prefill=False) is_prefill=is_prefill)
if ep_size == 1: if ep_size == 1:
assert result.shape == (16, 2) assert result.shape == (16, 2)
@@ -327,8 +347,13 @@ class TestAscendUnquantizedFusedMoEMethod:
4 test use_select_experts and fused_experts 4 test use_select_experts and fused_experts
""" """
ep_size, alltoall_buffer = others_param ep_size, alltoall_buffer = others_param
is_prefill = False
forward_context = MagicMock(
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER", with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
alltoall_buffer): alltoall_buffer), \
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2) x = torch.randn(8, 2, 2)
@@ -347,7 +372,7 @@ class TestAscendUnquantizedFusedMoEMethod:
renormalize=True, renormalize=True,
global_num_experts=128, global_num_experts=128,
expert_map=expert_map, expert_map=expert_map,
is_prefill=False) is_prefill=is_prefill)
if ep_size == 16 or ep_size == 1: if ep_size == 16 or ep_size == 1:
assert result.shape == (16, 2) assert result.shape == (16, 2)

View File

@@ -0,0 +1,117 @@
import math
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.platforms import current_platform
import vllm_ascend.envs as envs
class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
AllGatherEP = 3
NaiveMulticast = 4
# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
if with_prefill:
return FusedMoEState.NaiveMulticast
else:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill:
return FusedMoEState.All2All
else:
return FusedMoEState.MC2
@contextmanager
def set_ascend_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
num_actual_tokens: Optional[int] = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
"""
with set_forward_context(attn_metadata,
vllm_config,
virtual_engine=virtual_engine,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
forward_context = get_forward_context()
forward_context.with_prefill = with_prefill
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
is_deepseek_v3_r1 = hasattr(
vllm_config.model_config.hf_config, 'n_routed_experts'
) and vllm_config.model_config.hf_config.n_routed_experts == 256
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
is_deepseek_v3_r1)
forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run
# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False
if num_tokens is None and attn_metadata is not None:
if hasattr(attn_metadata, 'num_actual_tokens'):
# for v1 engine
num_tokens = attn_metadata.num_actual_tokens
else:
# for v0 engine
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
if num_actual_tokens is None:
num_actual_tokens = num_tokens
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
)
else:
max_tokens_across_dp = num_tokens
forward_context.max_tokens_across_dp = max_tokens_across_dp
if num_tokens is not None:
tp_world_size = get_tp_group().world_size
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
dtype=torch.bool,
device=current_platform.device_type)
mc2_mask[:num_actual_tokens] = True
forward_context.mc2_mask = mc2_mask
try:
yield
finally:
pass

View File

@@ -119,6 +119,7 @@ class AscendAttentionState(Enum):
@dataclass @dataclass
class AscendMetadata: class AscendMetadata:
# **************************** Basic Properties **************************** # **************************** Basic Properties ****************************
attn_mask: Optional[torch.Tensor] = None attn_mask: Optional[torch.Tensor] = None
# Current state of this attention run. # Current state of this attention run.
@@ -149,11 +150,6 @@ class AscendMetadata:
# (num_tokens,) # (num_tokens,)
slot_mapping: torch.Tensor = None slot_mapping: torch.Tensor = None
# ************************* DP Related Properties **************************
with_prefill_across_dp: bool = False
# Maximum number of tokens across dp group
max_num_tokens_across_dp: int = 0
class AscendAttentionMetadataBuilder: class AscendAttentionMetadataBuilder:
@@ -164,12 +160,7 @@ class AscendAttentionMetadataBuilder:
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
return False return False
def build(self, def build(self, num_reqs, num_actual_tokens, max_query_len):
num_reqs,
num_actual_tokens,
max_query_len,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
block_table = self.runner.input_batch.block_table[0].get_device_tensor( block_table = self.runner.input_batch.block_table[0].get_device_tensor(
) )
@@ -196,18 +187,15 @@ class AscendAttentionMetadataBuilder:
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
ACL_FORMAT_FRACTAL_NZ) ACL_FORMAT_FRACTAL_NZ)
attn_metadata = AscendMetadata( attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens, block_tables=block_table,
block_tables=block_table, query_start_loc=query_start_loc,
query_start_loc=query_start_loc, query_lens=query_lens,
query_lens=query_lens, seq_lens=seq_lens,
seq_lens=seq_lens, max_query_len=max_query_len,
max_query_len=max_query_len, slot_mapping=slot_mapping,
slot_mapping=slot_mapping, attn_mask=attn_mask,
attn_mask=attn_mask, attn_state=attn_state)
attn_state=attn_state,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp)
return attn_metadata return attn_metadata

View File

@@ -127,8 +127,6 @@ class AscendTorchairMetadata:
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
query_lens: torch.Tensor query_lens: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
# max value of number of tokens across dp group
max_num_tokens_across_dp: int = 0
# Maximum query length in the batch. None for decoding. # Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be # (num_tokens,). The indices of the token slots that input tokens will be
@@ -139,7 +137,7 @@ class AscendTorchairMetadata:
# Current state of this attention run. # Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None attn_mask: Optional[torch.Tensor] = None
with_prefill_across_dp: bool = False
decode: Optional[AscendDecodeMetadata] = None decode: Optional[AscendDecodeMetadata] = None
@@ -178,8 +176,9 @@ class AscendAttentionTorchairMetadataBuilder:
return graph_block_tables[:num_seqs, :max_blocks] return graph_block_tables[:num_seqs, :max_blocks]
def build_dummy(self, num_reqs: int, def build_torchair_graph_dummy(
num_actual_tokens: int) -> AscendTorchairMetadata: self, num_reqs: int,
num_actual_tokens: int) -> AscendTorchairMetadata:
device = self.runner.device device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape _, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks), block_table = torch.zeros((num_reqs, max_blocks),
@@ -214,7 +213,6 @@ class AscendAttentionTorchairMetadataBuilder:
seq_lens=seq_lens, seq_lens=seq_lens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
attn_state=AscendAttentionState.DecodeOnly, attn_state=AscendAttentionState.DecodeOnly,
max_num_tokens_across_dp=num_reqs,
decode=decode_metadata) decode=decode_metadata)
return attn_metadata return attn_metadata
@@ -222,9 +220,7 @@ class AscendAttentionTorchairMetadataBuilder:
num_reqs, num_reqs,
num_actual_tokens, num_actual_tokens,
max_query_len, max_query_len,
graph_pad_size: int = -1, graph_pad_size: int = -1):
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False):
device = self.runner.device device = self.runner.device
@@ -263,7 +259,6 @@ class AscendAttentionTorchairMetadataBuilder:
pad_value = 1 pad_value = 1
padded_seq_lens = seq_lens.tolist() + [pad_value padded_seq_lens = seq_lens.tolist() + [pad_value
] * graph_pad_size ] * graph_pad_size
max_num_tokens_across_dp = len(padded_seq_lens)
seq_lens = torch.from_numpy( seq_lens = torch.from_numpy(
np.array(padded_seq_lens).astype(np.int32)) np.array(padded_seq_lens).astype(np.int32))
@@ -303,9 +298,7 @@ class AscendAttentionTorchairMetadataBuilder:
max_query_len=max_query_len, max_query_len=max_query_len,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
attn_mask=attn_mask, attn_mask=attn_mask,
attn_state=attn_state, attn_state=attn_state)
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp)
return attn_metadata return attn_metadata

View File

@@ -126,9 +126,6 @@ class AscendMLAMetadata:
# For logging. # For logging.
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
max_num_tokens_across_dp: int = 0
with_prefill_across_dp: bool = False
query_lens: Optional[list[int]] = None query_lens: Optional[list[int]] = None
# The dimension of the attention heads # The dimension of the attention heads
head_dim: Optional[int] = None head_dim: Optional[int] = None
@@ -302,8 +299,8 @@ class AscendMLAMetadataBuilder:
return graph_block_tables[:num_seqs, :max_blocks] return graph_block_tables[:num_seqs, :max_blocks]
def build_dummy(self, num_reqs: int, def build_torchair_graph_dummy(
num_actual_tokens: int) -> AscendMLAMetadata: self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata:
device = self.runner.device device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape _, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks), block_table = torch.zeros((num_reqs, max_blocks),
@@ -353,8 +350,6 @@ class AscendMLAMetadataBuilder:
num_actual_tokens: int, num_actual_tokens: int,
max_query_len: int, max_query_len: int,
graph_pad_size: int = -1, graph_pad_size: int = -1,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False,
query_start_loc: torch.Tensor = None, query_start_loc: torch.Tensor = None,
) -> AscendMLAMetadata: ) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs assert self._num_decodes + self._num_prefills == num_reqs
@@ -498,8 +493,6 @@ class AscendMLAMetadataBuilder:
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
block_tables=block_table, block_tables=block_table,
seq_lens=seq_lens, seq_lens=seq_lens,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp,
) )

View File

@@ -0,0 +1,48 @@
from typing import Optional
import torch
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
init_model_parallel_group)
# Currently, mc2 op need their own group coordinator.
_MC2: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
assert _MC2 is not None, ("mc2 group is not initialized")
return _MC2
def model_parallel_initialized():
return (_MC2 is not None)
def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
if model_parallel_initialized():
return
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
backend = torch.distributed.get_backend(get_world_group().device_group)
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
global _MC2
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_MC2 = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="mc2")
def destroy_ascend_model_parallel():
global _MC2
if _MC2:
_MC2.destroy()
_MC2 = None

View File

@@ -174,20 +174,12 @@ class CustomDeepseekDBOMoE(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
if attn_metadata is None: forward_context = get_forward_context()
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens # when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank. # to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not. enable_force_load_balance = forward_context.in_profile_run
if attn_metadata is None:
# for profile run is_prefill = forward_context.with_prefill
is_prefill = True
enable_force_load_balance = True
else:
is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
old_hidden_states = hidden_states.clone() old_hidden_states = hidden_states.clone()

View File

@@ -377,20 +377,14 @@ class CustomDeepseekV2MoE(nn.Module):
attn_metadata: Optional[AttentionMetadata] = None, attn_metadata: Optional[AttentionMetadata] = None,
replace_allreduce: bool = False) -> torch.Tensor: replace_allreduce: bool = False) -> torch.Tensor:
if attn_metadata is None: forward_context = get_forward_context()
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens # when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank. # to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.
if attn_metadata is None: enable_force_load_balance = forward_context.in_profile_run
# for profile run
is_prefill = True is_prefill = forward_context.with_prefill
enable_force_load_balance = True
else:
is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
# If this node is kv_consumer, we force the moe always runs in decode path to make sure # If this node is kv_consumer, we force the moe always runs in decode path to make sure
# the behaviour aligned between dummy_run and normal model_execute. # the behaviour aligned between dummy_run and normal model_execute.
if self.kv_consumer: if self.kv_consumer:
@@ -572,9 +566,10 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None, kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
enable_multistream_mla = (self.enable_multistream_mla enable_multistream_mla = (self.enable_multistream_mla
and attn_metadata is not None and attn_metadata is not None
and not attn_metadata.with_prefill_across_dp and not forward_context.with_prefill
and attn_metadata.num_decodes > 0) and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None: if self.q_lora_rank is not None:

View File

@@ -837,12 +837,8 @@ class PanguProMoEModel(nn.Module):
# if attn_meatadata is not passed, we try to get it from forward_context. # if attn_meatadata is not passed, we try to get it from forward_context.
if attn_metadata is None: if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks max_tokens_across_dp = get_forward_context().max_tokens_across_dp
# are same.
max_tokens_across_dp = hidden_states.shape[0]
else:
max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
tp_size = get_tp_group().world_size tp_size = get_tp_group().world_size
# reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks. # reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks.

View File

@@ -223,7 +223,6 @@ def model_input_split_v1_mla_attn(
attn_mask=attn_mask_pre, attn_mask=attn_mask_pre,
prefill=prefill_pre, prefill=prefill_pre,
decode=decode_pre, decode=decode_pre,
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
) )
attention_metadata_post = _metadata_cls( attention_metadata_post = _metadata_cls(
num_actual_tokens=attn_metadata.num_actual_tokens - token_index, num_actual_tokens=attn_metadata.num_actual_tokens - token_index,
@@ -240,6 +239,5 @@ def model_input_split_v1_mla_attn(
attn_state=attn_state_post, attn_state=attn_state_post,
prefill=prefill_post, prefill=prefill_post,
decode=decode_post, decode=decode_post,
with_prefill_across_dp=attn_metadata.with_prefill_across_dp,
) )
return [attention_metadata_pre, attention_metadata_post] return [attention_metadata_pre, attention_metadata_post]

View File

@@ -40,12 +40,15 @@ from vllm.model_executor.layers.quantization.base_config import \
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \ from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (FusedMoEState, dispose_tensor, from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state, get_fused_moe_state, get_all_reduce_merge_state,
get_ascend_soc_version,
get_rm_router_logits_state, is_310p) get_rm_router_logits_state, is_310p)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -127,9 +130,23 @@ def fused_experts_with_mc2(
moe_parallel_config: FusedMoEParallelConfig, moe_parallel_config: FusedMoEParallelConfig,
expert_map: torch.Tensor = None, expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None, moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None shared_experts: Optional[Any] = None,
is_torchair: bool = False,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
global_bs = 0 quant_mode = 0
ep_rank_id = moe_parallel_config.ep_rank
ep_world_size = moe_parallel_config.ep_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
moe_expert_num = len(expert_map) moe_expert_num = len(expert_map)
kwargs_mc2 = { kwargs_mc2 = {
"x": hidden_states, "x": hidden_states,
@@ -137,32 +154,35 @@ def fused_experts_with_mc2(
"expert_shard_type": 0, "expert_shard_type": 0,
"shared_expert_rank_num": 0, "shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num, "moe_expert_num": moe_expert_num,
"global_bs": global_bs, "global_bs": 0,
} }
rank = torch.distributed.get_rank()
quant_mode = 0
ep_rank_id = moe_parallel_config.ep_rank
ep_world_size = moe_parallel_config.ep_size
tp_world_size = moe_parallel_config.tp_size
tp_rank = rank % tp_world_size
stage1_kwargs = { stage1_kwargs = {
"scales": None, "scales": None,
"quant_mode": quant_mode, "quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name, "group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size, "ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id, "ep_rank_id": ep_rank_id,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_world_size,
"tp_rank_id": tp_rank,
} }
if need_extra_args:
stage1_kwargs.update({
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs) kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) output = torch_npu.npu_moe_distribute_dispatch_v2(
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ **kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:5] 0:5]
if shared_experts is not None: if shared_experts is not None:
@@ -205,7 +225,6 @@ def fused_experts_with_mc2(
kwargs_mc2 = { kwargs_mc2 = {
"expand_x": down_out_list, "expand_x": down_out_list,
"expert_ids": topk_ids, "expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32), "expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0, "expert_shard_type": 0,
"shared_expert_rank_num": 0, "shared_expert_rank_num": 0,
@@ -218,15 +237,33 @@ def fused_experts_with_mc2(
"group_ep": moe_all_to_all_group_name, "group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size, "ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id, "ep_rank_id": ep_rank_id,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_world_size,
"tp_rank_id": tp_rank,
} }
if enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": assist_info_for_combine,
})
if need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs) kwargs_mc2.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if shared_experts is None: if shared_experts is None:
return hidden_states return hidden_states
@@ -981,17 +1018,14 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
super().__init__(moe=moe) super().__init__(moe=moe)
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.ep_group = get_ep_group()
self.ep_size = self.moe.moe_parallel_config.ep_size
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try: try:
device_group = self.ep_group.device_group device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group # TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group) local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu")) backend = device_group._get_backend(torch.device("npu"))
@@ -1074,8 +1108,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if enable_force_load_balance: if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_fused_moe_state(self.ep_size, is_prefill, fused_moe_state = get_forward_context().fused_moe_state
is_deepseek_v3_r1)
if fused_moe_state == FusedMoEState.MC2: if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2( return fused_experts_with_mc2(
hidden_states=x, hidden_states=x,
@@ -1087,7 +1121,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
top_k=top_k, top_k=top_k,
expert_map=expert_map, expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name, moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts) shared_experts=shared_experts,
mc2_mask=kwargs.get("mc2_mask", None))
elif fused_moe_state in [ elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]: ]:
@@ -1295,52 +1330,56 @@ class AscendFusedMoE(FusedMoE):
real_top_k = self.top_k real_top_k = self.top_k
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
is_deepseek_v3_r1 = self.global_num_experts == 256
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, forward_context = get_forward_context()
is_prefill, is_deepseek_v3_r1) fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
if shared_experts: if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states) shared_hidden_states = shared_experts(hidden_states)
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if (tp_size > 1 and fused_moe_state not in [ if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast FusedMoEState.NaiveMulticast
] and not replace_allreduce): ] and not replace_allreduce):
if num_tokens < tp_size: if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size:
hidden_states = nn.functional.pad( hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens)) hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad( router_logits = nn.functional.pad(
router_logits, (0, 0, 0, tp_size - num_tokens)) router_logits, (0, 0, 0, padding_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states, if tp_size > 1:
tp_size, chunk_hidden_states = torch.tensor_split(hidden_states,
dim=0) tp_size,
chunk_router_logits = torch.tensor_split(router_logits, dim=0)
tp_size, chunk_router_logits = torch.tensor_split(router_logits,
dim=0) tp_size,
tp_rank = get_tensor_model_parallel_rank() dim=0)
hidden_states = chunk_hidden_states[tp_rank] chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
router_logits = chunk_router_logits[tp_rank] tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
mc2_mask = chunk_mc2_mask[tp_rank]
if self.dp_size > 1: if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather: if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1 # NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled: if not self.torchair_graph_enabled:
attn_metadata = get_forward_context().attn_metadata max_tokens_across_dp = forward_context.max_tokens_across_dp
if attn_metadata is not None: if num_tokens < max_tokens_across_dp:
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp hidden_states = nn.functional.pad(
if num_tokens < max_num_tokens_across_dp: hidden_states,
hidden_states = nn.functional.pad( (0, 0, 0, max_tokens_across_dp - num_tokens))
hidden_states, if not self.rm_router_logits:
(0, 0, 0, router_logits = nn.functional.pad(
max_num_tokens_across_dp - num_tokens)) router_logits,
if not self.rm_router_logits: (0, 0, 0, max_tokens_across_dp - num_tokens))
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0) hidden_states = get_dp_group().all_gather(hidden_states, 0)
if self.rm_router_logits: if self.rm_router_logits:
router_logits, _ = gate(hidden_states) router_logits, _ = gate(hidden_states)
@@ -1379,20 +1418,24 @@ class AscendFusedMoE(FusedMoE):
global_redundant_expert_num=self.global_redundant_expert_num, global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=shared_experts if self.torchair_graph_enabled shared_experts=shared_experts if self.torchair_graph_enabled
and self.enable_multistream_moe and not is_prefill else None, and self.enable_multistream_moe and not is_prefill else None,
mc2_mask=mc2_mask,
) )
if shared_experts: if shared_experts:
if isinstance(e_hidden_states, tuple): if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states e_hidden_states, shared_hidden_states = e_hidden_states
if (tp_size > 1 and fused_moe_state not in [ if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast FusedMoEState.NaiveMulticast
] and not replace_allreduce): ] and not replace_allreduce):
dist.all_gather(list(chunk_hidden_states), e_hidden_states, if tp_size > 1:
self.tp_group) dist.all_gather(list(chunk_hidden_states), e_hidden_states,
final_hidden_states = torch.cat(chunk_hidden_states, dim=0) self.tp_group)
if num_tokens < tp_size: final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
else:
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens] final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states) dispose_tensor(e_hidden_states)
elif self.dp_size > 1: elif self.dp_size > 1:

View File

@@ -22,8 +22,7 @@ from typing import Any, Dict, List, Optional
from vllm.logger import logger from vllm.logger import logger
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot, from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
wrapper_rmsnorm_init)
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod) AscendW8A8LinearMethod)
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
@@ -81,9 +80,6 @@ class VLLMAscendQuantizer:
VLLMAscendQuantizer.apply_patch( VLLMAscendQuantizer.apply_patch(
"vllm.model_executor.layers.layernorm.RMSNorm", "vllm.model_executor.layers.layernorm.RMSNorm",
"forward_oot", [wrapper_rmsnorm_forward_oot]) "forward_oot", [wrapper_rmsnorm_forward_oot])
VLLMAscendQuantizer.apply_patch(
"vllm_ascend.worker.model_runner.NPUModelRunnerBase",
"load_model", [wrapper_load_model])
break break
VLLMAscendQuantizer.patched = True VLLMAscendQuantizer.patched = True
logger.info("Using the vLLM Ascend Quantizer version now!") logger.info("Using the vLLM Ascend Quantizer version now!")

View File

@@ -20,15 +20,17 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu
from vllm.distributed import GroupCoordinator from vllm.distributed import GroupCoordinator, get_ep_group
from vllm.distributed.parallel_state import get_ep_group from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs import vllm_ascend.envs as envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_fused_moe_state) dispose_tensor, get_ascend_soc_version)
def apply_mlp(hidden_states: torch.Tensor, def apply_mlp(hidden_states: torch.Tensor,
@@ -118,10 +120,29 @@ def fused_experts_with_mc2(
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None, shared_experts: Optional[Any] = None,
is_torchair: bool = False,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert mc2_mask is not None
if log2phy is not None: if log2phy is not None:
topk_ids = log2phy[topk_ids] topk_ids = log2phy[topk_ids]
global_bs = 0
quant_mode = 2
ep_group = get_mc2_group()
ep_rank_id = ep_group.rank_in_group
ep_world_size = ep_group.world_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
if (expert_map is not None): if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num moe_expert_num = len(expert_map) + global_redundant_expert_num
else: else:
@@ -133,47 +154,43 @@ def fused_experts_with_mc2(
"expert_shard_type": 0, "expert_shard_type": 0,
"shared_expert_rank_num": 0, "shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num, "moe_expert_num": moe_expert_num,
"global_bs": global_bs, "global_bs": 0,
"expert_scales": topk_weights.to(torch.float32),
} }
rank = torch.distributed.get_rank()
quant_mode = 2
ep_group = get_ep_group().device_group
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
world_size = torch.distributed.get_world_size()
tp_size = world_size // all_to_all_group_size
tp_rank = rank % tp_size
stage1_kwargs = { stage1_kwargs = {
"scales": None, "scales": None,
"quant_mode": quant_mode, "quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name, "group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size, "ep_world_size": ep_world_size,
"ep_rank_id": local_rank, "ep_rank_id": ep_rank_id,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
} }
if need_extra_args:
stage1_kwargs.update({
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs) kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream()) # comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[ expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:7] 0:5]
if shared_experts is not None: if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0): with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(hidden_states, topk_weights) npu_wait_tensor(quantized_x_for_share, expand_x)
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) shared_act_out = shared_experts.act_fn(
npu_wait_tensor(shared_gate_up[0], expand_x) (quantized_x_for_share, dynamic_scale_for_share))
shared_act = shared_experts.act_fn(shared_gate_up) shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
# `expand_x` will be disposed in the `apply_mlp` function
down_out_list = apply_mlp(expand_x, down_out_list = apply_mlp(expand_x,
w1, w1,
w1_scale, w1_scale,
@@ -186,13 +203,11 @@ def fused_experts_with_mc2(
kwargs_mc2 = { kwargs_mc2 = {
"expand_x": down_out_list, "expand_x": down_out_list,
"expert_ids": topk_ids, "expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32), "expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0, "expert_shard_type": 0,
"shared_expert_rank_num": 0, "shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num, "moe_expert_num": moe_expert_num,
"global_bs": 0, "global_bs": 0,
"expand_scales": expand_scales,
} }
tp_recv_counts = torch.empty(1, tp_recv_counts = torch.empty(1,
dtype=torch.int32, dtype=torch.int32,
@@ -200,24 +215,43 @@ def fused_experts_with_mc2(
stage3_kwargs = { stage3_kwargs = {
"ep_send_counts": ep_recv_counts, "ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name, "group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size, "ep_world_size": ep_world_size,
"ep_rank_id": local_rank, "ep_rank_id": ep_rank_id,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
} }
if enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": assist_info_for_combine,
})
if need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs) kwargs_mc2.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if shared_experts is None: if shared_experts is None:
return hidden_states return hidden_states
else: else:
with npu_stream_switch("moe_secondary", 0): with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_act[0], down_out_list) npu_wait_tensor(shared_act, down_out_list)
shared_output, _ = shared_experts.down_proj(shared_act) shared_output, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
return hidden_states, shared_output return hidden_states, shared_output
@@ -640,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod:
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try: try:
device_group = self.ep_group.device_group device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group # TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group) local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu")) backend = device_group._get_backend(torch.device("npu"))
@@ -755,8 +789,7 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
fused_moe_state = get_fused_moe_state(self.ep_group.world_size, fused_moe_state = get_forward_context().fused_moe_state
is_prefill, is_deepseek_v3_r1)
if fused_moe_state == FusedMoEState.AllGatherEP: if fused_moe_state == FusedMoEState.AllGatherEP:
return fused_experts_with_allgather( return fused_experts_with_allgather(
hidden_states=x, hidden_states=x,
@@ -782,7 +815,9 @@ class AscendW8A8DynamicFusedMoEMethod:
moe_all_to_all_group_name=self.moe_all_to_all_group_name, moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy, log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num, global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts) shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None))
elif fused_moe_state in [ elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]: ]:

View File

@@ -52,13 +52,3 @@ class NPUTorchairWorker(NPUWorker):
self.model_runner.new_kv_cache_bytes = available_kv_cache_memory self.model_runner.new_kv_cache_bytes = available_kv_cache_memory
return available_kv_cache_memory return available_kv_cache_memory
def _get_max_num_tokens_and_with_prefill(self):
"""Override _get_max_num_tokens_and_with_prefill to update max_num_tokens."""
max_num_tokens, with_prefill = super(
)._get_max_num_tokens_and_with_prefill()
if not with_prefill:
max_num_tokens = self.model_runner.select_torchair_padded_batch_size(
max_num_tokens)
return max_num_tokens, with_prefill

View File

@@ -429,15 +429,6 @@ def npu_prefetch(input: torch.Tensor,
torch_npu.npu_prefetch(input, dependency, max_size) torch_npu.npu_prefetch(input, dependency, max_size)
# TODO(zzzzwwjj): move this into forward_context
class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
AllGatherEP = 3
NaiveMulticast = 4
# TODO(ttanzhiqiang): rm_router_logits # TODO(ttanzhiqiang): rm_router_logits
# dp>1 will trigger # dp>1 will trigger
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. # In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
@@ -468,26 +459,6 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool):
return False return False
# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
if with_prefill:
return FusedMoEState.NaiveMulticast
else:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill:
return FusedMoEState.All2All
else:
return FusedMoEState.MC2
def register_ascend_customop(): def register_ascend_customop():
"""Register Ascend CustomOP """Register Ascend CustomOP
@@ -506,3 +477,30 @@ def register_ascend_customop():
# NOTE: Keep this at last to ensure all custom actions are registered # NOTE: Keep this at last to ensure all custom actions are registered
_ASCEND_CUSTOMOP_IS_REIGISTERED = True _ASCEND_CUSTOMOP_IS_REIGISTERED = True
# TODO(zzzzwwjj): It will be judged with _build_info afterwards.
class AscendSocVersion(Enum):
A2 = 0
A3 = 1
UNDEFINED = 2
_ascend_soc_version = None
def init_ascend_soc_version():
soc_version = torch_npu.npu.get_soc_version()
global _ascend_soc_version
if 220 <= soc_version <= 225:
_ascend_soc_version = AscendSocVersion.A2
elif 250 <= soc_version <= 255:
_ascend_soc_version = AscendSocVersion.A3
else:
_ascend_soc_version = AscendSocVersion.UNDEFINED
def get_ascend_soc_version():
global _ascend_soc_version
assert _ascend_soc_version is not None
return _ascend_soc_version

View File

@@ -7,13 +7,13 @@ from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
@@ -142,9 +142,9 @@ class EagleProposer:
self.positions[:num_tokens] = target_positions.to(device) self.positions[:num_tokens] = target_positions.to(device)
self.hidden_states[:num_tokens] = target_hidden_states self.hidden_states[:num_tokens] = target_hidden_states
attn_metadata.block_tables = block_table.to(device) attn_metadata.block_tables = block_table.to(device)
with set_forward_context(attn_metadata, with set_ascend_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens): num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
@@ -239,9 +239,9 @@ class EagleProposer:
attn_metadata.attn_mask = attn_mask attn_metadata.attn_mask = attn_mask
attn_metadata.block_tables = block_table.to(device) attn_metadata.block_tables = block_table.to(device)
# Run the model. # Run the model.
with set_forward_context(attn_metadata, with set_ascend_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch_size): num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size], input_ids=self.input_ids[:input_batch_size],
@@ -344,8 +344,9 @@ class EagleProposer:
self, self,
num_tokens: int, num_tokens: int,
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, with set_ascend_forward_context(None,
num_tokens=num_tokens): self.vllm_config,
num_tokens=num_tokens):
self.model( self.model(
input_ids=self.input_ids[:num_tokens], input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens], positions=self.positions[:num_tokens],

View File

@@ -34,7 +34,6 @@ import torch
import torch._dynamo.cache_size import torch._dynamo.cache_size
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.distributed import ReduceOp
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
@@ -44,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group) get_tp_group)
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -77,6 +76,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
scatter_mm_placeholders) scatter_mm_placeholders)
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import (AscendAttentionState, from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata) AscendMetadata)
@@ -347,6 +347,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch._logging.set_logs( torch._logging.set_logs(
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
self.in_profile_run = False
# kv role # kv role
self.is_kv_producer = False self.is_kv_producer = False
if vllm_config.kv_transfer_config is not None: if vllm_config.kv_transfer_config is not None:
@@ -566,16 +569,44 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata() self.input_batch.refresh_sampling_metadata()
def _get_forward_metadata_across_dp( def _get_forward_metadata_across_dp(
self, total_num_scheduled_tokens: int, self,
with_prefill: bool) -> tuple[int, bool]: maybe_padded_num_tokens: int,
forward_metadata = torch.tensor( num_tokens: int,
[total_num_scheduled_tokens, with_prefill], with_prefill: bool,
device="cpu", enable_dbo: bool = False,
dtype=torch.int32) ) -> tuple[int, Optional[torch.Tensor], bool, bool]:
dist.all_reduce(forward_metadata, if self.dp_size == 1:
op=ReduceOp.MAX, return maybe_padded_num_tokens, None, with_prefill, enable_dbo
group=get_dp_group().cpu_group)
return int(forward_metadata[0]), bool(forward_metadata[1] > 0) num_tokens_across_dp = [0] * self.dp_size * 2
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
forward_metadata = torch.tensor(num_tokens_across_dp +
[with_prefill, not enable_dbo],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
with_prefill = bool(forward_metadata[-2])
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
if with_prefill:
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
2]
maybe_padded_num_tokens = num_tokens
else:
num_tokens_across_dp = forward_metadata[:self.dp_size]
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
# `max_tokens_across_dp`, in other situation it is not necessary.
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
self.dp_size,
device="cpu",
dtype=torch.int32)
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
forward_metadata[-1])
def get_eagle_atten_dict( def get_eagle_atten_dict(
self, self,
@@ -1052,21 +1083,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
] ]
if self.dp_size > 1: maybe_padded_num_tokens = total_num_scheduled_tokens
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
total_num_scheduled_tokens, with_prefill)
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
# Add graph_pad_size here
if self.torchair_graph_enabled and not with_prefill: if self.torchair_graph_enabled and not with_prefill:
if self.dp_size > 1: maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
padded_batch_size = self.select_torchair_padded_batch_size( total_num_scheduled_tokens)
max_num_tokens) (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
else: enable_dbo) = self._get_forward_metadata_across_dp(
padded_batch_size = self.select_torchair_padded_batch_size( maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
total_num_scheduled_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens if self.torchair_graph_enabled and not with_prefill:
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
extra_builder_kwargs['graph_pad_size'] = graph_pad_size extra_builder_kwargs['graph_pad_size'] = graph_pad_size
@@ -1134,8 +1160,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions = self.mrope_positions[:, :num_input_tokens] positions = self.mrope_positions[:, :num_input_tokens]
if self.torchair_graph_enabled and not with_prefill: if self.torchair_graph_enabled and not with_prefill:
input_ids = self.input_ids[:padded_batch_size] input_ids = self.input_ids[:padded_num_tokens_across_dp]
positions = self.positions[:padded_batch_size] positions = self.positions[:padded_num_tokens_across_dp]
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
intermediate_tensors = None intermediate_tensors = None
@@ -1151,9 +1177,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
}) })
# Run forward pass # Run forward pass
with set_forward_context(attn_metadata, with set_ascend_forward_context(
self.vllm_config, attn_metadata,
num_tokens=num_input_tokens): self.vllm_config,
num_tokens=padded_num_tokens_across_dp,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
num_actual_tokens=total_num_scheduled_tokens):
with ProfileExecuteDuration().capture_async("forward"): with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output) self.maybe_setup_kv_connector(scheduler_output)
model_kwargs = {} model_kwargs = {}
@@ -1165,7 +1195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
ACL_FORMAT_FRACTAL_NZ) ACL_FORMAT_FRACTAL_NZ)
compiled_model = self._get_torchair_lazy_compiled_model( compiled_model = self._get_torchair_lazy_compiled_model(
padded_batch_size) padded_num_tokens_across_dp)
hidden_states = compiled_model( hidden_states = compiled_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
@@ -1643,7 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def kv_connector_no_forward( def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
with set_forward_context(None, self.vllm_config): with set_ascend_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output) self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = ( finished_sending, finished_recving = (
self.get_finished_kv_transfer(scheduler_output)) self.get_finished_kv_transfer(scheduler_output))
@@ -1688,14 +1718,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _dummy_run( def _dummy_run(
self, self,
num_tokens: int, num_tokens: int,
is_compile: bool = False, skip_attn: bool = True,
with_prefill: bool = True, with_prefill: bool = False,
is_torchair_compile: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
maybe_padded_num_tokens = num_tokens
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
enable_dbo) = self._get_forward_metadata_across_dp(
maybe_padded_num_tokens, num_tokens, with_prefill, False)
# Set num_scheduled_tokens based on num_tokens and max_num_seqs # Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively # for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total. # has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens assert num_tokens <= self.scheduler_config.max_num_batched_tokens
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs
@@ -1706,6 +1748,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.is_kv_producer: if self.is_kv_producer:
with_prefill = True with_prefill = True
# NOTE: If torchair graph mode and not with_prefill,
# we can't skip_attn, it will cause graph recompile.
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
elif skip_attn:
attn_metadata = None
else:
# TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata
attn_metadata = None
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
model = self.model model = self.model
@@ -1735,20 +1788,27 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items() for k, v in self.intermediate_tensors.items()
}) })
with set_forward_context(None, with set_ascend_forward_context(
self.vllm_config, attn_metadata,
num_tokens=num_tokens): self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
num_actual_tokens=0,
):
model_kwargs = {}
if self.torchair_graph_enabled and not with_prefill: if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
# Only mark static while compiling # Only mark static while compiling
if is_compile: if is_torchair_compile:
torch._dynamo.mark_static(input_ids) torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions) torch._dynamo.mark_static(positions)
torch._dynamo.mark_static( torch._dynamo.mark_static(
attn_metadata.decode.block_table) attn_metadata.decode.block_table)
torch._dynamo.mark_static( torch._dynamo.mark_static(
attn_metadata.decode.input_positions) attn_metadata.decode.input_positions)
torch._dynamo.mark_static(
get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping) torch._dynamo.mark_static(attn_metadata.slot_mapping)
for kv in self.kv_caches: for kv in self.kv_caches:
assert isinstance( assert isinstance(
@@ -1761,13 +1821,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
compiled_model = self._get_torchair_lazy_compiled_model( compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens) num_tokens)
model_kwargs["kv_caches"] = self.kv_caches
model_kwargs["attn_metadata"] = attn_metadata
hidden_states = compiled_model( hidden_states = compiled_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=None, inputs_embeds=None,
kv_caches=self.kv_caches, **model_kwargs,
attn_metadata=attn_metadata,
) )
else: else:
maybe_converting_weight_acl_format(self.model, maybe_converting_weight_acl_format(self.model,
@@ -1787,9 +1848,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.drafter.dummy_run(num_tokens) self.drafter.dummy_run(num_tokens)
return hidden_states return hidden_states
@contextmanager
def set_in_profile_run(self):
self.in_profile_run = True
try:
yield
finally:
self.in_profile_run = False
def profile_run(self) -> None: def profile_run(self) -> None:
# Trigger compilation for general shape. # Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens) with self.set_in_profile_run():
hidden_states = self._dummy_run(self.max_num_tokens,
with_prefill=True)
output = None output = None
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
if self.is_pooling_model: if self.is_pooling_model:
@@ -2159,10 +2230,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)): for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)):
for _ in range(self.vllm_config.compilation_config. for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups): cudagraph_num_of_warmups):
self._dummy_run(num_tokens, self._dummy_run(num_tokens, is_torchair_compile=True)
is_compile=True, self._dummy_run(num_tokens, is_torchair_compile=True)
with_prefill=False)
self._dummy_run(num_tokens, is_compile=True, with_prefill=False)
logger.info("Batchsize %d is compiled successfully: %d/%d.", logger.info("Batchsize %d is compiled successfully: %d/%d.",
num_tokens, idx + 1, len(torchair_graph_batch_sizes)) num_tokens, idx + 1, len(torchair_graph_batch_sizes))
@@ -2205,6 +2274,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Trigger ACL graph capture for specific shapes. # Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes # Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes. # can reuse the memory pool allocated for the large shapes.
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
with graph_capture(device=self.device): with graph_capture(device=self.device):
for num_tokens in reversed(self.aclgraph_batch_sizes): for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config. for _ in range(self.vllm_config.compilation_config.

View File

@@ -2,12 +2,12 @@ import torch
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config, from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config) set_current_vllm_config)
from vllm.forward_context import set_forward_context
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype) process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
@@ -117,7 +117,7 @@ class MtpProposer:
query_start_loc=cu_num_tokens, query_start_loc=cu_num_tokens,
) )
with set_forward_context(attn_metadata, self.vllm_config): with set_ascend_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=target_positions, positions=target_positions,

View File

@@ -40,9 +40,10 @@ from vllm.v1.worker.worker_base import WorkerBase
from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (sleep_mode_enabled, try_register_lib, from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled,
vllm_version_is) try_register_lib, vllm_version_is)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if not vllm_version_is("0.10.0"): if not vllm_version_is("0.10.0"):
@@ -134,6 +135,7 @@ class NPUWorker(WorkerBase):
NPUPlatform.empty_cache() NPUPlatform.empty_cache()
self.init_npu_memory = NPUPlatform.mem_get_info()[0] self.init_npu_memory = NPUPlatform.mem_get_info()[0]
init_ascend_soc_version()
# Initialize the distributed environment. # Initialize the distributed environment.
self._init_worker_distributed_environment() self._init_worker_distributed_environment()
# Set random seed. # Set random seed.
@@ -272,20 +274,8 @@ class NPUWorker(WorkerBase):
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
return self.model_runner.pin_lora(lora_id) return self.model_runner.pin_lora(lora_id)
def _get_max_num_tokens_and_with_prefill(self):
max_num_tokens = 1
with_prefill = False
if self.model_runner.dp_size > 1:
max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp(
max_num_tokens, with_prefill)
return max_num_tokens, with_prefill
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill( self.model_runner._dummy_run(1)
)
self.model_runner._dummy_run(max_num_tokens,
is_compile=False,
with_prefill=with_prefill)
def _init_worker_distributed_environment(self) -> None: def _init_worker_distributed_environment(self) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
@@ -295,6 +285,7 @@ class NPUWorker(WorkerBase):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size, self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size) self.parallel_config.pipeline_parallel_size)
init_ascend_model_parallel(self.parallel_config)
ensure_kv_transfer_initialized(self.vllm_config) ensure_kv_transfer_initialized(self.vllm_config)
def _init_profiler(self): def _init_profiler(self):