From ba3dfbd59e43b9071895f483d12c034d8538ced0 Mon Sep 17 00:00:00 2001 From: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com> Date: Mon, 28 Jul 2025 14:06:20 +0800 Subject: [PATCH] [main][refactor] Refactoring forward_context and model_runner_v1 (#1979) ### What this PR does / why we need it? A refactoring of forward_context and model_runner_v1, add some context which is necessary in model inference into forward_context, and refactor dummy_run logic, make it more reasonable. Some details for this PR: Add `ascend_forward_context`; Update mc2_v2 op, and support `active_mask` param; Update scripts in examples dir; refactor `dummy_run` logic; Add soc_version for A2 and A3; ### Does this PR introduce _any_ user-facing change? No change at user-facing. ### How was this patch tested? - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/57c22e57f989b466a46a990243bb7f072a668b7f Signed-off-by: zzzzwwjj <1183291235@qq.com> --- examples/offline_dualbatch_overlap_npu.py | 2 +- examples/run_dp_server.sh | 37 ++-- tests/ut/models/test_deepseek_v2.py | 14 +- tests/ut/ops/test_fused_ops.py | 55 ++++-- vllm_ascend/ascend_forward_context.py | 117 ++++++++++++ vllm_ascend/attention/attention_v1.py | 34 ++-- .../attention/attention_v1_torchair.py | 19 +- vllm_ascend/attention/mla_v1.py | 11 +- vllm_ascend/distributed/parallel_state.py | 48 +++++ vllm_ascend/models/deepseek_dbo.py | 16 +- vllm_ascend/models/deepseek_v2.py | 21 +-- vllm_ascend/models/pangu_moe.py | 8 +- vllm_ascend/multistream/ms_split.py | 2 - vllm_ascend/ops/fused_moe.py | 177 +++++++++++------- vllm_ascend/quantization/quantizer.py | 6 +- vllm_ascend/quantization/w8a8_dynamic.py | 131 ++++++++----- vllm_ascend/torchair/torchair_worker.py | 10 - vllm_ascend/utils.py | 56 +++--- vllm_ascend/worker/eagle_proposer_v1.py | 19 +- vllm_ascend/worker/model_runner_v1.py | 168 ++++++++++++----- vllm_ascend/worker/mtp_proposer_v1.py | 4 +- vllm_ascend/worker/worker_v1.py | 21 +-- 22 files changed, 629 insertions(+), 347 deletions(-) create mode 100644 vllm_ascend/ascend_forward_context.py create mode 100644 vllm_ascend/distributed/parallel_state.py diff --git a/examples/offline_dualbatch_overlap_npu.py b/examples/offline_dualbatch_overlap_npu.py index 2cc5213..3829d6a 100644 --- a/examples/offline_dualbatch_overlap_npu.py +++ b/examples/offline_dualbatch_overlap_npu.py @@ -21,6 +21,7 @@ def main(): tensor_parallel_size=2, max_model_len=4096, trust_remote_code=True, + enable_expert_parallel=True, additional_config={ "torchair_graph_config": { "enabled": False @@ -28,7 +29,6 @@ def main(): "ascend_scheduler_config": { "enabled": True }, - "expert_tensor_parallel_size": 1 }) # Generate texts from the prompts. The output is a list of RequestOutput diff --git a/examples/run_dp_server.sh b/examples/run_dp_server.sh index e2bf4c8..eb3cfbf 100644 --- a/examples/run_dp_server.sh +++ b/examples/run_dp_server.sh @@ -1,3 +1,7 @@ +rm -rf ./.torchair_cache/ +rm -rf ./dynamo_* +rm -rf /root/ascend/log/debug/plog/* + export HCCL_IF_IP=2.0.0.0 export GLOO_SOCKET_IFNAME="enp189s0f0" export TP_SOCKET_IFNAME="enp189s0f0" @@ -6,25 +10,24 @@ export HCCL_SOCKET_IFNAME="enp189s0f0" export OMP_PROC_BIND=false export OMP_NUM_THREADS=100 -export VLLM_USE_V1=0 - -export ASCEND_RT_VISIBLE_DEVICES=0,1 -export VLLM_DP_SIZE=2 -export VLLM_DP_RANK=0 -export VLLM_DP_MASTER_IP="2.0.0.0" -export VLLM_DP_MASTER_PORT=40001 -export VLLM_DP_PROXY_IP="2.0.0.0" -export VLLM_DP_PROXY_PORT=30002 -export VLLM_DP_MONITOR_PORT=30003 -export VLLM_HTTP_PORT=20001 +export VLLM_USE_V1=1 +export ASCEND_LAUNCH_BLOCKING=0 vllm serve /data/weights/Qwen2.5-0.5B-Instruct \ --host 0.0.0.0 \ - --port 20001 \ - --tensor-parallel-size 1 \ - --seed 1024 \ + --port 20002 \ --served-model-name Qwen \ - --max-model-len 2000 \ - --max-num-batched-tokens 2000 \ + --data-parallel-size 4 \ + --data-parallel-size-local 4 \ + --data-parallel-address 2.0.0.0 \ + --data-parallel-rpc-port 13389 \ + --tensor-parallel-size 4 \ + --enable-expert-parallel \ + --no-enable-prefix-caching \ + --max-num-seqs 16 \ + --max-model-len 4096 \ + --max-num-batched-tokens 4096 \ + --gpu-memory-utilization 0.9 \ --trust-remote-code \ - --gpu-memory-utilization 0.9 \ \ No newline at end of file + --enforce-eager \ + --additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}' diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index f3a7d1a..3902b5b 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -114,7 +114,16 @@ def mock_distributed(): return_value=Mock(is_first_rank=False, is_last_rank=False)), \ patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, - _PP=pp_group): + _PP=pp_group), \ + patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group): + yield + + +@pytest.fixture +def mock_forward_context(): + forward_context = Mock(in_profile_run=False, with_prefill=False) + with patch("vllm_ascend.models.deepseek_v2.get_forward_context", + return_value=forward_context): yield @@ -205,7 +214,8 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config): quant_config=None) -def test_custom_deepseek_v2_moe(mock_distributed, base_config): +def test_custom_deepseek_v2_moe(mock_distributed, base_config, + mock_forward_context): base_config.n_shared_experts = 1 moe = CustomDeepseekV2MoE(config=base_config, quant_config=None, diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index eb265b9..2b6f1aa 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -18,16 +18,18 @@ from unittest.mock import MagicMock, patch import pytest import torch import torch.nn as nn +import torch_npu from pytest_mock import MockerFixture +from vllm_ascend.ascend_forward_context import get_fused_moe_state from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) -from vllm_ascend.utils import adapt_patch # noqa E402 +from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 adapt_patch(True) -def mock_ep_group(mocker): +def mock_ep_and_mc2_group(mocker): mock_group = mocker.MagicMock() mock_group.rank_in_group = 0 mock_group.rank = 0 @@ -52,7 +54,8 @@ def mock_dist_env(mocker: MockerFixture): with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ - patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \ + patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ @@ -73,7 +76,7 @@ def mock_dist_env(mocker: MockerFixture): return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ patch('vllm_ascend.ops.fused_moe.get_forward_context', return_value=MagicMock( - attn_metadata=MagicMock(max_num_tokens_across_dp=10), + max_tokens_across_dp=10, dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]) )), \ patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', @@ -122,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture): patch("torch_npu.npu_moe_finalize_routing", return_value=( torch.randn(16, 2) )): - yield + if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'): + with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=( + torch.randn(16, 2))), \ + patch("torch_npu.npu_moe_distribute_combine_v2", return_value=( + torch.randn(16, 2))): + yield + else: + yield @pytest.fixture @@ -237,11 +247,16 @@ class TestAscendFusedMoe: moe.moe_parallel_config.ep_size = 1 moe.quant_method = MockQuantMethod(shared_experts, num_tokens) - output = moe.forward(inputs, - router_logits, - is_prefill=is_prefill, - top_k=top_k, - shared_experts=shared_experts) + forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens, + dtype=torch.bool), + padded_num_tokens=num_tokens) + with patch("vllm_ascend.ops.fused_moe.get_forward_context", + return_value=forward_context): + output = moe.forward(inputs, + router_logits, + is_prefill=is_prefill, + top_k=top_k, + shared_experts=shared_experts) moe.quant_method.apply.assert_called_once() @@ -288,15 +303,20 @@ class TestAscendUnquantizedFusedMoEMethod: def test_apply_without_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): """ - 1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all + 1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all 2 test use_select_experts and fused_experts 3 test use select_gating_topk_softmax_experts and fused_experts 4 test use select_experts and fused_experts_with_all2all_buffer """ global_num_experts, ep_size, select_softmax = others_param + is_prefill = False + is_deepseek_v3_r1 = global_num_experts == 256 + forward_context = MagicMock(fused_moe_state=get_fused_moe_state( + ep_size, is_prefill, is_deepseek_v3_r1)) with patch( "vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS", - select_softmax): + select_softmax), \ + patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context): moe_method.ep_size = ep_size x = torch.randn(8, 2, 2) router_logits = torch.randn(8, 8) @@ -309,7 +329,7 @@ class TestAscendUnquantizedFusedMoEMethod: top_k=2, renormalize=True, global_num_experts=global_num_experts, - is_prefill=False) + is_prefill=is_prefill) if ep_size == 1: assert result.shape == (16, 2) @@ -327,8 +347,13 @@ class TestAscendUnquantizedFusedMoEMethod: 4 test use_select_experts and fused_experts """ ep_size, alltoall_buffer = others_param + is_prefill = False + forward_context = MagicMock( + fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True)) with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER", - alltoall_buffer): + alltoall_buffer), \ + patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ + patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3): expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) moe_method.ep_size = ep_size x = torch.randn(8, 2, 2) @@ -347,7 +372,7 @@ class TestAscendUnquantizedFusedMoEMethod: renormalize=True, global_num_experts=128, expert_map=expert_map, - is_prefill=False) + is_prefill=is_prefill) if ep_size == 16 or ep_size == 1: assert result.shape == (16, 2) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py new file mode 100644 index 0000000..83e4ee8 --- /dev/null +++ b/vllm_ascend/ascend_forward_context.py @@ -0,0 +1,117 @@ +import math +from contextlib import contextmanager +from enum import Enum +from typing import Any, Optional + +import torch +from vllm.config import VllmConfig +from vllm.distributed import get_dp_group, get_ep_group, get_tp_group +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.platforms import current_platform + +import vllm_ascend.envs as envs + + +class FusedMoEState(Enum): + AllGather = 0 + All2All = 1 + MC2 = 2 + AllGatherEP = 3 + NaiveMulticast = 4 + + +# TODO(zzzzwwjj): add soc_version to choose branch +def get_fused_moe_state(ep_size: int, with_prefill: bool, + is_deepseek_v3_r1: bool): + # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep + # only supports deepseek v3/r1 + if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + and is_deepseek_v3_r1): + return FusedMoEState.AllGatherEP + elif ep_size == 1: + if with_prefill: + return FusedMoEState.NaiveMulticast + else: + return FusedMoEState.AllGather + # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. + elif ep_size < 16 or with_prefill: + return FusedMoEState.All2All + else: + return FusedMoEState.MC2 + + +@contextmanager +def set_ascend_forward_context( + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + with_prefill: bool = True, + in_profile_run: bool = False, + num_actual_tokens: Optional[int] = None, +): + """A context manager that stores the current forward context, + can be attention metadata, etc. + We add some additional param into forward_context. + """ + with set_forward_context(attn_metadata, + vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp): + forward_context = get_forward_context() + forward_context.with_prefill = with_prefill + ep_size = (get_ep_group().world_size if + vllm_config.parallel_config.enable_expert_parallel else 1) + + is_deepseek_v3_r1 = hasattr( + vllm_config.model_config.hf_config, 'n_routed_experts' + ) and vllm_config.model_config.hf_config.n_routed_experts == 256 + fused_moe_state = get_fused_moe_state(ep_size, with_prefill, + is_deepseek_v3_r1) + + forward_context.fused_moe_state = fused_moe_state + + forward_context.in_profile_run = in_profile_run + + # NOTE: This cannot be set using set_forward_context + # due to multiple warmups before actual capturing + forward_context.capturing = False + + if num_tokens is None and attn_metadata is not None: + if hasattr(attn_metadata, 'num_actual_tokens'): + # for v1 engine + num_tokens = attn_metadata.num_actual_tokens + else: + # for v0 engine + num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + + if num_actual_tokens is None: + num_actual_tokens = num_tokens + + dp_world_size = get_dp_group().world_size + if dp_world_size > 1 and forward_context.dp_metadata is not None: + max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item( + ) + else: + max_tokens_across_dp = num_tokens + + forward_context.max_tokens_across_dp = max_tokens_across_dp + + if num_tokens is not None: + tp_world_size = get_tp_group().world_size + # NOTE: token num which need to pad to when mc2 + forward_context.padded_num_tokens = math.ceil( + max_tokens_across_dp / tp_world_size) * tp_world_size + + mc2_mask = torch.zeros(forward_context.padded_num_tokens, + dtype=torch.bool, + device=current_platform.device_type) + mc2_mask[:num_actual_tokens] = True + forward_context.mc2_mask = mc2_mask + + try: + yield + finally: + pass diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b29017e..dc92136 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -119,6 +119,7 @@ class AscendAttentionState(Enum): @dataclass class AscendMetadata: + # **************************** Basic Properties **************************** attn_mask: Optional[torch.Tensor] = None # Current state of this attention run. @@ -149,11 +150,6 @@ class AscendMetadata: # (num_tokens,) slot_mapping: torch.Tensor = None - # ************************* DP Related Properties ************************** - with_prefill_across_dp: bool = False - # Maximum number of tokens across dp group - max_num_tokens_across_dp: int = 0 - class AscendAttentionMetadataBuilder: @@ -164,12 +160,7 @@ class AscendAttentionMetadataBuilder: scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, - num_reqs, - num_actual_tokens, - max_query_len, - max_num_tokens_across_dp: int = 0, - with_prefill_across_dp: bool = False): + def build(self, num_reqs, num_actual_tokens, max_query_len): block_table = self.runner.input_batch.block_table[0].get_device_tensor( ) @@ -196,18 +187,15 @@ class AscendAttentionMetadataBuilder: attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), ACL_FORMAT_FRACTAL_NZ) - attn_metadata = AscendMetadata( - num_actual_tokens=num_actual_tokens, - block_tables=block_table, - query_start_loc=query_start_loc, - query_lens=query_lens, - seq_lens=seq_lens, - max_query_len=max_query_len, - slot_mapping=slot_mapping, - attn_mask=attn_mask, - attn_state=attn_state, - max_num_tokens_across_dp=max_num_tokens_across_dp, - with_prefill_across_dp=with_prefill_across_dp) + attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens, + block_tables=block_table, + query_start_loc=query_start_loc, + query_lens=query_lens, + seq_lens=seq_lens, + max_query_len=max_query_len, + slot_mapping=slot_mapping, + attn_mask=attn_mask, + attn_state=attn_state) return attn_metadata diff --git a/vllm_ascend/attention/attention_v1_torchair.py b/vllm_ascend/attention/attention_v1_torchair.py index 9bfc038..3b4c7a9 100644 --- a/vllm_ascend/attention/attention_v1_torchair.py +++ b/vllm_ascend/attention/attention_v1_torchair.py @@ -127,8 +127,6 @@ class AscendTorchairMetadata: query_start_loc: torch.Tensor query_lens: torch.Tensor seq_lens: torch.Tensor - # max value of number of tokens across dp group - max_num_tokens_across_dp: int = 0 # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None # (num_tokens,). The indices of the token slots that input tokens will be @@ -139,7 +137,7 @@ class AscendTorchairMetadata: # Current state of this attention run. attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_mask: Optional[torch.Tensor] = None - with_prefill_across_dp: bool = False + decode: Optional[AscendDecodeMetadata] = None @@ -178,8 +176,9 @@ class AscendAttentionTorchairMetadataBuilder: return graph_block_tables[:num_seqs, :max_blocks] - def build_dummy(self, num_reqs: int, - num_actual_tokens: int) -> AscendTorchairMetadata: + def build_torchair_graph_dummy( + self, num_reqs: int, + num_actual_tokens: int) -> AscendTorchairMetadata: device = self.runner.device _, max_blocks = self.runner.graph_block_tables.shape block_table = torch.zeros((num_reqs, max_blocks), @@ -214,7 +213,6 @@ class AscendAttentionTorchairMetadataBuilder: seq_lens=seq_lens, slot_mapping=slot_mapping, attn_state=AscendAttentionState.DecodeOnly, - max_num_tokens_across_dp=num_reqs, decode=decode_metadata) return attn_metadata @@ -222,9 +220,7 @@ class AscendAttentionTorchairMetadataBuilder: num_reqs, num_actual_tokens, max_query_len, - graph_pad_size: int = -1, - max_num_tokens_across_dp: int = 0, - with_prefill_across_dp: bool = False): + graph_pad_size: int = -1): device = self.runner.device @@ -263,7 +259,6 @@ class AscendAttentionTorchairMetadataBuilder: pad_value = 1 padded_seq_lens = seq_lens.tolist() + [pad_value ] * graph_pad_size - max_num_tokens_across_dp = len(padded_seq_lens) seq_lens = torch.from_numpy( np.array(padded_seq_lens).astype(np.int32)) @@ -303,9 +298,7 @@ class AscendAttentionTorchairMetadataBuilder: max_query_len=max_query_len, slot_mapping=slot_mapping, attn_mask=attn_mask, - attn_state=attn_state, - max_num_tokens_across_dp=max_num_tokens_across_dp, - with_prefill_across_dp=with_prefill_across_dp) + attn_state=attn_state) return attn_metadata diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 60c281a..537cd8e 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -126,9 +126,6 @@ class AscendMLAMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. - max_num_tokens_across_dp: int = 0 - with_prefill_across_dp: bool = False - query_lens: Optional[list[int]] = None # The dimension of the attention heads head_dim: Optional[int] = None @@ -302,8 +299,8 @@ class AscendMLAMetadataBuilder: return graph_block_tables[:num_seqs, :max_blocks] - def build_dummy(self, num_reqs: int, - num_actual_tokens: int) -> AscendMLAMetadata: + def build_torchair_graph_dummy( + self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata: device = self.runner.device _, max_blocks = self.runner.graph_block_tables.shape block_table = torch.zeros((num_reqs, max_blocks), @@ -353,8 +350,6 @@ class AscendMLAMetadataBuilder: num_actual_tokens: int, max_query_len: int, graph_pad_size: int = -1, - max_num_tokens_across_dp: int = 0, - with_prefill_across_dp: bool = False, query_start_loc: torch.Tensor = None, ) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs @@ -498,8 +493,6 @@ class AscendMLAMetadataBuilder: query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, - max_num_tokens_across_dp=max_num_tokens_across_dp, - with_prefill_across_dp=with_prefill_across_dp, ) diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py new file mode 100644 index 0000000..1e31359 --- /dev/null +++ b/vllm_ascend/distributed/parallel_state.py @@ -0,0 +1,48 @@ +from typing import Optional + +import torch +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group, + init_model_parallel_group) + +# Currently, mc2 op need their own group coordinator. +_MC2: Optional[GroupCoordinator] = None + + +def get_mc2_group() -> GroupCoordinator: + assert _MC2 is not None, ("mc2 group is not initialized") + return _MC2 + + +def model_parallel_initialized(): + return (_MC2 is not None) + + +def init_ascend_model_parallel(parallel_config: ParallelConfig, ): + if model_parallel_initialized(): + return + assert torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() + backend = torch.distributed.get_backend(get_world_group().device_group) + + # The layout of all ranks: ExternalDP * EP + # ExternalDP is the data parallel group that is not part of the model, + # every dp rank can generate independently (in verl integration). + all_ranks = torch.arange(world_size).reshape( + -1, parallel_config.data_parallel_size * + parallel_config.tensor_parallel_size) + global _MC2 + group_ranks = all_ranks.unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + + _MC2 = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="mc2") + + +def destroy_ascend_model_parallel(): + global _MC2 + if _MC2: + _MC2.destroy() + _MC2 = None diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 3d77785..13e5efa 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -174,20 +174,12 @@ class CustomDeepseekDBOMoE(nn.Module): self, hidden_states: torch.Tensor, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata + forward_context = get_forward_context() # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata.num_prefills > 0 - enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + enable_force_load_balance = forward_context.in_profile_run + + is_prefill = forward_context.with_prefill old_hidden_states = hidden_states.clone() diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 9b28278..cb0649a 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -377,20 +377,14 @@ class CustomDeepseekV2MoE(nn.Module): attn_metadata: Optional[AttentionMetadata] = None, replace_allreduce: bool = False) -> torch.Tensor: - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata + forward_context = get_forward_context() # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = attn_metadata.num_prefills > 0 - enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + + enable_force_load_balance = forward_context.in_profile_run + + is_prefill = forward_context.with_prefill + # If this node is kv_consumer, we force the moe always runs in decode path to make sure # the behaviour aligned between dummy_run and normal model_execute. if self.kv_consumer: @@ -572,9 +566,10 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() enable_multistream_mla = (self.enable_multistream_mla and attn_metadata is not None - and not attn_metadata.with_prefill_across_dp + and not forward_context.with_prefill and attn_metadata.num_decodes > 0) forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} if self.q_lora_rank is not None: diff --git a/vllm_ascend/models/pangu_moe.py b/vllm_ascend/models/pangu_moe.py index 0d2d9a6..f31650f 100644 --- a/vllm_ascend/models/pangu_moe.py +++ b/vllm_ascend/models/pangu_moe.py @@ -837,12 +837,8 @@ class PanguProMoEModel(nn.Module): # if attn_meatadata is not passed, we try to get it from forward_context. if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata - if attn_metadata is None: - # when attn_meatadata is None, it is in profile_run. num_tokens on all dp ranks - # are same. - max_tokens_across_dp = hidden_states.shape[0] - else: - max_tokens_across_dp = attn_metadata.max_num_tokens_across_dp + + max_tokens_across_dp = get_forward_context().max_tokens_across_dp tp_size = get_tp_group().world_size # reduce scatter will split the input tensor into equal sizes and then scatter them on all ranks. diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 430f57b..3af6337 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -223,7 +223,6 @@ def model_input_split_v1_mla_attn( attn_mask=attn_mask_pre, prefill=prefill_pre, decode=decode_pre, - with_prefill_across_dp=attn_metadata.with_prefill_across_dp, ) attention_metadata_post = _metadata_cls( num_actual_tokens=attn_metadata.num_actual_tokens - token_index, @@ -240,6 +239,5 @@ def model_input_split_v1_mla_attn( attn_state=attn_state_post, prefill=prefill_post, decode=decode_post, - with_prefill_across_dp=attn_metadata.with_prefill_across_dp, ) return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 8697477..61205ff 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -40,12 +40,15 @@ from vllm.model_executor.layers.quantization.base_config import \ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.communication_op import \ data_parallel_reduce_scatter +from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor -from vllm_ascend.utils import (FusedMoEState, dispose_tensor, - get_all_reduce_merge_state, get_fused_moe_state, +from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, + get_all_reduce_merge_state, + get_ascend_soc_version, get_rm_router_logits_state, is_310p) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -127,9 +130,23 @@ def fused_experts_with_mc2( moe_parallel_config: FusedMoEParallelConfig, expert_map: torch.Tensor = None, moe_all_to_all_group_name: Optional[str] = None, - shared_experts: Optional[Any] = None + shared_experts: Optional[Any] = None, + is_torchair: bool = False, + mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - global_bs = 0 + quant_mode = 0 + ep_rank_id = moe_parallel_config.ep_rank + ep_world_size = moe_parallel_config.ep_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + + enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") + moe_expert_num = len(expert_map) kwargs_mc2 = { "x": hidden_states, @@ -137,32 +154,35 @@ def fused_experts_with_mc2( "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": global_bs, + "global_bs": 0, } - rank = torch.distributed.get_rank() - - quant_mode = 0 - ep_rank_id = moe_parallel_config.ep_rank - ep_world_size = moe_parallel_config.ep_size - - tp_world_size = moe_parallel_config.tp_size - tp_rank = rank % tp_world_size - stage1_kwargs = { "scales": None, "quant_mode": quant_mode, "group_ep": moe_all_to_all_group_name, "ep_world_size": ep_world_size, "ep_rank_id": ep_rank_id, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_world_size, - "tp_rank_id": tp_rank, } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) + kwargs_mc2.update(stage1_kwargs) - output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) - expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ + output = torch_npu.npu_moe_distribute_dispatch_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( + **kwargs_mc2) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[ 0:5] if shared_experts is not None: @@ -205,7 +225,6 @@ def fused_experts_with_mc2( kwargs_mc2 = { "expand_x": down_out_list, "expert_ids": topk_ids, - "expand_idx": expand_idx, "expert_scales": topk_weights.to(torch.float32), "expert_shard_type": 0, "shared_expert_rank_num": 0, @@ -218,15 +237,33 @@ def fused_experts_with_mc2( "group_ep": moe_all_to_all_group_name, "ep_world_size": ep_world_size, "ep_rank_id": ep_rank_id, - "tp_send_counts": tp_recv_counts, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_world_size, - "tp_rank_id": tp_rank, } + if enable_dispatch_v2: + stage3_kwargs.update({ + "assist_info_for_combine": + assist_info_for_combine, + }) + else: + stage3_kwargs.update({ + "expand_idx": assist_info_for_combine, + }) + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) - hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + hidden_states = torch_npu.npu_moe_distribute_combine_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( + **kwargs_mc2) if shared_experts is None: return hidden_states @@ -981,17 +1018,14 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): super().__init__(moe=moe) vllm_config = get_current_vllm_config() - self.ep_group = get_ep_group() - self.ep_size = self.moe.moe_parallel_config.ep_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs - self.local_batch_size = self.global_batch_size // self.ep_size self.max_model_len = vllm_config.model_config.max_model_len ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled try: - device_group = self.ep_group.device_group + device_group = get_mc2_group().device_group # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) @@ -1074,8 +1108,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - fused_moe_state = get_fused_moe_state(self.ep_size, is_prefill, - is_deepseek_v3_r1) + fused_moe_state = get_forward_context().fused_moe_state + if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, @@ -1087,7 +1121,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, - shared_experts=shared_experts) + shared_experts=shared_experts, + mc2_mask=kwargs.get("mc2_mask", None)) elif fused_moe_state in [ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast ]: @@ -1295,52 +1330,56 @@ class AscendFusedMoE(FusedMoE): real_top_k = self.top_k num_tokens, hidden_size = hidden_states.shape - is_deepseek_v3_r1 = self.global_num_experts == 256 - fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, - is_prefill, is_deepseek_v3_r1) + forward_context = get_forward_context() + fused_moe_state = forward_context.fused_moe_state + mc2_mask = forward_context.mc2_mask if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce shared_hidden_states = shared_experts(hidden_states) tp_size = get_tensor_model_parallel_world_size() - if (tp_size > 1 and fused_moe_state not in [ + if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast ] and not replace_allreduce): - if num_tokens < tp_size: + if fused_moe_state in {FusedMoEState.MC2}: + padding_size = forward_context.padded_num_tokens + else: + # TODO: Determine if we can remove the padding + padding_size = tp_size + if num_tokens < padding_size: hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, tp_size - num_tokens)) + hidden_states, (0, 0, 0, padding_size - num_tokens)) router_logits = nn.functional.pad( - router_logits, (0, 0, 0, tp_size - num_tokens)) - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - tp_rank = get_tensor_model_parallel_rank() - hidden_states = chunk_hidden_states[tp_rank] - router_logits = chunk_router_logits[tp_rank] + router_logits, (0, 0, 0, padding_size - num_tokens)) + if tp_size > 1: + chunk_hidden_states = torch.tensor_split(hidden_states, + tp_size, + dim=0) + chunk_router_logits = torch.tensor_split(router_logits, + tp_size, + dim=0) + chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) + tp_rank = get_tensor_model_parallel_rank() + hidden_states = chunk_hidden_states[tp_rank] + router_logits = chunk_router_logits[tp_rank] + mc2_mask = chunk_mc2_mask[tp_rank] if self.dp_size > 1: if fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: - max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp - if num_tokens < max_num_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, - max_num_tokens_across_dp - num_tokens)) - if not self.rm_router_logits: - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, - max_num_tokens_across_dp - num_tokens)) + max_tokens_across_dp = forward_context.max_tokens_across_dp + if num_tokens < max_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, max_tokens_across_dp - num_tokens)) + if not self.rm_router_logits: + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, max_tokens_across_dp - num_tokens)) hidden_states = get_dp_group().all_gather(hidden_states, 0) if self.rm_router_logits: router_logits, _ = gate(hidden_states) @@ -1379,20 +1418,24 @@ class AscendFusedMoE(FusedMoE): global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=shared_experts if self.torchair_graph_enabled and self.enable_multistream_moe and not is_prefill else None, + mc2_mask=mc2_mask, ) if shared_experts: if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states - if (tp_size > 1 and fused_moe_state not in [ + if (fused_moe_state not in [ FusedMoEState.AllGather, FusedMoEState.AllGatherEP, FusedMoEState.NaiveMulticast ] and not replace_allreduce): - dist.all_gather(list(chunk_hidden_states), e_hidden_states, - self.tp_group) - final_hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens < tp_size: + if tp_size > 1: + dist.all_gather(list(chunk_hidden_states), e_hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + else: + final_hidden_states = e_hidden_states + if num_tokens < padding_size: final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) elif self.dp_size > 1: diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index 02f4486..8178d5e 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -22,8 +22,7 @@ from typing import Any, Dict, List, Optional from vllm.logger import logger -from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot, - wrapper_rmsnorm_init) +from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod) from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod, @@ -81,9 +80,6 @@ class VLLMAscendQuantizer: VLLMAscendQuantizer.apply_patch( "vllm.model_executor.layers.layernorm.RMSNorm", "forward_oot", [wrapper_rmsnorm_forward_oot]) - VLLMAscendQuantizer.apply_patch( - "vllm_ascend.worker.model_runner.NPUModelRunnerBase", - "load_model", [wrapper_load_model]) break VLLMAscendQuantizer.patched = True logger.info("Using the vLLM Ascend Quantizer version now!") diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 261e43b..f1667d0 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -20,15 +20,17 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.distributed as dist import torch_npu -from vllm.distributed import GroupCoordinator -from vllm.distributed.parallel_state import get_ep_group +from vllm.distributed import GroupCoordinator, get_ep_group +from vllm.forward_context import get_forward_context import vllm_ascend.envs as envs from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState, - dispose_tensor, get_fused_moe_state) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, + dispose_tensor, get_ascend_soc_version) def apply_mlp(hidden_states: torch.Tensor, @@ -118,10 +120,29 @@ def fused_experts_with_mc2( log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, shared_experts: Optional[Any] = None, + is_torchair: bool = False, + quantized_x_for_share: Optional[Any] = None, + dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + assert mc2_mask is not None if log2phy is not None: topk_ids = log2phy[topk_ids] - global_bs = 0 + + quant_mode = 2 + ep_group = get_mc2_group() + ep_rank_id = ep_group.rank_in_group + ep_world_size = ep_group.world_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + + enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") + if (expert_map is not None): moe_expert_num = len(expert_map) + global_redundant_expert_num else: @@ -133,47 +154,43 @@ def fused_experts_with_mc2( "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, - "global_bs": global_bs, - "expert_scales": topk_weights.to(torch.float32), + "global_bs": 0, } - rank = torch.distributed.get_rank() - - quant_mode = 2 - ep_group = get_ep_group().device_group - local_rank = torch.distributed.get_rank(group=ep_group) - all_to_all_group_size = torch.distributed.get_world_size(ep_group) - - world_size = torch.distributed.get_world_size() - tp_size = world_size // all_to_all_group_size - tp_rank = rank % tp_size - stage1_kwargs = { "scales": None, "quant_mode": quant_mode, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage1_kwargs) - output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) + output = torch_npu.npu_moe_distribute_dispatch_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( + **kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[ - 0:7] + expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[ + 0:5] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(hidden_states, topk_weights) - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - npu_wait_tensor(shared_gate_up[0], expand_x) - shared_act = shared_experts.act_fn(shared_gate_up) + npu_wait_tensor(quantized_x_for_share, expand_x) + shared_act_out = shared_experts.act_fn( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] - # `expand_x` will be disposed in the `apply_mlp` function down_out_list = apply_mlp(expand_x, w1, w1_scale, @@ -186,13 +203,11 @@ def fused_experts_with_mc2( kwargs_mc2 = { "expand_x": down_out_list, "expert_ids": topk_ids, - "expand_idx": expand_idx, "expert_scales": topk_weights.to(torch.float32), "expert_shard_type": 0, "shared_expert_rank_num": 0, "moe_expert_num": moe_expert_num, "global_bs": 0, - "expand_scales": expand_scales, } tp_recv_counts = torch.empty(1, dtype=torch.int32, @@ -200,24 +215,43 @@ def fused_experts_with_mc2( stage3_kwargs = { "ep_send_counts": ep_recv_counts, "group_ep": moe_all_to_all_group_name, - "ep_world_size": all_to_all_group_size, - "ep_rank_id": local_rank, - "tp_send_counts": tp_recv_counts, - # "group_tp": self.moe_rs_group_name, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": tp_size, - "tp_rank_id": tp_rank, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, } + if enable_dispatch_v2: + stage3_kwargs.update({ + "assist_info_for_combine": + assist_info_for_combine, + }) + else: + stage3_kwargs.update({ + "expand_idx": assist_info_for_combine, + }) + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + if a3_need_extra_args and enable_dispatch_v2: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) - hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + hidden_states = torch_npu.npu_moe_distribute_combine_v2( + **kwargs_mc2 + ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( + **kwargs_mc2) if shared_experts is None: return hidden_states else: with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_act[0], down_out_list) - shared_output, _ = shared_experts.down_proj(shared_act) + npu_wait_tensor(shared_act, down_out_list) + shared_output, _ = shared_experts.down_proj( + (shared_act, swiglu_out_scale)) return hidden_states, shared_output @@ -640,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod: self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled try: - device_group = self.ep_group.device_group + device_group = get_mc2_group().device_group # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) @@ -755,8 +789,7 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(x.dtype) - fused_moe_state = get_fused_moe_state(self.ep_group.world_size, - is_prefill, is_deepseek_v3_r1) + fused_moe_state = get_forward_context().fused_moe_state if fused_moe_state == FusedMoEState.AllGatherEP: return fused_experts_with_allgather( hidden_states=x, @@ -782,7 +815,9 @@ class AscendW8A8DynamicFusedMoEMethod: moe_all_to_all_group_name=self.moe_all_to_all_group_name, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts) + shared_experts=shared_experts, + is_torchair=self.torchair_graph_enabled, + mc2_mask=kwargs.get("mc2_mask", None)) elif fused_moe_state in [ FusedMoEState.AllGather, FusedMoEState.NaiveMulticast ]: diff --git a/vllm_ascend/torchair/torchair_worker.py b/vllm_ascend/torchair/torchair_worker.py index f5e2f21..f74bc02 100644 --- a/vllm_ascend/torchair/torchair_worker.py +++ b/vllm_ascend/torchair/torchair_worker.py @@ -52,13 +52,3 @@ class NPUTorchairWorker(NPUWorker): self.model_runner.new_kv_cache_bytes = available_kv_cache_memory return available_kv_cache_memory - - def _get_max_num_tokens_and_with_prefill(self): - """Override _get_max_num_tokens_and_with_prefill to update max_num_tokens.""" - - max_num_tokens, with_prefill = super( - )._get_max_num_tokens_and_with_prefill() - if not with_prefill: - max_num_tokens = self.model_runner.select_torchair_padded_batch_size( - max_num_tokens) - return max_num_tokens, with_prefill diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 3e1785f..a5e4984 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -429,15 +429,6 @@ def npu_prefetch(input: torch.Tensor, torch_npu.npu_prefetch(input, dependency, max_size) -# TODO(zzzzwwjj): move this into forward_context -class FusedMoEState(Enum): - AllGather = 0 - All2All = 1 - MC2 = 2 - AllGatherEP = 3 - NaiveMulticast = 4 - - # TODO(ttanzhiqiang): rm_router_logits # dp>1 will trigger # In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. @@ -468,26 +459,6 @@ def get_all_reduce_merge_state(ep_size: int, is_deepseek_v3_r1: bool): return False -# TODO(zzzzwwjj): add soc_version to choose branch -def get_fused_moe_state(ep_size: int, with_prefill: bool, - is_deepseek_v3_r1: bool): - # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep - # only supports deepseek v3/r1 - if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 - and is_deepseek_v3_r1): - return FusedMoEState.AllGatherEP - elif ep_size == 1: - if with_prefill: - return FusedMoEState.NaiveMulticast - else: - return FusedMoEState.AllGather - # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. - elif ep_size < 16 or with_prefill: - return FusedMoEState.All2All - else: - return FusedMoEState.MC2 - - def register_ascend_customop(): """Register Ascend CustomOP @@ -506,3 +477,30 @@ def register_ascend_customop(): # NOTE: Keep this at last to ensure all custom actions are registered _ASCEND_CUSTOMOP_IS_REIGISTERED = True + + +# TODO(zzzzwwjj): It will be judged with _build_info afterwards. +class AscendSocVersion(Enum): + A2 = 0 + A3 = 1 + UNDEFINED = 2 + + +_ascend_soc_version = None + + +def init_ascend_soc_version(): + soc_version = torch_npu.npu.get_soc_version() + global _ascend_soc_version + if 220 <= soc_version <= 225: + _ascend_soc_version = AscendSocVersion.A2 + elif 250 <= soc_version <= 255: + _ascend_soc_version = AscendSocVersion.A3 + else: + _ascend_soc_version = AscendSocVersion.UNDEFINED + + +def get_ascend_soc_version(): + global _ascend_soc_version + assert _ascend_soc_version is not None + return _ascend_soc_version diff --git a/vllm_ascend/worker/eagle_proposer_v1.py b/vllm_ascend/worker/eagle_proposer_v1.py index 3ce0d87..660f0f3 100644 --- a/vllm_ascend/worker/eagle_proposer_v1.py +++ b/vllm_ascend/worker/eagle_proposer_v1.py @@ -7,13 +7,13 @@ from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.v1.sample.metadata import SamplingMetadata +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -142,9 +142,9 @@ class EagleProposer: self.positions[:num_tokens] = target_positions.to(device) self.hidden_states[:num_tokens] = target_hidden_states attn_metadata.block_tables = block_table.to(device) - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], @@ -239,9 +239,9 @@ class EagleProposer: attn_metadata.attn_mask = attn_mask attn_metadata.block_tables = block_table.to(device) # Run the model. - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): + with set_ascend_forward_context(attn_metadata, + self.vllm_config, + num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:input_batch_size], @@ -344,8 +344,9 @@ class EagleProposer: self, num_tokens: int, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + with set_ascend_forward_context(None, + self.vllm_config, + num_tokens=num_tokens): self.model( input_ids=self.input_ids[:num_tokens], positions=self.positions[:num_tokens], diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 706ecdb..8ea680c 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -34,7 +34,6 @@ import torch import torch._dynamo.cache_size import torch.distributed as dist import torch.nn as nn -from torch.distributed import ReduceOp from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig @@ -44,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, get_tp_group) -from vllm.forward_context import get_forward_context, set_forward_context +from vllm.forward_context import get_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE @@ -77,6 +76,7 @@ from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders, scatter_mm_placeholders) from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import (AscendAttentionState, AscendMetadata) @@ -347,6 +347,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): torch._logging.set_logs( recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) + # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True + self.in_profile_run = False + # kv role self.is_kv_producer = False if vllm_config.kv_transfer_config is not None: @@ -566,16 +569,44 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.refresh_sampling_metadata() def _get_forward_metadata_across_dp( - self, total_num_scheduled_tokens: int, - with_prefill: bool) -> tuple[int, bool]: - forward_metadata = torch.tensor( - [total_num_scheduled_tokens, with_prefill], - device="cpu", - dtype=torch.int32) - dist.all_reduce(forward_metadata, - op=ReduceOp.MAX, - group=get_dp_group().cpu_group) - return int(forward_metadata[0]), bool(forward_metadata[1] > 0) + self, + maybe_padded_num_tokens: int, + num_tokens: int, + with_prefill: bool, + enable_dbo: bool = False, + ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + if self.dp_size == 1: + return maybe_padded_num_tokens, None, with_prefill, enable_dbo + + num_tokens_across_dp = [0] * self.dp_size * 2 + num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens + num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens + forward_metadata = torch.tensor(num_tokens_across_dp + + [with_prefill, not enable_dbo], + device="cpu", + dtype=torch.int32) + dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group) + with_prefill = bool(forward_metadata[-2]) + + # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad. + if with_prefill: + num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size * + 2] + maybe_padded_num_tokens = num_tokens + else: + num_tokens_across_dp = forward_metadata[:self.dp_size] + + # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to + # `max_tokens_across_dp`, in other situation it is not necessary. + if self.torchair_graph_enabled and not with_prefill: + maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item() + num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] * + self.dp_size, + device="cpu", + dtype=torch.int32) + + return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool( + forward_metadata[-1]) def get_eagle_atten_dict( self, @@ -1052,21 +1083,16 @@ class NPUModelRunner(LoRAModelRunnerMixin): AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding ] - if self.dp_size > 1: - max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( - total_num_scheduled_tokens, with_prefill) - extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens - extra_builder_kwargs['with_prefill_across_dp'] = with_prefill - - # Add graph_pad_size here + maybe_padded_num_tokens = total_num_scheduled_tokens if self.torchair_graph_enabled and not with_prefill: - if self.dp_size > 1: - padded_batch_size = self.select_torchair_padded_batch_size( - max_num_tokens) - else: - padded_batch_size = self.select_torchair_padded_batch_size( - total_num_scheduled_tokens) - graph_pad_size = padded_batch_size - total_num_scheduled_tokens + maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + total_num_scheduled_tokens) + (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, + enable_dbo) = self._get_forward_metadata_across_dp( + maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill) + + if self.torchair_graph_enabled and not with_prefill: + graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens extra_builder_kwargs['graph_pad_size'] = graph_pad_size @@ -1134,8 +1160,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): positions = self.mrope_positions[:, :num_input_tokens] if self.torchair_graph_enabled and not with_prefill: - input_ids = self.input_ids[:padded_batch_size] - positions = self.positions[:padded_batch_size] + input_ids = self.input_ids[:padded_num_tokens_across_dp] + positions = self.positions[:padded_num_tokens_across_dp] if get_pp_group().is_first_rank: intermediate_tensors = None @@ -1151,9 +1177,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): }) # Run forward pass - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=padded_num_tokens_across_dp, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + num_actual_tokens=total_num_scheduled_tokens): with ProfileExecuteDuration().capture_async("forward"): self.maybe_setup_kv_connector(scheduler_output) model_kwargs = {} @@ -1165,7 +1195,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): ACL_FORMAT_FRACTAL_NZ) compiled_model = self._get_torchair_lazy_compiled_model( - padded_batch_size) + padded_num_tokens_across_dp) hidden_states = compiled_model( input_ids=input_ids, positions=positions, @@ -1643,7 +1673,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): def kv_connector_no_forward( self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: - with set_forward_context(None, self.vllm_config): + with set_ascend_forward_context(None, self.vllm_config): self.maybe_setup_kv_connector(scheduler_output) finished_sending, finished_recving = ( self.get_finished_kv_transfer(scheduler_output)) @@ -1688,14 +1718,26 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _dummy_run( self, num_tokens: int, - is_compile: bool = False, - with_prefill: bool = True, + skip_attn: bool = True, + with_prefill: bool = False, + is_torchair_compile: bool = False, ) -> torch.Tensor: + maybe_padded_num_tokens = num_tokens + if self.torchair_graph_enabled and not with_prefill: + maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + num_tokens) + + # Padding for DP + (num_tokens, num_tokens_across_dp, with_prefill, + enable_dbo) = self._get_forward_metadata_across_dp( + maybe_padded_num_tokens, num_tokens, with_prefill, False) + # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. assert num_tokens <= self.scheduler_config.max_num_batched_tokens - num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs num_scheduled_tokens_list[-1] += num_tokens % num_reqs @@ -1706,6 +1748,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.is_kv_producer: with_prefill = True + # NOTE: If torchair graph mode and not with_prefill, + # we can't skip_attn, it will cause graph recompile. + if self.torchair_graph_enabled and not with_prefill: + attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_tokens, num_actual_tokens=1) + elif skip_attn: + attn_metadata = None + else: + # TODO(zzzzwwjj): when aclgraph and full graph mode, we need build attn_metadata + attn_metadata = None + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model @@ -1735,20 +1788,27 @@ class NPUModelRunner(LoRAModelRunnerMixin): for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): + with set_ascend_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + with_prefill=with_prefill, + in_profile_run=self.in_profile_run, + num_actual_tokens=0, + ): + model_kwargs = {} if self.torchair_graph_enabled and not with_prefill: - attn_metadata = self.attn_metadata_builder.build_dummy( - num_reqs=num_tokens, num_actual_tokens=1) # Only mark static while compiling - if is_compile: + if is_torchair_compile: torch._dynamo.mark_static(input_ids) torch._dynamo.mark_static(positions) torch._dynamo.mark_static( attn_metadata.decode.block_table) torch._dynamo.mark_static( attn_metadata.decode.input_positions) + torch._dynamo.mark_static( + get_forward_context().mc2_mask) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance( @@ -1761,13 +1821,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): compiled_model = self._get_torchair_lazy_compiled_model( num_tokens) + model_kwargs["kv_caches"] = self.kv_caches + model_kwargs["attn_metadata"] = attn_metadata hidden_states = compiled_model( input_ids=input_ids, positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=None, - kv_caches=self.kv_caches, - attn_metadata=attn_metadata, + **model_kwargs, ) else: maybe_converting_weight_acl_format(self.model, @@ -1787,9 +1848,19 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.drafter.dummy_run(num_tokens) return hidden_states + @contextmanager + def set_in_profile_run(self): + self.in_profile_run = True + try: + yield + finally: + self.in_profile_run = False + def profile_run(self) -> None: # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.max_num_tokens) + with self.set_in_profile_run(): + hidden_states = self._dummy_run(self.max_num_tokens, + with_prefill=True) output = None if get_pp_group().is_last_rank: if self.is_pooling_model: @@ -2159,10 +2230,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): for idx, num_tokens in enumerate(reversed(torchair_graph_batch_sizes)): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens, - is_compile=True, - with_prefill=False) - self._dummy_run(num_tokens, is_compile=True, with_prefill=False) + self._dummy_run(num_tokens, is_torchair_compile=True) + self._dummy_run(num_tokens, is_torchair_compile=True) logger.info("Batchsize %d is compiled successfully: %d/%d.", num_tokens, idx + 1, len(torchair_graph_batch_sizes)) @@ -2205,6 +2274,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Trigger ACL graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. + # TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode with graph_capture(device=self.device): for num_tokens in reversed(self.aclgraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 6577bb8..08438ec 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -2,12 +2,12 @@ import torch from vllm.attention.layer import Attention from vllm.config import (VllmConfig, get_layers_from_vllm_config, set_current_vllm_config) -from vllm.forward_context import set_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, set_default_torch_dtype) from vllm.v1.sample.metadata import SamplingMetadata +from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP @@ -117,7 +117,7 @@ class MtpProposer: query_start_loc=cu_num_tokens, ) - with set_forward_context(attn_metadata, self.vllm_config): + with set_ascend_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( input_ids=input_ids, positions=target_positions, diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 9066fc7..15dba8b 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -40,9 +40,10 @@ from vllm.v1.worker.worker_base import WorkerBase from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator +from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (sleep_mode_enabled, try_register_lib, - vllm_version_is) +from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled, + try_register_lib, vllm_version_is) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner if not vllm_version_is("0.10.0"): @@ -134,6 +135,7 @@ class NPUWorker(WorkerBase): NPUPlatform.empty_cache() self.init_npu_memory = NPUPlatform.mem_get_info()[0] + init_ascend_soc_version() # Initialize the distributed environment. self._init_worker_distributed_environment() # Set random seed. @@ -272,20 +274,8 @@ class NPUWorker(WorkerBase): def pin_lora(self, lora_id: int) -> bool: return self.model_runner.pin_lora(lora_id) - def _get_max_num_tokens_and_with_prefill(self): - max_num_tokens = 1 - with_prefill = False - if self.model_runner.dp_size > 1: - max_num_tokens, with_prefill = self.model_runner._get_forward_metadata_across_dp( - max_num_tokens, with_prefill) - return max_num_tokens, with_prefill - def execute_dummy_batch(self) -> None: - max_num_tokens, with_prefill = self._get_max_num_tokens_and_with_prefill( - ) - self.model_runner._dummy_run(max_num_tokens, - is_compile=False, - with_prefill=with_prefill) + self.model_runner._dummy_run(1) def _init_worker_distributed_environment(self) -> None: """Initialize the distributed environment.""" @@ -295,6 +285,7 @@ class NPUWorker(WorkerBase): ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) + init_ascend_model_parallel(self.parallel_config) ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self):