From 2870f7c8ad20754f5cc09cc5ea25044ffb3c6515 Mon Sep 17 00:00:00 2001 From: Li-Yongwen <63399187+Li-Yongwen@users.noreply.github.com> Date: Thu, 26 Feb 2026 10:22:47 +0800 Subject: [PATCH] [Feat] Support routing replay (#6696) ### What this PR does / why we need it? [Feat] Support routing replay same as https://github.com/vllm-project/vllm-ascend/pull/6666 resubmit because of DOC failure ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9562912cead1f11e8540fb91306c5cbda66f0007 --------- Signed-off-by: liyongwen <1310439159@qq.com> Signed-off-by: Li-Yongwen <63399187+Li-Yongwen@users.noreply.github.com> Co-authored-by: wangxiyuan --- .github/workflows/scripts/config.yaml | 2 + .../2-cards/test_qwen3_moe_routing_replay.py | 32 +++++++++ .../test_patch_routed_experts_capturer.py | 58 ++++++++++++++++ vllm_ascend/ops/fused_moe/fused_moe.py | 8 +++ vllm_ascend/patch/worker/__init__.py | 1 + .../worker/patch_routed_experts_capturer.py | 68 +++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 21 ++++++ 7 files changed, 190 insertions(+) create mode 100644 tests/e2e/multicard/2-cards/test_qwen3_moe_routing_replay.py create mode 100644 tests/ut/patch/worker/patch_common/test_patch_routed_experts_capturer.py create mode 100644 vllm_ascend/patch/worker/patch_routed_experts_capturer.py diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index 02092d55..9784cbf7 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -126,6 +126,8 @@ e2e-multicard-2-cards: estimated_time: 70 - name: tests/e2e/multicard/2-cards/test_qwen3_moe.py estimated_time: 1050 + - name: tests/e2e/multicard/2-cards/test_qwen3_moe_routing_replay.py + estimated_time: 1050 - name: tests/e2e/multicard/2-cards/test_single_request_aclgraph.py estimated_time: 215 - name: tests/e2e/multicard/2-cards/test_disaggregated_encoder.py diff --git a/tests/e2e/multicard/2-cards/test_qwen3_moe_routing_replay.py b/tests/e2e/multicard/2-cards/test_qwen3_moe_routing_replay.py new file mode 100644 index 00000000..0876eb57 --- /dev/null +++ b/tests/e2e/multicard/2-cards/test_qwen3_moe_routing_replay.py @@ -0,0 +1,32 @@ +import os +from unittest.mock import patch + +from tests.e2e.conftest import VllmRunner +from vllm import SamplingParams +from vllm.sampling_params import RequestOutputKind + + +@patch.dict(os.environ, {"OMP_NUM_THREADS": "1"}) +def test_qwen3_moe_routing_replay(): + prompts = [ + "Hello, please introduce yourself.", + ] + with VllmRunner( + "Qwen/Qwen3-30B-A3B", + tensor_parallel_size=2, + enable_expert_parallel=True, + cudagraph_capture_sizes=[1, 2, 4, 8], + distributed_executor_backend="mp", + enable_return_routed_experts=True, + ) as vllm_model: + sampling_params = SamplingParams( + max_tokens=5, + temperature=0.8, + top_p=0.95, + output_kind=RequestOutputKind.FINAL_ONLY + ) + inputs = vllm_model.get_inputs(prompts=prompts) + outputs = vllm_model.model.generate(prompts=inputs, sampling_params=sampling_params) + assert outputs[0].finished + assert len(outputs[0].outputs[0].text) > 0 + assert outputs[0].outputs[0].routed_experts.size > 0 diff --git a/tests/ut/patch/worker/patch_common/test_patch_routed_experts_capturer.py b/tests/ut/patch/worker/patch_common/test_patch_routed_experts_capturer.py new file mode 100644 index 00000000..79e74962 --- /dev/null +++ b/tests/ut/patch/worker/patch_common/test_patch_routed_experts_capturer.py @@ -0,0 +1,58 @@ +from unittest.mock import patch, MagicMock +import uuid + +import torch + +from tests.ut.base import TestBase +from vllm_ascend.patch.worker.patch_routed_experts_capturer import RoutedExpertsCapturer +from vllm.config import ModelConfig, VllmConfig +from vllm.config.parallel import ParallelConfig +from transformers import PretrainedConfig +from vllm.platforms import current_platform + + +class MockVllmConfig: + + def __init__(self): + self.model_config = MagicMock() + self.model_config.hf_text_config.num_hidden_layers = 1 + self.model_config.hf_text_config.num_experts_per_tok = 1 + self.parallel_config = MagicMock() + self.parallel_config.data_parallel_rank = 0 + self.instance_id = uuid.uuid4().hex + + +class TestPatchRoutedExpertsCapturer(TestBase): + + def setUp(self): + RoutedExpertsCapturer.create() + self.capturer = RoutedExpertsCapturer.get_instance() + self.vllm_config = MockVllmConfig() + + def test_init_buffer(self): + max_num_batched_tokens = 1 + max_num_kv_tokens = 1 + with patch( + target="vllm_ascend.patch.worker.patch_routed_experts_capturer.get_tensor_model_parallel_rank", + return_value=True + ): + current_platform.device_name = "cpu" + self.capturer.init_buffer( + max_num_batched_tokens=max_num_batched_tokens, + max_num_kv_tokens=max_num_kv_tokens, + vllm_config=self.vllm_config, + ) + self.assertEqual( + self.capturer._device_buffer.shape, + ( + max_num_batched_tokens, + self.vllm_config.model_config.hf_text_config.num_hidden_layers, + self.vllm_config.model_config.hf_text_config.num_experts_per_tok, + ) + ) + self.assertEqual(self.capturer._device_buffer.dtype, torch.int32) + self.assertEqual(self.capturer._device_buffer.device.type, current_platform.device_name) + + def tearDown(self): + self.capturer.clear_buffer() + self.capturer.cleanup() diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index e2300a07..d039d1fc 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -26,6 +26,7 @@ from vllm.forward_context import get_forward_context from vllm.logger import logger from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import RoutedExpertsCapturer from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm_ascend.utils import vllm_version_is @@ -122,6 +123,13 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): e_score_correction_bias=e_score_correction_bias, global_num_experts=global_num_experts, ) + if layer.vllm_config.model_config is not None and layer.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.capture( + layer_id=layer.layer_id, + topk_ids=topk_ids, + ) if zero_expert_num > 0 and zero_expert_type is not None: topk_ids, topk_weights, zero_expert_result = zero_experts_compute( diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index e916aee8..d735dcf2 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -34,5 +34,6 @@ import vllm_ascend.patch.worker.patch_qwen3_next # noqa import vllm_ascend.patch.worker.patch_v2_eagle # noqa import vllm_ascend.patch.worker.patch_v2_uva # noqa import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa +import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa import vllm_ascend.patch.worker.patch_kimi_k25 # noqa diff --git a/vllm_ascend/patch/worker/patch_routed_experts_capturer.py b/vllm_ascend/patch/worker/patch_routed_experts_capturer.py new file mode 100644 index 00000000..c65f5863 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_routed_experts_capturer.py @@ -0,0 +1,68 @@ +import numpy as np +import torch +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( + _BUFFER_PREFIX, + _LOCK_FILE_PREFIX, + RoutedExpertsCapturer, + _create_or_attach_shared_memory, + logger, +) +from vllm.platforms import current_platform + + +def init_buffer( + self, + max_num_batched_tokens: int, + max_num_kv_tokens: int, + vllm_config: VllmConfig, +) -> None: + """ + Initialize the device buffer and optionally shared memory buffer. + + Args: + max_num_batched_tokens: Maximum number of tokens in a batch. + max_num_kv_tokens: Maximum number of KV tokens for shared memory. + vllm_config: vllm configuration containing layer and expert info. + """ + + if self._device_buffer is not None: + raise RuntimeError("Device buffer has already been initialized") + + hf_config = vllm_config.model_config.hf_text_config + num_layers = hf_config.num_hidden_layers + num_experts_per_tok = hf_config.num_experts_per_tok + + # Initialize device buffer + self._device_buffer = torch.zeros( + (max_num_batched_tokens, num_layers, num_experts_per_tok), + dtype=torch.int32, + device=current_platform.device_name, + ) + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + + if get_tensor_model_parallel_rank() != 0: + return + + # Initialize shared memory + shape = (max_num_kv_tokens, num_layers, num_experts_per_tok) + buffer_size = int(np.prod(shape)) * np.dtype(np.int32).itemsize + instance_id = vllm_config.instance_id + self._lock_file = f"{_LOCK_FILE_PREFIX}_{instance_id}_{self.dp_rank}.lock" + shm_name = f"{_BUFFER_PREFIX}_{instance_id}_{self.dp_rank}" + + self._shm = _create_or_attach_shared_memory(shm_name, buffer_size, self._lock_file) + self._host_buffer_view = np.ndarray(shape, dtype=np.int32, buffer=self._shm.buf) + self._host_buffer_view.fill(0) + + logger.debug( + "Created shared memory buffer '%s' with shape %s", + shm_name, + shape, + ) + + +# Patch for _device_buffer's initialization(device="cuda" -> device=current_platform.device_name). +# TODO Remove this patch when pr(https://github.com/vllm-project/vllm/pull/34336) is merged. +RoutedExpertsCapturer.init_buffer = init_buffer diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a489b69d..8d80b719 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -129,6 +129,7 @@ from vllm_ascend.ascend_forward_context import ( # isort: skip set_mc2_mask, set_mc2_tokens_capacity, ) +from vllm.model_executor.layers.fused_moe.routed_experts_capturer import RoutedExpertsCapturer if TYPE_CHECKING: import xgrammar as xgr # type: ignore[import-untyped] @@ -373,6 +374,7 @@ class NPUModelRunner(GPUModelRunner): self.intermediate_tensors: IntermediateTensors | None = None self.reorder_batch_threshold: int | None = None self.long_seq_metadata = None + self.cpu_slot_mapping = None @property def use_cp(self) -> bool: @@ -1050,6 +1052,12 @@ class NPUModelRunner(GPUModelRunner): scheduler_output: "SchedulerOutput", intermediate_tensors: IntermediateTensors | None = None, ) -> ModelRunnerOutput | IntermediateTensors | None: + if self.vllm_config.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.clear_buffer() + else: + logger.warning("RoutedExpertsCapturer is not initialized.") if self.execute_model_state is not None: raise RuntimeError("State error: sample_tokens() must be called after execute_model() returns None.") # self._draft_token_ids is None when `input_fits_in_drafter=False` @@ -1428,6 +1436,14 @@ class NPUModelRunner(GPUModelRunner): if has_kv_transfer_group(): get_kv_transfer_group().clear_connector_metadata() + + if self.model_config.enable_return_routed_experts: + capturer = RoutedExpertsCapturer.get_instance() + if capturer is not None: + capturer.save_captured_experts(indices=self.cpu_slot_mapping) + else: + logger.warning("RoutedExpertsCapturer is not initialized.") + model_runner_output = ModelRunnerOutput( req_ids=req_ids_output_copy, req_id_to_index=req_id_to_index_output_copy, @@ -1902,6 +1918,8 @@ class NPUModelRunner(GPUModelRunner): num_tokens_padded, slot_mapping, ) + if self.model_config.enable_return_routed_experts and kv_cache_gid == 0: + self.cpu_slot_mapping = slot_mapping.cpu().numpy() return blk_table_tensor, slot_mapping block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0) @@ -2364,6 +2382,9 @@ class NPUModelRunner(GPUModelRunner): if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(kv_caches) + if self.model_config.enable_return_routed_experts: + self.init_routed_experts_capturer() + def _align_memory(self, tensor: torch.Tensor, alignment: int) -> torch.Tensor: data_ptr = tensor.data_ptr() aligned_addr = (data_ptr + alignment - 1) // alignment * alignment